loom-agent 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of loom-agent might be problematic. Click here for more details.
- loom/__init__.py +77 -0
- loom/agent.py +217 -0
- loom/agents/__init__.py +10 -0
- loom/agents/refs.py +28 -0
- loom/agents/registry.py +50 -0
- loom/builtin/compression/__init__.py +4 -0
- loom/builtin/compression/structured.py +79 -0
- loom/builtin/embeddings/__init__.py +9 -0
- loom/builtin/embeddings/openai_embedding.py +135 -0
- loom/builtin/embeddings/sentence_transformers_embedding.py +145 -0
- loom/builtin/llms/__init__.py +8 -0
- loom/builtin/llms/mock.py +34 -0
- loom/builtin/llms/openai.py +168 -0
- loom/builtin/llms/rule.py +102 -0
- loom/builtin/memory/__init__.py +5 -0
- loom/builtin/memory/in_memory.py +21 -0
- loom/builtin/memory/persistent_memory.py +278 -0
- loom/builtin/retriever/__init__.py +9 -0
- loom/builtin/retriever/chroma_store.py +265 -0
- loom/builtin/retriever/in_memory.py +106 -0
- loom/builtin/retriever/milvus_store.py +307 -0
- loom/builtin/retriever/pinecone_store.py +237 -0
- loom/builtin/retriever/qdrant_store.py +274 -0
- loom/builtin/retriever/vector_store.py +128 -0
- loom/builtin/retriever/vector_store_config.py +217 -0
- loom/builtin/tools/__init__.py +32 -0
- loom/builtin/tools/calculator.py +49 -0
- loom/builtin/tools/document_search.py +111 -0
- loom/builtin/tools/glob.py +27 -0
- loom/builtin/tools/grep.py +56 -0
- loom/builtin/tools/http_request.py +86 -0
- loom/builtin/tools/python_repl.py +73 -0
- loom/builtin/tools/read_file.py +32 -0
- loom/builtin/tools/task.py +158 -0
- loom/builtin/tools/web_search.py +64 -0
- loom/builtin/tools/write_file.py +31 -0
- loom/callbacks/base.py +9 -0
- loom/callbacks/logging.py +12 -0
- loom/callbacks/metrics.py +27 -0
- loom/callbacks/observability.py +248 -0
- loom/components/agent.py +107 -0
- loom/core/agent_executor.py +450 -0
- loom/core/circuit_breaker.py +178 -0
- loom/core/compression_manager.py +329 -0
- loom/core/context_retriever.py +185 -0
- loom/core/error_classifier.py +193 -0
- loom/core/errors.py +66 -0
- loom/core/message_queue.py +167 -0
- loom/core/permission_store.py +62 -0
- loom/core/permissions.py +69 -0
- loom/core/scheduler.py +125 -0
- loom/core/steering_control.py +47 -0
- loom/core/structured_logger.py +279 -0
- loom/core/subagent_pool.py +232 -0
- loom/core/system_prompt.py +141 -0
- loom/core/system_reminders.py +283 -0
- loom/core/tool_pipeline.py +113 -0
- loom/core/types.py +269 -0
- loom/interfaces/compressor.py +59 -0
- loom/interfaces/embedding.py +51 -0
- loom/interfaces/llm.py +33 -0
- loom/interfaces/memory.py +29 -0
- loom/interfaces/retriever.py +179 -0
- loom/interfaces/tool.py +27 -0
- loom/interfaces/vector_store.py +80 -0
- loom/llm/__init__.py +14 -0
- loom/llm/config.py +228 -0
- loom/llm/factory.py +111 -0
- loom/llm/model_health.py +235 -0
- loom/llm/model_pool_advanced.py +305 -0
- loom/llm/pool.py +170 -0
- loom/llm/registry.py +201 -0
- loom/mcp/__init__.py +4 -0
- loom/mcp/client.py +86 -0
- loom/mcp/registry.py +58 -0
- loom/mcp/tool_adapter.py +48 -0
- loom/observability/__init__.py +5 -0
- loom/patterns/__init__.py +5 -0
- loom/patterns/multi_agent.py +123 -0
- loom/patterns/rag.py +262 -0
- loom/plugins/registry.py +55 -0
- loom/resilience/__init__.py +5 -0
- loom/tooling.py +72 -0
- loom/utils/agent_loader.py +218 -0
- loom/utils/token_counter.py +19 -0
- loom_agent-0.0.1.dist-info/METADATA +457 -0
- loom_agent-0.0.1.dist-info/RECORD +89 -0
- loom_agent-0.0.1.dist-info/WHEEL +4 -0
- loom_agent-0.0.1.dist-info/licenses/LICENSE +21 -0
loom/core/errors.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class ErrorCategory(Enum):
|
|
5
|
+
"""Error classification for retry logic (T018 - US5)."""
|
|
6
|
+
|
|
7
|
+
NETWORK_ERROR = "network_error" # httpx.TimeoutException, httpx.ConnectError - retryable
|
|
8
|
+
TIMEOUT_ERROR = "timeout_error" # asyncio.TimeoutError - retryable
|
|
9
|
+
RATE_LIMIT_ERROR = "rate_limit_error" # 429 responses - retryable with backoff
|
|
10
|
+
VALIDATION_ERROR = "validation_error" # Pydantic ValidationError - non-retryable
|
|
11
|
+
PERMISSION_ERROR = "permission_error" # PermissionDeniedError - non-retryable
|
|
12
|
+
AUTHENTICATION_ERROR = "authentication_error" # 401/403 - non-retryable
|
|
13
|
+
SERVICE_ERROR = "service_error" # 5xx errors - retryable
|
|
14
|
+
NOT_FOUND_ERROR = "not_found_error" # 404, ToolNotFoundError - non-retryable
|
|
15
|
+
UNKNOWN_ERROR = "unknown_error" # Catch-all - non-retryable by default
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class LoomException(Exception):
|
|
19
|
+
"""Base exception for Loom framework."""
|
|
20
|
+
|
|
21
|
+
def __init__(self, message: str, category: ErrorCategory = ErrorCategory.UNKNOWN_ERROR) -> None:
|
|
22
|
+
super().__init__(message)
|
|
23
|
+
self.category = category
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ToolNotFoundError(LoomException):
|
|
27
|
+
"""Tool not found in registry."""
|
|
28
|
+
|
|
29
|
+
def __init__(self, message: str) -> None:
|
|
30
|
+
super().__init__(message, category=ErrorCategory.NOT_FOUND_ERROR)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class ToolValidationError(LoomException):
|
|
34
|
+
"""Tool argument validation failed."""
|
|
35
|
+
|
|
36
|
+
def __init__(self, message: str) -> None:
|
|
37
|
+
super().__init__(message, category=ErrorCategory.VALIDATION_ERROR)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class PermissionDeniedError(LoomException):
|
|
41
|
+
"""Permission check failed."""
|
|
42
|
+
|
|
43
|
+
def __init__(self, message: str) -> None:
|
|
44
|
+
super().__init__(message, category=ErrorCategory.PERMISSION_ERROR)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class ToolExecutionTimeout(LoomException):
|
|
48
|
+
"""Tool execution exceeded timeout."""
|
|
49
|
+
|
|
50
|
+
def __init__(self, message: str) -> None:
|
|
51
|
+
super().__init__(message, category=ErrorCategory.TIMEOUT_ERROR)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class ExecutionAbortedError(LoomException):
|
|
55
|
+
"""Execution aborted by user."""
|
|
56
|
+
|
|
57
|
+
def __init__(self, message: str) -> None:
|
|
58
|
+
super().__init__(message, category=ErrorCategory.UNKNOWN_ERROR)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class RecursionLimitError(LoomException):
|
|
62
|
+
"""Sub-agent recursion depth exceeded (US3)."""
|
|
63
|
+
|
|
64
|
+
def __init__(self, message: str) -> None:
|
|
65
|
+
super().__init__(message, category=ErrorCategory.VALIDATION_ERROR)
|
|
66
|
+
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
"""Message Queue for h2A Real-Time Steering (US1)
|
|
2
|
+
|
|
3
|
+
This module implements the async priority message queue that enables:
|
|
4
|
+
- Real-time agent interruption and cancellation
|
|
5
|
+
- Priority-based message processing
|
|
6
|
+
- Graceful shutdown with partial results
|
|
7
|
+
- Correlation ID tracking for multi-agent workflows
|
|
8
|
+
|
|
9
|
+
Architecture: h2A async message queue (Claude Code inspired)
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import asyncio
|
|
15
|
+
from typing import Optional
|
|
16
|
+
from uuid import uuid4
|
|
17
|
+
|
|
18
|
+
from loom.core.types import MessageQueueItem
|
|
19
|
+
from loom.core.errors import ExecutionAbortedError
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class MessageQueue:
|
|
23
|
+
"""h2A async priority message queue for real-time steering.
|
|
24
|
+
|
|
25
|
+
Features:
|
|
26
|
+
- Priority-based ordering (10 = highest, 0 = lowest)
|
|
27
|
+
- FIFO within same priority level
|
|
28
|
+
- Cancel-all support for graceful shutdown
|
|
29
|
+
- Correlation ID propagation
|
|
30
|
+
|
|
31
|
+
Usage:
|
|
32
|
+
queue = MessageQueue()
|
|
33
|
+
await queue.put(MessageQueueItem(role="user", content="Task", priority=5))
|
|
34
|
+
item = await queue.get() # Blocks until item available
|
|
35
|
+
await queue.cancel_all() # Clear all pending items
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(self, cancel_token: Optional[asyncio.Event] = None) -> None:
|
|
39
|
+
"""Initialize message queue.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
cancel_token: Optional Event to signal cancellation from outside
|
|
43
|
+
"""
|
|
44
|
+
# Use asyncio.PriorityQueue for automatic priority sorting
|
|
45
|
+
# Items sorted by (priority, insertion_order) tuple
|
|
46
|
+
self._queue: asyncio.PriorityQueue = asyncio.PriorityQueue()
|
|
47
|
+
self._insertion_counter = 0 # For FIFO within same priority
|
|
48
|
+
self._cancelled = asyncio.Event()
|
|
49
|
+
self._external_cancel_token = cancel_token
|
|
50
|
+
|
|
51
|
+
async def put(self, item: MessageQueueItem) -> None:
|
|
52
|
+
"""Add item to queue with priority ordering.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
item: Message queue item with priority (0-10)
|
|
56
|
+
|
|
57
|
+
Note: Higher priority numbers are processed first.
|
|
58
|
+
Inverted for PriorityQueue (lower tuple values dequeued first).
|
|
59
|
+
"""
|
|
60
|
+
if self._is_cancelled():
|
|
61
|
+
raise ExecutionAbortedError("Queue cancelled, cannot add new items")
|
|
62
|
+
|
|
63
|
+
# Invert priority for PriorityQueue (10 becomes -10, so it's dequeued first)
|
|
64
|
+
priority = -item.priority
|
|
65
|
+
|
|
66
|
+
# Use insertion counter for FIFO within same priority
|
|
67
|
+
insertion_order = self._insertion_counter
|
|
68
|
+
self._insertion_counter += 1
|
|
69
|
+
|
|
70
|
+
# PriorityQueue sorts by tuple: (priority, insertion_order, item)
|
|
71
|
+
await self._queue.put((priority, insertion_order, item))
|
|
72
|
+
|
|
73
|
+
async def get(self, timeout: Optional[float] = None) -> MessageQueueItem:
|
|
74
|
+
"""Get highest priority item from queue.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
timeout: Optional timeout in seconds (None = wait forever)
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
MessageQueueItem: Highest priority item
|
|
81
|
+
|
|
82
|
+
Raises:
|
|
83
|
+
ExecutionAbortedError: If queue cancelled during get()
|
|
84
|
+
asyncio.TimeoutError: If timeout expires
|
|
85
|
+
"""
|
|
86
|
+
try:
|
|
87
|
+
if timeout is not None:
|
|
88
|
+
priority, order, item = await asyncio.wait_for(
|
|
89
|
+
self._queue.get(), timeout=timeout
|
|
90
|
+
)
|
|
91
|
+
else:
|
|
92
|
+
# Check cancellation before blocking
|
|
93
|
+
if self._is_cancelled():
|
|
94
|
+
raise ExecutionAbortedError("Queue cancelled")
|
|
95
|
+
|
|
96
|
+
priority, order, item = await self._queue.get()
|
|
97
|
+
|
|
98
|
+
# Check cancellation after get (in case cancelled while waiting)
|
|
99
|
+
if self._is_cancelled():
|
|
100
|
+
# Put item back if it's cancellable
|
|
101
|
+
if item.cancellable:
|
|
102
|
+
raise ExecutionAbortedError("Queue cancelled")
|
|
103
|
+
# Non-cancellable items still processed
|
|
104
|
+
|
|
105
|
+
return item
|
|
106
|
+
|
|
107
|
+
except asyncio.TimeoutError:
|
|
108
|
+
raise
|
|
109
|
+
|
|
110
|
+
async def cancel_all(self) -> None:
|
|
111
|
+
"""Cancel all pending items and prevent new additions.
|
|
112
|
+
|
|
113
|
+
This signals graceful shutdown - current processing continues,
|
|
114
|
+
but no new items will be dequeued.
|
|
115
|
+
"""
|
|
116
|
+
self._cancelled.set()
|
|
117
|
+
|
|
118
|
+
# Drain queue of cancellable items
|
|
119
|
+
drained = []
|
|
120
|
+
while not self._queue.empty():
|
|
121
|
+
try:
|
|
122
|
+
priority, order, item = self._queue.get_nowait()
|
|
123
|
+
if not item.cancellable:
|
|
124
|
+
drained.append((priority, order, item)) # Keep non-cancellable
|
|
125
|
+
except asyncio.QueueEmpty:
|
|
126
|
+
break
|
|
127
|
+
|
|
128
|
+
# Re-add non-cancellable items
|
|
129
|
+
for priority, order, item in drained:
|
|
130
|
+
await self._queue.put((priority, order, item))
|
|
131
|
+
|
|
132
|
+
def is_empty(self) -> bool:
|
|
133
|
+
"""Check if queue is empty."""
|
|
134
|
+
return self._queue.empty()
|
|
135
|
+
|
|
136
|
+
def is_cancelled(self) -> bool:
|
|
137
|
+
"""Check if queue has been cancelled."""
|
|
138
|
+
return self._is_cancelled()
|
|
139
|
+
|
|
140
|
+
def _is_cancelled(self) -> bool:
|
|
141
|
+
"""Internal cancellation check (includes external token)."""
|
|
142
|
+
if self._cancelled.is_set():
|
|
143
|
+
return True
|
|
144
|
+
if self._external_cancel_token and self._external_cancel_token.is_set():
|
|
145
|
+
return True
|
|
146
|
+
return False
|
|
147
|
+
|
|
148
|
+
def qsize(self) -> int:
|
|
149
|
+
"""Return approximate queue size."""
|
|
150
|
+
return self._queue.qsize()
|
|
151
|
+
|
|
152
|
+
async def peek(self) -> Optional[MessageQueueItem]:
|
|
153
|
+
"""Peek at highest priority item without removing it.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
MessageQueueItem if queue not empty, None otherwise
|
|
157
|
+
"""
|
|
158
|
+
if self._queue.empty():
|
|
159
|
+
return None
|
|
160
|
+
|
|
161
|
+
# Get item
|
|
162
|
+
priority, order, item = await self._queue.get()
|
|
163
|
+
|
|
164
|
+
# Put it back
|
|
165
|
+
await self._queue.put((priority, order, item))
|
|
166
|
+
|
|
167
|
+
return item
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import List
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
DEFAULT_DIR = Path.home() / ".loom"
|
|
10
|
+
DEFAULT_PATH = DEFAULT_DIR / "config.json"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class PermissionStore:
|
|
15
|
+
"""Persistent allow-list for tools (framework capability).
|
|
16
|
+
|
|
17
|
+
Schema stored at ~/.loom/config.json:
|
|
18
|
+
{
|
|
19
|
+
"allowed_tools": ["*" | "tool_name", ...]
|
|
20
|
+
}
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
allowed_tools: List[str] = field(default_factory=list)
|
|
24
|
+
path: Path = DEFAULT_PATH
|
|
25
|
+
|
|
26
|
+
@classmethod
|
|
27
|
+
def load_default(cls) -> "PermissionStore":
|
|
28
|
+
try:
|
|
29
|
+
if DEFAULT_PATH.exists():
|
|
30
|
+
data = json.loads(DEFAULT_PATH.read_text(encoding="utf-8"))
|
|
31
|
+
allowed = data.get("allowed_tools", [])
|
|
32
|
+
if isinstance(allowed, list):
|
|
33
|
+
return cls(allowed_tools=[str(x) for x in allowed], path=DEFAULT_PATH)
|
|
34
|
+
except Exception:
|
|
35
|
+
pass
|
|
36
|
+
return cls(path=DEFAULT_PATH)
|
|
37
|
+
|
|
38
|
+
def is_allowed(self, tool_name: str) -> bool:
|
|
39
|
+
if "*" in self.allowed_tools:
|
|
40
|
+
return True
|
|
41
|
+
return tool_name in self.allowed_tools
|
|
42
|
+
|
|
43
|
+
def grant(self, tool_name: str) -> None:
|
|
44
|
+
if tool_name not in self.allowed_tools:
|
|
45
|
+
self.allowed_tools.append(tool_name)
|
|
46
|
+
self.allowed_tools.sort()
|
|
47
|
+
|
|
48
|
+
def revoke(self, tool_name: str) -> None:
|
|
49
|
+
try:
|
|
50
|
+
self.allowed_tools.remove(tool_name)
|
|
51
|
+
except ValueError:
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
def save(self) -> None:
|
|
55
|
+
try:
|
|
56
|
+
DEFAULT_DIR.mkdir(parents=True, exist_ok=True)
|
|
57
|
+
data = {"allowed_tools": self.allowed_tools}
|
|
58
|
+
self.path.write_text(json.dumps(data, ensure_ascii=False, indent=2))
|
|
59
|
+
except Exception:
|
|
60
|
+
# best-effort; ignore IO failures
|
|
61
|
+
pass
|
|
62
|
+
|
loom/core/permissions.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import Any, Callable, Dict, Optional
|
|
5
|
+
from .permission_store import PermissionStore
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class PermissionAction(str, Enum):
|
|
9
|
+
ALLOW = "allow"
|
|
10
|
+
DENY = "deny"
|
|
11
|
+
ASK = "ask"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
ConfirmHandler = Callable[[str, Dict[str, Any]], bool]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class PermissionManager:
|
|
18
|
+
"""权限网关(框架能力):
|
|
19
|
+
- 默认策略(policy/default)
|
|
20
|
+
- 可选安全模式:启用后优先通过持久化允许列表( PermissionStore ) 进行放行;否则 ASK。
|
|
21
|
+
- 用户确认通过后可持久化授权。
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
policy: Optional[Dict[str, str]] = None,
|
|
27
|
+
default: str = "deny",
|
|
28
|
+
ask_handler: Optional[ConfirmHandler] = None,
|
|
29
|
+
*,
|
|
30
|
+
safe_mode: bool = False,
|
|
31
|
+
permission_store: Optional[PermissionStore] = None,
|
|
32
|
+
persist_on_approve: bool = True,
|
|
33
|
+
) -> None:
|
|
34
|
+
self.policy = {**(policy or {})}
|
|
35
|
+
self.default = default
|
|
36
|
+
self.ask_handler = ask_handler
|
|
37
|
+
self.safe_mode = safe_mode
|
|
38
|
+
self.permission_store = permission_store or PermissionStore.load_default()
|
|
39
|
+
self.persist_on_approve = persist_on_approve
|
|
40
|
+
|
|
41
|
+
def _policy_action(self, tool_name: str) -> PermissionAction:
|
|
42
|
+
action = self.policy.get(tool_name, self.policy.get("default", self.default))
|
|
43
|
+
try:
|
|
44
|
+
return PermissionAction(action)
|
|
45
|
+
except Exception:
|
|
46
|
+
return PermissionAction.DENY
|
|
47
|
+
|
|
48
|
+
def check(self, tool_name: str, arguments: Dict[str, Any]) -> PermissionAction:
|
|
49
|
+
# 1) Policy precedence
|
|
50
|
+
policy_action = self._policy_action(tool_name)
|
|
51
|
+
if not self.safe_mode:
|
|
52
|
+
return policy_action
|
|
53
|
+
|
|
54
|
+
# safe_mode enabled
|
|
55
|
+
if policy_action in (PermissionAction.ALLOW, PermissionAction.DENY):
|
|
56
|
+
return policy_action
|
|
57
|
+
|
|
58
|
+
# ASK or unspecified → consult store
|
|
59
|
+
if self.permission_store.is_allowed(tool_name):
|
|
60
|
+
return PermissionAction.ALLOW
|
|
61
|
+
return PermissionAction.ASK
|
|
62
|
+
|
|
63
|
+
def confirm(self, tool_name: str, arguments: Dict[str, Any]) -> bool:
|
|
64
|
+
approved = bool(self.ask_handler(tool_name, arguments)) if self.ask_handler else False
|
|
65
|
+
if approved and self.safe_mode and self.persist_on_approve and self.permission_store:
|
|
66
|
+
# persist allow for this tool
|
|
67
|
+
self.permission_store.grant(tool_name)
|
|
68
|
+
self.permission_store.save()
|
|
69
|
+
return approved
|
loom/core/scheduler.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Any, AsyncGenerator, Dict, Iterable, Set, Tuple
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
from loom.interfaces.tool import BaseTool
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class SchedulerConfig:
|
|
13
|
+
max_concurrency: int = 10
|
|
14
|
+
timeout_seconds: int = 120
|
|
15
|
+
enable_priority: bool = True
|
|
16
|
+
detect_file_conflicts: bool = True # US4: File write conflict detection
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Scheduler:
|
|
20
|
+
"""智能调度器(并发/超时控制 + US4文件冲突检测)。"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, config: SchedulerConfig | None = None) -> None:
|
|
23
|
+
self.config = config or SchedulerConfig()
|
|
24
|
+
self._semaphore = asyncio.Semaphore(self.config.max_concurrency)
|
|
25
|
+
self._file_locks: Dict[str, asyncio.Lock] = {} # US4: Per-file locks
|
|
26
|
+
|
|
27
|
+
async def schedule_batch(
|
|
28
|
+
self, tool_calls: Iterable[Tuple[BaseTool, Dict]]
|
|
29
|
+
) -> AsyncGenerator[Any, None]:
|
|
30
|
+
"""US4: Enhanced scheduling with file conflict detection.
|
|
31
|
+
|
|
32
|
+
Groups tools into:
|
|
33
|
+
1. Concurrent-safe (parallel_safe=True) - executed in parallel
|
|
34
|
+
2. File-writing with conflicts - serialized by file path
|
|
35
|
+
3. Sequential-only (parallel_safe=False) - executed serially
|
|
36
|
+
"""
|
|
37
|
+
concurrent_safe: list[Tuple[BaseTool, Dict]] = []
|
|
38
|
+
file_writers: list[Tuple[BaseTool, Dict, str]] = [] # (tool, args, file_path)
|
|
39
|
+
sequential_only: list[Tuple[BaseTool, Dict]] = []
|
|
40
|
+
|
|
41
|
+
for tool, args in tool_calls:
|
|
42
|
+
if tool.is_concurrency_safe:
|
|
43
|
+
# Check if it's a file-writing tool (US4)
|
|
44
|
+
file_path = self._detect_file_write(tool, args)
|
|
45
|
+
if file_path and self.config.detect_file_conflicts:
|
|
46
|
+
file_writers.append((tool, args, file_path))
|
|
47
|
+
else:
|
|
48
|
+
concurrent_safe.append((tool, args))
|
|
49
|
+
else:
|
|
50
|
+
sequential_only.append((tool, args))
|
|
51
|
+
|
|
52
|
+
# Execute concurrent-safe tools in parallel
|
|
53
|
+
if concurrent_safe:
|
|
54
|
+
async for result in self._execute_concurrent(concurrent_safe):
|
|
55
|
+
yield result
|
|
56
|
+
|
|
57
|
+
# Execute file writers with conflict detection
|
|
58
|
+
if file_writers:
|
|
59
|
+
async for result in self._execute_file_writers(file_writers):
|
|
60
|
+
yield result
|
|
61
|
+
|
|
62
|
+
# Execute sequential tools serially
|
|
63
|
+
for tool, args in sequential_only:
|
|
64
|
+
yield await self._execute_single(tool, args)
|
|
65
|
+
|
|
66
|
+
def _detect_file_write(self, tool: BaseTool, args: Dict) -> str | None:
|
|
67
|
+
"""US4: Detect if tool is writing to a file.
|
|
68
|
+
|
|
69
|
+
Heuristics:
|
|
70
|
+
- Tool name contains 'write', 'edit', 'save'
|
|
71
|
+
- Args contain 'file_path', 'path', 'filename'
|
|
72
|
+
|
|
73
|
+
Returns normalized file path if detected, None otherwise.
|
|
74
|
+
"""
|
|
75
|
+
tool_name_lower = tool.name.lower()
|
|
76
|
+
is_file_op = any(kw in tool_name_lower for kw in ['write', 'edit', 'save', 'create'])
|
|
77
|
+
|
|
78
|
+
if not is_file_op:
|
|
79
|
+
return None
|
|
80
|
+
|
|
81
|
+
# Extract file path from args
|
|
82
|
+
for key in ['file_path', 'path', 'filename', 'target']:
|
|
83
|
+
if key in args:
|
|
84
|
+
file_path = str(args[key])
|
|
85
|
+
# Normalize path
|
|
86
|
+
try:
|
|
87
|
+
return str(Path(file_path).resolve())
|
|
88
|
+
except Exception:
|
|
89
|
+
return file_path
|
|
90
|
+
|
|
91
|
+
return None
|
|
92
|
+
|
|
93
|
+
async def _execute_file_writers(
|
|
94
|
+
self, file_writers: list[Tuple[BaseTool, Dict, str]]
|
|
95
|
+
) -> AsyncGenerator[Any, None]:
|
|
96
|
+
"""US4: Execute file-writing tools with per-file serialization."""
|
|
97
|
+
async def run_with_lock(tool: BaseTool, args: Dict, file_path: str) -> Any:
|
|
98
|
+
# Get or create lock for this file
|
|
99
|
+
if file_path not in self._file_locks:
|
|
100
|
+
self._file_locks[file_path] = asyncio.Lock()
|
|
101
|
+
|
|
102
|
+
lock = self._file_locks[file_path]
|
|
103
|
+
|
|
104
|
+
async with lock: # Serialize writes to same file
|
|
105
|
+
async with self._semaphore: # Respect global concurrency limit
|
|
106
|
+
return await self._execute_single(tool, args)
|
|
107
|
+
|
|
108
|
+
tasks = [asyncio.create_task(run_with_lock(t, a, fp)) for t, a, fp in file_writers]
|
|
109
|
+
for coro in asyncio.as_completed(tasks):
|
|
110
|
+
yield await coro
|
|
111
|
+
|
|
112
|
+
async def _execute_concurrent(
|
|
113
|
+
self, tool_calls: Iterable[Tuple[BaseTool, Dict]]
|
|
114
|
+
) -> AsyncGenerator[Any, None]:
|
|
115
|
+
async def run(tool: BaseTool, args: Dict) -> Any:
|
|
116
|
+
async with self._semaphore:
|
|
117
|
+
return await self._execute_single(tool, args)
|
|
118
|
+
|
|
119
|
+
tasks = [asyncio.create_task(run(t, a)) for t, a in tool_calls]
|
|
120
|
+
for coro in asyncio.as_completed(tasks):
|
|
121
|
+
yield await coro
|
|
122
|
+
|
|
123
|
+
async def _execute_single(self, tool: BaseTool, args: Dict) -> Any:
|
|
124
|
+
return await asyncio.wait_for(tool.run(**args), timeout=self.config.timeout_seconds)
|
|
125
|
+
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
"""Steering Control: Simplified abort/pause signal management.
|
|
2
|
+
|
|
3
|
+
Replaces legacy EventBus with focus on cancel signals only.
|
|
4
|
+
For real-time steering, use cancel_token (US1 pattern) instead.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import asyncio
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SteeringControl:
|
|
13
|
+
"""Lightweight steering control for abort/pause signals.
|
|
14
|
+
|
|
15
|
+
Note: For cancellation, prefer using cancel_token (asyncio.Event) directly
|
|
16
|
+
with Agent.run(input, cancel_token=token). This class is kept for legacy
|
|
17
|
+
compatibility and may be removed in v5.0.0.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(self) -> None:
|
|
21
|
+
self._abort_signal = asyncio.Event()
|
|
22
|
+
self._pause_signal = asyncio.Event()
|
|
23
|
+
|
|
24
|
+
def abort(self) -> None:
|
|
25
|
+
"""Signal abort request."""
|
|
26
|
+
self._abort_signal.set()
|
|
27
|
+
|
|
28
|
+
def is_aborted(self) -> bool:
|
|
29
|
+
"""Check if abort was requested."""
|
|
30
|
+
return self._abort_signal.is_set()
|
|
31
|
+
|
|
32
|
+
def pause(self) -> None:
|
|
33
|
+
"""Signal pause request."""
|
|
34
|
+
self._pause_signal.set()
|
|
35
|
+
|
|
36
|
+
def resume(self) -> None:
|
|
37
|
+
"""Clear pause signal."""
|
|
38
|
+
self._pause_signal.clear()
|
|
39
|
+
|
|
40
|
+
def is_paused(self) -> bool:
|
|
41
|
+
"""Check if paused."""
|
|
42
|
+
return self._pause_signal.is_set()
|
|
43
|
+
|
|
44
|
+
def reset(self) -> None:
|
|
45
|
+
"""Reset all signals."""
|
|
46
|
+
self._abort_signal.clear()
|
|
47
|
+
self._pause_signal.clear()
|