polos-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- polos/__init__.py +105 -0
- polos/agents/__init__.py +7 -0
- polos/agents/agent.py +746 -0
- polos/agents/conversation_history.py +121 -0
- polos/agents/stop_conditions.py +280 -0
- polos/agents/stream.py +635 -0
- polos/core/__init__.py +0 -0
- polos/core/context.py +143 -0
- polos/core/state.py +26 -0
- polos/core/step.py +1380 -0
- polos/core/workflow.py +1192 -0
- polos/features/__init__.py +0 -0
- polos/features/events.py +456 -0
- polos/features/schedules.py +110 -0
- polos/features/tracing.py +605 -0
- polos/features/wait.py +82 -0
- polos/llm/__init__.py +9 -0
- polos/llm/generate.py +152 -0
- polos/llm/providers/__init__.py +5 -0
- polos/llm/providers/anthropic.py +615 -0
- polos/llm/providers/azure.py +42 -0
- polos/llm/providers/base.py +196 -0
- polos/llm/providers/fireworks.py +41 -0
- polos/llm/providers/gemini.py +40 -0
- polos/llm/providers/groq.py +40 -0
- polos/llm/providers/openai.py +1021 -0
- polos/llm/providers/together.py +40 -0
- polos/llm/stream.py +183 -0
- polos/middleware/__init__.py +0 -0
- polos/middleware/guardrail.py +148 -0
- polos/middleware/guardrail_executor.py +253 -0
- polos/middleware/hook.py +164 -0
- polos/middleware/hook_executor.py +104 -0
- polos/runtime/__init__.py +0 -0
- polos/runtime/batch.py +87 -0
- polos/runtime/client.py +841 -0
- polos/runtime/queue.py +42 -0
- polos/runtime/worker.py +1365 -0
- polos/runtime/worker_server.py +249 -0
- polos/tools/__init__.py +0 -0
- polos/tools/tool.py +587 -0
- polos/types/__init__.py +23 -0
- polos/types/types.py +116 -0
- polos/utils/__init__.py +27 -0
- polos/utils/agent.py +27 -0
- polos/utils/client_context.py +41 -0
- polos/utils/config.py +12 -0
- polos/utils/output_schema.py +311 -0
- polos/utils/retry.py +47 -0
- polos/utils/serializer.py +167 -0
- polos/utils/tracing.py +27 -0
- polos/utils/worker_singleton.py +40 -0
- polos_sdk-0.1.0.dist-info/METADATA +650 -0
- polos_sdk-0.1.0.dist-info/RECORD +55 -0
- polos_sdk-0.1.0.dist-info/WHEEL +4 -0
polos/core/step.py
ADDED
|
@@ -0,0 +1,1380 @@
|
|
|
1
|
+
"""Step execution helper for durable execution within workflows."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import contextvars
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
import os
|
|
10
|
+
import random as random_module
|
|
11
|
+
import time
|
|
12
|
+
import uuid as uuid_module
|
|
13
|
+
from collections.abc import Callable
|
|
14
|
+
from contextlib import contextmanager
|
|
15
|
+
from datetime import datetime, timedelta, timezone
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
from opentelemetry.trace import Status, StatusCode
|
|
19
|
+
from pydantic import BaseModel
|
|
20
|
+
|
|
21
|
+
from ..features.events import EventData, EventPayload, batch_publish
|
|
22
|
+
from ..features.tracing import extract_traceparent, get_current_span, get_tracer
|
|
23
|
+
from ..features.wait import WaitException, _get_wait_time, _set_waiting
|
|
24
|
+
from ..runtime.client import ExecutionHandle, get_step_output, store_step_output
|
|
25
|
+
from ..types.types import BatchStepResult, BatchWorkflowInput
|
|
26
|
+
from ..utils.client_context import get_client_or_raise
|
|
27
|
+
from ..utils.retry import retry_with_backoff
|
|
28
|
+
from ..utils.serializer import deserialize, safe_serialize, serialize
|
|
29
|
+
from ..utils.tracing import (
|
|
30
|
+
get_parent_span_context_from_execution_context,
|
|
31
|
+
get_span_context_from_execution_context,
|
|
32
|
+
set_span_context_in_execution_context,
|
|
33
|
+
)
|
|
34
|
+
from .context import WorkflowContext
|
|
35
|
+
from .workflow import (
|
|
36
|
+
StepExecutionError,
|
|
37
|
+
Workflow,
|
|
38
|
+
_execution_context,
|
|
39
|
+
get_workflow,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
logger = logging.getLogger(__name__)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class Step:
|
|
46
|
+
"""Step execution helper - provides durable execution primitives.
|
|
47
|
+
|
|
48
|
+
Steps are executed within a workflow context and their outputs are
|
|
49
|
+
saved to avoid re-execution on workflow resume/replay.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(self, ctx: WorkflowContext):
|
|
53
|
+
"""Initialize Step with a WorkflowContext.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
ctx: The workflow execution context
|
|
57
|
+
"""
|
|
58
|
+
self.ctx = ctx
|
|
59
|
+
|
|
60
|
+
async def _check_existing_step(self, step_key: str) -> dict[str, Any] | None:
|
|
61
|
+
"""
|
|
62
|
+
Check for existing step output using step_key.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
step_key: Step key identifier (must be unique per execution)
|
|
66
|
+
workflow: Whether the step is a workflow step
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
existing_step_output or None
|
|
70
|
+
- If existing_step_output is not None, it means the step was already executed
|
|
71
|
+
- The caller should handle returning cached results or raising errors
|
|
72
|
+
"""
|
|
73
|
+
return await get_step_output(self.ctx.execution_id, step_key)
|
|
74
|
+
|
|
75
|
+
async def _handle_existing_step(
|
|
76
|
+
self,
|
|
77
|
+
existing_step: dict[str, Any],
|
|
78
|
+
) -> Any:
|
|
79
|
+
"""
|
|
80
|
+
Handle existing step output - either return cached result or raise error.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
existing_step: The existing step output from _check_existing_step
|
|
84
|
+
workflow: Whether the step is a workflow step
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
Cached result (from on_success callback or existing_step.get("outputs"))
|
|
88
|
+
|
|
89
|
+
Raises:
|
|
90
|
+
StepExecutionError: If step previously failed
|
|
91
|
+
"""
|
|
92
|
+
if existing_step.get("success", False):
|
|
93
|
+
outputs = existing_step.get("outputs")
|
|
94
|
+
if outputs:
|
|
95
|
+
result = await deserialize(outputs, existing_step.get("output_schema_name"))
|
|
96
|
+
return result
|
|
97
|
+
else:
|
|
98
|
+
return None
|
|
99
|
+
else:
|
|
100
|
+
error = existing_step.get("error", {})
|
|
101
|
+
error_message = (
|
|
102
|
+
error.get("message", "Step execution failed")
|
|
103
|
+
if isinstance(error, dict)
|
|
104
|
+
else str(error)
|
|
105
|
+
)
|
|
106
|
+
raise StepExecutionError(error_message)
|
|
107
|
+
|
|
108
|
+
async def _save_step_output(
|
|
109
|
+
self,
|
|
110
|
+
step_key: str,
|
|
111
|
+
result: Any,
|
|
112
|
+
source_execution_id: str | None = None,
|
|
113
|
+
) -> None:
|
|
114
|
+
"""Save step output to database using step_key as the unique identifier.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
step_key: Step key identifier (must be unique per execution)
|
|
118
|
+
result: Result to save
|
|
119
|
+
source_execution_id: Optional source execution ID
|
|
120
|
+
|
|
121
|
+
If result is a Pydantic BaseModel, converts it to dict using model_dump(mode="json")
|
|
122
|
+
to ensure only valid Pydantic models are stored. model_dump(mode="json") automatically
|
|
123
|
+
handles nested Pydantic models within the model.
|
|
124
|
+
|
|
125
|
+
Also extracts and stores the full module path of Pydantic classes for
|
|
126
|
+
automatic deserialization when reading from the database.
|
|
127
|
+
|
|
128
|
+
If result is not a Pydantic model, validates that it's JSON serializable
|
|
129
|
+
by attempting json.dumps(). Raises StepExecutionError if not serializable.
|
|
130
|
+
"""
|
|
131
|
+
output_schema_name = None
|
|
132
|
+
if isinstance(result, BaseModel):
|
|
133
|
+
outputs = result.model_dump(mode="json")
|
|
134
|
+
# Extract full module path for Pydantic class
|
|
135
|
+
# (e.g., "polos.llm.providers.base.LLMResponse")
|
|
136
|
+
output_schema_name = f"{result.__class__.__module__}.{result.__class__.__name__}"
|
|
137
|
+
else:
|
|
138
|
+
outputs = result
|
|
139
|
+
|
|
140
|
+
await store_step_output(
|
|
141
|
+
execution_id=self.ctx.execution_id,
|
|
142
|
+
step_key=step_key,
|
|
143
|
+
outputs=outputs,
|
|
144
|
+
error=None,
|
|
145
|
+
success=True,
|
|
146
|
+
source_execution_id=source_execution_id,
|
|
147
|
+
output_schema_name=output_schema_name,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
async def _save_step_output_with_error(
|
|
151
|
+
self,
|
|
152
|
+
step_key: str,
|
|
153
|
+
error: str,
|
|
154
|
+
) -> None:
|
|
155
|
+
"""Save step output with error to database."""
|
|
156
|
+
await store_step_output(
|
|
157
|
+
execution_id=self.ctx.execution_id,
|
|
158
|
+
step_key=step_key,
|
|
159
|
+
outputs=None,
|
|
160
|
+
error={"message": error},
|
|
161
|
+
success=False,
|
|
162
|
+
source_execution_id=None,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
async def _raise_step_execution_error(self, step_key: str, error: str) -> None:
|
|
166
|
+
"""Raise a step execution error."""
|
|
167
|
+
await self._save_step_output_with_error(
|
|
168
|
+
step_key,
|
|
169
|
+
error,
|
|
170
|
+
)
|
|
171
|
+
raise StepExecutionError(error)
|
|
172
|
+
|
|
173
|
+
async def _publish_step_event(
|
|
174
|
+
self,
|
|
175
|
+
event_type: str,
|
|
176
|
+
step_key: str,
|
|
177
|
+
step_type: str,
|
|
178
|
+
input_params: dict[str, Any],
|
|
179
|
+
) -> None:
|
|
180
|
+
"""Publish a step event for the current workflow (fire-and-forget).
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
event_type: Type of event (e.g., "step_start", "step_finish")
|
|
184
|
+
step_key: Step key/identifier
|
|
185
|
+
step_type: Type of step (e.g., "run", "wait_for", "invoke", etc.)
|
|
186
|
+
input_params: Input parameters for the step
|
|
187
|
+
"""
|
|
188
|
+
events = [
|
|
189
|
+
EventData(
|
|
190
|
+
data={
|
|
191
|
+
"step_key": step_key,
|
|
192
|
+
"step_type": step_type,
|
|
193
|
+
"data": safe_serialize(input_params) if input_params else {},
|
|
194
|
+
"_metadata": {
|
|
195
|
+
"execution_id": self.ctx.execution_id,
|
|
196
|
+
"workflow_id": self.ctx.workflow_id,
|
|
197
|
+
},
|
|
198
|
+
},
|
|
199
|
+
event_type=event_type,
|
|
200
|
+
)
|
|
201
|
+
]
|
|
202
|
+
# Fire-and-forget: spawn task without awaiting to reduce latency
|
|
203
|
+
client = get_client_or_raise()
|
|
204
|
+
asyncio.create_task(
|
|
205
|
+
batch_publish(
|
|
206
|
+
client=client,
|
|
207
|
+
topic=f"workflow:{self.ctx.root_execution_id or self.ctx.execution_id}",
|
|
208
|
+
events=events,
|
|
209
|
+
execution_id=self.ctx.execution_id,
|
|
210
|
+
root_execution_id=self.ctx.root_execution_id,
|
|
211
|
+
)
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
async def run(
|
|
215
|
+
self,
|
|
216
|
+
step_key: str,
|
|
217
|
+
func: Callable,
|
|
218
|
+
*args,
|
|
219
|
+
max_retries: int = 2,
|
|
220
|
+
base_delay: float = 1.0,
|
|
221
|
+
max_delay: float = 10.0,
|
|
222
|
+
**kwargs,
|
|
223
|
+
) -> Any:
|
|
224
|
+
"""
|
|
225
|
+
Execute a callable as a durable step with retry support.
|
|
226
|
+
|
|
227
|
+
Checks step_outputs for existing result. If found, returns cached result.
|
|
228
|
+
Otherwise, executes function with retries, saves output, and returns result.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
step_key: Step key identifier (must be unique per execution)
|
|
232
|
+
func: Callable to execute (sync or async)
|
|
233
|
+
*args: Positional arguments to pass to function
|
|
234
|
+
max_retries: Maximum number of retries on failure (default: 2)
|
|
235
|
+
base_delay: Base delay in seconds for exponential backoff (default: 1.0)
|
|
236
|
+
max_delay: Maximum delay in seconds (default: 10.0)
|
|
237
|
+
**kwargs: Keyword arguments to pass to function
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
Result of function execution
|
|
241
|
+
|
|
242
|
+
Raises:
|
|
243
|
+
StepExecutionError: If function fails after all retries
|
|
244
|
+
"""
|
|
245
|
+
# Check for existing step output
|
|
246
|
+
existing_step = await self._check_existing_step(step_key)
|
|
247
|
+
if existing_step:
|
|
248
|
+
return await self._handle_existing_step(existing_step)
|
|
249
|
+
|
|
250
|
+
# Get parent span context from execution context
|
|
251
|
+
exec_context = _execution_context.get()
|
|
252
|
+
parent_context = get_parent_span_context_from_execution_context(exec_context)
|
|
253
|
+
tracer = get_tracer()
|
|
254
|
+
|
|
255
|
+
# Create span for step execution using context manager
|
|
256
|
+
with tracer.start_as_current_span(
|
|
257
|
+
name=f"step.{step_key}",
|
|
258
|
+
context=parent_context, # None for root, or parent context for child
|
|
259
|
+
attributes={
|
|
260
|
+
"step.key": step_key,
|
|
261
|
+
"step.function": func.__name__ if hasattr(func, "__name__") else str(func),
|
|
262
|
+
"step.execution_id": self.ctx.execution_id,
|
|
263
|
+
"step.max_retries": max_retries,
|
|
264
|
+
},
|
|
265
|
+
) as step_span:
|
|
266
|
+
# Update execution context with current span for nested spans
|
|
267
|
+
# Save old values to restore later
|
|
268
|
+
old_span_context = get_span_context_from_execution_context(exec_context)
|
|
269
|
+
set_span_context_in_execution_context(exec_context, step_span.get_span_context())
|
|
270
|
+
try:
|
|
271
|
+
# Use safe_serialize here because function arguments may be complex objects
|
|
272
|
+
# and may not be JSON serializable
|
|
273
|
+
safe_args = [safe_serialize(arg) for arg in args]
|
|
274
|
+
safe_kwargs = {k: safe_serialize(v) for k, v in kwargs.items()}
|
|
275
|
+
|
|
276
|
+
input_params = {
|
|
277
|
+
"func": func.__name__ if hasattr(func, "__name__") else str(func),
|
|
278
|
+
"args": safe_args,
|
|
279
|
+
"kwargs": safe_kwargs,
|
|
280
|
+
"max_retries": max_retries,
|
|
281
|
+
"base_delay": base_delay,
|
|
282
|
+
"max_delay": max_delay,
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
await self._publish_step_event(
|
|
286
|
+
"step_start",
|
|
287
|
+
step_key,
|
|
288
|
+
"run",
|
|
289
|
+
input_params,
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
# Store input in span attributes as JSON string
|
|
293
|
+
step_span.set_attributes(
|
|
294
|
+
{
|
|
295
|
+
"step.input": json.dumps(
|
|
296
|
+
{
|
|
297
|
+
"args": safe_args,
|
|
298
|
+
"kwargs": safe_kwargs,
|
|
299
|
+
}
|
|
300
|
+
),
|
|
301
|
+
"step.function": func.__name__ if hasattr(func, "__name__") else str(func),
|
|
302
|
+
"step.max_retries": max_retries,
|
|
303
|
+
"step.base_delay": base_delay,
|
|
304
|
+
"step.max_delay": max_delay,
|
|
305
|
+
}
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
# Execute the function with retries
|
|
309
|
+
async def _execute_func() -> Any:
|
|
310
|
+
is_async = asyncio.iscoroutinefunction(func)
|
|
311
|
+
if is_async:
|
|
312
|
+
return await func(*args, **kwargs)
|
|
313
|
+
else:
|
|
314
|
+
# Run sync function in executor
|
|
315
|
+
# IMPORTANT: Capture the current context (including ContextVar values)
|
|
316
|
+
# so they can be restored in the executor thread
|
|
317
|
+
func_ctx = contextvars.copy_context()
|
|
318
|
+
loop = asyncio.get_event_loop()
|
|
319
|
+
|
|
320
|
+
# Execute in executor with context restored
|
|
321
|
+
def run_with_context():
|
|
322
|
+
return func_ctx.run(func, *args, **kwargs)
|
|
323
|
+
|
|
324
|
+
return await loop.run_in_executor(None, run_with_context)
|
|
325
|
+
|
|
326
|
+
try:
|
|
327
|
+
result = await retry_with_backoff(
|
|
328
|
+
_execute_func,
|
|
329
|
+
max_retries=max_retries,
|
|
330
|
+
base_delay=base_delay,
|
|
331
|
+
max_delay=max_delay,
|
|
332
|
+
)
|
|
333
|
+
serialized_result = serialize(result)
|
|
334
|
+
|
|
335
|
+
# Set span status to success
|
|
336
|
+
step_span.set_status(Status(StatusCode.OK))
|
|
337
|
+
step_span.set_attributes(
|
|
338
|
+
{
|
|
339
|
+
"step.status": "completed",
|
|
340
|
+
"step.output": json.dumps(serialized_result),
|
|
341
|
+
}
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
# Publish step_finish event
|
|
345
|
+
await self._publish_step_event(
|
|
346
|
+
"step_finish",
|
|
347
|
+
step_key,
|
|
348
|
+
"run",
|
|
349
|
+
{"result": serialized_result},
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
# Save step output on success
|
|
353
|
+
await self._save_step_output(
|
|
354
|
+
step_key,
|
|
355
|
+
result,
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
# Span automatically ended and stored by DatabaseSpanExporter
|
|
359
|
+
return result
|
|
360
|
+
except Exception as e:
|
|
361
|
+
# Set span status to error
|
|
362
|
+
step_span.set_status(Status(StatusCode.ERROR, str(e)))
|
|
363
|
+
step_span.record_exception(e)
|
|
364
|
+
|
|
365
|
+
# Store error in span attributes as JSON string
|
|
366
|
+
error_message = str(e)
|
|
367
|
+
step_error = {
|
|
368
|
+
"message": error_message,
|
|
369
|
+
"type": type(e).__name__,
|
|
370
|
+
}
|
|
371
|
+
step_span.set_attributes(
|
|
372
|
+
{
|
|
373
|
+
"step.error": json.dumps(safe_serialize(step_error)),
|
|
374
|
+
"step.status": "failed",
|
|
375
|
+
}
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
# Save error to step output
|
|
379
|
+
await self._save_step_output_with_error(
|
|
380
|
+
step_key,
|
|
381
|
+
error_message,
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
# Span automatically ended and stored by DatabaseSpanExporter
|
|
385
|
+
raise StepExecutionError(
|
|
386
|
+
f"Step execution failed after {max_retries} retries: {error_message}"
|
|
387
|
+
) from e
|
|
388
|
+
finally:
|
|
389
|
+
# Restore previous span context values
|
|
390
|
+
set_span_context_in_execution_context(exec_context, old_span_context)
|
|
391
|
+
|
|
392
|
+
async def wait_for(
|
|
393
|
+
self,
|
|
394
|
+
step_key: str,
|
|
395
|
+
seconds: float | None = None,
|
|
396
|
+
minutes: float | None = None,
|
|
397
|
+
hours: float | None = None,
|
|
398
|
+
days: float | None = None,
|
|
399
|
+
weeks: float | None = None,
|
|
400
|
+
) -> None:
|
|
401
|
+
"""Wait for a time duration.
|
|
402
|
+
|
|
403
|
+
Args:
|
|
404
|
+
step_key: Step key identifier (must be unique per execution)
|
|
405
|
+
seconds: Optional seconds to wait
|
|
406
|
+
minutes: Optional minutes to wait
|
|
407
|
+
hours: Optional hours to wait
|
|
408
|
+
days: Optional days to wait
|
|
409
|
+
weeks: Optional weeks to wait
|
|
410
|
+
"""
|
|
411
|
+
# Check for existing step output
|
|
412
|
+
existing_step = await self._check_existing_step(step_key)
|
|
413
|
+
if existing_step:
|
|
414
|
+
return await self._handle_existing_step(existing_step)
|
|
415
|
+
|
|
416
|
+
wait_seconds, wait_until = await _get_wait_time(seconds, minutes, hours, days, weeks)
|
|
417
|
+
if wait_seconds <= 0:
|
|
418
|
+
await self._raise_step_execution_error(step_key, error="Wait duration must be positive")
|
|
419
|
+
|
|
420
|
+
# Add span event for wait
|
|
421
|
+
current_span = get_current_span()
|
|
422
|
+
if current_span and hasattr(current_span, "add_event"):
|
|
423
|
+
# Build attributes dict, filtering out None values (OpenTelemetry doesn't accept None)
|
|
424
|
+
attributes = {
|
|
425
|
+
"step.key": step_key,
|
|
426
|
+
"wait.seconds": wait_seconds,
|
|
427
|
+
}
|
|
428
|
+
if wait_until:
|
|
429
|
+
attributes["wait.until"] = wait_until.isoformat()
|
|
430
|
+
if seconds is not None:
|
|
431
|
+
attributes["wait.seconds_param"] = seconds
|
|
432
|
+
if minutes is not None:
|
|
433
|
+
attributes["wait.minutes_param"] = minutes
|
|
434
|
+
if hours is not None:
|
|
435
|
+
attributes["wait.hours_param"] = hours
|
|
436
|
+
if days is not None:
|
|
437
|
+
attributes["wait.days_param"] = days
|
|
438
|
+
if weeks is not None:
|
|
439
|
+
attributes["wait.weeks_param"] = weeks
|
|
440
|
+
|
|
441
|
+
current_span.add_event("step.wait_for", attributes=attributes)
|
|
442
|
+
|
|
443
|
+
# Get wait threshold from environment (default 10 seconds)
|
|
444
|
+
wait_threshold = float(os.getenv("POLOS_WAIT_THRESHOLD_SECONDS", "10.0"))
|
|
445
|
+
|
|
446
|
+
if wait_seconds <= wait_threshold:
|
|
447
|
+
# Short wait - just sleep without raising WaitException
|
|
448
|
+
await asyncio.sleep(wait_seconds)
|
|
449
|
+
result = {"wait_until": wait_until.isoformat()}
|
|
450
|
+
await self._save_step_output(
|
|
451
|
+
step_key,
|
|
452
|
+
result,
|
|
453
|
+
)
|
|
454
|
+
return
|
|
455
|
+
|
|
456
|
+
# Long wait - pause execution atomically
|
|
457
|
+
await _set_waiting(
|
|
458
|
+
self.ctx.execution_id,
|
|
459
|
+
wait_until,
|
|
460
|
+
"time",
|
|
461
|
+
step_key,
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
# Raise a special exception to pause execution
|
|
465
|
+
# The orchestrator will resume it when wait_until is reached
|
|
466
|
+
raise WaitException(f"Waiting until {wait_until.isoformat()}")
|
|
467
|
+
|
|
468
|
+
async def wait_until(self, step_key: str, timestamp: datetime) -> None:
|
|
469
|
+
"""Wait until a timestamp.
|
|
470
|
+
|
|
471
|
+
Args:
|
|
472
|
+
step_key: Step key identifier (must be unique per execution)
|
|
473
|
+
timestamp: Timestamp to wait until
|
|
474
|
+
"""
|
|
475
|
+
# Check for existing step output
|
|
476
|
+
existing_step = await self._check_existing_step(step_key)
|
|
477
|
+
if existing_step:
|
|
478
|
+
return await self._handle_existing_step(existing_step)
|
|
479
|
+
|
|
480
|
+
# Ensure date is timezone-aware (use UTC if naive)
|
|
481
|
+
date = timestamp
|
|
482
|
+
if date.tzinfo is None:
|
|
483
|
+
date = date.replace(tzinfo=timezone.utc)
|
|
484
|
+
|
|
485
|
+
# Convert to UTC
|
|
486
|
+
wait_until = date.astimezone(timezone.utc)
|
|
487
|
+
now = datetime.now(timezone.utc)
|
|
488
|
+
|
|
489
|
+
if wait_until < now:
|
|
490
|
+
# Date is in the past, raise error
|
|
491
|
+
await self._raise_step_execution_error(
|
|
492
|
+
step_key, error=f"Wait date {timestamp} is in the past"
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
# Calculate wait duration
|
|
496
|
+
wait_seconds = (wait_until - now).total_seconds()
|
|
497
|
+
if wait_seconds < 0:
|
|
498
|
+
await self._raise_step_execution_error(
|
|
499
|
+
step_key, error=f"Wait date {timestamp} is in the past"
|
|
500
|
+
)
|
|
501
|
+
|
|
502
|
+
# Add span event for wait
|
|
503
|
+
current_span = get_current_span()
|
|
504
|
+
if current_span and hasattr(current_span, "add_event"):
|
|
505
|
+
current_span.add_event(
|
|
506
|
+
"step.wait_until",
|
|
507
|
+
attributes={
|
|
508
|
+
"step.key": step_key,
|
|
509
|
+
"wait.timestamp": timestamp.isoformat(),
|
|
510
|
+
"wait.until": wait_until.isoformat(),
|
|
511
|
+
"wait.seconds": wait_seconds,
|
|
512
|
+
},
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
# Get wait threshold from environment (default 10 seconds)
|
|
516
|
+
wait_threshold = float(os.getenv("POLOS_WAIT_THRESHOLD_SECONDS", "10.0"))
|
|
517
|
+
|
|
518
|
+
if wait_seconds <= wait_threshold:
|
|
519
|
+
# Short wait - just sleep without raising WaitException
|
|
520
|
+
await asyncio.sleep(wait_seconds)
|
|
521
|
+
result = {"wait_until": wait_until.isoformat()}
|
|
522
|
+
await self._save_step_output(
|
|
523
|
+
step_key,
|
|
524
|
+
result,
|
|
525
|
+
)
|
|
526
|
+
return
|
|
527
|
+
|
|
528
|
+
# Long wait - pause execution atomically
|
|
529
|
+
await _set_waiting(
|
|
530
|
+
self.ctx.execution_id,
|
|
531
|
+
wait_until,
|
|
532
|
+
"time",
|
|
533
|
+
step_key,
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
# Raise a special exception to pause execution
|
|
537
|
+
# The orchestrator will resume it when wait_until is reached
|
|
538
|
+
raise WaitException(f"Waiting until {wait_until.isoformat()}")
|
|
539
|
+
|
|
540
|
+
async def wait_for_event(
|
|
541
|
+
self, step_key: str, topic: str, timeout: int | None = None
|
|
542
|
+
) -> EventPayload:
|
|
543
|
+
"""Wait for an event on a topic.
|
|
544
|
+
|
|
545
|
+
Args:
|
|
546
|
+
step_key: Step key identifier (must be unique per execution)
|
|
547
|
+
topic: Event topic to wait for
|
|
548
|
+
timeout: Optional timeout in seconds. If provided, wait will expire after this duration.
|
|
549
|
+
|
|
550
|
+
Returns:
|
|
551
|
+
EventPayload: Event payload with sequence_id, topic, event_type, data, and created_at
|
|
552
|
+
"""
|
|
553
|
+
# Check for existing step output
|
|
554
|
+
existing_step = await self._check_existing_step(step_key)
|
|
555
|
+
if existing_step:
|
|
556
|
+
result = await self._handle_existing_step(existing_step)
|
|
557
|
+
# Convert dict to EventPayload
|
|
558
|
+
if isinstance(result, dict):
|
|
559
|
+
return EventPayload.model_validate(result)
|
|
560
|
+
# If not event data, return as-is (shouldn't happen for wait_for_event)
|
|
561
|
+
return result
|
|
562
|
+
|
|
563
|
+
# Calculate expires_at if timeout is provided
|
|
564
|
+
expires_at = None
|
|
565
|
+
if timeout is not None:
|
|
566
|
+
expires_at = datetime.now(timezone.utc) + timedelta(seconds=timeout)
|
|
567
|
+
|
|
568
|
+
# Add span event for wait
|
|
569
|
+
current_span = get_current_span()
|
|
570
|
+
if current_span and hasattr(current_span, "add_event"):
|
|
571
|
+
wait_attributes = {
|
|
572
|
+
"step.key": step_key,
|
|
573
|
+
"wait.topic": topic,
|
|
574
|
+
}
|
|
575
|
+
if timeout is not None:
|
|
576
|
+
wait_attributes["wait.timeout"] = timeout
|
|
577
|
+
if expires_at is not None:
|
|
578
|
+
wait_attributes["wait.expires_at"] = expires_at.isoformat()
|
|
579
|
+
current_span.add_event("step.wait_for_event", attributes=wait_attributes)
|
|
580
|
+
|
|
581
|
+
# Add row in wait_steps with wait_type="event", wait_topic=topic, expires_at=timeout
|
|
582
|
+
await _set_waiting(
|
|
583
|
+
self.ctx.execution_id,
|
|
584
|
+
wait_until=expires_at,
|
|
585
|
+
wait_type="event",
|
|
586
|
+
step_key=step_key,
|
|
587
|
+
wait_topic=topic,
|
|
588
|
+
expires_at=expires_at,
|
|
589
|
+
)
|
|
590
|
+
|
|
591
|
+
# Raise WaitException to pause execution
|
|
592
|
+
# When resumed, execution continues from here and will check execution_step_outputs again
|
|
593
|
+
raise WaitException(f"Waiting for event on topic: {topic}")
|
|
594
|
+
|
|
595
|
+
async def publish_event(
|
|
596
|
+
self,
|
|
597
|
+
step_key: str,
|
|
598
|
+
topic: str,
|
|
599
|
+
data: dict[str, Any],
|
|
600
|
+
event_type: str | None = None,
|
|
601
|
+
) -> None:
|
|
602
|
+
"""Publish an event as a step.
|
|
603
|
+
|
|
604
|
+
Args:
|
|
605
|
+
step_key: Step key identifier (must be unique per execution)
|
|
606
|
+
topic: Event topic
|
|
607
|
+
data: Event data
|
|
608
|
+
event_type: Optional event type
|
|
609
|
+
"""
|
|
610
|
+
# Check for existing step output
|
|
611
|
+
existing_step = await self._check_existing_step(step_key)
|
|
612
|
+
if existing_step:
|
|
613
|
+
return await self._handle_existing_step(existing_step)
|
|
614
|
+
|
|
615
|
+
events = [EventData(data=data, event_type=event_type)]
|
|
616
|
+
# Publish event
|
|
617
|
+
client = get_client_or_raise()
|
|
618
|
+
await batch_publish(
|
|
619
|
+
client=client,
|
|
620
|
+
topic=topic,
|
|
621
|
+
events=events,
|
|
622
|
+
execution_id=self.ctx.execution_id,
|
|
623
|
+
root_execution_id=self.ctx.root_execution_id,
|
|
624
|
+
)
|
|
625
|
+
|
|
626
|
+
await self._save_step_output(
|
|
627
|
+
step_key,
|
|
628
|
+
None,
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
async def publish_workflow_event(
|
|
632
|
+
self,
|
|
633
|
+
step_key: str,
|
|
634
|
+
data: dict[str, Any],
|
|
635
|
+
event_type: str | None = None,
|
|
636
|
+
) -> None:
|
|
637
|
+
"""Publish an event for the current workflow as a step.
|
|
638
|
+
|
|
639
|
+
Args:
|
|
640
|
+
step_key: Step key identifier (must be unique per execution)
|
|
641
|
+
data: Event data
|
|
642
|
+
event_type: Optional event type
|
|
643
|
+
"""
|
|
644
|
+
topic = f"workflow:{self.ctx.root_execution_id or self.ctx.execution_id}"
|
|
645
|
+
return await self.publish_event(step_key, topic, data, event_type)
|
|
646
|
+
|
|
647
|
+
async def suspend(
|
|
648
|
+
self, step_key: str, data: dict[str, Any] | BaseModel, timeout: int | None = None
|
|
649
|
+
) -> Any:
|
|
650
|
+
"""Suspend execution and wait for a resume event.
|
|
651
|
+
|
|
652
|
+
This internally uses wait_for_event and waits for an event on topic
|
|
653
|
+
f"{step_key}/{ctx.root_execution_id}/resume".
|
|
654
|
+
|
|
655
|
+
The data passed to suspend will be included in the suspend event that is
|
|
656
|
+
published internally. When resumed, the event data from the resume event is returned.
|
|
657
|
+
|
|
658
|
+
Args:
|
|
659
|
+
step_key: Step key identifier (must be unique per execution)
|
|
660
|
+
data: Data to associate with the suspend (can be dict or Pydantic BaseModel)
|
|
661
|
+
timeout: Optional timeout in seconds. If provided, wait will expire after this duration.
|
|
662
|
+
|
|
663
|
+
Returns:
|
|
664
|
+
Event data from the resume event
|
|
665
|
+
"""
|
|
666
|
+
# Check for existing step output
|
|
667
|
+
existing_step = await self._check_existing_step(step_key)
|
|
668
|
+
if existing_step:
|
|
669
|
+
return await self._handle_existing_step(existing_step)
|
|
670
|
+
|
|
671
|
+
serialized_data = serialize(data)
|
|
672
|
+
topic = f"{step_key}/{self.ctx.root_execution_id or self.ctx.execution_id}"
|
|
673
|
+
# Publish suspend event
|
|
674
|
+
client = get_client_or_raise()
|
|
675
|
+
await batch_publish(
|
|
676
|
+
client=client,
|
|
677
|
+
topic=topic,
|
|
678
|
+
events=[EventData(data=serialized_data, event_type="suspend")],
|
|
679
|
+
execution_id=self.ctx.execution_id,
|
|
680
|
+
root_execution_id=self.ctx.root_execution_id,
|
|
681
|
+
)
|
|
682
|
+
|
|
683
|
+
# Calculate expires_at if timeout is provided
|
|
684
|
+
expires_at = None
|
|
685
|
+
if timeout is not None:
|
|
686
|
+
expires_at = datetime.now(timezone.utc) + timedelta(seconds=timeout)
|
|
687
|
+
|
|
688
|
+
# Add row in wait_steps with wait_type="event", wait_topic=topic, expires_at=timeout
|
|
689
|
+
await _set_waiting(
|
|
690
|
+
self.ctx.execution_id,
|
|
691
|
+
wait_until=expires_at,
|
|
692
|
+
wait_type="event",
|
|
693
|
+
step_key=step_key,
|
|
694
|
+
wait_topic=topic,
|
|
695
|
+
expires_at=expires_at,
|
|
696
|
+
)
|
|
697
|
+
|
|
698
|
+
# Resume event will be added to step outputs by the orchestrator when it is received
|
|
699
|
+
|
|
700
|
+
# Raise WaitException to pause execution
|
|
701
|
+
# When resumed, execution continues from here and will check execution_step_outputs again
|
|
702
|
+
raise WaitException(f"Waiting for resume event: {topic}")
|
|
703
|
+
|
|
704
|
+
async def resume(
|
|
705
|
+
self,
|
|
706
|
+
step_key: str,
|
|
707
|
+
suspend_step_key: str,
|
|
708
|
+
suspend_execution_id: str,
|
|
709
|
+
data: dict[str, Any] | BaseModel,
|
|
710
|
+
) -> None:
|
|
711
|
+
"""Resume a suspended execution by publishing a resume event.
|
|
712
|
+
|
|
713
|
+
This publishes an event with topic f"{step_key}/{ctx.root_execution_id}",
|
|
714
|
+
event_type="resume", and data=data.
|
|
715
|
+
|
|
716
|
+
Args:
|
|
717
|
+
step_key: Step key identifier (must be unique per execution)
|
|
718
|
+
data: Data to pass in the resume event (can be dict or Pydantic BaseModel)
|
|
719
|
+
"""
|
|
720
|
+
serialized_data = serialize(data)
|
|
721
|
+
|
|
722
|
+
topic = f"{suspend_step_key}/{suspend_execution_id}/resume"
|
|
723
|
+
|
|
724
|
+
# Publish event with event_type="resume"
|
|
725
|
+
await self.publish_event(
|
|
726
|
+
step_key=step_key,
|
|
727
|
+
topic=topic,
|
|
728
|
+
data=serialized_data,
|
|
729
|
+
event_type="resume",
|
|
730
|
+
)
|
|
731
|
+
|
|
732
|
+
async def _invoke(
|
|
733
|
+
self,
|
|
734
|
+
step_key: str,
|
|
735
|
+
workflow: str | Workflow,
|
|
736
|
+
payload: Any,
|
|
737
|
+
initial_state: dict[str, Any] | BaseModel | None = None,
|
|
738
|
+
queue: str | None = None,
|
|
739
|
+
concurrency_key: str | None = None,
|
|
740
|
+
run_timeout_seconds: int | None = None,
|
|
741
|
+
wait_for_subworkflow: bool = False,
|
|
742
|
+
) -> Any:
|
|
743
|
+
"""
|
|
744
|
+
Invoke another workflow as a step.
|
|
745
|
+
|
|
746
|
+
Args:
|
|
747
|
+
step_key: Step key identifier (must be unique per execution)
|
|
748
|
+
workflow: Workflow ID or Workflow instance
|
|
749
|
+
payload: Payload for the workflow
|
|
750
|
+
queue: Optional queue name
|
|
751
|
+
concurrency_key: Optional concurrency key
|
|
752
|
+
wait_for_subworkflow: Whether to wait for sub-workflow completion
|
|
753
|
+
|
|
754
|
+
Returns:
|
|
755
|
+
A tuple containing [ExecutionHandle of the sub-workflow,
|
|
756
|
+
True if the step output was found, False otherwise]
|
|
757
|
+
"""
|
|
758
|
+
# Get workflow ID
|
|
759
|
+
if isinstance(workflow, Workflow):
|
|
760
|
+
workflow_id = workflow.id
|
|
761
|
+
workflow_instance = workflow
|
|
762
|
+
else:
|
|
763
|
+
workflow_id = workflow
|
|
764
|
+
workflow_instance = get_workflow(workflow_id)
|
|
765
|
+
|
|
766
|
+
if not workflow_instance:
|
|
767
|
+
raise StepExecutionError(f"Workflow {workflow_id} not found")
|
|
768
|
+
|
|
769
|
+
# Check for existing step output
|
|
770
|
+
existing_step = await self._check_existing_step(step_key)
|
|
771
|
+
if existing_step:
|
|
772
|
+
result = await self._handle_existing_step(existing_step)
|
|
773
|
+
return result, True
|
|
774
|
+
|
|
775
|
+
# Extract trace context for propagation to sub-workflow
|
|
776
|
+
exec_context = _execution_context.get()
|
|
777
|
+
traceparent = None
|
|
778
|
+
if exec_context:
|
|
779
|
+
# Get current span and extract traceparent
|
|
780
|
+
current_span = get_current_span()
|
|
781
|
+
if current_span:
|
|
782
|
+
traceparent = extract_traceparent(current_span)
|
|
783
|
+
|
|
784
|
+
# Invoke workflow
|
|
785
|
+
client = get_client_or_raise()
|
|
786
|
+
handle = await workflow_instance._invoke(
|
|
787
|
+
client,
|
|
788
|
+
payload,
|
|
789
|
+
initial_state=initial_state,
|
|
790
|
+
queue=queue,
|
|
791
|
+
concurrency_key=concurrency_key,
|
|
792
|
+
session_id=self.ctx.session_id,
|
|
793
|
+
user_id=self.ctx.user_id,
|
|
794
|
+
deployment_id=self.ctx.deployment_id,
|
|
795
|
+
parent_execution_id=self.ctx.execution_id,
|
|
796
|
+
root_execution_id=self.ctx.root_execution_id or self.ctx.execution_id,
|
|
797
|
+
step_key=step_key if wait_for_subworkflow else None,
|
|
798
|
+
wait_for_subworkflow=wait_for_subworkflow,
|
|
799
|
+
otel_traceparent=traceparent,
|
|
800
|
+
run_timeout_seconds=run_timeout_seconds,
|
|
801
|
+
)
|
|
802
|
+
|
|
803
|
+
if wait_for_subworkflow:
|
|
804
|
+
# No need to save the handle
|
|
805
|
+
# Orchestrator will save the output on completion
|
|
806
|
+
return None, False
|
|
807
|
+
else:
|
|
808
|
+
await self._save_step_output(
|
|
809
|
+
step_key,
|
|
810
|
+
handle,
|
|
811
|
+
)
|
|
812
|
+
return handle, False
|
|
813
|
+
|
|
814
|
+
async def invoke(
|
|
815
|
+
self,
|
|
816
|
+
step_key: str,
|
|
817
|
+
workflow: str | Workflow,
|
|
818
|
+
payload: Any,
|
|
819
|
+
initial_state: BaseModel | dict[str, Any] | None = None,
|
|
820
|
+
queue: str | None = None,
|
|
821
|
+
concurrency_key: str | None = None,
|
|
822
|
+
run_timeout_seconds: int | None = None,
|
|
823
|
+
) -> Any:
|
|
824
|
+
"""
|
|
825
|
+
Invoke another workflow as a step.
|
|
826
|
+
|
|
827
|
+
Args:
|
|
828
|
+
step_key: Step key identifier (must be unique per execution)
|
|
829
|
+
workflow: Workflow ID or Workflow instance
|
|
830
|
+
payload: Payload for the workflow
|
|
831
|
+
queue: Optional queue name
|
|
832
|
+
concurrency_key: Optional concurrency key
|
|
833
|
+
run_timeout_seconds: Optional timeout in seconds
|
|
834
|
+
|
|
835
|
+
Returns:
|
|
836
|
+
ExecutionHandle of the sub-workflow
|
|
837
|
+
"""
|
|
838
|
+
# Note: step_finish will be emitted by the orchestrator when it saves the step output
|
|
839
|
+
result, found = await self._invoke(
|
|
840
|
+
step_key,
|
|
841
|
+
workflow,
|
|
842
|
+
payload,
|
|
843
|
+
initial_state,
|
|
844
|
+
queue,
|
|
845
|
+
concurrency_key,
|
|
846
|
+
wait_for_subworkflow=False,
|
|
847
|
+
run_timeout_seconds=run_timeout_seconds,
|
|
848
|
+
)
|
|
849
|
+
return result
|
|
850
|
+
|
|
851
|
+
async def invoke_and_wait(
|
|
852
|
+
self,
|
|
853
|
+
step_key: str,
|
|
854
|
+
workflow: str | Workflow,
|
|
855
|
+
payload: Any,
|
|
856
|
+
initial_state: BaseModel | dict[str, Any] | None = None,
|
|
857
|
+
queue: str | None = None,
|
|
858
|
+
concurrency_key: str | None = None,
|
|
859
|
+
run_timeout_seconds: int | None = None,
|
|
860
|
+
) -> Any:
|
|
861
|
+
"""
|
|
862
|
+
Invoke another workflow as a step.
|
|
863
|
+
|
|
864
|
+
This creates a child workflow execution and waits for it.
|
|
865
|
+
Note that this will raise WaitException to pause execution until the
|
|
866
|
+
child workflow completes.
|
|
867
|
+
|
|
868
|
+
Args:
|
|
869
|
+
step_key: Step key identifier (must be unique per execution)
|
|
870
|
+
workflow: Workflow ID or Workflow instance
|
|
871
|
+
payload: Payload for the workflow
|
|
872
|
+
queue: Optional queue name
|
|
873
|
+
concurrency_key: Optional concurrency key
|
|
874
|
+
|
|
875
|
+
Returns:
|
|
876
|
+
Any: Result of the child workflow
|
|
877
|
+
"""
|
|
878
|
+
result, found = await self._invoke(
|
|
879
|
+
step_key,
|
|
880
|
+
workflow,
|
|
881
|
+
payload,
|
|
882
|
+
initial_state,
|
|
883
|
+
queue,
|
|
884
|
+
concurrency_key,
|
|
885
|
+
wait_for_subworkflow=True,
|
|
886
|
+
run_timeout_seconds=run_timeout_seconds,
|
|
887
|
+
)
|
|
888
|
+
if found:
|
|
889
|
+
# Step is complete, return result
|
|
890
|
+
return result
|
|
891
|
+
|
|
892
|
+
# Step did not exist, execute wait
|
|
893
|
+
from ..features.wait import WaitException
|
|
894
|
+
|
|
895
|
+
workflow_id = workflow.id if isinstance(workflow, Workflow) else workflow
|
|
896
|
+
raise WaitException(f"Waiting for sub-workflow {workflow_id} to complete")
|
|
897
|
+
|
|
898
|
+
async def batch_invoke(
|
|
899
|
+
self,
|
|
900
|
+
step_key: str,
|
|
901
|
+
workflows: list[BatchWorkflowInput],
|
|
902
|
+
) -> list[ExecutionHandle]:
|
|
903
|
+
"""
|
|
904
|
+
Invoke multiple workflows as a single step using the batch endpoint.
|
|
905
|
+
|
|
906
|
+
Args:
|
|
907
|
+
workflows: List of BatchWorkflowInput objects with 'id'
|
|
908
|
+
(workflow_id string) and 'payload' (dict or Pydantic model)
|
|
909
|
+
|
|
910
|
+
Returns:
|
|
911
|
+
List of ExecutionHandle objects for the submitted workflows
|
|
912
|
+
"""
|
|
913
|
+
if not workflows:
|
|
914
|
+
return []
|
|
915
|
+
|
|
916
|
+
# Check for existing step output
|
|
917
|
+
existing_step = await self._check_existing_step(step_key)
|
|
918
|
+
if existing_step:
|
|
919
|
+
# Extract handles from existing step output
|
|
920
|
+
existing_output = await self._handle_existing_step(existing_step)
|
|
921
|
+
if existing_output and isinstance(existing_output, list):
|
|
922
|
+
handles = [
|
|
923
|
+
ExecutionHandle.model_validate(handle_data) for handle_data in existing_output
|
|
924
|
+
]
|
|
925
|
+
return handles
|
|
926
|
+
return []
|
|
927
|
+
|
|
928
|
+
# Extract trace context for propagation to sub-workflow
|
|
929
|
+
exec_context = _execution_context.get()
|
|
930
|
+
traceparent = None
|
|
931
|
+
if exec_context:
|
|
932
|
+
# Get current span and extract traceparent
|
|
933
|
+
current_span = get_current_span()
|
|
934
|
+
if current_span:
|
|
935
|
+
traceparent = extract_traceparent(current_span)
|
|
936
|
+
|
|
937
|
+
# Build workflow requests for batch submission
|
|
938
|
+
workflow_requests = []
|
|
939
|
+
for workflow_input in workflows:
|
|
940
|
+
workflow_id = workflow_input.id
|
|
941
|
+
payload = serialize(workflow_input.payload)
|
|
942
|
+
|
|
943
|
+
workflow_obj = get_workflow(workflow_id)
|
|
944
|
+
if not workflow_obj:
|
|
945
|
+
raise StepExecutionError(f"Workflow '{workflow_id}' not found")
|
|
946
|
+
|
|
947
|
+
workflow_req = {
|
|
948
|
+
"workflow_id": workflow_id,
|
|
949
|
+
"payload": payload,
|
|
950
|
+
"initial_state": serialize(workflow_input.initial_state),
|
|
951
|
+
"run_timeout_seconds": workflow_input.run_timeout_seconds,
|
|
952
|
+
}
|
|
953
|
+
|
|
954
|
+
# Per-workflow properties (queue_name, concurrency_key, etc.)
|
|
955
|
+
if workflow_obj.queue_name is not None:
|
|
956
|
+
workflow_req["queue_name"] = workflow_obj.queue_name
|
|
957
|
+
|
|
958
|
+
if workflow_obj.queue_concurrency_limit is not None:
|
|
959
|
+
workflow_req["queue_concurrency_limit"] = workflow_obj.queue_concurrency_limit
|
|
960
|
+
|
|
961
|
+
workflow_requests.append(workflow_req)
|
|
962
|
+
|
|
963
|
+
# Submit all workflows in a single batch using the batch endpoint
|
|
964
|
+
client = get_client_or_raise()
|
|
965
|
+
handles = await client._submit_workflows(
|
|
966
|
+
workflows=workflow_requests,
|
|
967
|
+
deployment_id=self.ctx.deployment_id,
|
|
968
|
+
parent_execution_id=self.ctx.execution_id,
|
|
969
|
+
root_execution_id=self.ctx.root_execution_id or self.ctx.execution_id,
|
|
970
|
+
step_key=None, # Don't set step_key since we don't want to wait
|
|
971
|
+
# for the batch to complete
|
|
972
|
+
session_id=self.ctx.session_id,
|
|
973
|
+
user_id=self.ctx.user_id,
|
|
974
|
+
wait_for_subworkflow=False, # batch_invoke is fire-and-forget
|
|
975
|
+
otel_traceparent=traceparent,
|
|
976
|
+
)
|
|
977
|
+
|
|
978
|
+
await self._save_step_output(
|
|
979
|
+
step_key,
|
|
980
|
+
handles,
|
|
981
|
+
)
|
|
982
|
+
return handles
|
|
983
|
+
|
|
984
|
+
async def batch_invoke_and_wait(
|
|
985
|
+
self,
|
|
986
|
+
step_key: str,
|
|
987
|
+
workflows: list[BatchWorkflowInput],
|
|
988
|
+
) -> list[BatchStepResult]:
|
|
989
|
+
"""
|
|
990
|
+
Invoke multiple workflows as a single step using the batch endpoint
|
|
991
|
+
and wait for all to complete.
|
|
992
|
+
|
|
993
|
+
Args:
|
|
994
|
+
workflows: List of BatchWorkflowInput objects with 'id'
|
|
995
|
+
(workflow_id string) and 'payload' (dict or Pydantic model)
|
|
996
|
+
|
|
997
|
+
Returns:
|
|
998
|
+
List of BatchStepResult objects, one for each workflow
|
|
999
|
+
"""
|
|
1000
|
+
if not workflows:
|
|
1001
|
+
return []
|
|
1002
|
+
|
|
1003
|
+
# Check for existing step output
|
|
1004
|
+
existing_step = await self._check_existing_step(step_key)
|
|
1005
|
+
if existing_step:
|
|
1006
|
+
# Extract results from existing step output
|
|
1007
|
+
existing_output = None
|
|
1008
|
+
try:
|
|
1009
|
+
existing_output = await self._handle_existing_step(existing_step)
|
|
1010
|
+
except Exception:
|
|
1011
|
+
existing_output = existing_step.get("outputs")
|
|
1012
|
+
|
|
1013
|
+
if existing_output and isinstance(existing_output, list):
|
|
1014
|
+
# Reconstruct BatchStepResult objects from stored dicts
|
|
1015
|
+
batch_results = []
|
|
1016
|
+
for item in existing_output:
|
|
1017
|
+
result = item.get("result")
|
|
1018
|
+
if result:
|
|
1019
|
+
deserialized_result = await deserialize(
|
|
1020
|
+
result, item.get("result_schema_name")
|
|
1021
|
+
)
|
|
1022
|
+
item["result"] = deserialized_result
|
|
1023
|
+
|
|
1024
|
+
batch_results.append(BatchStepResult.model_validate(item))
|
|
1025
|
+
return batch_results
|
|
1026
|
+
return []
|
|
1027
|
+
|
|
1028
|
+
# Extract trace context for propagation to sub-workflow
|
|
1029
|
+
exec_context = _execution_context.get()
|
|
1030
|
+
traceparent = None
|
|
1031
|
+
if exec_context:
|
|
1032
|
+
# Get current span and extract traceparent
|
|
1033
|
+
current_span = get_current_span()
|
|
1034
|
+
if current_span:
|
|
1035
|
+
traceparent = extract_traceparent(current_span)
|
|
1036
|
+
|
|
1037
|
+
# Build workflow requests for batch submission
|
|
1038
|
+
workflow_requests = []
|
|
1039
|
+
for _i, workflow_input in enumerate(workflows):
|
|
1040
|
+
workflow_id = workflow_input.id
|
|
1041
|
+
payload = serialize(workflow_input.payload)
|
|
1042
|
+
|
|
1043
|
+
workflow_obj = get_workflow(workflow_id)
|
|
1044
|
+
if not workflow_obj:
|
|
1045
|
+
raise StepExecutionError(f"Workflow '{workflow_id}' not found")
|
|
1046
|
+
|
|
1047
|
+
workflow_req = {
|
|
1048
|
+
"workflow_id": workflow_id,
|
|
1049
|
+
"payload": payload,
|
|
1050
|
+
"initial_state": serialize(workflow_input.initial_state),
|
|
1051
|
+
"run_timeout_seconds": workflow_input.run_timeout_seconds,
|
|
1052
|
+
}
|
|
1053
|
+
|
|
1054
|
+
# Per-workflow properties (queue_name, concurrency_key, etc.)
|
|
1055
|
+
if workflow_obj.queue_name is not None:
|
|
1056
|
+
workflow_req["queue_name"] = workflow_obj.queue_name
|
|
1057
|
+
|
|
1058
|
+
if workflow_obj.queue_concurrency_limit is not None:
|
|
1059
|
+
workflow_req["queue_concurrency_limit"] = workflow_obj.queue_concurrency_limit
|
|
1060
|
+
|
|
1061
|
+
workflow_requests.append(workflow_req)
|
|
1062
|
+
|
|
1063
|
+
# Submit all workflows in a single batch using the batch endpoint
|
|
1064
|
+
# with wait_for_subworkflow=True
|
|
1065
|
+
client = get_client_or_raise()
|
|
1066
|
+
await client._submit_workflows(
|
|
1067
|
+
workflows=workflow_requests,
|
|
1068
|
+
deployment_id=self.ctx.deployment_id,
|
|
1069
|
+
parent_execution_id=self.ctx.execution_id,
|
|
1070
|
+
root_execution_id=self.ctx.root_execution_id or self.ctx.execution_id,
|
|
1071
|
+
step_key=step_key,
|
|
1072
|
+
session_id=self.ctx.session_id,
|
|
1073
|
+
user_id=self.ctx.user_id,
|
|
1074
|
+
wait_for_subworkflow=True, # This will set parent to waiting until all complete
|
|
1075
|
+
otel_traceparent=traceparent,
|
|
1076
|
+
)
|
|
1077
|
+
|
|
1078
|
+
# Raise WaitException to pause execution
|
|
1079
|
+
# The orchestrator will resume the execution when all sub-workflows complete
|
|
1080
|
+
# When resumed, this function will be called again and all workflows should have completed
|
|
1081
|
+
from ..features.wait import WaitException
|
|
1082
|
+
|
|
1083
|
+
workflow_ids = [w.id for w in workflows]
|
|
1084
|
+
raise WaitException(f"Waiting for sub-workflows {workflow_ids} to complete")
|
|
1085
|
+
|
|
1086
|
+
async def agent_invoke(
|
|
1087
|
+
self,
|
|
1088
|
+
step_key: str,
|
|
1089
|
+
config: Any,
|
|
1090
|
+
) -> ExecutionHandle:
|
|
1091
|
+
"""
|
|
1092
|
+
Invoke a single agent as a workflow step without waiting for completion.
|
|
1093
|
+
|
|
1094
|
+
This is designed to be used with Agent.with_input(), which returns
|
|
1095
|
+
an AgentRunConfig instance.
|
|
1096
|
+
|
|
1097
|
+
Args:
|
|
1098
|
+
step_key: Step key identifier (must be unique per execution)
|
|
1099
|
+
config: AgentRunConfig instance
|
|
1100
|
+
"""
|
|
1101
|
+
from ..agents.agent import AgentRunConfig # Local import to avoid circular dependency
|
|
1102
|
+
|
|
1103
|
+
if not isinstance(config, AgentRunConfig):
|
|
1104
|
+
raise StepExecutionError(
|
|
1105
|
+
f"agent_invoke expects an AgentRunConfig, got {type(config).__name__}"
|
|
1106
|
+
)
|
|
1107
|
+
|
|
1108
|
+
payload = {
|
|
1109
|
+
"input": config.input,
|
|
1110
|
+
"streaming": config.streaming,
|
|
1111
|
+
"session_id": self.ctx.session_id,
|
|
1112
|
+
"user_id": self.ctx.user_id,
|
|
1113
|
+
"conversation_id": config.conversation_id**config.kwargs,
|
|
1114
|
+
}
|
|
1115
|
+
|
|
1116
|
+
workflow_obj = get_workflow(config.agent.id)
|
|
1117
|
+
if not workflow_obj:
|
|
1118
|
+
raise StepExecutionError(f"Agent workflow '{config.agent.id}' not found")
|
|
1119
|
+
|
|
1120
|
+
result, found = await self._invoke(
|
|
1121
|
+
step_key,
|
|
1122
|
+
workflow_obj,
|
|
1123
|
+
serialize(payload),
|
|
1124
|
+
initial_state=serialize(config.initial_state),
|
|
1125
|
+
wait_for_subworkflow=False,
|
|
1126
|
+
run_timeout_seconds=config.run_timeout_seconds,
|
|
1127
|
+
)
|
|
1128
|
+
return result
|
|
1129
|
+
|
|
1130
|
+
async def agent_invoke_and_wait(
|
|
1131
|
+
self,
|
|
1132
|
+
step_key: str,
|
|
1133
|
+
config: Any,
|
|
1134
|
+
) -> Any:
|
|
1135
|
+
"""
|
|
1136
|
+
Invoke a single agent as a workflow step and wait for completion.
|
|
1137
|
+
|
|
1138
|
+
This is designed to be used with Agent.with_input(), which returns
|
|
1139
|
+
an AgentRunConfig instance. The returned value is whatever the
|
|
1140
|
+
underlying agent workflow returns (typically an AgentResult).
|
|
1141
|
+
|
|
1142
|
+
Args:
|
|
1143
|
+
step_key: Step key identifier (must be unique per execution)
|
|
1144
|
+
config: AgentRunConfig instance
|
|
1145
|
+
"""
|
|
1146
|
+
from ..agents.agent import AgentRunConfig # Local import to avoid circular dependency
|
|
1147
|
+
|
|
1148
|
+
if not isinstance(config, AgentRunConfig):
|
|
1149
|
+
raise StepExecutionError(
|
|
1150
|
+
f"agent_invoke_and_wait expects an AgentRunConfig, got {type(config).__name__}"
|
|
1151
|
+
)
|
|
1152
|
+
|
|
1153
|
+
payload = {
|
|
1154
|
+
"input": config.input,
|
|
1155
|
+
"streaming": config.streaming,
|
|
1156
|
+
"session_id": self.ctx.session_id,
|
|
1157
|
+
"user_id": self.ctx.user_id,
|
|
1158
|
+
"conversation_id": config.conversation_id,
|
|
1159
|
+
**config.kwargs,
|
|
1160
|
+
}
|
|
1161
|
+
|
|
1162
|
+
workflow_obj = get_workflow(config.agent.id)
|
|
1163
|
+
if not workflow_obj:
|
|
1164
|
+
raise StepExecutionError(f"Agent workflow '{config.agent.id}' not found")
|
|
1165
|
+
|
|
1166
|
+
result, found = await self._invoke(
|
|
1167
|
+
step_key,
|
|
1168
|
+
workflow_obj,
|
|
1169
|
+
serialize(payload),
|
|
1170
|
+
initial_state=serialize(config.initial_state),
|
|
1171
|
+
wait_for_subworkflow=True,
|
|
1172
|
+
run_timeout_seconds=config.run_timeout_seconds,
|
|
1173
|
+
)
|
|
1174
|
+
if found:
|
|
1175
|
+
# Step is complete, return result
|
|
1176
|
+
return result
|
|
1177
|
+
|
|
1178
|
+
# Step did not exist yet; raise WaitException to pause execution
|
|
1179
|
+
from ..features.wait import WaitException
|
|
1180
|
+
|
|
1181
|
+
raise WaitException(f"Waiting for agent workflow '{config.agent.id}' to complete")
|
|
1182
|
+
|
|
1183
|
+
async def batch_agent_invoke(
|
|
1184
|
+
self,
|
|
1185
|
+
step_key: str,
|
|
1186
|
+
configs: list[Any],
|
|
1187
|
+
) -> list[ExecutionHandle]:
|
|
1188
|
+
"""
|
|
1189
|
+
Invoke multiple agents in parallel as a single workflow step.
|
|
1190
|
+
|
|
1191
|
+
This is designed to be used with Agent.with_input(), which returns
|
|
1192
|
+
AgentRunConfig instances.
|
|
1193
|
+
"""
|
|
1194
|
+
from ..agents.agent import AgentRunConfig # Local import to avoid circular dependency
|
|
1195
|
+
|
|
1196
|
+
workflows: list[BatchWorkflowInput] = []
|
|
1197
|
+
for config in configs:
|
|
1198
|
+
if not isinstance(config, AgentRunConfig):
|
|
1199
|
+
raise StepExecutionError(
|
|
1200
|
+
f"batch_agent_invoke expects AgentRunConfig instances, "
|
|
1201
|
+
f"got {type(config).__name__}"
|
|
1202
|
+
)
|
|
1203
|
+
payload = {
|
|
1204
|
+
"input": config.input,
|
|
1205
|
+
"streaming": config.streaming,
|
|
1206
|
+
"session_id": self.ctx.session_id,
|
|
1207
|
+
"user_id": self.ctx.user_id,
|
|
1208
|
+
"conversation_id": config.conversation_id,
|
|
1209
|
+
**config.kwargs,
|
|
1210
|
+
}
|
|
1211
|
+
workflows.append(
|
|
1212
|
+
BatchWorkflowInput(
|
|
1213
|
+
id=config.agent.id,
|
|
1214
|
+
payload=payload,
|
|
1215
|
+
initial_state=config.initial_state,
|
|
1216
|
+
run_timeout_seconds=config.run_timeout_seconds,
|
|
1217
|
+
)
|
|
1218
|
+
)
|
|
1219
|
+
|
|
1220
|
+
return await self.batch_invoke(step_key, workflows)
|
|
1221
|
+
|
|
1222
|
+
async def batch_agent_invoke_and_wait(
|
|
1223
|
+
self,
|
|
1224
|
+
step_key: str,
|
|
1225
|
+
configs: list[Any],
|
|
1226
|
+
) -> list[BatchStepResult]:
|
|
1227
|
+
"""
|
|
1228
|
+
Invoke multiple agents in parallel as a single workflow step and wait for all to complete.
|
|
1229
|
+
|
|
1230
|
+
This is designed to be used with Agent.with_input(), which returns
|
|
1231
|
+
AgentRunConfig instances. The results are returned as BatchStepResult,
|
|
1232
|
+
where each .result is the AgentResult from the corresponding agent.
|
|
1233
|
+
"""
|
|
1234
|
+
from ..agents.agent import AgentRunConfig # Local import to avoid circular dependency
|
|
1235
|
+
|
|
1236
|
+
workflows: list[BatchWorkflowInput] = []
|
|
1237
|
+
for config in configs:
|
|
1238
|
+
if not isinstance(config, AgentRunConfig):
|
|
1239
|
+
raise StepExecutionError(
|
|
1240
|
+
f"batch_agent_invoke_and_wait expects AgentRunConfig instances, "
|
|
1241
|
+
f"got {type(config).__name__}"
|
|
1242
|
+
)
|
|
1243
|
+
payload = {
|
|
1244
|
+
"input": config.input,
|
|
1245
|
+
"streaming": config.streaming,
|
|
1246
|
+
"session_id": self.ctx.session_id,
|
|
1247
|
+
"user_id": self.ctx.user_id,
|
|
1248
|
+
"conversation_id": config.conversation_id,
|
|
1249
|
+
**config.kwargs,
|
|
1250
|
+
}
|
|
1251
|
+
workflows.append(
|
|
1252
|
+
BatchWorkflowInput(
|
|
1253
|
+
id=config.agent.id,
|
|
1254
|
+
payload=payload,
|
|
1255
|
+
initial_state=config.initial_state,
|
|
1256
|
+
run_timeout_seconds=config.run_timeout_seconds,
|
|
1257
|
+
)
|
|
1258
|
+
)
|
|
1259
|
+
|
|
1260
|
+
return await self.batch_invoke_and_wait(step_key, workflows)
|
|
1261
|
+
|
|
1262
|
+
async def uuid(self, step_key: str) -> str:
|
|
1263
|
+
"""Get a UUID that is persisted across workflow runs.
|
|
1264
|
+
|
|
1265
|
+
On the first execution, generates a new UUID and saves it.
|
|
1266
|
+
On subsequent executions (replay/resume), returns the same UUID.
|
|
1267
|
+
|
|
1268
|
+
Args:
|
|
1269
|
+
step_key: Step key identifier (must be unique per execution)
|
|
1270
|
+
|
|
1271
|
+
Returns:
|
|
1272
|
+
UUID string (persisted across runs)
|
|
1273
|
+
"""
|
|
1274
|
+
# Check for existing step output
|
|
1275
|
+
existing_step = await self._check_existing_step(step_key)
|
|
1276
|
+
if existing_step:
|
|
1277
|
+
return await self._handle_existing_step(existing_step)
|
|
1278
|
+
|
|
1279
|
+
# Generate new UUID
|
|
1280
|
+
generated_uuid = str(uuid_module.uuid4())
|
|
1281
|
+
|
|
1282
|
+
# Save the UUID
|
|
1283
|
+
await self._save_step_output(step_key, generated_uuid)
|
|
1284
|
+
|
|
1285
|
+
return generated_uuid
|
|
1286
|
+
|
|
1287
|
+
async def now(self, step_key: str) -> int:
|
|
1288
|
+
"""Get current timestamp in milliseconds (durable across runs).
|
|
1289
|
+
|
|
1290
|
+
Args:
|
|
1291
|
+
step_key: Step key identifier (must be unique per execution)
|
|
1292
|
+
|
|
1293
|
+
Returns:
|
|
1294
|
+
Current timestamp in milliseconds since epoch
|
|
1295
|
+
"""
|
|
1296
|
+
# Check for existing step output
|
|
1297
|
+
existing_step = await self._check_existing_step(step_key)
|
|
1298
|
+
if existing_step:
|
|
1299
|
+
return existing_step.get("outputs", int(time.time() * 1000))
|
|
1300
|
+
|
|
1301
|
+
# Generate new timestamp
|
|
1302
|
+
timestamp = int(time.time() * 1000)
|
|
1303
|
+
|
|
1304
|
+
# Save step output
|
|
1305
|
+
await self._save_step_output(step_key, timestamp)
|
|
1306
|
+
|
|
1307
|
+
return timestamp
|
|
1308
|
+
|
|
1309
|
+
async def random(self, step_key: str) -> float:
|
|
1310
|
+
"""Get a random float between 0.0 and 1.0 that is persisted across workflow runs.
|
|
1311
|
+
|
|
1312
|
+
On the first execution, generates a new random number and saves it.
|
|
1313
|
+
On subsequent executions (replay/resume), returns the same random number.
|
|
1314
|
+
|
|
1315
|
+
Args:
|
|
1316
|
+
step_key: Step key identifier (must be unique per execution)
|
|
1317
|
+
|
|
1318
|
+
Returns:
|
|
1319
|
+
Random float between 0.0 and 1.0 (persisted across runs)
|
|
1320
|
+
"""
|
|
1321
|
+
# Check for existing step output
|
|
1322
|
+
existing_step = await self._check_existing_step(step_key)
|
|
1323
|
+
if existing_step:
|
|
1324
|
+
return await self._handle_existing_step(existing_step)
|
|
1325
|
+
|
|
1326
|
+
# Generate new random number
|
|
1327
|
+
random_value = random_module.random()
|
|
1328
|
+
|
|
1329
|
+
# Save the random number
|
|
1330
|
+
await self._save_step_output(step_key, random_value)
|
|
1331
|
+
|
|
1332
|
+
return random_value
|
|
1333
|
+
|
|
1334
|
+
@contextmanager
|
|
1335
|
+
def trace(self, name: str, attributes: dict[str, Any] | None = None):
|
|
1336
|
+
"""Create a custom span within the current step execution.
|
|
1337
|
+
|
|
1338
|
+
Args:
|
|
1339
|
+
name: Name of the span
|
|
1340
|
+
attributes: Optional dictionary of span attributes
|
|
1341
|
+
|
|
1342
|
+
Returns:
|
|
1343
|
+
Context manager for use in 'with' statement
|
|
1344
|
+
|
|
1345
|
+
Example:
|
|
1346
|
+
async def my_step(ctx):
|
|
1347
|
+
with ctx.step.trace("database_query", {"table": "users"}):
|
|
1348
|
+
result = await db.query("SELECT * FROM users")
|
|
1349
|
+
return result
|
|
1350
|
+
"""
|
|
1351
|
+
# Use the current OpenTelemetry span as parent (which should be the step/workflow span)
|
|
1352
|
+
# This ensures proper parent-child relationship within the same trace
|
|
1353
|
+
exec_context = _execution_context.get()
|
|
1354
|
+
tracer = get_tracer()
|
|
1355
|
+
parent_context = get_parent_span_context_from_execution_context(exec_context)
|
|
1356
|
+
|
|
1357
|
+
# Create span using context manager
|
|
1358
|
+
with tracer.start_as_current_span(
|
|
1359
|
+
name=name, context=parent_context, attributes=attributes or {}
|
|
1360
|
+
) as span:
|
|
1361
|
+
# Update execution context with current span for nested spans
|
|
1362
|
+
# Save old values to restore later
|
|
1363
|
+
old_span_context = get_span_context_from_execution_context(exec_context)
|
|
1364
|
+
set_span_context_in_execution_context(exec_context, span.get_span_context())
|
|
1365
|
+
|
|
1366
|
+
try:
|
|
1367
|
+
yield span
|
|
1368
|
+
# Set status to success
|
|
1369
|
+
span.set_status(Status(StatusCode.OK))
|
|
1370
|
+
# Span automatically ended and stored by DatabaseSpanExporter
|
|
1371
|
+
except Exception as e:
|
|
1372
|
+
# Set status to error
|
|
1373
|
+
span.set_status(Status(StatusCode.ERROR, str(e)))
|
|
1374
|
+
span.record_exception(e)
|
|
1375
|
+
# Span automatically ended and stored by DatabaseSpanExporter
|
|
1376
|
+
raise
|
|
1377
|
+
finally:
|
|
1378
|
+
# Restore previous span context
|
|
1379
|
+
set_span_context_in_execution_context(exec_context, old_span_context)
|
|
1380
|
+
# Span context automatically cleaned up by context manager
|