loom-agent 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of loom-agent might be problematic. Click here for more details.
- loom/__init__.py +77 -0
- loom/agent.py +217 -0
- loom/agents/__init__.py +10 -0
- loom/agents/refs.py +28 -0
- loom/agents/registry.py +50 -0
- loom/builtin/compression/__init__.py +4 -0
- loom/builtin/compression/structured.py +79 -0
- loom/builtin/embeddings/__init__.py +9 -0
- loom/builtin/embeddings/openai_embedding.py +135 -0
- loom/builtin/embeddings/sentence_transformers_embedding.py +145 -0
- loom/builtin/llms/__init__.py +8 -0
- loom/builtin/llms/mock.py +34 -0
- loom/builtin/llms/openai.py +168 -0
- loom/builtin/llms/rule.py +102 -0
- loom/builtin/memory/__init__.py +5 -0
- loom/builtin/memory/in_memory.py +21 -0
- loom/builtin/memory/persistent_memory.py +278 -0
- loom/builtin/retriever/__init__.py +9 -0
- loom/builtin/retriever/chroma_store.py +265 -0
- loom/builtin/retriever/in_memory.py +106 -0
- loom/builtin/retriever/milvus_store.py +307 -0
- loom/builtin/retriever/pinecone_store.py +237 -0
- loom/builtin/retriever/qdrant_store.py +274 -0
- loom/builtin/retriever/vector_store.py +128 -0
- loom/builtin/retriever/vector_store_config.py +217 -0
- loom/builtin/tools/__init__.py +32 -0
- loom/builtin/tools/calculator.py +49 -0
- loom/builtin/tools/document_search.py +111 -0
- loom/builtin/tools/glob.py +27 -0
- loom/builtin/tools/grep.py +56 -0
- loom/builtin/tools/http_request.py +86 -0
- loom/builtin/tools/python_repl.py +73 -0
- loom/builtin/tools/read_file.py +32 -0
- loom/builtin/tools/task.py +158 -0
- loom/builtin/tools/web_search.py +64 -0
- loom/builtin/tools/write_file.py +31 -0
- loom/callbacks/base.py +9 -0
- loom/callbacks/logging.py +12 -0
- loom/callbacks/metrics.py +27 -0
- loom/callbacks/observability.py +248 -0
- loom/components/agent.py +107 -0
- loom/core/agent_executor.py +450 -0
- loom/core/circuit_breaker.py +178 -0
- loom/core/compression_manager.py +329 -0
- loom/core/context_retriever.py +185 -0
- loom/core/error_classifier.py +193 -0
- loom/core/errors.py +66 -0
- loom/core/message_queue.py +167 -0
- loom/core/permission_store.py +62 -0
- loom/core/permissions.py +69 -0
- loom/core/scheduler.py +125 -0
- loom/core/steering_control.py +47 -0
- loom/core/structured_logger.py +279 -0
- loom/core/subagent_pool.py +232 -0
- loom/core/system_prompt.py +141 -0
- loom/core/system_reminders.py +283 -0
- loom/core/tool_pipeline.py +113 -0
- loom/core/types.py +269 -0
- loom/interfaces/compressor.py +59 -0
- loom/interfaces/embedding.py +51 -0
- loom/interfaces/llm.py +33 -0
- loom/interfaces/memory.py +29 -0
- loom/interfaces/retriever.py +179 -0
- loom/interfaces/tool.py +27 -0
- loom/interfaces/vector_store.py +80 -0
- loom/llm/__init__.py +14 -0
- loom/llm/config.py +228 -0
- loom/llm/factory.py +111 -0
- loom/llm/model_health.py +235 -0
- loom/llm/model_pool_advanced.py +305 -0
- loom/llm/pool.py +170 -0
- loom/llm/registry.py +201 -0
- loom/mcp/__init__.py +4 -0
- loom/mcp/client.py +86 -0
- loom/mcp/registry.py +58 -0
- loom/mcp/tool_adapter.py +48 -0
- loom/observability/__init__.py +5 -0
- loom/patterns/__init__.py +5 -0
- loom/patterns/multi_agent.py +123 -0
- loom/patterns/rag.py +262 -0
- loom/plugins/registry.py +55 -0
- loom/resilience/__init__.py +5 -0
- loom/tooling.py +72 -0
- loom/utils/agent_loader.py +218 -0
- loom/utils/token_counter.py +19 -0
- loom_agent-0.0.1.dist-info/METADATA +457 -0
- loom_agent-0.0.1.dist-info/RECORD +89 -0
- loom_agent-0.0.1.dist-info/WHEEL +4 -0
- loom_agent-0.0.1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,450 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from typing import AsyncGenerator, Dict, List, Optional
|
|
5
|
+
from uuid import uuid4
|
|
6
|
+
|
|
7
|
+
from loom.core.steering_control import SteeringControl
|
|
8
|
+
from loom.core.tool_pipeline import ToolExecutionPipeline
|
|
9
|
+
from loom.core.types import Message, StreamEvent, ToolCall
|
|
10
|
+
from loom.core.system_prompt import build_system_prompt
|
|
11
|
+
from loom.interfaces.compressor import BaseCompressor
|
|
12
|
+
from loom.interfaces.llm import BaseLLM
|
|
13
|
+
from loom.interfaces.memory import BaseMemory
|
|
14
|
+
from loom.interfaces.tool import BaseTool
|
|
15
|
+
from loom.core.permissions import PermissionManager
|
|
16
|
+
from loom.callbacks.metrics import MetricsCollector
|
|
17
|
+
import time
|
|
18
|
+
from loom.utils.token_counter import count_messages_tokens
|
|
19
|
+
from loom.callbacks.base import BaseCallback
|
|
20
|
+
from loom.core.errors import ExecutionAbortedError
|
|
21
|
+
|
|
22
|
+
# RAG support
|
|
23
|
+
try:
|
|
24
|
+
from loom.core.context_retriever import ContextRetriever
|
|
25
|
+
except ImportError:
|
|
26
|
+
ContextRetriever = None # type: ignore
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class AgentExecutor:
|
|
30
|
+
"""Agent 执行器:封装主循环,连接 LLM、内存、工具流水线与事件流。"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
llm: BaseLLM,
|
|
35
|
+
tools: Dict[str, BaseTool] | None = None,
|
|
36
|
+
memory: BaseMemory | None = None,
|
|
37
|
+
compressor: BaseCompressor | None = None,
|
|
38
|
+
context_retriever: Optional["ContextRetriever"] = None, # 🆕 RAG support
|
|
39
|
+
steering_control: SteeringControl | None = None,
|
|
40
|
+
max_iterations: int = 50,
|
|
41
|
+
max_context_tokens: int = 16000,
|
|
42
|
+
permission_manager: PermissionManager | None = None,
|
|
43
|
+
metrics: MetricsCollector | None = None,
|
|
44
|
+
system_instructions: Optional[str] = None,
|
|
45
|
+
callbacks: Optional[List[BaseCallback]] = None,
|
|
46
|
+
enable_steering: bool = False, # 🆕 US1: Real-time steering
|
|
47
|
+
) -> None:
|
|
48
|
+
self.llm = llm
|
|
49
|
+
self.tools = tools or {}
|
|
50
|
+
self.memory = memory
|
|
51
|
+
self.compressor = compressor
|
|
52
|
+
self.context_retriever = context_retriever # 🆕 RAG support
|
|
53
|
+
self.steering_control = steering_control or SteeringControl()
|
|
54
|
+
self.max_iterations = max_iterations
|
|
55
|
+
self.max_context_tokens = max_context_tokens
|
|
56
|
+
self.metrics = metrics or MetricsCollector()
|
|
57
|
+
self.permission_manager = permission_manager or PermissionManager(policy={"default": "allow"})
|
|
58
|
+
self.system_instructions = system_instructions
|
|
59
|
+
self.callbacks = callbacks or []
|
|
60
|
+
self.enable_steering = enable_steering # 🆕 US1
|
|
61
|
+
self.tool_pipeline = ToolExecutionPipeline(
|
|
62
|
+
self.tools, permission_manager=self.permission_manager, metrics=self.metrics
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
async def _emit(self, event_type: str, payload: Dict) -> None:
|
|
66
|
+
if not self.callbacks:
|
|
67
|
+
return
|
|
68
|
+
enriched = dict(payload)
|
|
69
|
+
enriched.setdefault("ts", time.time())
|
|
70
|
+
enriched.setdefault("type", event_type)
|
|
71
|
+
for cb in self.callbacks:
|
|
72
|
+
try:
|
|
73
|
+
await cb.on_event(event_type, enriched)
|
|
74
|
+
except Exception:
|
|
75
|
+
# best-effort; don't fail agent execution on callback errors
|
|
76
|
+
pass
|
|
77
|
+
|
|
78
|
+
async def execute(
|
|
79
|
+
self,
|
|
80
|
+
user_input: str,
|
|
81
|
+
cancel_token: Optional[asyncio.Event] = None, # 🆕 US1: Cancellation support
|
|
82
|
+
correlation_id: Optional[str] = None, # 🆕 US1: Request tracing
|
|
83
|
+
) -> str:
|
|
84
|
+
"""非流式执行,包含工具调用的 ReAct 循环(最小实现)。
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
user_input: User query/instruction
|
|
88
|
+
cancel_token: Optional Event to signal cancellation (US1)
|
|
89
|
+
correlation_id: Optional correlation ID for request tracing (US1)
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
Final agent response (or partial results if cancelled)
|
|
93
|
+
"""
|
|
94
|
+
# Generate correlation_id if not provided
|
|
95
|
+
if correlation_id is None:
|
|
96
|
+
correlation_id = str(uuid4())
|
|
97
|
+
|
|
98
|
+
await self._emit("request_start", {
|
|
99
|
+
"input": user_input,
|
|
100
|
+
"source": "execute",
|
|
101
|
+
"iteration": 0,
|
|
102
|
+
"correlation_id": correlation_id, # 🆕 US1
|
|
103
|
+
})
|
|
104
|
+
|
|
105
|
+
# Check cancellation before starting
|
|
106
|
+
if cancel_token and cancel_token.is_set():
|
|
107
|
+
await self._emit("agent_finish", {
|
|
108
|
+
"content": "Request cancelled before execution",
|
|
109
|
+
"source": "execute",
|
|
110
|
+
"correlation_id": correlation_id,
|
|
111
|
+
})
|
|
112
|
+
return "Request cancelled before execution"
|
|
113
|
+
|
|
114
|
+
history = await self._load_history()
|
|
115
|
+
|
|
116
|
+
# 🆕 Step 1: RAG - 自动检索相关文档(如果配置了 context_retriever)
|
|
117
|
+
retrieved_docs = []
|
|
118
|
+
if self.context_retriever:
|
|
119
|
+
retrieved_docs = await self.context_retriever.retrieve_for_query(user_input)
|
|
120
|
+
if retrieved_docs:
|
|
121
|
+
# 注入检索到的文档上下文
|
|
122
|
+
if self.context_retriever.inject_as == "system":
|
|
123
|
+
doc_context = self.context_retriever.format_documents(retrieved_docs)
|
|
124
|
+
history.append(Message(
|
|
125
|
+
role="system",
|
|
126
|
+
content=doc_context,
|
|
127
|
+
metadata={"type": "retrieved_context", "doc_count": len(retrieved_docs)}
|
|
128
|
+
))
|
|
129
|
+
# 记录检索指标
|
|
130
|
+
self.metrics.metrics.retrievals = getattr(self.metrics.metrics, "retrievals", 0) + 1
|
|
131
|
+
await self._emit("retrieval_complete", {"doc_count": len(retrieved_docs), "source": "execute"})
|
|
132
|
+
|
|
133
|
+
# Step 2: 添加用户消息
|
|
134
|
+
history.append(Message(role="user", content=user_input))
|
|
135
|
+
|
|
136
|
+
# Step 3: 压缩检查
|
|
137
|
+
history = await self._maybe_compress(history)
|
|
138
|
+
|
|
139
|
+
# Step 4: 动态生成系统提示
|
|
140
|
+
context = {"retrieved_docs_count": len(retrieved_docs)} if retrieved_docs else None
|
|
141
|
+
system_prompt = build_system_prompt(self.tools, self.system_instructions, context)
|
|
142
|
+
history = self._inject_system_prompt(history, system_prompt)
|
|
143
|
+
|
|
144
|
+
if not self.llm.supports_tools or not self.tools:
|
|
145
|
+
try:
|
|
146
|
+
# Create LLM task that can be cancelled
|
|
147
|
+
llm_task = asyncio.create_task(self.llm.generate([m.__dict__ for m in history]))
|
|
148
|
+
|
|
149
|
+
# Poll for cancellation while waiting for LLM
|
|
150
|
+
while not llm_task.done():
|
|
151
|
+
if cancel_token and cancel_token.is_set():
|
|
152
|
+
llm_task.cancel()
|
|
153
|
+
try:
|
|
154
|
+
await llm_task
|
|
155
|
+
except asyncio.CancelledError:
|
|
156
|
+
pass
|
|
157
|
+
partial_result = "Execution interrupted during LLM call"
|
|
158
|
+
await self._emit("agent_finish", {
|
|
159
|
+
"content": partial_result,
|
|
160
|
+
"source": "execute",
|
|
161
|
+
"correlation_id": correlation_id,
|
|
162
|
+
"interrupted": True,
|
|
163
|
+
})
|
|
164
|
+
return partial_result
|
|
165
|
+
await asyncio.sleep(0.1) # Check every 100ms
|
|
166
|
+
|
|
167
|
+
text = await llm_task
|
|
168
|
+
except asyncio.CancelledError:
|
|
169
|
+
partial_result = "Execution interrupted during LLM call"
|
|
170
|
+
await self._emit("agent_finish", {
|
|
171
|
+
"content": partial_result,
|
|
172
|
+
"source": "execute",
|
|
173
|
+
"correlation_id": correlation_id,
|
|
174
|
+
"interrupted": True,
|
|
175
|
+
})
|
|
176
|
+
return partial_result
|
|
177
|
+
except Exception as e:
|
|
178
|
+
self.metrics.metrics.total_errors += 1
|
|
179
|
+
await self._emit("error", {"stage": "llm_generate", "message": str(e)})
|
|
180
|
+
raise
|
|
181
|
+
self.metrics.metrics.llm_calls += 1
|
|
182
|
+
if self.memory:
|
|
183
|
+
await self.memory.add_message(Message(role="assistant", content=text))
|
|
184
|
+
await self._emit("agent_finish", {"content": text, "source": "execute"})
|
|
185
|
+
return text
|
|
186
|
+
|
|
187
|
+
tools_spec = self._serialize_tools()
|
|
188
|
+
iterations = 0
|
|
189
|
+
final_text = ""
|
|
190
|
+
while iterations < self.max_iterations:
|
|
191
|
+
# 🆕 US1: Check cancellation before each iteration
|
|
192
|
+
if cancel_token and cancel_token.is_set():
|
|
193
|
+
partial_result = f"Execution interrupted after {iterations} iterations. Partial progress: {final_text or '(in progress)'}"
|
|
194
|
+
await self._emit("agent_finish", {
|
|
195
|
+
"content": partial_result,
|
|
196
|
+
"source": "execute",
|
|
197
|
+
"correlation_id": correlation_id,
|
|
198
|
+
"interrupted": True,
|
|
199
|
+
"iterations_completed": iterations,
|
|
200
|
+
})
|
|
201
|
+
return partial_result
|
|
202
|
+
|
|
203
|
+
try:
|
|
204
|
+
resp = await self.llm.generate_with_tools([m.__dict__ for m in history], tools_spec)
|
|
205
|
+
except Exception as e:
|
|
206
|
+
self.metrics.metrics.total_errors += 1
|
|
207
|
+
await self._emit("error", {
|
|
208
|
+
"stage": "llm_generate_with_tools",
|
|
209
|
+
"message": str(e),
|
|
210
|
+
"source": "execute",
|
|
211
|
+
"iteration": iterations,
|
|
212
|
+
"correlation_id": correlation_id, # 🆕 US1
|
|
213
|
+
})
|
|
214
|
+
raise
|
|
215
|
+
self.metrics.metrics.llm_calls += 1
|
|
216
|
+
tool_calls = resp.get("tool_calls") or []
|
|
217
|
+
content = resp.get("content") or ""
|
|
218
|
+
|
|
219
|
+
if tool_calls:
|
|
220
|
+
# 广播工具调用开始(非流式路径)
|
|
221
|
+
try:
|
|
222
|
+
meta = [
|
|
223
|
+
{"id": str(tc.get("id", "")), "name": str(tc.get("name", ""))}
|
|
224
|
+
for tc in tool_calls
|
|
225
|
+
]
|
|
226
|
+
await self._emit("tool_calls_start", {"tool_calls": meta, "source": "execute", "iteration": iterations})
|
|
227
|
+
except Exception:
|
|
228
|
+
pass
|
|
229
|
+
# 执行工具并把结果写回消息
|
|
230
|
+
try:
|
|
231
|
+
for tr in await self._execute_tool_batch(tool_calls):
|
|
232
|
+
tool_msg = Message(role="tool", content=tr.content, tool_call_id=tr.tool_call_id)
|
|
233
|
+
history.append(tool_msg)
|
|
234
|
+
if self.memory:
|
|
235
|
+
await self.memory.add_message(tool_msg)
|
|
236
|
+
await self._emit("tool_result", {"tool_call_id": tr.tool_call_id, "content": tr.content, "source": "execute", "iteration": iterations})
|
|
237
|
+
except Exception as e:
|
|
238
|
+
self.metrics.metrics.total_errors += 1
|
|
239
|
+
await self._emit("error", {"stage": "tool_execute", "message": str(e), "source": "execute", "iteration": iterations})
|
|
240
|
+
raise
|
|
241
|
+
iterations += 1
|
|
242
|
+
self.metrics.metrics.total_iterations += 1
|
|
243
|
+
history = await self._maybe_compress(history)
|
|
244
|
+
continue
|
|
245
|
+
|
|
246
|
+
# 无工具调用:认为生成最终答案
|
|
247
|
+
final_text = content
|
|
248
|
+
if self.memory:
|
|
249
|
+
await self.memory.add_message(Message(role="assistant", content=final_text))
|
|
250
|
+
await self._emit("agent_finish", {
|
|
251
|
+
"content": final_text,
|
|
252
|
+
"source": "execute",
|
|
253
|
+
"correlation_id": correlation_id, # 🆕 US1
|
|
254
|
+
})
|
|
255
|
+
break
|
|
256
|
+
|
|
257
|
+
return final_text
|
|
258
|
+
|
|
259
|
+
async def stream(self, user_input: str) -> AsyncGenerator[StreamEvent, None]:
|
|
260
|
+
"""流式执行:输出 text_delta/agent_finish 事件。后续可接入 tool_calls。"""
|
|
261
|
+
yield StreamEvent(type="request_start")
|
|
262
|
+
await self._emit("request_start", {"input": user_input, "source": "stream", "iteration": 0})
|
|
263
|
+
history = await self._load_history()
|
|
264
|
+
|
|
265
|
+
# 🆕 RAG - 自动检索文档
|
|
266
|
+
retrieved_docs = []
|
|
267
|
+
if self.context_retriever:
|
|
268
|
+
retrieved_docs = await self.context_retriever.retrieve_for_query(user_input)
|
|
269
|
+
if retrieved_docs:
|
|
270
|
+
if self.context_retriever.inject_as == "system":
|
|
271
|
+
doc_context = self.context_retriever.format_documents(retrieved_docs)
|
|
272
|
+
history.append(Message(
|
|
273
|
+
role="system",
|
|
274
|
+
content=doc_context,
|
|
275
|
+
metadata={"type": "retrieved_context", "doc_count": len(retrieved_docs)}
|
|
276
|
+
))
|
|
277
|
+
self.metrics.metrics.retrievals = getattr(self.metrics.metrics, "retrievals", 0) + 1
|
|
278
|
+
# 🆕 广播检索事件
|
|
279
|
+
yield StreamEvent(type="retrieval_complete", metadata={"doc_count": len(retrieved_docs)})
|
|
280
|
+
await self._emit("retrieval_complete", {"doc_count": len(retrieved_docs), "source": "stream"})
|
|
281
|
+
|
|
282
|
+
history.append(Message(role="user", content=user_input))
|
|
283
|
+
|
|
284
|
+
# 压缩检查
|
|
285
|
+
compressed = await self._maybe_compress(history)
|
|
286
|
+
if compressed is not history:
|
|
287
|
+
history = compressed
|
|
288
|
+
yield StreamEvent(type="compression_applied")
|
|
289
|
+
|
|
290
|
+
# 动态生成系统提示
|
|
291
|
+
context = {"retrieved_docs_count": len(retrieved_docs)} if retrieved_docs else None
|
|
292
|
+
system_prompt = build_system_prompt(self.tools, self.system_instructions, context)
|
|
293
|
+
history = self._inject_system_prompt(history, system_prompt)
|
|
294
|
+
|
|
295
|
+
if not self.llm.supports_tools or not self.tools:
|
|
296
|
+
try:
|
|
297
|
+
async for delta in self.llm.stream([m.__dict__ for m in history]):
|
|
298
|
+
yield StreamEvent(type="text_delta", content=delta)
|
|
299
|
+
except Exception as e:
|
|
300
|
+
self.metrics.metrics.total_errors += 1
|
|
301
|
+
await self._emit("error", {"stage": "llm_stream", "message": str(e), "source": "stream"})
|
|
302
|
+
raise
|
|
303
|
+
yield StreamEvent(type="agent_finish")
|
|
304
|
+
return
|
|
305
|
+
|
|
306
|
+
tools_spec = self._serialize_tools()
|
|
307
|
+
iterations = 0
|
|
308
|
+
while iterations < self.max_iterations:
|
|
309
|
+
try:
|
|
310
|
+
resp = await self.llm.generate_with_tools([m.__dict__ for m in history], tools_spec)
|
|
311
|
+
except Exception as e:
|
|
312
|
+
self.metrics.metrics.total_errors += 1
|
|
313
|
+
await self._emit("error", {"stage": "llm_generate_with_tools", "message": str(e), "source": "stream", "iteration": iterations})
|
|
314
|
+
raise
|
|
315
|
+
self.metrics.metrics.llm_calls += 1
|
|
316
|
+
tool_calls = resp.get("tool_calls") or []
|
|
317
|
+
content = resp.get("content") or ""
|
|
318
|
+
|
|
319
|
+
if tool_calls:
|
|
320
|
+
# 广播工具调用开始
|
|
321
|
+
tc_models = [self._to_tool_call(tc) for tc in tool_calls]
|
|
322
|
+
yield StreamEvent(type="tool_calls_start", tool_calls=tc_models)
|
|
323
|
+
await self._emit(
|
|
324
|
+
"tool_calls_start",
|
|
325
|
+
{"tool_calls": [{"id": t.id, "name": t.name} for t in tc_models], "source": "stream", "iteration": iterations},
|
|
326
|
+
)
|
|
327
|
+
# 执行工具
|
|
328
|
+
try:
|
|
329
|
+
async for tr in self._execute_tool_calls_async(tc_models):
|
|
330
|
+
yield StreamEvent(type="tool_result", result=tr)
|
|
331
|
+
await self._emit("tool_result", {"tool_call_id": tr.tool_call_id, "content": tr.content, "source": "stream", "iteration": iterations})
|
|
332
|
+
tool_msg = Message(role="tool", content=tr.content, tool_call_id=tr.tool_call_id)
|
|
333
|
+
history.append(tool_msg)
|
|
334
|
+
if self.memory:
|
|
335
|
+
await self.memory.add_message(tool_msg)
|
|
336
|
+
except Exception as e:
|
|
337
|
+
self.metrics.metrics.total_errors += 1
|
|
338
|
+
await self._emit("error", {"stage": "tool_execute", "message": str(e), "source": "stream", "iteration": iterations})
|
|
339
|
+
raise
|
|
340
|
+
iterations += 1
|
|
341
|
+
self.metrics.metrics.total_iterations += 1
|
|
342
|
+
# 每轮结束后做压缩检查
|
|
343
|
+
history = await self._maybe_compress(history)
|
|
344
|
+
continue
|
|
345
|
+
|
|
346
|
+
# 无工具调用:输出最终文本并结束
|
|
347
|
+
if content:
|
|
348
|
+
yield StreamEvent(type="text_delta", content=content)
|
|
349
|
+
yield StreamEvent(type="agent_finish")
|
|
350
|
+
await self._emit("agent_finish", {"content": content})
|
|
351
|
+
if self.memory and content:
|
|
352
|
+
await self.memory.add_message(Message(role="assistant", content=content))
|
|
353
|
+
break
|
|
354
|
+
|
|
355
|
+
async def _load_history(self) -> List[Message]:
|
|
356
|
+
if not self.memory:
|
|
357
|
+
return []
|
|
358
|
+
return await self.memory.get_messages()
|
|
359
|
+
|
|
360
|
+
async def _maybe_compress(self, history: List[Message]) -> List[Message]:
|
|
361
|
+
"""Check if compression needed and apply if threshold reached.
|
|
362
|
+
|
|
363
|
+
US2: Automatic compression at 92% threshold with 8-segment summarization.
|
|
364
|
+
"""
|
|
365
|
+
if not self.compressor:
|
|
366
|
+
return history
|
|
367
|
+
|
|
368
|
+
tokens_before = count_messages_tokens(history)
|
|
369
|
+
|
|
370
|
+
# Check if compression should be triggered (92% threshold)
|
|
371
|
+
if self.compressor.should_compress(tokens_before, self.max_context_tokens):
|
|
372
|
+
# Attempt compression
|
|
373
|
+
try:
|
|
374
|
+
compressed_messages, metadata = await self.compressor.compress(history)
|
|
375
|
+
|
|
376
|
+
# Update metrics
|
|
377
|
+
self.metrics.metrics.compressions = getattr(self.metrics.metrics, "compressions", 0) + 1
|
|
378
|
+
if metadata.key_topics == ["fallback"]:
|
|
379
|
+
self.metrics.metrics.compression_fallbacks = getattr(self.metrics.metrics, "compression_fallbacks", 0) + 1
|
|
380
|
+
|
|
381
|
+
# Emit compression event with metadata
|
|
382
|
+
await self._emit(
|
|
383
|
+
"compression_applied",
|
|
384
|
+
{
|
|
385
|
+
"before_tokens": metadata.original_tokens,
|
|
386
|
+
"after_tokens": metadata.compressed_tokens,
|
|
387
|
+
"compression_ratio": metadata.compression_ratio,
|
|
388
|
+
"original_message_count": metadata.original_message_count,
|
|
389
|
+
"compressed_message_count": metadata.compressed_message_count,
|
|
390
|
+
"key_topics": metadata.key_topics,
|
|
391
|
+
"fallback_used": metadata.key_topics == ["fallback"],
|
|
392
|
+
},
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
return compressed_messages
|
|
396
|
+
|
|
397
|
+
except Exception as e:
|
|
398
|
+
# Compression failed - continue without compression
|
|
399
|
+
self.metrics.metrics.total_errors += 1
|
|
400
|
+
await self._emit("error", {
|
|
401
|
+
"stage": "compression",
|
|
402
|
+
"message": str(e),
|
|
403
|
+
})
|
|
404
|
+
return history
|
|
405
|
+
|
|
406
|
+
return history
|
|
407
|
+
|
|
408
|
+
def _serialize_tools(self) -> List[Dict]:
|
|
409
|
+
tools_spec: List[Dict] = []
|
|
410
|
+
for t in self.tools.values():
|
|
411
|
+
schema = {}
|
|
412
|
+
try:
|
|
413
|
+
schema = t.args_schema.model_json_schema() # type: ignore[attr-defined]
|
|
414
|
+
except Exception:
|
|
415
|
+
schema = {"type": "object", "properties": {}}
|
|
416
|
+
tools_spec.append(
|
|
417
|
+
{
|
|
418
|
+
"type": "function",
|
|
419
|
+
"function": {
|
|
420
|
+
"name": t.name,
|
|
421
|
+
"description": getattr(t, "description", ""),
|
|
422
|
+
"parameters": schema,
|
|
423
|
+
},
|
|
424
|
+
}
|
|
425
|
+
)
|
|
426
|
+
return tools_spec
|
|
427
|
+
|
|
428
|
+
def _to_tool_call(self, raw: Dict) -> ToolCall:
|
|
429
|
+
# 允许 Rule/Mock LLM 输出简单 dict
|
|
430
|
+
return ToolCall(id=str(raw.get("id", "call_0")), name=raw["name"], arguments=raw.get("arguments", {}))
|
|
431
|
+
|
|
432
|
+
async def _execute_tool_batch(self, tool_calls_raw: List[Dict]) -> List[ToolResult]:
|
|
433
|
+
tc_models = [self._to_tool_call(tc) for tc in tool_calls_raw]
|
|
434
|
+
results: List[ToolResult] = []
|
|
435
|
+
async for tr in self._execute_tool_calls_async(tc_models):
|
|
436
|
+
results.append(tr)
|
|
437
|
+
return results
|
|
438
|
+
|
|
439
|
+
async def _execute_tool_calls_async(self, tool_calls: List[ToolCall]):
|
|
440
|
+
async for tr in self.tool_pipeline.execute_calls(tool_calls):
|
|
441
|
+
yield tr
|
|
442
|
+
|
|
443
|
+
def _inject_system_prompt(self, history: List[Message], system_prompt: str) -> List[Message]:
|
|
444
|
+
"""注入或更新系统提示消息"""
|
|
445
|
+
# 如果第一条是系统消息,则替换;否则在开头插入
|
|
446
|
+
if history and history[0].role == "system":
|
|
447
|
+
history[0] = Message(role="system", content=system_prompt)
|
|
448
|
+
else:
|
|
449
|
+
history.insert(0, Message(role="system", content=system_prompt))
|
|
450
|
+
return history
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
"""US5: Circuit Breaker Pattern
|
|
2
|
+
|
|
3
|
+
Implements the circuit breaker pattern to prevent cascading failures.
|
|
4
|
+
|
|
5
|
+
States:
|
|
6
|
+
- CLOSED: Normal operation, requests pass through
|
|
7
|
+
- OPEN: Failures exceeded threshold, requests fail fast
|
|
8
|
+
- HALF_OPEN: Testing if service recovered, limited requests allowed
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import asyncio
|
|
14
|
+
import time
|
|
15
|
+
from enum import Enum
|
|
16
|
+
from typing import Callable, Any, Optional
|
|
17
|
+
from dataclasses import dataclass
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class CircuitState(str, Enum):
|
|
21
|
+
"""Circuit breaker states."""
|
|
22
|
+
CLOSED = "closed" # Normal operation
|
|
23
|
+
OPEN = "open" # Failing fast
|
|
24
|
+
HALF_OPEN = "half_open" # Testing recovery
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class CircuitBreakerConfig:
|
|
29
|
+
"""Configuration for circuit breaker."""
|
|
30
|
+
failure_threshold: int = 5 # Failures before opening circuit
|
|
31
|
+
success_threshold: int = 2 # Successes in half-open before closing
|
|
32
|
+
timeout_seconds: float = 60.0 # Time to wait before trying half-open
|
|
33
|
+
exclude_exceptions: tuple = () # Exceptions that don't count as failures
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class CircuitBreakerOpenError(Exception):
|
|
37
|
+
"""Raised when circuit breaker is open."""
|
|
38
|
+
pass
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class CircuitBreaker:
|
|
42
|
+
"""Circuit breaker implementation.
|
|
43
|
+
|
|
44
|
+
Example:
|
|
45
|
+
breaker = CircuitBreaker()
|
|
46
|
+
|
|
47
|
+
async def call_external_service():
|
|
48
|
+
async with breaker:
|
|
49
|
+
return await some_external_api_call()
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(self, config: Optional[CircuitBreakerConfig] = None):
|
|
53
|
+
"""Initialize circuit breaker.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
config: Circuit breaker configuration
|
|
57
|
+
"""
|
|
58
|
+
self.config = config or CircuitBreakerConfig()
|
|
59
|
+
self.state = CircuitState.CLOSED
|
|
60
|
+
self.failure_count = 0
|
|
61
|
+
self.success_count = 0
|
|
62
|
+
self.last_failure_time: Optional[float] = None
|
|
63
|
+
self._lock = asyncio.Lock()
|
|
64
|
+
|
|
65
|
+
async def __aenter__(self):
|
|
66
|
+
"""Context manager entry - check if circuit allows request."""
|
|
67
|
+
async with self._lock:
|
|
68
|
+
# Check if we should transition from OPEN to HALF_OPEN
|
|
69
|
+
if self.state == CircuitState.OPEN:
|
|
70
|
+
if self._should_attempt_reset():
|
|
71
|
+
self.state = CircuitState.HALF_OPEN
|
|
72
|
+
self.success_count = 0
|
|
73
|
+
else:
|
|
74
|
+
raise CircuitBreakerOpenError(
|
|
75
|
+
f"Circuit breaker is OPEN. "
|
|
76
|
+
f"Retry after {self._time_until_retry():.1f}s"
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
return self
|
|
80
|
+
|
|
81
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
82
|
+
"""Context manager exit - record success/failure."""
|
|
83
|
+
async with self._lock:
|
|
84
|
+
if exc_type is None:
|
|
85
|
+
# Success
|
|
86
|
+
await self._on_success()
|
|
87
|
+
elif not isinstance(exc_val, self.config.exclude_exceptions):
|
|
88
|
+
# Failure (unless excluded)
|
|
89
|
+
await self._on_failure()
|
|
90
|
+
|
|
91
|
+
return False # Don't suppress exceptions
|
|
92
|
+
|
|
93
|
+
async def call(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
|
|
94
|
+
"""Call a function through the circuit breaker.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
func: Async function to call
|
|
98
|
+
*args: Positional arguments
|
|
99
|
+
**kwargs: Keyword arguments
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
Result of the function
|
|
103
|
+
|
|
104
|
+
Raises:
|
|
105
|
+
CircuitBreakerOpenError: If circuit is open
|
|
106
|
+
Exception: Any exception from the function
|
|
107
|
+
"""
|
|
108
|
+
async with self:
|
|
109
|
+
if asyncio.iscoroutinefunction(func):
|
|
110
|
+
return await func(*args, **kwargs)
|
|
111
|
+
else:
|
|
112
|
+
return func(*args, **kwargs)
|
|
113
|
+
|
|
114
|
+
async def _on_success(self) -> None:
|
|
115
|
+
"""Handle successful call."""
|
|
116
|
+
if self.state == CircuitState.HALF_OPEN:
|
|
117
|
+
self.success_count += 1
|
|
118
|
+
if self.success_count >= self.config.success_threshold:
|
|
119
|
+
# Recovered! Close the circuit
|
|
120
|
+
self.state = CircuitState.CLOSED
|
|
121
|
+
self.failure_count = 0
|
|
122
|
+
self.success_count = 0
|
|
123
|
+
elif self.state == CircuitState.CLOSED:
|
|
124
|
+
# Reset failure count on success
|
|
125
|
+
self.failure_count = 0
|
|
126
|
+
|
|
127
|
+
async def _on_failure(self) -> None:
|
|
128
|
+
"""Handle failed call."""
|
|
129
|
+
self.last_failure_time = time.time()
|
|
130
|
+
|
|
131
|
+
if self.state == CircuitState.HALF_OPEN:
|
|
132
|
+
# Failed during recovery - back to OPEN
|
|
133
|
+
self.state = CircuitState.OPEN
|
|
134
|
+
self.success_count = 0
|
|
135
|
+
elif self.state == CircuitState.CLOSED:
|
|
136
|
+
self.failure_count += 1
|
|
137
|
+
if self.failure_count >= self.config.failure_threshold:
|
|
138
|
+
# Too many failures - open the circuit
|
|
139
|
+
self.state = CircuitState.OPEN
|
|
140
|
+
|
|
141
|
+
def _should_attempt_reset(self) -> bool:
|
|
142
|
+
"""Check if enough time has passed to attempt reset."""
|
|
143
|
+
if self.last_failure_time is None:
|
|
144
|
+
return True
|
|
145
|
+
|
|
146
|
+
elapsed = time.time() - self.last_failure_time
|
|
147
|
+
return elapsed >= self.config.timeout_seconds
|
|
148
|
+
|
|
149
|
+
def _time_until_retry(self) -> float:
|
|
150
|
+
"""Calculate time until retry is allowed."""
|
|
151
|
+
if self.last_failure_time is None:
|
|
152
|
+
return 0.0
|
|
153
|
+
|
|
154
|
+
elapsed = time.time() - self.last_failure_time
|
|
155
|
+
remaining = self.config.timeout_seconds - elapsed
|
|
156
|
+
return max(0.0, remaining)
|
|
157
|
+
|
|
158
|
+
def get_state(self) -> dict:
|
|
159
|
+
"""Get current circuit breaker state.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
Dictionary with state information
|
|
163
|
+
"""
|
|
164
|
+
return {
|
|
165
|
+
"state": self.state.value,
|
|
166
|
+
"failure_count": self.failure_count,
|
|
167
|
+
"success_count": self.success_count,
|
|
168
|
+
"last_failure_time": self.last_failure_time,
|
|
169
|
+
"time_until_retry": self._time_until_retry() if self.state == CircuitState.OPEN else 0.0,
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
async def reset(self) -> None:
|
|
173
|
+
"""Manually reset the circuit breaker to CLOSED state."""
|
|
174
|
+
async with self._lock:
|
|
175
|
+
self.state = CircuitState.CLOSED
|
|
176
|
+
self.failure_count = 0
|
|
177
|
+
self.success_count = 0
|
|
178
|
+
self.last_failure_time = None
|