polos-sdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. polos/__init__.py +105 -0
  2. polos/agents/__init__.py +7 -0
  3. polos/agents/agent.py +746 -0
  4. polos/agents/conversation_history.py +121 -0
  5. polos/agents/stop_conditions.py +280 -0
  6. polos/agents/stream.py +635 -0
  7. polos/core/__init__.py +0 -0
  8. polos/core/context.py +143 -0
  9. polos/core/state.py +26 -0
  10. polos/core/step.py +1380 -0
  11. polos/core/workflow.py +1192 -0
  12. polos/features/__init__.py +0 -0
  13. polos/features/events.py +456 -0
  14. polos/features/schedules.py +110 -0
  15. polos/features/tracing.py +605 -0
  16. polos/features/wait.py +82 -0
  17. polos/llm/__init__.py +9 -0
  18. polos/llm/generate.py +152 -0
  19. polos/llm/providers/__init__.py +5 -0
  20. polos/llm/providers/anthropic.py +615 -0
  21. polos/llm/providers/azure.py +42 -0
  22. polos/llm/providers/base.py +196 -0
  23. polos/llm/providers/fireworks.py +41 -0
  24. polos/llm/providers/gemini.py +40 -0
  25. polos/llm/providers/groq.py +40 -0
  26. polos/llm/providers/openai.py +1021 -0
  27. polos/llm/providers/together.py +40 -0
  28. polos/llm/stream.py +183 -0
  29. polos/middleware/__init__.py +0 -0
  30. polos/middleware/guardrail.py +148 -0
  31. polos/middleware/guardrail_executor.py +253 -0
  32. polos/middleware/hook.py +164 -0
  33. polos/middleware/hook_executor.py +104 -0
  34. polos/runtime/__init__.py +0 -0
  35. polos/runtime/batch.py +87 -0
  36. polos/runtime/client.py +841 -0
  37. polos/runtime/queue.py +42 -0
  38. polos/runtime/worker.py +1365 -0
  39. polos/runtime/worker_server.py +249 -0
  40. polos/tools/__init__.py +0 -0
  41. polos/tools/tool.py +587 -0
  42. polos/types/__init__.py +23 -0
  43. polos/types/types.py +116 -0
  44. polos/utils/__init__.py +27 -0
  45. polos/utils/agent.py +27 -0
  46. polos/utils/client_context.py +41 -0
  47. polos/utils/config.py +12 -0
  48. polos/utils/output_schema.py +311 -0
  49. polos/utils/retry.py +47 -0
  50. polos/utils/serializer.py +167 -0
  51. polos/utils/tracing.py +27 -0
  52. polos/utils/worker_singleton.py +40 -0
  53. polos_sdk-0.1.0.dist-info/METADATA +650 -0
  54. polos_sdk-0.1.0.dist-info/RECORD +55 -0
  55. polos_sdk-0.1.0.dist-info/WHEEL +4 -0
polos/core/workflow.py ADDED
@@ -0,0 +1,1192 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import inspect
5
+ import json
6
+ import logging
7
+ import time
8
+ from collections.abc import Callable
9
+ from contextvars import ContextVar
10
+ from typing import Any, Union
11
+
12
+ from pydantic import BaseModel
13
+
14
+ from ..features.tracing import (
15
+ create_context_from_traceparent,
16
+ create_context_with_trace_id,
17
+ generate_trace_id_from_execution_id,
18
+ get_tracer,
19
+ )
20
+ from ..features.wait import WaitException
21
+ from ..runtime.client import ExecutionHandle, PolosClient
22
+ from ..runtime.queue import Queue
23
+ from ..utils.serializer import safe_serialize, serialize
24
+ from .context import AgentContext, WorkflowContext
25
+
26
+ # Import OpenTelemetry types
27
+ try:
28
+ from opentelemetry import trace
29
+ from opentelemetry.trace import Status, StatusCode
30
+ except ImportError:
31
+ # Fallback if OpenTelemetry not available
32
+ class Status:
33
+ def __init__(self, code, description=None):
34
+ self.code = code
35
+ self.description = description
36
+
37
+ class StatusCode:
38
+ OK = "OK"
39
+ ERROR = "ERROR"
40
+
41
+ trace = None
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+ # Global registry of workflows
46
+ _WORKFLOW_REGISTRY: dict[str, Workflow] = {}
47
+
48
+ # Context variables for tracking execution state
49
+ _execution_context: ContextVar[dict[str, Any] | None] = ContextVar(
50
+ "execution_context", default=None
51
+ )
52
+
53
+
54
+ class StepExecutionError(Exception):
55
+ """
56
+ Exception raised a step fails and the workflow must fail.
57
+ """
58
+
59
+ def __init__(self, reason: str | None = None):
60
+ self.reason = reason
61
+ super().__init__(reason)
62
+
63
+
64
+ class WorkflowTimeoutError(Exception):
65
+ """
66
+ Exception raised when a workflow/agent/tool execution times out.
67
+ """
68
+
69
+ def __init__(self, execution_id: str | None = None, timeout_seconds: float | None = None):
70
+ self.execution_id = execution_id
71
+ self.timeout_seconds = timeout_seconds
72
+ message = f"Execution timed out after {timeout_seconds} seconds"
73
+ if execution_id:
74
+ message += f" (execution_id: {execution_id})"
75
+ super().__init__(message)
76
+
77
+
78
+ class Workflow:
79
+ def __init__(
80
+ self,
81
+ id: str,
82
+ func: Callable,
83
+ workflow_type: str | None = "workflow",
84
+ queue_name: str | None = None,
85
+ queue_concurrency_limit: int | None = None,
86
+ trigger_on_event: str | None = None,
87
+ batch_size: int = 1,
88
+ batch_timeout_seconds: int | None = None,
89
+ schedule: bool | str | dict[str, Any] | None = None,
90
+ on_start: str | list[str] | Any | list[Any] | None = None,
91
+ on_end: str | list[str] | Any | list[Any] | None = None,
92
+ payload_schema_class: type[BaseModel] | None = None,
93
+ output_schema: type[BaseModel] | None = None,
94
+ state_schema: type[BaseModel] | None = None,
95
+ ):
96
+ self.id = id
97
+ self.func = func
98
+ self.is_async = asyncio.iscoroutinefunction(func)
99
+ # Check if function has a payload parameter
100
+ sig = inspect.signature(func)
101
+ params = list(sig.parameters.values())
102
+ self.has_payload_param = len(params) >= 2
103
+
104
+ # Store payload schema class (extracted during decorator validation)
105
+ self._payload_schema_class = payload_schema_class
106
+
107
+ # If not provided by decorator, try to extract it (fallback for manual Workflow creation)
108
+ if self.has_payload_param and self._payload_schema_class is None:
109
+ second_param = params[1]
110
+ second_annotation = second_param.annotation
111
+
112
+ # Handle string annotations (forward references)
113
+ if isinstance(second_annotation, str):
114
+ # Try to resolve the annotation
115
+ try:
116
+ # Import from the function's module
117
+ func_module = inspect.getmodule(func)
118
+ if func_module:
119
+ # Try to evaluate the annotation in the function's module context
120
+ second_annotation = eval(second_annotation, func_module.__dict__)
121
+ except (NameError, AttributeError, SyntaxError):
122
+ pass
123
+
124
+ # Check if it's a Pydantic BaseModel subclass
125
+ if inspect.isclass(second_annotation) and issubclass(second_annotation, BaseModel):
126
+ self._payload_schema_class = second_annotation
127
+ elif hasattr(second_annotation, "__origin__"):
128
+ # Check if it's a Union type containing a BaseModel
129
+ origin = getattr(second_annotation, "__origin__", None)
130
+ if origin is Union or (hasattr(origin, "__name__") and origin.__name__ == "Union"):
131
+ args = getattr(second_annotation, "__args__", ())
132
+ for arg in args:
133
+ if inspect.isclass(arg) and issubclass(arg, BaseModel):
134
+ # Prefer the BaseModel over dict if both are present
135
+ self._payload_schema_class = arg
136
+ break
137
+
138
+ self.workflow_type = workflow_type
139
+ self.state_schema = state_schema
140
+ self.queue_name = queue_name # None means use workflow_id as queue name
141
+ self.queue_concurrency_limit = queue_concurrency_limit
142
+ self.trigger_on_event = trigger_on_event # Event topic that triggers this workflow
143
+ self.batch_size = batch_size # Max number of events to batch together
144
+ self.batch_timeout_seconds = batch_timeout_seconds # Max time to wait for batching
145
+ # Parse schedule configuration
146
+ # schedule can be: True (schedulable), False (not schedulable),
147
+ # cron string, or dict with cron/timezone
148
+ self.schedule = schedule
149
+ self.is_schedulable = (
150
+ False # Whether this workflow can be scheduled (schedule=True or has cron)
151
+ )
152
+
153
+ if schedule is True:
154
+ self.is_schedulable = True
155
+ elif schedule is False:
156
+ self.is_schedulable = False
157
+ elif schedule is not None:
158
+ # schedule is a cron string or dict - workflow is schedulable and has a default schedule
159
+ self.is_schedulable = True
160
+
161
+ # Scheduled workflows cannot be event-triggered
162
+ if schedule and trigger_on_event:
163
+ raise ValueError("Workflows cannot be both scheduled and event-triggered")
164
+
165
+ # Scheduled workflows cannot specify queues - they get their own queue with concurrency=1
166
+ if self.is_schedulable and (queue_name is not None or queue_concurrency_limit is not None):
167
+ raise ValueError("Scheduled workflows cannot specify a queue or concurrency limit")
168
+
169
+ # Lifecycle hooks - normalize to list of callables
170
+ self.on_start = self._normalize_hooks(on_start)
171
+ self.on_end = self._normalize_hooks(on_end)
172
+
173
+ # Output schema class (can be set in constructor or after execution
174
+ # if result is a Pydantic model)
175
+ self.output_schema: type[BaseModel] | None = output_schema
176
+
177
+ def _prepare_payload(self, payload: BaseModel | dict[str, Any] | None) -> dict[str, Any] | None:
178
+ """
179
+ Validate and normalize payload for submission to the orchestrator.
180
+
181
+ Rules:
182
+ - If payload is None, return None
183
+ - If payload is a Pydantic BaseModel instance, convert to dict via model_dump(mode="json")
184
+ - If payload is a dict, ensure it is JSON serializable via json.dumps()
185
+ - Otherwise, raise TypeError
186
+ """
187
+ if payload is None:
188
+ return None
189
+
190
+ # Pydantic model → dict
191
+ if isinstance(payload, BaseModel):
192
+ return payload.model_dump(mode="json")
193
+
194
+ # Require dict[str, Any] for non-Pydantic payloads
195
+ if isinstance(payload, dict):
196
+ try:
197
+ # Validate JSON serializability of the dict
198
+ json.dumps(payload)
199
+ except (TypeError, ValueError) as e:
200
+ raise TypeError(
201
+ f"Workflow '{self.id}' payload dict is not JSON serializable: {e}. "
202
+ f"Consider using a Pydantic BaseModel for structured data."
203
+ ) from e
204
+ return payload
205
+
206
+ raise TypeError(
207
+ f"Workflow '{self.id}' payload must be a dict or Pydantic BaseModel instance, "
208
+ f"got {type(payload).__name__}."
209
+ )
210
+
211
+ def _normalize_hooks(self, hooks: Callable | list[Callable] | None) -> list[Callable]:
212
+ """Normalize hooks to a list of callables.
213
+
214
+ Accepts:
215
+ - None: Returns empty list
216
+ - Callable: Single hook callable
217
+ - List[Callable]: List of hook callables
218
+ """
219
+
220
+ if hooks is None:
221
+ return []
222
+ if callable(hooks):
223
+ return [hooks]
224
+ if isinstance(hooks, list):
225
+ result = []
226
+ for hook in hooks:
227
+ if callable(hook):
228
+ result.append(hook)
229
+ else:
230
+ raise TypeError(
231
+ f"Invalid hook type: {type(hook)}. Expected a callable "
232
+ f"(function decorated with @hook)."
233
+ )
234
+ return result
235
+ raise TypeError(f"Invalid hooks type: {type(hooks)}. Expected callable or List[callable].")
236
+
237
+ async def _execute(self, context: dict[str, Any], payload: Any) -> Any:
238
+ """Execute the workflow with the given payload and checkpointing."""
239
+ execution_id = context.get("execution_id")
240
+ deployment_id = context.get("deployment_id")
241
+ parent_execution_id = context.get("parent_execution_id")
242
+ root_execution_id = context.get("root_execution_id")
243
+ retry_count = context.get("retry_count", 0)
244
+ created_at = context.get("created_at")
245
+ session_id = context.get("session_id")
246
+ user_id = context.get("user_id")
247
+ otel_traceparent = context.get("otel_traceparent")
248
+ otel_span_id = context.get("otel_span_id")
249
+
250
+ # Ensure execution_id is a string
251
+ if execution_id:
252
+ execution_id = str(execution_id)
253
+ if deployment_id:
254
+ deployment_id = str(deployment_id)
255
+ if parent_execution_id:
256
+ parent_execution_id = str(parent_execution_id)
257
+ if root_execution_id:
258
+ root_execution_id = str(root_execution_id)
259
+ if retry_count:
260
+ retry_count = int(retry_count)
261
+ if session_id:
262
+ session_id = str(session_id)
263
+ if user_id:
264
+ user_id = str(user_id)
265
+
266
+ # Set execution context for this workflow execution
267
+ # If we have root_execution_id, use it; otherwise, we are the root
268
+ effective_root_execution_id = root_execution_id if root_execution_id else execution_id
269
+
270
+ # Get initial_state from context if provided
271
+ initial_state = context.get("initial_state")
272
+ if initial_state:
273
+ initial_state = self.state_schema.model_validate(initial_state)
274
+
275
+ # Check if this is an Agent and create appropriate context
276
+ from ..agents.agent import Agent
277
+
278
+ if isinstance(self, Agent):
279
+ # Extract conversation_id from payload if provided
280
+ conversation_id = payload.get("conversation_id") if isinstance(payload, dict) else None
281
+ # Create AgentContext for agents
282
+ workflow_ctx = AgentContext(
283
+ agent_id=self.id,
284
+ execution_id=execution_id,
285
+ deployment_id=deployment_id,
286
+ parent_execution_id=parent_execution_id,
287
+ root_execution_id=effective_root_execution_id,
288
+ retry_count=retry_count,
289
+ model=self.model,
290
+ provider=self.provider,
291
+ system_prompt=self.system_prompt,
292
+ tools=self.tools,
293
+ temperature=self.temperature,
294
+ max_tokens=self.max_output_tokens,
295
+ session_id=session_id,
296
+ conversation_id=conversation_id,
297
+ user_id=user_id,
298
+ created_at=created_at,
299
+ otel_traceparent=otel_traceparent,
300
+ otel_span_id=otel_span_id,
301
+ state_schema=self.state_schema,
302
+ initial_state=initial_state,
303
+ )
304
+ else:
305
+ # Create WorkflowContext for regular workflows or tools
306
+ from ..tools.tool import Tool
307
+
308
+ workflow_type = "tool" if isinstance(self, Tool) else "workflow"
309
+
310
+ workflow_ctx = WorkflowContext(
311
+ workflow_id=self.id,
312
+ execution_id=execution_id,
313
+ deployment_id=deployment_id,
314
+ parent_execution_id=parent_execution_id,
315
+ root_execution_id=effective_root_execution_id,
316
+ retry_count=retry_count,
317
+ created_at=created_at,
318
+ session_id=session_id,
319
+ user_id=user_id,
320
+ workflow_type=workflow_type,
321
+ otel_traceparent=otel_traceparent,
322
+ otel_span_id=otel_span_id,
323
+ state_schema=self.state_schema,
324
+ initial_state=initial_state,
325
+ )
326
+
327
+ # Convert dict payload to Pydantic model if needed
328
+ prepared_payload = payload
329
+ if self.has_payload_param and self._payload_schema_class is not None:
330
+ if isinstance(payload, dict):
331
+ # Convert dict to Pydantic model
332
+ try:
333
+ prepared_payload = self._payload_schema_class.model_validate(payload)
334
+ except Exception as e:
335
+ raise ValueError(
336
+ f"Invalid payload for workflow '{self.id}': failed to "
337
+ f"validate against {self._payload_schema_class.__name__}: {e}"
338
+ ) from e
339
+ elif payload is not None and not isinstance(payload, self._payload_schema_class):
340
+ # Payload is provided but not the right type
341
+ raise ValueError(
342
+ f"Invalid payload for workflow '{self.id}': "
343
+ f"expected {self._payload_schema_class.__name__} or dict, "
344
+ f"got {type(payload).__name__}"
345
+ )
346
+
347
+ return await self._execute_internal(workflow_ctx, prepared_payload)
348
+
349
+ async def _execute_internal(self, ctx: WorkflowContext, payload: Any) -> Any:
350
+ """Internal execution method with shared logic for workflows and agents.
351
+
352
+ This method handles:
353
+ - Checking execution_step_outputs for replay
354
+ - Setting up execution context cache
355
+ - Publishing start event
356
+ - Executing on_start hooks
357
+ - Calling the workflow/agent function
358
+ - Executing on_end hooks
359
+ - Publishing finish event
360
+ - Handling WaitException
361
+ """
362
+ from ..middleware.hook import HookAction, HookContext
363
+ from ..middleware.hook_executor import execute_hooks
364
+
365
+ # Determine span name based on workflow type
366
+ workflow_type = ctx.workflow_type or "workflow"
367
+ span_name = f"{workflow_type}.{ctx.workflow_id}"
368
+
369
+ traceparent = ctx.otel_traceparent
370
+
371
+ # Get parent context for sub-workflows, or create deterministic trace ID for root
372
+ parent_context = None
373
+ if traceparent:
374
+ # Sub-workflow: use parent's trace context
375
+ parent_context = create_context_from_traceparent(traceparent)
376
+ if parent_context is None:
377
+ logger.warning(
378
+ "Failed to extract trace context from traceparent: %s. Creating new trace.",
379
+ traceparent,
380
+ )
381
+ # Fall back to deterministic trace ID if extraction fails
382
+ trace_id = generate_trace_id_from_execution_id(
383
+ ctx.root_execution_id or ctx.execution_id
384
+ )
385
+ parent_context = create_context_with_trace_id(trace_id)
386
+ else:
387
+ # Root workflow: create deterministic trace ID from root_execution_id
388
+ trace_id = generate_trace_id_from_execution_id(
389
+ ctx.root_execution_id or ctx.execution_id
390
+ )
391
+ parent_context = create_context_with_trace_id(trace_id)
392
+
393
+ # Create root span for workflow execution using context manager
394
+ tracer = get_tracer()
395
+ span_attributes = {
396
+ f"{workflow_type}.id": ctx.workflow_id,
397
+ f"{workflow_type}.execution_id": ctx.execution_id,
398
+ f"{workflow_type}.parent_execution_id": ctx.parent_execution_id or "",
399
+ f"{workflow_type}.root_execution_id": ctx.root_execution_id or ctx.execution_id,
400
+ f"{workflow_type}.deployment_id": ctx.deployment_id,
401
+ f"{workflow_type}.type": workflow_type,
402
+ f"{workflow_type}.session_id": ctx.session_id or "",
403
+ f"{workflow_type}.user_id": ctx.user_id or "",
404
+ f"{workflow_type}.retry_count": ctx.retry_count or 0,
405
+ }
406
+
407
+ # If this is a resumed workflow, add previous_span_id attribute
408
+ if ctx.otel_span_id:
409
+ span_attributes[f"{workflow_type}.previous_span_id"] = ctx.otel_span_id
410
+
411
+ # For root workflows with deterministic trace_id, we need to attach the context
412
+ # so the IdGenerator can access it. For sub-workflows, we use context parameter.
413
+ context_token = None
414
+ if parent_context and not traceparent:
415
+ # Root workflow: attach context so IdGenerator can read trace_id
416
+ from opentelemetry import context as otel_context
417
+
418
+ context_token = otel_context.attach(parent_context)
419
+
420
+ with tracer.start_as_current_span(
421
+ name=span_name,
422
+ context=parent_context
423
+ if traceparent
424
+ else None, # For sub-workflows, pass context; for root, rely on attached context
425
+ attributes=span_attributes,
426
+ ) as workflow_span:
427
+ # If this is a resumed workflow, add resumed event
428
+ if ctx.otel_span_id:
429
+ workflow_span.add_event(f"{workflow_type}.resumed")
430
+ # Update execution context with current span for nested spans
431
+ exec_context = ctx.to_dict()
432
+ span_context = workflow_span.get_span_context()
433
+ exec_context["_otel_span_context"] = span_context
434
+ exec_context["_otel_trace_id"] = format(span_context.trace_id, "032x")
435
+ exec_context["_otel_span_id"] = format(span_context.span_id, "016x")
436
+ exec_context["state"] = ctx.state # Store workflow_ctx for worker to access final_state
437
+ token = _execution_context.set(exec_context)
438
+
439
+ # Topic for workflow events
440
+ topic = f"workflow:{ctx.root_execution_id or ctx.execution_id}"
441
+
442
+ serialized_payload = serialize(payload) if payload is not None else None
443
+
444
+ # Store input in span attributes as JSON string
445
+ if payload is not None:
446
+ workflow_span.set_attribute(
447
+ f"{workflow_type}.input", json.dumps(serialized_payload)
448
+ )
449
+
450
+ # Store initial state in span if provided
451
+ if ctx.state is not None and self.state_schema:
452
+ try:
453
+ initial_state_dict = ctx.state.model_dump(mode="json")
454
+ workflow_span.set_attribute(
455
+ f"{workflow_type}.initial_state", json.dumps(initial_state_dict)
456
+ )
457
+ except Exception as e:
458
+ logger.warning(f"Failed to serialize initial state for span: {e}")
459
+
460
+ try:
461
+ # Publish start event
462
+ await ctx.step.publish_event(
463
+ "publish_start",
464
+ topic=topic,
465
+ event_type=f"{workflow_type}_start",
466
+ data={
467
+ "payload": serialized_payload,
468
+ "_metadata": {
469
+ "execution_id": ctx.execution_id,
470
+ "workflow_id": ctx.workflow_id,
471
+ },
472
+ },
473
+ )
474
+
475
+ # Execute on_start hooks
476
+ if self.on_start:
477
+ hook_context = HookContext(
478
+ workflow_id=self.id,
479
+ current_payload=payload,
480
+ )
481
+ hook_result = await execute_hooks("on_start", self.on_start, hook_context, ctx)
482
+
483
+ # Apply modifications from hooks
484
+ if hook_result.modified_payload is not None:
485
+ payload = hook_result.modified_payload
486
+
487
+ # Check hook action
488
+ if hook_result.action == HookAction.FAIL:
489
+ raise StepExecutionError(
490
+ hook_result.error_message or "Hook execution failed"
491
+ )
492
+
493
+ # Call the workflow/agent function with context and payload (if function expects it)
494
+ if self.is_async:
495
+ if self.has_payload_param:
496
+ result = await self.func(ctx, payload)
497
+ else:
498
+ result = await self.func(ctx)
499
+ else:
500
+ # Run sync function in executor to avoid blocking
501
+ loop = asyncio.get_event_loop()
502
+ if self.has_payload_param:
503
+ result = await loop.run_in_executor(None, self.func, ctx, payload)
504
+ else:
505
+ result = await loop.run_in_executor(None, self.func, ctx)
506
+
507
+ # Execute on_end hooks
508
+ if self.on_end:
509
+ hook_context = HookContext(
510
+ workflow_id=self.id,
511
+ current_payload=payload,
512
+ current_output=result,
513
+ )
514
+ hook_result = await execute_hooks("on_end", self.on_end, hook_context, ctx)
515
+
516
+ # Apply modifications from hooks
517
+ if hook_result.modified_output is not None:
518
+ result = hook_result.modified_output
519
+
520
+ # Check hook action
521
+ if hook_result.action == HookAction.FAIL:
522
+ raise StepExecutionError(
523
+ hook_result.error_message or "Hook execution failed"
524
+ )
525
+
526
+ serialized_result = serialize(result) if result is not None else None
527
+ # Publish finish event (only if we didn't hit WaitException)
528
+ await ctx.step.publish_event(
529
+ "publish_finish",
530
+ topic=topic,
531
+ event_type=f"{workflow_type}_finish",
532
+ data={
533
+ "result": serialized_result,
534
+ "_metadata": {
535
+ "execution_id": ctx.execution_id,
536
+ "workflow_id": ctx.workflow_id,
537
+ },
538
+ },
539
+ )
540
+
541
+ # Set span status to success
542
+ workflow_span.set_status(Status(StatusCode.OK))
543
+ workflow_span.set_attribute(f"{workflow_type}.status", "completed")
544
+ workflow_span.set_attribute(
545
+ f"{workflow_type}.result_size", len(str(result)) if result else 0
546
+ )
547
+ # Store output in span attributes as JSON string
548
+ if result is not None:
549
+ workflow_span.set_attribute(
550
+ f"{workflow_type}.output", json.dumps(serialized_result)
551
+ )
552
+
553
+ final_state = None
554
+ # Store final state in span if workflow has state_schema
555
+ if ctx.state is not None and self.state_schema:
556
+ try:
557
+ final_state = ctx.state.model_dump(mode="json")
558
+ workflow_span.set_attribute(
559
+ f"{workflow_type}.final_state", json.dumps(final_state)
560
+ )
561
+ except Exception as e:
562
+ logger.warning(f"Failed to serialize final state for span: {e}")
563
+
564
+ # Span automatically ended and stored by DatabaseSpanExporter
565
+ # when context manager exits
566
+ return result, final_state
567
+ except WaitException:
568
+ # Execution is paused for waiting - this is expected
569
+ # The orchestrator will resume it when the wait expires
570
+ # Do NOT publish finish event when WaitException is raised
571
+ workflow_span.set_status(Status(StatusCode.OK))
572
+ workflow_span.set_attribute(f"{workflow_type}.status", "waiting")
573
+ workflow_span.add_event(f"{workflow_type}.waiting")
574
+
575
+ # Save current span_id to database for resume linkage
576
+ span_context = workflow_span.get_span_context()
577
+ span_id_hex = format(span_context.span_id, "016x")
578
+ from ..runtime.client import update_execution_otel_span_id
579
+
580
+ try:
581
+ # Schedule the update as a background task (don't await to avoid blocking)
582
+ asyncio.create_task(
583
+ update_execution_otel_span_id(ctx.execution_id, span_id_hex)
584
+ )
585
+ except Exception as e:
586
+ # Log error but don't fail on span_id update failure
587
+ logger.warning(
588
+ f"Failed to update otel_span_id for execution {ctx.execution_id}: {e}"
589
+ )
590
+
591
+ # Span automatically ended and stored by DatabaseSpanExporter
592
+ # when context manager exits
593
+ raise
594
+
595
+ except Exception as e:
596
+ # Set span status to error
597
+ workflow_span.set_status(Status(StatusCode.ERROR, str(e)))
598
+ workflow_span.set_attribute(f"{workflow_type}.status", "failed")
599
+ workflow_span.record_exception(e)
600
+
601
+ # Store error in span attributes as JSON string
602
+ error_message = str(e)
603
+ workflow_error = {
604
+ "message": error_message,
605
+ "type": type(e).__name__,
606
+ }
607
+ workflow_span.set_attribute(
608
+ f"{workflow_type}.error", json.dumps(safe_serialize(workflow_error))
609
+ )
610
+
611
+ final_state = None
612
+ if ctx.state is not None and self.state_schema:
613
+ try:
614
+ final_state = ctx.state.model_dump(mode="json")
615
+ workflow_span.set_attribute(
616
+ f"{workflow_type}.final_state", json.dumps(final_state)
617
+ )
618
+ except Exception as e2:
619
+ logger.warning(f"Failed to serialize final state for error case: {e2}")
620
+
621
+ # Span automatically ended and stored by DatabaseSpanExporter
622
+ # when context manager exits
623
+ raise
624
+
625
+ finally:
626
+ # Restore previous context
627
+ _execution_context.reset(token)
628
+ # Detach OTel context if we attached it for root workflows
629
+ if context_token is not None:
630
+ from opentelemetry import context as otel_context
631
+
632
+ otel_context.detach(context_token)
633
+ # Span context automatically cleaned up by context manager
634
+
635
+ async def invoke(
636
+ self,
637
+ client: PolosClient,
638
+ payload: BaseModel | dict[str, Any] | None = None,
639
+ queue: str | None = None,
640
+ concurrency_key: str | None = None,
641
+ session_id: str | None = None,
642
+ user_id: str | None = None,
643
+ initial_state: BaseModel | dict[str, Any] | None = None,
644
+ run_timeout_seconds: int | None = None,
645
+ ) -> ExecutionHandle:
646
+ """Invoke workflow execution via orchestrator and return a handle immediately.
647
+
648
+ This is a fire-and-forget operation.
649
+ The workflow will be executed asynchronously and the handle will be returned immediately.
650
+ This workflow cannot be called from within a workflow or agent.
651
+ Use step.invoke() to call workflows from within workflows.
652
+
653
+ Args:
654
+ client: PolosClient instance
655
+ payload: Workflow payload
656
+ queue: Optional queue name (overrides workflow-level queue)
657
+ concurrency_key: Optional concurrency key for per-tenant queuing
658
+ session_id: Optional session ID (inherited from parent if not provided)
659
+ user_id: Optional user ID (inherited from parent if not provided)
660
+
661
+ Returns:
662
+ ExecutionHandle for monitoring and managing the execution
663
+
664
+ Raises:
665
+ ValueError: If workflow is event-triggered (cannot be invoked directly)
666
+ """
667
+ # Check if we're in an execution context - fail if we are
668
+ if _execution_context.get() is not None:
669
+ raise RuntimeError(
670
+ "workflow.run() cannot be called from within a workflow or agent. "
671
+ "Use step.invoke() to call workflows from within workflows."
672
+ )
673
+
674
+ return await self._invoke(
675
+ client,
676
+ payload,
677
+ queue=queue,
678
+ concurrency_key=concurrency_key,
679
+ session_id=session_id,
680
+ user_id=user_id,
681
+ initial_state=initial_state,
682
+ run_timeout_seconds=run_timeout_seconds,
683
+ )
684
+
685
+ async def _invoke(
686
+ self,
687
+ client: PolosClient,
688
+ payload: BaseModel | dict[str, Any] | None = None,
689
+ queue: str | None = None,
690
+ concurrency_key: str | None = None,
691
+ batch_id: str | None = None,
692
+ session_id: str | None = None,
693
+ user_id: str | None = None,
694
+ deployment_id: str | None = None,
695
+ parent_execution_id: str | None = None,
696
+ root_execution_id: str | None = None,
697
+ step_key: str | None = None,
698
+ wait_for_subworkflow: bool = False,
699
+ otel_traceparent: str | None = None,
700
+ initial_state: BaseModel | dict[str, Any] | None = None,
701
+ run_timeout_seconds: int | None = None,
702
+ ) -> ExecutionHandle:
703
+ """Invoke workflow execution via orchestrator and return a handle immediately.
704
+
705
+ This is a fire-and-forget operation.
706
+ The workflow will be executed asynchronously and the handle will be returned immediately.
707
+
708
+ Args:
709
+ client: PolosClient instance
710
+ payload: Workflow payload
711
+ queue: Optional queue name (overrides workflow-level queue)
712
+ concurrency_key: Optional concurrency key for per-tenant queuing
713
+ batch_id: Optional batch ID for batching
714
+ session_id: Optional session ID (inherited from parent if not provided)
715
+ user_id: Optional user ID (inherited from parent if not provided)
716
+ step_key: Optional step_key (set when invoked from step.py)
717
+
718
+ Returns:
719
+ ExecutionHandle for monitoring and managing the execution
720
+ """
721
+
722
+ if self.trigger_on_event and (payload is None or payload.get("events") is None):
723
+ raise ValueError(
724
+ f"Workflow '{self.id}' is event-triggered and should have events in the payload."
725
+ )
726
+
727
+ # Validate and normalize payload (dict or Pydantic BaseModel only)
728
+ # Only prepare payload if workflow expects it
729
+ if self.has_payload_param:
730
+ if payload is None:
731
+ raise ValueError(
732
+ f"Workflow '{self.id}' requires a payload parameter, but None was provided"
733
+ )
734
+ # payload = self._prepare_payload(payload)
735
+ payload = serialize(payload)
736
+ else:
737
+ # Workflow doesn't expect payload - ignore it if provided
738
+ if payload is not None:
739
+ # Warn but don't fail - user might be calling with payload by mistake
740
+ pass
741
+ payload = None
742
+
743
+ # Invoke the workflow (it will be checkpointed when it executes)
744
+ # For nested workflows called via step.invoke(), use workflow's own queue configuration
745
+ queue_name = queue if queue else self.queue_name if self.queue_name is not None else self.id
746
+ handle = await client._submit_workflow(
747
+ self.id,
748
+ payload,
749
+ deployment_id=deployment_id,
750
+ parent_execution_id=parent_execution_id,
751
+ root_execution_id=root_execution_id,
752
+ step_key=step_key,
753
+ queue_name=queue_name,
754
+ queue_concurrency_limit=self.queue_concurrency_limit,
755
+ concurrency_key=concurrency_key,
756
+ wait_for_subworkflow=wait_for_subworkflow,
757
+ batch_id=batch_id,
758
+ session_id=session_id,
759
+ user_id=user_id,
760
+ otel_traceparent=otel_traceparent,
761
+ initial_state=serialize(initial_state),
762
+ run_timeout_seconds=run_timeout_seconds,
763
+ )
764
+ return handle
765
+
766
+ async def run(
767
+ self,
768
+ client: PolosClient,
769
+ payload: BaseModel | dict[str, Any] | None = None,
770
+ queue: str | None = None,
771
+ concurrency_key: str | None = None,
772
+ session_id: str | None = None,
773
+ user_id: str | None = None,
774
+ timeout: float | None = 600.0,
775
+ initial_state: BaseModel | dict[str, Any] | None = None,
776
+ ) -> Any:
777
+ """
778
+ Run workflow and return final result (wait for completion).
779
+
780
+ This method cannot be called from within an execution context
781
+ (e.g., from within a workflow).
782
+ Use step.invoke_and_wait() to call workflows from within workflows.
783
+
784
+ Args:
785
+ client: PolosClient instance
786
+ payload: Workflow payload (dict or Pydantic BaseModel)
787
+ queue: Optional queue name (overrides workflow-level queue)
788
+ concurrency_key: Optional concurrency key for per-tenant queuing
789
+ session_id: Optional session ID
790
+ user_id: Optional user ID
791
+ timeout: Optional timeout in seconds (default: 600 seconds / 10 minutes)
792
+
793
+ Returns:
794
+ Result from workflow execution
795
+
796
+ Raises:
797
+ WorkflowTimeoutError: If the execution exceeds the timeout
798
+
799
+ Example:
800
+ result = await my_workflow.run({"param": "value"})
801
+ """
802
+ # Check if we're in an execution context - fail if we are
803
+ if _execution_context.get() is not None:
804
+ raise RuntimeError(
805
+ "workflow.run() cannot be called from within a workflow or agent. "
806
+ "Use step.invoke_and_wait() to call workflows from within workflows."
807
+ )
808
+
809
+ # Invoke workflow and get handle
810
+ handle = await self.invoke(
811
+ client=client,
812
+ payload=payload,
813
+ queue=queue,
814
+ concurrency_key=concurrency_key,
815
+ session_id=session_id,
816
+ user_id=user_id,
817
+ initial_state=initial_state,
818
+ run_timeout_seconds=int(timeout),
819
+ )
820
+
821
+ # Track start time for timeout
822
+ start_time = time.time()
823
+
824
+ # Poll handle.get() until status is "completed" or "failed"
825
+ while True:
826
+ # Check for timeout
827
+ elapsed_time = time.time() - start_time
828
+ if elapsed_time >= timeout:
829
+ raise WorkflowTimeoutError(
830
+ execution_id=handle.id if hasattr(handle, "id") else None,
831
+ timeout_seconds=timeout,
832
+ )
833
+
834
+ execution_info = await handle.get(client)
835
+ status = execution_info.get("status")
836
+
837
+ if status == "completed":
838
+ result = execution_info.get("result")
839
+ break
840
+ elif status == "failed":
841
+ error = execution_info.get("error", "Workflow execution failed")
842
+ raise RuntimeError(f"Workflow execution failed: {error}")
843
+
844
+ # Wait before checking again
845
+ await asyncio.sleep(0.5)
846
+
847
+ return result
848
+
849
+
850
+ def workflow(
851
+ id: str | None = None,
852
+ queue: str | Queue | dict[str, Any] | None = None,
853
+ trigger_on_event: str | None = None,
854
+ batch_size: int = 1,
855
+ batch_timeout_seconds: int | None = None,
856
+ schedule: bool | str | dict[str, Any] | None = None,
857
+ on_start: str | list[str] | Workflow | list[Workflow] | None = None,
858
+ on_end: str | list[str] | Workflow | list[Workflow] | None = None,
859
+ state_schema: type[BaseModel] | None = None,
860
+ ):
861
+ """Decorator to register a Polos workflow.
862
+
863
+ Usage:
864
+ @workflow
865
+ def my_workflow(payload):
866
+ return {"result": payload * 2}
867
+
868
+ @workflow()
869
+ def my_workflow2(payload):
870
+ return {"result": payload * 2}
871
+
872
+ @workflow(id="custom_workflow_id")
873
+ async def async_workflow(payload):
874
+ await asyncio.sleep(1)
875
+ return {"done": True}
876
+
877
+ # With inline queue config
878
+ @workflow(queue={"concurrency_limit": 1})
879
+ def one_at_a_time(payload):
880
+ return {"result": payload}
881
+
882
+ # With queue name
883
+ @workflow(queue="my-queue")
884
+ def my_queued_workflow(payload):
885
+ return {"result": payload}
886
+
887
+ # With Queue object
888
+ from polos import queue
889
+ my_queue = queue("my-queue", concurrency_limit=5)
890
+ @workflow(queue=my_queue)
891
+ def queued_workflow(payload):
892
+ return {"result": payload}
893
+
894
+ # Event-triggered workflow (one event per invocation)
895
+ @workflow(id="on-approval-dept1", trigger_on_event="approval/dept1")
896
+ async def on_approval_dept1(ctx, payload):
897
+ # payload contains event information
898
+ event_data = payload.get("events", [{}])[0]
899
+ return {"processed": event_data}
900
+
901
+ # Event-triggered workflow with batching (10 events per batch, 30 second timeout)
902
+ @workflow(
903
+ id="batch-processor",
904
+ trigger_on_event="data/updates",
905
+ batch_size=10,
906
+ batch_timeout_seconds=30
907
+ )
908
+ async def batch_processor(ctx, payload):
909
+ # payload contains batch of events
910
+ events = payload.get("events", [])
911
+ return {"processed_count": len(events)}
912
+
913
+ # Scheduled workflow (declarative with cron)
914
+ @workflow(id="daily-cleanup", schedule="0 3 * * *")
915
+ async def daily_cleanup(ctx, payload):
916
+ # payload is SchedulePayload
917
+ print(f"Scheduled to run at {payload.timestamp}")
918
+ return {"status": "cleaned"}
919
+
920
+ # Scheduled workflow with timezone
921
+ @workflow(
922
+ id="morning-report",
923
+ schedule={"cron": "0 8 * * *", "timezone": "America/New_York"}
924
+ )
925
+ async def morning_report(ctx, payload):
926
+ return {"report": "generated"}
927
+
928
+ # Workflow that can be scheduled later (schedule=True)
929
+ @workflow(id="reminder-workflow", schedule=True)
930
+ async def reminder_workflow(ctx, payload):
931
+ # Schedule will be added later using schedules.create()
932
+ return {"reminder": "sent"}
933
+
934
+ # Workflow that cannot be scheduled (schedule=False)
935
+ @workflow(id="one-time-workflow", schedule=False)
936
+ async def one_time_workflow(ctx, payload):
937
+ return {"done": True}
938
+
939
+ Args:
940
+ id: Optional workflow ID (defaults to function name)
941
+ queue: Optional queue configuration. Can be:
942
+ - str: Queue name
943
+ - Queue: Queue object
944
+ - dict: {"concurrency_limit": int} (uses workflow_id as queue name)
945
+ - None: Uses workflow_id as queue name with default concurrency
946
+ - Note: Cannot be specified for event-triggered or scheduled workflows
947
+ trigger_on_event: Optional event topic that triggers this workflow. If specified:
948
+ - Workflow will be automatically triggered when events are published to this topic
949
+ - Workflow gets its own queue with concurrency=1 (to ensure ordering)
950
+ - Payload will contain event information:
951
+ {"events": [{"id": "...", "topic": "...", "event_type": "...",
952
+ "data": {...}, "sequence_id": ..., "created_at": "..."}, ...]}
953
+ batch_size: For event-triggered workflows, number of events to batch
954
+ together (default: 1)
955
+ batch_timeout_seconds: For event-triggered workflows, max time to wait
956
+ for batching (None = no timeout)
957
+ schedule: Optional schedule configuration. Can be:
958
+ - True: Workflow can be scheduled later using schedules.create() API
959
+ - False: Workflow cannot be scheduled (explicit opt-out)
960
+ - str: Cron expression (e.g., "0 3 * * *" for 3 AM daily,
961
+ uses UTC timezone) - creates schedule immediately
962
+ - dict: {"cron": str, "timezone": str, "key": str}
963
+ (e.g., {"cron": "0 8 * * *", "timezone": "America/New_York"})
964
+ - creates schedule immediately
965
+ - Note: Cannot be specified for event-triggered workflows
966
+ - Note: Scheduled workflows cannot specify queues
967
+ (they get their own queue with concurrency=1)
968
+
969
+ Workflow IDs must be valid Python identifiers (letters, numbers, underscores;
970
+ cannot start with a number or be a Python keyword).
971
+ """
972
+
973
+ def decorator(func: Callable) -> Workflow:
974
+ # Validate function signature
975
+ sig = inspect.signature(func)
976
+ params = list(sig.parameters.values())
977
+
978
+ if len(params) < 1:
979
+ raise TypeError(
980
+ f"Workflow function '{func.__name__}' must have at least 1 "
981
+ f"parameter: (context: WorkflowContext) or "
982
+ f"(context: WorkflowContext, payload: "
983
+ f"Union[BaseModel, dict[str, Any]])"
984
+ )
985
+
986
+ if len(params) > 2:
987
+ raise TypeError(
988
+ f"Workflow function '{func.__name__}' must have at most 2 "
989
+ f"parameters: (context: WorkflowContext) or "
990
+ f"(context: WorkflowContext, payload: "
991
+ f"Union[BaseModel, dict[str, Any]])"
992
+ )
993
+
994
+ # Check first parameter (context)
995
+ first_param = params[0]
996
+ first_annotation = first_param.annotation
997
+
998
+ # Allow untyped parameters or anything that ends with WorkflowContext/AgentContext
999
+ first_type_valid = False
1000
+ if first_annotation == inspect.Parameter.empty:
1001
+ # Untyped is allowed
1002
+ first_type_valid = True
1003
+ elif isinstance(first_annotation, str):
1004
+ # String annotation - check if it ends with WorkflowContext or AgentContext
1005
+ if (
1006
+ first_annotation.endswith("WorkflowContext")
1007
+ or first_annotation.endswith("AgentContext")
1008
+ or "WorkflowContext" in first_annotation
1009
+ or "AgentContext" in first_annotation
1010
+ ):
1011
+ first_type_valid = True
1012
+ else:
1013
+ # Type annotation - check if class name ends with WorkflowContext or AgentContext
1014
+ try:
1015
+ # Get the class name
1016
+ type_name = getattr(first_annotation, "__name__", None) or str(first_annotation)
1017
+ if (
1018
+ type_name.endswith("WorkflowContext")
1019
+ or type_name.endswith("AgentContext")
1020
+ or "WorkflowContext" in type_name
1021
+ or "AgentContext" in type_name
1022
+ ):
1023
+ first_type_valid = True
1024
+ # Also check if it's the actual WorkflowContext or AgentContext class
1025
+ from ..core.context import AgentContext, WorkflowContext
1026
+
1027
+ if first_annotation in (WorkflowContext, AgentContext):
1028
+ first_type_valid = True
1029
+ elif hasattr(first_annotation, "__origin__"): # Handle Union, Optional, etc.
1030
+ # For Union types, check if WorkflowContext or AgentContext is in the union
1031
+ args = getattr(first_annotation, "__args__", ())
1032
+ if WorkflowContext in args or AgentContext in args:
1033
+ first_type_valid = True
1034
+ except (ImportError, AttributeError):
1035
+ # If we can't check, allow it if the name suggests it's
1036
+ # WorkflowContext or AgentContext
1037
+ type_name = getattr(first_annotation, "__name__", None) or str(first_annotation)
1038
+ if "WorkflowContext" in type_name or "AgentContext" in type_name:
1039
+ first_type_valid = True
1040
+
1041
+ if not first_type_valid:
1042
+ raise TypeError(
1043
+ f"Workflow function '{func.__name__}': first parameter "
1044
+ f"'{first_param.name}' must be typed as WorkflowContext or "
1045
+ f"AgentContext (or untyped), got {first_annotation}"
1046
+ )
1047
+
1048
+ # Check second parameter (payload) if it exists
1049
+ payload_schema_class: type[BaseModel] | None = None
1050
+ if len(params) >= 2:
1051
+ second_param = params[1]
1052
+ second_annotation = second_param.annotation
1053
+ if second_annotation == inspect.Parameter.empty:
1054
+ raise TypeError(
1055
+ f"Workflow function '{func.__name__}': second parameter "
1056
+ f"'{second_param.name}' must be typed as "
1057
+ f"Union[BaseModel, dict[str, Any]] or a specific "
1058
+ f"Pydantic BaseModel class"
1059
+ )
1060
+
1061
+ # Check if second parameter is dict[str, Any], BaseModel, or a Pydantic model
1062
+ second_type_valid = False
1063
+ if isinstance(second_annotation, str):
1064
+ # String annotation - check for dict, BaseModel, or Union
1065
+ if "dict" in second_annotation.lower() or "Dict" in second_annotation:
1066
+ second_type_valid = True
1067
+ if "BaseModel" in second_annotation:
1068
+ second_type_valid = True
1069
+ if "Union" in second_annotation or "|" in second_annotation:
1070
+ second_type_valid = True
1071
+ else:
1072
+ # Actual type - check various cases
1073
+ try:
1074
+ # Check if it's dict type
1075
+ if second_annotation is dict or (
1076
+ hasattr(second_annotation, "__origin__")
1077
+ and getattr(second_annotation, "__origin__", None) is dict
1078
+ ):
1079
+ second_type_valid = True
1080
+
1081
+ # Check if it's BaseModel or a subclass
1082
+ if (
1083
+ issubclass(second_annotation, BaseModel)
1084
+ if inspect.isclass(second_annotation)
1085
+ else False
1086
+ ):
1087
+ second_type_valid = True
1088
+ # Extract payload schema class for later use
1089
+ payload_schema_class = second_annotation
1090
+
1091
+ # Check if it's a Union type containing dict or BaseModel
1092
+ if hasattr(second_annotation, "__origin__"):
1093
+ origin = getattr(second_annotation, "__origin__", None)
1094
+ if origin is Union or (
1095
+ hasattr(origin, "__name__") and origin.__name__ == "Union"
1096
+ ):
1097
+ args = getattr(second_annotation, "__args__", ())
1098
+ for arg in args:
1099
+ if arg is dict or (
1100
+ inspect.isclass(arg) and issubclass(arg, BaseModel)
1101
+ ):
1102
+ second_type_valid = True
1103
+ # Extract BaseModel if present
1104
+ if inspect.isclass(arg) and issubclass(arg, BaseModel):
1105
+ payload_schema_class = arg
1106
+ break
1107
+ except (TypeError, AttributeError):
1108
+ pass
1109
+
1110
+ if not second_type_valid:
1111
+ raise TypeError(
1112
+ f"Workflow function '{func.__name__}': second parameter "
1113
+ f"'{second_param.name}' must be typed as "
1114
+ f"Union[BaseModel, dict[str, Any]] or a specific "
1115
+ f"Pydantic BaseModel class, got {second_annotation}"
1116
+ )
1117
+
1118
+ # Determine workflow ID
1119
+ workflow_id = id if id is not None else func.__name__
1120
+
1121
+ # Parse queue configuration
1122
+ queue_name: str | None = None
1123
+ queue_concurrency_limit: int | None = None
1124
+
1125
+ # Determine if workflow is schedulable
1126
+ is_schedulable = False
1127
+ if schedule is True:
1128
+ is_schedulable = True
1129
+ elif schedule is False:
1130
+ is_schedulable = False
1131
+ elif schedule is not None:
1132
+ # schedule is a cron string or dict - workflow is schedulable
1133
+ is_schedulable = True
1134
+
1135
+ # Scheduled workflows cannot specify queues
1136
+ if is_schedulable and queue is not None:
1137
+ raise ValueError("Scheduled workflows cannot specify a queue.")
1138
+
1139
+ if queue is not None:
1140
+ if isinstance(queue, str):
1141
+ # Queue name string
1142
+ queue_name = queue
1143
+ elif isinstance(queue, Queue):
1144
+ # Queue object
1145
+ queue_name = queue.name
1146
+ queue_concurrency_limit = queue.concurrency_limit
1147
+ elif isinstance(queue, dict):
1148
+ # Dict with concurrency_limit
1149
+ queue_name = queue.get(
1150
+ "name", workflow_id
1151
+ ) # Use workflow_id as queue name if not provided
1152
+ queue_concurrency_limit = queue.get("concurrency_limit")
1153
+ else:
1154
+ raise ValueError(
1155
+ f"Invalid queue type: {type(queue)}. Expected str, Queue, or dict."
1156
+ )
1157
+
1158
+ workflow_obj = Workflow(
1159
+ id=workflow_id,
1160
+ func=func,
1161
+ queue_name=queue_name,
1162
+ queue_concurrency_limit=queue_concurrency_limit,
1163
+ trigger_on_event=trigger_on_event,
1164
+ batch_size=batch_size,
1165
+ batch_timeout_seconds=batch_timeout_seconds,
1166
+ schedule=schedule,
1167
+ on_start=on_start,
1168
+ on_end=on_end,
1169
+ payload_schema_class=payload_schema_class if len(params) >= 2 else None,
1170
+ state_schema=state_schema,
1171
+ )
1172
+ _WORKFLOW_REGISTRY[workflow_id] = workflow_obj
1173
+ return workflow_obj
1174
+
1175
+ # Handle @workflow (without parentheses) - the function is passed as the first argument
1176
+ if callable(id):
1177
+ func = id
1178
+ id = None
1179
+ return decorator(func)
1180
+
1181
+ # Handle @workflow() or @workflow(id="...", queue=...)
1182
+ return decorator
1183
+
1184
+
1185
+ def get_workflow(workflow_id: str) -> Workflow | None:
1186
+ """Get a workflow by ID from the registry."""
1187
+ return _WORKFLOW_REGISTRY.get(workflow_id)
1188
+
1189
+
1190
+ def get_all_workflows() -> dict[str, Workflow]:
1191
+ """Get all registered workflows."""
1192
+ return _WORKFLOW_REGISTRY.copy()