aura-web/src/pages/chat/hooks/useChatSender.ts

447 lines
16 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import { useMemo, useState } from 'react';
import type { Agent, BranchInfo, ChatAttachment, ChatMessage, ModelOverrides, RetrievedSnippet, ToolCallTrace } from '../../../api';
import { ChatAPI, ChatAttachmentsAPI, ImageAPI, regenerateMessage, streamChat } from '../../../api';
import { buildAttachmentsText } from '../utils/attachments';
import { parseAgentModels } from '../utils/agentModels';
export interface StreamingState {
active: boolean;
reasoningText: string;
answerText: string;
errorMessage: string | null;
retryInfo: any | null;
retrieved: RetrievedSnippet[];
toolCalls: ToolCallTrace[];
targetAgentId?: string;
}
// 解析文本中的 @AgentName返回找到的目标 Agent
export function parseMentionedAgent(text: string, agentList: Agent[]): Agent | null {
// 匹配 @ 开头,空格/结尾结束的名称
const mentionRegex = /@([^\s]+)/g;
let match: RegExpExecArray | null;
let lastMention: string | null = null;
while ((match = mentionRegex.exec(text)) !== null) {
lastMention = match[1];
}
// 如果一行开头就是 @AgentName优先使用它最常见情况
const firstLine = text.split('\n')[0];
const startMatch = firstLine.match(/^@([^\s]+)/);
const candidateName = startMatch ? startMatch[1] : lastMention;
console.log('[parseMentionedAgent]', {
text,
lastMention,
firstLine,
startMatch,
candidateName,
agentListCount: agentList.length,
agentList: agentList.map(a => ({ id: a.id, name: a.name }))
});
if (!candidateName) {
console.log('[parseMentionedAgent] no candidateName, return null');
return null;
}
// 模糊匹配(前缀匹配,不区分大小写)
const lowerCandidate = candidateName.toLowerCase();
let bestMatch: Agent | null = null;
let bestScore = 0;
for (const a of agentList) {
const lowerName = a.name.toLowerCase();
console.log('[parseMentionedAgent] checking', { candidate: lowerCandidate, agentName: lowerName, id: a.id });
if (lowerName === lowerCandidate) {
console.log('[parseMentionedAgent] exact match found', a);
return a; // 精确匹配直接返回
}
if (lowerName.startsWith(lowerCandidate)) {
const score = candidateName.length / lowerName.length;
console.log('[parseMentionedAgent] prefix match', { score, agent: a });
if (score > bestScore) {
bestScore = score;
bestMatch = a;
}
} else if (lowerName.includes(lowerCandidate)) {
const score = candidateName.length / lowerName.length * 0.8; // 包含匹配权重稍低
console.log('[parseMentionedAgent] contains match', { score, agent: a });
if (score > bestScore) {
bestScore = score;
bestMatch = a;
}
}
}
const result = bestScore > 0.3 ? bestMatch : null;
console.log('[parseMentionedAgent] final result', { bestScore, result });
return result;
}
export function useChatSender(args: {
agentId?: string;
agent: Agent | null;
agentList: Agent[];
sessionId: string | null;
overrides: ModelOverrides;
setOverrides: (updater: (prev: ModelOverrides) => ModelOverrides) => void;
messages: ChatMessage[];
setMessages: (updater: (prev: ChatMessage[]) => ChatMessage[]) => void;
setBranches: (v: Record<string, BranchInfo>) => void;
loadMessages: () => Promise<void>;
scrollBottom: (force?: boolean) => void;
notify: { success: (t: string) => void; error: (t: string) => void };
abortRef: { current: AbortController | null };
}) {
const { agentId, agent, agentList, sessionId, overrides, setOverrides, setBranches, loadMessages, scrollBottom, notify, abortRef } = args;
const [input, setInput] = useState('');
const [sending, setSending] = useState(false);
const [useStream, setUseStream] = useState(true);
const [attachments, setAttachments] = useState<ChatAttachment[]>([]);
const [imageUrls, setImageUrls] = useState<string[]>([]);
const [uploadingAtt, setUploadingAtt] = useState(false);
const [streaming, setStreaming] = useState<StreamingState>({
active: false,
reasoningText: '',
answerText: '',
errorMessage: null,
retryInfo: null,
retrieved: [],
toolCalls: [],
targetAgentId: undefined
});
const [sessionRefresh, setSessionRefresh] = useState(0);
const agentModels = useMemo(() => parseAgentModels(agent?.model), [agent?.model]);
const modelOptions = useMemo(() => agentModels.map((model) => ({ value: model.id, label: model.name })), [agentModels]);
const activeModelValue = overrides.model_id || '';
const handleSendStream = async (text: string) => {
if (!agentId) return;
if (!sessionId) {
notify.error('会话未初始化,请稍后重试');
return;
}
// 解析 @提及的目标 Agent
const targetAgent = parseMentionedAgent(text, agentList);
const targetAgentId = targetAgent?.id || agentId;
const tempUser: ChatMessage = { id: 'tmp-' + Date.now(), role: 'user', content: text, createdAt: Date.now() };
args.setMessages((m) => [...(m || []), tempUser]);
setStreaming({ active: true, reasoningText: '', answerText: '', errorMessage: null, retryInfo: null, retrieved: [], toolCalls: [], targetAgentId });
scrollBottom(true);
abortRef.current?.abort();
const ctrl = new AbortController();
abortRef.current = ctrl;
// 如果提及了其他 Agent使用该 Agent 的第一个模型
let targetModel: string;
let targetModelId: string;
if (targetAgent && targetAgent.id !== agentId) {
const models = parseAgentModels(targetAgent.model);
targetModel = models[0]?.name || '';
targetModelId = models[0]?.id || '';
} else {
targetModel = overrides.model || agentModels[0]?.name || '';
targetModelId = overrides.model_id || agentModels[0]?.id || '';
}
const attText = buildAttachmentsText(attachments);
const content = attText ? `${text}\n\n${attText}` : text;
try {
await streamChat(
targetAgentId,
content,
{
onMeta: (m) => setStreaming((s) => ({ ...s, retrieved: m.retrieved || [] })),
onRetry: (data) => {
setStreaming((s) => ({ ...s, retryInfo: data }));
if (data?.stage === 'fallback_model' && data?.toModel) {
setOverrides((o) => ({ ...o, model: String(data.toModel) }));
}
},
onReasoningDelta: (chunk) =>
setStreaming((s) => {
const next = { ...s, reasoningText: s.reasoningText + chunk };
scrollBottom();
return next;
}),
onDelta: (chunk) =>
setStreaming((s) => {
const next = { ...s, answerText: s.answerText + chunk };
scrollBottom();
return next;
}),
onToolCall: (data) => setStreaming((s) => ({ ...s, toolCalls: [...s.toolCalls, { name: data.name, args: data.args, result: { pending: true } }] })),
onToolResult: (data) =>
setStreaming((s) => {
const list = [...s.toolCalls];
for (let i = list.length - 1; i >= 0; i--) {
if (list[i].name === data.name && (list[i].result as any)?.pending) {
list[i] = { ...list[i], result: data.result };
break;
}
}
return { ...s, toolCalls: list };
}),
onDone: (data) => {
setStreaming((s) => {
const assistant: any = data.assistant;
const reasoningText = assistant?.reasoning || assistant?.meta?.reasoning || assistant?.meta?.reasoningText || s.reasoningText;
const nextAssistant: ChatMessage = {
...data.assistant,
meta: {
...(data.assistant.meta || {}),
retrieved: data.assistant.meta?.retrieved || s.retrieved,
toolCalls: data.assistant.meta?.toolCalls || s.toolCalls,
reasoning: reasoningText || undefined
}
};
args.setMessages((m) => [...(m || []).filter((x) => x.id !== tempUser.id), data.user, nextAssistant]);
return { active: false, reasoningText: '', answerText: '', errorMessage: null, retryInfo: null, retrieved: [], toolCalls: [] };
});
setSessionRefresh((t) => t + 1);
setAttachments([]);
setImageUrls([]);
scrollBottom();
},
onAborted: (data) => {
const assistant: ChatMessage = { ...data.assistant, meta: { ...(data.assistant.meta || {}), aborted: true } };
args.setMessages((m) => [...(m || []).filter((x) => x.id !== tempUser.id), { ...tempUser, id: 'u-' + data.assistant.id }, assistant]);
setStreaming({ active: false, reasoningText: '', answerText: '', errorMessage: null, retryInfo: null, retrieved: [], toolCalls: [] });
setSessionRefresh((t) => t + 1);
loadMessages();
},
onError: (errMsg) => {
notify.error('流式失败:' + errMsg);
const errorMessage: ChatMessage = {
id: 'error-' + Date.now(),
role: 'assistant',
content: `❌ 请求失败:${errMsg}`,
createdAt: Date.now(),
meta: { error: errMsg }
};
args.setMessages((m) => [...(m || []), errorMessage]);
setStreaming({ active: false, reasoningText: '', answerText: '', errorMessage: errMsg, retryInfo: null, retrieved: [], toolCalls: [] });
}
},
ctrl.signal,
sessionId,
targetModel,
targetModelId,
imageUrls
);
} catch (e: any) {
if (e?.name !== 'AbortError') {
notify.error('请求失败:' + (e?.message ?? e));
const errorMessage: ChatMessage = {
id: 'error-' + Date.now(),
role: 'assistant',
content: `❌ 请求失败:${e?.message ?? String(e)}`,
createdAt: Date.now(),
meta: { error: e?.message ?? String(e) }
};
args.setMessages((m) => [...(m || []), errorMessage]);
}
setStreaming({ active: false, reasoningText: '', answerText: '', errorMessage: e?.message ?? String(e), retryInfo: null, retrieved: [], toolCalls: [] });
}
};
const handleSendNonStream = async (text: string) => {
if (!agentId) return;
if (!sessionId) {
notify.error('会话未初始化,请稍后重试');
return;
}
const tempUser: ChatMessage = { id: 'tmp-' + Date.now(), role: 'user', content: text, createdAt: Date.now() };
args.setMessages((m) => [...(m || []), tempUser]);
scrollBottom(true);
// 解析 @提及的目标 Agent
const targetAgent = parseMentionedAgent(text, agentList);
const targetAgentId = targetAgent?.id || agentId;
// 如果提及了其他 Agent使用该 Agent 的第一个模型
let targetModel: string;
if (targetAgent && targetAgent.id !== agentId) {
const models = parseAgentModels(targetAgent.model);
targetModel = models[0]?.name || '';
} else {
targetModel = overrides.model || agentModels[0]?.name || '';
}
const attText = buildAttachmentsText(attachments);
const content = attText ? `${text}\n\n${attText}` : text;
try {
const res = await ChatAPI.send(targetAgentId, content, sessionId, targetModel, imageUrls);
args.setMessages((m) => [...(m || []).filter((x) => x.id !== tempUser.id), res.user, res.assistant]);
setSessionRefresh((t) => t + 1);
setAttachments([]);
setImageUrls([]);
scrollBottom();
} catch (e: any) {
notify.error('发送失败:' + (e?.message ?? e));
args.setMessages((m) => (m || []).filter((x) => x.id !== tempUser.id));
}
};
const handleSend = async () => {
const text = input.trim();
if (!text || !agentId || sending) return;
setInput('');
setSending(true);
try {
if (useStream) await handleSendStream(text);
else await handleSendNonStream(text);
} finally {
setSending(false);
}
};
const handleStop = () => {
abortRef.current?.abort();
setSending(false);
};
const handleClear = async () => {
if (!agentId) return;
if (!sessionId) {
notify.error('会话未初始化,请稍后重试');
return;
}
await ChatAPI.clear(agentId, sessionId);
args.setMessages(() => []);
setBranches({});
notify.success('对话已清空');
};
const handleRegenerate = async (assistantId: string) => {
if (!agentId || sending) return;
setSending(true);
setStreaming({ active: true, reasoningText: '', answerText: '', errorMessage: null, retryInfo: null, retrieved: [], toolCalls: [] });
abortRef.current?.abort();
const ctrl = new AbortController();
abortRef.current = ctrl;
try {
await regenerateMessage(
agentId,
assistantId,
{
onMeta: (m) => setStreaming((s) => ({ ...s, retrieved: m.retrieved || [] })),
onRetry: (data) => {
setStreaming((s) => ({ ...s, retryInfo: data }));
if (data?.stage === 'fallback_model' && data?.toModel) {
setOverrides((o) => ({ ...o, model: String(data.toModel) }));
}
},
onReasoningDelta: (chunk) =>
setStreaming((s) => {
const next = { ...s, reasoningText: s.reasoningText + chunk };
scrollBottom();
return next;
}),
onDelta: (chunk) =>
setStreaming((s) => {
const next = { ...s, answerText: s.answerText + chunk };
scrollBottom();
return next;
}),
onToolCall: (data) => setStreaming((s) => ({ ...s, toolCalls: [...s.toolCalls, { name: data.name, args: data.args, result: { pending: true } }] })),
onToolResult: (data) =>
setStreaming((s) => {
const list = [...s.toolCalls];
for (let i = list.length - 1; i >= 0; i--) {
if (list[i].name === data.name && (list[i].result as any)?.pending) {
list[i] = { ...list[i], result: data.result };
break;
}
}
return { ...s, toolCalls: list };
}),
onDone: () => {
setStreaming({ active: false, reasoningText: '', answerText: '', errorMessage: null, retryInfo: null, retrieved: [], toolCalls: [] });
loadMessages();
},
onAborted: () => {
setStreaming({ active: false, reasoningText: '', answerText: '', errorMessage: null, retryInfo: null, retrieved: [], toolCalls: [] });
loadMessages();
},
onError: (errMsg) => {
notify.error('重新生成失败:' + errMsg);
setStreaming({ active: false, reasoningText: '', answerText: '', errorMessage: errMsg, retryInfo: null, retrieved: [], toolCalls: [] });
loadMessages();
}
},
ctrl.signal,
overrides
);
} finally {
setSending(false);
}
};
const handleSwitchBranch = async (userMsgId: string, branchId: string) => {
if (!agentId) return;
await ChatAPI.switchBranch(agentId, userMsgId, branchId);
await loadMessages();
};
const handleAttach = async (files: File[]) => {
if (!files.length) return;
setUploadingAtt(true);
try {
const images = files.filter((f) => f.type.startsWith('image/'));
const docs = files.filter((f) => !f.type.startsWith('image/'));
let imgN = 0;
let docN = 0;
if (images.length) {
const r = await ImageAPI.upload(images);
setImageUrls((arr) => [...arr, ...r.files.map((f) => f.url)]);
imgN = r.files.length;
}
if (docs.length) {
const r = await ChatAttachmentsAPI.upload(docs);
setAttachments((a) => [...a, ...r.files]);
docN = r.files.length;
}
const parts = [];
if (imgN) parts.push(`${imgN} 张图片`);
if (docN) parts.push(`${docN} 个文档`);
if (parts.length) notify.success(`已附加 ${parts.join(' + ')}`);
} catch (e: any) {
notify.error('附件上传失败:' + (e?.message ?? e));
} finally {
setUploadingAtt(false);
}
};
return {
input,
setInput,
sending,
useStream,
setUseStream,
sessionRefresh,
overrides,
setOverrides,
attachments,
setAttachments,
imageUrls,
setImageUrls,
uploadingAtt,
streaming,
handleSend,
handleStop,
handleClear,
handleRegenerate,
handleSwitchBranch,
handleAttach,
modelOptions,
activeModelValue
};
}