352 lines
13 KiB
TypeScript
352 lines
13 KiB
TypeScript
import { useMemo, useState } from 'react';
|
|
import type { Agent, BranchInfo, ChatAttachment, ChatMessage, ModelOverrides, RetrievedSnippet, ToolCallTrace } from '../../../api';
|
|
import { ChatAPI, ChatAttachmentsAPI, ImageAPI, regenerateMessage, streamChat } from '../../../api';
|
|
import { buildAttachmentsText } from '../utils/attachments';
|
|
import { parseAgentModels } from '../utils/agentModels';
|
|
|
|
export interface StreamingState {
|
|
active: boolean;
|
|
reasoningText: string;
|
|
answerText: string;
|
|
errorMessage: string | null;
|
|
retryInfo: any | null;
|
|
retrieved: RetrievedSnippet[];
|
|
toolCalls: ToolCallTrace[];
|
|
}
|
|
|
|
export function useChatSender(args: {
|
|
agentId?: string;
|
|
agent: Agent | null;
|
|
sessionId: string | null;
|
|
overrides: ModelOverrides;
|
|
setOverrides: (updater: (prev: ModelOverrides) => ModelOverrides) => void;
|
|
messages: ChatMessage[];
|
|
setMessages: (updater: (prev: ChatMessage[]) => ChatMessage[]) => void;
|
|
setBranches: (v: Record<string, BranchInfo>) => void;
|
|
loadMessages: () => Promise<void>;
|
|
scrollBottom: (force?: boolean) => void;
|
|
notify: { success: (t: string) => void; error: (t: string) => void };
|
|
abortRef: { current: AbortController | null };
|
|
}) {
|
|
const { agentId, agent, sessionId, overrides, setOverrides, setBranches, loadMessages, scrollBottom, notify, abortRef } = args;
|
|
const [input, setInput] = useState('');
|
|
const [sending, setSending] = useState(false);
|
|
const [useStream, setUseStream] = useState(true);
|
|
const [attachments, setAttachments] = useState<ChatAttachment[]>([]);
|
|
const [imageUrls, setImageUrls] = useState<string[]>([]);
|
|
const [uploadingAtt, setUploadingAtt] = useState(false);
|
|
const [streaming, setStreaming] = useState<StreamingState>({
|
|
active: false,
|
|
reasoningText: '',
|
|
answerText: '',
|
|
errorMessage: null,
|
|
retryInfo: null,
|
|
retrieved: [],
|
|
toolCalls: []
|
|
});
|
|
const [sessionRefresh, setSessionRefresh] = useState(0);
|
|
|
|
const agentModels = useMemo(() => parseAgentModels(agent?.model), [agent?.model]);
|
|
const modelOptions = useMemo(() => agentModels.map((model) => ({ value: model.id, label: model.name })), [agentModels]);
|
|
const activeModelValue = overrides.model_id || '';
|
|
|
|
const handleSendStream = async (text: string) => {
|
|
if (!agentId) return;
|
|
if (!sessionId) {
|
|
notify.error('会话未初始化,请稍后重试');
|
|
return;
|
|
}
|
|
const tempUser: ChatMessage = { id: 'tmp-' + Date.now(), role: 'user', content: text, createdAt: Date.now() };
|
|
args.setMessages((m) => [...(m || []), tempUser]);
|
|
setStreaming({ active: true, reasoningText: '', answerText: '', errorMessage: null, retryInfo: null, retrieved: [], toolCalls: [] });
|
|
scrollBottom(true);
|
|
|
|
abortRef.current?.abort();
|
|
const ctrl = new AbortController();
|
|
abortRef.current = ctrl;
|
|
|
|
const model = overrides.model || agentModels[0]?.name || '';
|
|
const modelId = overrides.model_id || agentModels[0]?.id || '';
|
|
const attText = buildAttachmentsText(attachments);
|
|
const content = attText ? `${text}\n\n${attText}` : text;
|
|
|
|
try {
|
|
await streamChat(
|
|
agentId,
|
|
content,
|
|
{
|
|
onMeta: (m) => setStreaming((s) => ({ ...s, retrieved: m.retrieved || [] })),
|
|
onRetry: (data) => {
|
|
setStreaming((s) => ({ ...s, retryInfo: data }));
|
|
if (data?.stage === 'fallback_model' && data?.toModel) {
|
|
setOverrides((o) => ({ ...o, model: String(data.toModel) }));
|
|
}
|
|
},
|
|
onReasoningDelta: (chunk) =>
|
|
setStreaming((s) => {
|
|
const next = { ...s, reasoningText: s.reasoningText + chunk };
|
|
scrollBottom();
|
|
return next;
|
|
}),
|
|
onDelta: (chunk) =>
|
|
setStreaming((s) => {
|
|
const next = { ...s, answerText: s.answerText + chunk };
|
|
scrollBottom();
|
|
return next;
|
|
}),
|
|
onToolCall: (data) => setStreaming((s) => ({ ...s, toolCalls: [...s.toolCalls, { name: data.name, args: data.args, result: { pending: true } }] })),
|
|
onToolResult: (data) =>
|
|
setStreaming((s) => {
|
|
const list = [...s.toolCalls];
|
|
for (let i = list.length - 1; i >= 0; i--) {
|
|
if (list[i].name === data.name && (list[i].result as any)?.pending) {
|
|
list[i] = { ...list[i], result: data.result };
|
|
break;
|
|
}
|
|
}
|
|
return { ...s, toolCalls: list };
|
|
}),
|
|
onDone: (data) => {
|
|
setStreaming((s) => {
|
|
const assistant: any = data.assistant;
|
|
const reasoningText = assistant?.reasoning || assistant?.meta?.reasoning || assistant?.meta?.reasoningText || s.reasoningText;
|
|
const nextAssistant: ChatMessage = {
|
|
...data.assistant,
|
|
meta: {
|
|
...(data.assistant.meta || {}),
|
|
retrieved: data.assistant.meta?.retrieved || s.retrieved,
|
|
toolCalls: data.assistant.meta?.toolCalls || s.toolCalls,
|
|
reasoning: reasoningText || undefined
|
|
}
|
|
};
|
|
args.setMessages((m) => [...(m || []).filter((x) => x.id !== tempUser.id), data.user, nextAssistant]);
|
|
return { active: false, reasoningText: '', answerText: '', errorMessage: null, retryInfo: null, retrieved: [], toolCalls: [] };
|
|
});
|
|
setSessionRefresh((t) => t + 1);
|
|
setAttachments([]);
|
|
setImageUrls([]);
|
|
scrollBottom();
|
|
},
|
|
onAborted: (data) => {
|
|
const assistant: ChatMessage = { ...data.assistant, meta: { ...(data.assistant.meta || {}), aborted: true } };
|
|
args.setMessages((m) => [...(m || []).filter((x) => x.id !== tempUser.id), { ...tempUser, id: 'u-' + data.assistant.id }, assistant]);
|
|
setStreaming({ active: false, reasoningText: '', answerText: '', errorMessage: null, retryInfo: null, retrieved: [], toolCalls: [] });
|
|
setSessionRefresh((t) => t + 1);
|
|
loadMessages();
|
|
},
|
|
onError: (errMsg) => {
|
|
notify.error('流式失败:' + errMsg);
|
|
const errorMessage: ChatMessage = {
|
|
id: 'error-' + Date.now(),
|
|
role: 'assistant',
|
|
content: `❌ 请求失败:${errMsg}`,
|
|
createdAt: Date.now(),
|
|
meta: { error: errMsg }
|
|
};
|
|
args.setMessages((m) => [...(m || []), errorMessage]);
|
|
setStreaming({ active: false, reasoningText: '', answerText: '', errorMessage: errMsg, retryInfo: null, retrieved: [], toolCalls: [] });
|
|
}
|
|
},
|
|
ctrl.signal,
|
|
sessionId,
|
|
model,
|
|
modelId,
|
|
imageUrls
|
|
);
|
|
} catch (e: any) {
|
|
if (e?.name !== 'AbortError') {
|
|
notify.error('请求失败:' + (e?.message ?? e));
|
|
const errorMessage: ChatMessage = {
|
|
id: 'error-' + Date.now(),
|
|
role: 'assistant',
|
|
content: `❌ 请求失败:${e?.message ?? String(e)}`,
|
|
createdAt: Date.now(),
|
|
meta: { error: e?.message ?? String(e) }
|
|
};
|
|
args.setMessages((m) => [...(m || []), errorMessage]);
|
|
}
|
|
setStreaming({ active: false, reasoningText: '', answerText: '', errorMessage: e?.message ?? String(e), retryInfo: null, retrieved: [], toolCalls: [] });
|
|
}
|
|
};
|
|
|
|
const handleSendNonStream = async (text: string) => {
|
|
if (!agentId) return;
|
|
if (!sessionId) {
|
|
notify.error('会话未初始化,请稍后重试');
|
|
return;
|
|
}
|
|
const tempUser: ChatMessage = { id: 'tmp-' + Date.now(), role: 'user', content: text, createdAt: Date.now() };
|
|
args.setMessages((m) => [...(m || []), tempUser]);
|
|
scrollBottom(true);
|
|
const attText = buildAttachmentsText(attachments);
|
|
const content = attText ? `${text}\n\n${attText}` : text;
|
|
const model = overrides.model || agentModels[0]?.name || '';
|
|
try {
|
|
const res = await ChatAPI.send(agentId, content, sessionId, model, imageUrls);
|
|
args.setMessages((m) => [...(m || []).filter((x) => x.id !== tempUser.id), res.user, res.assistant]);
|
|
setSessionRefresh((t) => t + 1);
|
|
setAttachments([]);
|
|
setImageUrls([]);
|
|
scrollBottom();
|
|
} catch (e: any) {
|
|
notify.error('发送失败:' + (e?.message ?? e));
|
|
args.setMessages((m) => (m || []).filter((x) => x.id !== tempUser.id));
|
|
}
|
|
};
|
|
|
|
const handleSend = async () => {
|
|
const text = input.trim();
|
|
if (!text || !agentId || sending) return;
|
|
setInput('');
|
|
setSending(true);
|
|
try {
|
|
if (useStream) await handleSendStream(text);
|
|
else await handleSendNonStream(text);
|
|
} finally {
|
|
setSending(false);
|
|
}
|
|
};
|
|
|
|
const handleStop = () => {
|
|
abortRef.current?.abort();
|
|
setSending(false);
|
|
};
|
|
|
|
const handleClear = async () => {
|
|
if (!agentId) return;
|
|
if (!sessionId) {
|
|
notify.error('会话未初始化,请稍后重试');
|
|
return;
|
|
}
|
|
await ChatAPI.clear(agentId, sessionId);
|
|
args.setMessages(() => []);
|
|
setBranches({});
|
|
notify.success('对话已清空');
|
|
};
|
|
|
|
const handleRegenerate = async (assistantId: string) => {
|
|
if (!agentId || sending) return;
|
|
setSending(true);
|
|
setStreaming({ active: true, reasoningText: '', answerText: '', errorMessage: null, retryInfo: null, retrieved: [], toolCalls: [] });
|
|
abortRef.current?.abort();
|
|
const ctrl = new AbortController();
|
|
abortRef.current = ctrl;
|
|
try {
|
|
await regenerateMessage(
|
|
agentId,
|
|
assistantId,
|
|
{
|
|
onMeta: (m) => setStreaming((s) => ({ ...s, retrieved: m.retrieved || [] })),
|
|
onRetry: (data) => {
|
|
setStreaming((s) => ({ ...s, retryInfo: data }));
|
|
if (data?.stage === 'fallback_model' && data?.toModel) {
|
|
setOverrides((o) => ({ ...o, model: String(data.toModel) }));
|
|
}
|
|
},
|
|
onReasoningDelta: (chunk) =>
|
|
setStreaming((s) => {
|
|
const next = { ...s, reasoningText: s.reasoningText + chunk };
|
|
scrollBottom();
|
|
return next;
|
|
}),
|
|
onDelta: (chunk) =>
|
|
setStreaming((s) => {
|
|
const next = { ...s, answerText: s.answerText + chunk };
|
|
scrollBottom();
|
|
return next;
|
|
}),
|
|
onToolCall: (data) => setStreaming((s) => ({ ...s, toolCalls: [...s.toolCalls, { name: data.name, args: data.args, result: { pending: true } }] })),
|
|
onToolResult: (data) =>
|
|
setStreaming((s) => {
|
|
const list = [...s.toolCalls];
|
|
for (let i = list.length - 1; i >= 0; i--) {
|
|
if (list[i].name === data.name && (list[i].result as any)?.pending) {
|
|
list[i] = { ...list[i], result: data.result };
|
|
break;
|
|
}
|
|
}
|
|
return { ...s, toolCalls: list };
|
|
}),
|
|
onDone: () => {
|
|
setStreaming({ active: false, reasoningText: '', answerText: '', errorMessage: null, retryInfo: null, retrieved: [], toolCalls: [] });
|
|
loadMessages();
|
|
},
|
|
onAborted: () => {
|
|
setStreaming({ active: false, reasoningText: '', answerText: '', errorMessage: null, retryInfo: null, retrieved: [], toolCalls: [] });
|
|
loadMessages();
|
|
},
|
|
onError: (errMsg) => {
|
|
notify.error('重新生成失败:' + errMsg);
|
|
setStreaming({ active: false, reasoningText: '', answerText: '', errorMessage: errMsg, retryInfo: null, retrieved: [], toolCalls: [] });
|
|
loadMessages();
|
|
}
|
|
},
|
|
ctrl.signal,
|
|
overrides
|
|
);
|
|
} finally {
|
|
setSending(false);
|
|
}
|
|
};
|
|
|
|
const handleSwitchBranch = async (userMsgId: string, branchId: string) => {
|
|
if (!agentId) return;
|
|
await ChatAPI.switchBranch(agentId, userMsgId, branchId);
|
|
await loadMessages();
|
|
};
|
|
|
|
const handleAttach = async (files: File[]) => {
|
|
if (!files.length) return;
|
|
setUploadingAtt(true);
|
|
try {
|
|
const images = files.filter((f) => f.type.startsWith('image/'));
|
|
const docs = files.filter((f) => !f.type.startsWith('image/'));
|
|
let imgN = 0;
|
|
let docN = 0;
|
|
if (images.length) {
|
|
const r = await ImageAPI.upload(images);
|
|
setImageUrls((arr) => [...arr, ...r.files.map((f) => f.url)]);
|
|
imgN = r.files.length;
|
|
}
|
|
if (docs.length) {
|
|
const r = await ChatAttachmentsAPI.upload(docs);
|
|
setAttachments((a) => [...a, ...r.files]);
|
|
docN = r.files.length;
|
|
}
|
|
const parts = [];
|
|
if (imgN) parts.push(`${imgN} 张图片`);
|
|
if (docN) parts.push(`${docN} 个文档`);
|
|
if (parts.length) notify.success(`已附加 ${parts.join(' + ')}`);
|
|
} catch (e: any) {
|
|
notify.error('附件上传失败:' + (e?.message ?? e));
|
|
} finally {
|
|
setUploadingAtt(false);
|
|
}
|
|
};
|
|
|
|
return {
|
|
input,
|
|
setInput,
|
|
sending,
|
|
useStream,
|
|
setUseStream,
|
|
sessionRefresh,
|
|
overrides,
|
|
setOverrides,
|
|
attachments,
|
|
setAttachments,
|
|
imageUrls,
|
|
setImageUrls,
|
|
uploadingAtt,
|
|
streaming,
|
|
handleSend,
|
|
handleStop,
|
|
handleClear,
|
|
handleRegenerate,
|
|
handleSwitchBranch,
|
|
handleAttach,
|
|
modelOptions,
|
|
activeModelValue
|
|
};
|
|
}
|