loom-agent 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of loom-agent might be problematic. Click here for more details.

Files changed (89) hide show
  1. loom/__init__.py +77 -0
  2. loom/agent.py +217 -0
  3. loom/agents/__init__.py +10 -0
  4. loom/agents/refs.py +28 -0
  5. loom/agents/registry.py +50 -0
  6. loom/builtin/compression/__init__.py +4 -0
  7. loom/builtin/compression/structured.py +79 -0
  8. loom/builtin/embeddings/__init__.py +9 -0
  9. loom/builtin/embeddings/openai_embedding.py +135 -0
  10. loom/builtin/embeddings/sentence_transformers_embedding.py +145 -0
  11. loom/builtin/llms/__init__.py +8 -0
  12. loom/builtin/llms/mock.py +34 -0
  13. loom/builtin/llms/openai.py +168 -0
  14. loom/builtin/llms/rule.py +102 -0
  15. loom/builtin/memory/__init__.py +5 -0
  16. loom/builtin/memory/in_memory.py +21 -0
  17. loom/builtin/memory/persistent_memory.py +278 -0
  18. loom/builtin/retriever/__init__.py +9 -0
  19. loom/builtin/retriever/chroma_store.py +265 -0
  20. loom/builtin/retriever/in_memory.py +106 -0
  21. loom/builtin/retriever/milvus_store.py +307 -0
  22. loom/builtin/retriever/pinecone_store.py +237 -0
  23. loom/builtin/retriever/qdrant_store.py +274 -0
  24. loom/builtin/retriever/vector_store.py +128 -0
  25. loom/builtin/retriever/vector_store_config.py +217 -0
  26. loom/builtin/tools/__init__.py +32 -0
  27. loom/builtin/tools/calculator.py +49 -0
  28. loom/builtin/tools/document_search.py +111 -0
  29. loom/builtin/tools/glob.py +27 -0
  30. loom/builtin/tools/grep.py +56 -0
  31. loom/builtin/tools/http_request.py +86 -0
  32. loom/builtin/tools/python_repl.py +73 -0
  33. loom/builtin/tools/read_file.py +32 -0
  34. loom/builtin/tools/task.py +158 -0
  35. loom/builtin/tools/web_search.py +64 -0
  36. loom/builtin/tools/write_file.py +31 -0
  37. loom/callbacks/base.py +9 -0
  38. loom/callbacks/logging.py +12 -0
  39. loom/callbacks/metrics.py +27 -0
  40. loom/callbacks/observability.py +248 -0
  41. loom/components/agent.py +107 -0
  42. loom/core/agent_executor.py +450 -0
  43. loom/core/circuit_breaker.py +178 -0
  44. loom/core/compression_manager.py +329 -0
  45. loom/core/context_retriever.py +185 -0
  46. loom/core/error_classifier.py +193 -0
  47. loom/core/errors.py +66 -0
  48. loom/core/message_queue.py +167 -0
  49. loom/core/permission_store.py +62 -0
  50. loom/core/permissions.py +69 -0
  51. loom/core/scheduler.py +125 -0
  52. loom/core/steering_control.py +47 -0
  53. loom/core/structured_logger.py +279 -0
  54. loom/core/subagent_pool.py +232 -0
  55. loom/core/system_prompt.py +141 -0
  56. loom/core/system_reminders.py +283 -0
  57. loom/core/tool_pipeline.py +113 -0
  58. loom/core/types.py +269 -0
  59. loom/interfaces/compressor.py +59 -0
  60. loom/interfaces/embedding.py +51 -0
  61. loom/interfaces/llm.py +33 -0
  62. loom/interfaces/memory.py +29 -0
  63. loom/interfaces/retriever.py +179 -0
  64. loom/interfaces/tool.py +27 -0
  65. loom/interfaces/vector_store.py +80 -0
  66. loom/llm/__init__.py +14 -0
  67. loom/llm/config.py +228 -0
  68. loom/llm/factory.py +111 -0
  69. loom/llm/model_health.py +235 -0
  70. loom/llm/model_pool_advanced.py +305 -0
  71. loom/llm/pool.py +170 -0
  72. loom/llm/registry.py +201 -0
  73. loom/mcp/__init__.py +4 -0
  74. loom/mcp/client.py +86 -0
  75. loom/mcp/registry.py +58 -0
  76. loom/mcp/tool_adapter.py +48 -0
  77. loom/observability/__init__.py +5 -0
  78. loom/patterns/__init__.py +5 -0
  79. loom/patterns/multi_agent.py +123 -0
  80. loom/patterns/rag.py +262 -0
  81. loom/plugins/registry.py +55 -0
  82. loom/resilience/__init__.py +5 -0
  83. loom/tooling.py +72 -0
  84. loom/utils/agent_loader.py +218 -0
  85. loom/utils/token_counter.py +19 -0
  86. loom_agent-0.0.1.dist-info/METADATA +457 -0
  87. loom_agent-0.0.1.dist-info/RECORD +89 -0
  88. loom_agent-0.0.1.dist-info/WHEEL +4 -0
  89. loom_agent-0.0.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,450 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ from typing import AsyncGenerator, Dict, List, Optional
5
+ from uuid import uuid4
6
+
7
+ from loom.core.steering_control import SteeringControl
8
+ from loom.core.tool_pipeline import ToolExecutionPipeline
9
+ from loom.core.types import Message, StreamEvent, ToolCall
10
+ from loom.core.system_prompt import build_system_prompt
11
+ from loom.interfaces.compressor import BaseCompressor
12
+ from loom.interfaces.llm import BaseLLM
13
+ from loom.interfaces.memory import BaseMemory
14
+ from loom.interfaces.tool import BaseTool
15
+ from loom.core.permissions import PermissionManager
16
+ from loom.callbacks.metrics import MetricsCollector
17
+ import time
18
+ from loom.utils.token_counter import count_messages_tokens
19
+ from loom.callbacks.base import BaseCallback
20
+ from loom.core.errors import ExecutionAbortedError
21
+
22
+ # RAG support
23
+ try:
24
+ from loom.core.context_retriever import ContextRetriever
25
+ except ImportError:
26
+ ContextRetriever = None # type: ignore
27
+
28
+
29
+ class AgentExecutor:
30
+ """Agent 执行器:封装主循环,连接 LLM、内存、工具流水线与事件流。"""
31
+
32
+ def __init__(
33
+ self,
34
+ llm: BaseLLM,
35
+ tools: Dict[str, BaseTool] | None = None,
36
+ memory: BaseMemory | None = None,
37
+ compressor: BaseCompressor | None = None,
38
+ context_retriever: Optional["ContextRetriever"] = None, # 🆕 RAG support
39
+ steering_control: SteeringControl | None = None,
40
+ max_iterations: int = 50,
41
+ max_context_tokens: int = 16000,
42
+ permission_manager: PermissionManager | None = None,
43
+ metrics: MetricsCollector | None = None,
44
+ system_instructions: Optional[str] = None,
45
+ callbacks: Optional[List[BaseCallback]] = None,
46
+ enable_steering: bool = False, # 🆕 US1: Real-time steering
47
+ ) -> None:
48
+ self.llm = llm
49
+ self.tools = tools or {}
50
+ self.memory = memory
51
+ self.compressor = compressor
52
+ self.context_retriever = context_retriever # 🆕 RAG support
53
+ self.steering_control = steering_control or SteeringControl()
54
+ self.max_iterations = max_iterations
55
+ self.max_context_tokens = max_context_tokens
56
+ self.metrics = metrics or MetricsCollector()
57
+ self.permission_manager = permission_manager or PermissionManager(policy={"default": "allow"})
58
+ self.system_instructions = system_instructions
59
+ self.callbacks = callbacks or []
60
+ self.enable_steering = enable_steering # 🆕 US1
61
+ self.tool_pipeline = ToolExecutionPipeline(
62
+ self.tools, permission_manager=self.permission_manager, metrics=self.metrics
63
+ )
64
+
65
+ async def _emit(self, event_type: str, payload: Dict) -> None:
66
+ if not self.callbacks:
67
+ return
68
+ enriched = dict(payload)
69
+ enriched.setdefault("ts", time.time())
70
+ enriched.setdefault("type", event_type)
71
+ for cb in self.callbacks:
72
+ try:
73
+ await cb.on_event(event_type, enriched)
74
+ except Exception:
75
+ # best-effort; don't fail agent execution on callback errors
76
+ pass
77
+
78
+ async def execute(
79
+ self,
80
+ user_input: str,
81
+ cancel_token: Optional[asyncio.Event] = None, # 🆕 US1: Cancellation support
82
+ correlation_id: Optional[str] = None, # 🆕 US1: Request tracing
83
+ ) -> str:
84
+ """非流式执行,包含工具调用的 ReAct 循环(最小实现)。
85
+
86
+ Args:
87
+ user_input: User query/instruction
88
+ cancel_token: Optional Event to signal cancellation (US1)
89
+ correlation_id: Optional correlation ID for request tracing (US1)
90
+
91
+ Returns:
92
+ Final agent response (or partial results if cancelled)
93
+ """
94
+ # Generate correlation_id if not provided
95
+ if correlation_id is None:
96
+ correlation_id = str(uuid4())
97
+
98
+ await self._emit("request_start", {
99
+ "input": user_input,
100
+ "source": "execute",
101
+ "iteration": 0,
102
+ "correlation_id": correlation_id, # 🆕 US1
103
+ })
104
+
105
+ # Check cancellation before starting
106
+ if cancel_token and cancel_token.is_set():
107
+ await self._emit("agent_finish", {
108
+ "content": "Request cancelled before execution",
109
+ "source": "execute",
110
+ "correlation_id": correlation_id,
111
+ })
112
+ return "Request cancelled before execution"
113
+
114
+ history = await self._load_history()
115
+
116
+ # 🆕 Step 1: RAG - 自动检索相关文档(如果配置了 context_retriever)
117
+ retrieved_docs = []
118
+ if self.context_retriever:
119
+ retrieved_docs = await self.context_retriever.retrieve_for_query(user_input)
120
+ if retrieved_docs:
121
+ # 注入检索到的文档上下文
122
+ if self.context_retriever.inject_as == "system":
123
+ doc_context = self.context_retriever.format_documents(retrieved_docs)
124
+ history.append(Message(
125
+ role="system",
126
+ content=doc_context,
127
+ metadata={"type": "retrieved_context", "doc_count": len(retrieved_docs)}
128
+ ))
129
+ # 记录检索指标
130
+ self.metrics.metrics.retrievals = getattr(self.metrics.metrics, "retrievals", 0) + 1
131
+ await self._emit("retrieval_complete", {"doc_count": len(retrieved_docs), "source": "execute"})
132
+
133
+ # Step 2: 添加用户消息
134
+ history.append(Message(role="user", content=user_input))
135
+
136
+ # Step 3: 压缩检查
137
+ history = await self._maybe_compress(history)
138
+
139
+ # Step 4: 动态生成系统提示
140
+ context = {"retrieved_docs_count": len(retrieved_docs)} if retrieved_docs else None
141
+ system_prompt = build_system_prompt(self.tools, self.system_instructions, context)
142
+ history = self._inject_system_prompt(history, system_prompt)
143
+
144
+ if not self.llm.supports_tools or not self.tools:
145
+ try:
146
+ # Create LLM task that can be cancelled
147
+ llm_task = asyncio.create_task(self.llm.generate([m.__dict__ for m in history]))
148
+
149
+ # Poll for cancellation while waiting for LLM
150
+ while not llm_task.done():
151
+ if cancel_token and cancel_token.is_set():
152
+ llm_task.cancel()
153
+ try:
154
+ await llm_task
155
+ except asyncio.CancelledError:
156
+ pass
157
+ partial_result = "Execution interrupted during LLM call"
158
+ await self._emit("agent_finish", {
159
+ "content": partial_result,
160
+ "source": "execute",
161
+ "correlation_id": correlation_id,
162
+ "interrupted": True,
163
+ })
164
+ return partial_result
165
+ await asyncio.sleep(0.1) # Check every 100ms
166
+
167
+ text = await llm_task
168
+ except asyncio.CancelledError:
169
+ partial_result = "Execution interrupted during LLM call"
170
+ await self._emit("agent_finish", {
171
+ "content": partial_result,
172
+ "source": "execute",
173
+ "correlation_id": correlation_id,
174
+ "interrupted": True,
175
+ })
176
+ return partial_result
177
+ except Exception as e:
178
+ self.metrics.metrics.total_errors += 1
179
+ await self._emit("error", {"stage": "llm_generate", "message": str(e)})
180
+ raise
181
+ self.metrics.metrics.llm_calls += 1
182
+ if self.memory:
183
+ await self.memory.add_message(Message(role="assistant", content=text))
184
+ await self._emit("agent_finish", {"content": text, "source": "execute"})
185
+ return text
186
+
187
+ tools_spec = self._serialize_tools()
188
+ iterations = 0
189
+ final_text = ""
190
+ while iterations < self.max_iterations:
191
+ # 🆕 US1: Check cancellation before each iteration
192
+ if cancel_token and cancel_token.is_set():
193
+ partial_result = f"Execution interrupted after {iterations} iterations. Partial progress: {final_text or '(in progress)'}"
194
+ await self._emit("agent_finish", {
195
+ "content": partial_result,
196
+ "source": "execute",
197
+ "correlation_id": correlation_id,
198
+ "interrupted": True,
199
+ "iterations_completed": iterations,
200
+ })
201
+ return partial_result
202
+
203
+ try:
204
+ resp = await self.llm.generate_with_tools([m.__dict__ for m in history], tools_spec)
205
+ except Exception as e:
206
+ self.metrics.metrics.total_errors += 1
207
+ await self._emit("error", {
208
+ "stage": "llm_generate_with_tools",
209
+ "message": str(e),
210
+ "source": "execute",
211
+ "iteration": iterations,
212
+ "correlation_id": correlation_id, # 🆕 US1
213
+ })
214
+ raise
215
+ self.metrics.metrics.llm_calls += 1
216
+ tool_calls = resp.get("tool_calls") or []
217
+ content = resp.get("content") or ""
218
+
219
+ if tool_calls:
220
+ # 广播工具调用开始(非流式路径)
221
+ try:
222
+ meta = [
223
+ {"id": str(tc.get("id", "")), "name": str(tc.get("name", ""))}
224
+ for tc in tool_calls
225
+ ]
226
+ await self._emit("tool_calls_start", {"tool_calls": meta, "source": "execute", "iteration": iterations})
227
+ except Exception:
228
+ pass
229
+ # 执行工具并把结果写回消息
230
+ try:
231
+ for tr in await self._execute_tool_batch(tool_calls):
232
+ tool_msg = Message(role="tool", content=tr.content, tool_call_id=tr.tool_call_id)
233
+ history.append(tool_msg)
234
+ if self.memory:
235
+ await self.memory.add_message(tool_msg)
236
+ await self._emit("tool_result", {"tool_call_id": tr.tool_call_id, "content": tr.content, "source": "execute", "iteration": iterations})
237
+ except Exception as e:
238
+ self.metrics.metrics.total_errors += 1
239
+ await self._emit("error", {"stage": "tool_execute", "message": str(e), "source": "execute", "iteration": iterations})
240
+ raise
241
+ iterations += 1
242
+ self.metrics.metrics.total_iterations += 1
243
+ history = await self._maybe_compress(history)
244
+ continue
245
+
246
+ # 无工具调用:认为生成最终答案
247
+ final_text = content
248
+ if self.memory:
249
+ await self.memory.add_message(Message(role="assistant", content=final_text))
250
+ await self._emit("agent_finish", {
251
+ "content": final_text,
252
+ "source": "execute",
253
+ "correlation_id": correlation_id, # 🆕 US1
254
+ })
255
+ break
256
+
257
+ return final_text
258
+
259
+ async def stream(self, user_input: str) -> AsyncGenerator[StreamEvent, None]:
260
+ """流式执行:输出 text_delta/agent_finish 事件。后续可接入 tool_calls。"""
261
+ yield StreamEvent(type="request_start")
262
+ await self._emit("request_start", {"input": user_input, "source": "stream", "iteration": 0})
263
+ history = await self._load_history()
264
+
265
+ # 🆕 RAG - 自动检索文档
266
+ retrieved_docs = []
267
+ if self.context_retriever:
268
+ retrieved_docs = await self.context_retriever.retrieve_for_query(user_input)
269
+ if retrieved_docs:
270
+ if self.context_retriever.inject_as == "system":
271
+ doc_context = self.context_retriever.format_documents(retrieved_docs)
272
+ history.append(Message(
273
+ role="system",
274
+ content=doc_context,
275
+ metadata={"type": "retrieved_context", "doc_count": len(retrieved_docs)}
276
+ ))
277
+ self.metrics.metrics.retrievals = getattr(self.metrics.metrics, "retrievals", 0) + 1
278
+ # 🆕 广播检索事件
279
+ yield StreamEvent(type="retrieval_complete", metadata={"doc_count": len(retrieved_docs)})
280
+ await self._emit("retrieval_complete", {"doc_count": len(retrieved_docs), "source": "stream"})
281
+
282
+ history.append(Message(role="user", content=user_input))
283
+
284
+ # 压缩检查
285
+ compressed = await self._maybe_compress(history)
286
+ if compressed is not history:
287
+ history = compressed
288
+ yield StreamEvent(type="compression_applied")
289
+
290
+ # 动态生成系统提示
291
+ context = {"retrieved_docs_count": len(retrieved_docs)} if retrieved_docs else None
292
+ system_prompt = build_system_prompt(self.tools, self.system_instructions, context)
293
+ history = self._inject_system_prompt(history, system_prompt)
294
+
295
+ if not self.llm.supports_tools or not self.tools:
296
+ try:
297
+ async for delta in self.llm.stream([m.__dict__ for m in history]):
298
+ yield StreamEvent(type="text_delta", content=delta)
299
+ except Exception as e:
300
+ self.metrics.metrics.total_errors += 1
301
+ await self._emit("error", {"stage": "llm_stream", "message": str(e), "source": "stream"})
302
+ raise
303
+ yield StreamEvent(type="agent_finish")
304
+ return
305
+
306
+ tools_spec = self._serialize_tools()
307
+ iterations = 0
308
+ while iterations < self.max_iterations:
309
+ try:
310
+ resp = await self.llm.generate_with_tools([m.__dict__ for m in history], tools_spec)
311
+ except Exception as e:
312
+ self.metrics.metrics.total_errors += 1
313
+ await self._emit("error", {"stage": "llm_generate_with_tools", "message": str(e), "source": "stream", "iteration": iterations})
314
+ raise
315
+ self.metrics.metrics.llm_calls += 1
316
+ tool_calls = resp.get("tool_calls") or []
317
+ content = resp.get("content") or ""
318
+
319
+ if tool_calls:
320
+ # 广播工具调用开始
321
+ tc_models = [self._to_tool_call(tc) for tc in tool_calls]
322
+ yield StreamEvent(type="tool_calls_start", tool_calls=tc_models)
323
+ await self._emit(
324
+ "tool_calls_start",
325
+ {"tool_calls": [{"id": t.id, "name": t.name} for t in tc_models], "source": "stream", "iteration": iterations},
326
+ )
327
+ # 执行工具
328
+ try:
329
+ async for tr in self._execute_tool_calls_async(tc_models):
330
+ yield StreamEvent(type="tool_result", result=tr)
331
+ await self._emit("tool_result", {"tool_call_id": tr.tool_call_id, "content": tr.content, "source": "stream", "iteration": iterations})
332
+ tool_msg = Message(role="tool", content=tr.content, tool_call_id=tr.tool_call_id)
333
+ history.append(tool_msg)
334
+ if self.memory:
335
+ await self.memory.add_message(tool_msg)
336
+ except Exception as e:
337
+ self.metrics.metrics.total_errors += 1
338
+ await self._emit("error", {"stage": "tool_execute", "message": str(e), "source": "stream", "iteration": iterations})
339
+ raise
340
+ iterations += 1
341
+ self.metrics.metrics.total_iterations += 1
342
+ # 每轮结束后做压缩检查
343
+ history = await self._maybe_compress(history)
344
+ continue
345
+
346
+ # 无工具调用:输出最终文本并结束
347
+ if content:
348
+ yield StreamEvent(type="text_delta", content=content)
349
+ yield StreamEvent(type="agent_finish")
350
+ await self._emit("agent_finish", {"content": content})
351
+ if self.memory and content:
352
+ await self.memory.add_message(Message(role="assistant", content=content))
353
+ break
354
+
355
+ async def _load_history(self) -> List[Message]:
356
+ if not self.memory:
357
+ return []
358
+ return await self.memory.get_messages()
359
+
360
+ async def _maybe_compress(self, history: List[Message]) -> List[Message]:
361
+ """Check if compression needed and apply if threshold reached.
362
+
363
+ US2: Automatic compression at 92% threshold with 8-segment summarization.
364
+ """
365
+ if not self.compressor:
366
+ return history
367
+
368
+ tokens_before = count_messages_tokens(history)
369
+
370
+ # Check if compression should be triggered (92% threshold)
371
+ if self.compressor.should_compress(tokens_before, self.max_context_tokens):
372
+ # Attempt compression
373
+ try:
374
+ compressed_messages, metadata = await self.compressor.compress(history)
375
+
376
+ # Update metrics
377
+ self.metrics.metrics.compressions = getattr(self.metrics.metrics, "compressions", 0) + 1
378
+ if metadata.key_topics == ["fallback"]:
379
+ self.metrics.metrics.compression_fallbacks = getattr(self.metrics.metrics, "compression_fallbacks", 0) + 1
380
+
381
+ # Emit compression event with metadata
382
+ await self._emit(
383
+ "compression_applied",
384
+ {
385
+ "before_tokens": metadata.original_tokens,
386
+ "after_tokens": metadata.compressed_tokens,
387
+ "compression_ratio": metadata.compression_ratio,
388
+ "original_message_count": metadata.original_message_count,
389
+ "compressed_message_count": metadata.compressed_message_count,
390
+ "key_topics": metadata.key_topics,
391
+ "fallback_used": metadata.key_topics == ["fallback"],
392
+ },
393
+ )
394
+
395
+ return compressed_messages
396
+
397
+ except Exception as e:
398
+ # Compression failed - continue without compression
399
+ self.metrics.metrics.total_errors += 1
400
+ await self._emit("error", {
401
+ "stage": "compression",
402
+ "message": str(e),
403
+ })
404
+ return history
405
+
406
+ return history
407
+
408
+ def _serialize_tools(self) -> List[Dict]:
409
+ tools_spec: List[Dict] = []
410
+ for t in self.tools.values():
411
+ schema = {}
412
+ try:
413
+ schema = t.args_schema.model_json_schema() # type: ignore[attr-defined]
414
+ except Exception:
415
+ schema = {"type": "object", "properties": {}}
416
+ tools_spec.append(
417
+ {
418
+ "type": "function",
419
+ "function": {
420
+ "name": t.name,
421
+ "description": getattr(t, "description", ""),
422
+ "parameters": schema,
423
+ },
424
+ }
425
+ )
426
+ return tools_spec
427
+
428
+ def _to_tool_call(self, raw: Dict) -> ToolCall:
429
+ # 允许 Rule/Mock LLM 输出简单 dict
430
+ return ToolCall(id=str(raw.get("id", "call_0")), name=raw["name"], arguments=raw.get("arguments", {}))
431
+
432
+ async def _execute_tool_batch(self, tool_calls_raw: List[Dict]) -> List[ToolResult]:
433
+ tc_models = [self._to_tool_call(tc) for tc in tool_calls_raw]
434
+ results: List[ToolResult] = []
435
+ async for tr in self._execute_tool_calls_async(tc_models):
436
+ results.append(tr)
437
+ return results
438
+
439
+ async def _execute_tool_calls_async(self, tool_calls: List[ToolCall]):
440
+ async for tr in self.tool_pipeline.execute_calls(tool_calls):
441
+ yield tr
442
+
443
+ def _inject_system_prompt(self, history: List[Message], system_prompt: str) -> List[Message]:
444
+ """注入或更新系统提示消息"""
445
+ # 如果第一条是系统消息,则替换;否则在开头插入
446
+ if history and history[0].role == "system":
447
+ history[0] = Message(role="system", content=system_prompt)
448
+ else:
449
+ history.insert(0, Message(role="system", content=system_prompt))
450
+ return history
@@ -0,0 +1,178 @@
1
+ """US5: Circuit Breaker Pattern
2
+
3
+ Implements the circuit breaker pattern to prevent cascading failures.
4
+
5
+ States:
6
+ - CLOSED: Normal operation, requests pass through
7
+ - OPEN: Failures exceeded threshold, requests fail fast
8
+ - HALF_OPEN: Testing if service recovered, limited requests allowed
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import asyncio
14
+ import time
15
+ from enum import Enum
16
+ from typing import Callable, Any, Optional
17
+ from dataclasses import dataclass
18
+
19
+
20
+ class CircuitState(str, Enum):
21
+ """Circuit breaker states."""
22
+ CLOSED = "closed" # Normal operation
23
+ OPEN = "open" # Failing fast
24
+ HALF_OPEN = "half_open" # Testing recovery
25
+
26
+
27
+ @dataclass
28
+ class CircuitBreakerConfig:
29
+ """Configuration for circuit breaker."""
30
+ failure_threshold: int = 5 # Failures before opening circuit
31
+ success_threshold: int = 2 # Successes in half-open before closing
32
+ timeout_seconds: float = 60.0 # Time to wait before trying half-open
33
+ exclude_exceptions: tuple = () # Exceptions that don't count as failures
34
+
35
+
36
+ class CircuitBreakerOpenError(Exception):
37
+ """Raised when circuit breaker is open."""
38
+ pass
39
+
40
+
41
+ class CircuitBreaker:
42
+ """Circuit breaker implementation.
43
+
44
+ Example:
45
+ breaker = CircuitBreaker()
46
+
47
+ async def call_external_service():
48
+ async with breaker:
49
+ return await some_external_api_call()
50
+ """
51
+
52
+ def __init__(self, config: Optional[CircuitBreakerConfig] = None):
53
+ """Initialize circuit breaker.
54
+
55
+ Args:
56
+ config: Circuit breaker configuration
57
+ """
58
+ self.config = config or CircuitBreakerConfig()
59
+ self.state = CircuitState.CLOSED
60
+ self.failure_count = 0
61
+ self.success_count = 0
62
+ self.last_failure_time: Optional[float] = None
63
+ self._lock = asyncio.Lock()
64
+
65
+ async def __aenter__(self):
66
+ """Context manager entry - check if circuit allows request."""
67
+ async with self._lock:
68
+ # Check if we should transition from OPEN to HALF_OPEN
69
+ if self.state == CircuitState.OPEN:
70
+ if self._should_attempt_reset():
71
+ self.state = CircuitState.HALF_OPEN
72
+ self.success_count = 0
73
+ else:
74
+ raise CircuitBreakerOpenError(
75
+ f"Circuit breaker is OPEN. "
76
+ f"Retry after {self._time_until_retry():.1f}s"
77
+ )
78
+
79
+ return self
80
+
81
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
82
+ """Context manager exit - record success/failure."""
83
+ async with self._lock:
84
+ if exc_type is None:
85
+ # Success
86
+ await self._on_success()
87
+ elif not isinstance(exc_val, self.config.exclude_exceptions):
88
+ # Failure (unless excluded)
89
+ await self._on_failure()
90
+
91
+ return False # Don't suppress exceptions
92
+
93
+ async def call(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
94
+ """Call a function through the circuit breaker.
95
+
96
+ Args:
97
+ func: Async function to call
98
+ *args: Positional arguments
99
+ **kwargs: Keyword arguments
100
+
101
+ Returns:
102
+ Result of the function
103
+
104
+ Raises:
105
+ CircuitBreakerOpenError: If circuit is open
106
+ Exception: Any exception from the function
107
+ """
108
+ async with self:
109
+ if asyncio.iscoroutinefunction(func):
110
+ return await func(*args, **kwargs)
111
+ else:
112
+ return func(*args, **kwargs)
113
+
114
+ async def _on_success(self) -> None:
115
+ """Handle successful call."""
116
+ if self.state == CircuitState.HALF_OPEN:
117
+ self.success_count += 1
118
+ if self.success_count >= self.config.success_threshold:
119
+ # Recovered! Close the circuit
120
+ self.state = CircuitState.CLOSED
121
+ self.failure_count = 0
122
+ self.success_count = 0
123
+ elif self.state == CircuitState.CLOSED:
124
+ # Reset failure count on success
125
+ self.failure_count = 0
126
+
127
+ async def _on_failure(self) -> None:
128
+ """Handle failed call."""
129
+ self.last_failure_time = time.time()
130
+
131
+ if self.state == CircuitState.HALF_OPEN:
132
+ # Failed during recovery - back to OPEN
133
+ self.state = CircuitState.OPEN
134
+ self.success_count = 0
135
+ elif self.state == CircuitState.CLOSED:
136
+ self.failure_count += 1
137
+ if self.failure_count >= self.config.failure_threshold:
138
+ # Too many failures - open the circuit
139
+ self.state = CircuitState.OPEN
140
+
141
+ def _should_attempt_reset(self) -> bool:
142
+ """Check if enough time has passed to attempt reset."""
143
+ if self.last_failure_time is None:
144
+ return True
145
+
146
+ elapsed = time.time() - self.last_failure_time
147
+ return elapsed >= self.config.timeout_seconds
148
+
149
+ def _time_until_retry(self) -> float:
150
+ """Calculate time until retry is allowed."""
151
+ if self.last_failure_time is None:
152
+ return 0.0
153
+
154
+ elapsed = time.time() - self.last_failure_time
155
+ remaining = self.config.timeout_seconds - elapsed
156
+ return max(0.0, remaining)
157
+
158
+ def get_state(self) -> dict:
159
+ """Get current circuit breaker state.
160
+
161
+ Returns:
162
+ Dictionary with state information
163
+ """
164
+ return {
165
+ "state": self.state.value,
166
+ "failure_count": self.failure_count,
167
+ "success_count": self.success_count,
168
+ "last_failure_time": self.last_failure_time,
169
+ "time_until_retry": self._time_until_retry() if self.state == CircuitState.OPEN else 0.0,
170
+ }
171
+
172
+ async def reset(self) -> None:
173
+ """Manually reset the circuit breaker to CLOSED state."""
174
+ async with self._lock:
175
+ self.state = CircuitState.CLOSED
176
+ self.failure_count = 0
177
+ self.success_count = 0
178
+ self.last_failure_time = None