aria-code 4.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. agents/__init__.py +32 -0
  2. agents/base.py +190 -0
  3. agents/deep/__init__.py +37 -0
  4. agents/deep/calibration_loop.py +144 -0
  5. agents/deep/critic.py +125 -0
  6. agents/deep/deepen.py +193 -0
  7. agents/deep/models.py +149 -0
  8. agents/deep/pipeline.py +164 -0
  9. agents/deep/quant_fusion.py +192 -0
  10. agents/deep/themes.py +95 -0
  11. agents/deep/tiers.py +106 -0
  12. agents/financial/__init__.py +10 -0
  13. agents/financial/catalyst.py +279 -0
  14. agents/financial/debate.py +145 -0
  15. agents/financial/earnings.py +303 -0
  16. agents/financial/fundamental.py +159 -0
  17. agents/financial/macro.py +99 -0
  18. agents/financial/news.py +207 -0
  19. agents/financial/risk.py +132 -0
  20. agents/financial/sector.py +279 -0
  21. agents/financial/synthesis.py +274 -0
  22. agents/financial/technical.py +258 -0
  23. agents/portfolio_agent.py +333 -0
  24. agents/realty/__init__.py +62 -0
  25. agents/realty/asset_diagnosis.py +150 -0
  26. agents/realty/business_match.py +165 -0
  27. agents/realty/cashflow_verify.py +208 -0
  28. agents/realty/contract_rules.py +209 -0
  29. agents/realty/energy_anomaly.py +188 -0
  30. agents/realty/exit_settlement.py +207 -0
  31. agents/realty/fulfillment_risk.py +205 -0
  32. agents/realty/ops_optimize.py +159 -0
  33. agents/realty/revenue_share.py +214 -0
  34. agents/registry.py +144 -0
  35. agents/sports/__init__.py +0 -0
  36. agents/sports/football_agent.py +169 -0
  37. agents/team.py +289 -0
  38. aliyun_data_client.py +660 -0
  39. apps/README.md +12 -0
  40. apps/__init__.py +2 -0
  41. apps/channels/README.md +15 -0
  42. apps/cli/README.md +13 -0
  43. apps/cli/__init__.py +2 -0
  44. apps/cli/bootstrap.py +99 -0
  45. apps/cli/codegen_paths.py +29 -0
  46. apps/cli/commands/__init__.py +16 -0
  47. apps/cli/commands/analysis_cmds.py +288 -0
  48. apps/cli/commands/backtest_cmds.py +1887 -0
  49. apps/cli/commands/broker_cmds.py +1154 -0
  50. apps/cli/commands/business_workflow_cmds.py +289 -0
  51. apps/cli/commands/catalog.py +84 -0
  52. apps/cli/commands/data_cmds.py +405 -0
  53. apps/cli/commands/diagnostic_cmds.py +179 -0
  54. apps/cli/commands/diagnostic_ops_cmds.py +696 -0
  55. apps/cli/commands/finance_render.py +12 -0
  56. apps/cli/commands/market.py +399 -0
  57. apps/cli/commands/market_cmds.py +1276 -0
  58. apps/cli/commands/market_context.py +425 -0
  59. apps/cli/commands/market_render.py +7 -0
  60. apps/cli/commands/model_cmds.py +1579 -0
  61. apps/cli/commands/ops_cmds.py +668 -0
  62. apps/cli/commands/portfolio_cmds.py +962 -0
  63. apps/cli/commands/report.py +377 -0
  64. apps/cli/commands/scaffold_templates.py +617 -0
  65. apps/cli/commands/session_cmds.py +179 -0
  66. apps/cli/commands/session_ux_cmds.py +280 -0
  67. apps/cli/commands/team.py +588 -0
  68. apps/cli/commands/team_render.py +8 -0
  69. apps/cli/commands/ui_cmds.py +358 -0
  70. apps/cli/commands/workflow_cmds.py +279 -0
  71. apps/cli/commands/workspace_cmds.py +1414 -0
  72. apps/cli/config_paths.py +70 -0
  73. apps/cli/config_store.py +61 -0
  74. apps/cli/deterministic.py +122 -0
  75. apps/cli/direct.py +48 -0
  76. apps/cli/github_app_auth.py +135 -0
  77. apps/cli/handlers/__init__.py +11 -0
  78. apps/cli/handlers/broker_handlers.py +122 -0
  79. apps/cli/handlers/chart_handlers.py +1309 -0
  80. apps/cli/handlers/market_handlers.py +2509 -0
  81. apps/cli/handlers/realty_handlers.py +114 -0
  82. apps/cli/handlers/strategy_advice.py +82 -0
  83. apps/cli/hooks.py +180 -0
  84. apps/cli/i18n.py +284 -0
  85. apps/cli/intent.py +136 -0
  86. apps/cli/intent_router.py +217 -0
  87. apps/cli/lifecycle_hooks.py +48 -0
  88. apps/cli/main.py +29 -0
  89. apps/cli/market_metadata.py +135 -0
  90. apps/cli/market_universe.py +265 -0
  91. apps/cli/message_processing.py +257 -0
  92. apps/cli/plan_mode.py +139 -0
  93. apps/cli/plotly_html.py +15 -0
  94. apps/cli/prediction_feedback.py +202 -0
  95. apps/cli/preflight.py +497 -0
  96. apps/cli/project_aria.py +60 -0
  97. apps/cli/prompts/__init__.py +0 -0
  98. apps/cli/prompts/coding.py +658 -0
  99. apps/cli/prompts/system_prompts.py +531 -0
  100. apps/cli/prompts/ui.py +434 -0
  101. apps/cli/providers/__init__.py +1 -0
  102. apps/cli/providers/base.py +271 -0
  103. apps/cli/providers/chat_routing.py +80 -0
  104. apps/cli/providers/llm/__init__.py +1 -0
  105. apps/cli/providers/llm/ollama_stream.py +1170 -0
  106. apps/cli/providers/llm/sse_stream.py +216 -0
  107. apps/cli/providers/runtime_bridge.py +185 -0
  108. apps/cli/runtime_consumer.py +489 -0
  109. apps/cli/session_export.py +87 -0
  110. apps/cli/session_jsonl.py +207 -0
  111. apps/cli/session_store.py +112 -0
  112. apps/cli/todo_tracker.py +190 -0
  113. apps/cli/tools/__init__.py +40 -0
  114. apps/cli/tools/context.py +46 -0
  115. apps/cli/tools/file_tools.py +112 -0
  116. apps/cli/tools/market_tools.py +549 -0
  117. apps/cli/tools/notebook_tools.py +111 -0
  118. apps/cli/tools/system_tools.py +669 -0
  119. apps/cli/tools/write_tools.py +715 -0
  120. apps/cli/tradingview_bridge.py +434 -0
  121. apps/cli/update_check.py +152 -0
  122. apps/cli/utils/__init__.py +0 -0
  123. apps/cli/utils/market_detect.py +1578 -0
  124. apps/daemon/README.md +14 -0
  125. apps/vscode/README.md +115 -0
  126. apps/vscode/package.json +70 -0
  127. aria_cli.py +11636 -0
  128. aria_code-4.1.3.dist-info/METADATA +952 -0
  129. aria_code-4.1.3.dist-info/RECORD +284 -0
  130. aria_code-4.1.3.dist-info/WHEEL +5 -0
  131. aria_code-4.1.3.dist-info/entry_points.txt +2 -0
  132. aria_code-4.1.3.dist-info/licenses/LICENSE +121 -0
  133. aria_code-4.1.3.dist-info/top_level.txt +50 -0
  134. aria_daemon.py +1295 -0
  135. aria_feishu_bot.py +1359 -0
  136. aria_relay_client.py +182 -0
  137. aria_relay_server.py +405 -0
  138. aria_telegram_bot.py +202 -0
  139. ariarc.py +328 -0
  140. artifacts.py +491 -0
  141. backtest_report.py +472 -0
  142. brokers/__init__.py +72 -0
  143. brokers/base.py +207 -0
  144. brokers/capabilities.py +264 -0
  145. brokers/cn/__init__.py +10 -0
  146. brokers/cn/easytrader_broker.py +193 -0
  147. brokers/cn/futu_broker.py +194 -0
  148. brokers/cn/longbridge_broker.py +190 -0
  149. brokers/cn/tiger_broker.py +196 -0
  150. brokers/cn/xtquant_broker.py +175 -0
  151. brokers/config.py +364 -0
  152. brokers/intl/__init__.py +5 -0
  153. brokers/intl/alpaca_broker.py +183 -0
  154. brokers/intl/ibkr_broker.py +215 -0
  155. brokers/intl/webull_broker.py +156 -0
  156. brokers/paper_broker.py +259 -0
  157. brokers/planning.py +296 -0
  158. brokers/registry.py +181 -0
  159. brokers/trading.py +237 -0
  160. change_store.py +127 -0
  161. command_safety.py +19 -0
  162. computer_use_tools.py +504 -0
  163. dashboard_generator.py +578 -0
  164. data_analysis_tools.py +808 -0
  165. data_cleaner.py +483 -0
  166. data_service.py +481 -0
  167. datasources/__init__.py +23 -0
  168. datasources/base.py +166 -0
  169. datasources/router.py +221 -0
  170. datasources/sources/__init__.py +15 -0
  171. datasources/sources/akshare_source.py +269 -0
  172. datasources/sources/alpha_vantage_source.py +202 -0
  173. datasources/sources/edgar_source.py +218 -0
  174. datasources/sources/finnhub_source.py +197 -0
  175. datasources/sources/fred_source.py +219 -0
  176. datasources/sources/tushare_source.py +141 -0
  177. datasources/sources/web_scraper_source.py +278 -0
  178. datasources/sources/world_bank_source.py +205 -0
  179. datasources/sources/yfinance_source.py +152 -0
  180. demo_player.py +204 -0
  181. doctor.py +508 -0
  182. file_analysis_tools.py +734 -0
  183. finance_formulas.py +389 -0
  184. football_data_client.py +1670 -0
  185. intent_classifier.py +358 -0
  186. local_finance_tools.py +3221 -0
  187. local_llm_provider.py +552 -0
  188. macro_tools.py +368 -0
  189. market_data_client.py +1899 -0
  190. mcp_client.py +506 -0
  191. memory_manager.py +245 -0
  192. model_capability.py +416 -0
  193. notification_tools.py +248 -0
  194. packages/__init__.py +23 -0
  195. packages/aria_agents/__init__.py +5 -0
  196. packages/aria_agents/manifest.py +69 -0
  197. packages/aria_core/__init__.py +34 -0
  198. packages/aria_core/architecture.py +192 -0
  199. packages/aria_core/export.py +124 -0
  200. packages/aria_core/manifest.py +65 -0
  201. packages/aria_infra/__init__.py +15 -0
  202. packages/aria_infra/arthera.py +52 -0
  203. packages/aria_infra/doctor.py +246 -0
  204. packages/aria_infra/product.py +37 -0
  205. packages/aria_mcp/__init__.py +25 -0
  206. packages/aria_mcp/bridge.py +38 -0
  207. packages/aria_mcp/config.py +97 -0
  208. packages/aria_mcp/tools.py +61 -0
  209. packages/aria_sdk/__init__.py +19 -0
  210. packages/aria_sdk/client.py +396 -0
  211. packages/aria_sdk/providers.py +70 -0
  212. packages/aria_sdk/streaming.py +73 -0
  213. packages/aria_sdk/types.py +86 -0
  214. packages/aria_services/__init__.py +55 -0
  215. packages/aria_services/context.py +258 -0
  216. packages/aria_services/data.py +11 -0
  217. packages/aria_services/provider_health.py +189 -0
  218. packages/aria_services/registry.py +213 -0
  219. packages/aria_services/usage.py +138 -0
  220. packages/aria_skills/__init__.py +5 -0
  221. packages/aria_skills/registry.py +59 -0
  222. packages/aria_tools/__init__.py +5 -0
  223. packages/aria_tools/registry.py +128 -0
  224. packages/quant_engine/__init__.py +6 -0
  225. packages/quant_engine/sports/__init__.py +72 -0
  226. packages/quant_engine/sports/calibrator.py +353 -0
  227. packages/quant_engine/sports/dixon_coles.py +234 -0
  228. packages/quant_engine/sports/elo.py +299 -0
  229. packages/quant_engine/sports/form.py +188 -0
  230. packages/quant_engine/sports/h2h.py +195 -0
  231. packages/quant_engine/sports/ml_model.py +354 -0
  232. packages/quant_engine/sports/predictor.py +311 -0
  233. packages/quant_engine/sports/tracker.py +664 -0
  234. packages/quant_engine/stochastic/__init__.py +27 -0
  235. packages/quant_engine/stochastic/gbm_enhanced.py +195 -0
  236. packages/quant_engine/stochastic/ito_calculus.py +477 -0
  237. packages/quant_engine/stochastic/kelly_criterion.py +181 -0
  238. packages/quant_engine/stochastic/monte_carlo_advanced.py +95 -0
  239. packages/quant_engine/stochastic/options_pricing.py +573 -0
  240. packages/quant_engine/stochastic/stochastic_processes.py +90 -0
  241. plan_utils.py +194 -0
  242. plugin_loader.py +328 -0
  243. portfolio_ledger.py +262 -0
  244. privacy/__init__.py +5 -0
  245. privacy/feedback.py +123 -0
  246. project_tools.py +525 -0
  247. providers/__init__.py +30 -0
  248. providers/llm/__init__.py +19 -0
  249. providers/llm/anthropic.py +184 -0
  250. providers/llm/base.py +139 -0
  251. providers/llm/ollama.py +128 -0
  252. providers/llm/openai_compat.py +282 -0
  253. providers/llm/registry.py +358 -0
  254. realty_data_tools.py +659 -0
  255. report_generator.py +1314 -0
  256. runtime/__init__.py +103 -0
  257. runtime/agent_loop.py +1183 -0
  258. runtime/approval.py +51 -0
  259. runtime/events.py +102 -0
  260. runtime/gateway.py +128 -0
  261. runtime/lsp.py +346 -0
  262. runtime/subagent.py +258 -0
  263. runtime/tool_executor.py +104 -0
  264. runtime/tool_policy.py +106 -0
  265. safety/__init__.py +21 -0
  266. safety/permissions.py +275 -0
  267. setup_wizard.py +653 -0
  268. strategy_vault.py +420 -0
  269. ui/__init__.py +100 -0
  270. ui/banner.py +310 -0
  271. ui/completer.py +391 -0
  272. ui/console.py +271 -0
  273. ui/image_render.py +243 -0
  274. ui/input_box.py +376 -0
  275. ui/picker.py +195 -0
  276. ui/render/__init__.py +11 -0
  277. ui/render/finance.py +1480 -0
  278. ui/render/market.py +225 -0
  279. ui/render/output.py +681 -0
  280. ui/render/team.py +346 -0
  281. ui/robot.py +235 -0
  282. workspace/__init__.py +6 -0
  283. workspace/files.py +170 -0
  284. workspace/verify.py +113 -0
runtime/agent_loop.py ADDED
@@ -0,0 +1,1183 @@
1
+ """Agent-loop orchestration helpers for Aria Code.
2
+
3
+ This module intentionally starts with pure, easily-tested primitives. The CLI
4
+ still owns UI prompts and provider calls, while the runtime owns the mechanical
5
+ shape of tool batching and follow-up construction.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import asyncio
11
+ import inspect
12
+ import time
13
+ from dataclasses import dataclass, field
14
+ from typing import AsyncGenerator, Awaitable, Callable, Dict, FrozenSet, Iterable, List, Optional, Sequence, Tuple, Union
15
+
16
+ from .approval import ApprovalDecision, apply_approval_decision
17
+ from .tool_executor import ToolExecutor
18
+
19
+
20
+ DEFAULT_SERIAL_TOOLS = {"write_file", "edit_file", "multi_edit", "run_command"}
21
+
22
+ # Phrases the model uses to signal task completion
23
+ _DONE_PHRASES = frozenset([
24
+ "task complete", "task is complete", "all done", "completed successfully",
25
+ "here is the final", "here's the final", "analysis complete",
26
+ "任务完成", "已完成", "分析完成", "操作完成", "已经完成", "以下是最终",
27
+ "i have completed", "i've completed", "the task has been completed",
28
+ ])
29
+
30
+
31
+ def detect_task_complete(response_text: str) -> bool:
32
+ """Heuristic: did the AI signal task completion without requesting more tools?"""
33
+ if not response_text:
34
+ return False
35
+ lower = response_text.lower()
36
+ return any(phrase in lower for phrase in _DONE_PHRASES)
37
+
38
+
39
+ class LoopGuard:
40
+ """Detect repeated identical *failing* tool calls and break the agent loop.
41
+
42
+ Weaker local models (and sometimes cloud models) get stuck calling the
43
+ exact same tool with the exact same arguments after it has already failed
44
+ — e.g. reading a file that does not exist, or calling an unknown tool. This
45
+ guard tracks failure signatures across rounds:
46
+
47
+ * After ``soft_threshold`` identical failures it injects a directive
48
+ telling the model to STOP repeating that call and try another approach.
49
+ * After ``hard_threshold`` it signals the caller to break the loop so we
50
+ don't burn the remaining tool rounds.
51
+
52
+ A successful call clears that signature's counter.
53
+ """
54
+
55
+ def __init__(self, *, soft_threshold: int = 2, hard_threshold: int = 4) -> None:
56
+ self.soft_threshold = soft_threshold
57
+ self.hard_threshold = hard_threshold
58
+ self._fail_counts: Dict[str, int] = {}
59
+ self._warned: set = set()
60
+ self.should_break: bool = False
61
+
62
+ @staticmethod
63
+ def _signature(tool_name: str, params: dict) -> str:
64
+ import hashlib
65
+ import json
66
+ try:
67
+ key = json.dumps(params or {}, sort_keys=True, ensure_ascii=False, default=str)
68
+ except Exception:
69
+ key = str(params)
70
+ digest = hashlib.md5(key.encode("utf-8", "ignore")).hexdigest()[:12]
71
+ return f"{tool_name}::{digest}"
72
+
73
+ @staticmethod
74
+ def is_failure(result) -> bool:
75
+ """Best-effort detection of a failed tool result (dict or summary string)."""
76
+ if isinstance(result, dict):
77
+ if result.get("success") is True:
78
+ return False
79
+ if result.get("success") is False:
80
+ return True
81
+ text = f"{result.get('error', '')} {result.get('output', '')}"
82
+ else:
83
+ text = str(result)
84
+ low = text[:200].lower()
85
+ return any(s in low for s in (
86
+ "error", "unknown local tool", "unknown tool", "not found",
87
+ "no such file", "failed", "traceback", "exception",
88
+ "missing", "未找到", "失败", "错误",
89
+ ))
90
+
91
+ def record(self, tool_name: str, params: dict, result) -> "str | None":
92
+ """Record one tool result. Return a directive string if a loop is detected."""
93
+ sig = self._signature(tool_name, params)
94
+ if not self.is_failure(result):
95
+ self._fail_counts.pop(sig, None)
96
+ self._warned.discard(sig)
97
+ return None
98
+
99
+ self._fail_counts[sig] = self._fail_counts.get(sig, 0) + 1
100
+ count = self._fail_counts[sig]
101
+
102
+ if count >= self.hard_threshold:
103
+ self.should_break = True
104
+ return (
105
+ f"⛔ 已连续 {count} 次用相同参数调用 `{tool_name}` 且全部失败。"
106
+ f"立即停止调用工具,基于现有信息直接回答用户,或说明卡在哪里、需要什么。"
107
+ )
108
+ if count >= self.soft_threshold and sig not in self._warned:
109
+ self._warned.add(sig)
110
+ return (
111
+ f"⚠ 你已经用完全相同的参数调用了 `{tool_name}` {count} 次,每次都失败。"
112
+ f"不要再用相同参数重试。请改变策略:换参数、换工具(如先用 list_files/search_code 定位),"
113
+ f"或基于已有结果继续。"
114
+ )
115
+ return None
116
+
117
+
118
+ def split_tool_calls(
119
+ pending: Sequence[dict],
120
+ serial_tools: Iterable[str] = DEFAULT_SERIAL_TOOLS,
121
+ ) -> Tuple[List[dict], List[dict]]:
122
+ """Split tool calls into parallel-safe and serial batches.
123
+
124
+ Beyond the static whitelist, detects data dependencies at runtime:
125
+ if a later run_command references a file written/edited earlier in
126
+ the same batch, it is moved to the serial queue so it runs after
127
+ the write completes.
128
+ """
129
+ serial = set(serial_tools)
130
+ parallel_batch: List[dict] = []
131
+ serial_batch: List[dict] = []
132
+ written_paths: set = set()
133
+
134
+ for tc in pending:
135
+ tool = tc.get("tool", "")
136
+ params = tc.get("params", {})
137
+
138
+ if tool in serial:
139
+ # Track paths being written so dependents can detect the dependency
140
+ for key in ("path", "file_path", "filename", "target"):
141
+ p = params.get(key)
142
+ if p:
143
+ written_paths.add(str(p))
144
+ serial_batch.append(tc)
145
+ elif tool == "run_command":
146
+ cmd = str(params.get("command", ""))
147
+ # If the command references a path currently being written → serial
148
+ if written_paths and any(p in cmd for p in written_paths):
149
+ serial_batch.append(tc)
150
+ else:
151
+ parallel_batch.append(tc)
152
+ else:
153
+ parallel_batch.append(tc)
154
+
155
+ return parallel_batch, serial_batch
156
+
157
+
158
+ def collect_parallel_done(
159
+ pending: Sequence[dict],
160
+ parallel_results: Sequence[tuple],
161
+ serial_tools: Iterable[str] = DEFAULT_SERIAL_TOOLS,
162
+ ) -> Dict[int, dict]:
163
+ """Map original pending indices to already-executed parallel results."""
164
+ serial = set(serial_tools)
165
+ done: Dict[int, dict] = {}
166
+ for original_index, tool_call in enumerate(pending):
167
+ if tool_call.get("tool") in serial:
168
+ continue
169
+ for result_tool_call, result in parallel_results:
170
+ if result_tool_call is tool_call:
171
+ done[original_index] = result
172
+ break
173
+ return done
174
+
175
+
176
+ RemoteToolRunner = Callable[[str, dict], Awaitable[dict]]
177
+ Hook = Callable[[str, str, dict, dict | None], None]
178
+ SummaryFormatter = Callable[[str, dict], str]
179
+ ApprovalCallback = Callable[[str, dict], Union[ApprovalDecision, Awaitable[ApprovalDecision]]]
180
+
181
+
182
+ @dataclass
183
+ class AgentTurnState:
184
+ """Mutable state accumulated across one agent response turn."""
185
+
186
+ provider: str = "aws"
187
+ total_response: str = ""
188
+ tools_used: List[str] = field(default_factory=list)
189
+ sources: List[dict] = field(default_factory=list)
190
+ usage: Dict[str, int] = field(default_factory=lambda: {
191
+ "prompt_tokens": 0,
192
+ "completion_tokens": 0,
193
+ "thinking_tokens": 0,
194
+ })
195
+ tool_time_total: float = 0.0
196
+
197
+ def append_response(self, text: str | None) -> None:
198
+ if text:
199
+ self.total_response += text
200
+
201
+ def apply_model_result(self, result: dict, fallback_response: str = "") -> None:
202
+ self.append_response(result.get("response", fallback_response))
203
+ self.tools_used.extend(result.get("tools_used", []))
204
+ self.sources.extend(result.get("sources", []))
205
+ self.provider = result.get("provider", self.provider)
206
+ self.add_usage(result.get("usage", {}))
207
+
208
+ def add_usage(self, usage: dict | None) -> None:
209
+ if not usage:
210
+ return
211
+ self.usage["prompt_tokens"] += int(usage.get("prompt_tokens", 0) or 0)
212
+ self.usage["completion_tokens"] += int(usage.get("completion_tokens", 0) or 0)
213
+ self.usage["thinking_tokens"] += int(usage.get("thinking_tokens", 0) or 0)
214
+
215
+ def add_tool_time(self, elapsed: float) -> None:
216
+ self.tool_time_total += elapsed
217
+
218
+ def reset_response(self) -> None:
219
+ self.total_response = ""
220
+
221
+ def final_text(self, fallback_response: str = "") -> str:
222
+ return self.total_response or fallback_response
223
+
224
+ def token_counts(self, *, token_count: int = 0, thinking_tokens: int = 0) -> Tuple[int, int, int, int]:
225
+ prompt_t = self.usage.get("prompt_tokens", 0)
226
+ completion_t = self.usage.get("completion_tokens", 0) or token_count
227
+ think_t = self.usage.get("thinking_tokens", 0) or thinking_tokens
228
+ return prompt_t, completion_t, think_t, prompt_t + completion_t + think_t
229
+
230
+ def generation_time(self, elapsed: float) -> float:
231
+ return elapsed - self.tool_time_total
232
+
233
+ def unique_tools(self) -> List[str]:
234
+ return list(dict.fromkeys(self.tools_used))
235
+
236
+ def build_metadata(
237
+ self,
238
+ *,
239
+ elapsed: float,
240
+ token_count: int = 0,
241
+ thinking_tokens: int = 0,
242
+ ) -> "AgentTurnMetadata":
243
+ prompt_t, completion_t, think_t, total_t = self.token_counts(
244
+ token_count=token_count,
245
+ thinking_tokens=thinking_tokens,
246
+ )
247
+ parts = [f"{elapsed:.1f}s"]
248
+ gen_time = self.generation_time(elapsed)
249
+
250
+ if total_t > 0:
251
+ token_parts = []
252
+ if prompt_t > 0:
253
+ token_parts.append(f"in: {prompt_t:,}")
254
+ if completion_t > 0:
255
+ token_parts.append(f"out: {completion_t:,}")
256
+ if think_t > 0:
257
+ token_parts.append(f"think: {think_t:,}")
258
+ parts.append(f"{total_t:,} tokens ({', '.join(token_parts)})")
259
+ if completion_t > 0 and gen_time > 0.5:
260
+ parts.append(f"{completion_t / gen_time:.0f} t/s")
261
+ elif token_count > 0:
262
+ parts.append(f"{token_count:,} tokens")
263
+ if gen_time > 0.5:
264
+ parts.append(f"{token_count / gen_time:.0f} t/s")
265
+
266
+ if self.tool_time_total > 0:
267
+ parts.append(f"tools: {self.tool_time_total:.1f}s")
268
+ if self.provider != "aws":
269
+ parts.append(self.provider)
270
+ unique_tools = self.unique_tools()
271
+ if unique_tools:
272
+ parts.append(" ".join(unique_tools))
273
+
274
+ # Turn-level cost — only for cloud providers with token data
275
+ _is_cloud = self.provider not in ("ollama", "ollama_cache", "local", "")
276
+ if _is_cloud and total_t > 0:
277
+ _cost = (prompt_t * 0.14 + completion_t * 0.28 + think_t * 1.10) / 1_000_000
278
+ if _cost >= 0.0001:
279
+ parts.append(f"${_cost:.4f}")
280
+
281
+ return AgentTurnMetadata(
282
+ parts=parts,
283
+ prompt_tokens=prompt_t,
284
+ completion_tokens=completion_t,
285
+ thinking_tokens=think_t,
286
+ total_tokens=total_t,
287
+ generation_time=gen_time,
288
+ provider=self.provider,
289
+ tools=unique_tools,
290
+ )
291
+
292
+ def build_result(
293
+ self,
294
+ *,
295
+ elapsed: float,
296
+ fallback_response: str = "",
297
+ token_count: int = 0,
298
+ thinking_tokens: int = 0,
299
+ success: bool = True,
300
+ cancelled: bool = False,
301
+ error: str = "",
302
+ ) -> "AgentTurnResult":
303
+ metadata = self.build_metadata(
304
+ elapsed=elapsed,
305
+ token_count=token_count,
306
+ thinking_tokens=thinking_tokens,
307
+ )
308
+ return AgentTurnResult(
309
+ success=success,
310
+ cancelled=cancelled,
311
+ error=error,
312
+ final_text=self.final_text(fallback_response),
313
+ metadata=metadata,
314
+ provider=metadata.provider,
315
+ tools=metadata.tools,
316
+ sources=list(self.sources),
317
+ )
318
+
319
+ def build_cancelled_result(
320
+ self,
321
+ *,
322
+ elapsed: float,
323
+ fallback_response: str = "",
324
+ token_count: int = 0,
325
+ thinking_tokens: int = 0,
326
+ ) -> "AgentTurnResult":
327
+ return self.build_result(
328
+ elapsed=elapsed,
329
+ fallback_response=fallback_response,
330
+ token_count=token_count,
331
+ thinking_tokens=thinking_tokens,
332
+ success=True,
333
+ cancelled=True,
334
+ )
335
+
336
+ def build_error_result(
337
+ self,
338
+ error: str | None,
339
+ *,
340
+ elapsed: float,
341
+ fallback_response: str = "",
342
+ token_count: int = 0,
343
+ thinking_tokens: int = 0,
344
+ ) -> "AgentTurnResult":
345
+ return self.build_result(
346
+ elapsed=elapsed,
347
+ fallback_response=fallback_response,
348
+ token_count=token_count,
349
+ thinking_tokens=thinking_tokens,
350
+ success=False,
351
+ cancelled=False,
352
+ error=error or "Unknown error",
353
+ )
354
+
355
+
356
+ @dataclass(frozen=True)
357
+ class AgentTurnMetadata:
358
+ """Display and accounting metadata for one completed agent turn."""
359
+
360
+ parts: List[str]
361
+ prompt_tokens: int = 0
362
+ completion_tokens: int = 0
363
+ thinking_tokens: int = 0
364
+ total_tokens: int = 0
365
+ generation_time: float = 0.0
366
+ provider: str = "aws"
367
+ tools: List[str] = field(default_factory=list)
368
+
369
+ def system_prompt_estimate(self, message: str) -> int:
370
+ return max(0, self.prompt_tokens - len(message) // 3)
371
+
372
+
373
+ @dataclass(frozen=True)
374
+ class AgentTurnResult:
375
+ """Structured result for a completed agent turn."""
376
+
377
+ success: bool
378
+ cancelled: bool
379
+ error: str
380
+ final_text: str
381
+ metadata: AgentTurnMetadata
382
+ provider: str = "aws"
383
+ tools: List[str] = field(default_factory=list)
384
+ sources: List[dict] = field(default_factory=list)
385
+
386
+ @classmethod
387
+ def cancelled_result(
388
+ cls,
389
+ *,
390
+ metadata: AgentTurnMetadata | None = None,
391
+ final_text: str = "",
392
+ ) -> "AgentTurnResult":
393
+ return cls(
394
+ success=True,
395
+ cancelled=True,
396
+ error="",
397
+ final_text=final_text,
398
+ metadata=metadata or AgentTurnMetadata(parts=[]),
399
+ )
400
+
401
+ @classmethod
402
+ def error_result(
403
+ cls,
404
+ error: str,
405
+ *,
406
+ metadata: AgentTurnMetadata | None = None,
407
+ final_text: str = "",
408
+ ) -> "AgentTurnResult":
409
+ return cls(
410
+ success=False,
411
+ cancelled=False,
412
+ error=error,
413
+ final_text=final_text,
414
+ metadata=metadata or AgentTurnMetadata(parts=[]),
415
+ )
416
+
417
+ def to_dict(self) -> dict:
418
+ return {
419
+ "success": self.success,
420
+ "cancelled": self.cancelled,
421
+ "error": self.error,
422
+ "final_text": self.final_text,
423
+ "provider": self.provider,
424
+ "tools": list(self.tools),
425
+ "sources": list(self.sources),
426
+ "metadata": {
427
+ "parts": list(self.metadata.parts),
428
+ "prompt_tokens": self.metadata.prompt_tokens,
429
+ "completion_tokens": self.metadata.completion_tokens,
430
+ "thinking_tokens": self.metadata.thinking_tokens,
431
+ "total_tokens": self.metadata.total_tokens,
432
+ "generation_time": self.metadata.generation_time,
433
+ "provider": self.metadata.provider,
434
+ "tools": list(self.metadata.tools),
435
+ },
436
+ }
437
+
438
+ def to_envelope(self) -> "AgentTurnEnvelope":
439
+ return AgentTurnEnvelope.from_result(self)
440
+
441
+
442
+ @dataclass(frozen=True)
443
+ class AgentTurnEnvelope:
444
+ """Stable runtime envelope for CLI/API consumers."""
445
+
446
+ status: str
447
+ success: bool
448
+ cancelled: bool
449
+ error: str
450
+ final_text: str
451
+ provider: str
452
+ tools: List[str]
453
+ summary: str
454
+ metadata: dict
455
+
456
+ @classmethod
457
+ def from_result(cls, result: AgentTurnResult) -> "AgentTurnEnvelope":
458
+ return cls(
459
+ status="ok" if result.success else "error",
460
+ success=result.success,
461
+ cancelled=result.cancelled,
462
+ error=result.error,
463
+ final_text=result.final_text,
464
+ provider=result.provider,
465
+ tools=list(result.tools),
466
+ summary=" · ".join(result.metadata.parts),
467
+ metadata=result.to_dict()["metadata"],
468
+ )
469
+
470
+ def to_dict(self) -> dict:
471
+ return {
472
+ "status": self.status,
473
+ "success": self.success,
474
+ "cancelled": self.cancelled,
475
+ "error": self.error,
476
+ "final_text": self.final_text,
477
+ "provider": self.provider,
478
+ "tools": list(self.tools),
479
+ "summary": self.summary,
480
+ "metadata": dict(self.metadata),
481
+ }
482
+
483
+
484
+ @dataclass(frozen=True)
485
+ class AgentErrorPresentation:
486
+ """User-facing error presentation for model/agent failures."""
487
+
488
+ error: str
489
+ lines: List[str]
490
+ level: str = "error"
491
+ use_generic_error_prefix: bool = False
492
+
493
+ @classmethod
494
+ def from_error(cls, error: str | None) -> "AgentErrorPresentation":
495
+ normalized = error or "Unknown error"
496
+ if normalized in ("no_cloud_provider", "no_provider"):
497
+ return cls(
498
+ error=normalized,
499
+ level="warning",
500
+ lines=[
501
+ "没有可用的 AI 模型",
502
+ " Ollama 未运行,且未配置云端 API Key。",
503
+ " 解决方案(任选其一):",
504
+ " • 启动 Ollama: ollama serve",
505
+ " • 配置云端 Key: /apikey set deepseek <your-key>",
506
+ " • 导出环境变量: export DEEPSEEK_API_KEY=sk-...",
507
+ ],
508
+ )
509
+ if normalized == "all_providers_failed":
510
+ return cls(
511
+ error=normalized,
512
+ level="warning",
513
+ lines=["所有云端 Provider 均请求失败,请检查网络或 API Key 是否有效。"],
514
+ )
515
+ return cls(
516
+ error=normalized,
517
+ level="error",
518
+ lines=[f"Error: {normalized}"],
519
+ use_generic_error_prefix=True,
520
+ )
521
+
522
+
523
+ @dataclass
524
+ class ToolBatchState:
525
+ """Mutable state for one model-requested batch of tool calls."""
526
+
527
+ tool_results: List[dict] = field(default_factory=list)
528
+ elapsed_total: float = 0.0
529
+ cancelled: bool = False
530
+
531
+ def add_result(
532
+ self,
533
+ tool_name: str,
534
+ result: dict,
535
+ formatter: SummaryFormatter,
536
+ *,
537
+ elapsed: float = 0.0,
538
+ ) -> dict:
539
+ self.elapsed_total += elapsed
540
+ return record_tool_result(self.tool_results, tool_name, result, formatter)
541
+
542
+ def cancel(self) -> None:
543
+ self.cancelled = True
544
+
545
+ def build_next_turn(self, total_response: str) -> Tuple[dict, dict, str]:
546
+ return build_next_turn_messages(total_response, self.tool_results)
547
+
548
+
549
+ @dataclass(frozen=True)
550
+ class ToolCallTask:
551
+ """One ordered tool call in a model-requested turn."""
552
+
553
+ index: int
554
+ tool_call: dict
555
+ parallel_result: dict | None = None
556
+
557
+ @property
558
+ def tool_name(self) -> str:
559
+ return self.tool_call.get("tool", "")
560
+
561
+ @property
562
+ def params(self) -> dict:
563
+ return self.tool_call.get("params", {})
564
+
565
+ @property
566
+ def has_parallel_result(self) -> bool:
567
+ return self.parallel_result is not None
568
+
569
+ def progress_label(self, total: int) -> str:
570
+ if total > 1:
571
+ return f" [{self.index + 1}/{total}] Running {self.tool_name}..."
572
+ return f" Running {self.tool_name}..."
573
+
574
+
575
+ @dataclass
576
+ class ToolTurnPlan:
577
+ """Runtime plan for executing one pending tool-call turn."""
578
+
579
+ pending: Sequence[dict]
580
+ parallel_done: Dict[int, dict] = field(default_factory=dict)
581
+ batch: ToolBatchState = field(default_factory=ToolBatchState)
582
+
583
+ def tasks(self) -> List[ToolCallTask]:
584
+ return [
585
+ ToolCallTask(
586
+ index=index,
587
+ tool_call=tool_call,
588
+ parallel_result=self.parallel_done.get(index),
589
+ )
590
+ for index, tool_call in enumerate(self.pending)
591
+ ]
592
+
593
+
594
+ @dataclass(frozen=True)
595
+ class ToolExecutionActivity:
596
+ """One executed tool activity in a model-requested batch."""
597
+
598
+ tool: str
599
+ result: dict
600
+ elapsed: float
601
+ params: dict
602
+ from_parallel: bool = False
603
+
604
+
605
+ @dataclass(frozen=True)
606
+ class ToolExecutionTurnResult:
607
+ """Structured result from one pending tool-call turn."""
608
+
609
+ batch: ToolBatchState
610
+ activities: List[ToolExecutionActivity]
611
+ assistant_message: dict
612
+ user_message: dict
613
+ followup: str
614
+ guard_directives: List[str] = field(default_factory=list)
615
+
616
+ @property
617
+ def cancelled(self) -> bool:
618
+ return self.batch.cancelled
619
+
620
+
621
+ async def run_parallel_tools(
622
+ pending: Sequence[dict],
623
+ tool_executor: ToolExecutor,
624
+ *,
625
+ remote_runner: RemoteToolRunner | None = None,
626
+ hook: Hook | None = None,
627
+ serial_tools: Iterable[str] = DEFAULT_SERIAL_TOOLS,
628
+ ) -> Dict[int, dict]:
629
+ """Execute parallel-safe pending tools and return results by original index."""
630
+ parallel_batch, _ = split_tool_calls(pending, serial_tools)
631
+
632
+ async def _exec_one(tool_call: dict) -> tuple:
633
+ tool_name = tool_call.get("tool", "")
634
+ tool_params = tool_call.get("params", {})
635
+ if tool_name in tool_executor.local_tools:
636
+ result = await tool_executor.execute(tool_name, tool_params)
637
+ elif remote_runner is not None:
638
+ if hook is not None:
639
+ hook("pre_tool", tool_name, tool_params, None)
640
+ try:
641
+ result = await remote_runner(tool_name, tool_params)
642
+ except Exception as exc:
643
+ result = {"success": False, "error": str(exc)}
644
+ if hook is not None:
645
+ hook("post_tool", tool_name, tool_params, result)
646
+ else:
647
+ result = {"success": False, "error": f"Unknown tool: {tool_name}"}
648
+ return tool_call, result
649
+
650
+ parallel_results: List[tuple] = []
651
+ if parallel_batch:
652
+ gathered = await asyncio.gather(
653
+ *[_exec_one(tool_call) for tool_call in parallel_batch],
654
+ return_exceptions=True,
655
+ )
656
+ for item in gathered:
657
+ if isinstance(item, Exception):
658
+ parallel_results.append((None, {"success": False, "error": str(item)}))
659
+ else:
660
+ parallel_results.append(item)
661
+ return collect_parallel_done(pending, parallel_results, serial_tools)
662
+
663
+
664
+ async def run_serial_tool(
665
+ tool_name: str,
666
+ tool_params: dict,
667
+ tool_executor: ToolExecutor,
668
+ *,
669
+ remote_runner: RemoteToolRunner | None = None,
670
+ hook: Hook | None = None,
671
+ ) -> Tuple[dict, float]:
672
+ """Execute one tool call and return (result, elapsed_seconds)."""
673
+ started = time.time()
674
+ if tool_name in tool_executor.local_tools:
675
+ result = tool_executor.execute_local(tool_name, tool_params)
676
+ elif remote_runner is not None:
677
+ if hook is not None:
678
+ hook("pre_tool", tool_name, tool_params, None)
679
+ try:
680
+ result = await remote_runner(tool_name, tool_params)
681
+ except Exception as exc:
682
+ result = {"success": False, "error": str(exc)}
683
+ if hook is not None:
684
+ hook("post_tool", tool_name, tool_params, result)
685
+ else:
686
+ result = {"success": False, "error": f"Unknown tool: {tool_name}"}
687
+ return result, time.time() - started
688
+
689
+
690
+ async def _maybe_await(value):
691
+ if inspect.isawaitable(value):
692
+ return await value
693
+ return value
694
+
695
+
696
+ async def execute_tool_turn(
697
+ pending: Sequence[dict],
698
+ *,
699
+ total_response: str,
700
+ tool_executor: ToolExecutor,
701
+ formatter: SummaryFormatter,
702
+ remote_runner: RemoteToolRunner | None = None,
703
+ hook: Hook | None = None,
704
+ cancel_event: asyncio.Event | None = None,
705
+ confirm_tools: Iterable[str] = (),
706
+ approval_callback: ApprovalCallback | None = None,
707
+ approval_applier: Callable[[dict, ApprovalDecision], dict] = apply_approval_decision,
708
+ loop_guard: LoopGuard | None = None,
709
+ serial_tools: Iterable[str] = DEFAULT_SERIAL_TOOLS,
710
+ ) -> ToolExecutionTurnResult:
711
+ """Execute one model-requested tool batch and prepare the next turn.
712
+
713
+ This is the reusable runtime boundary for the agent tool loop. Callers
714
+ provide UI-specific approval and rendering callbacks; this function owns
715
+ batching, permission application, execution, loop-guard retry directives,
716
+ and construction of the assistant/user follow-up messages.
717
+ """
718
+
719
+ confirm = set(confirm_tools)
720
+ parallel_done = await run_parallel_tools(
721
+ pending,
722
+ tool_executor,
723
+ remote_runner=remote_runner,
724
+ hook=hook,
725
+ serial_tools=serial_tools,
726
+ )
727
+ tool_turn = ToolTurnPlan(pending=pending, parallel_done=parallel_done)
728
+ tool_batch = tool_turn.batch
729
+ activities: List[ToolExecutionActivity] = []
730
+
731
+ for task in tool_turn.tasks():
732
+ if cancel_event and cancel_event.is_set():
733
+ tool_batch.cancel()
734
+ break
735
+
736
+ tool_name = task.tool_name
737
+ tool_params = task.params
738
+
739
+ if task.has_parallel_result:
740
+ tr = task.parallel_result or {}
741
+ tool_batch.add_result(tool_name, tr, formatter)
742
+ activities.append(ToolExecutionActivity(
743
+ tool=tool_name,
744
+ result=tr,
745
+ elapsed=0.0,
746
+ params=tool_params,
747
+ from_parallel=True,
748
+ ))
749
+ continue
750
+
751
+ if tool_name in confirm and approval_callback is not None:
752
+ decision = await _maybe_await(approval_callback(tool_name, tool_params))
753
+ if decision is None:
754
+ decision = ApprovalDecision.allow()
755
+ if not decision.approved:
756
+ tool_batch.cancel()
757
+ break
758
+ approval_applier(tool_params, decision)
759
+
760
+ tr, tool_elapsed = await run_serial_tool(
761
+ tool_name,
762
+ tool_params,
763
+ tool_executor,
764
+ remote_runner=remote_runner,
765
+ hook=hook,
766
+ )
767
+ tool_batch.add_result(tool_name, tr, formatter, elapsed=tool_elapsed)
768
+ activities.append(ToolExecutionActivity(
769
+ tool=tool_name,
770
+ result=tr,
771
+ elapsed=tool_elapsed,
772
+ params=tool_params,
773
+ ))
774
+
775
+ guard_directives: List[str] = []
776
+ if loop_guard is not None:
777
+ for activity in activities:
778
+ directive = loop_guard.record(activity.tool, activity.params, activity.result)
779
+ if directive:
780
+ guard_directives.append(directive)
781
+
782
+ assistant_message, user_message, followup = tool_batch.build_next_turn(total_response)
783
+ if guard_directives:
784
+ guard_text = "\n\n".join(guard_directives)
785
+ if isinstance(user_message.get("content"), str):
786
+ user_message["content"] += f"\n\n{guard_text}"
787
+ elif isinstance(user_message.get("content"), list):
788
+ user_message["content"].append({"type": "text", "text": guard_text})
789
+ followup += f"\n\n{guard_text}"
790
+
791
+ return ToolExecutionTurnResult(
792
+ batch=tool_batch,
793
+ activities=activities,
794
+ assistant_message=assistant_message,
795
+ user_message=user_message,
796
+ followup=followup,
797
+ guard_directives=guard_directives,
798
+ )
799
+
800
+
801
+ # Cap each tool result so a single large output (a long pip-install log, a big
802
+ # data dump, a verbose traceback) cannot blow past the model's context window.
803
+ # Without this, one oversized result is appended verbatim, the next provider
804
+ # call exceeds num_ctx, the prompt is truncated, and the model loses the task
805
+ # mid-run. Head+tail keeps the actionable parts (what ran / the final error)
806
+ # and drops the noisy middle.
807
+ _MAX_TOOL_RESULT_CHARS = 6000
808
+
809
+
810
+ def _truncate_tool_result(text: str, limit: int = _MAX_TOOL_RESULT_CHARS) -> str:
811
+ if len(text) <= limit:
812
+ return text
813
+ head = limit * 2 // 3
814
+ tail = limit - head
815
+ omitted = len(text) - head - tail
816
+ return (
817
+ text[:head]
818
+ + f"\n\n… [已截断 {omitted:,} 字符 — 输出过长,仅保留首尾以保护上下文] …\n\n"
819
+ + text[-tail:]
820
+ )
821
+
822
+
823
+ def build_tool_followup(tool_results: Sequence[dict]) -> str:
824
+ """Build a structured follow-up message from tool results.
825
+
826
+ Each result block is labelled with its tool name and a success/error
827
+ status so the model can clearly distinguish outcomes and respond
828
+ appropriately to failures rather than silently ignoring them. Each result
829
+ is size-capped (see ``_truncate_tool_result``) so a single huge output
830
+ cannot overflow the context window and cut the task short.
831
+ """
832
+ if not tool_results:
833
+ return "No tool results. Continue with what you know or ask the user for clarification."
834
+
835
+ blocks: List[str] = []
836
+ error_tools: List[str] = []
837
+
838
+ for item in tool_results:
839
+ tool = item.get("tool", "unknown")
840
+ result = item.get("result", "")
841
+ result_str = _truncate_tool_result(str(result))
842
+
843
+ is_error = (
844
+ result_str.startswith("Error") or
845
+ result_str.startswith("❌") or
846
+ "error" in result_str[:80].lower() or
847
+ "traceback" in result_str[:200].lower() or
848
+ "exception" in result_str[:200].lower()
849
+ )
850
+ status = "❌ Error" if is_error else "✓ Success"
851
+ if is_error:
852
+ error_tools.append(tool)
853
+ blocks.append(f"### [{tool}] {status}\n{result_str}")
854
+
855
+ followup = "## Tool Results\n\n" + "\n\n---\n\n".join(blocks)
856
+
857
+ if error_tools:
858
+ followup += (
859
+ f"\n\n⚠ Tool(s) returned errors: {', '.join(error_tools)}. "
860
+ "Read the error carefully. Options: (1) use read_file / search_code to diagnose, "
861
+ "(2) use edit_file to fix the issue and retry run_command, "
862
+ "(3) try a different approach. "
863
+ "Do NOT give up silently — explain what failed and what you tried."
864
+ )
865
+ else:
866
+ followup += (
867
+ "\n\nAll tools completed successfully. "
868
+ "If the task is now complete, provide your final response. "
869
+ "If additional steps are needed, continue using tools.\n\n"
870
+ "Please continue your analysis using these results."
871
+ )
872
+
873
+ return followup
874
+
875
+
876
+ def record_tool_result(
877
+ tool_results: List[dict],
878
+ tool_name: str,
879
+ result: dict,
880
+ formatter: SummaryFormatter,
881
+ ) -> dict:
882
+ """Append one tool result summary and return the appended record."""
883
+ summary = formatter(tool_name, result)
884
+ record = {"tool": tool_name, "result": summary}
885
+ tool_results.append(record)
886
+ return record
887
+
888
+
889
+ def build_next_turn_messages(total_response: str, tool_results: Sequence[dict]) -> Tuple[dict, dict, str]:
890
+ """Build assistant/user messages and follow-up text for the next agent turn.
891
+
892
+ When a screenshot tool stored an image in computer_use_tools._PENDING_VISION_IMAGE,
893
+ the user message content becomes a multipart list so vision models can see the image.
894
+ """
895
+ followup = build_tool_followup(tool_results)
896
+ assistant_message = {"role": "assistant", "content": total_response}
897
+
898
+ # Check for a pending screenshot from computer_screenshot / browser_screenshot
899
+ vision_b64: "str | None" = None
900
+ try:
901
+ from computer_use_tools import pop_pending_vision_image
902
+ vision_b64 = pop_pending_vision_image()
903
+ except ImportError:
904
+ pass
905
+
906
+ if vision_b64:
907
+ user_content: "str | list" = [
908
+ {
909
+ "type": "image_url",
910
+ "image_url": {"url": f"data:image/png;base64,{vision_b64}"},
911
+ },
912
+ {"type": "text", "text": followup},
913
+ ]
914
+ else:
915
+ user_content = followup
916
+
917
+ user_message = {"role": "user", "content": user_content}
918
+ return assistant_message, user_message, followup
919
+
920
+
921
+ # ── AgentEvent typed union ────────────────────────────────────────────────────
922
+
923
+ @dataclass(frozen=True)
924
+ class AgentEventToken:
925
+ """A text token streamed from the model."""
926
+ text: str
927
+
928
+
929
+ @dataclass(frozen=True)
930
+ class AgentEventThinking:
931
+ """A thinking/reasoning token from extended-thinking models."""
932
+ content: str
933
+
934
+
935
+ @dataclass(frozen=True)
936
+ class AgentEventToolCall:
937
+ """Model requested a tool call (before execution)."""
938
+ tool: str
939
+ params: dict
940
+
941
+
942
+ @dataclass(frozen=True)
943
+ class AgentEventToolResult:
944
+ """One tool has finished executing."""
945
+ tool: str
946
+ result: dict
947
+ elapsed: float
948
+
949
+
950
+ @dataclass(frozen=True)
951
+ class AgentEventStatus:
952
+ """Informational status change (e.g. provider fallback)."""
953
+ state: str
954
+ message: str
955
+
956
+
957
+ @dataclass(frozen=True)
958
+ class AgentEventComplete:
959
+ """Agent loop finished normally. Carries the full turn result."""
960
+ result: "AgentTurnResult"
961
+
962
+
963
+ @dataclass(frozen=True)
964
+ class AgentEventCancelled:
965
+ """Agent loop was cancelled by the user."""
966
+ partial_text: str
967
+
968
+
969
+ @dataclass(frozen=True)
970
+ class AgentEventError:
971
+ """Agent loop encountered an unrecoverable error."""
972
+ error: str
973
+
974
+
975
+ AgentEvent = Union[
976
+ AgentEventToken,
977
+ AgentEventThinking,
978
+ AgentEventToolCall,
979
+ AgentEventToolResult,
980
+ AgentEventStatus,
981
+ AgentEventComplete,
982
+ AgentEventCancelled,
983
+ AgentEventError,
984
+ ]
985
+
986
+
987
+ # ── AgentOptions ──────────────────────────────────────────────────────────────
988
+
989
+ @dataclass
990
+ class AgentOptions:
991
+ """Tunable parameters for one run_agent() invocation."""
992
+
993
+ max_rounds: int = 30
994
+ serial_tools: FrozenSet[str] = field(
995
+ default_factory=lambda: frozenset(DEFAULT_SERIAL_TOOLS)
996
+ )
997
+ tool_schemas: List[dict] = field(default_factory=list)
998
+
999
+
1000
+ # ── run_agent() ───────────────────────────────────────────────────────────────
1001
+
1002
+ async def run_agent(
1003
+ prompt: str,
1004
+ history: list,
1005
+ *,
1006
+ provider_fn: Callable,
1007
+ tool_executor: ToolExecutor,
1008
+ options: Optional["AgentOptions"] = None,
1009
+ remote_runner: Optional[RemoteToolRunner] = None,
1010
+ on_token: Optional[Callable[[str], None]] = None,
1011
+ on_thinking: Optional[Callable[[str], None]] = None,
1012
+ on_tool_call: Optional[Callable[[str, dict], None]] = None,
1013
+ on_tool_result: Optional[Callable[[str, dict], None]] = None,
1014
+ on_status: Optional[Callable[[str, str], None]] = None,
1015
+ hook: Optional[Hook] = None,
1016
+ cancel_event: Optional[asyncio.Event] = None,
1017
+ tool_result_formatter: Optional[SummaryFormatter] = None,
1018
+ ) -> AsyncGenerator["AgentEvent", None]:
1019
+ """Provider-agnostic multi-round agent loop.
1020
+
1021
+ Yields ``AgentEvent`` objects so every caller (REPL, bot, API) can
1022
+ handle UI in its own way without duplicating round-management logic.
1023
+
1024
+ Parameters
1025
+ ----------
1026
+ prompt:
1027
+ The user's message for this turn.
1028
+ history:
1029
+ Conversation history **before** the current prompt.
1030
+ provider_fn:
1031
+ Async callable ``(message, history, on_token, on_thinking,
1032
+ on_tool_call, cancel_event) -> dict``. Must return the same
1033
+ result dict that ``stream_ollama`` / ``stream_chat`` return.
1034
+ tool_executor:
1035
+ Local tool registry.
1036
+ options:
1037
+ Tunable loop parameters (max_rounds, serial_tools, …).
1038
+ remote_runner:
1039
+ Optional async callable for tools not in ``tool_executor``.
1040
+ on_token / on_thinking / on_tool_call / on_tool_result / on_status:
1041
+ Pass-through callbacks forwarded to ``provider_fn`` so callers
1042
+ that already set up streaming callbacks don't need to change.
1043
+ hook:
1044
+ Pre/post-tool hook fired around each tool execution.
1045
+ cancel_event:
1046
+ asyncio.Event; when set the loop exits at the next safe point.
1047
+ tool_result_formatter:
1048
+ Formats a tool result dict into a summary string. Defaults to
1049
+ ``str(result.get('output', result))``.
1050
+ """
1051
+ opts = options or AgentOptions()
1052
+ _formatter: SummaryFormatter = tool_result_formatter or (
1053
+ lambda _tool, res: str(res.get("output", res))
1054
+ )
1055
+ _serial = set(opts.serial_tools)
1056
+
1057
+ turn_state = AgentTurnState(provider="unknown")
1058
+ start_time = time.time()
1059
+ current_message = prompt
1060
+ token_count = 0
1061
+ thinking_tokens = 0
1062
+ result: dict = {}
1063
+ loop_guard = LoopGuard()
1064
+
1065
+ for round_num in range(opts.max_rounds):
1066
+ # ── Provider call ────────────────────────────────────────────────────
1067
+ response_text = ""
1068
+ _round_tokens = 0
1069
+
1070
+ def _wrap_on_token(tok: str) -> None:
1071
+ nonlocal response_text, token_count, _round_tokens
1072
+ response_text += tok
1073
+ _round_tokens += 1
1074
+ token_count += 1
1075
+ if on_token is not None:
1076
+ on_token(tok)
1077
+
1078
+ def _wrap_on_thinking(content: str) -> None:
1079
+ nonlocal thinking_tokens
1080
+ thinking_tokens += 1
1081
+ if on_thinking is not None:
1082
+ on_thinking(content)
1083
+
1084
+ def _wrap_on_tool_call(tool: str, params: dict) -> None:
1085
+ if on_tool_call is not None:
1086
+ on_tool_call(tool, params)
1087
+
1088
+ try:
1089
+ result = await provider_fn(
1090
+ current_message,
1091
+ history if round_num == 0 else [],
1092
+ on_token=_wrap_on_token,
1093
+ on_thinking=_wrap_on_thinking,
1094
+ on_tool_call=_wrap_on_tool_call,
1095
+ on_tool_result=on_tool_result,
1096
+ on_status=on_status,
1097
+ cancel_event=cancel_event,
1098
+ )
1099
+ except Exception as exc:
1100
+ yield AgentEventError(error=str(exc))
1101
+ return
1102
+
1103
+ if result.get("cancelled"):
1104
+ turn_state.append_response(response_text)
1105
+ yield AgentEventCancelled(partial_text=turn_state.total_response)
1106
+ return
1107
+
1108
+ if not result.get("success"):
1109
+ yield AgentEventError(error=result.get("error", "Unknown error"))
1110
+ return
1111
+
1112
+ turn_state.apply_model_result(result, response_text)
1113
+
1114
+ pending = result.get("tool_calls_pending", [])
1115
+ if not pending:
1116
+ break
1117
+ for tool_call in pending:
1118
+ yield AgentEventToolCall(
1119
+ tool=tool_call.get("tool", ""),
1120
+ params=tool_call.get("params", {}),
1121
+ )
1122
+
1123
+ # Warn caller on final round
1124
+ if round_num == opts.max_rounds - 1:
1125
+ yield AgentEventStatus(
1126
+ state="max_rounds",
1127
+ message=f"Max rounds ({opts.max_rounds}) reached",
1128
+ )
1129
+ break
1130
+
1131
+ # ── Tool execution ───────────────────────────────────────────────────
1132
+ tool_turn_result = await execute_tool_turn(
1133
+ pending,
1134
+ total_response=turn_state.total_response,
1135
+ tool_executor=tool_executor,
1136
+ formatter=_formatter,
1137
+ remote_runner=remote_runner,
1138
+ hook=hook,
1139
+ cancel_event=cancel_event,
1140
+ loop_guard=loop_guard,
1141
+ serial_tools=_serial,
1142
+ )
1143
+
1144
+ for activity in tool_turn_result.activities:
1145
+ turn_state.tools_used.append(activity.tool)
1146
+ yield AgentEventToolResult(
1147
+ tool=activity.tool,
1148
+ result=activity.result,
1149
+ elapsed=activity.elapsed,
1150
+ )
1151
+
1152
+ turn_state.add_tool_time(tool_turn_result.batch.elapsed_total)
1153
+ if tool_turn_result.cancelled:
1154
+ yield AgentEventCancelled(partial_text=turn_state.total_response)
1155
+ return
1156
+ if tool_turn_result.guard_directives:
1157
+ yield AgentEventStatus(
1158
+ state="loop_guard",
1159
+ message="Repeated failing tool call detected",
1160
+ )
1161
+
1162
+ history = list(history) + [
1163
+ tool_turn_result.assistant_message,
1164
+ tool_turn_result.user_message,
1165
+ ]
1166
+ current_message = tool_turn_result.followup
1167
+ turn_state.reset_response()
1168
+
1169
+ if loop_guard.should_break:
1170
+ turn_state.append_response(
1171
+ "\n\nRepeated failing tool calls were detected and the agent stopped retrying."
1172
+ )
1173
+ break
1174
+
1175
+ # ── Build final result ───────────────────────────────────────────────────
1176
+ elapsed = time.time() - start_time
1177
+ turn_result = turn_state.build_result(
1178
+ elapsed=elapsed,
1179
+ fallback_response=result.get("response", ""),
1180
+ token_count=token_count,
1181
+ thinking_tokens=thinking_tokens,
1182
+ )
1183
+ yield AgentEventComplete(result=turn_result)