polos-sdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. polos/__init__.py +105 -0
  2. polos/agents/__init__.py +7 -0
  3. polos/agents/agent.py +746 -0
  4. polos/agents/conversation_history.py +121 -0
  5. polos/agents/stop_conditions.py +280 -0
  6. polos/agents/stream.py +635 -0
  7. polos/core/__init__.py +0 -0
  8. polos/core/context.py +143 -0
  9. polos/core/state.py +26 -0
  10. polos/core/step.py +1380 -0
  11. polos/core/workflow.py +1192 -0
  12. polos/features/__init__.py +0 -0
  13. polos/features/events.py +456 -0
  14. polos/features/schedules.py +110 -0
  15. polos/features/tracing.py +605 -0
  16. polos/features/wait.py +82 -0
  17. polos/llm/__init__.py +9 -0
  18. polos/llm/generate.py +152 -0
  19. polos/llm/providers/__init__.py +5 -0
  20. polos/llm/providers/anthropic.py +615 -0
  21. polos/llm/providers/azure.py +42 -0
  22. polos/llm/providers/base.py +196 -0
  23. polos/llm/providers/fireworks.py +41 -0
  24. polos/llm/providers/gemini.py +40 -0
  25. polos/llm/providers/groq.py +40 -0
  26. polos/llm/providers/openai.py +1021 -0
  27. polos/llm/providers/together.py +40 -0
  28. polos/llm/stream.py +183 -0
  29. polos/middleware/__init__.py +0 -0
  30. polos/middleware/guardrail.py +148 -0
  31. polos/middleware/guardrail_executor.py +253 -0
  32. polos/middleware/hook.py +164 -0
  33. polos/middleware/hook_executor.py +104 -0
  34. polos/runtime/__init__.py +0 -0
  35. polos/runtime/batch.py +87 -0
  36. polos/runtime/client.py +841 -0
  37. polos/runtime/queue.py +42 -0
  38. polos/runtime/worker.py +1365 -0
  39. polos/runtime/worker_server.py +249 -0
  40. polos/tools/__init__.py +0 -0
  41. polos/tools/tool.py +587 -0
  42. polos/types/__init__.py +23 -0
  43. polos/types/types.py +116 -0
  44. polos/utils/__init__.py +27 -0
  45. polos/utils/agent.py +27 -0
  46. polos/utils/client_context.py +41 -0
  47. polos/utils/config.py +12 -0
  48. polos/utils/output_schema.py +311 -0
  49. polos/utils/retry.py +47 -0
  50. polos/utils/serializer.py +167 -0
  51. polos/utils/tracing.py +27 -0
  52. polos/utils/worker_singleton.py +40 -0
  53. polos_sdk-0.1.0.dist-info/METADATA +650 -0
  54. polos_sdk-0.1.0.dist-info/RECORD +55 -0
  55. polos_sdk-0.1.0.dist-info/WHEEL +4 -0
polos/core/step.py ADDED
@@ -0,0 +1,1380 @@
1
+ """Step execution helper for durable execution within workflows."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import contextvars
7
+ import json
8
+ import logging
9
+ import os
10
+ import random as random_module
11
+ import time
12
+ import uuid as uuid_module
13
+ from collections.abc import Callable
14
+ from contextlib import contextmanager
15
+ from datetime import datetime, timedelta, timezone
16
+ from typing import Any
17
+
18
+ from opentelemetry.trace import Status, StatusCode
19
+ from pydantic import BaseModel
20
+
21
+ from ..features.events import EventData, EventPayload, batch_publish
22
+ from ..features.tracing import extract_traceparent, get_current_span, get_tracer
23
+ from ..features.wait import WaitException, _get_wait_time, _set_waiting
24
+ from ..runtime.client import ExecutionHandle, get_step_output, store_step_output
25
+ from ..types.types import BatchStepResult, BatchWorkflowInput
26
+ from ..utils.client_context import get_client_or_raise
27
+ from ..utils.retry import retry_with_backoff
28
+ from ..utils.serializer import deserialize, safe_serialize, serialize
29
+ from ..utils.tracing import (
30
+ get_parent_span_context_from_execution_context,
31
+ get_span_context_from_execution_context,
32
+ set_span_context_in_execution_context,
33
+ )
34
+ from .context import WorkflowContext
35
+ from .workflow import (
36
+ StepExecutionError,
37
+ Workflow,
38
+ _execution_context,
39
+ get_workflow,
40
+ )
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+
45
+ class Step:
46
+ """Step execution helper - provides durable execution primitives.
47
+
48
+ Steps are executed within a workflow context and their outputs are
49
+ saved to avoid re-execution on workflow resume/replay.
50
+ """
51
+
52
+ def __init__(self, ctx: WorkflowContext):
53
+ """Initialize Step with a WorkflowContext.
54
+
55
+ Args:
56
+ ctx: The workflow execution context
57
+ """
58
+ self.ctx = ctx
59
+
60
+ async def _check_existing_step(self, step_key: str) -> dict[str, Any] | None:
61
+ """
62
+ Check for existing step output using step_key.
63
+
64
+ Args:
65
+ step_key: Step key identifier (must be unique per execution)
66
+ workflow: Whether the step is a workflow step
67
+
68
+ Returns:
69
+ existing_step_output or None
70
+ - If existing_step_output is not None, it means the step was already executed
71
+ - The caller should handle returning cached results or raising errors
72
+ """
73
+ return await get_step_output(self.ctx.execution_id, step_key)
74
+
75
+ async def _handle_existing_step(
76
+ self,
77
+ existing_step: dict[str, Any],
78
+ ) -> Any:
79
+ """
80
+ Handle existing step output - either return cached result or raise error.
81
+
82
+ Args:
83
+ existing_step: The existing step output from _check_existing_step
84
+ workflow: Whether the step is a workflow step
85
+
86
+ Returns:
87
+ Cached result (from on_success callback or existing_step.get("outputs"))
88
+
89
+ Raises:
90
+ StepExecutionError: If step previously failed
91
+ """
92
+ if existing_step.get("success", False):
93
+ outputs = existing_step.get("outputs")
94
+ if outputs:
95
+ result = await deserialize(outputs, existing_step.get("output_schema_name"))
96
+ return result
97
+ else:
98
+ return None
99
+ else:
100
+ error = existing_step.get("error", {})
101
+ error_message = (
102
+ error.get("message", "Step execution failed")
103
+ if isinstance(error, dict)
104
+ else str(error)
105
+ )
106
+ raise StepExecutionError(error_message)
107
+
108
+ async def _save_step_output(
109
+ self,
110
+ step_key: str,
111
+ result: Any,
112
+ source_execution_id: str | None = None,
113
+ ) -> None:
114
+ """Save step output to database using step_key as the unique identifier.
115
+
116
+ Args:
117
+ step_key: Step key identifier (must be unique per execution)
118
+ result: Result to save
119
+ source_execution_id: Optional source execution ID
120
+
121
+ If result is a Pydantic BaseModel, converts it to dict using model_dump(mode="json")
122
+ to ensure only valid Pydantic models are stored. model_dump(mode="json") automatically
123
+ handles nested Pydantic models within the model.
124
+
125
+ Also extracts and stores the full module path of Pydantic classes for
126
+ automatic deserialization when reading from the database.
127
+
128
+ If result is not a Pydantic model, validates that it's JSON serializable
129
+ by attempting json.dumps(). Raises StepExecutionError if not serializable.
130
+ """
131
+ output_schema_name = None
132
+ if isinstance(result, BaseModel):
133
+ outputs = result.model_dump(mode="json")
134
+ # Extract full module path for Pydantic class
135
+ # (e.g., "polos.llm.providers.base.LLMResponse")
136
+ output_schema_name = f"{result.__class__.__module__}.{result.__class__.__name__}"
137
+ else:
138
+ outputs = result
139
+
140
+ await store_step_output(
141
+ execution_id=self.ctx.execution_id,
142
+ step_key=step_key,
143
+ outputs=outputs,
144
+ error=None,
145
+ success=True,
146
+ source_execution_id=source_execution_id,
147
+ output_schema_name=output_schema_name,
148
+ )
149
+
150
+ async def _save_step_output_with_error(
151
+ self,
152
+ step_key: str,
153
+ error: str,
154
+ ) -> None:
155
+ """Save step output with error to database."""
156
+ await store_step_output(
157
+ execution_id=self.ctx.execution_id,
158
+ step_key=step_key,
159
+ outputs=None,
160
+ error={"message": error},
161
+ success=False,
162
+ source_execution_id=None,
163
+ )
164
+
165
+ async def _raise_step_execution_error(self, step_key: str, error: str) -> None:
166
+ """Raise a step execution error."""
167
+ await self._save_step_output_with_error(
168
+ step_key,
169
+ error,
170
+ )
171
+ raise StepExecutionError(error)
172
+
173
+ async def _publish_step_event(
174
+ self,
175
+ event_type: str,
176
+ step_key: str,
177
+ step_type: str,
178
+ input_params: dict[str, Any],
179
+ ) -> None:
180
+ """Publish a step event for the current workflow (fire-and-forget).
181
+
182
+ Args:
183
+ event_type: Type of event (e.g., "step_start", "step_finish")
184
+ step_key: Step key/identifier
185
+ step_type: Type of step (e.g., "run", "wait_for", "invoke", etc.)
186
+ input_params: Input parameters for the step
187
+ """
188
+ events = [
189
+ EventData(
190
+ data={
191
+ "step_key": step_key,
192
+ "step_type": step_type,
193
+ "data": safe_serialize(input_params) if input_params else {},
194
+ "_metadata": {
195
+ "execution_id": self.ctx.execution_id,
196
+ "workflow_id": self.ctx.workflow_id,
197
+ },
198
+ },
199
+ event_type=event_type,
200
+ )
201
+ ]
202
+ # Fire-and-forget: spawn task without awaiting to reduce latency
203
+ client = get_client_or_raise()
204
+ asyncio.create_task(
205
+ batch_publish(
206
+ client=client,
207
+ topic=f"workflow:{self.ctx.root_execution_id or self.ctx.execution_id}",
208
+ events=events,
209
+ execution_id=self.ctx.execution_id,
210
+ root_execution_id=self.ctx.root_execution_id,
211
+ )
212
+ )
213
+
214
+ async def run(
215
+ self,
216
+ step_key: str,
217
+ func: Callable,
218
+ *args,
219
+ max_retries: int = 2,
220
+ base_delay: float = 1.0,
221
+ max_delay: float = 10.0,
222
+ **kwargs,
223
+ ) -> Any:
224
+ """
225
+ Execute a callable as a durable step with retry support.
226
+
227
+ Checks step_outputs for existing result. If found, returns cached result.
228
+ Otherwise, executes function with retries, saves output, and returns result.
229
+
230
+ Args:
231
+ step_key: Step key identifier (must be unique per execution)
232
+ func: Callable to execute (sync or async)
233
+ *args: Positional arguments to pass to function
234
+ max_retries: Maximum number of retries on failure (default: 2)
235
+ base_delay: Base delay in seconds for exponential backoff (default: 1.0)
236
+ max_delay: Maximum delay in seconds (default: 10.0)
237
+ **kwargs: Keyword arguments to pass to function
238
+
239
+ Returns:
240
+ Result of function execution
241
+
242
+ Raises:
243
+ StepExecutionError: If function fails after all retries
244
+ """
245
+ # Check for existing step output
246
+ existing_step = await self._check_existing_step(step_key)
247
+ if existing_step:
248
+ return await self._handle_existing_step(existing_step)
249
+
250
+ # Get parent span context from execution context
251
+ exec_context = _execution_context.get()
252
+ parent_context = get_parent_span_context_from_execution_context(exec_context)
253
+ tracer = get_tracer()
254
+
255
+ # Create span for step execution using context manager
256
+ with tracer.start_as_current_span(
257
+ name=f"step.{step_key}",
258
+ context=parent_context, # None for root, or parent context for child
259
+ attributes={
260
+ "step.key": step_key,
261
+ "step.function": func.__name__ if hasattr(func, "__name__") else str(func),
262
+ "step.execution_id": self.ctx.execution_id,
263
+ "step.max_retries": max_retries,
264
+ },
265
+ ) as step_span:
266
+ # Update execution context with current span for nested spans
267
+ # Save old values to restore later
268
+ old_span_context = get_span_context_from_execution_context(exec_context)
269
+ set_span_context_in_execution_context(exec_context, step_span.get_span_context())
270
+ try:
271
+ # Use safe_serialize here because function arguments may be complex objects
272
+ # and may not be JSON serializable
273
+ safe_args = [safe_serialize(arg) for arg in args]
274
+ safe_kwargs = {k: safe_serialize(v) for k, v in kwargs.items()}
275
+
276
+ input_params = {
277
+ "func": func.__name__ if hasattr(func, "__name__") else str(func),
278
+ "args": safe_args,
279
+ "kwargs": safe_kwargs,
280
+ "max_retries": max_retries,
281
+ "base_delay": base_delay,
282
+ "max_delay": max_delay,
283
+ }
284
+
285
+ await self._publish_step_event(
286
+ "step_start",
287
+ step_key,
288
+ "run",
289
+ input_params,
290
+ )
291
+
292
+ # Store input in span attributes as JSON string
293
+ step_span.set_attributes(
294
+ {
295
+ "step.input": json.dumps(
296
+ {
297
+ "args": safe_args,
298
+ "kwargs": safe_kwargs,
299
+ }
300
+ ),
301
+ "step.function": func.__name__ if hasattr(func, "__name__") else str(func),
302
+ "step.max_retries": max_retries,
303
+ "step.base_delay": base_delay,
304
+ "step.max_delay": max_delay,
305
+ }
306
+ )
307
+
308
+ # Execute the function with retries
309
+ async def _execute_func() -> Any:
310
+ is_async = asyncio.iscoroutinefunction(func)
311
+ if is_async:
312
+ return await func(*args, **kwargs)
313
+ else:
314
+ # Run sync function in executor
315
+ # IMPORTANT: Capture the current context (including ContextVar values)
316
+ # so they can be restored in the executor thread
317
+ func_ctx = contextvars.copy_context()
318
+ loop = asyncio.get_event_loop()
319
+
320
+ # Execute in executor with context restored
321
+ def run_with_context():
322
+ return func_ctx.run(func, *args, **kwargs)
323
+
324
+ return await loop.run_in_executor(None, run_with_context)
325
+
326
+ try:
327
+ result = await retry_with_backoff(
328
+ _execute_func,
329
+ max_retries=max_retries,
330
+ base_delay=base_delay,
331
+ max_delay=max_delay,
332
+ )
333
+ serialized_result = serialize(result)
334
+
335
+ # Set span status to success
336
+ step_span.set_status(Status(StatusCode.OK))
337
+ step_span.set_attributes(
338
+ {
339
+ "step.status": "completed",
340
+ "step.output": json.dumps(serialized_result),
341
+ }
342
+ )
343
+
344
+ # Publish step_finish event
345
+ await self._publish_step_event(
346
+ "step_finish",
347
+ step_key,
348
+ "run",
349
+ {"result": serialized_result},
350
+ )
351
+
352
+ # Save step output on success
353
+ await self._save_step_output(
354
+ step_key,
355
+ result,
356
+ )
357
+
358
+ # Span automatically ended and stored by DatabaseSpanExporter
359
+ return result
360
+ except Exception as e:
361
+ # Set span status to error
362
+ step_span.set_status(Status(StatusCode.ERROR, str(e)))
363
+ step_span.record_exception(e)
364
+
365
+ # Store error in span attributes as JSON string
366
+ error_message = str(e)
367
+ step_error = {
368
+ "message": error_message,
369
+ "type": type(e).__name__,
370
+ }
371
+ step_span.set_attributes(
372
+ {
373
+ "step.error": json.dumps(safe_serialize(step_error)),
374
+ "step.status": "failed",
375
+ }
376
+ )
377
+
378
+ # Save error to step output
379
+ await self._save_step_output_with_error(
380
+ step_key,
381
+ error_message,
382
+ )
383
+
384
+ # Span automatically ended and stored by DatabaseSpanExporter
385
+ raise StepExecutionError(
386
+ f"Step execution failed after {max_retries} retries: {error_message}"
387
+ ) from e
388
+ finally:
389
+ # Restore previous span context values
390
+ set_span_context_in_execution_context(exec_context, old_span_context)
391
+
392
+ async def wait_for(
393
+ self,
394
+ step_key: str,
395
+ seconds: float | None = None,
396
+ minutes: float | None = None,
397
+ hours: float | None = None,
398
+ days: float | None = None,
399
+ weeks: float | None = None,
400
+ ) -> None:
401
+ """Wait for a time duration.
402
+
403
+ Args:
404
+ step_key: Step key identifier (must be unique per execution)
405
+ seconds: Optional seconds to wait
406
+ minutes: Optional minutes to wait
407
+ hours: Optional hours to wait
408
+ days: Optional days to wait
409
+ weeks: Optional weeks to wait
410
+ """
411
+ # Check for existing step output
412
+ existing_step = await self._check_existing_step(step_key)
413
+ if existing_step:
414
+ return await self._handle_existing_step(existing_step)
415
+
416
+ wait_seconds, wait_until = await _get_wait_time(seconds, minutes, hours, days, weeks)
417
+ if wait_seconds <= 0:
418
+ await self._raise_step_execution_error(step_key, error="Wait duration must be positive")
419
+
420
+ # Add span event for wait
421
+ current_span = get_current_span()
422
+ if current_span and hasattr(current_span, "add_event"):
423
+ # Build attributes dict, filtering out None values (OpenTelemetry doesn't accept None)
424
+ attributes = {
425
+ "step.key": step_key,
426
+ "wait.seconds": wait_seconds,
427
+ }
428
+ if wait_until:
429
+ attributes["wait.until"] = wait_until.isoformat()
430
+ if seconds is not None:
431
+ attributes["wait.seconds_param"] = seconds
432
+ if minutes is not None:
433
+ attributes["wait.minutes_param"] = minutes
434
+ if hours is not None:
435
+ attributes["wait.hours_param"] = hours
436
+ if days is not None:
437
+ attributes["wait.days_param"] = days
438
+ if weeks is not None:
439
+ attributes["wait.weeks_param"] = weeks
440
+
441
+ current_span.add_event("step.wait_for", attributes=attributes)
442
+
443
+ # Get wait threshold from environment (default 10 seconds)
444
+ wait_threshold = float(os.getenv("POLOS_WAIT_THRESHOLD_SECONDS", "10.0"))
445
+
446
+ if wait_seconds <= wait_threshold:
447
+ # Short wait - just sleep without raising WaitException
448
+ await asyncio.sleep(wait_seconds)
449
+ result = {"wait_until": wait_until.isoformat()}
450
+ await self._save_step_output(
451
+ step_key,
452
+ result,
453
+ )
454
+ return
455
+
456
+ # Long wait - pause execution atomically
457
+ await _set_waiting(
458
+ self.ctx.execution_id,
459
+ wait_until,
460
+ "time",
461
+ step_key,
462
+ )
463
+
464
+ # Raise a special exception to pause execution
465
+ # The orchestrator will resume it when wait_until is reached
466
+ raise WaitException(f"Waiting until {wait_until.isoformat()}")
467
+
468
+ async def wait_until(self, step_key: str, timestamp: datetime) -> None:
469
+ """Wait until a timestamp.
470
+
471
+ Args:
472
+ step_key: Step key identifier (must be unique per execution)
473
+ timestamp: Timestamp to wait until
474
+ """
475
+ # Check for existing step output
476
+ existing_step = await self._check_existing_step(step_key)
477
+ if existing_step:
478
+ return await self._handle_existing_step(existing_step)
479
+
480
+ # Ensure date is timezone-aware (use UTC if naive)
481
+ date = timestamp
482
+ if date.tzinfo is None:
483
+ date = date.replace(tzinfo=timezone.utc)
484
+
485
+ # Convert to UTC
486
+ wait_until = date.astimezone(timezone.utc)
487
+ now = datetime.now(timezone.utc)
488
+
489
+ if wait_until < now:
490
+ # Date is in the past, raise error
491
+ await self._raise_step_execution_error(
492
+ step_key, error=f"Wait date {timestamp} is in the past"
493
+ )
494
+
495
+ # Calculate wait duration
496
+ wait_seconds = (wait_until - now).total_seconds()
497
+ if wait_seconds < 0:
498
+ await self._raise_step_execution_error(
499
+ step_key, error=f"Wait date {timestamp} is in the past"
500
+ )
501
+
502
+ # Add span event for wait
503
+ current_span = get_current_span()
504
+ if current_span and hasattr(current_span, "add_event"):
505
+ current_span.add_event(
506
+ "step.wait_until",
507
+ attributes={
508
+ "step.key": step_key,
509
+ "wait.timestamp": timestamp.isoformat(),
510
+ "wait.until": wait_until.isoformat(),
511
+ "wait.seconds": wait_seconds,
512
+ },
513
+ )
514
+
515
+ # Get wait threshold from environment (default 10 seconds)
516
+ wait_threshold = float(os.getenv("POLOS_WAIT_THRESHOLD_SECONDS", "10.0"))
517
+
518
+ if wait_seconds <= wait_threshold:
519
+ # Short wait - just sleep without raising WaitException
520
+ await asyncio.sleep(wait_seconds)
521
+ result = {"wait_until": wait_until.isoformat()}
522
+ await self._save_step_output(
523
+ step_key,
524
+ result,
525
+ )
526
+ return
527
+
528
+ # Long wait - pause execution atomically
529
+ await _set_waiting(
530
+ self.ctx.execution_id,
531
+ wait_until,
532
+ "time",
533
+ step_key,
534
+ )
535
+
536
+ # Raise a special exception to pause execution
537
+ # The orchestrator will resume it when wait_until is reached
538
+ raise WaitException(f"Waiting until {wait_until.isoformat()}")
539
+
540
+ async def wait_for_event(
541
+ self, step_key: str, topic: str, timeout: int | None = None
542
+ ) -> EventPayload:
543
+ """Wait for an event on a topic.
544
+
545
+ Args:
546
+ step_key: Step key identifier (must be unique per execution)
547
+ topic: Event topic to wait for
548
+ timeout: Optional timeout in seconds. If provided, wait will expire after this duration.
549
+
550
+ Returns:
551
+ EventPayload: Event payload with sequence_id, topic, event_type, data, and created_at
552
+ """
553
+ # Check for existing step output
554
+ existing_step = await self._check_existing_step(step_key)
555
+ if existing_step:
556
+ result = await self._handle_existing_step(existing_step)
557
+ # Convert dict to EventPayload
558
+ if isinstance(result, dict):
559
+ return EventPayload.model_validate(result)
560
+ # If not event data, return as-is (shouldn't happen for wait_for_event)
561
+ return result
562
+
563
+ # Calculate expires_at if timeout is provided
564
+ expires_at = None
565
+ if timeout is not None:
566
+ expires_at = datetime.now(timezone.utc) + timedelta(seconds=timeout)
567
+
568
+ # Add span event for wait
569
+ current_span = get_current_span()
570
+ if current_span and hasattr(current_span, "add_event"):
571
+ wait_attributes = {
572
+ "step.key": step_key,
573
+ "wait.topic": topic,
574
+ }
575
+ if timeout is not None:
576
+ wait_attributes["wait.timeout"] = timeout
577
+ if expires_at is not None:
578
+ wait_attributes["wait.expires_at"] = expires_at.isoformat()
579
+ current_span.add_event("step.wait_for_event", attributes=wait_attributes)
580
+
581
+ # Add row in wait_steps with wait_type="event", wait_topic=topic, expires_at=timeout
582
+ await _set_waiting(
583
+ self.ctx.execution_id,
584
+ wait_until=expires_at,
585
+ wait_type="event",
586
+ step_key=step_key,
587
+ wait_topic=topic,
588
+ expires_at=expires_at,
589
+ )
590
+
591
+ # Raise WaitException to pause execution
592
+ # When resumed, execution continues from here and will check execution_step_outputs again
593
+ raise WaitException(f"Waiting for event on topic: {topic}")
594
+
595
+ async def publish_event(
596
+ self,
597
+ step_key: str,
598
+ topic: str,
599
+ data: dict[str, Any],
600
+ event_type: str | None = None,
601
+ ) -> None:
602
+ """Publish an event as a step.
603
+
604
+ Args:
605
+ step_key: Step key identifier (must be unique per execution)
606
+ topic: Event topic
607
+ data: Event data
608
+ event_type: Optional event type
609
+ """
610
+ # Check for existing step output
611
+ existing_step = await self._check_existing_step(step_key)
612
+ if existing_step:
613
+ return await self._handle_existing_step(existing_step)
614
+
615
+ events = [EventData(data=data, event_type=event_type)]
616
+ # Publish event
617
+ client = get_client_or_raise()
618
+ await batch_publish(
619
+ client=client,
620
+ topic=topic,
621
+ events=events,
622
+ execution_id=self.ctx.execution_id,
623
+ root_execution_id=self.ctx.root_execution_id,
624
+ )
625
+
626
+ await self._save_step_output(
627
+ step_key,
628
+ None,
629
+ )
630
+
631
+ async def publish_workflow_event(
632
+ self,
633
+ step_key: str,
634
+ data: dict[str, Any],
635
+ event_type: str | None = None,
636
+ ) -> None:
637
+ """Publish an event for the current workflow as a step.
638
+
639
+ Args:
640
+ step_key: Step key identifier (must be unique per execution)
641
+ data: Event data
642
+ event_type: Optional event type
643
+ """
644
+ topic = f"workflow:{self.ctx.root_execution_id or self.ctx.execution_id}"
645
+ return await self.publish_event(step_key, topic, data, event_type)
646
+
647
+ async def suspend(
648
+ self, step_key: str, data: dict[str, Any] | BaseModel, timeout: int | None = None
649
+ ) -> Any:
650
+ """Suspend execution and wait for a resume event.
651
+
652
+ This internally uses wait_for_event and waits for an event on topic
653
+ f"{step_key}/{ctx.root_execution_id}/resume".
654
+
655
+ The data passed to suspend will be included in the suspend event that is
656
+ published internally. When resumed, the event data from the resume event is returned.
657
+
658
+ Args:
659
+ step_key: Step key identifier (must be unique per execution)
660
+ data: Data to associate with the suspend (can be dict or Pydantic BaseModel)
661
+ timeout: Optional timeout in seconds. If provided, wait will expire after this duration.
662
+
663
+ Returns:
664
+ Event data from the resume event
665
+ """
666
+ # Check for existing step output
667
+ existing_step = await self._check_existing_step(step_key)
668
+ if existing_step:
669
+ return await self._handle_existing_step(existing_step)
670
+
671
+ serialized_data = serialize(data)
672
+ topic = f"{step_key}/{self.ctx.root_execution_id or self.ctx.execution_id}"
673
+ # Publish suspend event
674
+ client = get_client_or_raise()
675
+ await batch_publish(
676
+ client=client,
677
+ topic=topic,
678
+ events=[EventData(data=serialized_data, event_type="suspend")],
679
+ execution_id=self.ctx.execution_id,
680
+ root_execution_id=self.ctx.root_execution_id,
681
+ )
682
+
683
+ # Calculate expires_at if timeout is provided
684
+ expires_at = None
685
+ if timeout is not None:
686
+ expires_at = datetime.now(timezone.utc) + timedelta(seconds=timeout)
687
+
688
+ # Add row in wait_steps with wait_type="event", wait_topic=topic, expires_at=timeout
689
+ await _set_waiting(
690
+ self.ctx.execution_id,
691
+ wait_until=expires_at,
692
+ wait_type="event",
693
+ step_key=step_key,
694
+ wait_topic=topic,
695
+ expires_at=expires_at,
696
+ )
697
+
698
+ # Resume event will be added to step outputs by the orchestrator when it is received
699
+
700
+ # Raise WaitException to pause execution
701
+ # When resumed, execution continues from here and will check execution_step_outputs again
702
+ raise WaitException(f"Waiting for resume event: {topic}")
703
+
704
+ async def resume(
705
+ self,
706
+ step_key: str,
707
+ suspend_step_key: str,
708
+ suspend_execution_id: str,
709
+ data: dict[str, Any] | BaseModel,
710
+ ) -> None:
711
+ """Resume a suspended execution by publishing a resume event.
712
+
713
+ This publishes an event with topic f"{step_key}/{ctx.root_execution_id}",
714
+ event_type="resume", and data=data.
715
+
716
+ Args:
717
+ step_key: Step key identifier (must be unique per execution)
718
+ data: Data to pass in the resume event (can be dict or Pydantic BaseModel)
719
+ """
720
+ serialized_data = serialize(data)
721
+
722
+ topic = f"{suspend_step_key}/{suspend_execution_id}/resume"
723
+
724
+ # Publish event with event_type="resume"
725
+ await self.publish_event(
726
+ step_key=step_key,
727
+ topic=topic,
728
+ data=serialized_data,
729
+ event_type="resume",
730
+ )
731
+
732
+ async def _invoke(
733
+ self,
734
+ step_key: str,
735
+ workflow: str | Workflow,
736
+ payload: Any,
737
+ initial_state: dict[str, Any] | BaseModel | None = None,
738
+ queue: str | None = None,
739
+ concurrency_key: str | None = None,
740
+ run_timeout_seconds: int | None = None,
741
+ wait_for_subworkflow: bool = False,
742
+ ) -> Any:
743
+ """
744
+ Invoke another workflow as a step.
745
+
746
+ Args:
747
+ step_key: Step key identifier (must be unique per execution)
748
+ workflow: Workflow ID or Workflow instance
749
+ payload: Payload for the workflow
750
+ queue: Optional queue name
751
+ concurrency_key: Optional concurrency key
752
+ wait_for_subworkflow: Whether to wait for sub-workflow completion
753
+
754
+ Returns:
755
+ A tuple containing [ExecutionHandle of the sub-workflow,
756
+ True if the step output was found, False otherwise]
757
+ """
758
+ # Get workflow ID
759
+ if isinstance(workflow, Workflow):
760
+ workflow_id = workflow.id
761
+ workflow_instance = workflow
762
+ else:
763
+ workflow_id = workflow
764
+ workflow_instance = get_workflow(workflow_id)
765
+
766
+ if not workflow_instance:
767
+ raise StepExecutionError(f"Workflow {workflow_id} not found")
768
+
769
+ # Check for existing step output
770
+ existing_step = await self._check_existing_step(step_key)
771
+ if existing_step:
772
+ result = await self._handle_existing_step(existing_step)
773
+ return result, True
774
+
775
+ # Extract trace context for propagation to sub-workflow
776
+ exec_context = _execution_context.get()
777
+ traceparent = None
778
+ if exec_context:
779
+ # Get current span and extract traceparent
780
+ current_span = get_current_span()
781
+ if current_span:
782
+ traceparent = extract_traceparent(current_span)
783
+
784
+ # Invoke workflow
785
+ client = get_client_or_raise()
786
+ handle = await workflow_instance._invoke(
787
+ client,
788
+ payload,
789
+ initial_state=initial_state,
790
+ queue=queue,
791
+ concurrency_key=concurrency_key,
792
+ session_id=self.ctx.session_id,
793
+ user_id=self.ctx.user_id,
794
+ deployment_id=self.ctx.deployment_id,
795
+ parent_execution_id=self.ctx.execution_id,
796
+ root_execution_id=self.ctx.root_execution_id or self.ctx.execution_id,
797
+ step_key=step_key if wait_for_subworkflow else None,
798
+ wait_for_subworkflow=wait_for_subworkflow,
799
+ otel_traceparent=traceparent,
800
+ run_timeout_seconds=run_timeout_seconds,
801
+ )
802
+
803
+ if wait_for_subworkflow:
804
+ # No need to save the handle
805
+ # Orchestrator will save the output on completion
806
+ return None, False
807
+ else:
808
+ await self._save_step_output(
809
+ step_key,
810
+ handle,
811
+ )
812
+ return handle, False
813
+
814
+ async def invoke(
815
+ self,
816
+ step_key: str,
817
+ workflow: str | Workflow,
818
+ payload: Any,
819
+ initial_state: BaseModel | dict[str, Any] | None = None,
820
+ queue: str | None = None,
821
+ concurrency_key: str | None = None,
822
+ run_timeout_seconds: int | None = None,
823
+ ) -> Any:
824
+ """
825
+ Invoke another workflow as a step.
826
+
827
+ Args:
828
+ step_key: Step key identifier (must be unique per execution)
829
+ workflow: Workflow ID or Workflow instance
830
+ payload: Payload for the workflow
831
+ queue: Optional queue name
832
+ concurrency_key: Optional concurrency key
833
+ run_timeout_seconds: Optional timeout in seconds
834
+
835
+ Returns:
836
+ ExecutionHandle of the sub-workflow
837
+ """
838
+ # Note: step_finish will be emitted by the orchestrator when it saves the step output
839
+ result, found = await self._invoke(
840
+ step_key,
841
+ workflow,
842
+ payload,
843
+ initial_state,
844
+ queue,
845
+ concurrency_key,
846
+ wait_for_subworkflow=False,
847
+ run_timeout_seconds=run_timeout_seconds,
848
+ )
849
+ return result
850
+
851
+ async def invoke_and_wait(
852
+ self,
853
+ step_key: str,
854
+ workflow: str | Workflow,
855
+ payload: Any,
856
+ initial_state: BaseModel | dict[str, Any] | None = None,
857
+ queue: str | None = None,
858
+ concurrency_key: str | None = None,
859
+ run_timeout_seconds: int | None = None,
860
+ ) -> Any:
861
+ """
862
+ Invoke another workflow as a step.
863
+
864
+ This creates a child workflow execution and waits for it.
865
+ Note that this will raise WaitException to pause execution until the
866
+ child workflow completes.
867
+
868
+ Args:
869
+ step_key: Step key identifier (must be unique per execution)
870
+ workflow: Workflow ID or Workflow instance
871
+ payload: Payload for the workflow
872
+ queue: Optional queue name
873
+ concurrency_key: Optional concurrency key
874
+
875
+ Returns:
876
+ Any: Result of the child workflow
877
+ """
878
+ result, found = await self._invoke(
879
+ step_key,
880
+ workflow,
881
+ payload,
882
+ initial_state,
883
+ queue,
884
+ concurrency_key,
885
+ wait_for_subworkflow=True,
886
+ run_timeout_seconds=run_timeout_seconds,
887
+ )
888
+ if found:
889
+ # Step is complete, return result
890
+ return result
891
+
892
+ # Step did not exist, execute wait
893
+ from ..features.wait import WaitException
894
+
895
+ workflow_id = workflow.id if isinstance(workflow, Workflow) else workflow
896
+ raise WaitException(f"Waiting for sub-workflow {workflow_id} to complete")
897
+
898
+ async def batch_invoke(
899
+ self,
900
+ step_key: str,
901
+ workflows: list[BatchWorkflowInput],
902
+ ) -> list[ExecutionHandle]:
903
+ """
904
+ Invoke multiple workflows as a single step using the batch endpoint.
905
+
906
+ Args:
907
+ workflows: List of BatchWorkflowInput objects with 'id'
908
+ (workflow_id string) and 'payload' (dict or Pydantic model)
909
+
910
+ Returns:
911
+ List of ExecutionHandle objects for the submitted workflows
912
+ """
913
+ if not workflows:
914
+ return []
915
+
916
+ # Check for existing step output
917
+ existing_step = await self._check_existing_step(step_key)
918
+ if existing_step:
919
+ # Extract handles from existing step output
920
+ existing_output = await self._handle_existing_step(existing_step)
921
+ if existing_output and isinstance(existing_output, list):
922
+ handles = [
923
+ ExecutionHandle.model_validate(handle_data) for handle_data in existing_output
924
+ ]
925
+ return handles
926
+ return []
927
+
928
+ # Extract trace context for propagation to sub-workflow
929
+ exec_context = _execution_context.get()
930
+ traceparent = None
931
+ if exec_context:
932
+ # Get current span and extract traceparent
933
+ current_span = get_current_span()
934
+ if current_span:
935
+ traceparent = extract_traceparent(current_span)
936
+
937
+ # Build workflow requests for batch submission
938
+ workflow_requests = []
939
+ for workflow_input in workflows:
940
+ workflow_id = workflow_input.id
941
+ payload = serialize(workflow_input.payload)
942
+
943
+ workflow_obj = get_workflow(workflow_id)
944
+ if not workflow_obj:
945
+ raise StepExecutionError(f"Workflow '{workflow_id}' not found")
946
+
947
+ workflow_req = {
948
+ "workflow_id": workflow_id,
949
+ "payload": payload,
950
+ "initial_state": serialize(workflow_input.initial_state),
951
+ "run_timeout_seconds": workflow_input.run_timeout_seconds,
952
+ }
953
+
954
+ # Per-workflow properties (queue_name, concurrency_key, etc.)
955
+ if workflow_obj.queue_name is not None:
956
+ workflow_req["queue_name"] = workflow_obj.queue_name
957
+
958
+ if workflow_obj.queue_concurrency_limit is not None:
959
+ workflow_req["queue_concurrency_limit"] = workflow_obj.queue_concurrency_limit
960
+
961
+ workflow_requests.append(workflow_req)
962
+
963
+ # Submit all workflows in a single batch using the batch endpoint
964
+ client = get_client_or_raise()
965
+ handles = await client._submit_workflows(
966
+ workflows=workflow_requests,
967
+ deployment_id=self.ctx.deployment_id,
968
+ parent_execution_id=self.ctx.execution_id,
969
+ root_execution_id=self.ctx.root_execution_id or self.ctx.execution_id,
970
+ step_key=None, # Don't set step_key since we don't want to wait
971
+ # for the batch to complete
972
+ session_id=self.ctx.session_id,
973
+ user_id=self.ctx.user_id,
974
+ wait_for_subworkflow=False, # batch_invoke is fire-and-forget
975
+ otel_traceparent=traceparent,
976
+ )
977
+
978
+ await self._save_step_output(
979
+ step_key,
980
+ handles,
981
+ )
982
+ return handles
983
+
984
+ async def batch_invoke_and_wait(
985
+ self,
986
+ step_key: str,
987
+ workflows: list[BatchWorkflowInput],
988
+ ) -> list[BatchStepResult]:
989
+ """
990
+ Invoke multiple workflows as a single step using the batch endpoint
991
+ and wait for all to complete.
992
+
993
+ Args:
994
+ workflows: List of BatchWorkflowInput objects with 'id'
995
+ (workflow_id string) and 'payload' (dict or Pydantic model)
996
+
997
+ Returns:
998
+ List of BatchStepResult objects, one for each workflow
999
+ """
1000
+ if not workflows:
1001
+ return []
1002
+
1003
+ # Check for existing step output
1004
+ existing_step = await self._check_existing_step(step_key)
1005
+ if existing_step:
1006
+ # Extract results from existing step output
1007
+ existing_output = None
1008
+ try:
1009
+ existing_output = await self._handle_existing_step(existing_step)
1010
+ except Exception:
1011
+ existing_output = existing_step.get("outputs")
1012
+
1013
+ if existing_output and isinstance(existing_output, list):
1014
+ # Reconstruct BatchStepResult objects from stored dicts
1015
+ batch_results = []
1016
+ for item in existing_output:
1017
+ result = item.get("result")
1018
+ if result:
1019
+ deserialized_result = await deserialize(
1020
+ result, item.get("result_schema_name")
1021
+ )
1022
+ item["result"] = deserialized_result
1023
+
1024
+ batch_results.append(BatchStepResult.model_validate(item))
1025
+ return batch_results
1026
+ return []
1027
+
1028
+ # Extract trace context for propagation to sub-workflow
1029
+ exec_context = _execution_context.get()
1030
+ traceparent = None
1031
+ if exec_context:
1032
+ # Get current span and extract traceparent
1033
+ current_span = get_current_span()
1034
+ if current_span:
1035
+ traceparent = extract_traceparent(current_span)
1036
+
1037
+ # Build workflow requests for batch submission
1038
+ workflow_requests = []
1039
+ for _i, workflow_input in enumerate(workflows):
1040
+ workflow_id = workflow_input.id
1041
+ payload = serialize(workflow_input.payload)
1042
+
1043
+ workflow_obj = get_workflow(workflow_id)
1044
+ if not workflow_obj:
1045
+ raise StepExecutionError(f"Workflow '{workflow_id}' not found")
1046
+
1047
+ workflow_req = {
1048
+ "workflow_id": workflow_id,
1049
+ "payload": payload,
1050
+ "initial_state": serialize(workflow_input.initial_state),
1051
+ "run_timeout_seconds": workflow_input.run_timeout_seconds,
1052
+ }
1053
+
1054
+ # Per-workflow properties (queue_name, concurrency_key, etc.)
1055
+ if workflow_obj.queue_name is not None:
1056
+ workflow_req["queue_name"] = workflow_obj.queue_name
1057
+
1058
+ if workflow_obj.queue_concurrency_limit is not None:
1059
+ workflow_req["queue_concurrency_limit"] = workflow_obj.queue_concurrency_limit
1060
+
1061
+ workflow_requests.append(workflow_req)
1062
+
1063
+ # Submit all workflows in a single batch using the batch endpoint
1064
+ # with wait_for_subworkflow=True
1065
+ client = get_client_or_raise()
1066
+ await client._submit_workflows(
1067
+ workflows=workflow_requests,
1068
+ deployment_id=self.ctx.deployment_id,
1069
+ parent_execution_id=self.ctx.execution_id,
1070
+ root_execution_id=self.ctx.root_execution_id or self.ctx.execution_id,
1071
+ step_key=step_key,
1072
+ session_id=self.ctx.session_id,
1073
+ user_id=self.ctx.user_id,
1074
+ wait_for_subworkflow=True, # This will set parent to waiting until all complete
1075
+ otel_traceparent=traceparent,
1076
+ )
1077
+
1078
+ # Raise WaitException to pause execution
1079
+ # The orchestrator will resume the execution when all sub-workflows complete
1080
+ # When resumed, this function will be called again and all workflows should have completed
1081
+ from ..features.wait import WaitException
1082
+
1083
+ workflow_ids = [w.id for w in workflows]
1084
+ raise WaitException(f"Waiting for sub-workflows {workflow_ids} to complete")
1085
+
1086
+ async def agent_invoke(
1087
+ self,
1088
+ step_key: str,
1089
+ config: Any,
1090
+ ) -> ExecutionHandle:
1091
+ """
1092
+ Invoke a single agent as a workflow step without waiting for completion.
1093
+
1094
+ This is designed to be used with Agent.with_input(), which returns
1095
+ an AgentRunConfig instance.
1096
+
1097
+ Args:
1098
+ step_key: Step key identifier (must be unique per execution)
1099
+ config: AgentRunConfig instance
1100
+ """
1101
+ from ..agents.agent import AgentRunConfig # Local import to avoid circular dependency
1102
+
1103
+ if not isinstance(config, AgentRunConfig):
1104
+ raise StepExecutionError(
1105
+ f"agent_invoke expects an AgentRunConfig, got {type(config).__name__}"
1106
+ )
1107
+
1108
+ payload = {
1109
+ "input": config.input,
1110
+ "streaming": config.streaming,
1111
+ "session_id": self.ctx.session_id,
1112
+ "user_id": self.ctx.user_id,
1113
+ "conversation_id": config.conversation_id**config.kwargs,
1114
+ }
1115
+
1116
+ workflow_obj = get_workflow(config.agent.id)
1117
+ if not workflow_obj:
1118
+ raise StepExecutionError(f"Agent workflow '{config.agent.id}' not found")
1119
+
1120
+ result, found = await self._invoke(
1121
+ step_key,
1122
+ workflow_obj,
1123
+ serialize(payload),
1124
+ initial_state=serialize(config.initial_state),
1125
+ wait_for_subworkflow=False,
1126
+ run_timeout_seconds=config.run_timeout_seconds,
1127
+ )
1128
+ return result
1129
+
1130
+ async def agent_invoke_and_wait(
1131
+ self,
1132
+ step_key: str,
1133
+ config: Any,
1134
+ ) -> Any:
1135
+ """
1136
+ Invoke a single agent as a workflow step and wait for completion.
1137
+
1138
+ This is designed to be used with Agent.with_input(), which returns
1139
+ an AgentRunConfig instance. The returned value is whatever the
1140
+ underlying agent workflow returns (typically an AgentResult).
1141
+
1142
+ Args:
1143
+ step_key: Step key identifier (must be unique per execution)
1144
+ config: AgentRunConfig instance
1145
+ """
1146
+ from ..agents.agent import AgentRunConfig # Local import to avoid circular dependency
1147
+
1148
+ if not isinstance(config, AgentRunConfig):
1149
+ raise StepExecutionError(
1150
+ f"agent_invoke_and_wait expects an AgentRunConfig, got {type(config).__name__}"
1151
+ )
1152
+
1153
+ payload = {
1154
+ "input": config.input,
1155
+ "streaming": config.streaming,
1156
+ "session_id": self.ctx.session_id,
1157
+ "user_id": self.ctx.user_id,
1158
+ "conversation_id": config.conversation_id,
1159
+ **config.kwargs,
1160
+ }
1161
+
1162
+ workflow_obj = get_workflow(config.agent.id)
1163
+ if not workflow_obj:
1164
+ raise StepExecutionError(f"Agent workflow '{config.agent.id}' not found")
1165
+
1166
+ result, found = await self._invoke(
1167
+ step_key,
1168
+ workflow_obj,
1169
+ serialize(payload),
1170
+ initial_state=serialize(config.initial_state),
1171
+ wait_for_subworkflow=True,
1172
+ run_timeout_seconds=config.run_timeout_seconds,
1173
+ )
1174
+ if found:
1175
+ # Step is complete, return result
1176
+ return result
1177
+
1178
+ # Step did not exist yet; raise WaitException to pause execution
1179
+ from ..features.wait import WaitException
1180
+
1181
+ raise WaitException(f"Waiting for agent workflow '{config.agent.id}' to complete")
1182
+
1183
+ async def batch_agent_invoke(
1184
+ self,
1185
+ step_key: str,
1186
+ configs: list[Any],
1187
+ ) -> list[ExecutionHandle]:
1188
+ """
1189
+ Invoke multiple agents in parallel as a single workflow step.
1190
+
1191
+ This is designed to be used with Agent.with_input(), which returns
1192
+ AgentRunConfig instances.
1193
+ """
1194
+ from ..agents.agent import AgentRunConfig # Local import to avoid circular dependency
1195
+
1196
+ workflows: list[BatchWorkflowInput] = []
1197
+ for config in configs:
1198
+ if not isinstance(config, AgentRunConfig):
1199
+ raise StepExecutionError(
1200
+ f"batch_agent_invoke expects AgentRunConfig instances, "
1201
+ f"got {type(config).__name__}"
1202
+ )
1203
+ payload = {
1204
+ "input": config.input,
1205
+ "streaming": config.streaming,
1206
+ "session_id": self.ctx.session_id,
1207
+ "user_id": self.ctx.user_id,
1208
+ "conversation_id": config.conversation_id,
1209
+ **config.kwargs,
1210
+ }
1211
+ workflows.append(
1212
+ BatchWorkflowInput(
1213
+ id=config.agent.id,
1214
+ payload=payload,
1215
+ initial_state=config.initial_state,
1216
+ run_timeout_seconds=config.run_timeout_seconds,
1217
+ )
1218
+ )
1219
+
1220
+ return await self.batch_invoke(step_key, workflows)
1221
+
1222
+ async def batch_agent_invoke_and_wait(
1223
+ self,
1224
+ step_key: str,
1225
+ configs: list[Any],
1226
+ ) -> list[BatchStepResult]:
1227
+ """
1228
+ Invoke multiple agents in parallel as a single workflow step and wait for all to complete.
1229
+
1230
+ This is designed to be used with Agent.with_input(), which returns
1231
+ AgentRunConfig instances. The results are returned as BatchStepResult,
1232
+ where each .result is the AgentResult from the corresponding agent.
1233
+ """
1234
+ from ..agents.agent import AgentRunConfig # Local import to avoid circular dependency
1235
+
1236
+ workflows: list[BatchWorkflowInput] = []
1237
+ for config in configs:
1238
+ if not isinstance(config, AgentRunConfig):
1239
+ raise StepExecutionError(
1240
+ f"batch_agent_invoke_and_wait expects AgentRunConfig instances, "
1241
+ f"got {type(config).__name__}"
1242
+ )
1243
+ payload = {
1244
+ "input": config.input,
1245
+ "streaming": config.streaming,
1246
+ "session_id": self.ctx.session_id,
1247
+ "user_id": self.ctx.user_id,
1248
+ "conversation_id": config.conversation_id,
1249
+ **config.kwargs,
1250
+ }
1251
+ workflows.append(
1252
+ BatchWorkflowInput(
1253
+ id=config.agent.id,
1254
+ payload=payload,
1255
+ initial_state=config.initial_state,
1256
+ run_timeout_seconds=config.run_timeout_seconds,
1257
+ )
1258
+ )
1259
+
1260
+ return await self.batch_invoke_and_wait(step_key, workflows)
1261
+
1262
+ async def uuid(self, step_key: str) -> str:
1263
+ """Get a UUID that is persisted across workflow runs.
1264
+
1265
+ On the first execution, generates a new UUID and saves it.
1266
+ On subsequent executions (replay/resume), returns the same UUID.
1267
+
1268
+ Args:
1269
+ step_key: Step key identifier (must be unique per execution)
1270
+
1271
+ Returns:
1272
+ UUID string (persisted across runs)
1273
+ """
1274
+ # Check for existing step output
1275
+ existing_step = await self._check_existing_step(step_key)
1276
+ if existing_step:
1277
+ return await self._handle_existing_step(existing_step)
1278
+
1279
+ # Generate new UUID
1280
+ generated_uuid = str(uuid_module.uuid4())
1281
+
1282
+ # Save the UUID
1283
+ await self._save_step_output(step_key, generated_uuid)
1284
+
1285
+ return generated_uuid
1286
+
1287
+ async def now(self, step_key: str) -> int:
1288
+ """Get current timestamp in milliseconds (durable across runs).
1289
+
1290
+ Args:
1291
+ step_key: Step key identifier (must be unique per execution)
1292
+
1293
+ Returns:
1294
+ Current timestamp in milliseconds since epoch
1295
+ """
1296
+ # Check for existing step output
1297
+ existing_step = await self._check_existing_step(step_key)
1298
+ if existing_step:
1299
+ return existing_step.get("outputs", int(time.time() * 1000))
1300
+
1301
+ # Generate new timestamp
1302
+ timestamp = int(time.time() * 1000)
1303
+
1304
+ # Save step output
1305
+ await self._save_step_output(step_key, timestamp)
1306
+
1307
+ return timestamp
1308
+
1309
+ async def random(self, step_key: str) -> float:
1310
+ """Get a random float between 0.0 and 1.0 that is persisted across workflow runs.
1311
+
1312
+ On the first execution, generates a new random number and saves it.
1313
+ On subsequent executions (replay/resume), returns the same random number.
1314
+
1315
+ Args:
1316
+ step_key: Step key identifier (must be unique per execution)
1317
+
1318
+ Returns:
1319
+ Random float between 0.0 and 1.0 (persisted across runs)
1320
+ """
1321
+ # Check for existing step output
1322
+ existing_step = await self._check_existing_step(step_key)
1323
+ if existing_step:
1324
+ return await self._handle_existing_step(existing_step)
1325
+
1326
+ # Generate new random number
1327
+ random_value = random_module.random()
1328
+
1329
+ # Save the random number
1330
+ await self._save_step_output(step_key, random_value)
1331
+
1332
+ return random_value
1333
+
1334
+ @contextmanager
1335
+ def trace(self, name: str, attributes: dict[str, Any] | None = None):
1336
+ """Create a custom span within the current step execution.
1337
+
1338
+ Args:
1339
+ name: Name of the span
1340
+ attributes: Optional dictionary of span attributes
1341
+
1342
+ Returns:
1343
+ Context manager for use in 'with' statement
1344
+
1345
+ Example:
1346
+ async def my_step(ctx):
1347
+ with ctx.step.trace("database_query", {"table": "users"}):
1348
+ result = await db.query("SELECT * FROM users")
1349
+ return result
1350
+ """
1351
+ # Use the current OpenTelemetry span as parent (which should be the step/workflow span)
1352
+ # This ensures proper parent-child relationship within the same trace
1353
+ exec_context = _execution_context.get()
1354
+ tracer = get_tracer()
1355
+ parent_context = get_parent_span_context_from_execution_context(exec_context)
1356
+
1357
+ # Create span using context manager
1358
+ with tracer.start_as_current_span(
1359
+ name=name, context=parent_context, attributes=attributes or {}
1360
+ ) as span:
1361
+ # Update execution context with current span for nested spans
1362
+ # Save old values to restore later
1363
+ old_span_context = get_span_context_from_execution_context(exec_context)
1364
+ set_span_context_in_execution_context(exec_context, span.get_span_context())
1365
+
1366
+ try:
1367
+ yield span
1368
+ # Set status to success
1369
+ span.set_status(Status(StatusCode.OK))
1370
+ # Span automatically ended and stored by DatabaseSpanExporter
1371
+ except Exception as e:
1372
+ # Set status to error
1373
+ span.set_status(Status(StatusCode.ERROR, str(e)))
1374
+ span.record_exception(e)
1375
+ # Span automatically ended and stored by DatabaseSpanExporter
1376
+ raise
1377
+ finally:
1378
+ # Restore previous span context
1379
+ set_span_context_in_execution_context(exec_context, old_span_context)
1380
+ # Span context automatically cleaned up by context manager