agnt5 0.3.2a1__cp310-abi3-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of agnt5 might be problematic. Click here for more details.

agnt5/workflow.py ADDED
@@ -0,0 +1,1632 @@
1
+ """Workflow component implementation for AGNT5 SDK."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import functools
7
+ import inspect
8
+ import logging
9
+ import time
10
+ import uuid
11
+ from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar, Union, cast
12
+
13
+ from ._schema_utils import extract_function_metadata, extract_function_schemas
14
+ from .context import Context, set_current_context
15
+ from .entity import Entity, EntityState, _get_state_adapter
16
+ from .function import FunctionContext
17
+ from .types import HandlerFunc, WorkflowConfig
18
+ from ._telemetry import setup_module_logger
19
+
20
+ logger = setup_module_logger(__name__)
21
+
22
+ T = TypeVar("T")
23
+
24
+ # Global workflow registry
25
+ _WORKFLOW_REGISTRY: Dict[str, WorkflowConfig] = {}
26
+
27
+
28
+ class WorkflowContext(Context):
29
+ """
30
+ Context for durable workflows.
31
+
32
+ Extends base Context with:
33
+ - State management via WorkflowEntity.state
34
+ - Step tracking and replay
35
+ - Orchestration (task, parallel, gather)
36
+ - Checkpointing (step)
37
+ - Memory scoping (session_id, user_id for multi-level memory)
38
+
39
+ WorkflowContext delegates state to the underlying WorkflowEntity,
40
+ which provides durability and state change tracking for AI workflows.
41
+
42
+ Memory Scoping:
43
+ - run_id: Unique workflow run identifier
44
+ - session_id: For multi-turn conversations (optional)
45
+ - user_id: For user-scoped long-term memory (optional)
46
+ These identifiers enable agents to automatically select the appropriate
47
+ memory scope (run/session/user) via context propagation.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ workflow_entity: "WorkflowEntity", # Forward reference
53
+ run_id: str,
54
+ session_id: Optional[str] = None,
55
+ user_id: Optional[str] = None,
56
+ attempt: int = 0,
57
+ runtime_context: Optional[Any] = None,
58
+ checkpoint_callback: Optional[Callable[[dict], None]] = None,
59
+ checkpoint_client: Optional[Any] = None,
60
+ is_streaming: bool = False,
61
+ tenant_id: Optional[str] = None,
62
+ delta_callback: Optional[Callable[[str, str, int, int, int], None]] = None,
63
+ ) -> None:
64
+ """
65
+ Initialize workflow context.
66
+
67
+ Args:
68
+ workflow_entity: WorkflowEntity instance managing workflow state
69
+ run_id: Unique workflow run identifier
70
+ session_id: Session identifier for multi-turn conversations (default: run_id)
71
+ user_id: User identifier for user-scoped memory (optional)
72
+ attempt: Retry attempt number (0-indexed)
73
+ runtime_context: RuntimeContext for trace correlation
74
+ checkpoint_callback: Optional callback for sending real-time checkpoints
75
+ checkpoint_client: Optional CheckpointClient for platform-side memoization
76
+ is_streaming: Whether this is a streaming request (for real-time SSE log delivery)
77
+ tenant_id: Tenant identifier for multi-tenant deployments
78
+ delta_callback: Optional callback for forwarding streaming events from nested components
79
+ (event_type, output_data, content_index, sequence, source_timestamp_ns) -> None
80
+ """
81
+ super().__init__(
82
+ run_id=run_id,
83
+ attempt=attempt,
84
+ runtime_context=runtime_context,
85
+ is_streaming=is_streaming,
86
+ tenant_id=tenant_id,
87
+ checkpoint_callback=checkpoint_callback,
88
+ delta_callback=delta_callback,
89
+ )
90
+ self._workflow_entity = workflow_entity
91
+ self._step_counter: int = 0 # Track step sequence
92
+ self._sequence_number: int = 0 # Global sequence for checkpoints
93
+ self._checkpoint_client = checkpoint_client
94
+ self._delta_sequence: int = 0 # Sequence for delta events (separate from checkpoint sequence)
95
+
96
+ # Memory scoping identifiers (use private attrs since properties are read-only)
97
+ self._session_id = session_id or run_id # Default: session = run (ephemeral)
98
+ self._user_id = user_id # Optional: user-scoped memory
99
+
100
+ # Step hierarchy tracking - for nested step visualization
101
+ # Stack of event IDs for currently executing steps
102
+ self._step_event_stack: List[str] = []
103
+
104
+ # === State Management ===
105
+
106
+ def _forward_delta(self, event_type: str, output_data: str, content_index: int = 0, source_timestamp_ns: int = 0) -> None:
107
+ """
108
+ Forward a streaming delta event from a nested component.
109
+
110
+ Used by step executors to forward events from streaming agents/functions
111
+ to the client via the delta queue.
112
+
113
+ Args:
114
+ event_type: Event type (e.g., "agent.started", "lm.message.delta")
115
+ output_data: JSON-serialized event data
116
+ content_index: Content index for parallel events (default: 0)
117
+ source_timestamp_ns: Nanosecond timestamp when event was created (default: 0, will be generated if not provided)
118
+ """
119
+ if self._delta_callback:
120
+ self._delta_callback(event_type, output_data, content_index, self._delta_sequence, source_timestamp_ns)
121
+ self._delta_sequence += 1
122
+
123
+ async def _consume_streaming_result(self, async_gen: Any, step_name: str) -> Any:
124
+ """
125
+ Consume an async generator while forwarding streaming events to the client.
126
+
127
+ This method handles streaming from nested agents and functions within
128
+ workflow steps. Events are forwarded via the delta queue while the
129
+ final result is collected and returned for the next step.
130
+
131
+ For agents, the final output is extracted from the agent.completed event.
132
+ For functions, the last yielded value (or collected output) is returned.
133
+
134
+ Args:
135
+ async_gen: Async generator yielding Event objects or raw values
136
+ step_name: Name of the current step (for logging)
137
+
138
+ Returns:
139
+ The final result to pass to the next step:
140
+ - For agents: The output from agent.completed event
141
+ - For functions: The last yielded value or collected output
142
+ """
143
+ import json
144
+ from .events import Event, EventType
145
+
146
+ final_result = None
147
+ collected_output = [] # For streaming functions that yield chunks
148
+
149
+ async for item in async_gen:
150
+ if isinstance(item, Event):
151
+ # Forward typed Event via delta queue
152
+ event_data = item.to_response_fields()
153
+ output_data = event_data.get("output_data", b"")
154
+ output_str = output_data.decode("utf-8") if isinstance(output_data, bytes) else str(output_data or "{}")
155
+
156
+ self._forward_delta(
157
+ event_type=event_data.get("event_type", ""),
158
+ output_data=output_str,
159
+ content_index=event_data.get("content_index", 0),
160
+ source_timestamp_ns=item.source_timestamp_ns,
161
+ )
162
+
163
+ # Capture final result from specific event types
164
+ if item.event_type == EventType.AGENT_COMPLETED:
165
+ # For agents, extract the output from completed event
166
+ final_result = item.data.get("output", "")
167
+ logger.debug(f"Step '{step_name}': Captured agent output from agent.completed")
168
+ elif item.event_type == EventType.OUTPUT_STOP:
169
+ # For streaming functions, the collected output is the result
170
+ # (already collected from delta events)
171
+ pass
172
+
173
+ else:
174
+ # Raw value (non-Event) - streaming function output
175
+ # Forward as output.delta and collect for final result
176
+ try:
177
+ chunk_json = json.dumps(item)
178
+ except (TypeError, ValueError):
179
+ chunk_json = str(item)
180
+
181
+ self._forward_delta(
182
+ event_type="output.delta",
183
+ output_data=chunk_json,
184
+ source_timestamp_ns=time.time_ns(),
185
+ )
186
+ collected_output.append(item)
187
+
188
+ # Determine final result
189
+ if final_result is not None:
190
+ # Agent result was captured from agent.completed event
191
+ return final_result
192
+ elif collected_output:
193
+ # Streaming function - return collected chunks
194
+ # If single item, return it directly; otherwise return list
195
+ if len(collected_output) == 1:
196
+ return collected_output[0]
197
+ return collected_output
198
+ else:
199
+ # Empty generator
200
+ return None
201
+
202
+ def _send_checkpoint(self, checkpoint_type: str, checkpoint_data: dict) -> None:
203
+ """
204
+ Send a checkpoint via the unified event emission API.
205
+
206
+ This method uses ctx.emit for consistent event routing:
207
+ - Streaming mode: immediate delivery via delta queue
208
+ - Non-streaming mode: buffered delivery via checkpoint queue
209
+
210
+ Automatically adds parent_event_id from the step event stack if we're
211
+ currently executing inside a nested step call.
212
+
213
+ Args:
214
+ checkpoint_type: Type of checkpoint (e.g., "workflow.state.changed")
215
+ checkpoint_data: Checkpoint payload (should include event_id if needed)
216
+ """
217
+ # Add parent_event_id if we're in a nested step
218
+ if self._step_event_stack:
219
+ checkpoint_data = {
220
+ **checkpoint_data,
221
+ "parent_event_id": self._step_event_stack[-1],
222
+ }
223
+
224
+ # Use unified emit API for consistent event routing
225
+ # The emit property handles streaming vs non-streaming routing
226
+ self.emit.emit(checkpoint_type, checkpoint_data)
227
+
228
+ @property
229
+ def state(self):
230
+ """
231
+ Delegate to WorkflowEntity.state for durable state management.
232
+
233
+ Returns:
234
+ WorkflowState instance from the workflow entity
235
+
236
+ Example:
237
+ ctx.state.set("status", "processing")
238
+ status = ctx.state.get("status")
239
+ """
240
+ state = self._workflow_entity.state
241
+ # Pass checkpoint callback to state for real-time streaming
242
+ if hasattr(state, "_set_checkpoint_callback"):
243
+ state._set_checkpoint_callback(self._send_checkpoint)
244
+ return state
245
+
246
+ # === Orchestration ===
247
+
248
+ async def step(
249
+ self,
250
+ name_or_handler: Union[str, Callable, Awaitable[T]],
251
+ func_or_awaitable: Union[Callable[..., Awaitable[T]], Awaitable[T], Any] = None,
252
+ *args: Any,
253
+ **kwargs: Any,
254
+ ) -> T:
255
+ """
256
+ Execute a durable step with automatic checkpointing.
257
+
258
+ Steps are the primary building block for durable workflows. Results are
259
+ automatically persisted, so if the workflow crashes and restarts, completed
260
+ steps return their cached result without re-executing.
261
+
262
+ Supports multiple calling patterns:
263
+
264
+ 1. **Call a @function (recommended)**:
265
+ ```python
266
+ result = await ctx.step(process_data, arg1, arg2, kwarg=value)
267
+ ```
268
+ Auto-generates step name from function. Full IDE support.
269
+
270
+ 2. **Checkpoint an awaitable with explicit name**:
271
+ ```python
272
+ result = await ctx.step("load_data", fetch_expensive_data())
273
+ ```
274
+ For arbitrary async operations that aren't @functions.
275
+
276
+ 3. **Checkpoint a callable with explicit name**:
277
+ ```python
278
+ result = await ctx.step("compute", my_function, arg1, arg2)
279
+ ```
280
+
281
+ 4. **Legacy string-based @function call**:
282
+ ```python
283
+ result = await ctx.step("function_name", input=data)
284
+ ```
285
+
286
+ Args:
287
+ name_or_handler: Step name (str), @function reference, or awaitable
288
+ func_or_awaitable: Function/awaitable when name is provided, or first arg
289
+ *args: Additional arguments for the function
290
+ **kwargs: Keyword arguments for the function
291
+
292
+ Returns:
293
+ The step result (cached on replay)
294
+
295
+ Example (@function call):
296
+ ```python
297
+ @function
298
+ async def process_data(ctx: FunctionContext, data: list, multiplier: int = 2):
299
+ return [x * multiplier for x in data]
300
+
301
+ @workflow
302
+ async def my_workflow(ctx: WorkflowContext):
303
+ result = await ctx.step(process_data, [1, 2, 3], multiplier=3)
304
+ return result
305
+ ```
306
+
307
+ Example (checkpoint awaitable):
308
+ ```python
309
+ @workflow
310
+ async def my_workflow(ctx: WorkflowContext):
311
+ # Checkpoint expensive external call
312
+ data = await ctx.step("fetch_api", fetch_from_external_api())
313
+ return data
314
+ ```
315
+ """
316
+ import inspect
317
+
318
+ # Determine which calling pattern is being used
319
+ if callable(name_or_handler) and hasattr(name_or_handler, "_agnt5_config"):
320
+ # Pattern 1: step(handler, *args, **kwargs) - @function call
321
+ return await self._step_function(name_or_handler, func_or_awaitable, *args, **kwargs)
322
+ elif isinstance(name_or_handler, str):
323
+ # Check if it's a registered function name (legacy pattern)
324
+ from .function import FunctionRegistry
325
+ if FunctionRegistry.get(name_or_handler) is not None:
326
+ # Pattern 4: Legacy string-based function call
327
+ return await self._step_function(name_or_handler, func_or_awaitable, *args, **kwargs)
328
+ elif func_or_awaitable is not None:
329
+ # Pattern 2/3: step("name", awaitable) or step("name", callable, *args)
330
+ return await self._step_checkpoint(name_or_handler, func_or_awaitable, *args, **kwargs)
331
+ else:
332
+ # String without second arg and not a registered function
333
+ raise ValueError(
334
+ f"Function '{name_or_handler}' not found in registry. "
335
+ f"Either register it with @function decorator, or use "
336
+ f"ctx.step('{name_or_handler}', awaitable) to checkpoint an arbitrary operation."
337
+ )
338
+ elif inspect.iscoroutine(name_or_handler) or inspect.isawaitable(name_or_handler):
339
+ # Awaitable passed directly - auto-generate name
340
+ coro_name = getattr(name_or_handler, '__name__', 'awaitable')
341
+ return await self._step_checkpoint(coro_name, name_or_handler)
342
+ elif callable(name_or_handler):
343
+ # Callable without @function decorator
344
+ raise ValueError(
345
+ f"Function '{name_or_handler.__name__}' is not a registered @function. "
346
+ f"Did you forget to add the @function decorator? "
347
+ f"Or use ctx.step('name', callable) for non-decorated functions."
348
+ )
349
+ else:
350
+ raise ValueError(
351
+ f"step() first argument must be a @function, string name, or awaitable. "
352
+ f"Got: {type(name_or_handler)}"
353
+ )
354
+
355
+ async def _step_function(
356
+ self,
357
+ handler: Union[str, Callable],
358
+ first_arg: Any = None,
359
+ *args: Any,
360
+ **kwargs: Any,
361
+ ) -> Any:
362
+ """
363
+ Internal: Execute a @function as a durable step.
364
+
365
+ This handles both function references and legacy string-based calls.
366
+ """
367
+ from .function import FunctionRegistry
368
+
369
+ # Reconstruct args tuple (first_arg may have been split out by step())
370
+ if first_arg is not None:
371
+ args = (first_arg,) + args
372
+
373
+ # Extract handler name from function reference or use string
374
+ if callable(handler):
375
+ handler_name = handler.__name__
376
+ if not hasattr(handler, "_agnt5_config"):
377
+ raise ValueError(
378
+ f"Function '{handler_name}' is not a registered @function. "
379
+ f"Did you forget to add the @function decorator?"
380
+ )
381
+ else:
382
+ handler_name = handler
383
+
384
+ # Generate unique step name for durability
385
+ step_name = f"{handler_name}_{self._step_counter}"
386
+ self._step_counter += 1
387
+
388
+ # Generate unique event_id for this step (for hierarchy tracking)
389
+ step_event_id = str(uuid.uuid4())
390
+
391
+ # Check if step already completed (for replay)
392
+ if self._workflow_entity.has_completed_step(step_name):
393
+ result = self._workflow_entity.get_completed_step(step_name)
394
+ self._logger.info(f"🔄 Replaying cached step: {step_name}")
395
+ return result
396
+
397
+ # Emit workflow.step.started checkpoint
398
+ self._send_checkpoint(
399
+ "workflow.step.started",
400
+ {
401
+ "step_name": step_name,
402
+ "handler_name": handler_name,
403
+ "input": args or kwargs,
404
+ "event_id": step_event_id, # Include for hierarchy tracking
405
+ },
406
+ )
407
+
408
+ # Push this step's event_id onto the stack for nested calls
409
+ self._step_event_stack.append(step_event_id)
410
+
411
+ # Execute function with OpenTelemetry span
412
+ self._logger.info(f"▶️ Executing new step: {step_name}")
413
+ func_config = FunctionRegistry.get(handler_name)
414
+ if func_config is None:
415
+ raise ValueError(f"Function '{handler_name}' not found in registry")
416
+
417
+ # Import span creation utility (uses contextvar for async-safe parent-child linking)
418
+ from .tracing import create_span
419
+ import json
420
+
421
+ # Serialize input data for span attributes
422
+ input_repr = json.dumps({"args": args, "kwargs": kwargs}) if args or kwargs else "{}"
423
+
424
+ # Create span for task execution (contextvar handles parent-child linking)
425
+ with create_span(
426
+ f"workflow.task.{handler_name}",
427
+ "function",
428
+ self._runtime_context,
429
+ {
430
+ "step_name": step_name,
431
+ "handler_name": handler_name,
432
+ "run_id": self.run_id,
433
+ "input.data": input_repr,
434
+ },
435
+ ) as span:
436
+ # Create FunctionContext for the function execution
437
+ func_ctx = FunctionContext(
438
+ run_id=f"{self.run_id}:task:{handler_name}",
439
+ runtime_context=self._runtime_context,
440
+ )
441
+
442
+ try:
443
+ # Execute function with arguments
444
+ # Support legacy pattern: ctx.task("func_name", input=data) or ctx.task(func_ref, input=data)
445
+ if len(args) == 0 and "input" in kwargs:
446
+ # Legacy pattern - single input parameter
447
+ input_data = kwargs.pop("input") # Remove from kwargs
448
+ handler_result = func_config.handler(func_ctx, input_data, **kwargs)
449
+ else:
450
+ # Type-safe pattern - pass all args/kwargs
451
+ handler_result = func_config.handler(func_ctx, *args, **kwargs)
452
+
453
+ # Check if result is an async generator (streaming function or agent)
454
+ # If so, consume it while forwarding events via delta queue
455
+ if inspect.isasyncgen(handler_result):
456
+ result = await self._consume_streaming_result(handler_result, step_name)
457
+ elif inspect.iscoroutine(handler_result):
458
+ result = await handler_result
459
+ else:
460
+ result = handler_result
461
+
462
+ # Add output data to span
463
+ try:
464
+ output_repr = json.dumps(result)
465
+ span.set_attribute("output.data", output_repr)
466
+ except (TypeError, ValueError):
467
+ # If result is not JSON serializable, use repr
468
+ span.set_attribute("output.data", repr(result))
469
+
470
+ # Record step completion in WorkflowEntity
471
+ self._workflow_entity.record_step_completion(
472
+ step_name, handler_name, args or kwargs, result
473
+ )
474
+
475
+ # Pop this step's event_id from the stack (execution complete)
476
+ if self._step_event_stack:
477
+ popped_id = self._step_event_stack.pop()
478
+ if popped_id != step_event_id:
479
+ self._logger.warning(
480
+ f"Step event stack mismatch in task(): expected {step_event_id}, got {popped_id}"
481
+ )
482
+
483
+ # Emit workflow.step.completed checkpoint
484
+ self._send_checkpoint(
485
+ "workflow.step.completed",
486
+ {
487
+ "step_name": step_name,
488
+ "handler_name": handler_name,
489
+ "input": args or kwargs,
490
+ "result": result,
491
+ "event_id": step_event_id, # Include for consistency
492
+ },
493
+ )
494
+
495
+ return result
496
+
497
+ except Exception as e:
498
+ # Pop this step's event_id from the stack (execution failed)
499
+ if self._step_event_stack:
500
+ popped_id = self._step_event_stack.pop()
501
+ if popped_id != step_event_id:
502
+ self._logger.warning(
503
+ f"Step event stack mismatch in task() error path: expected {step_event_id}, got {popped_id}"
504
+ )
505
+
506
+ # Emit workflow.step.failed checkpoint
507
+ self._send_checkpoint(
508
+ "workflow.step.failed",
509
+ {
510
+ "step_name": step_name,
511
+ "handler_name": handler_name,
512
+ "input": args or kwargs,
513
+ "error": str(e),
514
+ "error_type": type(e).__name__,
515
+ "event_id": step_event_id, # Include for consistency
516
+ },
517
+ )
518
+
519
+ # Record error in span
520
+ span.set_attribute("error", "true")
521
+ span.set_attribute("error.message", str(e))
522
+ span.set_attribute("error.type", type(e).__name__)
523
+
524
+ # Re-raise to propagate failure
525
+ raise
526
+
527
+ async def parallel(self, *tasks: Awaitable[T]) -> List[T]:
528
+ """
529
+ Run multiple tasks in parallel.
530
+
531
+ Args:
532
+ *tasks: Async tasks to run in parallel
533
+
534
+ Returns:
535
+ List of results in the same order as tasks
536
+
537
+ Example:
538
+ result1, result2 = await ctx.parallel(
539
+ fetch_data(source1),
540
+ fetch_data(source2)
541
+ )
542
+ """
543
+ import asyncio
544
+
545
+ return list(await asyncio.gather(*tasks))
546
+
547
+ async def gather(self, **tasks: Awaitable[T]) -> Dict[str, T]:
548
+ """
549
+ Run tasks in parallel with named results.
550
+
551
+ Args:
552
+ **tasks: Named async tasks to run in parallel
553
+
554
+ Returns:
555
+ Dictionary mapping names to results
556
+
557
+ Example:
558
+ results = await ctx.gather(
559
+ db=query_database(),
560
+ api=fetch_api()
561
+ )
562
+ """
563
+ import asyncio
564
+
565
+ keys = list(tasks.keys())
566
+ values = list(tasks.values())
567
+ results = await asyncio.gather(*values)
568
+ return dict(zip(keys, results))
569
+
570
+ async def task(
571
+ self,
572
+ handler: Union[str, Callable],
573
+ *args: Any,
574
+ **kwargs: Any,
575
+ ) -> Any:
576
+ """
577
+ Execute a function and wait for result.
578
+
579
+ .. deprecated::
580
+ Use :meth:`step` instead. ``task()`` will be removed in a future version.
581
+
582
+ This method is an alias for :meth:`step` for backward compatibility.
583
+ New code should use ``ctx.step()`` directly.
584
+
585
+ Args:
586
+ handler: Either a @function reference or string name
587
+ *args: Positional arguments to pass to the function
588
+ **kwargs: Keyword arguments to pass to the function
589
+
590
+ Returns:
591
+ Function result
592
+ """
593
+ import warnings
594
+
595
+ warnings.warn(
596
+ "ctx.task() is deprecated, use ctx.step() instead. "
597
+ "task() will be removed in a future version.",
598
+ DeprecationWarning,
599
+ stacklevel=2,
600
+ )
601
+ return await self.step(handler, *args, **kwargs)
602
+
603
+ async def _step_checkpoint(
604
+ self,
605
+ name: str,
606
+ func_or_awaitable: Union[Callable[..., Awaitable[T]], Awaitable[T]],
607
+ *args: Any,
608
+ **kwargs: Any,
609
+ ) -> T:
610
+ """
611
+ Internal: Checkpoint an arbitrary awaitable or callable for durability.
612
+
613
+ If workflow crashes, won't re-execute this step on retry.
614
+ The step result is persisted to the platform for crash recovery.
615
+
616
+ When a CheckpointClient is available, this method uses platform-side
617
+ memoization via gRPC. The platform stores step results in the run_steps
618
+ table, enabling replay even after worker crashes.
619
+
620
+ Args:
621
+ name: Unique name for this checkpoint (used as step_key for memoization)
622
+ func_or_awaitable: Either an async function or awaitable
623
+ *args: Arguments to pass if func_or_awaitable is callable
624
+ **kwargs: Keyword arguments to pass if func_or_awaitable is callable
625
+
626
+ Returns:
627
+ The result of the function/awaitable
628
+ """
629
+ import inspect
630
+ import json
631
+ import time
632
+
633
+ # Generate step key for platform memoization
634
+ step_key = f"step:{name}:{self._step_counter}"
635
+ self._step_counter += 1
636
+
637
+ # Generate unique event_id for this step (for hierarchy tracking)
638
+ step_event_id = str(uuid.uuid4())
639
+
640
+ # Check platform-side memoization first (Phase 3)
641
+ if self._checkpoint_client:
642
+ try:
643
+ result = await self._checkpoint_client.step_started(
644
+ self.run_id,
645
+ step_key,
646
+ name,
647
+ "checkpoint",
648
+ )
649
+ if result.memoized and result.cached_output:
650
+ # Deserialize cached output
651
+ cached_value = json.loads(result.cached_output.decode("utf-8"))
652
+ self._logger.info(f"🔄 Replaying memoized step from platform: {name}")
653
+ # Also record locally for consistency
654
+ self._workflow_entity.record_step_completion(name, "checkpoint", None, cached_value)
655
+ return cached_value
656
+ except Exception as e:
657
+ self._logger.warning(f"Platform memoization check failed, falling back to local: {e}")
658
+
659
+ # Fall back to local memoization (for backward compatibility)
660
+ if self._workflow_entity.has_completed_step(name):
661
+ result = self._workflow_entity.get_completed_step(name)
662
+ self._logger.info(f"🔄 Replaying checkpoint from local cache: {name}")
663
+ return result
664
+
665
+ # Emit workflow.step.started checkpoint for observability
666
+ self._send_checkpoint(
667
+ "workflow.step.started",
668
+ {
669
+ "step_name": name,
670
+ "handler_name": "checkpoint",
671
+ "event_id": step_event_id, # Include for hierarchy tracking
672
+ },
673
+ )
674
+
675
+ # Push this step's event_id onto the stack for nested calls
676
+ self._step_event_stack.append(step_event_id)
677
+
678
+ start_time = time.time()
679
+ try:
680
+ # Execute and checkpoint
681
+ if inspect.isasyncgen(func_or_awaitable):
682
+ # Direct async generator - consume while forwarding events
683
+ result = await self._consume_streaming_result(func_or_awaitable, name)
684
+ elif inspect.iscoroutine(func_or_awaitable) or inspect.isawaitable(func_or_awaitable):
685
+ result = await func_or_awaitable
686
+ elif callable(func_or_awaitable):
687
+ # Call with args/kwargs if provided
688
+ call_result = func_or_awaitable(*args, **kwargs)
689
+ if inspect.isasyncgen(call_result):
690
+ # Callable returned async generator - consume while forwarding events
691
+ result = await self._consume_streaming_result(call_result, name)
692
+ elif inspect.iscoroutine(call_result) or inspect.isawaitable(call_result):
693
+ result = await call_result
694
+ else:
695
+ result = call_result
696
+ else:
697
+ raise ValueError(f"step() second argument must be awaitable or callable, got {type(func_or_awaitable)}")
698
+
699
+ latency_ms = int((time.time() - start_time) * 1000)
700
+
701
+ # Record step completion locally for in-memory replay
702
+ self._workflow_entity.record_step_completion(name, "checkpoint", None, result)
703
+
704
+ # Record to platform for persistent memoization (Phase 3)
705
+ if self._checkpoint_client:
706
+ try:
707
+ output_bytes = json.dumps(result).encode("utf-8")
708
+ await self._checkpoint_client.step_completed(
709
+ self.run_id,
710
+ step_key,
711
+ name,
712
+ "checkpoint",
713
+ output_bytes,
714
+ latency_ms,
715
+ )
716
+ except Exception as e:
717
+ self._logger.warning(f"Failed to record step completion to platform: {e}")
718
+
719
+ # Pop this step's event_id from the stack (execution complete)
720
+ if self._step_event_stack:
721
+ popped_id = self._step_event_stack.pop()
722
+ if popped_id != step_event_id:
723
+ self._logger.warning(
724
+ f"Step event stack mismatch in step(): expected {step_event_id}, got {popped_id}"
725
+ )
726
+
727
+ # Emit workflow.step.completed checkpoint to journal for crash recovery
728
+ self._send_checkpoint(
729
+ "workflow.step.completed",
730
+ {
731
+ "step_name": name,
732
+ "handler_name": "checkpoint",
733
+ "result": result,
734
+ "event_id": step_event_id, # Include for consistency
735
+ },
736
+ )
737
+
738
+ self._logger.info(f"✅ Checkpoint completed: {name} ({latency_ms}ms)")
739
+ return result
740
+
741
+ except Exception as e:
742
+ # Pop this step's event_id from the stack (execution failed)
743
+ if self._step_event_stack:
744
+ popped_id = self._step_event_stack.pop()
745
+ if popped_id != step_event_id:
746
+ self._logger.warning(
747
+ f"Step event stack mismatch in step() error path: expected {step_event_id}, got {popped_id}"
748
+ )
749
+
750
+ # Record failure to platform (Phase 3)
751
+ if self._checkpoint_client:
752
+ try:
753
+ await self._checkpoint_client.step_failed(
754
+ self.run_id,
755
+ step_key,
756
+ name,
757
+ "checkpoint",
758
+ str(e),
759
+ type(e).__name__,
760
+ )
761
+ except Exception as cp_err:
762
+ self._logger.warning(f"Failed to record step failure to platform: {cp_err}")
763
+
764
+ # Emit workflow.step.failed checkpoint
765
+ self._send_checkpoint(
766
+ "workflow.step.failed",
767
+ {
768
+ "step_name": name,
769
+ "handler_name": "checkpoint",
770
+ "error": str(e),
771
+ "error_type": type(e).__name__,
772
+ "event_id": step_event_id, # Include for consistency
773
+ },
774
+ )
775
+ raise
776
+
777
+ async def sleep(self, seconds: float, name: Optional[str] = None) -> None:
778
+ """
779
+ Durable sleep that survives workflow restarts.
780
+
781
+ Unlike regular `asyncio.sleep()`, this sleep is checkpointed. If the
782
+ workflow crashes and restarts, it will only sleep for the remaining
783
+ duration (or skip entirely if the sleep period has already elapsed).
784
+
785
+ Args:
786
+ seconds: Duration to sleep in seconds
787
+ name: Optional name for the sleep checkpoint (auto-generated if not provided)
788
+
789
+ Example:
790
+ ```python
791
+ @workflow
792
+ async def delayed_notification(ctx: WorkflowContext, user_id: str):
793
+ # Send immediate acknowledgment
794
+ await ctx.step(send_ack, user_id)
795
+
796
+ # Wait 24 hours (survives restarts!)
797
+ await ctx.sleep(24 * 60 * 60, name="wait_24h")
798
+
799
+ # Send follow-up
800
+ await ctx.step(send_followup, user_id)
801
+ ```
802
+ """
803
+ import time
804
+
805
+ # Generate unique step name for this sleep
806
+ sleep_name = name or f"sleep_{self._step_counter}"
807
+ self._step_counter += 1
808
+ step_key = f"sleep:{sleep_name}"
809
+
810
+ # Check if sleep was already started (replay scenario)
811
+ if self._workflow_entity.has_completed_step(step_key):
812
+ sleep_record = self._workflow_entity.get_completed_step(step_key)
813
+ start_time = sleep_record.get("start_time", 0)
814
+ duration = sleep_record.get("duration", seconds)
815
+ elapsed = time.time() - start_time
816
+
817
+ if elapsed >= duration:
818
+ # Sleep period already elapsed
819
+ self._logger.info(f"🔄 Sleep '{sleep_name}' already completed (elapsed: {elapsed:.1f}s)")
820
+ return
821
+
822
+ # Sleep for remaining duration
823
+ remaining = duration - elapsed
824
+ self._logger.info(f"⏰ Resuming sleep '{sleep_name}': {remaining:.1f}s remaining")
825
+ await asyncio.sleep(remaining)
826
+ return
827
+
828
+ # Record sleep start time for replay
829
+ sleep_record = {
830
+ "start_time": time.time(),
831
+ "duration": seconds,
832
+ }
833
+ self._workflow_entity.record_step_completion(step_key, "sleep", None, sleep_record)
834
+
835
+ # Emit checkpoint for observability
836
+ step_event_id = str(uuid.uuid4())
837
+ self._send_checkpoint(
838
+ "workflow.step.started",
839
+ {
840
+ "step_name": sleep_name,
841
+ "handler_name": "sleep",
842
+ "duration_seconds": seconds,
843
+ "event_id": step_event_id,
844
+ },
845
+ )
846
+
847
+ self._logger.info(f"💤 Starting durable sleep '{sleep_name}': {seconds}s")
848
+ await asyncio.sleep(seconds)
849
+
850
+ # Emit completion checkpoint
851
+ self._send_checkpoint(
852
+ "workflow.step.completed",
853
+ {
854
+ "step_name": sleep_name,
855
+ "handler_name": "sleep",
856
+ "duration_seconds": seconds,
857
+ "event_id": step_event_id,
858
+ },
859
+ )
860
+ self._logger.info(f"⏰ Sleep '{sleep_name}' completed")
861
+
862
+ async def wait_for_user(
863
+ self, question: str, input_type: str = "text", options: Optional[List[Dict]] = None
864
+ ) -> str:
865
+ """
866
+ Pause workflow execution and wait for user input.
867
+
868
+ On replay (even after worker crash), resumes from this point
869
+ with the user's response. This method enables human-in-the-loop
870
+ workflows by pausing execution and waiting for user interaction.
871
+
872
+ Args:
873
+ question: Question to ask the user
874
+ input_type: Type of input - "text", "approval", or "choice"
875
+ options: For approval/choice, list of option dicts with 'id' and 'label'
876
+
877
+ Returns:
878
+ User's response string
879
+
880
+ Raises:
881
+ WaitingForUserInputException: When no cached response exists (first call)
882
+
883
+ Example (text input):
884
+ ```python
885
+ city = await ctx.wait_for_user("Which city?")
886
+ ```
887
+
888
+ Example (approval):
889
+ ```python
890
+ decision = await ctx.wait_for_user(
891
+ "Approve this action?",
892
+ input_type="approval",
893
+ options=[
894
+ {"id": "approve", "label": "Approve"},
895
+ {"id": "reject", "label": "Reject"}
896
+ ]
897
+ )
898
+ ```
899
+
900
+ Example (choice):
901
+ ```python
902
+ model = await ctx.wait_for_user(
903
+ "Which model?",
904
+ input_type="choice",
905
+ options=[
906
+ {"id": "gpt4", "label": "GPT-4"},
907
+ {"id": "claude", "label": "Claude"}
908
+ ]
909
+ )
910
+ ```
911
+ """
912
+ from .exceptions import WaitingForUserInputException
913
+
914
+ # Generate unique step name for this user input request
915
+ # Each wait_for_user call gets a unique key based on pause_index
916
+ # This allows multi-step HITL workflows where each pause gets its own response
917
+ pause_index = self._workflow_entity._pause_index
918
+ response_key = f"user_response:{self.run_id}:{pause_index}"
919
+
920
+ # Increment pause index for next call (whether we replay or pause)
921
+ self._workflow_entity._pause_index += 1
922
+
923
+ # Check if we already have the user's response (replay scenario)
924
+ if self._workflow_entity.has_completed_step(response_key):
925
+ response = self._workflow_entity.get_completed_step(response_key)
926
+ self._logger.info(f"🔄 Replaying user response from checkpoint (pause {pause_index})")
927
+ return response
928
+
929
+ # No response yet - pause execution
930
+ # Collect current workflow state for checkpoint
931
+ checkpoint_state = {}
932
+ if hasattr(self._workflow_entity, "_state") and self._workflow_entity._state is not None:
933
+ checkpoint_state = self._workflow_entity._state.get_state_snapshot()
934
+
935
+ self._logger.info(f"⏸️ Pausing workflow for user input: {question}")
936
+
937
+ raise WaitingForUserInputException(
938
+ question=question,
939
+ input_type=input_type,
940
+ options=options,
941
+ checkpoint_state=checkpoint_state,
942
+ pause_index=pause_index, # Pass the pause index for multi-step HITL
943
+ )
944
+
945
+
946
+ # ============================================================================
947
+ # Helper functions for workflow execution
948
+ # ============================================================================
949
+
950
+
951
+ def _sanitize_for_json(obj: Any) -> Any:
952
+ """
953
+ Sanitize data for JSON serialization by removing or converting non-serializable objects.
954
+
955
+ Specifically handles:
956
+ - WorkflowContext objects (replaced with placeholder)
957
+ - Nested structures (recursively sanitized)
958
+
959
+ Args:
960
+ obj: Object to sanitize
961
+
962
+ Returns:
963
+ JSON-serializable version of the object
964
+ """
965
+ # Handle None, primitives
966
+ if obj is None or isinstance(obj, (str, int, float, bool)):
967
+ return obj
968
+
969
+ # Handle WorkflowContext - replace with placeholder
970
+ if isinstance(obj, WorkflowContext):
971
+ return "<WorkflowContext>"
972
+
973
+ # Handle tuples/lists - recursively sanitize
974
+ if isinstance(obj, (tuple, list)):
975
+ sanitized = [_sanitize_for_json(item) for item in obj]
976
+ return sanitized if isinstance(obj, list) else tuple(sanitized)
977
+
978
+ # Handle dicts - recursively sanitize values
979
+ if isinstance(obj, dict):
980
+ return {k: _sanitize_for_json(v) for k, v in obj.items()}
981
+
982
+ # For other objects, try to serialize or convert to string
983
+ try:
984
+ import json
985
+ json.dumps(obj)
986
+ return obj
987
+ except (TypeError, ValueError):
988
+ # Not JSON serializable, use string representation
989
+ return repr(obj)
990
+
991
+
992
+ # ============================================================================
993
+ # WorkflowEntity: Entity specialized for workflow execution state
994
+ # ============================================================================
995
+
996
+
997
+ class WorkflowEntity(Entity):
998
+ """
999
+ Entity specialized for workflow execution state.
1000
+
1001
+ Extends Entity with workflow-specific capabilities:
1002
+ - Step tracking for replay and crash recovery
1003
+ - State change tracking for debugging and audit (AI workflows)
1004
+ - Completed step cache for efficient replay
1005
+ - Automatic state persistence after workflow execution
1006
+
1007
+ Workflow state is persisted to the database after successful execution,
1008
+ enabling crash recovery, replay, and cross-invocation state management.
1009
+ The workflow decorator automatically calls _persist_state() to ensure
1010
+ durability.
1011
+ """
1012
+
1013
+ def __init__(
1014
+ self,
1015
+ run_id: str,
1016
+ session_id: Optional[str] = None,
1017
+ user_id: Optional[str] = None,
1018
+ component_name: Optional[str] = None,
1019
+ ):
1020
+ """
1021
+ Initialize workflow entity with memory scope.
1022
+
1023
+ Args:
1024
+ run_id: Unique workflow run identifier
1025
+ session_id: Session identifier for multi-turn conversations (optional)
1026
+ user_id: User identifier for user-scoped memory (optional)
1027
+ component_name: Workflow component name for session-scoped entities (optional)
1028
+
1029
+ Memory Scope Priority:
1030
+ - user_id present → key: user:{user_id} (shared across workflows)
1031
+ - session_id present (and != run_id) → key: workflow:{component_name}:session:{session_id}
1032
+ - else → key: run:{run_id}
1033
+
1034
+ Note: For session scope, component_name enables listing sessions by workflow name.
1035
+ User scope is shared across all workflows (not per-workflow).
1036
+ """
1037
+ # Determine entity key based on memory scope priority
1038
+ if user_id:
1039
+ # User scope: shared across all workflows (not per-workflow)
1040
+ entity_key = f"user:{user_id}"
1041
+ memory_scope = "user"
1042
+ elif session_id and session_id != run_id:
1043
+ # Session scope: include workflow name for queryability
1044
+ if component_name:
1045
+ entity_key = f"workflow:{component_name}:session:{session_id}"
1046
+ else:
1047
+ # Fallback for backward compatibility
1048
+ entity_key = f"session:{session_id}"
1049
+ memory_scope = "session"
1050
+ else:
1051
+ entity_key = f"run:{run_id}"
1052
+ memory_scope = "run"
1053
+
1054
+ # Initialize as entity with scoped key pattern
1055
+ super().__init__(key=entity_key)
1056
+
1057
+ # Store run_id separately for tracking (even if key is session/user scoped)
1058
+ self._run_id = run_id
1059
+ self._memory_scope = memory_scope
1060
+ self._component_name = component_name
1061
+ # Store scope identifiers for proper scope-based persistence
1062
+ self._session_id = session_id
1063
+ self._user_id = user_id
1064
+
1065
+ # Step tracking for replay and recovery
1066
+ self._step_events: list[Dict[str, Any]] = []
1067
+ self._completed_steps: Dict[str, Any] = {}
1068
+
1069
+ # HITL pause tracking - each wait_for_user gets unique index
1070
+ self._pause_index: int = 0
1071
+
1072
+ # State change tracking for debugging/audit (AI workflows)
1073
+ self._state_changes: list[Dict[str, Any]] = []
1074
+
1075
+ logger.debug(f"Created WorkflowEntity: run={run_id}, scope={memory_scope}, key={entity_key}, component={component_name}")
1076
+
1077
+ @property
1078
+ def run_id(self) -> str:
1079
+ """Get run_id for this workflow execution."""
1080
+ return self._run_id
1081
+
1082
+ def record_step_completion(
1083
+ self, step_name: str, handler_name: str, input_data: Any, result: Any
1084
+ ) -> None:
1085
+ """
1086
+ Record completed step for replay and recovery.
1087
+
1088
+ Args:
1089
+ step_name: Unique step identifier
1090
+ handler_name: Function handler name
1091
+ input_data: Input data passed to function
1092
+ result: Function result
1093
+ """
1094
+ # Sanitize input_data and result to ensure JSON serializability
1095
+ # This removes WorkflowContext objects and other non-serializable types
1096
+ sanitized_input = _sanitize_for_json(input_data)
1097
+ sanitized_result = _sanitize_for_json(result)
1098
+
1099
+ self._step_events.append(
1100
+ {
1101
+ "step_name": step_name,
1102
+ "handler_name": handler_name,
1103
+ "input": sanitized_input,
1104
+ "result": sanitized_result,
1105
+ }
1106
+ )
1107
+ self._completed_steps[step_name] = result
1108
+ logger.debug(f"Recorded step completion: {step_name}")
1109
+
1110
+ def get_completed_step(self, step_name: str) -> Optional[Any]:
1111
+ """
1112
+ Get result of completed step (for replay).
1113
+
1114
+ Args:
1115
+ step_name: Step identifier
1116
+
1117
+ Returns:
1118
+ Step result if completed, None otherwise
1119
+ """
1120
+ return self._completed_steps.get(step_name)
1121
+
1122
+ def has_completed_step(self, step_name: str) -> bool:
1123
+ """Check if step has been completed."""
1124
+ return step_name in self._completed_steps
1125
+
1126
+ def inject_user_response(self, response: str) -> None:
1127
+ """
1128
+ Inject user response as a completed step for workflow resume.
1129
+
1130
+ This method is called by the worker when resuming a paused workflow
1131
+ with the user's response. It stores the response as if it was a
1132
+ completed step, allowing wait_for_user() to retrieve it on replay.
1133
+
1134
+ The response is injected at the current pause_index (which should be
1135
+ restored from metadata before calling this method for multi-step HITL).
1136
+ This matches the key format used by wait_for_user().
1137
+
1138
+ Args:
1139
+ response: User's response to inject
1140
+
1141
+ Example:
1142
+ # Platform resumes workflow with user response
1143
+ workflow_entity.inject_user_response("yes")
1144
+ # On replay, wait_for_user() returns "yes" from cache
1145
+ """
1146
+ # Inject at current pause_index (restored from metadata for multi-step HITL)
1147
+ response_key = f"user_response:{self.run_id}:{self._pause_index}"
1148
+ self._completed_steps[response_key] = response
1149
+
1150
+ # Also add to step_events so it gets serialized to metadata on next pause
1151
+ # This ensures previous user responses are preserved across resumes
1152
+ self._step_events.append({
1153
+ "step_name": response_key,
1154
+ "handler_name": "user_response",
1155
+ "input": None,
1156
+ "result": response,
1157
+ })
1158
+
1159
+ logger.info(f"Injected user response for {self.run_id} at pause {self._pause_index}: {response}")
1160
+
1161
+ def get_agent_data(self, agent_name: str) -> Dict[str, Any]:
1162
+ """
1163
+ Get agent conversation data from workflow state.
1164
+
1165
+ Args:
1166
+ agent_name: Name of the agent
1167
+
1168
+ Returns:
1169
+ Dictionary containing agent conversation data (messages, metadata)
1170
+ or empty dict if agent has no data yet
1171
+
1172
+ Example:
1173
+ ```python
1174
+ agent_data = workflow_entity.get_agent_data("ResearchAgent")
1175
+ messages = agent_data.get("messages", [])
1176
+ ```
1177
+ """
1178
+ return self.state.get(f"agent.{agent_name}", {})
1179
+
1180
+ def get_agent_messages(self, agent_name: str) -> list[Dict[str, Any]]:
1181
+ """
1182
+ Get agent messages from workflow state.
1183
+
1184
+ Args:
1185
+ agent_name: Name of the agent
1186
+
1187
+ Returns:
1188
+ List of message dictionaries
1189
+
1190
+ Example:
1191
+ ```python
1192
+ messages = workflow_entity.get_agent_messages("ResearchAgent")
1193
+ for msg in messages:
1194
+ print(f"{msg['role']}: {msg['content']}")
1195
+ ```
1196
+ """
1197
+ agent_data = self.get_agent_data(agent_name)
1198
+ return agent_data.get("messages", [])
1199
+
1200
+ def list_agents(self) -> list[str]:
1201
+ """
1202
+ List all agents with data in this workflow.
1203
+
1204
+ Returns:
1205
+ List of agent names that have stored conversation data
1206
+
1207
+ Example:
1208
+ ```python
1209
+ agents = workflow_entity.list_agents()
1210
+ # ['ResearchAgent', 'AnalysisAgent', 'SynthesisAgent']
1211
+ ```
1212
+ """
1213
+ agents = []
1214
+ for key in self.state._state.keys():
1215
+ if key.startswith("agent."):
1216
+ agents.append(key.replace("agent.", "", 1))
1217
+ return agents
1218
+
1219
+ async def _persist_state(self) -> None:
1220
+ """
1221
+ Internal method to persist workflow state to entity storage.
1222
+
1223
+ This is prefixed with _ so it won't be wrapped by the entity method wrapper.
1224
+ Called after workflow execution completes to ensure state is durable.
1225
+ """
1226
+ logger.info(f"🔍 DEBUG: _persist_state() CALLED for workflow {self.run_id}")
1227
+
1228
+ try:
1229
+ from .entity import _get_state_adapter
1230
+
1231
+ logger.info(f"🔍 DEBUG: Getting state adapter...")
1232
+ # Get the state adapter (must be in Worker context)
1233
+ adapter = _get_state_adapter()
1234
+ logger.info(f"🔍 DEBUG: Got state adapter: {type(adapter).__name__}")
1235
+
1236
+ logger.info(f"🔍 DEBUG: Getting state snapshot...")
1237
+ # Get current state snapshot
1238
+ state_dict = self.state.get_state_snapshot()
1239
+ logger.info(f"🔍 DEBUG: State snapshot has {len(state_dict)} keys: {list(state_dict.keys())}")
1240
+
1241
+ # Determine scope and scope_id based on memory scope
1242
+ scope = self._memory_scope # "session", "user", or "run"
1243
+ scope_id = ""
1244
+ if self._memory_scope == "session" and self._session_id:
1245
+ scope_id = self._session_id
1246
+ elif self._memory_scope == "user" and self._user_id:
1247
+ scope_id = self._user_id
1248
+ elif self._memory_scope == "run":
1249
+ scope_id = self._run_id
1250
+
1251
+ logger.info(f"🔍 DEBUG: Loading current version for optimistic locking (scope={scope}, scope_id={scope_id})...")
1252
+ # Load current version (for optimistic locking) with proper scope
1253
+ _, current_version = await adapter.load_with_version(
1254
+ self._entity_type, self._key, scope=scope, scope_id=scope_id
1255
+ )
1256
+ logger.info(f"🔍 DEBUG: Current version: {current_version}")
1257
+
1258
+ logger.info(f"🔍 DEBUG: Saving state to database...")
1259
+
1260
+ logger.info(f"🔍 DEBUG: Using scope={scope}, scope_id={scope_id}")
1261
+ # Save state with version check and proper scope
1262
+ new_version = await adapter.save_state(
1263
+ self._entity_type, self._key, state_dict, current_version,
1264
+ scope=scope, scope_id=scope_id
1265
+ )
1266
+
1267
+ logger.info(
1268
+ f"✅ SUCCESS: Persisted WorkflowEntity state for {self.run_id} "
1269
+ f"(version {current_version} -> {new_version}, {len(state_dict)} keys)"
1270
+ )
1271
+ except Exception as e:
1272
+ logger.error(
1273
+ f"❌ ERROR: Failed to persist workflow state for {self.run_id}: {e}",
1274
+ exc_info=True
1275
+ )
1276
+ # Re-raise to let caller handle
1277
+ raise
1278
+
1279
+ @property
1280
+ def state(self) -> "WorkflowState":
1281
+ """
1282
+ Get workflow state with change tracking.
1283
+
1284
+ Returns WorkflowState which tracks all state mutations
1285
+ for debugging and replay of AI workflows.
1286
+ """
1287
+ if self._state is None:
1288
+ # Initialize with empty state dict - will be populated by entity system
1289
+ self._state = WorkflowState({}, self)
1290
+ return self._state
1291
+
1292
+
1293
+ class WorkflowState(EntityState):
1294
+ """
1295
+ State interface for WorkflowEntity with change tracking.
1296
+
1297
+ Extends EntityState to track all state mutations for:
1298
+ - AI workflow debugging
1299
+ - Audit trail
1300
+ - Replay capabilities
1301
+ """
1302
+
1303
+ def __init__(self, state_dict: Dict[str, Any], workflow_entity: WorkflowEntity):
1304
+ """
1305
+ Initialize workflow state.
1306
+
1307
+ Args:
1308
+ state_dict: Dictionary to use for state storage
1309
+ workflow_entity: Parent workflow entity for tracking
1310
+ """
1311
+ super().__init__(state_dict)
1312
+ self._workflow_entity = workflow_entity
1313
+ self._checkpoint_callback: Optional[Callable[[str, dict], None]] = None
1314
+
1315
+ def _set_checkpoint_callback(self, callback: Callable[[str, dict], None]) -> None:
1316
+ """
1317
+ Set the checkpoint callback for real-time state change streaming.
1318
+
1319
+ Args:
1320
+ callback: Function to call when state changes
1321
+ """
1322
+ self._checkpoint_callback = callback
1323
+
1324
+ def set(self, key: str, value: Any) -> None:
1325
+ """Set value and track change."""
1326
+ super().set(key, value)
1327
+ # Track change for debugging/audit
1328
+ import time
1329
+
1330
+ change_record = {"key": key, "value": value, "timestamp": time.time(), "deleted": False}
1331
+ self._workflow_entity._state_changes.append(change_record)
1332
+
1333
+ # Emit checkpoint for real-time state streaming
1334
+ if self._checkpoint_callback:
1335
+ self._checkpoint_callback(
1336
+ "workflow.state.changed", {"key": key, "value": value, "operation": "set"}
1337
+ )
1338
+
1339
+ def delete(self, key: str) -> None:
1340
+ """Delete key and track change."""
1341
+ super().delete(key)
1342
+ # Track deletion
1343
+ import time
1344
+
1345
+ change_record = {"key": key, "value": None, "timestamp": time.time(), "deleted": True}
1346
+ self._workflow_entity._state_changes.append(change_record)
1347
+
1348
+ # Emit checkpoint for real-time state streaming
1349
+ if self._checkpoint_callback:
1350
+ self._checkpoint_callback("workflow.state.changed", {"key": key, "operation": "delete"})
1351
+
1352
+ def clear(self) -> None:
1353
+ """Clear all state and track change."""
1354
+ super().clear()
1355
+ # Track clear operation
1356
+ import time
1357
+
1358
+ change_record = {
1359
+ "key": "__clear__",
1360
+ "value": None,
1361
+ "timestamp": time.time(),
1362
+ "deleted": True,
1363
+ }
1364
+ self._workflow_entity._state_changes.append(change_record)
1365
+
1366
+ # Emit checkpoint for real-time state streaming
1367
+ if self._checkpoint_callback:
1368
+ self._checkpoint_callback("workflow.state.changed", {"operation": "clear"})
1369
+
1370
+ def has_changes(self) -> bool:
1371
+ """Check if any state changes have been tracked."""
1372
+ return len(self._workflow_entity._state_changes) > 0
1373
+
1374
+ def get_state_snapshot(self) -> Dict[str, Any]:
1375
+ """Get current state as a snapshot dictionary."""
1376
+ return dict(self._state)
1377
+
1378
+
1379
+ class WorkflowRegistry:
1380
+ """Registry for workflow handlers."""
1381
+
1382
+ @staticmethod
1383
+ def register(config: WorkflowConfig) -> None:
1384
+ """
1385
+ Register a workflow handler.
1386
+
1387
+ Raises:
1388
+ ValueError: If a workflow with this name is already registered
1389
+ """
1390
+ if config.name in _WORKFLOW_REGISTRY:
1391
+ existing_workflow = _WORKFLOW_REGISTRY[config.name]
1392
+ logger.error(
1393
+ f"Workflow name collision detected: '{config.name}'\n"
1394
+ f" First defined in: {existing_workflow.handler.__module__}\n"
1395
+ f" Also defined in: {config.handler.__module__}\n"
1396
+ f" This is a bug - workflows must have unique names."
1397
+ )
1398
+ raise ValueError(
1399
+ f"Workflow '{config.name}' is already registered. "
1400
+ f"Use @workflow(name='unique_name') to specify a different name."
1401
+ )
1402
+
1403
+ _WORKFLOW_REGISTRY[config.name] = config
1404
+ logger.debug(f"Registered workflow '{config.name}'")
1405
+
1406
+ @staticmethod
1407
+ def get(name: str) -> Optional[WorkflowConfig]:
1408
+ """Get workflow configuration by name."""
1409
+ return _WORKFLOW_REGISTRY.get(name)
1410
+
1411
+ @staticmethod
1412
+ def all() -> Dict[str, WorkflowConfig]:
1413
+ """Get all registered workflows."""
1414
+ return _WORKFLOW_REGISTRY.copy()
1415
+
1416
+ @staticmethod
1417
+ def list_names() -> list[str]:
1418
+ """List all registered workflow names."""
1419
+ return list(_WORKFLOW_REGISTRY.keys())
1420
+
1421
+ @staticmethod
1422
+ def clear() -> None:
1423
+ """Clear all registered workflows."""
1424
+ _WORKFLOW_REGISTRY.clear()
1425
+
1426
+
1427
+ def workflow(
1428
+ _func: Optional[Callable[..., Any]] = None,
1429
+ *,
1430
+ name: Optional[str] = None,
1431
+ chat: bool = False,
1432
+ cron: Optional[str] = None,
1433
+ webhook: bool = False,
1434
+ webhook_secret: Optional[str] = None,
1435
+ ) -> Callable[..., Any]:
1436
+ """
1437
+ Decorator to mark a function as an AGNT5 durable workflow.
1438
+
1439
+ Workflows use WorkflowEntity for state management and WorkflowContext
1440
+ for orchestration. State changes are automatically tracked for replay.
1441
+
1442
+ Args:
1443
+ name: Custom workflow name (default: function's __name__)
1444
+ chat: Enable chat mode for multi-turn conversation workflows (default: False)
1445
+ cron: Cron expression for scheduled execution (e.g., "0 9 * * *" for daily at 9am)
1446
+ webhook: Enable webhook triggering for this workflow (default: False)
1447
+ webhook_secret: Optional secret for HMAC-SHA256 signature verification
1448
+
1449
+ Example (standard workflow):
1450
+ @workflow
1451
+ async def process_order(ctx: WorkflowContext, order_id: str) -> dict:
1452
+ # Durable state - survives crashes
1453
+ ctx.state.set("status", "processing")
1454
+ ctx.state.set("order_id", order_id)
1455
+
1456
+ # Validate order
1457
+ order = await ctx.task(validate_order, input={"order_id": order_id})
1458
+
1459
+ # Process payment (checkpointed - won't re-execute on crash)
1460
+ payment = await ctx.step("payment", process_payment(order["total"]))
1461
+
1462
+ # Fulfill order
1463
+ await ctx.task(ship_order, input={"order_id": order_id})
1464
+
1465
+ ctx.state.set("status", "completed")
1466
+ return {"status": ctx.state.get("status")}
1467
+
1468
+ Example (chat workflow):
1469
+ @workflow(chat=True)
1470
+ async def customer_support(ctx: WorkflowContext, message: str) -> dict:
1471
+ # Initialize conversation state
1472
+ if not ctx.state.get("messages"):
1473
+ ctx.state.set("messages", [])
1474
+
1475
+ # Add user message
1476
+ messages = ctx.state.get("messages")
1477
+ messages.append({"role": "user", "content": message})
1478
+ ctx.state.set("messages", messages)
1479
+
1480
+ # Generate AI response
1481
+ response = await ctx.task(generate_response, messages=messages)
1482
+
1483
+ # Add assistant response
1484
+ messages.append({"role": "assistant", "content": response})
1485
+ ctx.state.set("messages", messages)
1486
+
1487
+ return {"response": response, "turn_count": len(messages) // 2}
1488
+
1489
+ Example (scheduled workflow):
1490
+ @workflow(name="daily_report", cron="0 9 * * *")
1491
+ async def daily_report(ctx: WorkflowContext) -> dict:
1492
+ # Runs automatically every day at 9am
1493
+ sales = await ctx.task(get_sales_data, report_type="sales")
1494
+ report = await ctx.task(generate_pdf, input=sales)
1495
+ await ctx.task(send_email, to="team@company.com", attachment=report)
1496
+ return {"status": "sent", "report_id": report["id"]}
1497
+
1498
+ Example (webhook workflow):
1499
+ @workflow(name="on_payment", webhook=True, webhook_secret="your_secret_key")
1500
+ async def on_payment(ctx: WorkflowContext, webhook_data: dict) -> dict:
1501
+ # Triggered by webhook POST /v1/webhooks/on_payment
1502
+ # webhook_data contains: payload, headers, source_ip, timestamp
1503
+ payment = webhook_data["payload"]
1504
+
1505
+ if payment.get("status") == "succeeded":
1506
+ await ctx.task(fulfill_order, order_id=payment["order_id"])
1507
+ await ctx.task(send_receipt, customer_email=payment["email"])
1508
+ return {"status": "processed", "order_id": payment["order_id"]}
1509
+
1510
+ return {"status": "skipped", "reason": "payment not successful"}
1511
+ """
1512
+
1513
+ def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
1514
+ # Get workflow name
1515
+ workflow_name = name or func.__name__
1516
+
1517
+ # Validate function signature
1518
+ sig = inspect.signature(func)
1519
+ params = list(sig.parameters.values())
1520
+
1521
+ if not params or params[0].name != "ctx":
1522
+ raise ValueError(
1523
+ f"Workflow '{workflow_name}' must have 'ctx: WorkflowContext' as first parameter"
1524
+ )
1525
+
1526
+ # Convert sync to async if needed
1527
+ if inspect.iscoroutinefunction(func):
1528
+ handler_func = cast(HandlerFunc, func)
1529
+ else:
1530
+ # Wrap sync function in async
1531
+ @functools.wraps(func)
1532
+ async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
1533
+ return func(*args, **kwargs)
1534
+
1535
+ handler_func = cast(HandlerFunc, async_wrapper)
1536
+
1537
+ # Extract schemas from type hints
1538
+ input_schema, output_schema = extract_function_schemas(func)
1539
+
1540
+ # Extract metadata (description, etc.)
1541
+ metadata = extract_function_metadata(func)
1542
+
1543
+ # Add chat metadata if chat mode is enabled
1544
+ if chat:
1545
+ metadata["chat"] = "true"
1546
+
1547
+ # Add cron metadata if cron schedule is provided
1548
+ if cron:
1549
+ metadata["cron"] = cron
1550
+
1551
+ # Add webhook metadata if webhook is enabled
1552
+ if webhook:
1553
+ metadata["webhook"] = "true"
1554
+ if webhook_secret:
1555
+ metadata["webhook_secret"] = webhook_secret
1556
+
1557
+ # Register workflow
1558
+ config = WorkflowConfig(
1559
+ name=workflow_name,
1560
+ handler=handler_func,
1561
+ input_schema=input_schema,
1562
+ output_schema=output_schema,
1563
+ metadata=metadata,
1564
+ )
1565
+ WorkflowRegistry.register(config)
1566
+
1567
+ # Create wrapper that provides context
1568
+ @functools.wraps(func)
1569
+ async def wrapper(*args: Any, **kwargs: Any) -> Any:
1570
+ # Create WorkflowEntity and WorkflowContext if not provided
1571
+ if not args or not isinstance(args[0], WorkflowContext):
1572
+ # Auto-create workflow entity and context for direct workflow calls
1573
+ run_id = f"workflow-{uuid.uuid4().hex[:8]}"
1574
+
1575
+ # Create WorkflowEntity to manage state
1576
+ workflow_entity = WorkflowEntity(run_id=run_id)
1577
+
1578
+ # Create WorkflowContext that wraps the entity
1579
+ ctx = WorkflowContext(
1580
+ workflow_entity=workflow_entity,
1581
+ run_id=run_id,
1582
+ )
1583
+
1584
+ # Set context in task-local storage for automatic propagation
1585
+ token = set_current_context(ctx)
1586
+ try:
1587
+ # Execute workflow
1588
+ result = await handler_func(ctx, *args, **kwargs)
1589
+
1590
+ # Persist workflow state after successful execution
1591
+ try:
1592
+ await workflow_entity._persist_state()
1593
+ except Exception as e:
1594
+ logger.error(f"Failed to persist workflow state (non-fatal): {e}", exc_info=True)
1595
+ # Don't fail the workflow - persistence failure shouldn't break execution
1596
+
1597
+ return result
1598
+ finally:
1599
+ # Always reset context to prevent leakage
1600
+ from .context import _current_context
1601
+
1602
+ _current_context.reset(token)
1603
+ else:
1604
+ # WorkflowContext provided - use it and set in contextvar
1605
+ ctx = args[0]
1606
+ token = set_current_context(ctx)
1607
+ try:
1608
+ result = await handler_func(*args, **kwargs)
1609
+
1610
+ # Persist workflow state after successful execution
1611
+ try:
1612
+ await ctx._workflow_entity._persist_state()
1613
+ except Exception as e:
1614
+ logger.error(f"Failed to persist workflow state (non-fatal): {e}", exc_info=True)
1615
+ # Don't fail the workflow - persistence failure shouldn't break execution
1616
+
1617
+ return result
1618
+ finally:
1619
+ # Always reset context to prevent leakage
1620
+ from .context import _current_context
1621
+
1622
+ _current_context.reset(token)
1623
+
1624
+ # Store config on wrapper for introspection
1625
+ wrapper._agnt5_config = config # type: ignore
1626
+ return wrapper
1627
+
1628
+ # Handle both @workflow and @workflow(...) syntax
1629
+ if _func is None:
1630
+ return decorator
1631
+ else:
1632
+ return decorator(_func)