entari-plugin-hyw 4.0.0rc11__py3-none-any.whl → 4.0.0rc13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of entari-plugin-hyw might be problematic. Click here for more details.

hyw_core/agent.py ADDED
@@ -0,0 +1,705 @@
1
+ """
2
+ Agent Pipeline
3
+
4
+ Tool-calling agent that can autonomously use web_tool to search/screenshot.
5
+ Maximum 2 rounds of tool calls, up to 3 parallel calls per round.
6
+ """
7
+
8
+ import asyncio
9
+ import json
10
+ import re
11
+ import time
12
+ from dataclasses import dataclass, field
13
+ from typing import Any, Callable, Awaitable, Dict, List, Optional
14
+
15
+ from loguru import logger
16
+ from openai import AsyncOpenAI
17
+
18
+ from .definitions import get_web_tool, get_refuse_answer_tool, AGENT_SYSTEM_PROMPT
19
+ from .stages.base import StageContext, StageResult
20
+ from .search import SearchService
21
+
22
+
23
+ @dataclass
24
+ class AgentSession:
25
+ """Agent session with tool call tracking."""
26
+ session_id: str
27
+ user_query: str
28
+ tool_calls: List[Dict[str, Any]] = field(default_factory=list)
29
+ tool_results: List[Dict[str, Any]] = field(default_factory=list)
30
+ conversation_history: List[Dict] = field(default_factory=list)
31
+ messages: List[Dict] = field(default_factory=list) # LLM conversation
32
+ created_at: float = field(default_factory=time.time)
33
+
34
+ # Round tracking (each round can have up to 3 parallel tool calls)
35
+ round_count: int = 0
36
+
37
+ # Image tracking
38
+ user_image_count: int = 0 # Number of images from user input
39
+ total_image_count: int = 0 # Total images including web screenshots
40
+
41
+ # Time tracking
42
+ search_time: float = 0.0 # Total time spent on search/screenshot
43
+ llm_time: float = 0.0 # Total time spent on LLM calls
44
+ first_llm_time: float = 0.0 # Time for first LLM call (understanding intent)
45
+
46
+ # Usage tracking
47
+ usage_totals: Dict[str, int] = field(default_factory=lambda: {"input_tokens": 0, "output_tokens": 0})
48
+
49
+ @property
50
+ def call_count(self) -> int:
51
+ """Total number of individual tool calls."""
52
+ return len(self.tool_calls)
53
+
54
+ @property
55
+ def should_force_summary(self) -> bool:
56
+ """Force summary after 2 rounds of tool calls."""
57
+ return self.round_count >= 2
58
+
59
+
60
+ def parse_filter_syntax(query: str, max_count: int = 3):
61
+ """
62
+ Parse enhanced filter syntax supporting:
63
+ - Chinese/English colons (: :) and commas (, ,)
64
+ - Multiple filters: "mcmod=2, github=1 : xxx"
65
+ - Index lists: "1, 2, 3 : xxx"
66
+ - Max total selections
67
+
68
+ Returns:
69
+ filters: list of (filter_type, filter_value, count) tuples
70
+ filter_type: 'index' or 'link'
71
+ filter_value: int (for index) or str (for link match term)
72
+ count: how many to get (default 1)
73
+ search_query: the actual search query
74
+ error_msg: error message if exceeded max
75
+ """
76
+ import re
77
+
78
+ # Skip filter parsing if query contains URL (has :// pattern)
79
+ if re.search(r'https?://', query):
80
+ return [], query.strip(), None
81
+
82
+ # Normalize colons
83
+ query = query.replace(':', ':')
84
+
85
+ if ':' not in query:
86
+ return [], query.strip(), None
87
+
88
+ parts = query.split(':', 1)
89
+ if len(parts) != 2:
90
+ return [], query.strip(), None
91
+
92
+ filter_part = parts[0].strip()
93
+ search_query = parts[1].strip()
94
+
95
+ if not filter_part or not search_query:
96
+ return [], query.strip(), None
97
+
98
+ # Parse filter expressions
99
+ filters = []
100
+ total_count = 0
101
+
102
+ # Normalize commas
103
+ filter_part = filter_part.replace(',', ',').replace('、', ',')
104
+ filter_items = [f.strip() for f in filter_part.split(',') if f.strip()]
105
+
106
+ for item in filter_items:
107
+ # Check for "term=count" format (link filter)
108
+ if '=' in item:
109
+ term, count_str = item.split('=', 1)
110
+ term = term.strip().lower()
111
+ try:
112
+ count = int(count_str.strip())
113
+ except ValueError:
114
+ count = 1
115
+ if term and count > 0:
116
+ filters.append(('link', term, count))
117
+ total_count += count
118
+ # Check for pure number (index filter)
119
+ elif item.isdigit():
120
+ idx = int(item)
121
+ if 1 <= idx <= 10:
122
+ filters.append(('index', idx, 1))
123
+ total_count += 1
124
+
125
+ if total_count > max_count:
126
+ return None, search_query, f"⚠️ 最多选择{max_count}个结果"
127
+
128
+ return filters, search_query, None
129
+
130
+
131
+ class AgentPipeline:
132
+ """
133
+ Tool-calling agent pipeline.
134
+
135
+ Flow:
136
+ 1. 用户输入 → LLM (with tools)
137
+ 2. If tool_call: execute all tools in parallel → notify user with batched message → loop
138
+ 3. If call_count >= 2 rounds: force summary on next call
139
+ 4. Return final content
140
+ """
141
+
142
+ MAX_TOOL_ROUNDS = 2 # Maximum rounds of tool calls
143
+ MAX_PARALLEL_TOOLS = 3 # Maximum parallel tool calls per round
144
+
145
+ def __init__(
146
+ self,
147
+ config: Any,
148
+ search_service: SearchService,
149
+ send_func: Optional[Callable[[str], Awaitable[None]]] = None
150
+ ):
151
+ self.config = config
152
+ self.search_service = search_service
153
+ self.send_func = send_func
154
+ self.client = AsyncOpenAI(base_url=config.base_url, api_key=config.api_key)
155
+
156
+ async def execute(
157
+ self,
158
+ user_input: str,
159
+ conversation_history: List[Dict],
160
+ images: List[str] = None,
161
+ model_name: str = None,
162
+ ) -> Dict[str, Any]:
163
+ """Execute agent with tool-calling loop."""
164
+ start_time = time.time()
165
+
166
+ # Get model config
167
+ model_cfg = self.config.get_model_config("main")
168
+ model = model_name or model_cfg.model_name or self.config.model_name
169
+
170
+ client = AsyncOpenAI(
171
+ base_url=model_cfg.base_url or self.config.base_url,
172
+ api_key=model_cfg.api_key or self.config.api_key
173
+ )
174
+
175
+ # Create session
176
+ session = AgentSession(
177
+ session_id=str(time.time()),
178
+ user_query=user_input,
179
+ conversation_history=conversation_history.copy()
180
+ )
181
+
182
+ # Create context for results
183
+ context = StageContext(
184
+ user_input=user_input,
185
+ images=images or [],
186
+ conversation_history=conversation_history,
187
+ )
188
+
189
+ # Build initial messages
190
+ language = getattr(self.config, "language", "Simplified Chinese")
191
+ from datetime import datetime
192
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M")
193
+ system_prompt = AGENT_SYSTEM_PROMPT + f"\n\n用户要求的语言: {language}\n当前时间: {current_time}"
194
+
195
+ # Build user content with images if provided
196
+ user_image_count = len(images) if images else 0
197
+ session.user_image_count = user_image_count
198
+ session.total_image_count = user_image_count
199
+
200
+ if images:
201
+ user_content: List[Dict[str, Any]] = [{"type": "text", "text": user_input}]
202
+ for img_b64 in images:
203
+ url = f"data:image/jpeg;base64,{img_b64}" if not img_b64.startswith("data:") else img_b64
204
+ user_content.append({"type": "image_url", "image_url": {"url": url}})
205
+ else:
206
+ user_content = user_input
207
+
208
+ session.messages = [
209
+ {"role": "system", "content": system_prompt},
210
+ {"role": "user", "content": user_content}
211
+ ]
212
+
213
+ # Add image source hint for user images
214
+ if user_image_count > 0:
215
+ if user_image_count == 1:
216
+ hint = "第1张图片来自用户输入,请将这张图片作为用户输入的参考"
217
+ else:
218
+ hint = f"第1-{user_image_count}张图片来自用户输入,请将这{user_image_count}张图片作为用户输入的参考"
219
+ session.messages.append({"role": "system", "content": hint})
220
+
221
+ # Tool definitions
222
+ web_tool = get_web_tool()
223
+ refuse_tool = get_refuse_answer_tool()
224
+ tools = [web_tool, refuse_tool]
225
+
226
+ usage_totals = {"input_tokens": 0, "output_tokens": 0}
227
+ final_content = ""
228
+
229
+ # Send initial status notification
230
+ if self.send_func:
231
+ try:
232
+ await self.send_func("💭 正在理解用户意图...")
233
+ except Exception as e:
234
+ logger.warning(f"AgentPipeline: Failed to send initial notification: {e}")
235
+
236
+ # Agent loop
237
+ while True:
238
+ # Check if we need to force summary (no tools)
239
+ if session.should_force_summary:
240
+ logger.info(f"AgentPipeline: Max tool rounds ({self.MAX_TOOL_ROUNDS}) reached, forcing summary")
241
+ # Add context message about collected info
242
+ if context.web_results:
243
+ context_msg = self._format_web_context(context)
244
+ session.messages.append({
245
+ "role": "system",
246
+ "content": f"你已经完成了{session.call_count}次工具调用。请基于已收集的信息给出最终回答。\n\n{context_msg}"
247
+ })
248
+
249
+
250
+ # Final call without tools
251
+ response = await client.chat.completions.create(
252
+ model=model,
253
+ messages=session.messages,
254
+ temperature=self.config.temperature,
255
+ )
256
+
257
+ if response.usage:
258
+ usage_totals["input_tokens"] += response.usage.prompt_tokens or 0
259
+ usage_totals["output_tokens"] += response.usage.completion_tokens or 0
260
+
261
+ final_content = response.choices[0].message.content or ""
262
+ break
263
+
264
+ # Normal call with tools
265
+ llm_start = time.time()
266
+ try:
267
+ response = await client.chat.completions.create(
268
+ model=model,
269
+ messages=session.messages,
270
+ temperature=self.config.temperature,
271
+ tools=tools,
272
+ tool_choice="auto",
273
+ )
274
+ except Exception as e:
275
+ logger.error(f"AgentPipeline: LLM error: {e}")
276
+ return {
277
+ "llm_response": f"Error: {e}",
278
+ "success": False,
279
+ "error": str(e),
280
+ "stats": {"total_time": time.time() - start_time}
281
+ }
282
+
283
+ llm_duration = time.time() - llm_start
284
+ session.llm_time += llm_duration
285
+
286
+ # Track first LLM call time (理解用户意图)
287
+ if session.call_count == 0 and session.first_llm_time == 0:
288
+ session.first_llm_time = llm_duration
289
+
290
+ if response.usage:
291
+ usage_totals["input_tokens"] += response.usage.prompt_tokens or 0
292
+ usage_totals["output_tokens"] += response.usage.completion_tokens or 0
293
+
294
+ message = response.choices[0].message
295
+
296
+ # Check for tool calls
297
+ if not message.tool_calls:
298
+ # Model chose to answer directly
299
+ final_content = message.content or ""
300
+ logger.info(f"AgentPipeline: Model answered directly after {session.call_count} tool calls")
301
+ break
302
+
303
+ # Add assistant message with tool calls
304
+ session.messages.append({
305
+ "role": "assistant",
306
+ "content": message.content,
307
+ "tool_calls": [
308
+ {
309
+ "id": tc.id,
310
+ "type": "function",
311
+ "function": {"name": tc.function.name, "arguments": tc.function.arguments}
312
+ }
313
+ for tc in message.tool_calls
314
+ ]
315
+ })
316
+
317
+ # Execute all tool calls in parallel
318
+ tool_tasks = []
319
+ tool_call_ids = []
320
+ tool_call_names = []
321
+ tool_call_args_list = []
322
+
323
+ for tool_call in message.tool_calls:
324
+ tc_id = tool_call.id
325
+ func_name = tool_call.function.name
326
+
327
+ try:
328
+ args = json.loads(tool_call.function.arguments)
329
+ except json.JSONDecodeError:
330
+ args = {}
331
+
332
+ tool_call_ids.append(tc_id)
333
+ tool_call_names.append(func_name)
334
+ tool_call_args_list.append(args)
335
+ logger.info(f"AgentPipeline: Queueing tool '{func_name}' with args: {args}")
336
+
337
+ # Check for refuse_answer first (handle immediately)
338
+ for idx, func_name in enumerate(tool_call_names):
339
+ if func_name == "refuse_answer":
340
+ args = tool_call_args_list[idx]
341
+ reason = args.get("reason", "Refused")
342
+ context.should_refuse = True
343
+ context.refuse_reason = reason
344
+
345
+ session.messages.append({
346
+ "role": "tool",
347
+ "tool_call_id": tool_call_ids[idx],
348
+ "content": f"已拒绝回答: {reason}"
349
+ })
350
+
351
+ return {
352
+ "llm_response": "",
353
+ "success": True,
354
+ "refuse_answer": True,
355
+ "refuse_reason": reason,
356
+ "stats": {"total_time": time.time() - start_time},
357
+ "usage": usage_totals,
358
+ }
359
+
360
+ # Execute web_tool calls in parallel
361
+ search_start = time.time()
362
+ tasks_to_run = []
363
+ task_indices = []
364
+
365
+ for idx, func_name in enumerate(tool_call_names):
366
+ if func_name == "web_tool":
367
+ tasks_to_run.append(self._execute_web_tool(tool_call_args_list[idx], context))
368
+ task_indices.append(idx)
369
+
370
+ # Run all web_tool calls in parallel
371
+ if tasks_to_run:
372
+ results = await asyncio.gather(*tasks_to_run, return_exceptions=True)
373
+ else:
374
+ results = []
375
+
376
+ session.search_time += time.time() - search_start
377
+
378
+ # Process results and collect notifications
379
+ notifications = []
380
+ result_map = {} # Map task index to result
381
+
382
+ for i, result in enumerate(results):
383
+ task_idx = task_indices[i]
384
+ if isinstance(result, Exception):
385
+ result_map[task_idx] = {"summary": f"执行失败: {result}", "results": []}
386
+ else:
387
+ result_map[task_idx] = result
388
+
389
+ # Add all tool results to messages and collect notifications
390
+ for idx, func_name in enumerate(tool_call_names):
391
+ tc_id = tool_call_ids[idx]
392
+ args = tool_call_args_list[idx]
393
+
394
+ if func_name == "web_tool":
395
+ result = result_map.get(idx, {"summary": "未执行", "results": []})
396
+
397
+ # Track tool call
398
+ session.tool_calls.append({"name": func_name, "args": args})
399
+ session.tool_results.append(result)
400
+
401
+ # Collect notification
402
+ notifications.append(f"🔍 {result['summary']}")
403
+
404
+ # Add tool result to messages
405
+ result_content = f"搜索完成: {result['summary']}\n\n找到 {len(result.get('results', []))} 个结果"
406
+ session.messages.append({
407
+ "role": "tool",
408
+ "tool_call_id": tc_id,
409
+ "content": result_content
410
+ })
411
+
412
+ # Add image source hint for web screenshots
413
+ screenshot_count = result.get("screenshot_count", 0)
414
+ if screenshot_count > 0:
415
+ start_idx_img = session.total_image_count + 1
416
+ end_idx_img = session.total_image_count + screenshot_count
417
+ session.total_image_count = end_idx_img
418
+
419
+ source_desc = result.get("source_desc", "网页截图")
420
+ if start_idx_img == end_idx_img:
421
+ hint = f"第{start_idx_img}张图片来自{source_desc},作为查询的参考资料"
422
+ else:
423
+ hint = f"第{start_idx_img}-{end_idx_img}张图片来自{source_desc},作为查询的参考资料"
424
+ session.messages.append({"role": "system", "content": hint})
425
+ else:
426
+ # Unknown tool
427
+ session.messages.append({
428
+ "role": "tool",
429
+ "tool_call_id": tc_id,
430
+ "content": f"Unknown tool: {func_name}"
431
+ })
432
+
433
+ # Send batched notification (up to 3 lines)
434
+ if self.send_func and notifications:
435
+ try:
436
+ # Join notifications with newlines, max 3 lines
437
+ notification_msg = "\n".join(notifications[:3])
438
+ await self.send_func(notification_msg)
439
+ except Exception as e:
440
+ logger.warning(f"AgentPipeline: Failed to send notification: {e}")
441
+
442
+ # Increment round count after processing all tool calls in this round
443
+ if tasks_to_run:
444
+ session.round_count += 1
445
+
446
+ # Build final response
447
+ total_time = time.time() - start_time
448
+ stats = {"total_time": total_time}
449
+
450
+ # Update conversation history
451
+ conversation_history.append({"role": "user", "content": user_input})
452
+ conversation_history.append({"role": "assistant", "content": final_content})
453
+
454
+ stages_used = self._build_stages_ui(session, context, usage_totals, total_time)
455
+ logger.info(f"AgentPipeline: Built stages_used = {stages_used}")
456
+
457
+ return {
458
+ "llm_response": final_content,
459
+ "success": True,
460
+ "stats": stats,
461
+ "model_used": model,
462
+ "conversation_history": conversation_history,
463
+ "usage": usage_totals,
464
+ "web_results": context.web_results,
465
+ "tool_calls_count": session.call_count,
466
+ "stages_used": stages_used,
467
+ }
468
+
469
+ async def _execute_web_tool(self, args: Dict, context: StageContext) -> Dict[str, Any]:
470
+ """执行 web_tool - 复用 /w 逻辑,支持过滤器语法"""
471
+ query = args.get("query", "")
472
+
473
+ # 1. URL 截图模式 - 检测 query 中是否包含 URL
474
+ url_match = re.search(r'https?://\S+', query)
475
+ if url_match:
476
+ url = url_match.group(0)
477
+ # Send URL screenshot notification
478
+ if self.send_func:
479
+ try:
480
+ short_url = url[:40] + "..." if len(url) > 40 else url
481
+ await self.send_func(f"📸 正在截图: {short_url}")
482
+ except Exception:
483
+ pass
484
+
485
+ logger.info(f"AgentPipeline: Screenshot URL with content: {url}")
486
+ # Use screenshot_with_content to get both screenshot and text
487
+ result = await self.search_service.screenshot_with_content(url)
488
+ screenshot_b64 = result.get("screenshot_b64")
489
+ content = result.get("content", "")
490
+ title = result.get("title", "")
491
+
492
+ if screenshot_b64:
493
+ context.web_results.append({
494
+ "_id": context.next_id(),
495
+ "_type": "page",
496
+ "url": url,
497
+ "title": title or "Screenshot",
498
+ "screenshot_b64": screenshot_b64,
499
+ "content": content, # Text content for LLM
500
+ })
501
+ return {
502
+ "summary": f"已截图: {url[:50]}{'...' if len(url) > 50 else ''}",
503
+ "results": [{"_type": "screenshot", "url": url}],
504
+ "screenshot_count": 1,
505
+ "source_desc": f"URL截图 ({url[:30]}...)"
506
+ }
507
+ return {
508
+ "summary": f"截图失败: {url[:50]}",
509
+ "results": [],
510
+ "screenshot_count": 0
511
+ }
512
+
513
+ # 2. 解析过滤器语法
514
+ filters, search_query, error = parse_filter_syntax(query, max_count=3)
515
+
516
+ if error:
517
+ return {"summary": error, "results": []}
518
+
519
+ # 3. 如果有过滤器,发送搜索+截图预告
520
+ if filters and self.send_func:
521
+ try:
522
+ # Build filter description
523
+ filter_desc_parts = []
524
+ for f_type, f_val, f_count in filters:
525
+ if f_type == 'index':
526
+ filter_desc_parts.append(f"第{f_val}个")
527
+ else:
528
+ filter_desc_parts.append(f"{f_val}={f_count}")
529
+ filter_desc = ", ".join(filter_desc_parts)
530
+ await self.send_func(f"🔍 正在搜索 \"{search_query}\" 并匹配 [{filter_desc}]...")
531
+ except Exception:
532
+ pass
533
+
534
+ logger.info(f"AgentPipeline: Searching for: {search_query}")
535
+ results = await self.search_service.search(search_query)
536
+ visible = [r for r in results if not r.get("_hidden")]
537
+
538
+ # Add search results to context
539
+ for r in results:
540
+ r["_id"] = context.next_id()
541
+ if "_type" not in r:
542
+ r["_type"] = "search"
543
+ r["query"] = search_query
544
+ context.web_results.append(r)
545
+
546
+ # 4. 如果有过滤器,截图匹配的链接
547
+ if filters:
548
+ urls = self._collect_filter_urls(filters, visible)
549
+ if urls:
550
+ logger.info(f"AgentPipeline: Taking screenshots with content of {len(urls)} URLs")
551
+ # Use screenshot_with_content to get both screenshot and text
552
+ screenshot_tasks = [self.search_service.screenshot_with_content(u) for u in urls]
553
+ results = await asyncio.gather(*screenshot_tasks)
554
+
555
+ # Add screenshots and content to context
556
+ successful_count = 0
557
+ for url, result in zip(urls, results):
558
+ screenshot_b64 = result.get("screenshot_b64") if isinstance(result, dict) else None
559
+ content = result.get("content", "") if isinstance(result, dict) else ""
560
+ title = result.get("title", "") if isinstance(result, dict) else ""
561
+
562
+ if screenshot_b64:
563
+ successful_count += 1
564
+ # Find and update the matching result
565
+ for r in context.web_results:
566
+ if r.get("url") == url:
567
+ r["screenshot_b64"] = screenshot_b64
568
+ r["content"] = content # Text content for LLM
569
+ r["title"] = title or r.get("title", "")
570
+ r["_type"] = "page"
571
+ break
572
+
573
+ return {
574
+ "summary": f"搜索 \"{search_query}\" 并截图 {successful_count} 个匹配结果",
575
+ "results": [{"url": u, "_type": "page"} for u in urls],
576
+ "screenshot_count": successful_count,
577
+ "source_desc": f"搜索 \"{search_query}\" 的网页截图"
578
+ }
579
+
580
+ # 5. 普通搜索模式 (无截图)
581
+ return {
582
+ "summary": f"搜索 \"{search_query}\" 找到 {len(visible)} 条结果",
583
+ "results": visible,
584
+ "screenshot_count": 0
585
+ }
586
+
587
+ def _collect_filter_urls(self, filters: List, visible: List[Dict]) -> List[str]:
588
+ """Collect URLs based on filter specifications."""
589
+ urls = []
590
+
591
+ for filter_type, filter_value, count in filters:
592
+ if filter_type == 'index':
593
+ idx = filter_value - 1 # Convert to 0-based
594
+ if 0 <= idx < len(visible):
595
+ url = visible[idx].get("url", "")
596
+ if url and url not in urls:
597
+ urls.append(url)
598
+ else:
599
+ # Link filter
600
+ found_count = 0
601
+ for res in visible:
602
+ url = res.get("url", "")
603
+ title = res.get("title", "")
604
+ # Match filter against both URL and title
605
+ if (filter_value in url.lower() or filter_value in title.lower()) and url not in urls:
606
+ urls.append(url)
607
+ found_count += 1
608
+ if found_count >= count:
609
+ break
610
+
611
+ return urls
612
+
613
+ def _format_web_context(self, context: StageContext) -> str:
614
+ """Format web results for summary context."""
615
+ if not context.web_results:
616
+ return ""
617
+
618
+ lines = ["## 已收集的信息\n"]
619
+ for r in context.web_results:
620
+ idx = r.get("_id", "?")
621
+ title = r.get("title", "Untitled")
622
+ url = r.get("url", "")
623
+ content = r.get("content", "")[:500] if r.get("content") else ""
624
+ has_screenshot = "有截图" if r.get("screenshot_b64") else ""
625
+
626
+ lines.append(f"[{idx}] {title}")
627
+ if url:
628
+ lines.append(f" URL: {url}")
629
+ if has_screenshot:
630
+ lines.append(f" {has_screenshot}")
631
+ if content:
632
+ lines.append(f" 摘要: {content[:200]}...")
633
+ lines.append("")
634
+
635
+ return "\n".join(lines)
636
+
637
+ def _build_stages_ui(self, session: AgentSession, context: StageContext, usage_totals: Dict, total_time: float) -> List[Dict[str, Any]]:
638
+ """Build stages UI for rendering - compatible with App.vue flow section.
639
+
640
+ Flow: Instruct (意图) → Search (搜索) → Summary (总结)
641
+ """
642
+ stages = []
643
+
644
+ # Get model config for pricing
645
+ model_cfg = self.config.get_model_config("main")
646
+ model_name = model_cfg.model_name or self.config.model_name
647
+ input_price = getattr(model_cfg, "input_price", 0) or 0
648
+ output_price = getattr(model_cfg, "output_price", 0) or 0
649
+
650
+ # 1. Instruct Stage (理解用户意图 - 第一次LLM调用)
651
+ if session.first_llm_time > 0:
652
+ # Estimate tokens for first call (rough split based on proportion)
653
+ # Since we track total usage, we approximate first call as ~40% of total
654
+ first_call_ratio = 0.4 if session.call_count > 0 else 1.0
655
+ instruct_input = int(usage_totals.get("input_tokens", 0) * first_call_ratio)
656
+ instruct_output = int(usage_totals.get("output_tokens", 0) * first_call_ratio)
657
+ instruct_cost = (instruct_input * input_price + instruct_output * output_price) / 1_000_000
658
+
659
+ stages.append({
660
+ "name": "Instruct",
661
+ "model": model_name,
662
+ "provider": model_cfg.model_provider or "OpenRouter",
663
+ "description": "理解用户意图",
664
+ "time": session.first_llm_time,
665
+ "usage": {"input_tokens": instruct_input, "output_tokens": instruct_output},
666
+ "cost": instruct_cost,
667
+ })
668
+
669
+ # 2. Search Stage (搜索)
670
+ if session.tool_calls:
671
+ # Collect all search descriptions
672
+ search_descriptions = []
673
+ for tc, result in zip(session.tool_calls, session.tool_results):
674
+ desc = result.get("summary", "")
675
+ if desc:
676
+ search_descriptions.append(desc)
677
+
678
+ stages.append({
679
+ "name": "Search",
680
+ "model": "",
681
+ "provider": "Web",
682
+ "description": " → ".join(search_descriptions) if search_descriptions else "Web Search",
683
+ "time": session.search_time,
684
+ })
685
+
686
+ # 3. Summary Stage (总结)
687
+ # Calculate remaining tokens after instruct
688
+ summary_ratio = 0.6 if session.call_count > 0 else 0.0
689
+ summary_input = int(usage_totals.get("input_tokens", 0) * summary_ratio)
690
+ summary_output = int(usage_totals.get("output_tokens", 0) * summary_ratio)
691
+ summary_cost = (summary_input * input_price + summary_output * output_price) / 1_000_000
692
+ summary_time = session.llm_time - session.first_llm_time
693
+
694
+ if summary_time > 0 or session.call_count > 0:
695
+ stages.append({
696
+ "name": "Summary",
697
+ "model": model_name,
698
+ "provider": model_cfg.model_provider or "OpenRouter",
699
+ "description": f"生成回答 ({session.call_count} 次工具调用)",
700
+ "time": max(0, summary_time),
701
+ "usage": {"input_tokens": summary_input, "output_tokens": summary_output},
702
+ "cost": summary_cost,
703
+ })
704
+
705
+ return stages