polos-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- polos/__init__.py +105 -0
- polos/agents/__init__.py +7 -0
- polos/agents/agent.py +746 -0
- polos/agents/conversation_history.py +121 -0
- polos/agents/stop_conditions.py +280 -0
- polos/agents/stream.py +635 -0
- polos/core/__init__.py +0 -0
- polos/core/context.py +143 -0
- polos/core/state.py +26 -0
- polos/core/step.py +1380 -0
- polos/core/workflow.py +1192 -0
- polos/features/__init__.py +0 -0
- polos/features/events.py +456 -0
- polos/features/schedules.py +110 -0
- polos/features/tracing.py +605 -0
- polos/features/wait.py +82 -0
- polos/llm/__init__.py +9 -0
- polos/llm/generate.py +152 -0
- polos/llm/providers/__init__.py +5 -0
- polos/llm/providers/anthropic.py +615 -0
- polos/llm/providers/azure.py +42 -0
- polos/llm/providers/base.py +196 -0
- polos/llm/providers/fireworks.py +41 -0
- polos/llm/providers/gemini.py +40 -0
- polos/llm/providers/groq.py +40 -0
- polos/llm/providers/openai.py +1021 -0
- polos/llm/providers/together.py +40 -0
- polos/llm/stream.py +183 -0
- polos/middleware/__init__.py +0 -0
- polos/middleware/guardrail.py +148 -0
- polos/middleware/guardrail_executor.py +253 -0
- polos/middleware/hook.py +164 -0
- polos/middleware/hook_executor.py +104 -0
- polos/runtime/__init__.py +0 -0
- polos/runtime/batch.py +87 -0
- polos/runtime/client.py +841 -0
- polos/runtime/queue.py +42 -0
- polos/runtime/worker.py +1365 -0
- polos/runtime/worker_server.py +249 -0
- polos/tools/__init__.py +0 -0
- polos/tools/tool.py +587 -0
- polos/types/__init__.py +23 -0
- polos/types/types.py +116 -0
- polos/utils/__init__.py +27 -0
- polos/utils/agent.py +27 -0
- polos/utils/client_context.py +41 -0
- polos/utils/config.py +12 -0
- polos/utils/output_schema.py +311 -0
- polos/utils/retry.py +47 -0
- polos/utils/serializer.py +167 -0
- polos/utils/tracing.py +27 -0
- polos/utils/worker_singleton.py +40 -0
- polos_sdk-0.1.0.dist-info/METADATA +650 -0
- polos_sdk-0.1.0.dist-info/RECORD +55 -0
- polos_sdk-0.1.0.dist-info/WHEEL +4 -0
polos/core/workflow.py
ADDED
|
@@ -0,0 +1,1192 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import inspect
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
import time
|
|
8
|
+
from collections.abc import Callable
|
|
9
|
+
from contextvars import ContextVar
|
|
10
|
+
from typing import Any, Union
|
|
11
|
+
|
|
12
|
+
from pydantic import BaseModel
|
|
13
|
+
|
|
14
|
+
from ..features.tracing import (
|
|
15
|
+
create_context_from_traceparent,
|
|
16
|
+
create_context_with_trace_id,
|
|
17
|
+
generate_trace_id_from_execution_id,
|
|
18
|
+
get_tracer,
|
|
19
|
+
)
|
|
20
|
+
from ..features.wait import WaitException
|
|
21
|
+
from ..runtime.client import ExecutionHandle, PolosClient
|
|
22
|
+
from ..runtime.queue import Queue
|
|
23
|
+
from ..utils.serializer import safe_serialize, serialize
|
|
24
|
+
from .context import AgentContext, WorkflowContext
|
|
25
|
+
|
|
26
|
+
# Import OpenTelemetry types
|
|
27
|
+
try:
|
|
28
|
+
from opentelemetry import trace
|
|
29
|
+
from opentelemetry.trace import Status, StatusCode
|
|
30
|
+
except ImportError:
|
|
31
|
+
# Fallback if OpenTelemetry not available
|
|
32
|
+
class Status:
|
|
33
|
+
def __init__(self, code, description=None):
|
|
34
|
+
self.code = code
|
|
35
|
+
self.description = description
|
|
36
|
+
|
|
37
|
+
class StatusCode:
|
|
38
|
+
OK = "OK"
|
|
39
|
+
ERROR = "ERROR"
|
|
40
|
+
|
|
41
|
+
trace = None
|
|
42
|
+
|
|
43
|
+
logger = logging.getLogger(__name__)
|
|
44
|
+
|
|
45
|
+
# Global registry of workflows
|
|
46
|
+
_WORKFLOW_REGISTRY: dict[str, Workflow] = {}
|
|
47
|
+
|
|
48
|
+
# Context variables for tracking execution state
|
|
49
|
+
_execution_context: ContextVar[dict[str, Any] | None] = ContextVar(
|
|
50
|
+
"execution_context", default=None
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class StepExecutionError(Exception):
|
|
55
|
+
"""
|
|
56
|
+
Exception raised a step fails and the workflow must fail.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(self, reason: str | None = None):
|
|
60
|
+
self.reason = reason
|
|
61
|
+
super().__init__(reason)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class WorkflowTimeoutError(Exception):
|
|
65
|
+
"""
|
|
66
|
+
Exception raised when a workflow/agent/tool execution times out.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def __init__(self, execution_id: str | None = None, timeout_seconds: float | None = None):
|
|
70
|
+
self.execution_id = execution_id
|
|
71
|
+
self.timeout_seconds = timeout_seconds
|
|
72
|
+
message = f"Execution timed out after {timeout_seconds} seconds"
|
|
73
|
+
if execution_id:
|
|
74
|
+
message += f" (execution_id: {execution_id})"
|
|
75
|
+
super().__init__(message)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class Workflow:
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
id: str,
|
|
82
|
+
func: Callable,
|
|
83
|
+
workflow_type: str | None = "workflow",
|
|
84
|
+
queue_name: str | None = None,
|
|
85
|
+
queue_concurrency_limit: int | None = None,
|
|
86
|
+
trigger_on_event: str | None = None,
|
|
87
|
+
batch_size: int = 1,
|
|
88
|
+
batch_timeout_seconds: int | None = None,
|
|
89
|
+
schedule: bool | str | dict[str, Any] | None = None,
|
|
90
|
+
on_start: str | list[str] | Any | list[Any] | None = None,
|
|
91
|
+
on_end: str | list[str] | Any | list[Any] | None = None,
|
|
92
|
+
payload_schema_class: type[BaseModel] | None = None,
|
|
93
|
+
output_schema: type[BaseModel] | None = None,
|
|
94
|
+
state_schema: type[BaseModel] | None = None,
|
|
95
|
+
):
|
|
96
|
+
self.id = id
|
|
97
|
+
self.func = func
|
|
98
|
+
self.is_async = asyncio.iscoroutinefunction(func)
|
|
99
|
+
# Check if function has a payload parameter
|
|
100
|
+
sig = inspect.signature(func)
|
|
101
|
+
params = list(sig.parameters.values())
|
|
102
|
+
self.has_payload_param = len(params) >= 2
|
|
103
|
+
|
|
104
|
+
# Store payload schema class (extracted during decorator validation)
|
|
105
|
+
self._payload_schema_class = payload_schema_class
|
|
106
|
+
|
|
107
|
+
# If not provided by decorator, try to extract it (fallback for manual Workflow creation)
|
|
108
|
+
if self.has_payload_param and self._payload_schema_class is None:
|
|
109
|
+
second_param = params[1]
|
|
110
|
+
second_annotation = second_param.annotation
|
|
111
|
+
|
|
112
|
+
# Handle string annotations (forward references)
|
|
113
|
+
if isinstance(second_annotation, str):
|
|
114
|
+
# Try to resolve the annotation
|
|
115
|
+
try:
|
|
116
|
+
# Import from the function's module
|
|
117
|
+
func_module = inspect.getmodule(func)
|
|
118
|
+
if func_module:
|
|
119
|
+
# Try to evaluate the annotation in the function's module context
|
|
120
|
+
second_annotation = eval(second_annotation, func_module.__dict__)
|
|
121
|
+
except (NameError, AttributeError, SyntaxError):
|
|
122
|
+
pass
|
|
123
|
+
|
|
124
|
+
# Check if it's a Pydantic BaseModel subclass
|
|
125
|
+
if inspect.isclass(second_annotation) and issubclass(second_annotation, BaseModel):
|
|
126
|
+
self._payload_schema_class = second_annotation
|
|
127
|
+
elif hasattr(second_annotation, "__origin__"):
|
|
128
|
+
# Check if it's a Union type containing a BaseModel
|
|
129
|
+
origin = getattr(second_annotation, "__origin__", None)
|
|
130
|
+
if origin is Union or (hasattr(origin, "__name__") and origin.__name__ == "Union"):
|
|
131
|
+
args = getattr(second_annotation, "__args__", ())
|
|
132
|
+
for arg in args:
|
|
133
|
+
if inspect.isclass(arg) and issubclass(arg, BaseModel):
|
|
134
|
+
# Prefer the BaseModel over dict if both are present
|
|
135
|
+
self._payload_schema_class = arg
|
|
136
|
+
break
|
|
137
|
+
|
|
138
|
+
self.workflow_type = workflow_type
|
|
139
|
+
self.state_schema = state_schema
|
|
140
|
+
self.queue_name = queue_name # None means use workflow_id as queue name
|
|
141
|
+
self.queue_concurrency_limit = queue_concurrency_limit
|
|
142
|
+
self.trigger_on_event = trigger_on_event # Event topic that triggers this workflow
|
|
143
|
+
self.batch_size = batch_size # Max number of events to batch together
|
|
144
|
+
self.batch_timeout_seconds = batch_timeout_seconds # Max time to wait for batching
|
|
145
|
+
# Parse schedule configuration
|
|
146
|
+
# schedule can be: True (schedulable), False (not schedulable),
|
|
147
|
+
# cron string, or dict with cron/timezone
|
|
148
|
+
self.schedule = schedule
|
|
149
|
+
self.is_schedulable = (
|
|
150
|
+
False # Whether this workflow can be scheduled (schedule=True or has cron)
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
if schedule is True:
|
|
154
|
+
self.is_schedulable = True
|
|
155
|
+
elif schedule is False:
|
|
156
|
+
self.is_schedulable = False
|
|
157
|
+
elif schedule is not None:
|
|
158
|
+
# schedule is a cron string or dict - workflow is schedulable and has a default schedule
|
|
159
|
+
self.is_schedulable = True
|
|
160
|
+
|
|
161
|
+
# Scheduled workflows cannot be event-triggered
|
|
162
|
+
if schedule and trigger_on_event:
|
|
163
|
+
raise ValueError("Workflows cannot be both scheduled and event-triggered")
|
|
164
|
+
|
|
165
|
+
# Scheduled workflows cannot specify queues - they get their own queue with concurrency=1
|
|
166
|
+
if self.is_schedulable and (queue_name is not None or queue_concurrency_limit is not None):
|
|
167
|
+
raise ValueError("Scheduled workflows cannot specify a queue or concurrency limit")
|
|
168
|
+
|
|
169
|
+
# Lifecycle hooks - normalize to list of callables
|
|
170
|
+
self.on_start = self._normalize_hooks(on_start)
|
|
171
|
+
self.on_end = self._normalize_hooks(on_end)
|
|
172
|
+
|
|
173
|
+
# Output schema class (can be set in constructor or after execution
|
|
174
|
+
# if result is a Pydantic model)
|
|
175
|
+
self.output_schema: type[BaseModel] | None = output_schema
|
|
176
|
+
|
|
177
|
+
def _prepare_payload(self, payload: BaseModel | dict[str, Any] | None) -> dict[str, Any] | None:
|
|
178
|
+
"""
|
|
179
|
+
Validate and normalize payload for submission to the orchestrator.
|
|
180
|
+
|
|
181
|
+
Rules:
|
|
182
|
+
- If payload is None, return None
|
|
183
|
+
- If payload is a Pydantic BaseModel instance, convert to dict via model_dump(mode="json")
|
|
184
|
+
- If payload is a dict, ensure it is JSON serializable via json.dumps()
|
|
185
|
+
- Otherwise, raise TypeError
|
|
186
|
+
"""
|
|
187
|
+
if payload is None:
|
|
188
|
+
return None
|
|
189
|
+
|
|
190
|
+
# Pydantic model → dict
|
|
191
|
+
if isinstance(payload, BaseModel):
|
|
192
|
+
return payload.model_dump(mode="json")
|
|
193
|
+
|
|
194
|
+
# Require dict[str, Any] for non-Pydantic payloads
|
|
195
|
+
if isinstance(payload, dict):
|
|
196
|
+
try:
|
|
197
|
+
# Validate JSON serializability of the dict
|
|
198
|
+
json.dumps(payload)
|
|
199
|
+
except (TypeError, ValueError) as e:
|
|
200
|
+
raise TypeError(
|
|
201
|
+
f"Workflow '{self.id}' payload dict is not JSON serializable: {e}. "
|
|
202
|
+
f"Consider using a Pydantic BaseModel for structured data."
|
|
203
|
+
) from e
|
|
204
|
+
return payload
|
|
205
|
+
|
|
206
|
+
raise TypeError(
|
|
207
|
+
f"Workflow '{self.id}' payload must be a dict or Pydantic BaseModel instance, "
|
|
208
|
+
f"got {type(payload).__name__}."
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
def _normalize_hooks(self, hooks: Callable | list[Callable] | None) -> list[Callable]:
|
|
212
|
+
"""Normalize hooks to a list of callables.
|
|
213
|
+
|
|
214
|
+
Accepts:
|
|
215
|
+
- None: Returns empty list
|
|
216
|
+
- Callable: Single hook callable
|
|
217
|
+
- List[Callable]: List of hook callables
|
|
218
|
+
"""
|
|
219
|
+
|
|
220
|
+
if hooks is None:
|
|
221
|
+
return []
|
|
222
|
+
if callable(hooks):
|
|
223
|
+
return [hooks]
|
|
224
|
+
if isinstance(hooks, list):
|
|
225
|
+
result = []
|
|
226
|
+
for hook in hooks:
|
|
227
|
+
if callable(hook):
|
|
228
|
+
result.append(hook)
|
|
229
|
+
else:
|
|
230
|
+
raise TypeError(
|
|
231
|
+
f"Invalid hook type: {type(hook)}. Expected a callable "
|
|
232
|
+
f"(function decorated with @hook)."
|
|
233
|
+
)
|
|
234
|
+
return result
|
|
235
|
+
raise TypeError(f"Invalid hooks type: {type(hooks)}. Expected callable or List[callable].")
|
|
236
|
+
|
|
237
|
+
async def _execute(self, context: dict[str, Any], payload: Any) -> Any:
|
|
238
|
+
"""Execute the workflow with the given payload and checkpointing."""
|
|
239
|
+
execution_id = context.get("execution_id")
|
|
240
|
+
deployment_id = context.get("deployment_id")
|
|
241
|
+
parent_execution_id = context.get("parent_execution_id")
|
|
242
|
+
root_execution_id = context.get("root_execution_id")
|
|
243
|
+
retry_count = context.get("retry_count", 0)
|
|
244
|
+
created_at = context.get("created_at")
|
|
245
|
+
session_id = context.get("session_id")
|
|
246
|
+
user_id = context.get("user_id")
|
|
247
|
+
otel_traceparent = context.get("otel_traceparent")
|
|
248
|
+
otel_span_id = context.get("otel_span_id")
|
|
249
|
+
|
|
250
|
+
# Ensure execution_id is a string
|
|
251
|
+
if execution_id:
|
|
252
|
+
execution_id = str(execution_id)
|
|
253
|
+
if deployment_id:
|
|
254
|
+
deployment_id = str(deployment_id)
|
|
255
|
+
if parent_execution_id:
|
|
256
|
+
parent_execution_id = str(parent_execution_id)
|
|
257
|
+
if root_execution_id:
|
|
258
|
+
root_execution_id = str(root_execution_id)
|
|
259
|
+
if retry_count:
|
|
260
|
+
retry_count = int(retry_count)
|
|
261
|
+
if session_id:
|
|
262
|
+
session_id = str(session_id)
|
|
263
|
+
if user_id:
|
|
264
|
+
user_id = str(user_id)
|
|
265
|
+
|
|
266
|
+
# Set execution context for this workflow execution
|
|
267
|
+
# If we have root_execution_id, use it; otherwise, we are the root
|
|
268
|
+
effective_root_execution_id = root_execution_id if root_execution_id else execution_id
|
|
269
|
+
|
|
270
|
+
# Get initial_state from context if provided
|
|
271
|
+
initial_state = context.get("initial_state")
|
|
272
|
+
if initial_state:
|
|
273
|
+
initial_state = self.state_schema.model_validate(initial_state)
|
|
274
|
+
|
|
275
|
+
# Check if this is an Agent and create appropriate context
|
|
276
|
+
from ..agents.agent import Agent
|
|
277
|
+
|
|
278
|
+
if isinstance(self, Agent):
|
|
279
|
+
# Extract conversation_id from payload if provided
|
|
280
|
+
conversation_id = payload.get("conversation_id") if isinstance(payload, dict) else None
|
|
281
|
+
# Create AgentContext for agents
|
|
282
|
+
workflow_ctx = AgentContext(
|
|
283
|
+
agent_id=self.id,
|
|
284
|
+
execution_id=execution_id,
|
|
285
|
+
deployment_id=deployment_id,
|
|
286
|
+
parent_execution_id=parent_execution_id,
|
|
287
|
+
root_execution_id=effective_root_execution_id,
|
|
288
|
+
retry_count=retry_count,
|
|
289
|
+
model=self.model,
|
|
290
|
+
provider=self.provider,
|
|
291
|
+
system_prompt=self.system_prompt,
|
|
292
|
+
tools=self.tools,
|
|
293
|
+
temperature=self.temperature,
|
|
294
|
+
max_tokens=self.max_output_tokens,
|
|
295
|
+
session_id=session_id,
|
|
296
|
+
conversation_id=conversation_id,
|
|
297
|
+
user_id=user_id,
|
|
298
|
+
created_at=created_at,
|
|
299
|
+
otel_traceparent=otel_traceparent,
|
|
300
|
+
otel_span_id=otel_span_id,
|
|
301
|
+
state_schema=self.state_schema,
|
|
302
|
+
initial_state=initial_state,
|
|
303
|
+
)
|
|
304
|
+
else:
|
|
305
|
+
# Create WorkflowContext for regular workflows or tools
|
|
306
|
+
from ..tools.tool import Tool
|
|
307
|
+
|
|
308
|
+
workflow_type = "tool" if isinstance(self, Tool) else "workflow"
|
|
309
|
+
|
|
310
|
+
workflow_ctx = WorkflowContext(
|
|
311
|
+
workflow_id=self.id,
|
|
312
|
+
execution_id=execution_id,
|
|
313
|
+
deployment_id=deployment_id,
|
|
314
|
+
parent_execution_id=parent_execution_id,
|
|
315
|
+
root_execution_id=effective_root_execution_id,
|
|
316
|
+
retry_count=retry_count,
|
|
317
|
+
created_at=created_at,
|
|
318
|
+
session_id=session_id,
|
|
319
|
+
user_id=user_id,
|
|
320
|
+
workflow_type=workflow_type,
|
|
321
|
+
otel_traceparent=otel_traceparent,
|
|
322
|
+
otel_span_id=otel_span_id,
|
|
323
|
+
state_schema=self.state_schema,
|
|
324
|
+
initial_state=initial_state,
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
# Convert dict payload to Pydantic model if needed
|
|
328
|
+
prepared_payload = payload
|
|
329
|
+
if self.has_payload_param and self._payload_schema_class is not None:
|
|
330
|
+
if isinstance(payload, dict):
|
|
331
|
+
# Convert dict to Pydantic model
|
|
332
|
+
try:
|
|
333
|
+
prepared_payload = self._payload_schema_class.model_validate(payload)
|
|
334
|
+
except Exception as e:
|
|
335
|
+
raise ValueError(
|
|
336
|
+
f"Invalid payload for workflow '{self.id}': failed to "
|
|
337
|
+
f"validate against {self._payload_schema_class.__name__}: {e}"
|
|
338
|
+
) from e
|
|
339
|
+
elif payload is not None and not isinstance(payload, self._payload_schema_class):
|
|
340
|
+
# Payload is provided but not the right type
|
|
341
|
+
raise ValueError(
|
|
342
|
+
f"Invalid payload for workflow '{self.id}': "
|
|
343
|
+
f"expected {self._payload_schema_class.__name__} or dict, "
|
|
344
|
+
f"got {type(payload).__name__}"
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
return await self._execute_internal(workflow_ctx, prepared_payload)
|
|
348
|
+
|
|
349
|
+
async def _execute_internal(self, ctx: WorkflowContext, payload: Any) -> Any:
|
|
350
|
+
"""Internal execution method with shared logic for workflows and agents.
|
|
351
|
+
|
|
352
|
+
This method handles:
|
|
353
|
+
- Checking execution_step_outputs for replay
|
|
354
|
+
- Setting up execution context cache
|
|
355
|
+
- Publishing start event
|
|
356
|
+
- Executing on_start hooks
|
|
357
|
+
- Calling the workflow/agent function
|
|
358
|
+
- Executing on_end hooks
|
|
359
|
+
- Publishing finish event
|
|
360
|
+
- Handling WaitException
|
|
361
|
+
"""
|
|
362
|
+
from ..middleware.hook import HookAction, HookContext
|
|
363
|
+
from ..middleware.hook_executor import execute_hooks
|
|
364
|
+
|
|
365
|
+
# Determine span name based on workflow type
|
|
366
|
+
workflow_type = ctx.workflow_type or "workflow"
|
|
367
|
+
span_name = f"{workflow_type}.{ctx.workflow_id}"
|
|
368
|
+
|
|
369
|
+
traceparent = ctx.otel_traceparent
|
|
370
|
+
|
|
371
|
+
# Get parent context for sub-workflows, or create deterministic trace ID for root
|
|
372
|
+
parent_context = None
|
|
373
|
+
if traceparent:
|
|
374
|
+
# Sub-workflow: use parent's trace context
|
|
375
|
+
parent_context = create_context_from_traceparent(traceparent)
|
|
376
|
+
if parent_context is None:
|
|
377
|
+
logger.warning(
|
|
378
|
+
"Failed to extract trace context from traceparent: %s. Creating new trace.",
|
|
379
|
+
traceparent,
|
|
380
|
+
)
|
|
381
|
+
# Fall back to deterministic trace ID if extraction fails
|
|
382
|
+
trace_id = generate_trace_id_from_execution_id(
|
|
383
|
+
ctx.root_execution_id or ctx.execution_id
|
|
384
|
+
)
|
|
385
|
+
parent_context = create_context_with_trace_id(trace_id)
|
|
386
|
+
else:
|
|
387
|
+
# Root workflow: create deterministic trace ID from root_execution_id
|
|
388
|
+
trace_id = generate_trace_id_from_execution_id(
|
|
389
|
+
ctx.root_execution_id or ctx.execution_id
|
|
390
|
+
)
|
|
391
|
+
parent_context = create_context_with_trace_id(trace_id)
|
|
392
|
+
|
|
393
|
+
# Create root span for workflow execution using context manager
|
|
394
|
+
tracer = get_tracer()
|
|
395
|
+
span_attributes = {
|
|
396
|
+
f"{workflow_type}.id": ctx.workflow_id,
|
|
397
|
+
f"{workflow_type}.execution_id": ctx.execution_id,
|
|
398
|
+
f"{workflow_type}.parent_execution_id": ctx.parent_execution_id or "",
|
|
399
|
+
f"{workflow_type}.root_execution_id": ctx.root_execution_id or ctx.execution_id,
|
|
400
|
+
f"{workflow_type}.deployment_id": ctx.deployment_id,
|
|
401
|
+
f"{workflow_type}.type": workflow_type,
|
|
402
|
+
f"{workflow_type}.session_id": ctx.session_id or "",
|
|
403
|
+
f"{workflow_type}.user_id": ctx.user_id or "",
|
|
404
|
+
f"{workflow_type}.retry_count": ctx.retry_count or 0,
|
|
405
|
+
}
|
|
406
|
+
|
|
407
|
+
# If this is a resumed workflow, add previous_span_id attribute
|
|
408
|
+
if ctx.otel_span_id:
|
|
409
|
+
span_attributes[f"{workflow_type}.previous_span_id"] = ctx.otel_span_id
|
|
410
|
+
|
|
411
|
+
# For root workflows with deterministic trace_id, we need to attach the context
|
|
412
|
+
# so the IdGenerator can access it. For sub-workflows, we use context parameter.
|
|
413
|
+
context_token = None
|
|
414
|
+
if parent_context and not traceparent:
|
|
415
|
+
# Root workflow: attach context so IdGenerator can read trace_id
|
|
416
|
+
from opentelemetry import context as otel_context
|
|
417
|
+
|
|
418
|
+
context_token = otel_context.attach(parent_context)
|
|
419
|
+
|
|
420
|
+
with tracer.start_as_current_span(
|
|
421
|
+
name=span_name,
|
|
422
|
+
context=parent_context
|
|
423
|
+
if traceparent
|
|
424
|
+
else None, # For sub-workflows, pass context; for root, rely on attached context
|
|
425
|
+
attributes=span_attributes,
|
|
426
|
+
) as workflow_span:
|
|
427
|
+
# If this is a resumed workflow, add resumed event
|
|
428
|
+
if ctx.otel_span_id:
|
|
429
|
+
workflow_span.add_event(f"{workflow_type}.resumed")
|
|
430
|
+
# Update execution context with current span for nested spans
|
|
431
|
+
exec_context = ctx.to_dict()
|
|
432
|
+
span_context = workflow_span.get_span_context()
|
|
433
|
+
exec_context["_otel_span_context"] = span_context
|
|
434
|
+
exec_context["_otel_trace_id"] = format(span_context.trace_id, "032x")
|
|
435
|
+
exec_context["_otel_span_id"] = format(span_context.span_id, "016x")
|
|
436
|
+
exec_context["state"] = ctx.state # Store workflow_ctx for worker to access final_state
|
|
437
|
+
token = _execution_context.set(exec_context)
|
|
438
|
+
|
|
439
|
+
# Topic for workflow events
|
|
440
|
+
topic = f"workflow:{ctx.root_execution_id or ctx.execution_id}"
|
|
441
|
+
|
|
442
|
+
serialized_payload = serialize(payload) if payload is not None else None
|
|
443
|
+
|
|
444
|
+
# Store input in span attributes as JSON string
|
|
445
|
+
if payload is not None:
|
|
446
|
+
workflow_span.set_attribute(
|
|
447
|
+
f"{workflow_type}.input", json.dumps(serialized_payload)
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
# Store initial state in span if provided
|
|
451
|
+
if ctx.state is not None and self.state_schema:
|
|
452
|
+
try:
|
|
453
|
+
initial_state_dict = ctx.state.model_dump(mode="json")
|
|
454
|
+
workflow_span.set_attribute(
|
|
455
|
+
f"{workflow_type}.initial_state", json.dumps(initial_state_dict)
|
|
456
|
+
)
|
|
457
|
+
except Exception as e:
|
|
458
|
+
logger.warning(f"Failed to serialize initial state for span: {e}")
|
|
459
|
+
|
|
460
|
+
try:
|
|
461
|
+
# Publish start event
|
|
462
|
+
await ctx.step.publish_event(
|
|
463
|
+
"publish_start",
|
|
464
|
+
topic=topic,
|
|
465
|
+
event_type=f"{workflow_type}_start",
|
|
466
|
+
data={
|
|
467
|
+
"payload": serialized_payload,
|
|
468
|
+
"_metadata": {
|
|
469
|
+
"execution_id": ctx.execution_id,
|
|
470
|
+
"workflow_id": ctx.workflow_id,
|
|
471
|
+
},
|
|
472
|
+
},
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
# Execute on_start hooks
|
|
476
|
+
if self.on_start:
|
|
477
|
+
hook_context = HookContext(
|
|
478
|
+
workflow_id=self.id,
|
|
479
|
+
current_payload=payload,
|
|
480
|
+
)
|
|
481
|
+
hook_result = await execute_hooks("on_start", self.on_start, hook_context, ctx)
|
|
482
|
+
|
|
483
|
+
# Apply modifications from hooks
|
|
484
|
+
if hook_result.modified_payload is not None:
|
|
485
|
+
payload = hook_result.modified_payload
|
|
486
|
+
|
|
487
|
+
# Check hook action
|
|
488
|
+
if hook_result.action == HookAction.FAIL:
|
|
489
|
+
raise StepExecutionError(
|
|
490
|
+
hook_result.error_message or "Hook execution failed"
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
# Call the workflow/agent function with context and payload (if function expects it)
|
|
494
|
+
if self.is_async:
|
|
495
|
+
if self.has_payload_param:
|
|
496
|
+
result = await self.func(ctx, payload)
|
|
497
|
+
else:
|
|
498
|
+
result = await self.func(ctx)
|
|
499
|
+
else:
|
|
500
|
+
# Run sync function in executor to avoid blocking
|
|
501
|
+
loop = asyncio.get_event_loop()
|
|
502
|
+
if self.has_payload_param:
|
|
503
|
+
result = await loop.run_in_executor(None, self.func, ctx, payload)
|
|
504
|
+
else:
|
|
505
|
+
result = await loop.run_in_executor(None, self.func, ctx)
|
|
506
|
+
|
|
507
|
+
# Execute on_end hooks
|
|
508
|
+
if self.on_end:
|
|
509
|
+
hook_context = HookContext(
|
|
510
|
+
workflow_id=self.id,
|
|
511
|
+
current_payload=payload,
|
|
512
|
+
current_output=result,
|
|
513
|
+
)
|
|
514
|
+
hook_result = await execute_hooks("on_end", self.on_end, hook_context, ctx)
|
|
515
|
+
|
|
516
|
+
# Apply modifications from hooks
|
|
517
|
+
if hook_result.modified_output is not None:
|
|
518
|
+
result = hook_result.modified_output
|
|
519
|
+
|
|
520
|
+
# Check hook action
|
|
521
|
+
if hook_result.action == HookAction.FAIL:
|
|
522
|
+
raise StepExecutionError(
|
|
523
|
+
hook_result.error_message or "Hook execution failed"
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
serialized_result = serialize(result) if result is not None else None
|
|
527
|
+
# Publish finish event (only if we didn't hit WaitException)
|
|
528
|
+
await ctx.step.publish_event(
|
|
529
|
+
"publish_finish",
|
|
530
|
+
topic=topic,
|
|
531
|
+
event_type=f"{workflow_type}_finish",
|
|
532
|
+
data={
|
|
533
|
+
"result": serialized_result,
|
|
534
|
+
"_metadata": {
|
|
535
|
+
"execution_id": ctx.execution_id,
|
|
536
|
+
"workflow_id": ctx.workflow_id,
|
|
537
|
+
},
|
|
538
|
+
},
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
# Set span status to success
|
|
542
|
+
workflow_span.set_status(Status(StatusCode.OK))
|
|
543
|
+
workflow_span.set_attribute(f"{workflow_type}.status", "completed")
|
|
544
|
+
workflow_span.set_attribute(
|
|
545
|
+
f"{workflow_type}.result_size", len(str(result)) if result else 0
|
|
546
|
+
)
|
|
547
|
+
# Store output in span attributes as JSON string
|
|
548
|
+
if result is not None:
|
|
549
|
+
workflow_span.set_attribute(
|
|
550
|
+
f"{workflow_type}.output", json.dumps(serialized_result)
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
final_state = None
|
|
554
|
+
# Store final state in span if workflow has state_schema
|
|
555
|
+
if ctx.state is not None and self.state_schema:
|
|
556
|
+
try:
|
|
557
|
+
final_state = ctx.state.model_dump(mode="json")
|
|
558
|
+
workflow_span.set_attribute(
|
|
559
|
+
f"{workflow_type}.final_state", json.dumps(final_state)
|
|
560
|
+
)
|
|
561
|
+
except Exception as e:
|
|
562
|
+
logger.warning(f"Failed to serialize final state for span: {e}")
|
|
563
|
+
|
|
564
|
+
# Span automatically ended and stored by DatabaseSpanExporter
|
|
565
|
+
# when context manager exits
|
|
566
|
+
return result, final_state
|
|
567
|
+
except WaitException:
|
|
568
|
+
# Execution is paused for waiting - this is expected
|
|
569
|
+
# The orchestrator will resume it when the wait expires
|
|
570
|
+
# Do NOT publish finish event when WaitException is raised
|
|
571
|
+
workflow_span.set_status(Status(StatusCode.OK))
|
|
572
|
+
workflow_span.set_attribute(f"{workflow_type}.status", "waiting")
|
|
573
|
+
workflow_span.add_event(f"{workflow_type}.waiting")
|
|
574
|
+
|
|
575
|
+
# Save current span_id to database for resume linkage
|
|
576
|
+
span_context = workflow_span.get_span_context()
|
|
577
|
+
span_id_hex = format(span_context.span_id, "016x")
|
|
578
|
+
from ..runtime.client import update_execution_otel_span_id
|
|
579
|
+
|
|
580
|
+
try:
|
|
581
|
+
# Schedule the update as a background task (don't await to avoid blocking)
|
|
582
|
+
asyncio.create_task(
|
|
583
|
+
update_execution_otel_span_id(ctx.execution_id, span_id_hex)
|
|
584
|
+
)
|
|
585
|
+
except Exception as e:
|
|
586
|
+
# Log error but don't fail on span_id update failure
|
|
587
|
+
logger.warning(
|
|
588
|
+
f"Failed to update otel_span_id for execution {ctx.execution_id}: {e}"
|
|
589
|
+
)
|
|
590
|
+
|
|
591
|
+
# Span automatically ended and stored by DatabaseSpanExporter
|
|
592
|
+
# when context manager exits
|
|
593
|
+
raise
|
|
594
|
+
|
|
595
|
+
except Exception as e:
|
|
596
|
+
# Set span status to error
|
|
597
|
+
workflow_span.set_status(Status(StatusCode.ERROR, str(e)))
|
|
598
|
+
workflow_span.set_attribute(f"{workflow_type}.status", "failed")
|
|
599
|
+
workflow_span.record_exception(e)
|
|
600
|
+
|
|
601
|
+
# Store error in span attributes as JSON string
|
|
602
|
+
error_message = str(e)
|
|
603
|
+
workflow_error = {
|
|
604
|
+
"message": error_message,
|
|
605
|
+
"type": type(e).__name__,
|
|
606
|
+
}
|
|
607
|
+
workflow_span.set_attribute(
|
|
608
|
+
f"{workflow_type}.error", json.dumps(safe_serialize(workflow_error))
|
|
609
|
+
)
|
|
610
|
+
|
|
611
|
+
final_state = None
|
|
612
|
+
if ctx.state is not None and self.state_schema:
|
|
613
|
+
try:
|
|
614
|
+
final_state = ctx.state.model_dump(mode="json")
|
|
615
|
+
workflow_span.set_attribute(
|
|
616
|
+
f"{workflow_type}.final_state", json.dumps(final_state)
|
|
617
|
+
)
|
|
618
|
+
except Exception as e2:
|
|
619
|
+
logger.warning(f"Failed to serialize final state for error case: {e2}")
|
|
620
|
+
|
|
621
|
+
# Span automatically ended and stored by DatabaseSpanExporter
|
|
622
|
+
# when context manager exits
|
|
623
|
+
raise
|
|
624
|
+
|
|
625
|
+
finally:
|
|
626
|
+
# Restore previous context
|
|
627
|
+
_execution_context.reset(token)
|
|
628
|
+
# Detach OTel context if we attached it for root workflows
|
|
629
|
+
if context_token is not None:
|
|
630
|
+
from opentelemetry import context as otel_context
|
|
631
|
+
|
|
632
|
+
otel_context.detach(context_token)
|
|
633
|
+
# Span context automatically cleaned up by context manager
|
|
634
|
+
|
|
635
|
+
async def invoke(
|
|
636
|
+
self,
|
|
637
|
+
client: PolosClient,
|
|
638
|
+
payload: BaseModel | dict[str, Any] | None = None,
|
|
639
|
+
queue: str | None = None,
|
|
640
|
+
concurrency_key: str | None = None,
|
|
641
|
+
session_id: str | None = None,
|
|
642
|
+
user_id: str | None = None,
|
|
643
|
+
initial_state: BaseModel | dict[str, Any] | None = None,
|
|
644
|
+
run_timeout_seconds: int | None = None,
|
|
645
|
+
) -> ExecutionHandle:
|
|
646
|
+
"""Invoke workflow execution via orchestrator and return a handle immediately.
|
|
647
|
+
|
|
648
|
+
This is a fire-and-forget operation.
|
|
649
|
+
The workflow will be executed asynchronously and the handle will be returned immediately.
|
|
650
|
+
This workflow cannot be called from within a workflow or agent.
|
|
651
|
+
Use step.invoke() to call workflows from within workflows.
|
|
652
|
+
|
|
653
|
+
Args:
|
|
654
|
+
client: PolosClient instance
|
|
655
|
+
payload: Workflow payload
|
|
656
|
+
queue: Optional queue name (overrides workflow-level queue)
|
|
657
|
+
concurrency_key: Optional concurrency key for per-tenant queuing
|
|
658
|
+
session_id: Optional session ID (inherited from parent if not provided)
|
|
659
|
+
user_id: Optional user ID (inherited from parent if not provided)
|
|
660
|
+
|
|
661
|
+
Returns:
|
|
662
|
+
ExecutionHandle for monitoring and managing the execution
|
|
663
|
+
|
|
664
|
+
Raises:
|
|
665
|
+
ValueError: If workflow is event-triggered (cannot be invoked directly)
|
|
666
|
+
"""
|
|
667
|
+
# Check if we're in an execution context - fail if we are
|
|
668
|
+
if _execution_context.get() is not None:
|
|
669
|
+
raise RuntimeError(
|
|
670
|
+
"workflow.run() cannot be called from within a workflow or agent. "
|
|
671
|
+
"Use step.invoke() to call workflows from within workflows."
|
|
672
|
+
)
|
|
673
|
+
|
|
674
|
+
return await self._invoke(
|
|
675
|
+
client,
|
|
676
|
+
payload,
|
|
677
|
+
queue=queue,
|
|
678
|
+
concurrency_key=concurrency_key,
|
|
679
|
+
session_id=session_id,
|
|
680
|
+
user_id=user_id,
|
|
681
|
+
initial_state=initial_state,
|
|
682
|
+
run_timeout_seconds=run_timeout_seconds,
|
|
683
|
+
)
|
|
684
|
+
|
|
685
|
+
async def _invoke(
|
|
686
|
+
self,
|
|
687
|
+
client: PolosClient,
|
|
688
|
+
payload: BaseModel | dict[str, Any] | None = None,
|
|
689
|
+
queue: str | None = None,
|
|
690
|
+
concurrency_key: str | None = None,
|
|
691
|
+
batch_id: str | None = None,
|
|
692
|
+
session_id: str | None = None,
|
|
693
|
+
user_id: str | None = None,
|
|
694
|
+
deployment_id: str | None = None,
|
|
695
|
+
parent_execution_id: str | None = None,
|
|
696
|
+
root_execution_id: str | None = None,
|
|
697
|
+
step_key: str | None = None,
|
|
698
|
+
wait_for_subworkflow: bool = False,
|
|
699
|
+
otel_traceparent: str | None = None,
|
|
700
|
+
initial_state: BaseModel | dict[str, Any] | None = None,
|
|
701
|
+
run_timeout_seconds: int | None = None,
|
|
702
|
+
) -> ExecutionHandle:
|
|
703
|
+
"""Invoke workflow execution via orchestrator and return a handle immediately.
|
|
704
|
+
|
|
705
|
+
This is a fire-and-forget operation.
|
|
706
|
+
The workflow will be executed asynchronously and the handle will be returned immediately.
|
|
707
|
+
|
|
708
|
+
Args:
|
|
709
|
+
client: PolosClient instance
|
|
710
|
+
payload: Workflow payload
|
|
711
|
+
queue: Optional queue name (overrides workflow-level queue)
|
|
712
|
+
concurrency_key: Optional concurrency key for per-tenant queuing
|
|
713
|
+
batch_id: Optional batch ID for batching
|
|
714
|
+
session_id: Optional session ID (inherited from parent if not provided)
|
|
715
|
+
user_id: Optional user ID (inherited from parent if not provided)
|
|
716
|
+
step_key: Optional step_key (set when invoked from step.py)
|
|
717
|
+
|
|
718
|
+
Returns:
|
|
719
|
+
ExecutionHandle for monitoring and managing the execution
|
|
720
|
+
"""
|
|
721
|
+
|
|
722
|
+
if self.trigger_on_event and (payload is None or payload.get("events") is None):
|
|
723
|
+
raise ValueError(
|
|
724
|
+
f"Workflow '{self.id}' is event-triggered and should have events in the payload."
|
|
725
|
+
)
|
|
726
|
+
|
|
727
|
+
# Validate and normalize payload (dict or Pydantic BaseModel only)
|
|
728
|
+
# Only prepare payload if workflow expects it
|
|
729
|
+
if self.has_payload_param:
|
|
730
|
+
if payload is None:
|
|
731
|
+
raise ValueError(
|
|
732
|
+
f"Workflow '{self.id}' requires a payload parameter, but None was provided"
|
|
733
|
+
)
|
|
734
|
+
# payload = self._prepare_payload(payload)
|
|
735
|
+
payload = serialize(payload)
|
|
736
|
+
else:
|
|
737
|
+
# Workflow doesn't expect payload - ignore it if provided
|
|
738
|
+
if payload is not None:
|
|
739
|
+
# Warn but don't fail - user might be calling with payload by mistake
|
|
740
|
+
pass
|
|
741
|
+
payload = None
|
|
742
|
+
|
|
743
|
+
# Invoke the workflow (it will be checkpointed when it executes)
|
|
744
|
+
# For nested workflows called via step.invoke(), use workflow's own queue configuration
|
|
745
|
+
queue_name = queue if queue else self.queue_name if self.queue_name is not None else self.id
|
|
746
|
+
handle = await client._submit_workflow(
|
|
747
|
+
self.id,
|
|
748
|
+
payload,
|
|
749
|
+
deployment_id=deployment_id,
|
|
750
|
+
parent_execution_id=parent_execution_id,
|
|
751
|
+
root_execution_id=root_execution_id,
|
|
752
|
+
step_key=step_key,
|
|
753
|
+
queue_name=queue_name,
|
|
754
|
+
queue_concurrency_limit=self.queue_concurrency_limit,
|
|
755
|
+
concurrency_key=concurrency_key,
|
|
756
|
+
wait_for_subworkflow=wait_for_subworkflow,
|
|
757
|
+
batch_id=batch_id,
|
|
758
|
+
session_id=session_id,
|
|
759
|
+
user_id=user_id,
|
|
760
|
+
otel_traceparent=otel_traceparent,
|
|
761
|
+
initial_state=serialize(initial_state),
|
|
762
|
+
run_timeout_seconds=run_timeout_seconds,
|
|
763
|
+
)
|
|
764
|
+
return handle
|
|
765
|
+
|
|
766
|
+
async def run(
|
|
767
|
+
self,
|
|
768
|
+
client: PolosClient,
|
|
769
|
+
payload: BaseModel | dict[str, Any] | None = None,
|
|
770
|
+
queue: str | None = None,
|
|
771
|
+
concurrency_key: str | None = None,
|
|
772
|
+
session_id: str | None = None,
|
|
773
|
+
user_id: str | None = None,
|
|
774
|
+
timeout: float | None = 600.0,
|
|
775
|
+
initial_state: BaseModel | dict[str, Any] | None = None,
|
|
776
|
+
) -> Any:
|
|
777
|
+
"""
|
|
778
|
+
Run workflow and return final result (wait for completion).
|
|
779
|
+
|
|
780
|
+
This method cannot be called from within an execution context
|
|
781
|
+
(e.g., from within a workflow).
|
|
782
|
+
Use step.invoke_and_wait() to call workflows from within workflows.
|
|
783
|
+
|
|
784
|
+
Args:
|
|
785
|
+
client: PolosClient instance
|
|
786
|
+
payload: Workflow payload (dict or Pydantic BaseModel)
|
|
787
|
+
queue: Optional queue name (overrides workflow-level queue)
|
|
788
|
+
concurrency_key: Optional concurrency key for per-tenant queuing
|
|
789
|
+
session_id: Optional session ID
|
|
790
|
+
user_id: Optional user ID
|
|
791
|
+
timeout: Optional timeout in seconds (default: 600 seconds / 10 minutes)
|
|
792
|
+
|
|
793
|
+
Returns:
|
|
794
|
+
Result from workflow execution
|
|
795
|
+
|
|
796
|
+
Raises:
|
|
797
|
+
WorkflowTimeoutError: If the execution exceeds the timeout
|
|
798
|
+
|
|
799
|
+
Example:
|
|
800
|
+
result = await my_workflow.run({"param": "value"})
|
|
801
|
+
"""
|
|
802
|
+
# Check if we're in an execution context - fail if we are
|
|
803
|
+
if _execution_context.get() is not None:
|
|
804
|
+
raise RuntimeError(
|
|
805
|
+
"workflow.run() cannot be called from within a workflow or agent. "
|
|
806
|
+
"Use step.invoke_and_wait() to call workflows from within workflows."
|
|
807
|
+
)
|
|
808
|
+
|
|
809
|
+
# Invoke workflow and get handle
|
|
810
|
+
handle = await self.invoke(
|
|
811
|
+
client=client,
|
|
812
|
+
payload=payload,
|
|
813
|
+
queue=queue,
|
|
814
|
+
concurrency_key=concurrency_key,
|
|
815
|
+
session_id=session_id,
|
|
816
|
+
user_id=user_id,
|
|
817
|
+
initial_state=initial_state,
|
|
818
|
+
run_timeout_seconds=int(timeout),
|
|
819
|
+
)
|
|
820
|
+
|
|
821
|
+
# Track start time for timeout
|
|
822
|
+
start_time = time.time()
|
|
823
|
+
|
|
824
|
+
# Poll handle.get() until status is "completed" or "failed"
|
|
825
|
+
while True:
|
|
826
|
+
# Check for timeout
|
|
827
|
+
elapsed_time = time.time() - start_time
|
|
828
|
+
if elapsed_time >= timeout:
|
|
829
|
+
raise WorkflowTimeoutError(
|
|
830
|
+
execution_id=handle.id if hasattr(handle, "id") else None,
|
|
831
|
+
timeout_seconds=timeout,
|
|
832
|
+
)
|
|
833
|
+
|
|
834
|
+
execution_info = await handle.get(client)
|
|
835
|
+
status = execution_info.get("status")
|
|
836
|
+
|
|
837
|
+
if status == "completed":
|
|
838
|
+
result = execution_info.get("result")
|
|
839
|
+
break
|
|
840
|
+
elif status == "failed":
|
|
841
|
+
error = execution_info.get("error", "Workflow execution failed")
|
|
842
|
+
raise RuntimeError(f"Workflow execution failed: {error}")
|
|
843
|
+
|
|
844
|
+
# Wait before checking again
|
|
845
|
+
await asyncio.sleep(0.5)
|
|
846
|
+
|
|
847
|
+
return result
|
|
848
|
+
|
|
849
|
+
|
|
850
|
+
def workflow(
|
|
851
|
+
id: str | None = None,
|
|
852
|
+
queue: str | Queue | dict[str, Any] | None = None,
|
|
853
|
+
trigger_on_event: str | None = None,
|
|
854
|
+
batch_size: int = 1,
|
|
855
|
+
batch_timeout_seconds: int | None = None,
|
|
856
|
+
schedule: bool | str | dict[str, Any] | None = None,
|
|
857
|
+
on_start: str | list[str] | Workflow | list[Workflow] | None = None,
|
|
858
|
+
on_end: str | list[str] | Workflow | list[Workflow] | None = None,
|
|
859
|
+
state_schema: type[BaseModel] | None = None,
|
|
860
|
+
):
|
|
861
|
+
"""Decorator to register a Polos workflow.
|
|
862
|
+
|
|
863
|
+
Usage:
|
|
864
|
+
@workflow
|
|
865
|
+
def my_workflow(payload):
|
|
866
|
+
return {"result": payload * 2}
|
|
867
|
+
|
|
868
|
+
@workflow()
|
|
869
|
+
def my_workflow2(payload):
|
|
870
|
+
return {"result": payload * 2}
|
|
871
|
+
|
|
872
|
+
@workflow(id="custom_workflow_id")
|
|
873
|
+
async def async_workflow(payload):
|
|
874
|
+
await asyncio.sleep(1)
|
|
875
|
+
return {"done": True}
|
|
876
|
+
|
|
877
|
+
# With inline queue config
|
|
878
|
+
@workflow(queue={"concurrency_limit": 1})
|
|
879
|
+
def one_at_a_time(payload):
|
|
880
|
+
return {"result": payload}
|
|
881
|
+
|
|
882
|
+
# With queue name
|
|
883
|
+
@workflow(queue="my-queue")
|
|
884
|
+
def my_queued_workflow(payload):
|
|
885
|
+
return {"result": payload}
|
|
886
|
+
|
|
887
|
+
# With Queue object
|
|
888
|
+
from polos import queue
|
|
889
|
+
my_queue = queue("my-queue", concurrency_limit=5)
|
|
890
|
+
@workflow(queue=my_queue)
|
|
891
|
+
def queued_workflow(payload):
|
|
892
|
+
return {"result": payload}
|
|
893
|
+
|
|
894
|
+
# Event-triggered workflow (one event per invocation)
|
|
895
|
+
@workflow(id="on-approval-dept1", trigger_on_event="approval/dept1")
|
|
896
|
+
async def on_approval_dept1(ctx, payload):
|
|
897
|
+
# payload contains event information
|
|
898
|
+
event_data = payload.get("events", [{}])[0]
|
|
899
|
+
return {"processed": event_data}
|
|
900
|
+
|
|
901
|
+
# Event-triggered workflow with batching (10 events per batch, 30 second timeout)
|
|
902
|
+
@workflow(
|
|
903
|
+
id="batch-processor",
|
|
904
|
+
trigger_on_event="data/updates",
|
|
905
|
+
batch_size=10,
|
|
906
|
+
batch_timeout_seconds=30
|
|
907
|
+
)
|
|
908
|
+
async def batch_processor(ctx, payload):
|
|
909
|
+
# payload contains batch of events
|
|
910
|
+
events = payload.get("events", [])
|
|
911
|
+
return {"processed_count": len(events)}
|
|
912
|
+
|
|
913
|
+
# Scheduled workflow (declarative with cron)
|
|
914
|
+
@workflow(id="daily-cleanup", schedule="0 3 * * *")
|
|
915
|
+
async def daily_cleanup(ctx, payload):
|
|
916
|
+
# payload is SchedulePayload
|
|
917
|
+
print(f"Scheduled to run at {payload.timestamp}")
|
|
918
|
+
return {"status": "cleaned"}
|
|
919
|
+
|
|
920
|
+
# Scheduled workflow with timezone
|
|
921
|
+
@workflow(
|
|
922
|
+
id="morning-report",
|
|
923
|
+
schedule={"cron": "0 8 * * *", "timezone": "America/New_York"}
|
|
924
|
+
)
|
|
925
|
+
async def morning_report(ctx, payload):
|
|
926
|
+
return {"report": "generated"}
|
|
927
|
+
|
|
928
|
+
# Workflow that can be scheduled later (schedule=True)
|
|
929
|
+
@workflow(id="reminder-workflow", schedule=True)
|
|
930
|
+
async def reminder_workflow(ctx, payload):
|
|
931
|
+
# Schedule will be added later using schedules.create()
|
|
932
|
+
return {"reminder": "sent"}
|
|
933
|
+
|
|
934
|
+
# Workflow that cannot be scheduled (schedule=False)
|
|
935
|
+
@workflow(id="one-time-workflow", schedule=False)
|
|
936
|
+
async def one_time_workflow(ctx, payload):
|
|
937
|
+
return {"done": True}
|
|
938
|
+
|
|
939
|
+
Args:
|
|
940
|
+
id: Optional workflow ID (defaults to function name)
|
|
941
|
+
queue: Optional queue configuration. Can be:
|
|
942
|
+
- str: Queue name
|
|
943
|
+
- Queue: Queue object
|
|
944
|
+
- dict: {"concurrency_limit": int} (uses workflow_id as queue name)
|
|
945
|
+
- None: Uses workflow_id as queue name with default concurrency
|
|
946
|
+
- Note: Cannot be specified for event-triggered or scheduled workflows
|
|
947
|
+
trigger_on_event: Optional event topic that triggers this workflow. If specified:
|
|
948
|
+
- Workflow will be automatically triggered when events are published to this topic
|
|
949
|
+
- Workflow gets its own queue with concurrency=1 (to ensure ordering)
|
|
950
|
+
- Payload will contain event information:
|
|
951
|
+
{"events": [{"id": "...", "topic": "...", "event_type": "...",
|
|
952
|
+
"data": {...}, "sequence_id": ..., "created_at": "..."}, ...]}
|
|
953
|
+
batch_size: For event-triggered workflows, number of events to batch
|
|
954
|
+
together (default: 1)
|
|
955
|
+
batch_timeout_seconds: For event-triggered workflows, max time to wait
|
|
956
|
+
for batching (None = no timeout)
|
|
957
|
+
schedule: Optional schedule configuration. Can be:
|
|
958
|
+
- True: Workflow can be scheduled later using schedules.create() API
|
|
959
|
+
- False: Workflow cannot be scheduled (explicit opt-out)
|
|
960
|
+
- str: Cron expression (e.g., "0 3 * * *" for 3 AM daily,
|
|
961
|
+
uses UTC timezone) - creates schedule immediately
|
|
962
|
+
- dict: {"cron": str, "timezone": str, "key": str}
|
|
963
|
+
(e.g., {"cron": "0 8 * * *", "timezone": "America/New_York"})
|
|
964
|
+
- creates schedule immediately
|
|
965
|
+
- Note: Cannot be specified for event-triggered workflows
|
|
966
|
+
- Note: Scheduled workflows cannot specify queues
|
|
967
|
+
(they get their own queue with concurrency=1)
|
|
968
|
+
|
|
969
|
+
Workflow IDs must be valid Python identifiers (letters, numbers, underscores;
|
|
970
|
+
cannot start with a number or be a Python keyword).
|
|
971
|
+
"""
|
|
972
|
+
|
|
973
|
+
def decorator(func: Callable) -> Workflow:
|
|
974
|
+
# Validate function signature
|
|
975
|
+
sig = inspect.signature(func)
|
|
976
|
+
params = list(sig.parameters.values())
|
|
977
|
+
|
|
978
|
+
if len(params) < 1:
|
|
979
|
+
raise TypeError(
|
|
980
|
+
f"Workflow function '{func.__name__}' must have at least 1 "
|
|
981
|
+
f"parameter: (context: WorkflowContext) or "
|
|
982
|
+
f"(context: WorkflowContext, payload: "
|
|
983
|
+
f"Union[BaseModel, dict[str, Any]])"
|
|
984
|
+
)
|
|
985
|
+
|
|
986
|
+
if len(params) > 2:
|
|
987
|
+
raise TypeError(
|
|
988
|
+
f"Workflow function '{func.__name__}' must have at most 2 "
|
|
989
|
+
f"parameters: (context: WorkflowContext) or "
|
|
990
|
+
f"(context: WorkflowContext, payload: "
|
|
991
|
+
f"Union[BaseModel, dict[str, Any]])"
|
|
992
|
+
)
|
|
993
|
+
|
|
994
|
+
# Check first parameter (context)
|
|
995
|
+
first_param = params[0]
|
|
996
|
+
first_annotation = first_param.annotation
|
|
997
|
+
|
|
998
|
+
# Allow untyped parameters or anything that ends with WorkflowContext/AgentContext
|
|
999
|
+
first_type_valid = False
|
|
1000
|
+
if first_annotation == inspect.Parameter.empty:
|
|
1001
|
+
# Untyped is allowed
|
|
1002
|
+
first_type_valid = True
|
|
1003
|
+
elif isinstance(first_annotation, str):
|
|
1004
|
+
# String annotation - check if it ends with WorkflowContext or AgentContext
|
|
1005
|
+
if (
|
|
1006
|
+
first_annotation.endswith("WorkflowContext")
|
|
1007
|
+
or first_annotation.endswith("AgentContext")
|
|
1008
|
+
or "WorkflowContext" in first_annotation
|
|
1009
|
+
or "AgentContext" in first_annotation
|
|
1010
|
+
):
|
|
1011
|
+
first_type_valid = True
|
|
1012
|
+
else:
|
|
1013
|
+
# Type annotation - check if class name ends with WorkflowContext or AgentContext
|
|
1014
|
+
try:
|
|
1015
|
+
# Get the class name
|
|
1016
|
+
type_name = getattr(first_annotation, "__name__", None) or str(first_annotation)
|
|
1017
|
+
if (
|
|
1018
|
+
type_name.endswith("WorkflowContext")
|
|
1019
|
+
or type_name.endswith("AgentContext")
|
|
1020
|
+
or "WorkflowContext" in type_name
|
|
1021
|
+
or "AgentContext" in type_name
|
|
1022
|
+
):
|
|
1023
|
+
first_type_valid = True
|
|
1024
|
+
# Also check if it's the actual WorkflowContext or AgentContext class
|
|
1025
|
+
from ..core.context import AgentContext, WorkflowContext
|
|
1026
|
+
|
|
1027
|
+
if first_annotation in (WorkflowContext, AgentContext):
|
|
1028
|
+
first_type_valid = True
|
|
1029
|
+
elif hasattr(first_annotation, "__origin__"): # Handle Union, Optional, etc.
|
|
1030
|
+
# For Union types, check if WorkflowContext or AgentContext is in the union
|
|
1031
|
+
args = getattr(first_annotation, "__args__", ())
|
|
1032
|
+
if WorkflowContext in args or AgentContext in args:
|
|
1033
|
+
first_type_valid = True
|
|
1034
|
+
except (ImportError, AttributeError):
|
|
1035
|
+
# If we can't check, allow it if the name suggests it's
|
|
1036
|
+
# WorkflowContext or AgentContext
|
|
1037
|
+
type_name = getattr(first_annotation, "__name__", None) or str(first_annotation)
|
|
1038
|
+
if "WorkflowContext" in type_name or "AgentContext" in type_name:
|
|
1039
|
+
first_type_valid = True
|
|
1040
|
+
|
|
1041
|
+
if not first_type_valid:
|
|
1042
|
+
raise TypeError(
|
|
1043
|
+
f"Workflow function '{func.__name__}': first parameter "
|
|
1044
|
+
f"'{first_param.name}' must be typed as WorkflowContext or "
|
|
1045
|
+
f"AgentContext (or untyped), got {first_annotation}"
|
|
1046
|
+
)
|
|
1047
|
+
|
|
1048
|
+
# Check second parameter (payload) if it exists
|
|
1049
|
+
payload_schema_class: type[BaseModel] | None = None
|
|
1050
|
+
if len(params) >= 2:
|
|
1051
|
+
second_param = params[1]
|
|
1052
|
+
second_annotation = second_param.annotation
|
|
1053
|
+
if second_annotation == inspect.Parameter.empty:
|
|
1054
|
+
raise TypeError(
|
|
1055
|
+
f"Workflow function '{func.__name__}': second parameter "
|
|
1056
|
+
f"'{second_param.name}' must be typed as "
|
|
1057
|
+
f"Union[BaseModel, dict[str, Any]] or a specific "
|
|
1058
|
+
f"Pydantic BaseModel class"
|
|
1059
|
+
)
|
|
1060
|
+
|
|
1061
|
+
# Check if second parameter is dict[str, Any], BaseModel, or a Pydantic model
|
|
1062
|
+
second_type_valid = False
|
|
1063
|
+
if isinstance(second_annotation, str):
|
|
1064
|
+
# String annotation - check for dict, BaseModel, or Union
|
|
1065
|
+
if "dict" in second_annotation.lower() or "Dict" in second_annotation:
|
|
1066
|
+
second_type_valid = True
|
|
1067
|
+
if "BaseModel" in second_annotation:
|
|
1068
|
+
second_type_valid = True
|
|
1069
|
+
if "Union" in second_annotation or "|" in second_annotation:
|
|
1070
|
+
second_type_valid = True
|
|
1071
|
+
else:
|
|
1072
|
+
# Actual type - check various cases
|
|
1073
|
+
try:
|
|
1074
|
+
# Check if it's dict type
|
|
1075
|
+
if second_annotation is dict or (
|
|
1076
|
+
hasattr(second_annotation, "__origin__")
|
|
1077
|
+
and getattr(second_annotation, "__origin__", None) is dict
|
|
1078
|
+
):
|
|
1079
|
+
second_type_valid = True
|
|
1080
|
+
|
|
1081
|
+
# Check if it's BaseModel or a subclass
|
|
1082
|
+
if (
|
|
1083
|
+
issubclass(second_annotation, BaseModel)
|
|
1084
|
+
if inspect.isclass(second_annotation)
|
|
1085
|
+
else False
|
|
1086
|
+
):
|
|
1087
|
+
second_type_valid = True
|
|
1088
|
+
# Extract payload schema class for later use
|
|
1089
|
+
payload_schema_class = second_annotation
|
|
1090
|
+
|
|
1091
|
+
# Check if it's a Union type containing dict or BaseModel
|
|
1092
|
+
if hasattr(second_annotation, "__origin__"):
|
|
1093
|
+
origin = getattr(second_annotation, "__origin__", None)
|
|
1094
|
+
if origin is Union or (
|
|
1095
|
+
hasattr(origin, "__name__") and origin.__name__ == "Union"
|
|
1096
|
+
):
|
|
1097
|
+
args = getattr(second_annotation, "__args__", ())
|
|
1098
|
+
for arg in args:
|
|
1099
|
+
if arg is dict or (
|
|
1100
|
+
inspect.isclass(arg) and issubclass(arg, BaseModel)
|
|
1101
|
+
):
|
|
1102
|
+
second_type_valid = True
|
|
1103
|
+
# Extract BaseModel if present
|
|
1104
|
+
if inspect.isclass(arg) and issubclass(arg, BaseModel):
|
|
1105
|
+
payload_schema_class = arg
|
|
1106
|
+
break
|
|
1107
|
+
except (TypeError, AttributeError):
|
|
1108
|
+
pass
|
|
1109
|
+
|
|
1110
|
+
if not second_type_valid:
|
|
1111
|
+
raise TypeError(
|
|
1112
|
+
f"Workflow function '{func.__name__}': second parameter "
|
|
1113
|
+
f"'{second_param.name}' must be typed as "
|
|
1114
|
+
f"Union[BaseModel, dict[str, Any]] or a specific "
|
|
1115
|
+
f"Pydantic BaseModel class, got {second_annotation}"
|
|
1116
|
+
)
|
|
1117
|
+
|
|
1118
|
+
# Determine workflow ID
|
|
1119
|
+
workflow_id = id if id is not None else func.__name__
|
|
1120
|
+
|
|
1121
|
+
# Parse queue configuration
|
|
1122
|
+
queue_name: str | None = None
|
|
1123
|
+
queue_concurrency_limit: int | None = None
|
|
1124
|
+
|
|
1125
|
+
# Determine if workflow is schedulable
|
|
1126
|
+
is_schedulable = False
|
|
1127
|
+
if schedule is True:
|
|
1128
|
+
is_schedulable = True
|
|
1129
|
+
elif schedule is False:
|
|
1130
|
+
is_schedulable = False
|
|
1131
|
+
elif schedule is not None:
|
|
1132
|
+
# schedule is a cron string or dict - workflow is schedulable
|
|
1133
|
+
is_schedulable = True
|
|
1134
|
+
|
|
1135
|
+
# Scheduled workflows cannot specify queues
|
|
1136
|
+
if is_schedulable and queue is not None:
|
|
1137
|
+
raise ValueError("Scheduled workflows cannot specify a queue.")
|
|
1138
|
+
|
|
1139
|
+
if queue is not None:
|
|
1140
|
+
if isinstance(queue, str):
|
|
1141
|
+
# Queue name string
|
|
1142
|
+
queue_name = queue
|
|
1143
|
+
elif isinstance(queue, Queue):
|
|
1144
|
+
# Queue object
|
|
1145
|
+
queue_name = queue.name
|
|
1146
|
+
queue_concurrency_limit = queue.concurrency_limit
|
|
1147
|
+
elif isinstance(queue, dict):
|
|
1148
|
+
# Dict with concurrency_limit
|
|
1149
|
+
queue_name = queue.get(
|
|
1150
|
+
"name", workflow_id
|
|
1151
|
+
) # Use workflow_id as queue name if not provided
|
|
1152
|
+
queue_concurrency_limit = queue.get("concurrency_limit")
|
|
1153
|
+
else:
|
|
1154
|
+
raise ValueError(
|
|
1155
|
+
f"Invalid queue type: {type(queue)}. Expected str, Queue, or dict."
|
|
1156
|
+
)
|
|
1157
|
+
|
|
1158
|
+
workflow_obj = Workflow(
|
|
1159
|
+
id=workflow_id,
|
|
1160
|
+
func=func,
|
|
1161
|
+
queue_name=queue_name,
|
|
1162
|
+
queue_concurrency_limit=queue_concurrency_limit,
|
|
1163
|
+
trigger_on_event=trigger_on_event,
|
|
1164
|
+
batch_size=batch_size,
|
|
1165
|
+
batch_timeout_seconds=batch_timeout_seconds,
|
|
1166
|
+
schedule=schedule,
|
|
1167
|
+
on_start=on_start,
|
|
1168
|
+
on_end=on_end,
|
|
1169
|
+
payload_schema_class=payload_schema_class if len(params) >= 2 else None,
|
|
1170
|
+
state_schema=state_schema,
|
|
1171
|
+
)
|
|
1172
|
+
_WORKFLOW_REGISTRY[workflow_id] = workflow_obj
|
|
1173
|
+
return workflow_obj
|
|
1174
|
+
|
|
1175
|
+
# Handle @workflow (without parentheses) - the function is passed as the first argument
|
|
1176
|
+
if callable(id):
|
|
1177
|
+
func = id
|
|
1178
|
+
id = None
|
|
1179
|
+
return decorator(func)
|
|
1180
|
+
|
|
1181
|
+
# Handle @workflow() or @workflow(id="...", queue=...)
|
|
1182
|
+
return decorator
|
|
1183
|
+
|
|
1184
|
+
|
|
1185
|
+
def get_workflow(workflow_id: str) -> Workflow | None:
|
|
1186
|
+
"""Get a workflow by ID from the registry."""
|
|
1187
|
+
return _WORKFLOW_REGISTRY.get(workflow_id)
|
|
1188
|
+
|
|
1189
|
+
|
|
1190
|
+
def get_all_workflows() -> dict[str, Workflow]:
|
|
1191
|
+
"""Get all registered workflows."""
|
|
1192
|
+
return _WORKFLOW_REGISTRY.copy()
|