puffinflow 2.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- puffinflow/__init__.py +132 -0
- puffinflow/core/__init__.py +110 -0
- puffinflow/core/agent/__init__.py +320 -0
- puffinflow/core/agent/base.py +1635 -0
- puffinflow/core/agent/checkpoint.py +50 -0
- puffinflow/core/agent/context.py +521 -0
- puffinflow/core/agent/decorators/__init__.py +90 -0
- puffinflow/core/agent/decorators/builder.py +454 -0
- puffinflow/core/agent/decorators/flexible.py +714 -0
- puffinflow/core/agent/decorators/inspection.py +144 -0
- puffinflow/core/agent/dependencies.py +57 -0
- puffinflow/core/agent/scheduling/__init__.py +21 -0
- puffinflow/core/agent/scheduling/builder.py +160 -0
- puffinflow/core/agent/scheduling/exceptions.py +35 -0
- puffinflow/core/agent/scheduling/inputs.py +137 -0
- puffinflow/core/agent/scheduling/parser.py +209 -0
- puffinflow/core/agent/scheduling/scheduler.py +413 -0
- puffinflow/core/agent/state.py +141 -0
- puffinflow/core/config.py +62 -0
- puffinflow/core/coordination/__init__.py +137 -0
- puffinflow/core/coordination/agent_group.py +359 -0
- puffinflow/core/coordination/agent_pool.py +629 -0
- puffinflow/core/coordination/agent_team.py +577 -0
- puffinflow/core/coordination/coordinator.py +720 -0
- puffinflow/core/coordination/deadlock.py +1759 -0
- puffinflow/core/coordination/fluent_api.py +421 -0
- puffinflow/core/coordination/primitives.py +478 -0
- puffinflow/core/coordination/rate_limiter.py +520 -0
- puffinflow/core/observability/__init__.py +47 -0
- puffinflow/core/observability/agent.py +139 -0
- puffinflow/core/observability/alerting.py +73 -0
- puffinflow/core/observability/config.py +127 -0
- puffinflow/core/observability/context.py +88 -0
- puffinflow/core/observability/core.py +147 -0
- puffinflow/core/observability/decorators.py +105 -0
- puffinflow/core/observability/events.py +71 -0
- puffinflow/core/observability/interfaces.py +196 -0
- puffinflow/core/observability/metrics.py +137 -0
- puffinflow/core/observability/tracing.py +209 -0
- puffinflow/core/reliability/__init__.py +27 -0
- puffinflow/core/reliability/bulkhead.py +96 -0
- puffinflow/core/reliability/circuit_breaker.py +149 -0
- puffinflow/core/reliability/leak_detector.py +122 -0
- puffinflow/core/resources/__init__.py +77 -0
- puffinflow/core/resources/allocation.py +790 -0
- puffinflow/core/resources/pool.py +645 -0
- puffinflow/core/resources/quotas.py +567 -0
- puffinflow/core/resources/requirements.py +217 -0
- puffinflow/version.py +21 -0
- puffinflow-2.dev0.dist-info/METADATA +334 -0
- puffinflow-2.dev0.dist-info/RECORD +55 -0
- puffinflow-2.dev0.dist-info/WHEEL +5 -0
- puffinflow-2.dev0.dist-info/entry_points.txt +3 -0
- puffinflow-2.dev0.dist-info/licenses/LICENSE +21 -0
- puffinflow-2.dev0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1635 @@
|
|
|
1
|
+
"""Enhanced Agent with direct access and coordination features."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import pickle
|
|
7
|
+
import time
|
|
8
|
+
import weakref
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, Union
|
|
12
|
+
|
|
13
|
+
from .checkpoint import AgentCheckpoint
|
|
14
|
+
from .context import Context
|
|
15
|
+
from .state import (
|
|
16
|
+
AgentStatus,
|
|
17
|
+
DeadLetter,
|
|
18
|
+
PrioritizedState,
|
|
19
|
+
Priority,
|
|
20
|
+
RetryPolicy,
|
|
21
|
+
StateMetadata,
|
|
22
|
+
StateResult,
|
|
23
|
+
StateStatus,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
# Import scheduling components
|
|
27
|
+
try:
|
|
28
|
+
from .scheduling.builder import ScheduleBuilder
|
|
29
|
+
from .scheduling.scheduler import GlobalScheduler, ScheduledAgent
|
|
30
|
+
|
|
31
|
+
_SCHEDULING_AVAILABLE = True
|
|
32
|
+
except ImportError:
|
|
33
|
+
_SCHEDULING_AVAILABLE = False
|
|
34
|
+
|
|
35
|
+
# Import ResourceRequirements conditionally
|
|
36
|
+
try:
|
|
37
|
+
from ..resources.requirements import ResourceRequirements
|
|
38
|
+
|
|
39
|
+
_ResourceRequirements: Optional[type] = ResourceRequirements
|
|
40
|
+
except ImportError:
|
|
41
|
+
_ResourceRequirements = None
|
|
42
|
+
|
|
43
|
+
# Import these conditionally to avoid circular imports
|
|
44
|
+
if TYPE_CHECKING:
|
|
45
|
+
from ..coordination.agent_team import AgentTeam
|
|
46
|
+
from ..coordination.primitives import CoordinationPrimitive
|
|
47
|
+
from ..reliability.bulkhead import Bulkhead, BulkheadConfig
|
|
48
|
+
from ..reliability.circuit_breaker import CircuitBreaker, CircuitBreakerConfig
|
|
49
|
+
from ..resources.pool import ResourcePool
|
|
50
|
+
|
|
51
|
+
logger = logging.getLogger(__name__)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
# Checkpoint persistence interfaces
|
|
55
|
+
class CheckpointStorage(Protocol):
|
|
56
|
+
"""Protocol for checkpoint storage backends."""
|
|
57
|
+
|
|
58
|
+
async def save_checkpoint(
|
|
59
|
+
self, agent_name: str, checkpoint: AgentCheckpoint
|
|
60
|
+
) -> str:
|
|
61
|
+
"""Save checkpoint and return checkpoint ID."""
|
|
62
|
+
...
|
|
63
|
+
|
|
64
|
+
async def load_checkpoint(
|
|
65
|
+
self, agent_name: str, checkpoint_id: Optional[str] = None
|
|
66
|
+
) -> Optional[AgentCheckpoint]:
|
|
67
|
+
"""Load checkpoint by ID or latest for agent."""
|
|
68
|
+
...
|
|
69
|
+
|
|
70
|
+
async def list_checkpoints(self, agent_name: str) -> list[str]:
|
|
71
|
+
"""List available checkpoint IDs for agent."""
|
|
72
|
+
...
|
|
73
|
+
|
|
74
|
+
async def delete_checkpoint(self, agent_name: str, checkpoint_id: str) -> bool:
|
|
75
|
+
"""Delete a specific checkpoint."""
|
|
76
|
+
...
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class FileCheckpointStorage:
|
|
80
|
+
"""File-based checkpoint storage."""
|
|
81
|
+
|
|
82
|
+
def __init__(self, base_path: str = "./checkpoints", format: str = "pickle"):
|
|
83
|
+
"""
|
|
84
|
+
Initialize file storage.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
base_path: Directory to store checkpoint files
|
|
88
|
+
format: Storage format ('pickle' or 'json')
|
|
89
|
+
"""
|
|
90
|
+
self.base_path = Path(base_path)
|
|
91
|
+
self.format = format.lower()
|
|
92
|
+
if self.format not in ("pickle", "json"):
|
|
93
|
+
raise ValueError(f"Unsupported format: {format}. Use 'pickle' or 'json'")
|
|
94
|
+
|
|
95
|
+
# Create directory if it doesn't exist
|
|
96
|
+
self.base_path.mkdir(parents=True, exist_ok=True)
|
|
97
|
+
|
|
98
|
+
def _get_checkpoint_path(self, agent_name: str, checkpoint_id: str) -> Path:
|
|
99
|
+
"""Get file path for checkpoint."""
|
|
100
|
+
agent_dir = self.base_path / agent_name
|
|
101
|
+
agent_dir.mkdir(exist_ok=True)
|
|
102
|
+
ext = "pkl" if self.format == "pickle" else "json"
|
|
103
|
+
return agent_dir / f"{checkpoint_id}.{ext}"
|
|
104
|
+
|
|
105
|
+
async def save_checkpoint(
|
|
106
|
+
self, agent_name: str, checkpoint: AgentCheckpoint
|
|
107
|
+
) -> str:
|
|
108
|
+
"""Save checkpoint to file."""
|
|
109
|
+
checkpoint_id = f"checkpoint_{int(checkpoint.timestamp)}"
|
|
110
|
+
file_path = self._get_checkpoint_path(agent_name, checkpoint_id)
|
|
111
|
+
|
|
112
|
+
try:
|
|
113
|
+
if self.format == "pickle":
|
|
114
|
+
with file_path.open("wb") as f:
|
|
115
|
+
pickle.dump(checkpoint, f)
|
|
116
|
+
else: # json
|
|
117
|
+
# Convert checkpoint to JSON-serializable format
|
|
118
|
+
checkpoint_data = {
|
|
119
|
+
"timestamp": checkpoint.timestamp,
|
|
120
|
+
"agent_name": checkpoint.agent_name,
|
|
121
|
+
"agent_status": checkpoint.agent_status.value,
|
|
122
|
+
"priority_queue": [
|
|
123
|
+
{
|
|
124
|
+
"priority": ps.priority,
|
|
125
|
+
"timestamp": ps.timestamp,
|
|
126
|
+
"state_name": ps.state_name,
|
|
127
|
+
"metadata": {
|
|
128
|
+
"status": ps.metadata.status.value,
|
|
129
|
+
"attempts": ps.metadata.attempts,
|
|
130
|
+
"max_retries": ps.metadata.max_retries,
|
|
131
|
+
"last_execution": ps.metadata.last_execution,
|
|
132
|
+
"last_success": ps.metadata.last_success,
|
|
133
|
+
"state_id": ps.metadata.state_id,
|
|
134
|
+
"priority": ps.metadata.priority.value,
|
|
135
|
+
},
|
|
136
|
+
}
|
|
137
|
+
for ps in checkpoint.priority_queue
|
|
138
|
+
],
|
|
139
|
+
"state_metadata": {
|
|
140
|
+
k: {
|
|
141
|
+
"status": v.status.value,
|
|
142
|
+
"attempts": v.attempts,
|
|
143
|
+
"max_retries": v.max_retries,
|
|
144
|
+
"last_execution": v.last_execution,
|
|
145
|
+
"last_success": v.last_success,
|
|
146
|
+
"state_id": v.state_id,
|
|
147
|
+
"priority": v.priority.value,
|
|
148
|
+
}
|
|
149
|
+
for k, v in checkpoint.state_metadata.items()
|
|
150
|
+
},
|
|
151
|
+
"running_states": list(checkpoint.running_states),
|
|
152
|
+
"completed_states": list(checkpoint.completed_states),
|
|
153
|
+
"completed_once": list(checkpoint.completed_once),
|
|
154
|
+
"shared_state": checkpoint.shared_state,
|
|
155
|
+
"session_start": checkpoint.session_start,
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
with file_path.open("w") as f:
|
|
159
|
+
json.dump(checkpoint_data, f, indent=2, default=str)
|
|
160
|
+
|
|
161
|
+
logger.info(f"Checkpoint saved to {file_path}")
|
|
162
|
+
return checkpoint_id
|
|
163
|
+
|
|
164
|
+
except Exception as e:
|
|
165
|
+
logger.error(f"Failed to save checkpoint to {file_path}: {e}")
|
|
166
|
+
raise
|
|
167
|
+
|
|
168
|
+
async def load_checkpoint(
|
|
169
|
+
self, agent_name: str, checkpoint_id: Optional[str] = None
|
|
170
|
+
) -> Optional[AgentCheckpoint]:
|
|
171
|
+
"""Load checkpoint from file."""
|
|
172
|
+
if checkpoint_id is None:
|
|
173
|
+
# Load latest checkpoint
|
|
174
|
+
checkpoints = await self.list_checkpoints(agent_name)
|
|
175
|
+
if not checkpoints:
|
|
176
|
+
return None
|
|
177
|
+
checkpoint_id = checkpoints[-1] # Latest by timestamp
|
|
178
|
+
|
|
179
|
+
file_path = self._get_checkpoint_path(agent_name, checkpoint_id)
|
|
180
|
+
|
|
181
|
+
if not file_path.exists():
|
|
182
|
+
logger.warning(f"Checkpoint file not found: {file_path}")
|
|
183
|
+
return None
|
|
184
|
+
|
|
185
|
+
try:
|
|
186
|
+
if self.format == "pickle":
|
|
187
|
+
with file_path.open("rb") as f:
|
|
188
|
+
checkpoint: AgentCheckpoint = pickle.load(f)
|
|
189
|
+
return checkpoint
|
|
190
|
+
else: # json
|
|
191
|
+
with file_path.open("r") as f:
|
|
192
|
+
data = json.load(f)
|
|
193
|
+
|
|
194
|
+
# Reconstruct checkpoint from JSON data
|
|
195
|
+
from .state import (
|
|
196
|
+
AgentStatus,
|
|
197
|
+
PrioritizedState,
|
|
198
|
+
Priority,
|
|
199
|
+
StateMetadata,
|
|
200
|
+
StateStatus,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
checkpoint = AgentCheckpoint(
|
|
204
|
+
timestamp=data["timestamp"],
|
|
205
|
+
agent_name=data["agent_name"],
|
|
206
|
+
agent_status=AgentStatus(data["agent_status"]),
|
|
207
|
+
priority_queue=[
|
|
208
|
+
PrioritizedState(
|
|
209
|
+
priority=ps["priority"],
|
|
210
|
+
timestamp=ps["timestamp"],
|
|
211
|
+
state_name=ps["state_name"],
|
|
212
|
+
metadata=StateMetadata(
|
|
213
|
+
status=StateStatus(ps["metadata"]["status"]),
|
|
214
|
+
attempts=ps["metadata"]["attempts"],
|
|
215
|
+
max_retries=ps["metadata"]["max_retries"],
|
|
216
|
+
last_execution=ps["metadata"]["last_execution"],
|
|
217
|
+
last_success=ps["metadata"]["last_success"],
|
|
218
|
+
state_id=ps["metadata"]["state_id"],
|
|
219
|
+
priority=Priority(ps["metadata"]["priority"]),
|
|
220
|
+
),
|
|
221
|
+
)
|
|
222
|
+
for ps in data["priority_queue"]
|
|
223
|
+
],
|
|
224
|
+
state_metadata={
|
|
225
|
+
k: StateMetadata(
|
|
226
|
+
status=StateStatus(v["status"]),
|
|
227
|
+
attempts=v["attempts"],
|
|
228
|
+
max_retries=v["max_retries"],
|
|
229
|
+
last_execution=v["last_execution"],
|
|
230
|
+
last_success=v["last_success"],
|
|
231
|
+
state_id=v["state_id"],
|
|
232
|
+
priority=Priority(v["priority"]),
|
|
233
|
+
)
|
|
234
|
+
for k, v in data["state_metadata"].items()
|
|
235
|
+
},
|
|
236
|
+
running_states=set(data["running_states"]),
|
|
237
|
+
completed_states=set(data["completed_states"]),
|
|
238
|
+
completed_once=set(data["completed_once"]),
|
|
239
|
+
shared_state=data["shared_state"],
|
|
240
|
+
session_start=data["session_start"],
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
return checkpoint
|
|
244
|
+
|
|
245
|
+
except Exception as e:
|
|
246
|
+
logger.error(f"Failed to load checkpoint from {file_path}: {e}")
|
|
247
|
+
return None
|
|
248
|
+
|
|
249
|
+
async def list_checkpoints(self, agent_name: str) -> list[str]:
|
|
250
|
+
"""List available checkpoint files."""
|
|
251
|
+
agent_dir = self.base_path / agent_name
|
|
252
|
+
if not agent_dir.exists():
|
|
253
|
+
return []
|
|
254
|
+
|
|
255
|
+
ext = "pkl" if self.format == "pickle" else "json"
|
|
256
|
+
checkpoint_files = [f.stem for f in agent_dir.glob(f"*.{ext}") if f.is_file()]
|
|
257
|
+
|
|
258
|
+
# Sort by timestamp (extract from filename)
|
|
259
|
+
checkpoint_files.sort(
|
|
260
|
+
key=lambda x: int(x.split("_")[-1]) if x.split("_")[-1].isdigit() else 0
|
|
261
|
+
)
|
|
262
|
+
return checkpoint_files
|
|
263
|
+
|
|
264
|
+
async def delete_checkpoint(self, agent_name: str, checkpoint_id: str) -> bool:
|
|
265
|
+
"""Delete checkpoint file."""
|
|
266
|
+
file_path = self._get_checkpoint_path(agent_name, checkpoint_id)
|
|
267
|
+
|
|
268
|
+
try:
|
|
269
|
+
if file_path.exists():
|
|
270
|
+
file_path.unlink()
|
|
271
|
+
logger.info(f"Deleted checkpoint: {file_path}")
|
|
272
|
+
return True
|
|
273
|
+
return False
|
|
274
|
+
except Exception as e:
|
|
275
|
+
logger.error(f"Failed to delete checkpoint {file_path}: {e}")
|
|
276
|
+
return False
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
class MemoryCheckpointStorage:
|
|
280
|
+
"""In-memory checkpoint storage for testing."""
|
|
281
|
+
|
|
282
|
+
def __init__(self) -> None:
|
|
283
|
+
self._checkpoints: dict[str, dict[str, AgentCheckpoint]] = {}
|
|
284
|
+
|
|
285
|
+
async def save_checkpoint(
|
|
286
|
+
self, agent_name: str, checkpoint: AgentCheckpoint
|
|
287
|
+
) -> str:
|
|
288
|
+
"""Save checkpoint to memory."""
|
|
289
|
+
checkpoint_id = f"checkpoint_{int(checkpoint.timestamp)}"
|
|
290
|
+
|
|
291
|
+
if agent_name not in self._checkpoints:
|
|
292
|
+
self._checkpoints[agent_name] = {}
|
|
293
|
+
|
|
294
|
+
# Deep copy to prevent modifications
|
|
295
|
+
import copy
|
|
296
|
+
|
|
297
|
+
self._checkpoints[agent_name][checkpoint_id] = copy.deepcopy(checkpoint)
|
|
298
|
+
|
|
299
|
+
logger.info(f"Checkpoint saved to memory: {agent_name}/{checkpoint_id}")
|
|
300
|
+
return checkpoint_id
|
|
301
|
+
|
|
302
|
+
async def load_checkpoint(
|
|
303
|
+
self, agent_name: str, checkpoint_id: Optional[str] = None
|
|
304
|
+
) -> Optional[AgentCheckpoint]:
|
|
305
|
+
"""Load checkpoint from memory."""
|
|
306
|
+
if agent_name not in self._checkpoints:
|
|
307
|
+
return None
|
|
308
|
+
|
|
309
|
+
agent_checkpoints = self._checkpoints[agent_name]
|
|
310
|
+
|
|
311
|
+
if checkpoint_id is None:
|
|
312
|
+
# Get latest checkpoint
|
|
313
|
+
if not agent_checkpoints:
|
|
314
|
+
return None
|
|
315
|
+
latest_id = max(
|
|
316
|
+
agent_checkpoints.keys(), key=lambda x: int(x.split("_")[-1])
|
|
317
|
+
)
|
|
318
|
+
checkpoint_id = latest_id
|
|
319
|
+
|
|
320
|
+
checkpoint = agent_checkpoints.get(checkpoint_id)
|
|
321
|
+
if checkpoint:
|
|
322
|
+
# Return deep copy to prevent modifications
|
|
323
|
+
import copy
|
|
324
|
+
|
|
325
|
+
return copy.deepcopy(checkpoint)
|
|
326
|
+
return None
|
|
327
|
+
|
|
328
|
+
async def list_checkpoints(self, agent_name: str) -> list[str]:
|
|
329
|
+
"""List checkpoints in memory."""
|
|
330
|
+
if agent_name not in self._checkpoints:
|
|
331
|
+
return []
|
|
332
|
+
|
|
333
|
+
checkpoints = list(self._checkpoints[agent_name].keys())
|
|
334
|
+
checkpoints.sort(
|
|
335
|
+
key=lambda x: int(x.split("_")[-1]) if x.split("_")[-1].isdigit() else 0
|
|
336
|
+
)
|
|
337
|
+
return checkpoints
|
|
338
|
+
|
|
339
|
+
async def delete_checkpoint(self, agent_name: str, checkpoint_id: str) -> bool:
|
|
340
|
+
"""Delete checkpoint from memory."""
|
|
341
|
+
if (
|
|
342
|
+
agent_name in self._checkpoints
|
|
343
|
+
and checkpoint_id in self._checkpoints[agent_name]
|
|
344
|
+
):
|
|
345
|
+
del self._checkpoints[agent_name][checkpoint_id]
|
|
346
|
+
logger.info(f"Deleted checkpoint from memory: {agent_name}/{checkpoint_id}")
|
|
347
|
+
return True
|
|
348
|
+
return False
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
@dataclass
|
|
352
|
+
class AgentResult:
|
|
353
|
+
"""Rich result container for agent execution."""
|
|
354
|
+
|
|
355
|
+
agent_name: str
|
|
356
|
+
status: AgentStatus
|
|
357
|
+
outputs: dict[str, Any] = field(default_factory=dict)
|
|
358
|
+
variables: dict[str, Any] = field(default_factory=dict)
|
|
359
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
360
|
+
metrics: dict[str, Any] = field(default_factory=dict)
|
|
361
|
+
error: Optional[Exception] = None
|
|
362
|
+
start_time: Optional[float] = None
|
|
363
|
+
end_time: Optional[float] = None
|
|
364
|
+
execution_duration: Optional[float] = None
|
|
365
|
+
|
|
366
|
+
def get_output(self, key: str, default: Any = None) -> Any:
|
|
367
|
+
"""Get output value."""
|
|
368
|
+
return self.outputs.get(key, default)
|
|
369
|
+
|
|
370
|
+
def get_variable(self, key: str, default: Any = None) -> Any:
|
|
371
|
+
"""Get variable value."""
|
|
372
|
+
return self.variables.get(key, default)
|
|
373
|
+
|
|
374
|
+
def get_metadata(self, key: str, default: Any = None) -> Any:
|
|
375
|
+
"""Get metadata value."""
|
|
376
|
+
return self.metadata.get(key, default)
|
|
377
|
+
|
|
378
|
+
def get_metric(self, key: str, default: Any = None) -> Any:
|
|
379
|
+
"""Get metric value."""
|
|
380
|
+
return self.metrics.get(key, default)
|
|
381
|
+
|
|
382
|
+
@property
|
|
383
|
+
def is_success(self) -> bool:
|
|
384
|
+
"""Check if execution was successful."""
|
|
385
|
+
return self.status == AgentStatus.COMPLETED and self.error is None
|
|
386
|
+
|
|
387
|
+
@property
|
|
388
|
+
def is_failed(self) -> bool:
|
|
389
|
+
"""Check if execution failed."""
|
|
390
|
+
return self.status == AgentStatus.FAILED or self.error is not None
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
class ResourceTimeoutError(Exception):
|
|
394
|
+
"""Raised when resource acquisition times out."""
|
|
395
|
+
|
|
396
|
+
pass
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
class Agent:
|
|
400
|
+
"""Enhanced Agent with direct variable access and coordination features."""
|
|
401
|
+
|
|
402
|
+
def __init__(
|
|
403
|
+
self,
|
|
404
|
+
name: str,
|
|
405
|
+
resource_pool: Optional["ResourcePool"] = None,
|
|
406
|
+
retry_policy: Optional[RetryPolicy] = None,
|
|
407
|
+
circuit_breaker_config: Optional["CircuitBreakerConfig"] = None,
|
|
408
|
+
bulkhead_config: Optional["BulkheadConfig"] = None,
|
|
409
|
+
max_concurrent: int = 5,
|
|
410
|
+
enable_dead_letter: bool = True,
|
|
411
|
+
state_timeout: Optional[float] = None,
|
|
412
|
+
checkpoint_storage: Optional[CheckpointStorage] = None,
|
|
413
|
+
**kwargs: Any,
|
|
414
|
+
) -> None:
|
|
415
|
+
self.name = name
|
|
416
|
+
self.states: dict[str, Callable] = {}
|
|
417
|
+
self.state_metadata: dict[str, StateMetadata] = {}
|
|
418
|
+
self.dependencies: dict[str, list[str]] = {}
|
|
419
|
+
self.status = AgentStatus.IDLE
|
|
420
|
+
self.shared_state: dict[str, Any] = {}
|
|
421
|
+
self.priority_queue: list[PrioritizedState] = []
|
|
422
|
+
self.running_states: set[str] = set()
|
|
423
|
+
self.completed_states: set[str] = set()
|
|
424
|
+
self.completed_once: set[str] = set()
|
|
425
|
+
self.dead_letters: list[DeadLetter] = []
|
|
426
|
+
self.session_start: Optional[float] = None
|
|
427
|
+
|
|
428
|
+
# Configuration
|
|
429
|
+
self.max_concurrent = max_concurrent
|
|
430
|
+
self.state_timeout = state_timeout
|
|
431
|
+
self.checkpoint_storage = checkpoint_storage or MemoryCheckpointStorage()
|
|
432
|
+
|
|
433
|
+
# Enhanced features
|
|
434
|
+
self._context: Optional[Context] = None
|
|
435
|
+
self._variable_watchers: dict[str, list[Callable]] = {}
|
|
436
|
+
self._shared_variable_watchers: dict[str, list[Callable]] = {}
|
|
437
|
+
self._agent_variables: dict[str, Any] = {}
|
|
438
|
+
self._persistent_variables: dict[str, Any] = {}
|
|
439
|
+
self._property_definitions: dict[str, dict] = {}
|
|
440
|
+
self._team: Optional[weakref.ReferenceType] = None
|
|
441
|
+
self._message_handlers: dict[str, Callable] = {}
|
|
442
|
+
self._event_handlers: dict[str, list[Callable]] = {}
|
|
443
|
+
self._state_change_handlers: list[Callable] = []
|
|
444
|
+
|
|
445
|
+
# Resource and reliability components - lazy initialization
|
|
446
|
+
self._resource_pool = resource_pool
|
|
447
|
+
self._circuit_breaker: Optional[CircuitBreaker] = None
|
|
448
|
+
self._bulkhead: Optional[Bulkhead] = None
|
|
449
|
+
self._circuit_breaker_config = circuit_breaker_config
|
|
450
|
+
self._bulkhead_config = bulkhead_config
|
|
451
|
+
|
|
452
|
+
self.retry_policy = retry_policy or RetryPolicy()
|
|
453
|
+
self.enable_dead_letter = enable_dead_letter
|
|
454
|
+
self._cleanup_handlers: list[Callable] = []
|
|
455
|
+
|
|
456
|
+
# Create context
|
|
457
|
+
self.context = self._create_context(self.shared_state)
|
|
458
|
+
|
|
459
|
+
@property
|
|
460
|
+
def resource_pool(self) -> "ResourcePool":
|
|
461
|
+
"""Get or create resource pool."""
|
|
462
|
+
if self._resource_pool is None:
|
|
463
|
+
from ..resources.pool import ResourcePool
|
|
464
|
+
|
|
465
|
+
self._resource_pool = ResourcePool()
|
|
466
|
+
return self._resource_pool
|
|
467
|
+
|
|
468
|
+
@resource_pool.setter
|
|
469
|
+
def resource_pool(self, value: "ResourcePool") -> None:
|
|
470
|
+
"""Set resource pool."""
|
|
471
|
+
self._resource_pool = value
|
|
472
|
+
|
|
473
|
+
@property
|
|
474
|
+
def circuit_breaker(self) -> "CircuitBreaker":
|
|
475
|
+
"""Get or create circuit breaker."""
|
|
476
|
+
if self._circuit_breaker is None:
|
|
477
|
+
from ..reliability.circuit_breaker import (
|
|
478
|
+
CircuitBreaker,
|
|
479
|
+
CircuitBreakerConfig,
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
if self._circuit_breaker_config:
|
|
483
|
+
config = self._circuit_breaker_config
|
|
484
|
+
else:
|
|
485
|
+
config = CircuitBreakerConfig(name=f"{self.name}_circuit_breaker")
|
|
486
|
+
|
|
487
|
+
self._circuit_breaker = CircuitBreaker(config)
|
|
488
|
+
return self._circuit_breaker
|
|
489
|
+
|
|
490
|
+
@property
|
|
491
|
+
def bulkhead(self) -> "Bulkhead":
|
|
492
|
+
"""Get or create bulkhead."""
|
|
493
|
+
if self._bulkhead is None:
|
|
494
|
+
from ..reliability.bulkhead import Bulkhead, BulkheadConfig
|
|
495
|
+
|
|
496
|
+
if self._bulkhead_config:
|
|
497
|
+
config = self._bulkhead_config
|
|
498
|
+
else:
|
|
499
|
+
config = BulkheadConfig(
|
|
500
|
+
name=f"{self.name}_bulkhead", max_concurrent=self.max_concurrent
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
self._bulkhead = Bulkhead(config)
|
|
504
|
+
return self._bulkhead
|
|
505
|
+
|
|
506
|
+
# Direct variable access methods
|
|
507
|
+
def get_variable(self, key: str, default: Any = None) -> Any:
|
|
508
|
+
"""Get variable directly from agent context or internal storage."""
|
|
509
|
+
if self._context:
|
|
510
|
+
return self._context.get_variable(key, default)
|
|
511
|
+
return self._agent_variables.get(key, default)
|
|
512
|
+
|
|
513
|
+
def set_variable(self, key: str, value: Any) -> None:
|
|
514
|
+
"""Set variable directly on agent context or internal storage."""
|
|
515
|
+
if self._context:
|
|
516
|
+
old_value = self._context.get_variable(key)
|
|
517
|
+
self._context.set_variable(key, value)
|
|
518
|
+
self._trigger_variable_watchers(key, old_value, value)
|
|
519
|
+
else:
|
|
520
|
+
old_value = self._agent_variables.get(key)
|
|
521
|
+
self._agent_variables[key] = value
|
|
522
|
+
self._trigger_variable_watchers(key, old_value, value)
|
|
523
|
+
|
|
524
|
+
def increment_variable(self, key: str, amount: Union[int, float] = 1) -> None:
|
|
525
|
+
"""Increment a numeric variable."""
|
|
526
|
+
current = self.get_variable(key, 0)
|
|
527
|
+
self.set_variable(key, current + amount)
|
|
528
|
+
|
|
529
|
+
def append_variable(self, key: str, value: Any) -> None:
|
|
530
|
+
"""Append to a list variable."""
|
|
531
|
+
current = self.get_variable(key, [])
|
|
532
|
+
if not isinstance(current, list):
|
|
533
|
+
current = [current]
|
|
534
|
+
current.append(value)
|
|
535
|
+
self.set_variable(key, current)
|
|
536
|
+
|
|
537
|
+
def get_shared_variable(self, key: str, default: Any = None) -> Any:
|
|
538
|
+
"""Get shared variable accessible to all agents."""
|
|
539
|
+
return self.shared_state.get(key, default)
|
|
540
|
+
|
|
541
|
+
def set_shared_variable(self, key: str, value: Any) -> None:
|
|
542
|
+
"""Set shared variable accessible to all agents."""
|
|
543
|
+
old_value = self.shared_state.get(key)
|
|
544
|
+
self.shared_state[key] = value
|
|
545
|
+
self._trigger_shared_variable_watchers(key, old_value, value)
|
|
546
|
+
|
|
547
|
+
def get_agent_variable(self, key: str, default: Any = None) -> Any:
|
|
548
|
+
"""Get agent-specific variable (not shared)."""
|
|
549
|
+
return self._agent_variables.get(key, default)
|
|
550
|
+
|
|
551
|
+
def set_agent_variable(self, key: str, value: Any) -> None:
|
|
552
|
+
"""Set agent-specific variable (not shared)."""
|
|
553
|
+
old_value = self._agent_variables.get(key)
|
|
554
|
+
self._agent_variables[key] = value
|
|
555
|
+
self._trigger_variable_watchers(key, old_value, value)
|
|
556
|
+
|
|
557
|
+
def get_persistent_variable(self, key: str, default: Any = None) -> Any:
|
|
558
|
+
"""Get persistent variable that survives restarts."""
|
|
559
|
+
return self._persistent_variables.get(key, default)
|
|
560
|
+
|
|
561
|
+
def set_persistent_variable(self, key: str, value: Any) -> None:
|
|
562
|
+
"""Set persistent variable that survives restarts."""
|
|
563
|
+
self._persistent_variables[key] = value
|
|
564
|
+
|
|
565
|
+
# Context content access methods
|
|
566
|
+
def get_output(self, key: str, default: Any = None) -> Any:
|
|
567
|
+
"""Get output value from context."""
|
|
568
|
+
if self._context:
|
|
569
|
+
return self._context.get_output(key, default)
|
|
570
|
+
return default
|
|
571
|
+
|
|
572
|
+
def set_output(self, key: str, value: Any) -> None:
|
|
573
|
+
"""Set output value in context."""
|
|
574
|
+
if self._context:
|
|
575
|
+
self._context.set_output(key, value)
|
|
576
|
+
|
|
577
|
+
def get_all_outputs(self) -> dict[str, Any]:
|
|
578
|
+
"""Get all output values."""
|
|
579
|
+
if self._context:
|
|
580
|
+
output_keys = self._context.get_output_keys()
|
|
581
|
+
return {key: self._context.get_output(key) for key in output_keys}
|
|
582
|
+
return {}
|
|
583
|
+
|
|
584
|
+
def get_metadata(self, key: str, default: Any = None) -> Any:
|
|
585
|
+
"""Get metadata value."""
|
|
586
|
+
if self._context and hasattr(self._context, "get_metadata"):
|
|
587
|
+
return self._context.get_metadata(key, default)
|
|
588
|
+
return default
|
|
589
|
+
|
|
590
|
+
def set_metadata(self, key: str, value: Any) -> None:
|
|
591
|
+
"""Set metadata value."""
|
|
592
|
+
if self._context and hasattr(self._context, "set_metadata"):
|
|
593
|
+
self._context.set_metadata(key, value)
|
|
594
|
+
|
|
595
|
+
def get_cached(self, key: str, default: Any = None) -> Any:
|
|
596
|
+
"""Get cached value."""
|
|
597
|
+
if self._context:
|
|
598
|
+
return self._context.get_cached(key, default)
|
|
599
|
+
return default
|
|
600
|
+
|
|
601
|
+
def set_cached(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
|
|
602
|
+
"""Set cached value with optional TTL."""
|
|
603
|
+
if self._context:
|
|
604
|
+
self._context.set_cached(key, value, ttl)
|
|
605
|
+
|
|
606
|
+
# Property system
|
|
607
|
+
def define_property(
|
|
608
|
+
self,
|
|
609
|
+
name: str,
|
|
610
|
+
prop_type: type,
|
|
611
|
+
default: Any = None,
|
|
612
|
+
validator: Optional[Callable] = None,
|
|
613
|
+
) -> None:
|
|
614
|
+
"""Define a typed property with validation."""
|
|
615
|
+
self._property_definitions[name] = {
|
|
616
|
+
"type": prop_type,
|
|
617
|
+
"default": default,
|
|
618
|
+
"validator": validator,
|
|
619
|
+
}
|
|
620
|
+
|
|
621
|
+
# Set default value if not already set
|
|
622
|
+
if name not in self._agent_variables:
|
|
623
|
+
self.set_variable(name, default)
|
|
624
|
+
|
|
625
|
+
# Create property accessor
|
|
626
|
+
def getter(obj: Any) -> Any:
|
|
627
|
+
return obj.get_variable(name, default)
|
|
628
|
+
|
|
629
|
+
def setter(obj: Any, value: Any) -> None:
|
|
630
|
+
if validator:
|
|
631
|
+
value = validator(value)
|
|
632
|
+
if not isinstance(value, prop_type) and value is not None:
|
|
633
|
+
try:
|
|
634
|
+
value = prop_type(value)
|
|
635
|
+
except (ValueError, TypeError) as e:
|
|
636
|
+
raise TypeError(f"Cannot convert {value} to {prop_type}") from e
|
|
637
|
+
obj.set_variable(name, value)
|
|
638
|
+
|
|
639
|
+
setattr(self.__class__, name, property(getter, setter))
|
|
640
|
+
|
|
641
|
+
# Variable watching
|
|
642
|
+
def watch_variable(self, key: str, handler: Callable) -> None:
|
|
643
|
+
"""Watch for changes to a variable."""
|
|
644
|
+
if key not in self._variable_watchers:
|
|
645
|
+
self._variable_watchers[key] = []
|
|
646
|
+
self._variable_watchers[key].append(handler)
|
|
647
|
+
|
|
648
|
+
def watch_shared_variable(self, key: str, handler: Callable) -> None:
|
|
649
|
+
"""Watch for changes to a shared variable."""
|
|
650
|
+
if key not in self._shared_variable_watchers:
|
|
651
|
+
self._shared_variable_watchers[key] = []
|
|
652
|
+
self._shared_variable_watchers[key].append(handler)
|
|
653
|
+
|
|
654
|
+
def _trigger_variable_watchers(
|
|
655
|
+
self, key: str, old_value: Any, new_value: Any
|
|
656
|
+
) -> None:
|
|
657
|
+
"""Trigger watchers for variable changes."""
|
|
658
|
+
if key in self._variable_watchers and old_value != new_value:
|
|
659
|
+
for handler in self._variable_watchers[key]:
|
|
660
|
+
try:
|
|
661
|
+
if asyncio.iscoroutinefunction(handler):
|
|
662
|
+
task = asyncio.create_task(handler(old_value, new_value))
|
|
663
|
+
# Store task reference to prevent it from being garbage collected
|
|
664
|
+
if not hasattr(self, "_background_tasks"):
|
|
665
|
+
self._background_tasks = set()
|
|
666
|
+
self._background_tasks.add(task)
|
|
667
|
+
task.add_done_callback(
|
|
668
|
+
lambda t: self._background_tasks.discard(t)
|
|
669
|
+
)
|
|
670
|
+
else:
|
|
671
|
+
handler(old_value, new_value)
|
|
672
|
+
except Exception as e:
|
|
673
|
+
logger.error(f"Error in variable watcher for {key}: {e}")
|
|
674
|
+
|
|
675
|
+
def _trigger_shared_variable_watchers(
|
|
676
|
+
self, key: str, old_value: Any, new_value: Any
|
|
677
|
+
) -> None:
|
|
678
|
+
"""Trigger watchers for shared variable changes."""
|
|
679
|
+
if key in self._shared_variable_watchers and old_value != new_value:
|
|
680
|
+
for handler in self._shared_variable_watchers[key]:
|
|
681
|
+
try:
|
|
682
|
+
if asyncio.iscoroutinefunction(handler):
|
|
683
|
+
task = asyncio.create_task(handler(old_value, new_value))
|
|
684
|
+
# Store task reference to prevent it from being garbage collected
|
|
685
|
+
if not hasattr(self, "_background_tasks"):
|
|
686
|
+
self._background_tasks = set()
|
|
687
|
+
self._background_tasks.add(task)
|
|
688
|
+
task.add_done_callback(
|
|
689
|
+
lambda t: self._background_tasks.discard(t)
|
|
690
|
+
)
|
|
691
|
+
else:
|
|
692
|
+
handler(old_value, new_value)
|
|
693
|
+
except Exception as e:
|
|
694
|
+
logger.error(f"Error in shared variable watcher for {key}: {e}")
|
|
695
|
+
|
|
696
|
+
# State change events
|
|
697
|
+
def on_state_change(self, handler: Callable) -> None:
|
|
698
|
+
"""Register handler for state changes."""
|
|
699
|
+
self._state_change_handlers.append(handler)
|
|
700
|
+
|
|
701
|
+
def _trigger_state_change(self, old_state: Any, new_state: Any) -> None:
|
|
702
|
+
"""Trigger state change handlers."""
|
|
703
|
+
for handler in self._state_change_handlers:
|
|
704
|
+
try:
|
|
705
|
+
if asyncio.iscoroutinefunction(handler):
|
|
706
|
+
task = asyncio.create_task(handler(old_state, new_state))
|
|
707
|
+
# Store task reference to prevent it from being garbage collected
|
|
708
|
+
if not hasattr(self, "_background_tasks"):
|
|
709
|
+
self._background_tasks = set()
|
|
710
|
+
self._background_tasks.add(task)
|
|
711
|
+
task.add_done_callback(lambda t: self._background_tasks.discard(t))
|
|
712
|
+
else:
|
|
713
|
+
handler(old_state, new_state)
|
|
714
|
+
except Exception as e:
|
|
715
|
+
logger.error(f"Error in state change handler: {e}")
|
|
716
|
+
|
|
717
|
+
# Team coordination
|
|
718
|
+
def set_team(self, team: "AgentTeam") -> None:
|
|
719
|
+
"""Set the team this agent belongs to."""
|
|
720
|
+
self._team = weakref.ref(team)
|
|
721
|
+
|
|
722
|
+
def get_team(self) -> Optional["AgentTeam"]:
|
|
723
|
+
"""Get the team this agent belongs to."""
|
|
724
|
+
if self._team:
|
|
725
|
+
return self._team()
|
|
726
|
+
return None
|
|
727
|
+
|
|
728
|
+
# Messaging
|
|
729
|
+
def message_handler(self, message_type: str) -> Callable:
|
|
730
|
+
"""Decorator for message handlers."""
|
|
731
|
+
|
|
732
|
+
def decorator(func: Callable) -> Callable:
|
|
733
|
+
self._message_handlers[message_type] = func
|
|
734
|
+
return func
|
|
735
|
+
|
|
736
|
+
return decorator
|
|
737
|
+
|
|
738
|
+
async def send_message_to(
|
|
739
|
+
self, agent_name: str, message: dict[str, Any]
|
|
740
|
+
) -> dict[str, Any]:
|
|
741
|
+
"""Send message to another agent."""
|
|
742
|
+
team = self.get_team()
|
|
743
|
+
if team:
|
|
744
|
+
return await team.send_message(self.name, agent_name, message)
|
|
745
|
+
raise RuntimeError("Agent must be part of a team to send messages")
|
|
746
|
+
|
|
747
|
+
async def reply_to(self, sender_agent: str, message: dict[str, Any]) -> None:
|
|
748
|
+
"""Reply to a message from another agent."""
|
|
749
|
+
await self.send_message_to(sender_agent, message)
|
|
750
|
+
|
|
751
|
+
async def broadcast_message(self, message_type: str, data: dict[str, Any]) -> None:
|
|
752
|
+
"""Broadcast message to all agents in team."""
|
|
753
|
+
team = self.get_team()
|
|
754
|
+
if team:
|
|
755
|
+
await team.broadcast_message(self.name, message_type, data)
|
|
756
|
+
|
|
757
|
+
async def handle_message(
|
|
758
|
+
self, message_type: str, message: dict[str, Any], sender: str
|
|
759
|
+
) -> dict[str, Any]:
|
|
760
|
+
"""Handle incoming message."""
|
|
761
|
+
if message_type in self._message_handlers:
|
|
762
|
+
result = await self._message_handlers[message_type](message, sender)
|
|
763
|
+
return dict(result) if result else {}
|
|
764
|
+
return {}
|
|
765
|
+
|
|
766
|
+
# Event system
|
|
767
|
+
def on_event(self, event_type: str) -> Callable:
|
|
768
|
+
"""Decorator for event handlers."""
|
|
769
|
+
|
|
770
|
+
def decorator(func: Callable) -> Callable:
|
|
771
|
+
if event_type not in self._event_handlers:
|
|
772
|
+
self._event_handlers[event_type] = []
|
|
773
|
+
self._event_handlers[event_type].append(func)
|
|
774
|
+
return func
|
|
775
|
+
|
|
776
|
+
return decorator
|
|
777
|
+
|
|
778
|
+
async def emit_event(self, event_type: str, data: dict[str, Any]) -> None:
|
|
779
|
+
"""Emit an event."""
|
|
780
|
+
team = self.get_team()
|
|
781
|
+
if team:
|
|
782
|
+
await team.emit_event(self.name, event_type, data)
|
|
783
|
+
|
|
784
|
+
# Handle local events
|
|
785
|
+
if event_type in self._event_handlers:
|
|
786
|
+
for handler in self._event_handlers[event_type]:
|
|
787
|
+
try:
|
|
788
|
+
if asyncio.iscoroutinefunction(handler):
|
|
789
|
+
await handler(self._context, data)
|
|
790
|
+
else:
|
|
791
|
+
handler(self._context, data)
|
|
792
|
+
except Exception as e:
|
|
793
|
+
logger.error(f"Error in event handler for {event_type}: {e}")
|
|
794
|
+
|
|
795
|
+
# Variable synchronization
|
|
796
|
+
async def sync_variables_with(
|
|
797
|
+
self, agent_name: str, variable_names: list[str]
|
|
798
|
+
) -> None:
|
|
799
|
+
"""Sync specific variables with another agent."""
|
|
800
|
+
team = self.get_team()
|
|
801
|
+
if team:
|
|
802
|
+
other_agent = team.get_agent(agent_name)
|
|
803
|
+
if other_agent:
|
|
804
|
+
for var_name in variable_names:
|
|
805
|
+
value = self.get_variable(var_name)
|
|
806
|
+
other_agent.set_variable(var_name, value)
|
|
807
|
+
|
|
808
|
+
async def wait_for_agent_variable(
|
|
809
|
+
self,
|
|
810
|
+
agent_name: str,
|
|
811
|
+
variable_name: str,
|
|
812
|
+
expected_value: Any,
|
|
813
|
+
timeout: Optional[float] = None,
|
|
814
|
+
) -> bool:
|
|
815
|
+
"""Wait for another agent's variable to reach a specific value."""
|
|
816
|
+
team = self.get_team()
|
|
817
|
+
if not team:
|
|
818
|
+
return False
|
|
819
|
+
|
|
820
|
+
other_agent = team.get_agent(agent_name)
|
|
821
|
+
if not other_agent:
|
|
822
|
+
return False
|
|
823
|
+
|
|
824
|
+
start_time = time.time()
|
|
825
|
+
while True:
|
|
826
|
+
current_value = other_agent.get_variable(variable_name)
|
|
827
|
+
if current_value == expected_value:
|
|
828
|
+
return True
|
|
829
|
+
|
|
830
|
+
if timeout and (time.time() - start_time) > timeout:
|
|
831
|
+
return False
|
|
832
|
+
|
|
833
|
+
await asyncio.sleep(0.1)
|
|
834
|
+
|
|
835
|
+
def get_synced_variable(
|
|
836
|
+
self, agent_name: str, variable_name: str, default: Any = None
|
|
837
|
+
) -> Any:
|
|
838
|
+
"""Get a variable from another agent."""
|
|
839
|
+
team = self.get_team()
|
|
840
|
+
if team:
|
|
841
|
+
other_agent = team.get_agent(agent_name)
|
|
842
|
+
if other_agent:
|
|
843
|
+
return other_agent.get_variable(variable_name, default)
|
|
844
|
+
return default
|
|
845
|
+
|
|
846
|
+
# Context creation
|
|
847
|
+
def _create_context(self, shared_state: dict[str, Any]) -> Context:
|
|
848
|
+
"""Create enhanced context with agent variables."""
|
|
849
|
+
context = Context(shared_state)
|
|
850
|
+
|
|
851
|
+
# Copy agent variables to context
|
|
852
|
+
for key, value in self._agent_variables.items():
|
|
853
|
+
context.set_variable(key, value)
|
|
854
|
+
|
|
855
|
+
self._context = context
|
|
856
|
+
return context
|
|
857
|
+
|
|
858
|
+
# State management (keeping existing methods)
|
|
859
|
+
def add_state(
|
|
860
|
+
self,
|
|
861
|
+
name: str,
|
|
862
|
+
func: Callable,
|
|
863
|
+
dependencies: Optional[list[str]] = None,
|
|
864
|
+
resources: Optional[
|
|
865
|
+
Any
|
|
866
|
+
] = None, # Using Any since ResourceRequirements may not be available
|
|
867
|
+
priority: Optional[Priority] = None,
|
|
868
|
+
retry_policy: Optional[RetryPolicy] = None,
|
|
869
|
+
coordination_primitives: Optional[list["CoordinationPrimitive"]] = None,
|
|
870
|
+
max_retries: Optional[int] = None,
|
|
871
|
+
**kwargs: Any,
|
|
872
|
+
) -> None:
|
|
873
|
+
"""Add a state to the agent."""
|
|
874
|
+
self.states[name] = func
|
|
875
|
+
self.dependencies[name] = dependencies or []
|
|
876
|
+
|
|
877
|
+
# Extract decorator requirements if available
|
|
878
|
+
decorator_requirements = self._extract_decorator_requirements(func)
|
|
879
|
+
final_requirements = resources or decorator_requirements
|
|
880
|
+
|
|
881
|
+
# Get priority from function if not provided
|
|
882
|
+
final_priority = priority
|
|
883
|
+
if hasattr(func, "_priority"):
|
|
884
|
+
final_priority = func._priority
|
|
885
|
+
elif final_priority is None:
|
|
886
|
+
final_priority = Priority.NORMAL
|
|
887
|
+
|
|
888
|
+
# Ensure final_priority is not None
|
|
889
|
+
if final_priority is None:
|
|
890
|
+
final_priority = Priority.NORMAL
|
|
891
|
+
|
|
892
|
+
if final_requirements is None and _ResourceRequirements is not None:
|
|
893
|
+
final_requirements = _ResourceRequirements()
|
|
894
|
+
|
|
895
|
+
# Use max_retries parameter or fall back to retry_policy or agent default
|
|
896
|
+
final_max_retries = max_retries or (
|
|
897
|
+
retry_policy.max_retries if retry_policy else self.retry_policy.max_retries
|
|
898
|
+
)
|
|
899
|
+
|
|
900
|
+
# Create state metadata
|
|
901
|
+
metadata = StateMetadata(
|
|
902
|
+
status=StateStatus.PENDING,
|
|
903
|
+
priority=final_priority,
|
|
904
|
+
resources=final_requirements,
|
|
905
|
+
retry_policy=retry_policy or self.retry_policy,
|
|
906
|
+
coordination_primitives=coordination_primitives or [],
|
|
907
|
+
max_retries=final_max_retries,
|
|
908
|
+
)
|
|
909
|
+
|
|
910
|
+
self.state_metadata[name] = metadata
|
|
911
|
+
|
|
912
|
+
def _extract_decorator_requirements(
|
|
913
|
+
self, func: Callable
|
|
914
|
+
) -> Optional[Any]: # Using Any since ResourceRequirements may not be available
|
|
915
|
+
"""Extract resource requirements from decorator metadata."""
|
|
916
|
+
if hasattr(func, "_resource_requirements"):
|
|
917
|
+
requirements = func._resource_requirements
|
|
918
|
+
if _ResourceRequirements is not None and isinstance(
|
|
919
|
+
requirements, _ResourceRequirements
|
|
920
|
+
):
|
|
921
|
+
return requirements
|
|
922
|
+
return None
|
|
923
|
+
|
|
924
|
+
# Checkpointing
|
|
925
|
+
def create_checkpoint(self) -> AgentCheckpoint:
|
|
926
|
+
"""Create a checkpoint of current agent state."""
|
|
927
|
+
return AgentCheckpoint.create_from_agent(self)
|
|
928
|
+
|
|
929
|
+
async def restore_from_checkpoint(self, checkpoint: AgentCheckpoint) -> None:
|
|
930
|
+
"""Restore agent from checkpoint."""
|
|
931
|
+
self.status = checkpoint.agent_status
|
|
932
|
+
self.priority_queue = checkpoint.priority_queue.copy()
|
|
933
|
+
self.state_metadata = checkpoint.state_metadata.copy()
|
|
934
|
+
self.running_states = checkpoint.running_states.copy()
|
|
935
|
+
self.completed_states = checkpoint.completed_states.copy()
|
|
936
|
+
self.completed_once = checkpoint.completed_once.copy()
|
|
937
|
+
self.shared_state = checkpoint.shared_state.copy()
|
|
938
|
+
self.session_start = checkpoint.session_start
|
|
939
|
+
|
|
940
|
+
async def save_checkpoint(self) -> str:
|
|
941
|
+
"""Save current state as checkpoint with persistent storage."""
|
|
942
|
+
checkpoint = self.create_checkpoint()
|
|
943
|
+
|
|
944
|
+
try:
|
|
945
|
+
checkpoint_id = await self.checkpoint_storage.save_checkpoint(
|
|
946
|
+
agent_name=self.name, checkpoint=checkpoint
|
|
947
|
+
)
|
|
948
|
+
logger.info(
|
|
949
|
+
f"Checkpoint saved for agent {self.name} with ID: {checkpoint_id}"
|
|
950
|
+
)
|
|
951
|
+
return checkpoint_id
|
|
952
|
+
|
|
953
|
+
except Exception as e:
|
|
954
|
+
logger.error(f"Failed to save checkpoint for agent {self.name}: {e}")
|
|
955
|
+
raise
|
|
956
|
+
|
|
957
|
+
async def load_checkpoint(self, checkpoint_id: Optional[str] = None) -> bool:
|
|
958
|
+
"""
|
|
959
|
+
Load agent state from a checkpoint.
|
|
960
|
+
|
|
961
|
+
Args:
|
|
962
|
+
checkpoint_id: Specific checkpoint ID to load, or None for latest
|
|
963
|
+
|
|
964
|
+
Returns:
|
|
965
|
+
True if checkpoint was loaded successfully, False otherwise
|
|
966
|
+
"""
|
|
967
|
+
try:
|
|
968
|
+
checkpoint = await self.checkpoint_storage.load_checkpoint(
|
|
969
|
+
agent_name=self.name, checkpoint_id=checkpoint_id
|
|
970
|
+
)
|
|
971
|
+
|
|
972
|
+
if checkpoint is None:
|
|
973
|
+
logger.warning(f"No checkpoint found for agent {self.name}")
|
|
974
|
+
return False
|
|
975
|
+
|
|
976
|
+
await self.restore_from_checkpoint(checkpoint)
|
|
977
|
+
logger.info(
|
|
978
|
+
f"Agent {self.name} restored from checkpoint {checkpoint_id or 'latest'}"
|
|
979
|
+
)
|
|
980
|
+
return True
|
|
981
|
+
|
|
982
|
+
except Exception as e:
|
|
983
|
+
logger.error(f"Failed to load checkpoint for agent {self.name}: {e}")
|
|
984
|
+
return False
|
|
985
|
+
|
|
986
|
+
async def list_checkpoints(self) -> list[str]:
|
|
987
|
+
"""List available checkpoint IDs for this agent."""
|
|
988
|
+
try:
|
|
989
|
+
return await self.checkpoint_storage.list_checkpoints(self.name)
|
|
990
|
+
except Exception as e:
|
|
991
|
+
logger.error(f"Failed to list checkpoints for agent {self.name}: {e}")
|
|
992
|
+
return []
|
|
993
|
+
|
|
994
|
+
async def delete_checkpoint(self, checkpoint_id: str) -> bool:
|
|
995
|
+
"""Delete a specific checkpoint."""
|
|
996
|
+
try:
|
|
997
|
+
success = await self.checkpoint_storage.delete_checkpoint(
|
|
998
|
+
self.name, checkpoint_id
|
|
999
|
+
)
|
|
1000
|
+
if success:
|
|
1001
|
+
logger.info(f"Deleted checkpoint {checkpoint_id} for agent {self.name}")
|
|
1002
|
+
return success
|
|
1003
|
+
except Exception as e:
|
|
1004
|
+
logger.error(
|
|
1005
|
+
f"Failed to delete checkpoint {checkpoint_id} for agent {self.name}: {e}"
|
|
1006
|
+
)
|
|
1007
|
+
return False
|
|
1008
|
+
|
|
1009
|
+
# Execution control
|
|
1010
|
+
async def pause(self) -> AgentCheckpoint:
|
|
1011
|
+
"""Pause agent execution and return checkpoint."""
|
|
1012
|
+
self.status = AgentStatus.PAUSED
|
|
1013
|
+
return self.create_checkpoint()
|
|
1014
|
+
|
|
1015
|
+
async def resume(self) -> None:
|
|
1016
|
+
"""Resume agent execution."""
|
|
1017
|
+
if self.status == AgentStatus.PAUSED:
|
|
1018
|
+
self.status = AgentStatus.RUNNING
|
|
1019
|
+
|
|
1020
|
+
# Find entry states
|
|
1021
|
+
def _find_entry_states(self) -> list[str]:
|
|
1022
|
+
"""Find states with no dependencies."""
|
|
1023
|
+
entry_states = []
|
|
1024
|
+
for state_name in self.states:
|
|
1025
|
+
deps = self.dependencies.get(state_name, [])
|
|
1026
|
+
if not deps:
|
|
1027
|
+
entry_states.append(state_name)
|
|
1028
|
+
return entry_states
|
|
1029
|
+
|
|
1030
|
+
async def _add_to_queue(self, state_name: str, priority_boost: int = 0) -> None:
|
|
1031
|
+
"""Add state to priority queue."""
|
|
1032
|
+
if state_name not in self.state_metadata:
|
|
1033
|
+
logger.error(f"State {state_name} not found in metadata")
|
|
1034
|
+
return
|
|
1035
|
+
|
|
1036
|
+
metadata = self.state_metadata[state_name]
|
|
1037
|
+
priority_value = metadata.priority.value + priority_boost
|
|
1038
|
+
|
|
1039
|
+
prioritized_state = PrioritizedState(
|
|
1040
|
+
priority=-priority_value, # Negative for max-heap behavior
|
|
1041
|
+
timestamp=time.time(),
|
|
1042
|
+
state_name=state_name,
|
|
1043
|
+
metadata=metadata,
|
|
1044
|
+
)
|
|
1045
|
+
|
|
1046
|
+
import heapq
|
|
1047
|
+
|
|
1048
|
+
heapq.heappush(self.priority_queue, prioritized_state)
|
|
1049
|
+
|
|
1050
|
+
async def _get_ready_states(self) -> list[str]:
|
|
1051
|
+
"""Get states that are ready to run."""
|
|
1052
|
+
ready_states = []
|
|
1053
|
+
temp_queue = []
|
|
1054
|
+
|
|
1055
|
+
import heapq
|
|
1056
|
+
|
|
1057
|
+
while self.priority_queue:
|
|
1058
|
+
state = heapq.heappop(self.priority_queue)
|
|
1059
|
+
if await self._can_run(state.state_name):
|
|
1060
|
+
ready_states.append(state.state_name)
|
|
1061
|
+
# If it can't run, only put it back if it hasn't completed.
|
|
1062
|
+
elif state.state_name not in self.completed_once:
|
|
1063
|
+
temp_queue.append(state)
|
|
1064
|
+
|
|
1065
|
+
# Put non-ready states back
|
|
1066
|
+
for state in temp_queue:
|
|
1067
|
+
heapq.heappush(self.priority_queue, state)
|
|
1068
|
+
|
|
1069
|
+
return ready_states
|
|
1070
|
+
|
|
1071
|
+
async def run_state(self, state_name: str) -> None:
|
|
1072
|
+
"""Execute a single state."""
|
|
1073
|
+
if state_name in self.running_states:
|
|
1074
|
+
return
|
|
1075
|
+
|
|
1076
|
+
self.running_states.add(state_name)
|
|
1077
|
+
start_time = time.time()
|
|
1078
|
+
|
|
1079
|
+
try:
|
|
1080
|
+
await self._execute_state_with_circuit_breaker(state_name, start_time)
|
|
1081
|
+
finally:
|
|
1082
|
+
self.running_states.discard(state_name)
|
|
1083
|
+
|
|
1084
|
+
async def _execute_state_with_circuit_breaker(
|
|
1085
|
+
self, state_name: str, start_time: float
|
|
1086
|
+
) -> None:
|
|
1087
|
+
"""Execute state with circuit breaker protection."""
|
|
1088
|
+
try:
|
|
1089
|
+
async with self.circuit_breaker.protect():
|
|
1090
|
+
await self._execute_state_core(state_name, start_time)
|
|
1091
|
+
except Exception as e:
|
|
1092
|
+
await self._handle_state_failure(state_name, e, start_time)
|
|
1093
|
+
|
|
1094
|
+
async def _execute_state_core(self, state_name: str, start_time: float) -> None:
|
|
1095
|
+
"""Core state execution logic."""
|
|
1096
|
+
metadata = self.state_metadata[state_name]
|
|
1097
|
+
|
|
1098
|
+
# Get timeout from resources or default
|
|
1099
|
+
state_timeout = None
|
|
1100
|
+
if metadata.resources and hasattr(metadata.resources, "timeout"):
|
|
1101
|
+
state_timeout = metadata.resources.timeout
|
|
1102
|
+
|
|
1103
|
+
# Acquire resources (pass agent name for leak detection)
|
|
1104
|
+
resources = metadata.resources
|
|
1105
|
+
if resources is None and _ResourceRequirements is not None:
|
|
1106
|
+
resources = _ResourceRequirements()
|
|
1107
|
+
|
|
1108
|
+
# Only try to acquire resources if we have a valid ResourceRequirements object
|
|
1109
|
+
if resources is not None:
|
|
1110
|
+
resource_acquired = await self.resource_pool.acquire(
|
|
1111
|
+
state_name, resources, timeout=state_timeout, agent_name=self.name
|
|
1112
|
+
)
|
|
1113
|
+
|
|
1114
|
+
if not resource_acquired:
|
|
1115
|
+
raise ResourceTimeoutError(
|
|
1116
|
+
f"Failed to acquire resources for state {state_name}"
|
|
1117
|
+
)
|
|
1118
|
+
|
|
1119
|
+
try:
|
|
1120
|
+
# Execute the state function
|
|
1121
|
+
context = self._create_context(self.shared_state)
|
|
1122
|
+
|
|
1123
|
+
# Execute with timeout if specified
|
|
1124
|
+
if state_timeout:
|
|
1125
|
+
result = await asyncio.wait_for(
|
|
1126
|
+
self.states[state_name](context), timeout=state_timeout
|
|
1127
|
+
)
|
|
1128
|
+
else:
|
|
1129
|
+
result = await self.states[state_name](context)
|
|
1130
|
+
|
|
1131
|
+
time.time() - start_time
|
|
1132
|
+
|
|
1133
|
+
# Update metadata on success
|
|
1134
|
+
metadata.status = StateStatus.COMPLETED
|
|
1135
|
+
metadata.last_execution = time.time()
|
|
1136
|
+
metadata.last_success = time.time()
|
|
1137
|
+
metadata.attempts += 1
|
|
1138
|
+
|
|
1139
|
+
# Track completion
|
|
1140
|
+
self.completed_states.add(state_name)
|
|
1141
|
+
self.completed_once.add(state_name)
|
|
1142
|
+
|
|
1143
|
+
# Handle transitions/next states
|
|
1144
|
+
await self._handle_state_result(state_name, result)
|
|
1145
|
+
|
|
1146
|
+
# Update shared state from context
|
|
1147
|
+
self.shared_state.update(context.shared_state)
|
|
1148
|
+
|
|
1149
|
+
finally:
|
|
1150
|
+
# Always release resources if they were acquired
|
|
1151
|
+
if resources is not None:
|
|
1152
|
+
await self.resource_pool.release(state_name)
|
|
1153
|
+
|
|
1154
|
+
async def _handle_state_result(self, state_name: str, result: StateResult) -> None:
|
|
1155
|
+
"""Handle the result of state execution."""
|
|
1156
|
+
if result is None:
|
|
1157
|
+
return
|
|
1158
|
+
|
|
1159
|
+
if isinstance(result, str):
|
|
1160
|
+
# Single next state
|
|
1161
|
+
if result in self.states and result not in self.completed_states:
|
|
1162
|
+
await self._add_to_queue(result)
|
|
1163
|
+
elif isinstance(result, list):
|
|
1164
|
+
# Multiple next states
|
|
1165
|
+
for next_state in result:
|
|
1166
|
+
if (
|
|
1167
|
+
isinstance(next_state, str)
|
|
1168
|
+
and next_state in self.states
|
|
1169
|
+
and next_state not in self.completed_states
|
|
1170
|
+
):
|
|
1171
|
+
await self._add_to_queue(next_state)
|
|
1172
|
+
elif isinstance(next_state, tuple):
|
|
1173
|
+
# Handle agent transition: (agent, state)
|
|
1174
|
+
agent, state = next_state
|
|
1175
|
+
if hasattr(agent, "add_to_queue"):
|
|
1176
|
+
await agent._add_to_queue(state)
|
|
1177
|
+
|
|
1178
|
+
async def _handle_state_failure(
|
|
1179
|
+
self, state_name: str, error: Exception, start_time: float
|
|
1180
|
+
) -> None:
|
|
1181
|
+
"""Handle state execution failure."""
|
|
1182
|
+
metadata = self.state_metadata[state_name]
|
|
1183
|
+
metadata.attempts += 1
|
|
1184
|
+
|
|
1185
|
+
# Check if we've exceeded max retries
|
|
1186
|
+
if metadata.attempts >= metadata.max_retries:
|
|
1187
|
+
metadata.status = StateStatus.FAILED
|
|
1188
|
+
|
|
1189
|
+
# Check for compensation state
|
|
1190
|
+
compensation_state = f"{state_name}_compensation"
|
|
1191
|
+
if compensation_state in self.states:
|
|
1192
|
+
await self._add_to_queue(compensation_state)
|
|
1193
|
+
|
|
1194
|
+
# Determine if this should go to dead letter queue
|
|
1195
|
+
retry_policy = metadata.retry_policy or self.retry_policy
|
|
1196
|
+
should_dead_letter = (
|
|
1197
|
+
self.enable_dead_letter and retry_policy.dead_letter_on_max_retries
|
|
1198
|
+
)
|
|
1199
|
+
|
|
1200
|
+
if should_dead_letter:
|
|
1201
|
+
dead_letter = DeadLetter(
|
|
1202
|
+
state_name=state_name,
|
|
1203
|
+
agent_name=self.name,
|
|
1204
|
+
error_message=str(error),
|
|
1205
|
+
error_type=type(error).__name__,
|
|
1206
|
+
attempts=metadata.attempts,
|
|
1207
|
+
failed_at=time.time(),
|
|
1208
|
+
timeout_occurred=isinstance(error, asyncio.TimeoutError),
|
|
1209
|
+
context_snapshot=dict(self.shared_state),
|
|
1210
|
+
)
|
|
1211
|
+
self.dead_letters.append(dead_letter)
|
|
1212
|
+
else:
|
|
1213
|
+
# Retry the state
|
|
1214
|
+
metadata.status = StateStatus.PENDING
|
|
1215
|
+
if metadata.retry_policy:
|
|
1216
|
+
await metadata.retry_policy.wait(metadata.attempts - 1)
|
|
1217
|
+
await self._add_to_queue(state_name)
|
|
1218
|
+
|
|
1219
|
+
# Add alias for backward compatibility
|
|
1220
|
+
async def _handle_failure(
|
|
1221
|
+
self, state_name: str, error: Exception, start_time: Optional[float] = None
|
|
1222
|
+
) -> None:
|
|
1223
|
+
"""Handle state execution failure (alias for backward compatibility)."""
|
|
1224
|
+
if start_time is None:
|
|
1225
|
+
start_time = time.time()
|
|
1226
|
+
await self._handle_state_failure(state_name, error, start_time)
|
|
1227
|
+
|
|
1228
|
+
async def _resolve_dependencies(self, state_name: str) -> None:
|
|
1229
|
+
"""Resolve dependencies for a state."""
|
|
1230
|
+
deps = self.dependencies.get(state_name, [])
|
|
1231
|
+
unmet_deps = [dep for dep in deps if dep not in self.completed_states]
|
|
1232
|
+
|
|
1233
|
+
if unmet_deps:
|
|
1234
|
+
logger.warning(f"State {state_name} has unmet dependencies: {unmet_deps}")
|
|
1235
|
+
|
|
1236
|
+
async def _can_run(self, state_name: str) -> bool:
|
|
1237
|
+
"""Check if a state can run."""
|
|
1238
|
+
if state_name in self.running_states:
|
|
1239
|
+
return False
|
|
1240
|
+
|
|
1241
|
+
if state_name in self.completed_once:
|
|
1242
|
+
return False
|
|
1243
|
+
|
|
1244
|
+
# Check dependencies
|
|
1245
|
+
deps = self.dependencies.get(state_name, [])
|
|
1246
|
+
return all(dep in self.completed_states for dep in deps)
|
|
1247
|
+
|
|
1248
|
+
# State control
|
|
1249
|
+
def cancel_state(self, state_name: str) -> None:
|
|
1250
|
+
"""Cancel a running or queued state."""
|
|
1251
|
+
# Remove from queue
|
|
1252
|
+
import heapq
|
|
1253
|
+
|
|
1254
|
+
self.priority_queue = [
|
|
1255
|
+
s for s in self.priority_queue if s.state_name != state_name
|
|
1256
|
+
]
|
|
1257
|
+
heapq.heapify(self.priority_queue)
|
|
1258
|
+
|
|
1259
|
+
# Remove from running states
|
|
1260
|
+
self.running_states.discard(state_name)
|
|
1261
|
+
|
|
1262
|
+
# Update metadata
|
|
1263
|
+
if state_name in self.state_metadata:
|
|
1264
|
+
self.state_metadata[state_name].status = StateStatus.CANCELLED
|
|
1265
|
+
|
|
1266
|
+
async def cancel_all(self) -> None:
|
|
1267
|
+
"""Cancel all running and queued states."""
|
|
1268
|
+
self.priority_queue.clear()
|
|
1269
|
+
self.running_states.clear()
|
|
1270
|
+
self.status = AgentStatus.CANCELLED
|
|
1271
|
+
|
|
1272
|
+
# Information methods
|
|
1273
|
+
def get_resource_status(self) -> dict[str, Any]:
|
|
1274
|
+
"""Get current resource status."""
|
|
1275
|
+
status = {
|
|
1276
|
+
"available": dict(self.resource_pool.available),
|
|
1277
|
+
"allocated": self.resource_pool.get_state_allocations(),
|
|
1278
|
+
"waiting": list(self.resource_pool.get_waiting_states()),
|
|
1279
|
+
"preempted": list(self.resource_pool.get_preempted_states()),
|
|
1280
|
+
}
|
|
1281
|
+
return status
|
|
1282
|
+
|
|
1283
|
+
def get_state_info(self, state_name: str) -> dict[str, Any]:
|
|
1284
|
+
"""Get information about a specific state."""
|
|
1285
|
+
if state_name not in self.states:
|
|
1286
|
+
return {}
|
|
1287
|
+
|
|
1288
|
+
metadata = self.state_metadata.get(state_name)
|
|
1289
|
+
has_decorator = False
|
|
1290
|
+
try:
|
|
1291
|
+
from .decorators.inspection import is_puffinflow_state
|
|
1292
|
+
|
|
1293
|
+
has_decorator = is_puffinflow_state(self.states[state_name])
|
|
1294
|
+
except ImportError:
|
|
1295
|
+
pass
|
|
1296
|
+
|
|
1297
|
+
return {
|
|
1298
|
+
"name": state_name,
|
|
1299
|
+
"status": metadata.status if metadata else "unknown",
|
|
1300
|
+
"dependencies": self.dependencies.get(state_name, []),
|
|
1301
|
+
"has_decorator": has_decorator,
|
|
1302
|
+
"in_queue": any(s.state_name == state_name for s in self.priority_queue),
|
|
1303
|
+
"running": state_name in self.running_states,
|
|
1304
|
+
"completed": state_name in self.completed_states,
|
|
1305
|
+
}
|
|
1306
|
+
|
|
1307
|
+
def list_states(self) -> list[dict[str, Any]]:
|
|
1308
|
+
"""List all states with their information."""
|
|
1309
|
+
result = []
|
|
1310
|
+
for name in self.states:
|
|
1311
|
+
try:
|
|
1312
|
+
from .decorators.inspection import is_puffinflow_state
|
|
1313
|
+
|
|
1314
|
+
has_decorator = is_puffinflow_state(self.states[name])
|
|
1315
|
+
except ImportError:
|
|
1316
|
+
has_decorator = False
|
|
1317
|
+
|
|
1318
|
+
metadata = self.state_metadata.get(name)
|
|
1319
|
+
status = metadata.status if metadata is not None else "unknown"
|
|
1320
|
+
|
|
1321
|
+
result.append(
|
|
1322
|
+
{
|
|
1323
|
+
"name": name,
|
|
1324
|
+
"has_decorator": has_decorator,
|
|
1325
|
+
"dependencies": self.dependencies.get(name, []),
|
|
1326
|
+
"status": status,
|
|
1327
|
+
}
|
|
1328
|
+
)
|
|
1329
|
+
return result
|
|
1330
|
+
|
|
1331
|
+
# Dead letter management
|
|
1332
|
+
def get_dead_letters(self) -> list[DeadLetter]:
|
|
1333
|
+
"""Get all dead letters."""
|
|
1334
|
+
return self.dead_letters.copy()
|
|
1335
|
+
|
|
1336
|
+
def clear_dead_letters(self) -> None:
|
|
1337
|
+
"""Clear all dead letters."""
|
|
1338
|
+
count = len(self.dead_letters)
|
|
1339
|
+
self.dead_letters.clear()
|
|
1340
|
+
logger.info(f"Cleared {count} dead letters for agent {self.name}")
|
|
1341
|
+
|
|
1342
|
+
def get_dead_letter_count(self) -> int:
|
|
1343
|
+
"""Get count of dead letters."""
|
|
1344
|
+
return len(self.dead_letters)
|
|
1345
|
+
|
|
1346
|
+
def get_dead_letters_by_state(self, state_name: str) -> list[DeadLetter]:
|
|
1347
|
+
"""Get dead letters for a specific state."""
|
|
1348
|
+
return [dl for dl in self.dead_letters if dl.state_name == state_name]
|
|
1349
|
+
|
|
1350
|
+
# Circuit breaker control
|
|
1351
|
+
async def force_circuit_breaker_open(self) -> None:
|
|
1352
|
+
"""Force circuit breaker to open state."""
|
|
1353
|
+
await self.circuit_breaker.force_open()
|
|
1354
|
+
|
|
1355
|
+
async def force_circuit_breaker_close(self) -> None:
|
|
1356
|
+
"""Force circuit breaker to close state."""
|
|
1357
|
+
await self.circuit_breaker.force_close()
|
|
1358
|
+
|
|
1359
|
+
# Resource leak detection
|
|
1360
|
+
def check_resource_leaks(self) -> list[Any]:
|
|
1361
|
+
"""Check for resource leaks."""
|
|
1362
|
+
return self.resource_pool.check_leaks()
|
|
1363
|
+
|
|
1364
|
+
# Cleanup
|
|
1365
|
+
def add_cleanup_handler(self, handler: Callable) -> None:
|
|
1366
|
+
"""Add cleanup handler."""
|
|
1367
|
+
self._cleanup_handlers.append(handler)
|
|
1368
|
+
|
|
1369
|
+
async def cleanup(self) -> None:
|
|
1370
|
+
"""Cleanup resources and handlers."""
|
|
1371
|
+
for handler in self._cleanup_handlers:
|
|
1372
|
+
try:
|
|
1373
|
+
if asyncio.iscoroutinefunction(handler):
|
|
1374
|
+
await handler()
|
|
1375
|
+
else:
|
|
1376
|
+
handler()
|
|
1377
|
+
except Exception as e:
|
|
1378
|
+
logger.error(f"Error in cleanup handler: {e}")
|
|
1379
|
+
|
|
1380
|
+
def _get_execution_metadata(self) -> dict[str, Any]:
|
|
1381
|
+
"""Get execution metadata."""
|
|
1382
|
+
return {
|
|
1383
|
+
"states_completed": list(self.completed_states),
|
|
1384
|
+
"states_failed": [dl.state_name for dl in self.dead_letters],
|
|
1385
|
+
"total_states": len(self.states),
|
|
1386
|
+
"session_start": self.session_start,
|
|
1387
|
+
"dead_letter_count": len(self.dead_letters),
|
|
1388
|
+
}
|
|
1389
|
+
|
|
1390
|
+
def _get_execution_metrics(self) -> dict[str, Any]:
|
|
1391
|
+
"""Get execution metrics."""
|
|
1392
|
+
return {
|
|
1393
|
+
"completion_rate": len(self.completed_states) / max(1, len(self.states)),
|
|
1394
|
+
"error_rate": len(self.dead_letters) / max(1, len(self.states)),
|
|
1395
|
+
"resource_usage": self.get_resource_status(),
|
|
1396
|
+
"circuit_breaker_metrics": self.circuit_breaker.get_metrics(),
|
|
1397
|
+
"bulkhead_metrics": self.bulkhead.get_metrics(),
|
|
1398
|
+
}
|
|
1399
|
+
|
|
1400
|
+
# Scheduling methods
|
|
1401
|
+
def schedule(self, when: str, **inputs: Any) -> "ScheduledAgent":
|
|
1402
|
+
"""Schedule this agent to run at specified times with given inputs.
|
|
1403
|
+
|
|
1404
|
+
Args:
|
|
1405
|
+
when: Schedule string (natural language or cron expression)
|
|
1406
|
+
Examples: "daily", "hourly", "every 5 minutes", "0 9 * * 1-5"
|
|
1407
|
+
**inputs: Input parameters with optional magic prefixes:
|
|
1408
|
+
- secret:value - Store as secret
|
|
1409
|
+
- const:value - Store as constant
|
|
1410
|
+
- cache:TTL:value - Store as cached with TTL
|
|
1411
|
+
- typed:value - Store as typed variable
|
|
1412
|
+
- output:value - Pre-set as output
|
|
1413
|
+
- value (no prefix) - Store as regular variable
|
|
1414
|
+
|
|
1415
|
+
Returns:
|
|
1416
|
+
ScheduledAgent instance for managing the scheduled execution
|
|
1417
|
+
|
|
1418
|
+
Raises:
|
|
1419
|
+
ImportError: If scheduling module is not available
|
|
1420
|
+
SchedulingError: If scheduling fails
|
|
1421
|
+
|
|
1422
|
+
Examples:
|
|
1423
|
+
# Basic scheduling
|
|
1424
|
+
agent.schedule("daily at 09:00", source="database")
|
|
1425
|
+
|
|
1426
|
+
# With magic prefixes
|
|
1427
|
+
agent.schedule(
|
|
1428
|
+
"every 30 minutes",
|
|
1429
|
+
api_key="secret:sk-1234567890abcdef",
|
|
1430
|
+
pool_size="const:10",
|
|
1431
|
+
config="cache:3600:{'timeout': 30}",
|
|
1432
|
+
source="warehouse"
|
|
1433
|
+
)
|
|
1434
|
+
"""
|
|
1435
|
+
if not _SCHEDULING_AVAILABLE:
|
|
1436
|
+
raise ImportError(
|
|
1437
|
+
"Scheduling module not available. Install required dependencies."
|
|
1438
|
+
)
|
|
1439
|
+
|
|
1440
|
+
scheduler = GlobalScheduler.get_instance_sync()
|
|
1441
|
+
return scheduler.schedule_agent(self, when, **inputs)
|
|
1442
|
+
|
|
1443
|
+
def every(self, interval: str) -> "ScheduleBuilder":
|
|
1444
|
+
"""Start fluent API for scheduling with intervals.
|
|
1445
|
+
|
|
1446
|
+
Args:
|
|
1447
|
+
interval: Interval string like "5 minutes", "2 hours", "daily"
|
|
1448
|
+
|
|
1449
|
+
Returns:
|
|
1450
|
+
ScheduleBuilder for chaining
|
|
1451
|
+
|
|
1452
|
+
Examples:
|
|
1453
|
+
agent.every("5 minutes").with_inputs(source="api").run()
|
|
1454
|
+
agent.every("daily").with_secrets(api_key="sk-123").run()
|
|
1455
|
+
"""
|
|
1456
|
+
if not _SCHEDULING_AVAILABLE:
|
|
1457
|
+
raise ImportError(
|
|
1458
|
+
"Scheduling module not available. Install required dependencies."
|
|
1459
|
+
)
|
|
1460
|
+
|
|
1461
|
+
# Handle "every X" format - avoid double "every"
|
|
1462
|
+
if not interval.startswith("every "):
|
|
1463
|
+
interval = f"every {interval}"
|
|
1464
|
+
else:
|
|
1465
|
+
# If it already starts with "every", don't add another
|
|
1466
|
+
pass
|
|
1467
|
+
|
|
1468
|
+
return ScheduleBuilder(self, interval)
|
|
1469
|
+
|
|
1470
|
+
def daily(self, time_str: Optional[str] = None) -> "ScheduleBuilder":
|
|
1471
|
+
"""Start fluent API for daily scheduling.
|
|
1472
|
+
|
|
1473
|
+
Args:
|
|
1474
|
+
time_str: Optional time like "09:00" or "2pm"
|
|
1475
|
+
|
|
1476
|
+
Returns:
|
|
1477
|
+
ScheduleBuilder for chaining
|
|
1478
|
+
|
|
1479
|
+
Examples:
|
|
1480
|
+
agent.daily().with_inputs(batch_size=1000).run()
|
|
1481
|
+
agent.daily("09:00").with_secrets(db_pass="secret123").run()
|
|
1482
|
+
"""
|
|
1483
|
+
if not _SCHEDULING_AVAILABLE:
|
|
1484
|
+
raise ImportError(
|
|
1485
|
+
"Scheduling module not available. Install required dependencies."
|
|
1486
|
+
)
|
|
1487
|
+
|
|
1488
|
+
schedule_str = f"daily at {time_str}" if time_str else "daily"
|
|
1489
|
+
|
|
1490
|
+
return ScheduleBuilder(self, schedule_str)
|
|
1491
|
+
|
|
1492
|
+
def hourly(self, minute: Optional[int] = None) -> "ScheduleBuilder":
|
|
1493
|
+
"""Start fluent API for hourly scheduling.
|
|
1494
|
+
|
|
1495
|
+
Args:
|
|
1496
|
+
minute: Optional minute of the hour (0-59)
|
|
1497
|
+
|
|
1498
|
+
Returns:
|
|
1499
|
+
ScheduleBuilder for chaining
|
|
1500
|
+
|
|
1501
|
+
Examples:
|
|
1502
|
+
agent.hourly().with_inputs(check_status=True).run()
|
|
1503
|
+
agent.hourly(30).with_constants(timeout=60).run()
|
|
1504
|
+
"""
|
|
1505
|
+
if not _SCHEDULING_AVAILABLE:
|
|
1506
|
+
raise ImportError(
|
|
1507
|
+
"Scheduling module not available. Install required dependencies."
|
|
1508
|
+
)
|
|
1509
|
+
|
|
1510
|
+
schedule_str = f"every hour at {minute}" if minute is not None else "hourly"
|
|
1511
|
+
|
|
1512
|
+
return ScheduleBuilder(self, schedule_str)
|
|
1513
|
+
|
|
1514
|
+
# Main execution
|
|
1515
|
+
async def run(self, timeout: Optional[float] = None) -> AgentResult:
|
|
1516
|
+
"""Run the agent workflow with enhanced result tracking."""
|
|
1517
|
+
start_time = time.time()
|
|
1518
|
+
self.status = AgentStatus.RUNNING
|
|
1519
|
+
|
|
1520
|
+
if self.session_start is None:
|
|
1521
|
+
self.session_start = start_time
|
|
1522
|
+
|
|
1523
|
+
try:
|
|
1524
|
+
# Check if we have any states
|
|
1525
|
+
if not self.states:
|
|
1526
|
+
logger.info("No states defined, nothing to run")
|
|
1527
|
+
self.status = AgentStatus.IDLE
|
|
1528
|
+
return AgentResult(
|
|
1529
|
+
agent_name=self.name,
|
|
1530
|
+
status=self.status,
|
|
1531
|
+
start_time=start_time,
|
|
1532
|
+
end_time=time.time(),
|
|
1533
|
+
execution_duration=time.time() - start_time,
|
|
1534
|
+
)
|
|
1535
|
+
|
|
1536
|
+
# Create context with current shared state
|
|
1537
|
+
self._create_context(self.shared_state)
|
|
1538
|
+
|
|
1539
|
+
# Find entry states (states with no dependencies)
|
|
1540
|
+
entry_states = self._find_entry_states()
|
|
1541
|
+
if not entry_states:
|
|
1542
|
+
# If no entry states found, use all states (they will be filtered by dependencies during execution)
|
|
1543
|
+
entry_states = list(self.states.keys())
|
|
1544
|
+
|
|
1545
|
+
# Add entry states to queue
|
|
1546
|
+
for state_name in entry_states:
|
|
1547
|
+
await self._add_to_queue(state_name)
|
|
1548
|
+
|
|
1549
|
+
# Main execution loop
|
|
1550
|
+
while self.status == AgentStatus.RUNNING:
|
|
1551
|
+
if timeout and (time.time() - start_time) > timeout:
|
|
1552
|
+
logger.warning(
|
|
1553
|
+
f"Agent {self.name} timed out after {timeout} seconds."
|
|
1554
|
+
)
|
|
1555
|
+
self.status = AgentStatus.FAILED
|
|
1556
|
+
break
|
|
1557
|
+
|
|
1558
|
+
# Stop if there's nothing left to do
|
|
1559
|
+
if not self.priority_queue and not self.running_states:
|
|
1560
|
+
break
|
|
1561
|
+
|
|
1562
|
+
ready_states = await self._get_ready_states()
|
|
1563
|
+
|
|
1564
|
+
if ready_states:
|
|
1565
|
+
tasks = []
|
|
1566
|
+
for state_name in ready_states[: self.max_concurrent]:
|
|
1567
|
+
task = asyncio.create_task(self.run_state(state_name))
|
|
1568
|
+
tasks.append(task)
|
|
1569
|
+
|
|
1570
|
+
if tasks:
|
|
1571
|
+
await asyncio.gather(*tasks, return_exceptions=True)
|
|
1572
|
+
|
|
1573
|
+
elif self.priority_queue and not self.running_states:
|
|
1574
|
+
# States are in queue but none can run, and nothing is running
|
|
1575
|
+
# This indicates a deadlock or unmeetable dependencies
|
|
1576
|
+
logger.warning(
|
|
1577
|
+
f"Deadlock in agent {self.name}: States in queue but none can run."
|
|
1578
|
+
)
|
|
1579
|
+
self.status = AgentStatus.FAILED
|
|
1580
|
+
break
|
|
1581
|
+
else:
|
|
1582
|
+
# No states ready, but some are running. Wait for them.
|
|
1583
|
+
await asyncio.sleep(0.01) # Short sleep to yield control
|
|
1584
|
+
|
|
1585
|
+
# Determine final status
|
|
1586
|
+
if self.status == AgentStatus.RUNNING:
|
|
1587
|
+
has_failed_states = any(
|
|
1588
|
+
s.status == StateStatus.FAILED for s in self.state_metadata.values()
|
|
1589
|
+
)
|
|
1590
|
+
if has_failed_states:
|
|
1591
|
+
self.status = AgentStatus.FAILED
|
|
1592
|
+
else:
|
|
1593
|
+
self.status = AgentStatus.COMPLETED
|
|
1594
|
+
|
|
1595
|
+
end_time = time.time()
|
|
1596
|
+
|
|
1597
|
+
# Create result
|
|
1598
|
+
result = AgentResult(
|
|
1599
|
+
agent_name=self.name,
|
|
1600
|
+
status=self.status,
|
|
1601
|
+
outputs=self.get_all_outputs(),
|
|
1602
|
+
variables={**self._agent_variables, **self.shared_state},
|
|
1603
|
+
metadata=self._get_execution_metadata(),
|
|
1604
|
+
metrics=self._get_execution_metrics(),
|
|
1605
|
+
start_time=start_time,
|
|
1606
|
+
end_time=end_time,
|
|
1607
|
+
execution_duration=end_time - start_time,
|
|
1608
|
+
)
|
|
1609
|
+
|
|
1610
|
+
return result
|
|
1611
|
+
|
|
1612
|
+
except Exception as e:
|
|
1613
|
+
self.status = AgentStatus.FAILED
|
|
1614
|
+
end_time = time.time()
|
|
1615
|
+
|
|
1616
|
+
return AgentResult(
|
|
1617
|
+
agent_name=self.name,
|
|
1618
|
+
status=self.status,
|
|
1619
|
+
error=e,
|
|
1620
|
+
start_time=start_time,
|
|
1621
|
+
end_time=end_time,
|
|
1622
|
+
execution_duration=end_time - start_time,
|
|
1623
|
+
)
|
|
1624
|
+
|
|
1625
|
+
def __del__(self) -> None:
|
|
1626
|
+
"""Cleanup on deletion."""
|
|
1627
|
+
if self._cleanup_handlers:
|
|
1628
|
+
try:
|
|
1629
|
+
loop = asyncio.get_event_loop()
|
|
1630
|
+
if loop.is_running():
|
|
1631
|
+
# Create cleanup task but don't store reference as object is being destroyed
|
|
1632
|
+
task = loop.create_task(self.cleanup())
|
|
1633
|
+
task.add_done_callback(lambda t: None) # Prevent warnings
|
|
1634
|
+
except Exception:
|
|
1635
|
+
pass
|