agnt5 0.3.2a1__cp310-abi3-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of agnt5 might be problematic. Click here for more details.
- agnt5/__init__.py +119 -0
- agnt5/_compat.py +16 -0
- agnt5/_core.abi3.so +0 -0
- agnt5/_retry_utils.py +196 -0
- agnt5/_schema_utils.py +312 -0
- agnt5/_sentry.py +515 -0
- agnt5/_telemetry.py +279 -0
- agnt5/agent/__init__.py +48 -0
- agnt5/agent/context.py +581 -0
- agnt5/agent/core.py +1782 -0
- agnt5/agent/decorator.py +112 -0
- agnt5/agent/handoff.py +105 -0
- agnt5/agent/registry.py +68 -0
- agnt5/agent/result.py +39 -0
- agnt5/checkpoint.py +246 -0
- agnt5/client.py +1556 -0
- agnt5/context.py +288 -0
- agnt5/emit.py +197 -0
- agnt5/entity.py +1230 -0
- agnt5/events.py +567 -0
- agnt5/exceptions.py +110 -0
- agnt5/function.py +330 -0
- agnt5/journal.py +212 -0
- agnt5/lm.py +1266 -0
- agnt5/memoization.py +379 -0
- agnt5/memory.py +521 -0
- agnt5/tool.py +721 -0
- agnt5/tracing.py +300 -0
- agnt5/types.py +111 -0
- agnt5/version.py +19 -0
- agnt5/worker.py +2094 -0
- agnt5/workflow.py +1632 -0
- agnt5-0.3.2a1.dist-info/METADATA +26 -0
- agnt5-0.3.2a1.dist-info/RECORD +35 -0
- agnt5-0.3.2a1.dist-info/WHEEL +4 -0
agnt5/workflow.py
ADDED
|
@@ -0,0 +1,1632 @@
|
|
|
1
|
+
"""Workflow component implementation for AGNT5 SDK."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import functools
|
|
7
|
+
import inspect
|
|
8
|
+
import logging
|
|
9
|
+
import time
|
|
10
|
+
import uuid
|
|
11
|
+
from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar, Union, cast
|
|
12
|
+
|
|
13
|
+
from ._schema_utils import extract_function_metadata, extract_function_schemas
|
|
14
|
+
from .context import Context, set_current_context
|
|
15
|
+
from .entity import Entity, EntityState, _get_state_adapter
|
|
16
|
+
from .function import FunctionContext
|
|
17
|
+
from .types import HandlerFunc, WorkflowConfig
|
|
18
|
+
from ._telemetry import setup_module_logger
|
|
19
|
+
|
|
20
|
+
logger = setup_module_logger(__name__)
|
|
21
|
+
|
|
22
|
+
T = TypeVar("T")
|
|
23
|
+
|
|
24
|
+
# Global workflow registry
|
|
25
|
+
_WORKFLOW_REGISTRY: Dict[str, WorkflowConfig] = {}
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class WorkflowContext(Context):
|
|
29
|
+
"""
|
|
30
|
+
Context for durable workflows.
|
|
31
|
+
|
|
32
|
+
Extends base Context with:
|
|
33
|
+
- State management via WorkflowEntity.state
|
|
34
|
+
- Step tracking and replay
|
|
35
|
+
- Orchestration (task, parallel, gather)
|
|
36
|
+
- Checkpointing (step)
|
|
37
|
+
- Memory scoping (session_id, user_id for multi-level memory)
|
|
38
|
+
|
|
39
|
+
WorkflowContext delegates state to the underlying WorkflowEntity,
|
|
40
|
+
which provides durability and state change tracking for AI workflows.
|
|
41
|
+
|
|
42
|
+
Memory Scoping:
|
|
43
|
+
- run_id: Unique workflow run identifier
|
|
44
|
+
- session_id: For multi-turn conversations (optional)
|
|
45
|
+
- user_id: For user-scoped long-term memory (optional)
|
|
46
|
+
These identifiers enable agents to automatically select the appropriate
|
|
47
|
+
memory scope (run/session/user) via context propagation.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
workflow_entity: "WorkflowEntity", # Forward reference
|
|
53
|
+
run_id: str,
|
|
54
|
+
session_id: Optional[str] = None,
|
|
55
|
+
user_id: Optional[str] = None,
|
|
56
|
+
attempt: int = 0,
|
|
57
|
+
runtime_context: Optional[Any] = None,
|
|
58
|
+
checkpoint_callback: Optional[Callable[[dict], None]] = None,
|
|
59
|
+
checkpoint_client: Optional[Any] = None,
|
|
60
|
+
is_streaming: bool = False,
|
|
61
|
+
tenant_id: Optional[str] = None,
|
|
62
|
+
delta_callback: Optional[Callable[[str, str, int, int, int], None]] = None,
|
|
63
|
+
) -> None:
|
|
64
|
+
"""
|
|
65
|
+
Initialize workflow context.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
workflow_entity: WorkflowEntity instance managing workflow state
|
|
69
|
+
run_id: Unique workflow run identifier
|
|
70
|
+
session_id: Session identifier for multi-turn conversations (default: run_id)
|
|
71
|
+
user_id: User identifier for user-scoped memory (optional)
|
|
72
|
+
attempt: Retry attempt number (0-indexed)
|
|
73
|
+
runtime_context: RuntimeContext for trace correlation
|
|
74
|
+
checkpoint_callback: Optional callback for sending real-time checkpoints
|
|
75
|
+
checkpoint_client: Optional CheckpointClient for platform-side memoization
|
|
76
|
+
is_streaming: Whether this is a streaming request (for real-time SSE log delivery)
|
|
77
|
+
tenant_id: Tenant identifier for multi-tenant deployments
|
|
78
|
+
delta_callback: Optional callback for forwarding streaming events from nested components
|
|
79
|
+
(event_type, output_data, content_index, sequence, source_timestamp_ns) -> None
|
|
80
|
+
"""
|
|
81
|
+
super().__init__(
|
|
82
|
+
run_id=run_id,
|
|
83
|
+
attempt=attempt,
|
|
84
|
+
runtime_context=runtime_context,
|
|
85
|
+
is_streaming=is_streaming,
|
|
86
|
+
tenant_id=tenant_id,
|
|
87
|
+
checkpoint_callback=checkpoint_callback,
|
|
88
|
+
delta_callback=delta_callback,
|
|
89
|
+
)
|
|
90
|
+
self._workflow_entity = workflow_entity
|
|
91
|
+
self._step_counter: int = 0 # Track step sequence
|
|
92
|
+
self._sequence_number: int = 0 # Global sequence for checkpoints
|
|
93
|
+
self._checkpoint_client = checkpoint_client
|
|
94
|
+
self._delta_sequence: int = 0 # Sequence for delta events (separate from checkpoint sequence)
|
|
95
|
+
|
|
96
|
+
# Memory scoping identifiers (use private attrs since properties are read-only)
|
|
97
|
+
self._session_id = session_id or run_id # Default: session = run (ephemeral)
|
|
98
|
+
self._user_id = user_id # Optional: user-scoped memory
|
|
99
|
+
|
|
100
|
+
# Step hierarchy tracking - for nested step visualization
|
|
101
|
+
# Stack of event IDs for currently executing steps
|
|
102
|
+
self._step_event_stack: List[str] = []
|
|
103
|
+
|
|
104
|
+
# === State Management ===
|
|
105
|
+
|
|
106
|
+
def _forward_delta(self, event_type: str, output_data: str, content_index: int = 0, source_timestamp_ns: int = 0) -> None:
|
|
107
|
+
"""
|
|
108
|
+
Forward a streaming delta event from a nested component.
|
|
109
|
+
|
|
110
|
+
Used by step executors to forward events from streaming agents/functions
|
|
111
|
+
to the client via the delta queue.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
event_type: Event type (e.g., "agent.started", "lm.message.delta")
|
|
115
|
+
output_data: JSON-serialized event data
|
|
116
|
+
content_index: Content index for parallel events (default: 0)
|
|
117
|
+
source_timestamp_ns: Nanosecond timestamp when event was created (default: 0, will be generated if not provided)
|
|
118
|
+
"""
|
|
119
|
+
if self._delta_callback:
|
|
120
|
+
self._delta_callback(event_type, output_data, content_index, self._delta_sequence, source_timestamp_ns)
|
|
121
|
+
self._delta_sequence += 1
|
|
122
|
+
|
|
123
|
+
async def _consume_streaming_result(self, async_gen: Any, step_name: str) -> Any:
|
|
124
|
+
"""
|
|
125
|
+
Consume an async generator while forwarding streaming events to the client.
|
|
126
|
+
|
|
127
|
+
This method handles streaming from nested agents and functions within
|
|
128
|
+
workflow steps. Events are forwarded via the delta queue while the
|
|
129
|
+
final result is collected and returned for the next step.
|
|
130
|
+
|
|
131
|
+
For agents, the final output is extracted from the agent.completed event.
|
|
132
|
+
For functions, the last yielded value (or collected output) is returned.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
async_gen: Async generator yielding Event objects or raw values
|
|
136
|
+
step_name: Name of the current step (for logging)
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
The final result to pass to the next step:
|
|
140
|
+
- For agents: The output from agent.completed event
|
|
141
|
+
- For functions: The last yielded value or collected output
|
|
142
|
+
"""
|
|
143
|
+
import json
|
|
144
|
+
from .events import Event, EventType
|
|
145
|
+
|
|
146
|
+
final_result = None
|
|
147
|
+
collected_output = [] # For streaming functions that yield chunks
|
|
148
|
+
|
|
149
|
+
async for item in async_gen:
|
|
150
|
+
if isinstance(item, Event):
|
|
151
|
+
# Forward typed Event via delta queue
|
|
152
|
+
event_data = item.to_response_fields()
|
|
153
|
+
output_data = event_data.get("output_data", b"")
|
|
154
|
+
output_str = output_data.decode("utf-8") if isinstance(output_data, bytes) else str(output_data or "{}")
|
|
155
|
+
|
|
156
|
+
self._forward_delta(
|
|
157
|
+
event_type=event_data.get("event_type", ""),
|
|
158
|
+
output_data=output_str,
|
|
159
|
+
content_index=event_data.get("content_index", 0),
|
|
160
|
+
source_timestamp_ns=item.source_timestamp_ns,
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
# Capture final result from specific event types
|
|
164
|
+
if item.event_type == EventType.AGENT_COMPLETED:
|
|
165
|
+
# For agents, extract the output from completed event
|
|
166
|
+
final_result = item.data.get("output", "")
|
|
167
|
+
logger.debug(f"Step '{step_name}': Captured agent output from agent.completed")
|
|
168
|
+
elif item.event_type == EventType.OUTPUT_STOP:
|
|
169
|
+
# For streaming functions, the collected output is the result
|
|
170
|
+
# (already collected from delta events)
|
|
171
|
+
pass
|
|
172
|
+
|
|
173
|
+
else:
|
|
174
|
+
# Raw value (non-Event) - streaming function output
|
|
175
|
+
# Forward as output.delta and collect for final result
|
|
176
|
+
try:
|
|
177
|
+
chunk_json = json.dumps(item)
|
|
178
|
+
except (TypeError, ValueError):
|
|
179
|
+
chunk_json = str(item)
|
|
180
|
+
|
|
181
|
+
self._forward_delta(
|
|
182
|
+
event_type="output.delta",
|
|
183
|
+
output_data=chunk_json,
|
|
184
|
+
source_timestamp_ns=time.time_ns(),
|
|
185
|
+
)
|
|
186
|
+
collected_output.append(item)
|
|
187
|
+
|
|
188
|
+
# Determine final result
|
|
189
|
+
if final_result is not None:
|
|
190
|
+
# Agent result was captured from agent.completed event
|
|
191
|
+
return final_result
|
|
192
|
+
elif collected_output:
|
|
193
|
+
# Streaming function - return collected chunks
|
|
194
|
+
# If single item, return it directly; otherwise return list
|
|
195
|
+
if len(collected_output) == 1:
|
|
196
|
+
return collected_output[0]
|
|
197
|
+
return collected_output
|
|
198
|
+
else:
|
|
199
|
+
# Empty generator
|
|
200
|
+
return None
|
|
201
|
+
|
|
202
|
+
def _send_checkpoint(self, checkpoint_type: str, checkpoint_data: dict) -> None:
|
|
203
|
+
"""
|
|
204
|
+
Send a checkpoint via the unified event emission API.
|
|
205
|
+
|
|
206
|
+
This method uses ctx.emit for consistent event routing:
|
|
207
|
+
- Streaming mode: immediate delivery via delta queue
|
|
208
|
+
- Non-streaming mode: buffered delivery via checkpoint queue
|
|
209
|
+
|
|
210
|
+
Automatically adds parent_event_id from the step event stack if we're
|
|
211
|
+
currently executing inside a nested step call.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
checkpoint_type: Type of checkpoint (e.g., "workflow.state.changed")
|
|
215
|
+
checkpoint_data: Checkpoint payload (should include event_id if needed)
|
|
216
|
+
"""
|
|
217
|
+
# Add parent_event_id if we're in a nested step
|
|
218
|
+
if self._step_event_stack:
|
|
219
|
+
checkpoint_data = {
|
|
220
|
+
**checkpoint_data,
|
|
221
|
+
"parent_event_id": self._step_event_stack[-1],
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
# Use unified emit API for consistent event routing
|
|
225
|
+
# The emit property handles streaming vs non-streaming routing
|
|
226
|
+
self.emit.emit(checkpoint_type, checkpoint_data)
|
|
227
|
+
|
|
228
|
+
@property
|
|
229
|
+
def state(self):
|
|
230
|
+
"""
|
|
231
|
+
Delegate to WorkflowEntity.state for durable state management.
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
WorkflowState instance from the workflow entity
|
|
235
|
+
|
|
236
|
+
Example:
|
|
237
|
+
ctx.state.set("status", "processing")
|
|
238
|
+
status = ctx.state.get("status")
|
|
239
|
+
"""
|
|
240
|
+
state = self._workflow_entity.state
|
|
241
|
+
# Pass checkpoint callback to state for real-time streaming
|
|
242
|
+
if hasattr(state, "_set_checkpoint_callback"):
|
|
243
|
+
state._set_checkpoint_callback(self._send_checkpoint)
|
|
244
|
+
return state
|
|
245
|
+
|
|
246
|
+
# === Orchestration ===
|
|
247
|
+
|
|
248
|
+
async def step(
|
|
249
|
+
self,
|
|
250
|
+
name_or_handler: Union[str, Callable, Awaitable[T]],
|
|
251
|
+
func_or_awaitable: Union[Callable[..., Awaitable[T]], Awaitable[T], Any] = None,
|
|
252
|
+
*args: Any,
|
|
253
|
+
**kwargs: Any,
|
|
254
|
+
) -> T:
|
|
255
|
+
"""
|
|
256
|
+
Execute a durable step with automatic checkpointing.
|
|
257
|
+
|
|
258
|
+
Steps are the primary building block for durable workflows. Results are
|
|
259
|
+
automatically persisted, so if the workflow crashes and restarts, completed
|
|
260
|
+
steps return their cached result without re-executing.
|
|
261
|
+
|
|
262
|
+
Supports multiple calling patterns:
|
|
263
|
+
|
|
264
|
+
1. **Call a @function (recommended)**:
|
|
265
|
+
```python
|
|
266
|
+
result = await ctx.step(process_data, arg1, arg2, kwarg=value)
|
|
267
|
+
```
|
|
268
|
+
Auto-generates step name from function. Full IDE support.
|
|
269
|
+
|
|
270
|
+
2. **Checkpoint an awaitable with explicit name**:
|
|
271
|
+
```python
|
|
272
|
+
result = await ctx.step("load_data", fetch_expensive_data())
|
|
273
|
+
```
|
|
274
|
+
For arbitrary async operations that aren't @functions.
|
|
275
|
+
|
|
276
|
+
3. **Checkpoint a callable with explicit name**:
|
|
277
|
+
```python
|
|
278
|
+
result = await ctx.step("compute", my_function, arg1, arg2)
|
|
279
|
+
```
|
|
280
|
+
|
|
281
|
+
4. **Legacy string-based @function call**:
|
|
282
|
+
```python
|
|
283
|
+
result = await ctx.step("function_name", input=data)
|
|
284
|
+
```
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
name_or_handler: Step name (str), @function reference, or awaitable
|
|
288
|
+
func_or_awaitable: Function/awaitable when name is provided, or first arg
|
|
289
|
+
*args: Additional arguments for the function
|
|
290
|
+
**kwargs: Keyword arguments for the function
|
|
291
|
+
|
|
292
|
+
Returns:
|
|
293
|
+
The step result (cached on replay)
|
|
294
|
+
|
|
295
|
+
Example (@function call):
|
|
296
|
+
```python
|
|
297
|
+
@function
|
|
298
|
+
async def process_data(ctx: FunctionContext, data: list, multiplier: int = 2):
|
|
299
|
+
return [x * multiplier for x in data]
|
|
300
|
+
|
|
301
|
+
@workflow
|
|
302
|
+
async def my_workflow(ctx: WorkflowContext):
|
|
303
|
+
result = await ctx.step(process_data, [1, 2, 3], multiplier=3)
|
|
304
|
+
return result
|
|
305
|
+
```
|
|
306
|
+
|
|
307
|
+
Example (checkpoint awaitable):
|
|
308
|
+
```python
|
|
309
|
+
@workflow
|
|
310
|
+
async def my_workflow(ctx: WorkflowContext):
|
|
311
|
+
# Checkpoint expensive external call
|
|
312
|
+
data = await ctx.step("fetch_api", fetch_from_external_api())
|
|
313
|
+
return data
|
|
314
|
+
```
|
|
315
|
+
"""
|
|
316
|
+
import inspect
|
|
317
|
+
|
|
318
|
+
# Determine which calling pattern is being used
|
|
319
|
+
if callable(name_or_handler) and hasattr(name_or_handler, "_agnt5_config"):
|
|
320
|
+
# Pattern 1: step(handler, *args, **kwargs) - @function call
|
|
321
|
+
return await self._step_function(name_or_handler, func_or_awaitable, *args, **kwargs)
|
|
322
|
+
elif isinstance(name_or_handler, str):
|
|
323
|
+
# Check if it's a registered function name (legacy pattern)
|
|
324
|
+
from .function import FunctionRegistry
|
|
325
|
+
if FunctionRegistry.get(name_or_handler) is not None:
|
|
326
|
+
# Pattern 4: Legacy string-based function call
|
|
327
|
+
return await self._step_function(name_or_handler, func_or_awaitable, *args, **kwargs)
|
|
328
|
+
elif func_or_awaitable is not None:
|
|
329
|
+
# Pattern 2/3: step("name", awaitable) or step("name", callable, *args)
|
|
330
|
+
return await self._step_checkpoint(name_or_handler, func_or_awaitable, *args, **kwargs)
|
|
331
|
+
else:
|
|
332
|
+
# String without second arg and not a registered function
|
|
333
|
+
raise ValueError(
|
|
334
|
+
f"Function '{name_or_handler}' not found in registry. "
|
|
335
|
+
f"Either register it with @function decorator, or use "
|
|
336
|
+
f"ctx.step('{name_or_handler}', awaitable) to checkpoint an arbitrary operation."
|
|
337
|
+
)
|
|
338
|
+
elif inspect.iscoroutine(name_or_handler) or inspect.isawaitable(name_or_handler):
|
|
339
|
+
# Awaitable passed directly - auto-generate name
|
|
340
|
+
coro_name = getattr(name_or_handler, '__name__', 'awaitable')
|
|
341
|
+
return await self._step_checkpoint(coro_name, name_or_handler)
|
|
342
|
+
elif callable(name_or_handler):
|
|
343
|
+
# Callable without @function decorator
|
|
344
|
+
raise ValueError(
|
|
345
|
+
f"Function '{name_or_handler.__name__}' is not a registered @function. "
|
|
346
|
+
f"Did you forget to add the @function decorator? "
|
|
347
|
+
f"Or use ctx.step('name', callable) for non-decorated functions."
|
|
348
|
+
)
|
|
349
|
+
else:
|
|
350
|
+
raise ValueError(
|
|
351
|
+
f"step() first argument must be a @function, string name, or awaitable. "
|
|
352
|
+
f"Got: {type(name_or_handler)}"
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
async def _step_function(
|
|
356
|
+
self,
|
|
357
|
+
handler: Union[str, Callable],
|
|
358
|
+
first_arg: Any = None,
|
|
359
|
+
*args: Any,
|
|
360
|
+
**kwargs: Any,
|
|
361
|
+
) -> Any:
|
|
362
|
+
"""
|
|
363
|
+
Internal: Execute a @function as a durable step.
|
|
364
|
+
|
|
365
|
+
This handles both function references and legacy string-based calls.
|
|
366
|
+
"""
|
|
367
|
+
from .function import FunctionRegistry
|
|
368
|
+
|
|
369
|
+
# Reconstruct args tuple (first_arg may have been split out by step())
|
|
370
|
+
if first_arg is not None:
|
|
371
|
+
args = (first_arg,) + args
|
|
372
|
+
|
|
373
|
+
# Extract handler name from function reference or use string
|
|
374
|
+
if callable(handler):
|
|
375
|
+
handler_name = handler.__name__
|
|
376
|
+
if not hasattr(handler, "_agnt5_config"):
|
|
377
|
+
raise ValueError(
|
|
378
|
+
f"Function '{handler_name}' is not a registered @function. "
|
|
379
|
+
f"Did you forget to add the @function decorator?"
|
|
380
|
+
)
|
|
381
|
+
else:
|
|
382
|
+
handler_name = handler
|
|
383
|
+
|
|
384
|
+
# Generate unique step name for durability
|
|
385
|
+
step_name = f"{handler_name}_{self._step_counter}"
|
|
386
|
+
self._step_counter += 1
|
|
387
|
+
|
|
388
|
+
# Generate unique event_id for this step (for hierarchy tracking)
|
|
389
|
+
step_event_id = str(uuid.uuid4())
|
|
390
|
+
|
|
391
|
+
# Check if step already completed (for replay)
|
|
392
|
+
if self._workflow_entity.has_completed_step(step_name):
|
|
393
|
+
result = self._workflow_entity.get_completed_step(step_name)
|
|
394
|
+
self._logger.info(f"🔄 Replaying cached step: {step_name}")
|
|
395
|
+
return result
|
|
396
|
+
|
|
397
|
+
# Emit workflow.step.started checkpoint
|
|
398
|
+
self._send_checkpoint(
|
|
399
|
+
"workflow.step.started",
|
|
400
|
+
{
|
|
401
|
+
"step_name": step_name,
|
|
402
|
+
"handler_name": handler_name,
|
|
403
|
+
"input": args or kwargs,
|
|
404
|
+
"event_id": step_event_id, # Include for hierarchy tracking
|
|
405
|
+
},
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
# Push this step's event_id onto the stack for nested calls
|
|
409
|
+
self._step_event_stack.append(step_event_id)
|
|
410
|
+
|
|
411
|
+
# Execute function with OpenTelemetry span
|
|
412
|
+
self._logger.info(f"▶️ Executing new step: {step_name}")
|
|
413
|
+
func_config = FunctionRegistry.get(handler_name)
|
|
414
|
+
if func_config is None:
|
|
415
|
+
raise ValueError(f"Function '{handler_name}' not found in registry")
|
|
416
|
+
|
|
417
|
+
# Import span creation utility (uses contextvar for async-safe parent-child linking)
|
|
418
|
+
from .tracing import create_span
|
|
419
|
+
import json
|
|
420
|
+
|
|
421
|
+
# Serialize input data for span attributes
|
|
422
|
+
input_repr = json.dumps({"args": args, "kwargs": kwargs}) if args or kwargs else "{}"
|
|
423
|
+
|
|
424
|
+
# Create span for task execution (contextvar handles parent-child linking)
|
|
425
|
+
with create_span(
|
|
426
|
+
f"workflow.task.{handler_name}",
|
|
427
|
+
"function",
|
|
428
|
+
self._runtime_context,
|
|
429
|
+
{
|
|
430
|
+
"step_name": step_name,
|
|
431
|
+
"handler_name": handler_name,
|
|
432
|
+
"run_id": self.run_id,
|
|
433
|
+
"input.data": input_repr,
|
|
434
|
+
},
|
|
435
|
+
) as span:
|
|
436
|
+
# Create FunctionContext for the function execution
|
|
437
|
+
func_ctx = FunctionContext(
|
|
438
|
+
run_id=f"{self.run_id}:task:{handler_name}",
|
|
439
|
+
runtime_context=self._runtime_context,
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
try:
|
|
443
|
+
# Execute function with arguments
|
|
444
|
+
# Support legacy pattern: ctx.task("func_name", input=data) or ctx.task(func_ref, input=data)
|
|
445
|
+
if len(args) == 0 and "input" in kwargs:
|
|
446
|
+
# Legacy pattern - single input parameter
|
|
447
|
+
input_data = kwargs.pop("input") # Remove from kwargs
|
|
448
|
+
handler_result = func_config.handler(func_ctx, input_data, **kwargs)
|
|
449
|
+
else:
|
|
450
|
+
# Type-safe pattern - pass all args/kwargs
|
|
451
|
+
handler_result = func_config.handler(func_ctx, *args, **kwargs)
|
|
452
|
+
|
|
453
|
+
# Check if result is an async generator (streaming function or agent)
|
|
454
|
+
# If so, consume it while forwarding events via delta queue
|
|
455
|
+
if inspect.isasyncgen(handler_result):
|
|
456
|
+
result = await self._consume_streaming_result(handler_result, step_name)
|
|
457
|
+
elif inspect.iscoroutine(handler_result):
|
|
458
|
+
result = await handler_result
|
|
459
|
+
else:
|
|
460
|
+
result = handler_result
|
|
461
|
+
|
|
462
|
+
# Add output data to span
|
|
463
|
+
try:
|
|
464
|
+
output_repr = json.dumps(result)
|
|
465
|
+
span.set_attribute("output.data", output_repr)
|
|
466
|
+
except (TypeError, ValueError):
|
|
467
|
+
# If result is not JSON serializable, use repr
|
|
468
|
+
span.set_attribute("output.data", repr(result))
|
|
469
|
+
|
|
470
|
+
# Record step completion in WorkflowEntity
|
|
471
|
+
self._workflow_entity.record_step_completion(
|
|
472
|
+
step_name, handler_name, args or kwargs, result
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
# Pop this step's event_id from the stack (execution complete)
|
|
476
|
+
if self._step_event_stack:
|
|
477
|
+
popped_id = self._step_event_stack.pop()
|
|
478
|
+
if popped_id != step_event_id:
|
|
479
|
+
self._logger.warning(
|
|
480
|
+
f"Step event stack mismatch in task(): expected {step_event_id}, got {popped_id}"
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
# Emit workflow.step.completed checkpoint
|
|
484
|
+
self._send_checkpoint(
|
|
485
|
+
"workflow.step.completed",
|
|
486
|
+
{
|
|
487
|
+
"step_name": step_name,
|
|
488
|
+
"handler_name": handler_name,
|
|
489
|
+
"input": args or kwargs,
|
|
490
|
+
"result": result,
|
|
491
|
+
"event_id": step_event_id, # Include for consistency
|
|
492
|
+
},
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
return result
|
|
496
|
+
|
|
497
|
+
except Exception as e:
|
|
498
|
+
# Pop this step's event_id from the stack (execution failed)
|
|
499
|
+
if self._step_event_stack:
|
|
500
|
+
popped_id = self._step_event_stack.pop()
|
|
501
|
+
if popped_id != step_event_id:
|
|
502
|
+
self._logger.warning(
|
|
503
|
+
f"Step event stack mismatch in task() error path: expected {step_event_id}, got {popped_id}"
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
# Emit workflow.step.failed checkpoint
|
|
507
|
+
self._send_checkpoint(
|
|
508
|
+
"workflow.step.failed",
|
|
509
|
+
{
|
|
510
|
+
"step_name": step_name,
|
|
511
|
+
"handler_name": handler_name,
|
|
512
|
+
"input": args or kwargs,
|
|
513
|
+
"error": str(e),
|
|
514
|
+
"error_type": type(e).__name__,
|
|
515
|
+
"event_id": step_event_id, # Include for consistency
|
|
516
|
+
},
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
# Record error in span
|
|
520
|
+
span.set_attribute("error", "true")
|
|
521
|
+
span.set_attribute("error.message", str(e))
|
|
522
|
+
span.set_attribute("error.type", type(e).__name__)
|
|
523
|
+
|
|
524
|
+
# Re-raise to propagate failure
|
|
525
|
+
raise
|
|
526
|
+
|
|
527
|
+
async def parallel(self, *tasks: Awaitable[T]) -> List[T]:
|
|
528
|
+
"""
|
|
529
|
+
Run multiple tasks in parallel.
|
|
530
|
+
|
|
531
|
+
Args:
|
|
532
|
+
*tasks: Async tasks to run in parallel
|
|
533
|
+
|
|
534
|
+
Returns:
|
|
535
|
+
List of results in the same order as tasks
|
|
536
|
+
|
|
537
|
+
Example:
|
|
538
|
+
result1, result2 = await ctx.parallel(
|
|
539
|
+
fetch_data(source1),
|
|
540
|
+
fetch_data(source2)
|
|
541
|
+
)
|
|
542
|
+
"""
|
|
543
|
+
import asyncio
|
|
544
|
+
|
|
545
|
+
return list(await asyncio.gather(*tasks))
|
|
546
|
+
|
|
547
|
+
async def gather(self, **tasks: Awaitable[T]) -> Dict[str, T]:
|
|
548
|
+
"""
|
|
549
|
+
Run tasks in parallel with named results.
|
|
550
|
+
|
|
551
|
+
Args:
|
|
552
|
+
**tasks: Named async tasks to run in parallel
|
|
553
|
+
|
|
554
|
+
Returns:
|
|
555
|
+
Dictionary mapping names to results
|
|
556
|
+
|
|
557
|
+
Example:
|
|
558
|
+
results = await ctx.gather(
|
|
559
|
+
db=query_database(),
|
|
560
|
+
api=fetch_api()
|
|
561
|
+
)
|
|
562
|
+
"""
|
|
563
|
+
import asyncio
|
|
564
|
+
|
|
565
|
+
keys = list(tasks.keys())
|
|
566
|
+
values = list(tasks.values())
|
|
567
|
+
results = await asyncio.gather(*values)
|
|
568
|
+
return dict(zip(keys, results))
|
|
569
|
+
|
|
570
|
+
async def task(
|
|
571
|
+
self,
|
|
572
|
+
handler: Union[str, Callable],
|
|
573
|
+
*args: Any,
|
|
574
|
+
**kwargs: Any,
|
|
575
|
+
) -> Any:
|
|
576
|
+
"""
|
|
577
|
+
Execute a function and wait for result.
|
|
578
|
+
|
|
579
|
+
.. deprecated::
|
|
580
|
+
Use :meth:`step` instead. ``task()`` will be removed in a future version.
|
|
581
|
+
|
|
582
|
+
This method is an alias for :meth:`step` for backward compatibility.
|
|
583
|
+
New code should use ``ctx.step()`` directly.
|
|
584
|
+
|
|
585
|
+
Args:
|
|
586
|
+
handler: Either a @function reference or string name
|
|
587
|
+
*args: Positional arguments to pass to the function
|
|
588
|
+
**kwargs: Keyword arguments to pass to the function
|
|
589
|
+
|
|
590
|
+
Returns:
|
|
591
|
+
Function result
|
|
592
|
+
"""
|
|
593
|
+
import warnings
|
|
594
|
+
|
|
595
|
+
warnings.warn(
|
|
596
|
+
"ctx.task() is deprecated, use ctx.step() instead. "
|
|
597
|
+
"task() will be removed in a future version.",
|
|
598
|
+
DeprecationWarning,
|
|
599
|
+
stacklevel=2,
|
|
600
|
+
)
|
|
601
|
+
return await self.step(handler, *args, **kwargs)
|
|
602
|
+
|
|
603
|
+
async def _step_checkpoint(
|
|
604
|
+
self,
|
|
605
|
+
name: str,
|
|
606
|
+
func_or_awaitable: Union[Callable[..., Awaitable[T]], Awaitable[T]],
|
|
607
|
+
*args: Any,
|
|
608
|
+
**kwargs: Any,
|
|
609
|
+
) -> T:
|
|
610
|
+
"""
|
|
611
|
+
Internal: Checkpoint an arbitrary awaitable or callable for durability.
|
|
612
|
+
|
|
613
|
+
If workflow crashes, won't re-execute this step on retry.
|
|
614
|
+
The step result is persisted to the platform for crash recovery.
|
|
615
|
+
|
|
616
|
+
When a CheckpointClient is available, this method uses platform-side
|
|
617
|
+
memoization via gRPC. The platform stores step results in the run_steps
|
|
618
|
+
table, enabling replay even after worker crashes.
|
|
619
|
+
|
|
620
|
+
Args:
|
|
621
|
+
name: Unique name for this checkpoint (used as step_key for memoization)
|
|
622
|
+
func_or_awaitable: Either an async function or awaitable
|
|
623
|
+
*args: Arguments to pass if func_or_awaitable is callable
|
|
624
|
+
**kwargs: Keyword arguments to pass if func_or_awaitable is callable
|
|
625
|
+
|
|
626
|
+
Returns:
|
|
627
|
+
The result of the function/awaitable
|
|
628
|
+
"""
|
|
629
|
+
import inspect
|
|
630
|
+
import json
|
|
631
|
+
import time
|
|
632
|
+
|
|
633
|
+
# Generate step key for platform memoization
|
|
634
|
+
step_key = f"step:{name}:{self._step_counter}"
|
|
635
|
+
self._step_counter += 1
|
|
636
|
+
|
|
637
|
+
# Generate unique event_id for this step (for hierarchy tracking)
|
|
638
|
+
step_event_id = str(uuid.uuid4())
|
|
639
|
+
|
|
640
|
+
# Check platform-side memoization first (Phase 3)
|
|
641
|
+
if self._checkpoint_client:
|
|
642
|
+
try:
|
|
643
|
+
result = await self._checkpoint_client.step_started(
|
|
644
|
+
self.run_id,
|
|
645
|
+
step_key,
|
|
646
|
+
name,
|
|
647
|
+
"checkpoint",
|
|
648
|
+
)
|
|
649
|
+
if result.memoized and result.cached_output:
|
|
650
|
+
# Deserialize cached output
|
|
651
|
+
cached_value = json.loads(result.cached_output.decode("utf-8"))
|
|
652
|
+
self._logger.info(f"🔄 Replaying memoized step from platform: {name}")
|
|
653
|
+
# Also record locally for consistency
|
|
654
|
+
self._workflow_entity.record_step_completion(name, "checkpoint", None, cached_value)
|
|
655
|
+
return cached_value
|
|
656
|
+
except Exception as e:
|
|
657
|
+
self._logger.warning(f"Platform memoization check failed, falling back to local: {e}")
|
|
658
|
+
|
|
659
|
+
# Fall back to local memoization (for backward compatibility)
|
|
660
|
+
if self._workflow_entity.has_completed_step(name):
|
|
661
|
+
result = self._workflow_entity.get_completed_step(name)
|
|
662
|
+
self._logger.info(f"🔄 Replaying checkpoint from local cache: {name}")
|
|
663
|
+
return result
|
|
664
|
+
|
|
665
|
+
# Emit workflow.step.started checkpoint for observability
|
|
666
|
+
self._send_checkpoint(
|
|
667
|
+
"workflow.step.started",
|
|
668
|
+
{
|
|
669
|
+
"step_name": name,
|
|
670
|
+
"handler_name": "checkpoint",
|
|
671
|
+
"event_id": step_event_id, # Include for hierarchy tracking
|
|
672
|
+
},
|
|
673
|
+
)
|
|
674
|
+
|
|
675
|
+
# Push this step's event_id onto the stack for nested calls
|
|
676
|
+
self._step_event_stack.append(step_event_id)
|
|
677
|
+
|
|
678
|
+
start_time = time.time()
|
|
679
|
+
try:
|
|
680
|
+
# Execute and checkpoint
|
|
681
|
+
if inspect.isasyncgen(func_or_awaitable):
|
|
682
|
+
# Direct async generator - consume while forwarding events
|
|
683
|
+
result = await self._consume_streaming_result(func_or_awaitable, name)
|
|
684
|
+
elif inspect.iscoroutine(func_or_awaitable) or inspect.isawaitable(func_or_awaitable):
|
|
685
|
+
result = await func_or_awaitable
|
|
686
|
+
elif callable(func_or_awaitable):
|
|
687
|
+
# Call with args/kwargs if provided
|
|
688
|
+
call_result = func_or_awaitable(*args, **kwargs)
|
|
689
|
+
if inspect.isasyncgen(call_result):
|
|
690
|
+
# Callable returned async generator - consume while forwarding events
|
|
691
|
+
result = await self._consume_streaming_result(call_result, name)
|
|
692
|
+
elif inspect.iscoroutine(call_result) or inspect.isawaitable(call_result):
|
|
693
|
+
result = await call_result
|
|
694
|
+
else:
|
|
695
|
+
result = call_result
|
|
696
|
+
else:
|
|
697
|
+
raise ValueError(f"step() second argument must be awaitable or callable, got {type(func_or_awaitable)}")
|
|
698
|
+
|
|
699
|
+
latency_ms = int((time.time() - start_time) * 1000)
|
|
700
|
+
|
|
701
|
+
# Record step completion locally for in-memory replay
|
|
702
|
+
self._workflow_entity.record_step_completion(name, "checkpoint", None, result)
|
|
703
|
+
|
|
704
|
+
# Record to platform for persistent memoization (Phase 3)
|
|
705
|
+
if self._checkpoint_client:
|
|
706
|
+
try:
|
|
707
|
+
output_bytes = json.dumps(result).encode("utf-8")
|
|
708
|
+
await self._checkpoint_client.step_completed(
|
|
709
|
+
self.run_id,
|
|
710
|
+
step_key,
|
|
711
|
+
name,
|
|
712
|
+
"checkpoint",
|
|
713
|
+
output_bytes,
|
|
714
|
+
latency_ms,
|
|
715
|
+
)
|
|
716
|
+
except Exception as e:
|
|
717
|
+
self._logger.warning(f"Failed to record step completion to platform: {e}")
|
|
718
|
+
|
|
719
|
+
# Pop this step's event_id from the stack (execution complete)
|
|
720
|
+
if self._step_event_stack:
|
|
721
|
+
popped_id = self._step_event_stack.pop()
|
|
722
|
+
if popped_id != step_event_id:
|
|
723
|
+
self._logger.warning(
|
|
724
|
+
f"Step event stack mismatch in step(): expected {step_event_id}, got {popped_id}"
|
|
725
|
+
)
|
|
726
|
+
|
|
727
|
+
# Emit workflow.step.completed checkpoint to journal for crash recovery
|
|
728
|
+
self._send_checkpoint(
|
|
729
|
+
"workflow.step.completed",
|
|
730
|
+
{
|
|
731
|
+
"step_name": name,
|
|
732
|
+
"handler_name": "checkpoint",
|
|
733
|
+
"result": result,
|
|
734
|
+
"event_id": step_event_id, # Include for consistency
|
|
735
|
+
},
|
|
736
|
+
)
|
|
737
|
+
|
|
738
|
+
self._logger.info(f"✅ Checkpoint completed: {name} ({latency_ms}ms)")
|
|
739
|
+
return result
|
|
740
|
+
|
|
741
|
+
except Exception as e:
|
|
742
|
+
# Pop this step's event_id from the stack (execution failed)
|
|
743
|
+
if self._step_event_stack:
|
|
744
|
+
popped_id = self._step_event_stack.pop()
|
|
745
|
+
if popped_id != step_event_id:
|
|
746
|
+
self._logger.warning(
|
|
747
|
+
f"Step event stack mismatch in step() error path: expected {step_event_id}, got {popped_id}"
|
|
748
|
+
)
|
|
749
|
+
|
|
750
|
+
# Record failure to platform (Phase 3)
|
|
751
|
+
if self._checkpoint_client:
|
|
752
|
+
try:
|
|
753
|
+
await self._checkpoint_client.step_failed(
|
|
754
|
+
self.run_id,
|
|
755
|
+
step_key,
|
|
756
|
+
name,
|
|
757
|
+
"checkpoint",
|
|
758
|
+
str(e),
|
|
759
|
+
type(e).__name__,
|
|
760
|
+
)
|
|
761
|
+
except Exception as cp_err:
|
|
762
|
+
self._logger.warning(f"Failed to record step failure to platform: {cp_err}")
|
|
763
|
+
|
|
764
|
+
# Emit workflow.step.failed checkpoint
|
|
765
|
+
self._send_checkpoint(
|
|
766
|
+
"workflow.step.failed",
|
|
767
|
+
{
|
|
768
|
+
"step_name": name,
|
|
769
|
+
"handler_name": "checkpoint",
|
|
770
|
+
"error": str(e),
|
|
771
|
+
"error_type": type(e).__name__,
|
|
772
|
+
"event_id": step_event_id, # Include for consistency
|
|
773
|
+
},
|
|
774
|
+
)
|
|
775
|
+
raise
|
|
776
|
+
|
|
777
|
+
async def sleep(self, seconds: float, name: Optional[str] = None) -> None:
|
|
778
|
+
"""
|
|
779
|
+
Durable sleep that survives workflow restarts.
|
|
780
|
+
|
|
781
|
+
Unlike regular `asyncio.sleep()`, this sleep is checkpointed. If the
|
|
782
|
+
workflow crashes and restarts, it will only sleep for the remaining
|
|
783
|
+
duration (or skip entirely if the sleep period has already elapsed).
|
|
784
|
+
|
|
785
|
+
Args:
|
|
786
|
+
seconds: Duration to sleep in seconds
|
|
787
|
+
name: Optional name for the sleep checkpoint (auto-generated if not provided)
|
|
788
|
+
|
|
789
|
+
Example:
|
|
790
|
+
```python
|
|
791
|
+
@workflow
|
|
792
|
+
async def delayed_notification(ctx: WorkflowContext, user_id: str):
|
|
793
|
+
# Send immediate acknowledgment
|
|
794
|
+
await ctx.step(send_ack, user_id)
|
|
795
|
+
|
|
796
|
+
# Wait 24 hours (survives restarts!)
|
|
797
|
+
await ctx.sleep(24 * 60 * 60, name="wait_24h")
|
|
798
|
+
|
|
799
|
+
# Send follow-up
|
|
800
|
+
await ctx.step(send_followup, user_id)
|
|
801
|
+
```
|
|
802
|
+
"""
|
|
803
|
+
import time
|
|
804
|
+
|
|
805
|
+
# Generate unique step name for this sleep
|
|
806
|
+
sleep_name = name or f"sleep_{self._step_counter}"
|
|
807
|
+
self._step_counter += 1
|
|
808
|
+
step_key = f"sleep:{sleep_name}"
|
|
809
|
+
|
|
810
|
+
# Check if sleep was already started (replay scenario)
|
|
811
|
+
if self._workflow_entity.has_completed_step(step_key):
|
|
812
|
+
sleep_record = self._workflow_entity.get_completed_step(step_key)
|
|
813
|
+
start_time = sleep_record.get("start_time", 0)
|
|
814
|
+
duration = sleep_record.get("duration", seconds)
|
|
815
|
+
elapsed = time.time() - start_time
|
|
816
|
+
|
|
817
|
+
if elapsed >= duration:
|
|
818
|
+
# Sleep period already elapsed
|
|
819
|
+
self._logger.info(f"🔄 Sleep '{sleep_name}' already completed (elapsed: {elapsed:.1f}s)")
|
|
820
|
+
return
|
|
821
|
+
|
|
822
|
+
# Sleep for remaining duration
|
|
823
|
+
remaining = duration - elapsed
|
|
824
|
+
self._logger.info(f"⏰ Resuming sleep '{sleep_name}': {remaining:.1f}s remaining")
|
|
825
|
+
await asyncio.sleep(remaining)
|
|
826
|
+
return
|
|
827
|
+
|
|
828
|
+
# Record sleep start time for replay
|
|
829
|
+
sleep_record = {
|
|
830
|
+
"start_time": time.time(),
|
|
831
|
+
"duration": seconds,
|
|
832
|
+
}
|
|
833
|
+
self._workflow_entity.record_step_completion(step_key, "sleep", None, sleep_record)
|
|
834
|
+
|
|
835
|
+
# Emit checkpoint for observability
|
|
836
|
+
step_event_id = str(uuid.uuid4())
|
|
837
|
+
self._send_checkpoint(
|
|
838
|
+
"workflow.step.started",
|
|
839
|
+
{
|
|
840
|
+
"step_name": sleep_name,
|
|
841
|
+
"handler_name": "sleep",
|
|
842
|
+
"duration_seconds": seconds,
|
|
843
|
+
"event_id": step_event_id,
|
|
844
|
+
},
|
|
845
|
+
)
|
|
846
|
+
|
|
847
|
+
self._logger.info(f"💤 Starting durable sleep '{sleep_name}': {seconds}s")
|
|
848
|
+
await asyncio.sleep(seconds)
|
|
849
|
+
|
|
850
|
+
# Emit completion checkpoint
|
|
851
|
+
self._send_checkpoint(
|
|
852
|
+
"workflow.step.completed",
|
|
853
|
+
{
|
|
854
|
+
"step_name": sleep_name,
|
|
855
|
+
"handler_name": "sleep",
|
|
856
|
+
"duration_seconds": seconds,
|
|
857
|
+
"event_id": step_event_id,
|
|
858
|
+
},
|
|
859
|
+
)
|
|
860
|
+
self._logger.info(f"⏰ Sleep '{sleep_name}' completed")
|
|
861
|
+
|
|
862
|
+
async def wait_for_user(
|
|
863
|
+
self, question: str, input_type: str = "text", options: Optional[List[Dict]] = None
|
|
864
|
+
) -> str:
|
|
865
|
+
"""
|
|
866
|
+
Pause workflow execution and wait for user input.
|
|
867
|
+
|
|
868
|
+
On replay (even after worker crash), resumes from this point
|
|
869
|
+
with the user's response. This method enables human-in-the-loop
|
|
870
|
+
workflows by pausing execution and waiting for user interaction.
|
|
871
|
+
|
|
872
|
+
Args:
|
|
873
|
+
question: Question to ask the user
|
|
874
|
+
input_type: Type of input - "text", "approval", or "choice"
|
|
875
|
+
options: For approval/choice, list of option dicts with 'id' and 'label'
|
|
876
|
+
|
|
877
|
+
Returns:
|
|
878
|
+
User's response string
|
|
879
|
+
|
|
880
|
+
Raises:
|
|
881
|
+
WaitingForUserInputException: When no cached response exists (first call)
|
|
882
|
+
|
|
883
|
+
Example (text input):
|
|
884
|
+
```python
|
|
885
|
+
city = await ctx.wait_for_user("Which city?")
|
|
886
|
+
```
|
|
887
|
+
|
|
888
|
+
Example (approval):
|
|
889
|
+
```python
|
|
890
|
+
decision = await ctx.wait_for_user(
|
|
891
|
+
"Approve this action?",
|
|
892
|
+
input_type="approval",
|
|
893
|
+
options=[
|
|
894
|
+
{"id": "approve", "label": "Approve"},
|
|
895
|
+
{"id": "reject", "label": "Reject"}
|
|
896
|
+
]
|
|
897
|
+
)
|
|
898
|
+
```
|
|
899
|
+
|
|
900
|
+
Example (choice):
|
|
901
|
+
```python
|
|
902
|
+
model = await ctx.wait_for_user(
|
|
903
|
+
"Which model?",
|
|
904
|
+
input_type="choice",
|
|
905
|
+
options=[
|
|
906
|
+
{"id": "gpt4", "label": "GPT-4"},
|
|
907
|
+
{"id": "claude", "label": "Claude"}
|
|
908
|
+
]
|
|
909
|
+
)
|
|
910
|
+
```
|
|
911
|
+
"""
|
|
912
|
+
from .exceptions import WaitingForUserInputException
|
|
913
|
+
|
|
914
|
+
# Generate unique step name for this user input request
|
|
915
|
+
# Each wait_for_user call gets a unique key based on pause_index
|
|
916
|
+
# This allows multi-step HITL workflows where each pause gets its own response
|
|
917
|
+
pause_index = self._workflow_entity._pause_index
|
|
918
|
+
response_key = f"user_response:{self.run_id}:{pause_index}"
|
|
919
|
+
|
|
920
|
+
# Increment pause index for next call (whether we replay or pause)
|
|
921
|
+
self._workflow_entity._pause_index += 1
|
|
922
|
+
|
|
923
|
+
# Check if we already have the user's response (replay scenario)
|
|
924
|
+
if self._workflow_entity.has_completed_step(response_key):
|
|
925
|
+
response = self._workflow_entity.get_completed_step(response_key)
|
|
926
|
+
self._logger.info(f"🔄 Replaying user response from checkpoint (pause {pause_index})")
|
|
927
|
+
return response
|
|
928
|
+
|
|
929
|
+
# No response yet - pause execution
|
|
930
|
+
# Collect current workflow state for checkpoint
|
|
931
|
+
checkpoint_state = {}
|
|
932
|
+
if hasattr(self._workflow_entity, "_state") and self._workflow_entity._state is not None:
|
|
933
|
+
checkpoint_state = self._workflow_entity._state.get_state_snapshot()
|
|
934
|
+
|
|
935
|
+
self._logger.info(f"⏸️ Pausing workflow for user input: {question}")
|
|
936
|
+
|
|
937
|
+
raise WaitingForUserInputException(
|
|
938
|
+
question=question,
|
|
939
|
+
input_type=input_type,
|
|
940
|
+
options=options,
|
|
941
|
+
checkpoint_state=checkpoint_state,
|
|
942
|
+
pause_index=pause_index, # Pass the pause index for multi-step HITL
|
|
943
|
+
)
|
|
944
|
+
|
|
945
|
+
|
|
946
|
+
# ============================================================================
|
|
947
|
+
# Helper functions for workflow execution
|
|
948
|
+
# ============================================================================
|
|
949
|
+
|
|
950
|
+
|
|
951
|
+
def _sanitize_for_json(obj: Any) -> Any:
|
|
952
|
+
"""
|
|
953
|
+
Sanitize data for JSON serialization by removing or converting non-serializable objects.
|
|
954
|
+
|
|
955
|
+
Specifically handles:
|
|
956
|
+
- WorkflowContext objects (replaced with placeholder)
|
|
957
|
+
- Nested structures (recursively sanitized)
|
|
958
|
+
|
|
959
|
+
Args:
|
|
960
|
+
obj: Object to sanitize
|
|
961
|
+
|
|
962
|
+
Returns:
|
|
963
|
+
JSON-serializable version of the object
|
|
964
|
+
"""
|
|
965
|
+
# Handle None, primitives
|
|
966
|
+
if obj is None or isinstance(obj, (str, int, float, bool)):
|
|
967
|
+
return obj
|
|
968
|
+
|
|
969
|
+
# Handle WorkflowContext - replace with placeholder
|
|
970
|
+
if isinstance(obj, WorkflowContext):
|
|
971
|
+
return "<WorkflowContext>"
|
|
972
|
+
|
|
973
|
+
# Handle tuples/lists - recursively sanitize
|
|
974
|
+
if isinstance(obj, (tuple, list)):
|
|
975
|
+
sanitized = [_sanitize_for_json(item) for item in obj]
|
|
976
|
+
return sanitized if isinstance(obj, list) else tuple(sanitized)
|
|
977
|
+
|
|
978
|
+
# Handle dicts - recursively sanitize values
|
|
979
|
+
if isinstance(obj, dict):
|
|
980
|
+
return {k: _sanitize_for_json(v) for k, v in obj.items()}
|
|
981
|
+
|
|
982
|
+
# For other objects, try to serialize or convert to string
|
|
983
|
+
try:
|
|
984
|
+
import json
|
|
985
|
+
json.dumps(obj)
|
|
986
|
+
return obj
|
|
987
|
+
except (TypeError, ValueError):
|
|
988
|
+
# Not JSON serializable, use string representation
|
|
989
|
+
return repr(obj)
|
|
990
|
+
|
|
991
|
+
|
|
992
|
+
# ============================================================================
|
|
993
|
+
# WorkflowEntity: Entity specialized for workflow execution state
|
|
994
|
+
# ============================================================================
|
|
995
|
+
|
|
996
|
+
|
|
997
|
+
class WorkflowEntity(Entity):
|
|
998
|
+
"""
|
|
999
|
+
Entity specialized for workflow execution state.
|
|
1000
|
+
|
|
1001
|
+
Extends Entity with workflow-specific capabilities:
|
|
1002
|
+
- Step tracking for replay and crash recovery
|
|
1003
|
+
- State change tracking for debugging and audit (AI workflows)
|
|
1004
|
+
- Completed step cache for efficient replay
|
|
1005
|
+
- Automatic state persistence after workflow execution
|
|
1006
|
+
|
|
1007
|
+
Workflow state is persisted to the database after successful execution,
|
|
1008
|
+
enabling crash recovery, replay, and cross-invocation state management.
|
|
1009
|
+
The workflow decorator automatically calls _persist_state() to ensure
|
|
1010
|
+
durability.
|
|
1011
|
+
"""
|
|
1012
|
+
|
|
1013
|
+
def __init__(
|
|
1014
|
+
self,
|
|
1015
|
+
run_id: str,
|
|
1016
|
+
session_id: Optional[str] = None,
|
|
1017
|
+
user_id: Optional[str] = None,
|
|
1018
|
+
component_name: Optional[str] = None,
|
|
1019
|
+
):
|
|
1020
|
+
"""
|
|
1021
|
+
Initialize workflow entity with memory scope.
|
|
1022
|
+
|
|
1023
|
+
Args:
|
|
1024
|
+
run_id: Unique workflow run identifier
|
|
1025
|
+
session_id: Session identifier for multi-turn conversations (optional)
|
|
1026
|
+
user_id: User identifier for user-scoped memory (optional)
|
|
1027
|
+
component_name: Workflow component name for session-scoped entities (optional)
|
|
1028
|
+
|
|
1029
|
+
Memory Scope Priority:
|
|
1030
|
+
- user_id present → key: user:{user_id} (shared across workflows)
|
|
1031
|
+
- session_id present (and != run_id) → key: workflow:{component_name}:session:{session_id}
|
|
1032
|
+
- else → key: run:{run_id}
|
|
1033
|
+
|
|
1034
|
+
Note: For session scope, component_name enables listing sessions by workflow name.
|
|
1035
|
+
User scope is shared across all workflows (not per-workflow).
|
|
1036
|
+
"""
|
|
1037
|
+
# Determine entity key based on memory scope priority
|
|
1038
|
+
if user_id:
|
|
1039
|
+
# User scope: shared across all workflows (not per-workflow)
|
|
1040
|
+
entity_key = f"user:{user_id}"
|
|
1041
|
+
memory_scope = "user"
|
|
1042
|
+
elif session_id and session_id != run_id:
|
|
1043
|
+
# Session scope: include workflow name for queryability
|
|
1044
|
+
if component_name:
|
|
1045
|
+
entity_key = f"workflow:{component_name}:session:{session_id}"
|
|
1046
|
+
else:
|
|
1047
|
+
# Fallback for backward compatibility
|
|
1048
|
+
entity_key = f"session:{session_id}"
|
|
1049
|
+
memory_scope = "session"
|
|
1050
|
+
else:
|
|
1051
|
+
entity_key = f"run:{run_id}"
|
|
1052
|
+
memory_scope = "run"
|
|
1053
|
+
|
|
1054
|
+
# Initialize as entity with scoped key pattern
|
|
1055
|
+
super().__init__(key=entity_key)
|
|
1056
|
+
|
|
1057
|
+
# Store run_id separately for tracking (even if key is session/user scoped)
|
|
1058
|
+
self._run_id = run_id
|
|
1059
|
+
self._memory_scope = memory_scope
|
|
1060
|
+
self._component_name = component_name
|
|
1061
|
+
# Store scope identifiers for proper scope-based persistence
|
|
1062
|
+
self._session_id = session_id
|
|
1063
|
+
self._user_id = user_id
|
|
1064
|
+
|
|
1065
|
+
# Step tracking for replay and recovery
|
|
1066
|
+
self._step_events: list[Dict[str, Any]] = []
|
|
1067
|
+
self._completed_steps: Dict[str, Any] = {}
|
|
1068
|
+
|
|
1069
|
+
# HITL pause tracking - each wait_for_user gets unique index
|
|
1070
|
+
self._pause_index: int = 0
|
|
1071
|
+
|
|
1072
|
+
# State change tracking for debugging/audit (AI workflows)
|
|
1073
|
+
self._state_changes: list[Dict[str, Any]] = []
|
|
1074
|
+
|
|
1075
|
+
logger.debug(f"Created WorkflowEntity: run={run_id}, scope={memory_scope}, key={entity_key}, component={component_name}")
|
|
1076
|
+
|
|
1077
|
+
@property
|
|
1078
|
+
def run_id(self) -> str:
|
|
1079
|
+
"""Get run_id for this workflow execution."""
|
|
1080
|
+
return self._run_id
|
|
1081
|
+
|
|
1082
|
+
def record_step_completion(
|
|
1083
|
+
self, step_name: str, handler_name: str, input_data: Any, result: Any
|
|
1084
|
+
) -> None:
|
|
1085
|
+
"""
|
|
1086
|
+
Record completed step for replay and recovery.
|
|
1087
|
+
|
|
1088
|
+
Args:
|
|
1089
|
+
step_name: Unique step identifier
|
|
1090
|
+
handler_name: Function handler name
|
|
1091
|
+
input_data: Input data passed to function
|
|
1092
|
+
result: Function result
|
|
1093
|
+
"""
|
|
1094
|
+
# Sanitize input_data and result to ensure JSON serializability
|
|
1095
|
+
# This removes WorkflowContext objects and other non-serializable types
|
|
1096
|
+
sanitized_input = _sanitize_for_json(input_data)
|
|
1097
|
+
sanitized_result = _sanitize_for_json(result)
|
|
1098
|
+
|
|
1099
|
+
self._step_events.append(
|
|
1100
|
+
{
|
|
1101
|
+
"step_name": step_name,
|
|
1102
|
+
"handler_name": handler_name,
|
|
1103
|
+
"input": sanitized_input,
|
|
1104
|
+
"result": sanitized_result,
|
|
1105
|
+
}
|
|
1106
|
+
)
|
|
1107
|
+
self._completed_steps[step_name] = result
|
|
1108
|
+
logger.debug(f"Recorded step completion: {step_name}")
|
|
1109
|
+
|
|
1110
|
+
def get_completed_step(self, step_name: str) -> Optional[Any]:
|
|
1111
|
+
"""
|
|
1112
|
+
Get result of completed step (for replay).
|
|
1113
|
+
|
|
1114
|
+
Args:
|
|
1115
|
+
step_name: Step identifier
|
|
1116
|
+
|
|
1117
|
+
Returns:
|
|
1118
|
+
Step result if completed, None otherwise
|
|
1119
|
+
"""
|
|
1120
|
+
return self._completed_steps.get(step_name)
|
|
1121
|
+
|
|
1122
|
+
def has_completed_step(self, step_name: str) -> bool:
|
|
1123
|
+
"""Check if step has been completed."""
|
|
1124
|
+
return step_name in self._completed_steps
|
|
1125
|
+
|
|
1126
|
+
def inject_user_response(self, response: str) -> None:
|
|
1127
|
+
"""
|
|
1128
|
+
Inject user response as a completed step for workflow resume.
|
|
1129
|
+
|
|
1130
|
+
This method is called by the worker when resuming a paused workflow
|
|
1131
|
+
with the user's response. It stores the response as if it was a
|
|
1132
|
+
completed step, allowing wait_for_user() to retrieve it on replay.
|
|
1133
|
+
|
|
1134
|
+
The response is injected at the current pause_index (which should be
|
|
1135
|
+
restored from metadata before calling this method for multi-step HITL).
|
|
1136
|
+
This matches the key format used by wait_for_user().
|
|
1137
|
+
|
|
1138
|
+
Args:
|
|
1139
|
+
response: User's response to inject
|
|
1140
|
+
|
|
1141
|
+
Example:
|
|
1142
|
+
# Platform resumes workflow with user response
|
|
1143
|
+
workflow_entity.inject_user_response("yes")
|
|
1144
|
+
# On replay, wait_for_user() returns "yes" from cache
|
|
1145
|
+
"""
|
|
1146
|
+
# Inject at current pause_index (restored from metadata for multi-step HITL)
|
|
1147
|
+
response_key = f"user_response:{self.run_id}:{self._pause_index}"
|
|
1148
|
+
self._completed_steps[response_key] = response
|
|
1149
|
+
|
|
1150
|
+
# Also add to step_events so it gets serialized to metadata on next pause
|
|
1151
|
+
# This ensures previous user responses are preserved across resumes
|
|
1152
|
+
self._step_events.append({
|
|
1153
|
+
"step_name": response_key,
|
|
1154
|
+
"handler_name": "user_response",
|
|
1155
|
+
"input": None,
|
|
1156
|
+
"result": response,
|
|
1157
|
+
})
|
|
1158
|
+
|
|
1159
|
+
logger.info(f"Injected user response for {self.run_id} at pause {self._pause_index}: {response}")
|
|
1160
|
+
|
|
1161
|
+
def get_agent_data(self, agent_name: str) -> Dict[str, Any]:
|
|
1162
|
+
"""
|
|
1163
|
+
Get agent conversation data from workflow state.
|
|
1164
|
+
|
|
1165
|
+
Args:
|
|
1166
|
+
agent_name: Name of the agent
|
|
1167
|
+
|
|
1168
|
+
Returns:
|
|
1169
|
+
Dictionary containing agent conversation data (messages, metadata)
|
|
1170
|
+
or empty dict if agent has no data yet
|
|
1171
|
+
|
|
1172
|
+
Example:
|
|
1173
|
+
```python
|
|
1174
|
+
agent_data = workflow_entity.get_agent_data("ResearchAgent")
|
|
1175
|
+
messages = agent_data.get("messages", [])
|
|
1176
|
+
```
|
|
1177
|
+
"""
|
|
1178
|
+
return self.state.get(f"agent.{agent_name}", {})
|
|
1179
|
+
|
|
1180
|
+
def get_agent_messages(self, agent_name: str) -> list[Dict[str, Any]]:
|
|
1181
|
+
"""
|
|
1182
|
+
Get agent messages from workflow state.
|
|
1183
|
+
|
|
1184
|
+
Args:
|
|
1185
|
+
agent_name: Name of the agent
|
|
1186
|
+
|
|
1187
|
+
Returns:
|
|
1188
|
+
List of message dictionaries
|
|
1189
|
+
|
|
1190
|
+
Example:
|
|
1191
|
+
```python
|
|
1192
|
+
messages = workflow_entity.get_agent_messages("ResearchAgent")
|
|
1193
|
+
for msg in messages:
|
|
1194
|
+
print(f"{msg['role']}: {msg['content']}")
|
|
1195
|
+
```
|
|
1196
|
+
"""
|
|
1197
|
+
agent_data = self.get_agent_data(agent_name)
|
|
1198
|
+
return agent_data.get("messages", [])
|
|
1199
|
+
|
|
1200
|
+
def list_agents(self) -> list[str]:
|
|
1201
|
+
"""
|
|
1202
|
+
List all agents with data in this workflow.
|
|
1203
|
+
|
|
1204
|
+
Returns:
|
|
1205
|
+
List of agent names that have stored conversation data
|
|
1206
|
+
|
|
1207
|
+
Example:
|
|
1208
|
+
```python
|
|
1209
|
+
agents = workflow_entity.list_agents()
|
|
1210
|
+
# ['ResearchAgent', 'AnalysisAgent', 'SynthesisAgent']
|
|
1211
|
+
```
|
|
1212
|
+
"""
|
|
1213
|
+
agents = []
|
|
1214
|
+
for key in self.state._state.keys():
|
|
1215
|
+
if key.startswith("agent."):
|
|
1216
|
+
agents.append(key.replace("agent.", "", 1))
|
|
1217
|
+
return agents
|
|
1218
|
+
|
|
1219
|
+
async def _persist_state(self) -> None:
|
|
1220
|
+
"""
|
|
1221
|
+
Internal method to persist workflow state to entity storage.
|
|
1222
|
+
|
|
1223
|
+
This is prefixed with _ so it won't be wrapped by the entity method wrapper.
|
|
1224
|
+
Called after workflow execution completes to ensure state is durable.
|
|
1225
|
+
"""
|
|
1226
|
+
logger.info(f"🔍 DEBUG: _persist_state() CALLED for workflow {self.run_id}")
|
|
1227
|
+
|
|
1228
|
+
try:
|
|
1229
|
+
from .entity import _get_state_adapter
|
|
1230
|
+
|
|
1231
|
+
logger.info(f"🔍 DEBUG: Getting state adapter...")
|
|
1232
|
+
# Get the state adapter (must be in Worker context)
|
|
1233
|
+
adapter = _get_state_adapter()
|
|
1234
|
+
logger.info(f"🔍 DEBUG: Got state adapter: {type(adapter).__name__}")
|
|
1235
|
+
|
|
1236
|
+
logger.info(f"🔍 DEBUG: Getting state snapshot...")
|
|
1237
|
+
# Get current state snapshot
|
|
1238
|
+
state_dict = self.state.get_state_snapshot()
|
|
1239
|
+
logger.info(f"🔍 DEBUG: State snapshot has {len(state_dict)} keys: {list(state_dict.keys())}")
|
|
1240
|
+
|
|
1241
|
+
# Determine scope and scope_id based on memory scope
|
|
1242
|
+
scope = self._memory_scope # "session", "user", or "run"
|
|
1243
|
+
scope_id = ""
|
|
1244
|
+
if self._memory_scope == "session" and self._session_id:
|
|
1245
|
+
scope_id = self._session_id
|
|
1246
|
+
elif self._memory_scope == "user" and self._user_id:
|
|
1247
|
+
scope_id = self._user_id
|
|
1248
|
+
elif self._memory_scope == "run":
|
|
1249
|
+
scope_id = self._run_id
|
|
1250
|
+
|
|
1251
|
+
logger.info(f"🔍 DEBUG: Loading current version for optimistic locking (scope={scope}, scope_id={scope_id})...")
|
|
1252
|
+
# Load current version (for optimistic locking) with proper scope
|
|
1253
|
+
_, current_version = await adapter.load_with_version(
|
|
1254
|
+
self._entity_type, self._key, scope=scope, scope_id=scope_id
|
|
1255
|
+
)
|
|
1256
|
+
logger.info(f"🔍 DEBUG: Current version: {current_version}")
|
|
1257
|
+
|
|
1258
|
+
logger.info(f"🔍 DEBUG: Saving state to database...")
|
|
1259
|
+
|
|
1260
|
+
logger.info(f"🔍 DEBUG: Using scope={scope}, scope_id={scope_id}")
|
|
1261
|
+
# Save state with version check and proper scope
|
|
1262
|
+
new_version = await adapter.save_state(
|
|
1263
|
+
self._entity_type, self._key, state_dict, current_version,
|
|
1264
|
+
scope=scope, scope_id=scope_id
|
|
1265
|
+
)
|
|
1266
|
+
|
|
1267
|
+
logger.info(
|
|
1268
|
+
f"✅ SUCCESS: Persisted WorkflowEntity state for {self.run_id} "
|
|
1269
|
+
f"(version {current_version} -> {new_version}, {len(state_dict)} keys)"
|
|
1270
|
+
)
|
|
1271
|
+
except Exception as e:
|
|
1272
|
+
logger.error(
|
|
1273
|
+
f"❌ ERROR: Failed to persist workflow state for {self.run_id}: {e}",
|
|
1274
|
+
exc_info=True
|
|
1275
|
+
)
|
|
1276
|
+
# Re-raise to let caller handle
|
|
1277
|
+
raise
|
|
1278
|
+
|
|
1279
|
+
@property
|
|
1280
|
+
def state(self) -> "WorkflowState":
|
|
1281
|
+
"""
|
|
1282
|
+
Get workflow state with change tracking.
|
|
1283
|
+
|
|
1284
|
+
Returns WorkflowState which tracks all state mutations
|
|
1285
|
+
for debugging and replay of AI workflows.
|
|
1286
|
+
"""
|
|
1287
|
+
if self._state is None:
|
|
1288
|
+
# Initialize with empty state dict - will be populated by entity system
|
|
1289
|
+
self._state = WorkflowState({}, self)
|
|
1290
|
+
return self._state
|
|
1291
|
+
|
|
1292
|
+
|
|
1293
|
+
class WorkflowState(EntityState):
|
|
1294
|
+
"""
|
|
1295
|
+
State interface for WorkflowEntity with change tracking.
|
|
1296
|
+
|
|
1297
|
+
Extends EntityState to track all state mutations for:
|
|
1298
|
+
- AI workflow debugging
|
|
1299
|
+
- Audit trail
|
|
1300
|
+
- Replay capabilities
|
|
1301
|
+
"""
|
|
1302
|
+
|
|
1303
|
+
def __init__(self, state_dict: Dict[str, Any], workflow_entity: WorkflowEntity):
|
|
1304
|
+
"""
|
|
1305
|
+
Initialize workflow state.
|
|
1306
|
+
|
|
1307
|
+
Args:
|
|
1308
|
+
state_dict: Dictionary to use for state storage
|
|
1309
|
+
workflow_entity: Parent workflow entity for tracking
|
|
1310
|
+
"""
|
|
1311
|
+
super().__init__(state_dict)
|
|
1312
|
+
self._workflow_entity = workflow_entity
|
|
1313
|
+
self._checkpoint_callback: Optional[Callable[[str, dict], None]] = None
|
|
1314
|
+
|
|
1315
|
+
def _set_checkpoint_callback(self, callback: Callable[[str, dict], None]) -> None:
|
|
1316
|
+
"""
|
|
1317
|
+
Set the checkpoint callback for real-time state change streaming.
|
|
1318
|
+
|
|
1319
|
+
Args:
|
|
1320
|
+
callback: Function to call when state changes
|
|
1321
|
+
"""
|
|
1322
|
+
self._checkpoint_callback = callback
|
|
1323
|
+
|
|
1324
|
+
def set(self, key: str, value: Any) -> None:
|
|
1325
|
+
"""Set value and track change."""
|
|
1326
|
+
super().set(key, value)
|
|
1327
|
+
# Track change for debugging/audit
|
|
1328
|
+
import time
|
|
1329
|
+
|
|
1330
|
+
change_record = {"key": key, "value": value, "timestamp": time.time(), "deleted": False}
|
|
1331
|
+
self._workflow_entity._state_changes.append(change_record)
|
|
1332
|
+
|
|
1333
|
+
# Emit checkpoint for real-time state streaming
|
|
1334
|
+
if self._checkpoint_callback:
|
|
1335
|
+
self._checkpoint_callback(
|
|
1336
|
+
"workflow.state.changed", {"key": key, "value": value, "operation": "set"}
|
|
1337
|
+
)
|
|
1338
|
+
|
|
1339
|
+
def delete(self, key: str) -> None:
|
|
1340
|
+
"""Delete key and track change."""
|
|
1341
|
+
super().delete(key)
|
|
1342
|
+
# Track deletion
|
|
1343
|
+
import time
|
|
1344
|
+
|
|
1345
|
+
change_record = {"key": key, "value": None, "timestamp": time.time(), "deleted": True}
|
|
1346
|
+
self._workflow_entity._state_changes.append(change_record)
|
|
1347
|
+
|
|
1348
|
+
# Emit checkpoint for real-time state streaming
|
|
1349
|
+
if self._checkpoint_callback:
|
|
1350
|
+
self._checkpoint_callback("workflow.state.changed", {"key": key, "operation": "delete"})
|
|
1351
|
+
|
|
1352
|
+
def clear(self) -> None:
|
|
1353
|
+
"""Clear all state and track change."""
|
|
1354
|
+
super().clear()
|
|
1355
|
+
# Track clear operation
|
|
1356
|
+
import time
|
|
1357
|
+
|
|
1358
|
+
change_record = {
|
|
1359
|
+
"key": "__clear__",
|
|
1360
|
+
"value": None,
|
|
1361
|
+
"timestamp": time.time(),
|
|
1362
|
+
"deleted": True,
|
|
1363
|
+
}
|
|
1364
|
+
self._workflow_entity._state_changes.append(change_record)
|
|
1365
|
+
|
|
1366
|
+
# Emit checkpoint for real-time state streaming
|
|
1367
|
+
if self._checkpoint_callback:
|
|
1368
|
+
self._checkpoint_callback("workflow.state.changed", {"operation": "clear"})
|
|
1369
|
+
|
|
1370
|
+
def has_changes(self) -> bool:
|
|
1371
|
+
"""Check if any state changes have been tracked."""
|
|
1372
|
+
return len(self._workflow_entity._state_changes) > 0
|
|
1373
|
+
|
|
1374
|
+
def get_state_snapshot(self) -> Dict[str, Any]:
|
|
1375
|
+
"""Get current state as a snapshot dictionary."""
|
|
1376
|
+
return dict(self._state)
|
|
1377
|
+
|
|
1378
|
+
|
|
1379
|
+
class WorkflowRegistry:
|
|
1380
|
+
"""Registry for workflow handlers."""
|
|
1381
|
+
|
|
1382
|
+
@staticmethod
|
|
1383
|
+
def register(config: WorkflowConfig) -> None:
|
|
1384
|
+
"""
|
|
1385
|
+
Register a workflow handler.
|
|
1386
|
+
|
|
1387
|
+
Raises:
|
|
1388
|
+
ValueError: If a workflow with this name is already registered
|
|
1389
|
+
"""
|
|
1390
|
+
if config.name in _WORKFLOW_REGISTRY:
|
|
1391
|
+
existing_workflow = _WORKFLOW_REGISTRY[config.name]
|
|
1392
|
+
logger.error(
|
|
1393
|
+
f"Workflow name collision detected: '{config.name}'\n"
|
|
1394
|
+
f" First defined in: {existing_workflow.handler.__module__}\n"
|
|
1395
|
+
f" Also defined in: {config.handler.__module__}\n"
|
|
1396
|
+
f" This is a bug - workflows must have unique names."
|
|
1397
|
+
)
|
|
1398
|
+
raise ValueError(
|
|
1399
|
+
f"Workflow '{config.name}' is already registered. "
|
|
1400
|
+
f"Use @workflow(name='unique_name') to specify a different name."
|
|
1401
|
+
)
|
|
1402
|
+
|
|
1403
|
+
_WORKFLOW_REGISTRY[config.name] = config
|
|
1404
|
+
logger.debug(f"Registered workflow '{config.name}'")
|
|
1405
|
+
|
|
1406
|
+
@staticmethod
|
|
1407
|
+
def get(name: str) -> Optional[WorkflowConfig]:
|
|
1408
|
+
"""Get workflow configuration by name."""
|
|
1409
|
+
return _WORKFLOW_REGISTRY.get(name)
|
|
1410
|
+
|
|
1411
|
+
@staticmethod
|
|
1412
|
+
def all() -> Dict[str, WorkflowConfig]:
|
|
1413
|
+
"""Get all registered workflows."""
|
|
1414
|
+
return _WORKFLOW_REGISTRY.copy()
|
|
1415
|
+
|
|
1416
|
+
@staticmethod
|
|
1417
|
+
def list_names() -> list[str]:
|
|
1418
|
+
"""List all registered workflow names."""
|
|
1419
|
+
return list(_WORKFLOW_REGISTRY.keys())
|
|
1420
|
+
|
|
1421
|
+
@staticmethod
|
|
1422
|
+
def clear() -> None:
|
|
1423
|
+
"""Clear all registered workflows."""
|
|
1424
|
+
_WORKFLOW_REGISTRY.clear()
|
|
1425
|
+
|
|
1426
|
+
|
|
1427
|
+
def workflow(
|
|
1428
|
+
_func: Optional[Callable[..., Any]] = None,
|
|
1429
|
+
*,
|
|
1430
|
+
name: Optional[str] = None,
|
|
1431
|
+
chat: bool = False,
|
|
1432
|
+
cron: Optional[str] = None,
|
|
1433
|
+
webhook: bool = False,
|
|
1434
|
+
webhook_secret: Optional[str] = None,
|
|
1435
|
+
) -> Callable[..., Any]:
|
|
1436
|
+
"""
|
|
1437
|
+
Decorator to mark a function as an AGNT5 durable workflow.
|
|
1438
|
+
|
|
1439
|
+
Workflows use WorkflowEntity for state management and WorkflowContext
|
|
1440
|
+
for orchestration. State changes are automatically tracked for replay.
|
|
1441
|
+
|
|
1442
|
+
Args:
|
|
1443
|
+
name: Custom workflow name (default: function's __name__)
|
|
1444
|
+
chat: Enable chat mode for multi-turn conversation workflows (default: False)
|
|
1445
|
+
cron: Cron expression for scheduled execution (e.g., "0 9 * * *" for daily at 9am)
|
|
1446
|
+
webhook: Enable webhook triggering for this workflow (default: False)
|
|
1447
|
+
webhook_secret: Optional secret for HMAC-SHA256 signature verification
|
|
1448
|
+
|
|
1449
|
+
Example (standard workflow):
|
|
1450
|
+
@workflow
|
|
1451
|
+
async def process_order(ctx: WorkflowContext, order_id: str) -> dict:
|
|
1452
|
+
# Durable state - survives crashes
|
|
1453
|
+
ctx.state.set("status", "processing")
|
|
1454
|
+
ctx.state.set("order_id", order_id)
|
|
1455
|
+
|
|
1456
|
+
# Validate order
|
|
1457
|
+
order = await ctx.task(validate_order, input={"order_id": order_id})
|
|
1458
|
+
|
|
1459
|
+
# Process payment (checkpointed - won't re-execute on crash)
|
|
1460
|
+
payment = await ctx.step("payment", process_payment(order["total"]))
|
|
1461
|
+
|
|
1462
|
+
# Fulfill order
|
|
1463
|
+
await ctx.task(ship_order, input={"order_id": order_id})
|
|
1464
|
+
|
|
1465
|
+
ctx.state.set("status", "completed")
|
|
1466
|
+
return {"status": ctx.state.get("status")}
|
|
1467
|
+
|
|
1468
|
+
Example (chat workflow):
|
|
1469
|
+
@workflow(chat=True)
|
|
1470
|
+
async def customer_support(ctx: WorkflowContext, message: str) -> dict:
|
|
1471
|
+
# Initialize conversation state
|
|
1472
|
+
if not ctx.state.get("messages"):
|
|
1473
|
+
ctx.state.set("messages", [])
|
|
1474
|
+
|
|
1475
|
+
# Add user message
|
|
1476
|
+
messages = ctx.state.get("messages")
|
|
1477
|
+
messages.append({"role": "user", "content": message})
|
|
1478
|
+
ctx.state.set("messages", messages)
|
|
1479
|
+
|
|
1480
|
+
# Generate AI response
|
|
1481
|
+
response = await ctx.task(generate_response, messages=messages)
|
|
1482
|
+
|
|
1483
|
+
# Add assistant response
|
|
1484
|
+
messages.append({"role": "assistant", "content": response})
|
|
1485
|
+
ctx.state.set("messages", messages)
|
|
1486
|
+
|
|
1487
|
+
return {"response": response, "turn_count": len(messages) // 2}
|
|
1488
|
+
|
|
1489
|
+
Example (scheduled workflow):
|
|
1490
|
+
@workflow(name="daily_report", cron="0 9 * * *")
|
|
1491
|
+
async def daily_report(ctx: WorkflowContext) -> dict:
|
|
1492
|
+
# Runs automatically every day at 9am
|
|
1493
|
+
sales = await ctx.task(get_sales_data, report_type="sales")
|
|
1494
|
+
report = await ctx.task(generate_pdf, input=sales)
|
|
1495
|
+
await ctx.task(send_email, to="team@company.com", attachment=report)
|
|
1496
|
+
return {"status": "sent", "report_id": report["id"]}
|
|
1497
|
+
|
|
1498
|
+
Example (webhook workflow):
|
|
1499
|
+
@workflow(name="on_payment", webhook=True, webhook_secret="your_secret_key")
|
|
1500
|
+
async def on_payment(ctx: WorkflowContext, webhook_data: dict) -> dict:
|
|
1501
|
+
# Triggered by webhook POST /v1/webhooks/on_payment
|
|
1502
|
+
# webhook_data contains: payload, headers, source_ip, timestamp
|
|
1503
|
+
payment = webhook_data["payload"]
|
|
1504
|
+
|
|
1505
|
+
if payment.get("status") == "succeeded":
|
|
1506
|
+
await ctx.task(fulfill_order, order_id=payment["order_id"])
|
|
1507
|
+
await ctx.task(send_receipt, customer_email=payment["email"])
|
|
1508
|
+
return {"status": "processed", "order_id": payment["order_id"]}
|
|
1509
|
+
|
|
1510
|
+
return {"status": "skipped", "reason": "payment not successful"}
|
|
1511
|
+
"""
|
|
1512
|
+
|
|
1513
|
+
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
1514
|
+
# Get workflow name
|
|
1515
|
+
workflow_name = name or func.__name__
|
|
1516
|
+
|
|
1517
|
+
# Validate function signature
|
|
1518
|
+
sig = inspect.signature(func)
|
|
1519
|
+
params = list(sig.parameters.values())
|
|
1520
|
+
|
|
1521
|
+
if not params or params[0].name != "ctx":
|
|
1522
|
+
raise ValueError(
|
|
1523
|
+
f"Workflow '{workflow_name}' must have 'ctx: WorkflowContext' as first parameter"
|
|
1524
|
+
)
|
|
1525
|
+
|
|
1526
|
+
# Convert sync to async if needed
|
|
1527
|
+
if inspect.iscoroutinefunction(func):
|
|
1528
|
+
handler_func = cast(HandlerFunc, func)
|
|
1529
|
+
else:
|
|
1530
|
+
# Wrap sync function in async
|
|
1531
|
+
@functools.wraps(func)
|
|
1532
|
+
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
1533
|
+
return func(*args, **kwargs)
|
|
1534
|
+
|
|
1535
|
+
handler_func = cast(HandlerFunc, async_wrapper)
|
|
1536
|
+
|
|
1537
|
+
# Extract schemas from type hints
|
|
1538
|
+
input_schema, output_schema = extract_function_schemas(func)
|
|
1539
|
+
|
|
1540
|
+
# Extract metadata (description, etc.)
|
|
1541
|
+
metadata = extract_function_metadata(func)
|
|
1542
|
+
|
|
1543
|
+
# Add chat metadata if chat mode is enabled
|
|
1544
|
+
if chat:
|
|
1545
|
+
metadata["chat"] = "true"
|
|
1546
|
+
|
|
1547
|
+
# Add cron metadata if cron schedule is provided
|
|
1548
|
+
if cron:
|
|
1549
|
+
metadata["cron"] = cron
|
|
1550
|
+
|
|
1551
|
+
# Add webhook metadata if webhook is enabled
|
|
1552
|
+
if webhook:
|
|
1553
|
+
metadata["webhook"] = "true"
|
|
1554
|
+
if webhook_secret:
|
|
1555
|
+
metadata["webhook_secret"] = webhook_secret
|
|
1556
|
+
|
|
1557
|
+
# Register workflow
|
|
1558
|
+
config = WorkflowConfig(
|
|
1559
|
+
name=workflow_name,
|
|
1560
|
+
handler=handler_func,
|
|
1561
|
+
input_schema=input_schema,
|
|
1562
|
+
output_schema=output_schema,
|
|
1563
|
+
metadata=metadata,
|
|
1564
|
+
)
|
|
1565
|
+
WorkflowRegistry.register(config)
|
|
1566
|
+
|
|
1567
|
+
# Create wrapper that provides context
|
|
1568
|
+
@functools.wraps(func)
|
|
1569
|
+
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
1570
|
+
# Create WorkflowEntity and WorkflowContext if not provided
|
|
1571
|
+
if not args or not isinstance(args[0], WorkflowContext):
|
|
1572
|
+
# Auto-create workflow entity and context for direct workflow calls
|
|
1573
|
+
run_id = f"workflow-{uuid.uuid4().hex[:8]}"
|
|
1574
|
+
|
|
1575
|
+
# Create WorkflowEntity to manage state
|
|
1576
|
+
workflow_entity = WorkflowEntity(run_id=run_id)
|
|
1577
|
+
|
|
1578
|
+
# Create WorkflowContext that wraps the entity
|
|
1579
|
+
ctx = WorkflowContext(
|
|
1580
|
+
workflow_entity=workflow_entity,
|
|
1581
|
+
run_id=run_id,
|
|
1582
|
+
)
|
|
1583
|
+
|
|
1584
|
+
# Set context in task-local storage for automatic propagation
|
|
1585
|
+
token = set_current_context(ctx)
|
|
1586
|
+
try:
|
|
1587
|
+
# Execute workflow
|
|
1588
|
+
result = await handler_func(ctx, *args, **kwargs)
|
|
1589
|
+
|
|
1590
|
+
# Persist workflow state after successful execution
|
|
1591
|
+
try:
|
|
1592
|
+
await workflow_entity._persist_state()
|
|
1593
|
+
except Exception as e:
|
|
1594
|
+
logger.error(f"Failed to persist workflow state (non-fatal): {e}", exc_info=True)
|
|
1595
|
+
# Don't fail the workflow - persistence failure shouldn't break execution
|
|
1596
|
+
|
|
1597
|
+
return result
|
|
1598
|
+
finally:
|
|
1599
|
+
# Always reset context to prevent leakage
|
|
1600
|
+
from .context import _current_context
|
|
1601
|
+
|
|
1602
|
+
_current_context.reset(token)
|
|
1603
|
+
else:
|
|
1604
|
+
# WorkflowContext provided - use it and set in contextvar
|
|
1605
|
+
ctx = args[0]
|
|
1606
|
+
token = set_current_context(ctx)
|
|
1607
|
+
try:
|
|
1608
|
+
result = await handler_func(*args, **kwargs)
|
|
1609
|
+
|
|
1610
|
+
# Persist workflow state after successful execution
|
|
1611
|
+
try:
|
|
1612
|
+
await ctx._workflow_entity._persist_state()
|
|
1613
|
+
except Exception as e:
|
|
1614
|
+
logger.error(f"Failed to persist workflow state (non-fatal): {e}", exc_info=True)
|
|
1615
|
+
# Don't fail the workflow - persistence failure shouldn't break execution
|
|
1616
|
+
|
|
1617
|
+
return result
|
|
1618
|
+
finally:
|
|
1619
|
+
# Always reset context to prevent leakage
|
|
1620
|
+
from .context import _current_context
|
|
1621
|
+
|
|
1622
|
+
_current_context.reset(token)
|
|
1623
|
+
|
|
1624
|
+
# Store config on wrapper for introspection
|
|
1625
|
+
wrapper._agnt5_config = config # type: ignore
|
|
1626
|
+
return wrapper
|
|
1627
|
+
|
|
1628
|
+
# Handle both @workflow and @workflow(...) syntax
|
|
1629
|
+
if _func is None:
|
|
1630
|
+
return decorator
|
|
1631
|
+
else:
|
|
1632
|
+
return decorator(_func)
|