agnt5 0.1.0__cp39-abi3-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agnt5/__init__.py +307 -0
- agnt5/__pycache__/__init__.cpython-311.pyc +0 -0
- agnt5/__pycache__/agent.cpython-311.pyc +0 -0
- agnt5/__pycache__/context.cpython-311.pyc +0 -0
- agnt5/__pycache__/durable.cpython-311.pyc +0 -0
- agnt5/__pycache__/extraction.cpython-311.pyc +0 -0
- agnt5/__pycache__/memory.cpython-311.pyc +0 -0
- agnt5/__pycache__/reflection.cpython-311.pyc +0 -0
- agnt5/__pycache__/runtime.cpython-311.pyc +0 -0
- agnt5/__pycache__/task.cpython-311.pyc +0 -0
- agnt5/__pycache__/tool.cpython-311.pyc +0 -0
- agnt5/__pycache__/tracing.cpython-311.pyc +0 -0
- agnt5/__pycache__/types.cpython-311.pyc +0 -0
- agnt5/__pycache__/workflow.cpython-311.pyc +0 -0
- agnt5/_core.abi3.so +0 -0
- agnt5/agent.py +1086 -0
- agnt5/context.py +406 -0
- agnt5/durable.py +1050 -0
- agnt5/extraction.py +410 -0
- agnt5/llm/__init__.py +179 -0
- agnt5/llm/__pycache__/__init__.cpython-311.pyc +0 -0
- agnt5/llm/__pycache__/anthropic.cpython-311.pyc +0 -0
- agnt5/llm/__pycache__/azure.cpython-311.pyc +0 -0
- agnt5/llm/__pycache__/base.cpython-311.pyc +0 -0
- agnt5/llm/__pycache__/google.cpython-311.pyc +0 -0
- agnt5/llm/__pycache__/mistral.cpython-311.pyc +0 -0
- agnt5/llm/__pycache__/openai.cpython-311.pyc +0 -0
- agnt5/llm/__pycache__/together.cpython-311.pyc +0 -0
- agnt5/llm/anthropic.py +319 -0
- agnt5/llm/azure.py +348 -0
- agnt5/llm/base.py +315 -0
- agnt5/llm/google.py +373 -0
- agnt5/llm/mistral.py +330 -0
- agnt5/llm/model_registry.py +467 -0
- agnt5/llm/models.json +227 -0
- agnt5/llm/openai.py +334 -0
- agnt5/llm/together.py +377 -0
- agnt5/memory.py +746 -0
- agnt5/reflection.py +514 -0
- agnt5/runtime.py +699 -0
- agnt5/task.py +476 -0
- agnt5/testing.py +451 -0
- agnt5/tool.py +516 -0
- agnt5/tracing.py +624 -0
- agnt5/types.py +210 -0
- agnt5/workflow.py +897 -0
- agnt5-0.1.0.dist-info/METADATA +93 -0
- agnt5-0.1.0.dist-info/RECORD +49 -0
- agnt5-0.1.0.dist-info/WHEEL +4 -0
agnt5/workflow.py
ADDED
|
@@ -0,0 +1,897 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Workflow implementation for the AGNT5 SDK.
|
|
3
|
+
|
|
4
|
+
Workflows orchestrate complex multi-step processes with durability guarantees.
|
|
5
|
+
They support parallel execution, error handling, and state management.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import (
|
|
9
|
+
Any, Callable, Dict, List, Optional, Union, TypeVar, Generic,
|
|
10
|
+
AsyncIterator, Tuple, Set
|
|
11
|
+
)
|
|
12
|
+
import asyncio
|
|
13
|
+
from dataclasses import dataclass, field
|
|
14
|
+
import inspect
|
|
15
|
+
import logging
|
|
16
|
+
from datetime import datetime
|
|
17
|
+
import json
|
|
18
|
+
|
|
19
|
+
from .types import (
|
|
20
|
+
WorkflowConfig,
|
|
21
|
+
ExecutionContext,
|
|
22
|
+
ExecutionState,
|
|
23
|
+
DurablePromise,
|
|
24
|
+
)
|
|
25
|
+
from .context import Context, get_context
|
|
26
|
+
from .durable import durable, DurableContext
|
|
27
|
+
# from .agent import Agent # Commented out to avoid dependency issues
|
|
28
|
+
# from .tool import Tool # Commented out to avoid dependency issues
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
T = TypeVar('T')
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class Step:
|
|
38
|
+
"""Represents a step in a workflow."""
|
|
39
|
+
name: str
|
|
40
|
+
func: Callable
|
|
41
|
+
args: Tuple[Any, ...] = field(default_factory=tuple)
|
|
42
|
+
kwargs: Dict[str, Any] = field(default_factory=dict)
|
|
43
|
+
dependencies: Set[str] = field(default_factory=set)
|
|
44
|
+
result: Optional[Any] = None
|
|
45
|
+
error: Optional[Exception] = None
|
|
46
|
+
started_at: Optional[datetime] = None
|
|
47
|
+
completed_at: Optional[datetime] = None
|
|
48
|
+
retries: int = 0
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class Workflow:
|
|
52
|
+
"""
|
|
53
|
+
Orchestrates complex multi-step processes with durability.
|
|
54
|
+
|
|
55
|
+
Example:
|
|
56
|
+
```python
|
|
57
|
+
from agnt5 import Workflow, Agent
|
|
58
|
+
|
|
59
|
+
class DataProcessingWorkflow(Workflow):
|
|
60
|
+
'''Process data through multiple stages.'''
|
|
61
|
+
|
|
62
|
+
async def run(self, data: List[str]) -> Dict[str, Any]:
|
|
63
|
+
# Step 1: Validate data
|
|
64
|
+
validated = await self.step("validate", self.validate_data, data)
|
|
65
|
+
|
|
66
|
+
# Step 2: Process in parallel
|
|
67
|
+
results = await self.parallel([
|
|
68
|
+
("analyze", self.analyze_data, validated),
|
|
69
|
+
("transform", self.transform_data, validated),
|
|
70
|
+
])
|
|
71
|
+
|
|
72
|
+
# Step 3: Combine results
|
|
73
|
+
final = await self.step("combine", self.combine_results, results)
|
|
74
|
+
|
|
75
|
+
return final
|
|
76
|
+
|
|
77
|
+
async def validate_data(self, data: List[str]) -> List[str]:
|
|
78
|
+
# Validation logic
|
|
79
|
+
return [d for d in data if d]
|
|
80
|
+
|
|
81
|
+
async def analyze_data(self, data: List[str]) -> Dict[str, Any]:
|
|
82
|
+
# Analysis logic
|
|
83
|
+
return {"count": len(data), "unique": len(set(data))}
|
|
84
|
+
|
|
85
|
+
async def transform_data(self, data: List[str]) -> List[str]:
|
|
86
|
+
# Transformation logic
|
|
87
|
+
return [d.upper() for d in data]
|
|
88
|
+
|
|
89
|
+
async def combine_results(self, results: List[Any]) -> Dict[str, Any]:
|
|
90
|
+
# Combine parallel results
|
|
91
|
+
return {
|
|
92
|
+
"analysis": results[0],
|
|
93
|
+
"transformed": results[1],
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
# Use the workflow
|
|
97
|
+
workflow = DataProcessingWorkflow(name="data-processor")
|
|
98
|
+
result = await workflow.execute(["hello", "world"])
|
|
99
|
+
```
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
def __init__(
|
|
103
|
+
self,
|
|
104
|
+
name: str,
|
|
105
|
+
*,
|
|
106
|
+
description: Optional[str] = None,
|
|
107
|
+
version: str = "1.0.0",
|
|
108
|
+
config: Optional[WorkflowConfig] = None,
|
|
109
|
+
):
|
|
110
|
+
"""Initialize a Workflow."""
|
|
111
|
+
if config:
|
|
112
|
+
self.config = config
|
|
113
|
+
else:
|
|
114
|
+
self.config = WorkflowConfig(
|
|
115
|
+
name=name,
|
|
116
|
+
description=description or inspect.getdoc(self.__class__) or f"Workflow: {name}",
|
|
117
|
+
version=version,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
self._steps: Dict[str, Step] = {}
|
|
121
|
+
self._execution_context: Optional[ExecutionContext] = None
|
|
122
|
+
self._agents: Dict[str, Agent] = {}
|
|
123
|
+
self._tools: Dict[str, Tool] = {}
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def name(self) -> str:
|
|
127
|
+
return self.config.name
|
|
128
|
+
|
|
129
|
+
def add_agent(self, agent) -> None:
|
|
130
|
+
"""Add an agent to the workflow."""
|
|
131
|
+
self._agents[agent.name] = agent
|
|
132
|
+
|
|
133
|
+
def add_tool(self, tool) -> None:
|
|
134
|
+
"""Add a tool to the workflow."""
|
|
135
|
+
self._tools[tool.name] = tool
|
|
136
|
+
|
|
137
|
+
async def execute(self, *args, **kwargs) -> Any:
|
|
138
|
+
"""
|
|
139
|
+
Execute the workflow.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
*args: Positional arguments for the run method
|
|
143
|
+
**kwargs: Keyword arguments for the run method
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
Workflow result
|
|
147
|
+
"""
|
|
148
|
+
ctx = get_context()
|
|
149
|
+
|
|
150
|
+
# Create execution context
|
|
151
|
+
self._execution_context = ExecutionContext(
|
|
152
|
+
execution_id=ctx.execution_id,
|
|
153
|
+
workflow_id=self.name,
|
|
154
|
+
state=ExecutionState.RUNNING,
|
|
155
|
+
started_at=datetime.utcnow(),
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
try:
|
|
159
|
+
# Execute with durability if enabled (use durable flow)
|
|
160
|
+
if self.config.enable_durability:
|
|
161
|
+
result = await self._execute_as_durable_flow(*args, **kwargs)
|
|
162
|
+
else:
|
|
163
|
+
result = await self._execute_direct(*args, **kwargs)
|
|
164
|
+
|
|
165
|
+
# Mark as completed
|
|
166
|
+
self._execution_context.state = ExecutionState.COMPLETED
|
|
167
|
+
self._execution_context.completed_at = datetime.utcnow()
|
|
168
|
+
|
|
169
|
+
return result
|
|
170
|
+
|
|
171
|
+
except Exception as e:
|
|
172
|
+
# Mark as failed
|
|
173
|
+
self._execution_context.state = ExecutionState.FAILED
|
|
174
|
+
self._execution_context.completed_at = datetime.utcnow()
|
|
175
|
+
self._execution_context.last_error = str(e)
|
|
176
|
+
raise
|
|
177
|
+
|
|
178
|
+
async def _execute_as_durable_flow(self, *args, **kwargs) -> Any:
|
|
179
|
+
"""
|
|
180
|
+
Execute as a durable flow with comprehensive state management and recovery.
|
|
181
|
+
|
|
182
|
+
This creates a sophisticated durable flow that provides:
|
|
183
|
+
- Automatic checkpointing at configurable intervals
|
|
184
|
+
- Step-by-step state tracking and recovery
|
|
185
|
+
- Parallel execution coordination
|
|
186
|
+
- Tool and agent integration
|
|
187
|
+
- Comprehensive error handling and retry logic
|
|
188
|
+
"""
|
|
189
|
+
@durable.flow(
|
|
190
|
+
name=f"workflow_{self.name}",
|
|
191
|
+
checkpoint_interval=getattr(self.config, 'checkpoint_interval', 1),
|
|
192
|
+
max_retries=getattr(self.config, 'max_retries', 3),
|
|
193
|
+
max_concurrent_steps=getattr(self.config, 'max_parallel_steps', 5),
|
|
194
|
+
deterministic=True,
|
|
195
|
+
timeout=getattr(self.config, 'timeout', None),
|
|
196
|
+
)
|
|
197
|
+
async def comprehensive_workflow_flow(ctx: DurableContext, *flow_args, **flow_kwargs) -> Any:
|
|
198
|
+
# Initialize comprehensive workflow state
|
|
199
|
+
workflow_metadata = {
|
|
200
|
+
"workflow_name": self.name,
|
|
201
|
+
"workflow_version": getattr(self.config, 'version', '1.0.0'),
|
|
202
|
+
"execution_id": ctx.execution_id,
|
|
203
|
+
"started_at": datetime.utcnow().isoformat(),
|
|
204
|
+
"input_args": flow_args,
|
|
205
|
+
"input_kwargs": flow_kwargs,
|
|
206
|
+
"config": {
|
|
207
|
+
"checkpoint_interval": getattr(self.config, 'checkpoint_interval', 1),
|
|
208
|
+
"max_retries": getattr(self.config, 'max_retries', 3),
|
|
209
|
+
"max_parallel_steps": getattr(self.config, 'max_parallel_steps', 5),
|
|
210
|
+
"enable_durability": getattr(self.config, 'enable_durability', True),
|
|
211
|
+
}
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
await ctx.state.set("workflow_metadata", workflow_metadata)
|
|
215
|
+
await ctx.state.set("current_step", 0)
|
|
216
|
+
await ctx.state.set("completed_steps", [])
|
|
217
|
+
await ctx.state.set("step_results", {})
|
|
218
|
+
await ctx.state.set("workflow_status", "running")
|
|
219
|
+
|
|
220
|
+
# Create enhanced workflow context
|
|
221
|
+
workflow_ctx = EnhancedWorkflowDurableContext(ctx, self)
|
|
222
|
+
|
|
223
|
+
try:
|
|
224
|
+
# Execute the workflow with comprehensive tracking
|
|
225
|
+
result = await self._run_with_enhanced_context(workflow_ctx, *flow_args, **flow_kwargs)
|
|
226
|
+
|
|
227
|
+
# Mark workflow as completed
|
|
228
|
+
workflow_metadata["completed_at"] = datetime.utcnow().isoformat()
|
|
229
|
+
workflow_metadata["status"] = "completed"
|
|
230
|
+
workflow_metadata["result"] = result
|
|
231
|
+
|
|
232
|
+
await ctx.state.set("workflow_metadata", workflow_metadata)
|
|
233
|
+
await ctx.state.set("workflow_status", "completed")
|
|
234
|
+
await ctx.state.set("final_result", result)
|
|
235
|
+
|
|
236
|
+
logger.info(f"Workflow '{self.name}' completed successfully")
|
|
237
|
+
return result
|
|
238
|
+
|
|
239
|
+
except Exception as e:
|
|
240
|
+
# Handle workflow failure with comprehensive error tracking
|
|
241
|
+
workflow_metadata["failed_at"] = datetime.utcnow().isoformat()
|
|
242
|
+
workflow_metadata["status"] = "failed"
|
|
243
|
+
workflow_metadata["error"] = str(e)
|
|
244
|
+
workflow_metadata["error_type"] = type(e).__name__
|
|
245
|
+
|
|
246
|
+
await ctx.state.set("workflow_metadata", workflow_metadata)
|
|
247
|
+
await ctx.state.set("workflow_status", "failed")
|
|
248
|
+
await ctx.state.set("workflow_error", {
|
|
249
|
+
"error": str(e),
|
|
250
|
+
"error_type": type(e).__name__,
|
|
251
|
+
"failed_at": datetime.utcnow().isoformat(),
|
|
252
|
+
})
|
|
253
|
+
|
|
254
|
+
logger.error(f"Workflow '{self.name}' failed: {e}")
|
|
255
|
+
raise
|
|
256
|
+
|
|
257
|
+
# Execute the comprehensive durable flow
|
|
258
|
+
return await comprehensive_workflow_flow(*args, **kwargs)
|
|
259
|
+
|
|
260
|
+
@durable.function
|
|
261
|
+
async def _execute_durable(self, *args, **kwargs) -> Any:
|
|
262
|
+
"""Execute with durability guarantees (legacy method)."""
|
|
263
|
+
return await self.run(*args, **kwargs)
|
|
264
|
+
|
|
265
|
+
async def _execute_direct(self, *args, **kwargs) -> Any:
|
|
266
|
+
"""Direct execution without durability."""
|
|
267
|
+
return await self.run(*args, **kwargs)
|
|
268
|
+
|
|
269
|
+
async def run(self, *args, **kwargs) -> Any:
|
|
270
|
+
"""
|
|
271
|
+
Main workflow logic. Override this in subclasses.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
*args: Workflow arguments
|
|
275
|
+
**kwargs: Workflow keyword arguments
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
Workflow result
|
|
279
|
+
"""
|
|
280
|
+
raise NotImplementedError("Subclasses must implement the run method")
|
|
281
|
+
|
|
282
|
+
async def step(
|
|
283
|
+
self,
|
|
284
|
+
name: str,
|
|
285
|
+
func: Callable,
|
|
286
|
+
*args,
|
|
287
|
+
**kwargs,
|
|
288
|
+
) -> Any:
|
|
289
|
+
"""
|
|
290
|
+
Execute a workflow step.
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
name: Step name
|
|
294
|
+
func: Function to execute
|
|
295
|
+
*args: Function arguments
|
|
296
|
+
**kwargs: Function keyword arguments
|
|
297
|
+
|
|
298
|
+
Returns:
|
|
299
|
+
Step result
|
|
300
|
+
"""
|
|
301
|
+
# Create step
|
|
302
|
+
step = Step(
|
|
303
|
+
name=name,
|
|
304
|
+
func=func,
|
|
305
|
+
args=args,
|
|
306
|
+
kwargs=kwargs,
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
# Store step
|
|
310
|
+
self._steps[name] = step
|
|
311
|
+
|
|
312
|
+
# Log step start
|
|
313
|
+
logger.info(f"Starting step '{name}' in workflow '{self.name}'")
|
|
314
|
+
step.started_at = datetime.utcnow()
|
|
315
|
+
|
|
316
|
+
try:
|
|
317
|
+
# Execute step
|
|
318
|
+
if inspect.iscoroutinefunction(func):
|
|
319
|
+
result = await func(*args, **kwargs)
|
|
320
|
+
else:
|
|
321
|
+
# Run sync function in thread pool
|
|
322
|
+
loop = asyncio.get_event_loop()
|
|
323
|
+
result = await loop.run_in_executor(None, func, *args, **kwargs)
|
|
324
|
+
|
|
325
|
+
# Store result
|
|
326
|
+
step.result = result
|
|
327
|
+
step.completed_at = datetime.utcnow()
|
|
328
|
+
|
|
329
|
+
# Checkpoint if configured
|
|
330
|
+
if self.config.checkpoint_on_step:
|
|
331
|
+
await self._checkpoint()
|
|
332
|
+
|
|
333
|
+
logger.info(f"Completed step '{name}' in workflow '{self.name}'")
|
|
334
|
+
return result
|
|
335
|
+
|
|
336
|
+
except Exception as e:
|
|
337
|
+
# Store error
|
|
338
|
+
step.error = e
|
|
339
|
+
step.completed_at = datetime.utcnow()
|
|
340
|
+
|
|
341
|
+
logger.error(f"Step '{name}' failed in workflow '{self.name}': {e}")
|
|
342
|
+
|
|
343
|
+
# Retry if configured
|
|
344
|
+
if step.retries < self.config.get("max_retries", 3):
|
|
345
|
+
step.retries += 1
|
|
346
|
+
logger.info(f"Retrying step '{name}' (attempt {step.retries})")
|
|
347
|
+
return await self.step(name, func, *args, **kwargs)
|
|
348
|
+
|
|
349
|
+
raise
|
|
350
|
+
|
|
351
|
+
async def parallel(
|
|
352
|
+
self,
|
|
353
|
+
tasks: List[Tuple[str, Callable, ...]],
|
|
354
|
+
) -> List[Any]:
|
|
355
|
+
"""
|
|
356
|
+
Execute multiple steps in parallel.
|
|
357
|
+
|
|
358
|
+
Args:
|
|
359
|
+
tasks: List of (name, func, *args) tuples
|
|
360
|
+
|
|
361
|
+
Returns:
|
|
362
|
+
List of results in the same order as tasks
|
|
363
|
+
"""
|
|
364
|
+
# Create tasks
|
|
365
|
+
async_tasks = []
|
|
366
|
+
|
|
367
|
+
for task_spec in tasks:
|
|
368
|
+
if len(task_spec) == 2:
|
|
369
|
+
name, func = task_spec
|
|
370
|
+
args = ()
|
|
371
|
+
else:
|
|
372
|
+
name, func, *args = task_spec
|
|
373
|
+
|
|
374
|
+
# Create coroutine
|
|
375
|
+
coro = self.step(name, func, *args)
|
|
376
|
+
async_tasks.append(coro)
|
|
377
|
+
|
|
378
|
+
# Execute in parallel with concurrency limit
|
|
379
|
+
if self.config.max_parallel_steps:
|
|
380
|
+
results = []
|
|
381
|
+
for i in range(0, len(async_tasks), self.config.max_parallel_steps):
|
|
382
|
+
batch = async_tasks[i:i + self.config.max_parallel_steps]
|
|
383
|
+
batch_results = await asyncio.gather(*batch)
|
|
384
|
+
results.extend(batch_results)
|
|
385
|
+
return results
|
|
386
|
+
else:
|
|
387
|
+
return await asyncio.gather(*async_tasks)
|
|
388
|
+
|
|
389
|
+
async def call_agent(
|
|
390
|
+
self,
|
|
391
|
+
agent_name: str,
|
|
392
|
+
message: str,
|
|
393
|
+
**kwargs,
|
|
394
|
+
) -> Any:
|
|
395
|
+
"""
|
|
396
|
+
Call an agent within the workflow.
|
|
397
|
+
|
|
398
|
+
Args:
|
|
399
|
+
agent_name: Name of the agent
|
|
400
|
+
message: Message to send
|
|
401
|
+
**kwargs: Additional arguments
|
|
402
|
+
|
|
403
|
+
Returns:
|
|
404
|
+
Agent response
|
|
405
|
+
"""
|
|
406
|
+
agent = self._agents.get(agent_name)
|
|
407
|
+
if not agent:
|
|
408
|
+
raise ValueError(f"Agent '{agent_name}' not found in workflow")
|
|
409
|
+
|
|
410
|
+
return await agent.run(message, **kwargs)
|
|
411
|
+
|
|
412
|
+
async def call_tool(
|
|
413
|
+
self,
|
|
414
|
+
tool_name: str,
|
|
415
|
+
**kwargs,
|
|
416
|
+
) -> Any:
|
|
417
|
+
"""
|
|
418
|
+
Call a tool within the workflow.
|
|
419
|
+
|
|
420
|
+
Args:
|
|
421
|
+
tool_name: Name of the tool
|
|
422
|
+
**kwargs: Tool arguments
|
|
423
|
+
|
|
424
|
+
Returns:
|
|
425
|
+
Tool result
|
|
426
|
+
"""
|
|
427
|
+
tool = self._tools.get(tool_name)
|
|
428
|
+
if not tool:
|
|
429
|
+
raise ValueError(f"Tool '{tool_name}' not found in workflow")
|
|
430
|
+
|
|
431
|
+
return await tool.invoke(**kwargs)
|
|
432
|
+
|
|
433
|
+
async def wait(self, seconds: float) -> None:
|
|
434
|
+
"""Wait for a specified duration."""
|
|
435
|
+
await asyncio.sleep(seconds)
|
|
436
|
+
|
|
437
|
+
async def _checkpoint(self) -> None:
|
|
438
|
+
"""Save workflow state checkpoint."""
|
|
439
|
+
if not self._execution_context:
|
|
440
|
+
return
|
|
441
|
+
|
|
442
|
+
checkpoint_data = {
|
|
443
|
+
"steps": {
|
|
444
|
+
name: {
|
|
445
|
+
"result": step.result,
|
|
446
|
+
"error": str(step.error) if step.error else None,
|
|
447
|
+
"started_at": step.started_at.isoformat() if step.started_at else None,
|
|
448
|
+
"completed_at": step.completed_at.isoformat() if step.completed_at else None,
|
|
449
|
+
"retries": step.retries,
|
|
450
|
+
}
|
|
451
|
+
for name, step in self._steps.items()
|
|
452
|
+
},
|
|
453
|
+
}
|
|
454
|
+
|
|
455
|
+
self._execution_context.checkpoint_data = checkpoint_data
|
|
456
|
+
self._execution_context.last_checkpoint = datetime.utcnow()
|
|
457
|
+
|
|
458
|
+
# TODO: Persist checkpoint to durable storage
|
|
459
|
+
|
|
460
|
+
def get_step_result(self, name: str) -> Optional[Any]:
|
|
461
|
+
"""Get the result of a completed step."""
|
|
462
|
+
step = self._steps.get(name)
|
|
463
|
+
if step and step.result is not None:
|
|
464
|
+
return step.result
|
|
465
|
+
return None
|
|
466
|
+
|
|
467
|
+
def get_execution_context(self) -> Optional[ExecutionContext]:
|
|
468
|
+
"""Get the current execution context."""
|
|
469
|
+
return self._execution_context
|
|
470
|
+
|
|
471
|
+
async def save_state(self) -> Dict[str, Any]:
|
|
472
|
+
"""Save workflow state for persistence."""
|
|
473
|
+
return {
|
|
474
|
+
"config": self.config.__dict__,
|
|
475
|
+
"execution_context": self._execution_context.__dict__ if self._execution_context else None,
|
|
476
|
+
"steps": {
|
|
477
|
+
name: {
|
|
478
|
+
"name": step.name,
|
|
479
|
+
"result": step.result,
|
|
480
|
+
"error": str(step.error) if step.error else None,
|
|
481
|
+
"started_at": step.started_at.isoformat() if step.started_at else None,
|
|
482
|
+
"completed_at": step.completed_at.isoformat() if step.completed_at else None,
|
|
483
|
+
"retries": step.retries,
|
|
484
|
+
}
|
|
485
|
+
for name, step in self._steps.items()
|
|
486
|
+
},
|
|
487
|
+
}
|
|
488
|
+
|
|
489
|
+
async def load_state(self, state: Dict[str, Any]) -> None:
|
|
490
|
+
"""Load workflow state from persistence."""
|
|
491
|
+
# Restore execution context
|
|
492
|
+
if state.get("execution_context"):
|
|
493
|
+
ctx_data = state["execution_context"]
|
|
494
|
+
self._execution_context = ExecutionContext(**ctx_data)
|
|
495
|
+
|
|
496
|
+
# Note: Step results would need to be restored based on the workflow's
|
|
497
|
+
# specific implementation and replay logic
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
def workflow(
|
|
501
|
+
func: Optional[Callable] = None,
|
|
502
|
+
*,
|
|
503
|
+
name: Optional[str] = None,
|
|
504
|
+
description: Optional[str] = None,
|
|
505
|
+
version: str = "1.0.0",
|
|
506
|
+
enable_durability: bool = True,
|
|
507
|
+
) -> Union[Workflow, Callable]:
|
|
508
|
+
"""
|
|
509
|
+
Decorator to create a workflow from a function.
|
|
510
|
+
|
|
511
|
+
Example:
|
|
512
|
+
```python
|
|
513
|
+
@workflow
|
|
514
|
+
async def process_order(order_id: str) -> Dict[str, Any]:
|
|
515
|
+
'''Process an order through multiple stages.'''
|
|
516
|
+
# Workflow logic
|
|
517
|
+
return {"status": "completed", "order_id": order_id}
|
|
518
|
+
```
|
|
519
|
+
"""
|
|
520
|
+
def decorator(f: Callable) -> Workflow:
|
|
521
|
+
# Create a workflow class from the function
|
|
522
|
+
class FunctionWorkflow(Workflow):
|
|
523
|
+
async def run(self, *args, **kwargs):
|
|
524
|
+
return await f(*args, **kwargs)
|
|
525
|
+
|
|
526
|
+
# Create instance
|
|
527
|
+
workflow_name = name or f.__name__
|
|
528
|
+
workflow_desc = description or inspect.getdoc(f) or f"Workflow: {workflow_name}"
|
|
529
|
+
|
|
530
|
+
config = WorkflowConfig(
|
|
531
|
+
name=workflow_name,
|
|
532
|
+
description=workflow_desc,
|
|
533
|
+
version=version,
|
|
534
|
+
enable_durability=enable_durability,
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
return FunctionWorkflow(name=workflow_name, config=config)
|
|
538
|
+
|
|
539
|
+
if func is None:
|
|
540
|
+
# Called with arguments: @workflow(name="...", ...)
|
|
541
|
+
return decorator
|
|
542
|
+
else:
|
|
543
|
+
# Called without arguments: @workflow
|
|
544
|
+
return decorator(func)
|
|
545
|
+
|
|
546
|
+
|
|
547
|
+
class WorkflowDurableContext:
|
|
548
|
+
"""
|
|
549
|
+
Bridge between Workflow and DurableContext to provide workflow-specific APIs.
|
|
550
|
+
|
|
551
|
+
This class wraps a DurableContext and adds workflow-specific functionality
|
|
552
|
+
like step tracking, parallel execution, and agent/tool integration.
|
|
553
|
+
"""
|
|
554
|
+
|
|
555
|
+
def __init__(self, durable_ctx: DurableContext, workflow: Workflow):
|
|
556
|
+
self.durable_ctx = durable_ctx
|
|
557
|
+
self.workflow = workflow
|
|
558
|
+
self.workflow._durable_context = durable_ctx
|
|
559
|
+
|
|
560
|
+
async def call(self, service: str, method: str, *args, **kwargs) -> Any:
|
|
561
|
+
"""Make a durable service call."""
|
|
562
|
+
return await self.durable_ctx.call(service, method, *args, **kwargs)
|
|
563
|
+
|
|
564
|
+
async def sleep(self, seconds: float) -> None:
|
|
565
|
+
"""Durable sleep."""
|
|
566
|
+
await self.durable_ctx.sleep(seconds)
|
|
567
|
+
|
|
568
|
+
async def get_object(self, object_class, object_id: str):
|
|
569
|
+
"""Get a durable object."""
|
|
570
|
+
return await self.durable_ctx.get_object(object_class, object_id)
|
|
571
|
+
|
|
572
|
+
@property
|
|
573
|
+
def state(self):
|
|
574
|
+
"""Access durable state."""
|
|
575
|
+
return self.durable_ctx.state
|
|
576
|
+
|
|
577
|
+
@property
|
|
578
|
+
def execution_id(self) -> str:
|
|
579
|
+
"""Get execution ID."""
|
|
580
|
+
return self.durable_ctx.execution_id
|
|
581
|
+
|
|
582
|
+
async def step(self, name: str, func: Callable, *args, **kwargs) -> Any:
|
|
583
|
+
"""Execute a workflow step with durable state tracking."""
|
|
584
|
+
return await self.workflow.step(name, func, *args, **kwargs)
|
|
585
|
+
|
|
586
|
+
async def parallel(self, tasks: List[Tuple[str, Callable, ...]]) -> List[Any]:
|
|
587
|
+
"""Execute multiple steps in parallel with durable state tracking."""
|
|
588
|
+
return await self.workflow.parallel(tasks)
|
|
589
|
+
|
|
590
|
+
async def call_agent(self, agent_name: str, message: str, **kwargs) -> Any:
|
|
591
|
+
"""Call an agent within the workflow."""
|
|
592
|
+
return await self.workflow.call_agent(agent_name, message, **kwargs)
|
|
593
|
+
|
|
594
|
+
async def call_tool(self, tool_name: str, **kwargs) -> Any:
|
|
595
|
+
"""Call a tool within the workflow."""
|
|
596
|
+
return await self.workflow.call_tool(tool_name, **kwargs)
|
|
597
|
+
|
|
598
|
+
|
|
599
|
+
# Add method to Workflow class for durable context execution
|
|
600
|
+
async def _run_with_durable_context(self, ctx: WorkflowDurableContext, *args, **kwargs) -> Any:
|
|
601
|
+
"""
|
|
602
|
+
Execute the workflow run method with a durable context.
|
|
603
|
+
|
|
604
|
+
This method replaces the original run method when executing as a durable flow.
|
|
605
|
+
"""
|
|
606
|
+
return await self.run(*args, **kwargs)
|
|
607
|
+
|
|
608
|
+
# Enhanced method for Workflow class durable context execution
|
|
609
|
+
async def _run_with_enhanced_context(self, ctx, *args, **kwargs) -> Any:
|
|
610
|
+
"""
|
|
611
|
+
Execute the workflow run method with an enhanced durable context.
|
|
612
|
+
|
|
613
|
+
This method provides the workflow with full access to durable primitives
|
|
614
|
+
and comprehensive state management.
|
|
615
|
+
"""
|
|
616
|
+
# Store the enhanced context for use in step methods
|
|
617
|
+
self._enhanced_context = ctx
|
|
618
|
+
|
|
619
|
+
try:
|
|
620
|
+
# Execute the workflow run method with enhanced context
|
|
621
|
+
result = await self.run(*args, **kwargs)
|
|
622
|
+
return result
|
|
623
|
+
finally:
|
|
624
|
+
# Clean up context reference
|
|
625
|
+
if hasattr(self, '_enhanced_context'):
|
|
626
|
+
delattr(self, '_enhanced_context')
|
|
627
|
+
|
|
628
|
+
# Monkey patch the enhanced method onto the Workflow class
|
|
629
|
+
Workflow._run_with_enhanced_context = _run_with_enhanced_context
|
|
630
|
+
|
|
631
|
+
|
|
632
|
+
class EnhancedWorkflowDurableContext:
|
|
633
|
+
"""
|
|
634
|
+
Enhanced bridge between Workflow and DurableContext with advanced features.
|
|
635
|
+
|
|
636
|
+
This context provides:
|
|
637
|
+
- Comprehensive step tracking and state management
|
|
638
|
+
- Advanced parallel execution with concurrency control
|
|
639
|
+
- Integrated agent and tool management
|
|
640
|
+
- Automatic checkpointing and recovery
|
|
641
|
+
- Performance monitoring and metrics
|
|
642
|
+
- Error handling and retry coordination
|
|
643
|
+
"""
|
|
644
|
+
|
|
645
|
+
def __init__(self, durable_ctx: DurableContext, workflow):
|
|
646
|
+
self.durable_ctx = durable_ctx
|
|
647
|
+
self.workflow = workflow
|
|
648
|
+
self.workflow._durable_context = durable_ctx
|
|
649
|
+
self._step_counter = 0
|
|
650
|
+
self._parallel_execution_slots = getattr(workflow.config, 'max_parallel_steps', 5)
|
|
651
|
+
|
|
652
|
+
async def call(self, service: str, method: str, *args, **kwargs) -> Any:
|
|
653
|
+
"""Make a durable service call with workflow integration."""
|
|
654
|
+
# Track service call in workflow state
|
|
655
|
+
call_id = f"service_call_{self._step_counter}"
|
|
656
|
+
self._step_counter += 1
|
|
657
|
+
|
|
658
|
+
await self.durable_ctx.state.set(f"call_{call_id}", {
|
|
659
|
+
"service": service,
|
|
660
|
+
"method": method,
|
|
661
|
+
"args": args,
|
|
662
|
+
"kwargs": kwargs,
|
|
663
|
+
"started_at": datetime.utcnow().isoformat(),
|
|
664
|
+
})
|
|
665
|
+
|
|
666
|
+
try:
|
|
667
|
+
result = await self.durable_ctx.call(service, method, *args, **kwargs)
|
|
668
|
+
await self.durable_ctx.state.set(f"call_{call_id}_result", {
|
|
669
|
+
"result": result,
|
|
670
|
+
"completed_at": datetime.utcnow().isoformat(),
|
|
671
|
+
"status": "success",
|
|
672
|
+
})
|
|
673
|
+
return result
|
|
674
|
+
except Exception as e:
|
|
675
|
+
await self.durable_ctx.state.set(f"call_{call_id}_result", {
|
|
676
|
+
"error": str(e),
|
|
677
|
+
"failed_at": datetime.utcnow().isoformat(),
|
|
678
|
+
"status": "failed",
|
|
679
|
+
})
|
|
680
|
+
raise
|
|
681
|
+
|
|
682
|
+
async def sleep(self, seconds: float) -> None:
|
|
683
|
+
"""Durable sleep with workflow state tracking."""
|
|
684
|
+
sleep_id = f"sleep_{self._step_counter}"
|
|
685
|
+
self._step_counter += 1
|
|
686
|
+
|
|
687
|
+
await self.durable_ctx.state.set(f"sleep_{sleep_id}", {
|
|
688
|
+
"duration": seconds,
|
|
689
|
+
"started_at": datetime.utcnow().isoformat(),
|
|
690
|
+
})
|
|
691
|
+
|
|
692
|
+
await self.durable_ctx.sleep(seconds)
|
|
693
|
+
|
|
694
|
+
await self.durable_ctx.state.set(f"sleep_{sleep_id}_completed", {
|
|
695
|
+
"completed_at": datetime.utcnow().isoformat(),
|
|
696
|
+
})
|
|
697
|
+
|
|
698
|
+
async def get_object(self, object_class, object_id: str):
|
|
699
|
+
"""Get a durable object with workflow integration."""
|
|
700
|
+
# Track object access
|
|
701
|
+
access_id = f"object_access_{self._step_counter}"
|
|
702
|
+
self._step_counter += 1
|
|
703
|
+
|
|
704
|
+
await self.durable_ctx.state.set(f"object_{access_id}", {
|
|
705
|
+
"object_class": object_class.__name__,
|
|
706
|
+
"object_id": object_id,
|
|
707
|
+
"accessed_at": datetime.utcnow().isoformat(),
|
|
708
|
+
})
|
|
709
|
+
|
|
710
|
+
return await self.durable_ctx.get_object(object_class, object_id)
|
|
711
|
+
|
|
712
|
+
@property
|
|
713
|
+
def state(self):
|
|
714
|
+
"""Access durable state with workflow enhancements."""
|
|
715
|
+
return self.durable_ctx.state
|
|
716
|
+
|
|
717
|
+
@property
|
|
718
|
+
def execution_id(self) -> str:
|
|
719
|
+
"""Get execution ID."""
|
|
720
|
+
return self.durable_ctx.execution_id
|
|
721
|
+
|
|
722
|
+
async def step(self, name: str, func: Callable, *args, **kwargs) -> Any:
|
|
723
|
+
"""Execute a workflow step with comprehensive durable state tracking."""
|
|
724
|
+
step_start_time = datetime.utcnow()
|
|
725
|
+
|
|
726
|
+
# Update step tracking
|
|
727
|
+
current_step = await self.state.get("current_step", 0)
|
|
728
|
+
current_step += 1
|
|
729
|
+
await self.state.set("current_step", current_step)
|
|
730
|
+
|
|
731
|
+
# Store step metadata
|
|
732
|
+
step_metadata = {
|
|
733
|
+
"step_number": current_step,
|
|
734
|
+
"step_name": name,
|
|
735
|
+
"function_name": func.__name__ if hasattr(func, '__name__') else str(func),
|
|
736
|
+
"started_at": step_start_time.isoformat(),
|
|
737
|
+
"args": args,
|
|
738
|
+
"kwargs": kwargs,
|
|
739
|
+
}
|
|
740
|
+
await self.state.set(f"step_{current_step}_metadata", step_metadata)
|
|
741
|
+
|
|
742
|
+
try:
|
|
743
|
+
# Execute the step through the workflow
|
|
744
|
+
result = await self.workflow.step(name, func, *args, **kwargs)
|
|
745
|
+
|
|
746
|
+
# Update completion tracking
|
|
747
|
+
completed_steps = await self.state.get("completed_steps", [])
|
|
748
|
+
completed_steps.append(name)
|
|
749
|
+
await self.state.set("completed_steps", completed_steps)
|
|
750
|
+
|
|
751
|
+
# Store step result
|
|
752
|
+
step_results = await self.state.get("step_results", {})
|
|
753
|
+
step_results[name] = result
|
|
754
|
+
await self.state.set("step_results", step_results)
|
|
755
|
+
|
|
756
|
+
# Update step completion metadata
|
|
757
|
+
step_metadata["completed_at"] = datetime.utcnow().isoformat()
|
|
758
|
+
step_metadata["duration_seconds"] = (datetime.utcnow() - step_start_time).total_seconds()
|
|
759
|
+
step_metadata["status"] = "completed"
|
|
760
|
+
await self.state.set(f"step_{current_step}_metadata", step_metadata)
|
|
761
|
+
|
|
762
|
+
return result
|
|
763
|
+
|
|
764
|
+
except Exception as e:
|
|
765
|
+
# Track step failure
|
|
766
|
+
step_metadata["failed_at"] = datetime.utcnow().isoformat()
|
|
767
|
+
step_metadata["duration_seconds"] = (datetime.utcnow() - step_start_time).total_seconds()
|
|
768
|
+
step_metadata["status"] = "failed"
|
|
769
|
+
step_metadata["error"] = str(e)
|
|
770
|
+
await self.state.set(f"step_{current_step}_metadata", step_metadata)
|
|
771
|
+
raise
|
|
772
|
+
|
|
773
|
+
async def parallel(self, tasks: List[Tuple[str, Callable, ...]]) -> List[Any]:
|
|
774
|
+
"""Execute multiple steps in parallel with advanced coordination."""
|
|
775
|
+
parallel_execution_id = f"parallel_{self._step_counter}"
|
|
776
|
+
self._step_counter += 1
|
|
777
|
+
|
|
778
|
+
# Track parallel execution start
|
|
779
|
+
parallel_metadata = {
|
|
780
|
+
"execution_id": parallel_execution_id,
|
|
781
|
+
"task_count": len(tasks),
|
|
782
|
+
"started_at": datetime.utcnow().isoformat(),
|
|
783
|
+
"tasks": [
|
|
784
|
+
{
|
|
785
|
+
"name": task[0] if len(task) > 0 else f"task_{i}",
|
|
786
|
+
"function": task[1].__name__ if len(task) > 1 and hasattr(task[1], '__name__') else "unknown",
|
|
787
|
+
"args_count": len(task) - 2 if len(task) > 2 else 0,
|
|
788
|
+
}
|
|
789
|
+
for i, task in enumerate(tasks)
|
|
790
|
+
]
|
|
791
|
+
}
|
|
792
|
+
await self.state.set(f"parallel_{parallel_execution_id}", parallel_metadata)
|
|
793
|
+
|
|
794
|
+
try:
|
|
795
|
+
# Execute with concurrency control
|
|
796
|
+
results = await self.workflow.parallel(tasks)
|
|
797
|
+
|
|
798
|
+
# Track successful completion
|
|
799
|
+
parallel_metadata["completed_at"] = datetime.utcnow().isoformat()
|
|
800
|
+
parallel_metadata["duration_seconds"] = (
|
|
801
|
+
datetime.utcnow() - datetime.fromisoformat(parallel_metadata["started_at"])
|
|
802
|
+
).total_seconds()
|
|
803
|
+
parallel_metadata["status"] = "completed"
|
|
804
|
+
parallel_metadata["results_count"] = len(results)
|
|
805
|
+
|
|
806
|
+
await self.state.set(f"parallel_{parallel_execution_id}", parallel_metadata)
|
|
807
|
+
return results
|
|
808
|
+
|
|
809
|
+
except Exception as e:
|
|
810
|
+
# Track parallel execution failure
|
|
811
|
+
parallel_metadata["failed_at"] = datetime.utcnow().isoformat()
|
|
812
|
+
parallel_metadata["duration_seconds"] = (
|
|
813
|
+
datetime.utcnow() - datetime.fromisoformat(parallel_metadata["started_at"])
|
|
814
|
+
).total_seconds()
|
|
815
|
+
parallel_metadata["status"] = "failed"
|
|
816
|
+
parallel_metadata["error"] = str(e)
|
|
817
|
+
|
|
818
|
+
await self.state.set(f"parallel_{parallel_execution_id}", parallel_metadata)
|
|
819
|
+
raise
|
|
820
|
+
|
|
821
|
+
async def call_agent(self, agent_name: str, message: str, **kwargs) -> Any:
|
|
822
|
+
"""Call an agent within the workflow with state tracking."""
|
|
823
|
+
agent_call_id = f"agent_call_{self._step_counter}"
|
|
824
|
+
self._step_counter += 1
|
|
825
|
+
|
|
826
|
+
await self.state.set(f"agent_call_{agent_call_id}", {
|
|
827
|
+
"agent_name": agent_name,
|
|
828
|
+
"message": message,
|
|
829
|
+
"kwargs": kwargs,
|
|
830
|
+
"started_at": datetime.utcnow().isoformat(),
|
|
831
|
+
})
|
|
832
|
+
|
|
833
|
+
try:
|
|
834
|
+
result = await self.workflow.call_agent(agent_name, message, **kwargs)
|
|
835
|
+
await self.state.set(f"agent_call_{agent_call_id}_result", {
|
|
836
|
+
"result": result,
|
|
837
|
+
"completed_at": datetime.utcnow().isoformat(),
|
|
838
|
+
})
|
|
839
|
+
return result
|
|
840
|
+
except Exception as e:
|
|
841
|
+
await self.state.set(f"agent_call_{agent_call_id}_result", {
|
|
842
|
+
"error": str(e),
|
|
843
|
+
"failed_at": datetime.utcnow().isoformat(),
|
|
844
|
+
})
|
|
845
|
+
raise
|
|
846
|
+
|
|
847
|
+
async def call_tool(self, tool_name: str, **kwargs) -> Any:
|
|
848
|
+
"""Call a tool within the workflow with enhanced durability."""
|
|
849
|
+
tool_call_id = f"tool_call_{self._step_counter}"
|
|
850
|
+
self._step_counter += 1
|
|
851
|
+
|
|
852
|
+
await self.state.set(f"tool_call_{tool_call_id}", {
|
|
853
|
+
"tool_name": tool_name,
|
|
854
|
+
"kwargs": kwargs,
|
|
855
|
+
"started_at": datetime.utcnow().isoformat(),
|
|
856
|
+
})
|
|
857
|
+
|
|
858
|
+
try:
|
|
859
|
+
# Use durable tool execution if available
|
|
860
|
+
result = await self.workflow.call_tool(tool_name, **kwargs)
|
|
861
|
+
await self.state.set(f"tool_call_{tool_call_id}_result", {
|
|
862
|
+
"result": result,
|
|
863
|
+
"completed_at": datetime.utcnow().isoformat(),
|
|
864
|
+
})
|
|
865
|
+
return result
|
|
866
|
+
except Exception as e:
|
|
867
|
+
await self.state.set(f"tool_call_{tool_call_id}_result", {
|
|
868
|
+
"error": str(e),
|
|
869
|
+
"failed_at": datetime.utcnow().isoformat(),
|
|
870
|
+
})
|
|
871
|
+
raise
|
|
872
|
+
|
|
873
|
+
async def get_workflow_progress(self) -> Dict[str, Any]:
|
|
874
|
+
"""Get comprehensive workflow progress information."""
|
|
875
|
+
return {
|
|
876
|
+
"execution_id": self.execution_id,
|
|
877
|
+
"workflow_name": self.workflow.name,
|
|
878
|
+
"current_step": await self.state.get("current_step", 0),
|
|
879
|
+
"completed_steps": await self.state.get("completed_steps", []),
|
|
880
|
+
"workflow_status": await self.state.get("workflow_status", "unknown"),
|
|
881
|
+
"step_results": await self.state.get("step_results", {}),
|
|
882
|
+
"workflow_metadata": await self.state.get("workflow_metadata", {}),
|
|
883
|
+
}
|
|
884
|
+
|
|
885
|
+
|
|
886
|
+
class WorkflowDurableContext(EnhancedWorkflowDurableContext):
|
|
887
|
+
"""
|
|
888
|
+
Backward compatibility alias for the enhanced workflow context.
|
|
889
|
+
|
|
890
|
+
This ensures existing code continues to work while providing
|
|
891
|
+
access to all the new enhanced features.
|
|
892
|
+
"""
|
|
893
|
+
pass
|
|
894
|
+
|
|
895
|
+
|
|
896
|
+
# Monkey patch the method onto the Workflow class
|
|
897
|
+
Workflow._run_with_durable_context = _run_with_durable_context
|