polos-sdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. polos/__init__.py +105 -0
  2. polos/agents/__init__.py +7 -0
  3. polos/agents/agent.py +746 -0
  4. polos/agents/conversation_history.py +121 -0
  5. polos/agents/stop_conditions.py +280 -0
  6. polos/agents/stream.py +635 -0
  7. polos/core/__init__.py +0 -0
  8. polos/core/context.py +143 -0
  9. polos/core/state.py +26 -0
  10. polos/core/step.py +1380 -0
  11. polos/core/workflow.py +1192 -0
  12. polos/features/__init__.py +0 -0
  13. polos/features/events.py +456 -0
  14. polos/features/schedules.py +110 -0
  15. polos/features/tracing.py +605 -0
  16. polos/features/wait.py +82 -0
  17. polos/llm/__init__.py +9 -0
  18. polos/llm/generate.py +152 -0
  19. polos/llm/providers/__init__.py +5 -0
  20. polos/llm/providers/anthropic.py +615 -0
  21. polos/llm/providers/azure.py +42 -0
  22. polos/llm/providers/base.py +196 -0
  23. polos/llm/providers/fireworks.py +41 -0
  24. polos/llm/providers/gemini.py +40 -0
  25. polos/llm/providers/groq.py +40 -0
  26. polos/llm/providers/openai.py +1021 -0
  27. polos/llm/providers/together.py +40 -0
  28. polos/llm/stream.py +183 -0
  29. polos/middleware/__init__.py +0 -0
  30. polos/middleware/guardrail.py +148 -0
  31. polos/middleware/guardrail_executor.py +253 -0
  32. polos/middleware/hook.py +164 -0
  33. polos/middleware/hook_executor.py +104 -0
  34. polos/runtime/__init__.py +0 -0
  35. polos/runtime/batch.py +87 -0
  36. polos/runtime/client.py +841 -0
  37. polos/runtime/queue.py +42 -0
  38. polos/runtime/worker.py +1365 -0
  39. polos/runtime/worker_server.py +249 -0
  40. polos/tools/__init__.py +0 -0
  41. polos/tools/tool.py +587 -0
  42. polos/types/__init__.py +23 -0
  43. polos/types/types.py +116 -0
  44. polos/utils/__init__.py +27 -0
  45. polos/utils/agent.py +27 -0
  46. polos/utils/client_context.py +41 -0
  47. polos/utils/config.py +12 -0
  48. polos/utils/output_schema.py +311 -0
  49. polos/utils/retry.py +47 -0
  50. polos/utils/serializer.py +167 -0
  51. polos/utils/tracing.py +27 -0
  52. polos/utils/worker_singleton.py +40 -0
  53. polos_sdk-0.1.0.dist-info/METADATA +650 -0
  54. polos_sdk-0.1.0.dist-info/RECORD +55 -0
  55. polos_sdk-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,1365 @@
1
+ """Worker class for executing Polos workflows, agents, and tools."""
2
+
3
+ import asyncio
4
+ import json
5
+ import logging
6
+ import os
7
+ import signal
8
+ import traceback
9
+ from datetime import datetime
10
+ from typing import Any
11
+
12
+ import httpx
13
+ from dotenv import load_dotenv
14
+ from pydantic import BaseModel
15
+
16
+ from ..agents.agent import Agent
17
+ from ..core.workflow import _WORKFLOW_REGISTRY, StepExecutionError, Workflow
18
+ from ..features.wait import WaitException
19
+ from ..tools.tool import Tool
20
+ from ..utils.config import is_localhost_url
21
+ from ..utils.worker_singleton import set_current_worker
22
+ from .client import PolosClient
23
+
24
+ # FastAPI imports for push mode
25
+ try:
26
+ from .worker_server import WorkerServer
27
+
28
+ FASTAPI_AVAILABLE = True
29
+ except ImportError:
30
+ FASTAPI_AVAILABLE = False
31
+ WorkerServer = None
32
+
33
+ load_dotenv()
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ class Worker:
38
+ """
39
+ Polos worker that executes workflows and agents.
40
+
41
+ The worker:
42
+ 1. Registers with the orchestrator (creates/replaces deployment)
43
+ 2. Registers agent and tool definitions
44
+ 3. Registers all workflows/agents in deployment_workflows table
45
+ 4. Polls orchestrator for workflows and executes them
46
+
47
+ Usage:
48
+ from polos import Worker, Agent, Tool, PolosClient
49
+
50
+ client = PolosClient(api_url="http://localhost:8080")
51
+
52
+ # Define your workflows
53
+ research_agent = Agent(...)
54
+ analysis_agent = Agent(...)
55
+
56
+ # Create worker
57
+ worker = Worker(
58
+ client=client,
59
+ deployment_id=os.getenv("WORKER_DEPLOYMENT_ID"),
60
+ agents=[research_agent, analysis_agent],
61
+ tools=[search_web],
62
+ workflows=[step_condition_workflow],
63
+ )
64
+
65
+ # Run worker (blocks until shutdown)
66
+ await worker.run()
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ client: PolosClient,
72
+ deployment_id: str,
73
+ agents: list[Agent] | None = None,
74
+ tools: list[Tool] | None = None,
75
+ workflows: list[Workflow] | None = None,
76
+ max_concurrent_workflows: int | None = None,
77
+ mode: str = "push", # "push" or "pull"
78
+ worker_server_url: str | None = None, # Required if mode="push"
79
+ ):
80
+ """
81
+ Initialize worker.
82
+
83
+ Args:
84
+ client: PolosClient instance (required)
85
+ deployment_id: Required deployment ID (unique identifier for the deployment)
86
+ agents: List of Agent instances to register
87
+ tools: List of Tool instances to register
88
+ workflows: List of Workflow instances to register
89
+ max_concurrent_workflows: Maximum number of workflows to execute in parallel.
90
+ If not provided, reads from POLOS_MAX_CONCURRENT_WORKFLOWS env var
91
+ (default: 100)
92
+ mode: Worker mode - "push" (default) or "pull".
93
+ Push mode uses FastAPI server to receive work.
94
+ worker_server_url: Full URL for worker server endpoint
95
+ (e.g., "https://worker.example.com").
96
+ If not provided and mode="push", will be auto-generated from
97
+ POLOS_WORKER_SERVER_URL env var or default to "http://localhost:8000"
98
+
99
+ Raises:
100
+ ValueError: If deployment_id is not provided,
101
+ or if mode="push" but FastAPI unavailable
102
+ """
103
+ self.polos_client = client
104
+ self.deployment_id = deployment_id or os.getenv("POLOS_DEPLOYMENT_ID")
105
+ if not self.deployment_id:
106
+ raise ValueError(
107
+ "deployment_id is required for Worker initialization. "
108
+ "Set it via parameter or POLOS_DEPLOYMENT_ID env var."
109
+ )
110
+
111
+ # Use client's configuration
112
+ self.project_id = client.project_id
113
+ self.api_url = client.api_url
114
+
115
+ # Check if local_mode can be enabled (only allowed for localhost addresses)
116
+ local_mode_requested = os.getenv("POLOS_LOCAL_MODE", "False").lower() == "true"
117
+ is_localhost = is_localhost_url(self.api_url)
118
+ self.local_mode = local_mode_requested and is_localhost
119
+
120
+ if local_mode_requested and not is_localhost:
121
+ logger.warning(
122
+ "POLOS_LOCAL_MODE=True ignored because api_url (%s) is not localhost.",
123
+ self.api_url,
124
+ )
125
+
126
+ self.api_key = client.api_key
127
+ if not self.local_mode and not self.api_key:
128
+ raise ValueError(
129
+ "api_key is required for Worker initialization. "
130
+ "Set it via PolosClient(api_key='...') or POLOS_API_KEY environment variable. "
131
+ "Or set POLOS_LOCAL_MODE=True for local development "
132
+ "(only works with localhost URLs)."
133
+ )
134
+
135
+ # Worker mode configuration
136
+ self.mode = mode.lower()
137
+ if self.mode not in ("push", "pull"):
138
+ raise ValueError(f"mode must be 'push' or 'pull', got '{mode}'")
139
+
140
+ if self.mode == "pull":
141
+ raise ValueError("[Worker] Pull mode not supported yet. Use push mode instead.")
142
+
143
+ if self.mode == "push":
144
+ if not FASTAPI_AVAILABLE:
145
+ raise ValueError(
146
+ "FastAPI and uvicorn are required for push mode. "
147
+ "Install with: pip install fastapi uvicorn"
148
+ )
149
+
150
+ # Determine push endpoint URL
151
+ if worker_server_url:
152
+ self.worker_server_url = worker_server_url
153
+ else:
154
+ env_url = os.getenv("POLOS_WORKER_SERVER_URL")
155
+ if env_url:
156
+ self.worker_server_url = env_url
157
+ else:
158
+ self.worker_server_url = "http://localhost:8000"
159
+
160
+ self.worker_server: WorkerServer | None = None
161
+ else:
162
+ self.worker_server_url = None
163
+ self.worker_server = None
164
+
165
+ # Get max_concurrent_workflows from parameter, env var, or default
166
+ if max_concurrent_workflows is not None:
167
+ self.max_concurrent_workflows = max_concurrent_workflows
168
+ else:
169
+ env_value = os.getenv("POLOS_MAX_CONCURRENT_WORKFLOWS")
170
+ if env_value:
171
+ try:
172
+ self.max_concurrent_workflows = int(env_value)
173
+ except ValueError:
174
+ logger.warning(
175
+ "Invalid POLOS_MAX_CONCURRENT_WORKFLOWS value '%s', using default 100",
176
+ env_value,
177
+ )
178
+ self.max_concurrent_workflows = 100
179
+ else:
180
+ self.max_concurrent_workflows = 100
181
+
182
+ self.execution_semaphore = asyncio.Semaphore(self.max_concurrent_workflows)
183
+ self.active_executions: set = set()
184
+ # Store tasks for each execution (for manual cancellation)
185
+ self.execution_tasks: dict[str, asyncio.Task] = {}
186
+ self.execution_tasks_lock = asyncio.Lock()
187
+
188
+ # Build workflow registry
189
+ self.workflows_registry: dict[str, Workflow] = {}
190
+ self.agents: list[Agent] = [a for a in agents if isinstance(a, Agent)] or []
191
+ self.tools: list[Tool] = [t for t in tools if isinstance(t, Tool)] or []
192
+ self.agent_ids: list[str] = []
193
+ self.tool_ids: list[str] = []
194
+ self.workflow_ids: list[str] = []
195
+
196
+ # Process workflows list - convert stop conditions to Workflow instances
197
+ processed_workflows: list[Workflow] = []
198
+ for workflow in workflows or []:
199
+ # Regular Workflow instance
200
+ if not isinstance(workflow, Workflow):
201
+ logger.warning("Skipping non-Workflow object in workflows list: %s", workflow)
202
+ continue
203
+
204
+ processed_workflows.append(workflow)
205
+ self.workflows_registry[workflow.id] = workflow
206
+ self.workflow_ids.append(workflow.id)
207
+
208
+ # Store processed workflows (all are Workflow instances now)
209
+ self.workflows: list[Workflow] = processed_workflows
210
+
211
+ # Register all agents and tools in local registry
212
+ for agent in self.agents:
213
+ if isinstance(agent, Agent):
214
+ self.workflows_registry[agent.id] = agent
215
+ self.agent_ids.append(agent.id)
216
+
217
+ for tool in self.tools:
218
+ if isinstance(tool, Tool):
219
+ self.workflows_registry[tool.id] = tool
220
+ self.tool_ids.append(tool.id)
221
+
222
+ # Worker state
223
+ self.worker_id: str | None = None
224
+ self.running = False
225
+ self.poll_task: asyncio.Task | None = None
226
+ self.heartbeat_task: asyncio.Task | None = None
227
+ self.worker_server_task: asyncio.Task | None = None
228
+
229
+ # Reusable HTTP client for polling operations
230
+ self.client: httpx.AsyncClient | None = None
231
+
232
+ async def run(self):
233
+ """Run the worker (blocks until shutdown)."""
234
+ logger.info("Starting worker...")
235
+ logger.info("Deployment ID: %s", self.deployment_id)
236
+ logger.info("Orchestrator: %s", self.api_url)
237
+
238
+ self.client = httpx.AsyncClient(timeout=httpx.Timeout(35.0, connect=5.0))
239
+
240
+ # Register with orchestrator
241
+ await self._register()
242
+
243
+ # Register deployment
244
+ await self._register_deployment()
245
+
246
+ # Register agents, tools, and workflows
247
+ await self._register_agents()
248
+ await self._register_tools()
249
+ await self._register_workflows()
250
+
251
+ # Register/update queues used by workflows, and agents
252
+ await self._register_queues()
253
+
254
+ # Mark worker as online after all registrations are complete
255
+ await self._mark_online()
256
+
257
+ self.running = True
258
+
259
+ # Register this worker instance so client.py can reuse its HTTP client
260
+ # and so features can access the client
261
+ set_current_worker(self)
262
+
263
+ # Setup signal handlers
264
+ def signal_handler(sig):
265
+ """Handle shutdown signals."""
266
+ logger.info("Received signal %s, shutting down...", sig)
267
+ # Create shutdown task
268
+ asyncio.create_task(self.shutdown())
269
+
270
+ loop = asyncio.get_event_loop()
271
+ for sig in (signal.SIGINT, signal.SIGTERM):
272
+ loop.add_signal_handler(sig, lambda s=sig: signal_handler(s))
273
+
274
+ # Start tasks based on mode
275
+ if self.mode == "push":
276
+ # Initialize push server
277
+ self._setup_worker_server()
278
+
279
+ # Start FastAPI server for push mode
280
+ self.worker_server_task = asyncio.create_task(self.worker_server.run())
281
+ self.heartbeat_task = asyncio.create_task(self._heartbeat_loop())
282
+ try:
283
+ await asyncio.gather(
284
+ self.worker_server_task, self.heartbeat_task, return_exceptions=True
285
+ )
286
+ except Exception as e:
287
+ logger.error("Error in worker tasks: %s", e)
288
+ else:
289
+ # Start polling and heartbeat tasks for pull mode
290
+ self.poll_task = asyncio.create_task(self._poll_loop())
291
+ self.heartbeat_task = asyncio.create_task(self._heartbeat_loop())
292
+ try:
293
+ await asyncio.gather(self.poll_task, self.heartbeat_task, return_exceptions=True)
294
+ except Exception as e:
295
+ logger.error("Error in worker tasks: %s", e)
296
+
297
+ async def _register(self):
298
+ """Register worker with orchestrator."""
299
+ try:
300
+ headers = self._get_headers()
301
+ registration_data = {
302
+ "deployment_id": self.deployment_id,
303
+ "project_id": self.project_id,
304
+ "mode": self.mode,
305
+ "capabilities": {
306
+ "runtime": "python",
307
+ "agent_ids": self.agent_ids,
308
+ "tool_ids": self.tool_ids,
309
+ "workflow_ids": self.workflow_ids,
310
+ },
311
+ "max_concurrent_executions": self.max_concurrent_workflows,
312
+ }
313
+
314
+ # Add worker_server_url if in push mode
315
+ if self.mode == "push":
316
+ registration_data["push_endpoint_url"] = self.worker_server_url
317
+
318
+ response = await self.client.post(
319
+ f"{self.api_url}/api/v1/workers/register",
320
+ json=registration_data,
321
+ headers=headers,
322
+ )
323
+ response.raise_for_status()
324
+ data = response.json()
325
+ self.worker_id = data["worker_id"]
326
+
327
+ logger.info("Registered: %s (mode: %s)", self.worker_id, self.mode)
328
+ except Exception as error:
329
+ logger.error("Registration failed: %s", error)
330
+ raise
331
+
332
+ async def _mark_online(self):
333
+ """Mark worker as online after completing all registrations."""
334
+ try:
335
+ headers = self._get_headers()
336
+ response = await self.client.post(
337
+ f"{self.api_url}/api/v1/workers/{self.worker_id}/online",
338
+ headers=headers,
339
+ )
340
+ response.raise_for_status()
341
+ logger.info("Marked as online: %s", self.worker_id)
342
+ except Exception as error:
343
+ logger.warning("Failed to mark worker as online: %s", error)
344
+ # Don't raise - allow worker to continue, heartbeat will update status
345
+
346
+ async def _re_register(self):
347
+ """Re-register worker, deployment, agents, tools, workflows, and queues."""
348
+ try:
349
+ # Register with orchestrator (gets new worker_id)
350
+ await self._register()
351
+
352
+ # Register deployment
353
+ await self._register_deployment()
354
+
355
+ # Register agents, tools, and workflows
356
+ await self._register_agents()
357
+ await self._register_tools()
358
+ await self._register_workflows()
359
+
360
+ # Register/update queues used by workflows, and agents
361
+ await self._register_queues()
362
+
363
+ # Mark worker as online after all registrations are complete
364
+ await self._mark_online()
365
+
366
+ # Update push server with new worker_id if in push mode
367
+ if self.mode == "push" and self.worker_server:
368
+ self.worker_server.update_worker_id(self.worker_id)
369
+ logger.debug("Updated push server with new worker_id: %s", self.worker_id)
370
+
371
+ logger.info("Re-registration complete: %s", self.worker_id)
372
+ except Exception as error:
373
+ logger.error("Re-registration failed: %s", error)
374
+ # Don't raise - allow heartbeat to continue and retry later
375
+
376
+ async def _register_deployment(self):
377
+ """Create or replace deployment in orchestrator."""
378
+ try:
379
+ headers = self._get_headers()
380
+ response = await self.client.post(
381
+ f"{self.api_url}/api/v1/workers/deployments",
382
+ json={
383
+ "deployment_id": self.deployment_id,
384
+ },
385
+ headers=headers,
386
+ )
387
+ response.raise_for_status()
388
+
389
+ logger.info("Deployment registered: %s", self.deployment_id)
390
+ except Exception as error:
391
+ logger.error("Deployment registration failed: %s", error)
392
+ raise
393
+
394
+ async def _register_agents(self):
395
+ """Register agent definitions and add to deployment_workflows."""
396
+ for agent in self.agents:
397
+ try:
398
+ # Register agent definition
399
+ headers = self._get_headers()
400
+ # Get tool definitions for agent
401
+ tools_json = None
402
+ if agent.tools:
403
+ tools_list = []
404
+ for tool in agent.tools:
405
+ if isinstance(tool, Tool):
406
+ tools_list.append(tool.to_llm_tool_definition())
407
+ elif isinstance(tool, dict):
408
+ tools_list.append(tool)
409
+ if tools_list:
410
+ import json
411
+
412
+ tools_json = json.loads(json.dumps(tools_list))
413
+
414
+ # Build metadata with stop condition function names and guardrail info
415
+ metadata = {}
416
+ if agent.provider_base_url:
417
+ metadata["provider_base_url"] = agent.provider_base_url
418
+
419
+ # Add stop condition function names
420
+ if agent.stop_conditions:
421
+ stop_condition_names = []
422
+ for sc in agent.stop_conditions:
423
+ # Get the function name from the configured callable
424
+ # Check for __stop_condition_name__ attribute first (set by decorator)
425
+ if hasattr(sc, "__stop_condition_name__"):
426
+ stop_condition_names.append(sc.__stop_condition_name__)
427
+ # Fall back to __name__ attribute
428
+ elif hasattr(sc, "__name__"):
429
+ stop_condition_names.append(sc.__name__)
430
+ else:
431
+ # Last resort: use string representation
432
+ stop_condition_names.append(str(sc))
433
+
434
+ if stop_condition_names:
435
+ metadata["stop_conditions"] = stop_condition_names
436
+
437
+ # Add guardrail function names and strings
438
+ if agent.guardrails:
439
+ guardrail_info = []
440
+ for gr in agent.guardrails:
441
+ if callable(gr):
442
+ # Get function name for callable guardrails
443
+ if hasattr(gr, "__name__") and gr.__name__ != "<lambda>":
444
+ guardrail_info.append({"type": "function", "name": gr.__name__})
445
+ else:
446
+ guardrail_info.append({"type": "function", "name": str(gr)})
447
+ elif isinstance(gr, str):
448
+ # Include string guardrails (truncate if too long for readability)
449
+ truncated = gr[:200] + "..." if len(gr) > 200 else gr
450
+ guardrail_info.append({"type": "string", "content": truncated})
451
+
452
+ if guardrail_info:
453
+ metadata["guardrails"] = guardrail_info
454
+
455
+ # Set to None if empty to avoid sending empty dict
456
+ if not metadata:
457
+ metadata = None
458
+
459
+ response = await self.client.post(
460
+ f"{self.api_url}/api/v1/agents/register",
461
+ json={
462
+ "id": agent.id,
463
+ "deployment_id": self.deployment_id,
464
+ "provider": agent.provider,
465
+ "model": agent.model,
466
+ "system_prompt": agent.system_prompt,
467
+ "tools": tools_json,
468
+ "temperature": agent.temperature,
469
+ "max_output_tokens": agent.max_output_tokens,
470
+ "metadata": metadata,
471
+ },
472
+ headers=headers,
473
+ )
474
+ response.raise_for_status()
475
+
476
+ # Register in deployment_workflows
477
+ await self._register_deployment_workflow(agent.id, "agent")
478
+
479
+ logger.debug("Registered agent: %s", agent.id)
480
+ except Exception as error:
481
+ logger.error("Failed to register agent %s: %s", agent.id, error)
482
+ raise
483
+
484
+ async def _register_tools(self):
485
+ """Register tool definitions."""
486
+ for tool in self.tools:
487
+ try:
488
+ tool_type = tool.get_tool_type()
489
+ metadata = tool.get_tool_metadata()
490
+
491
+ # Register tool definition
492
+ headers = self._get_headers()
493
+ response = await self.client.post(
494
+ f"{self.api_url}/api/v1/tools/register",
495
+ json={
496
+ "id": tool.id,
497
+ "deployment_id": self.deployment_id,
498
+ "tool_type": tool_type,
499
+ "description": tool._tool_description,
500
+ "parameters": tool._tool_parameters,
501
+ "metadata": metadata,
502
+ },
503
+ headers=headers,
504
+ )
505
+ response.raise_for_status()
506
+
507
+ # Register in deployment_workflows
508
+ await self._register_deployment_workflow(tool.id, "tool")
509
+
510
+ logger.debug("Registered tool: %s (type: %s)", tool.id, tool_type)
511
+ except Exception as error:
512
+ logger.error("Failed to register tool %s: %s", tool.id, error)
513
+ raise
514
+
515
+ async def _register_workflows(self):
516
+ """Register workflows in deployment_workflows, event triggers, and schedules."""
517
+ # self.workflows now only contains Workflow instances (stop conditions already converted)
518
+ for workflow in self.workflows:
519
+ try:
520
+ # Check if workflow is event-triggered (trigger_on_event is the topic string if set)
521
+ is_event_triggered = (
522
+ hasattr(workflow, "trigger_on_event") and workflow.trigger_on_event is not None
523
+ )
524
+ event_topic = workflow.trigger_on_event if is_event_triggered else None
525
+
526
+ # Check if workflow is schedulable (schedule=True or has cron string/dict)
527
+ is_schedulable = getattr(workflow, "is_schedulable", False)
528
+
529
+ # Register in deployment_workflows with boolean flags
530
+ await self._register_deployment_workflow(
531
+ workflow.id, "workflow", is_event_triggered, is_schedulable
532
+ )
533
+
534
+ # Register event trigger if workflow has trigger_on_event
535
+ if is_event_triggered and event_topic:
536
+ queue_name = workflow.queue_name or workflow.id
537
+ await self._register_event_trigger(
538
+ workflow.id,
539
+ event_topic,
540
+ getattr(workflow, "batch_size", 1),
541
+ getattr(workflow, "batch_timeout_seconds", None),
542
+ queue_name,
543
+ )
544
+
545
+ # Register schedule if workflow has a cron schedule (not just schedule=True)
546
+ schedule_config = getattr(workflow, "schedule", None)
547
+ if schedule_config and schedule_config is not True and schedule_config is not False:
548
+ # schedule is a cron string or dict - register it
549
+ await self._register_schedule(workflow)
550
+
551
+ logger.debug("Registered workflow: %s", workflow.id)
552
+ except Exception as error:
553
+ logger.error("Failed to register workflow %s: %s", workflow.id, error)
554
+ raise
555
+
556
+ async def _register_event_trigger(
557
+ self,
558
+ workflow_id: str,
559
+ event_topic: str,
560
+ batch_size: int,
561
+ batch_timeout_seconds: int | None,
562
+ queue_name: str,
563
+ ):
564
+ """Register an event trigger for a workflow."""
565
+ headers = self._get_headers()
566
+ response = await self.client.post(
567
+ f"{self.api_url}/api/v1/event-triggers/register",
568
+ json={
569
+ "workflow_id": workflow_id,
570
+ "deployment_id": self.deployment_id,
571
+ "event_topic": event_topic,
572
+ "batch_size": batch_size,
573
+ "batch_timeout_seconds": batch_timeout_seconds,
574
+ "queue_name": queue_name,
575
+ },
576
+ headers=headers,
577
+ )
578
+ response.raise_for_status()
579
+
580
+ async def _register_queues(self):
581
+ """Register/update queues used by workflows, agents, and tools."""
582
+ # Collect all unique queues from workflows, agents, and tools
583
+ queues: dict[str, int | None] = {}
584
+
585
+ # Collect from workflows
586
+ for workflow in self.workflows:
587
+ # Skip scheduled workflows - they get their own queues registered separately
588
+ is_schedulable = getattr(workflow, "is_schedulable", False)
589
+ if is_schedulable:
590
+ continue
591
+
592
+ # If queue_name is None, workflow will use workflow.id as queue name at runtime
593
+ queue_name = getattr(workflow, "queue_name", None) or workflow.id
594
+
595
+ # Scheduled workflows always have concurrency=1
596
+ if getattr(workflow, "is_schedulable", False):
597
+ queue_limit = 1
598
+ else:
599
+ queue_limit = getattr(workflow, "queue_concurrency_limit", None)
600
+
601
+ # If queue already exists, use the more restrictive limit if both are set
602
+ if queue_name in queues:
603
+ if queue_limit is not None and queues[queue_name] is not None:
604
+ queues[queue_name] = min(queues[queue_name], queue_limit)
605
+ elif queue_limit is not None:
606
+ queues[queue_name] = queue_limit
607
+ else:
608
+ queues[queue_name] = queue_limit
609
+
610
+ # Collect from agents
611
+ for agent in self.agents:
612
+ # If queue_name is None, agent will use agent.id as queue name at runtime
613
+ queue_name = getattr(agent, "queue_name", None) or agent.id
614
+ queue_limit = getattr(agent, "queue_concurrency_limit", None)
615
+ if queue_name in queues:
616
+ if queue_limit is not None and queues[queue_name] is not None:
617
+ queues[queue_name] = min(queues[queue_name], queue_limit)
618
+ elif queue_limit is not None:
619
+ queues[queue_name] = queue_limit
620
+ else:
621
+ queues[queue_name] = queue_limit
622
+
623
+ # Collect from tools
624
+ for tool in self.tools:
625
+ # If queue_name is None, agent will use agent.id as queue name at runtime
626
+ queue_name = getattr(tool, "queue_name", None) or tool.id
627
+ queue_limit = getattr(tool, "queue_concurrency_limit", None)
628
+ if queue_name in queues:
629
+ if queue_limit is not None and queues[queue_name] is not None:
630
+ queues[queue_name] = min(queues[queue_name], queue_limit)
631
+ elif queue_limit is not None:
632
+ queues[queue_name] = queue_limit
633
+ else:
634
+ queues[queue_name] = queue_limit
635
+
636
+ # Batch register/update all queues
637
+ if queues:
638
+ try:
639
+ headers = self._get_headers()
640
+ # Convert dict to list of queue info dicts
641
+ queues_list = [
642
+ {"name": name, "concurrency_limit": limit} for name, limit in queues.items()
643
+ ]
644
+ response = await self.client.post(
645
+ f"{self.api_url}/api/v1/workers/queues",
646
+ json={"deployment_id": self.deployment_id, "queues": queues_list},
647
+ headers=headers,
648
+ )
649
+ response.raise_for_status()
650
+ logger.info(
651
+ "Registered/updated %d queue(s) for deployment %s",
652
+ len(queues),
653
+ self.deployment_id,
654
+ )
655
+ except Exception as error:
656
+ logger.error("Failed to register queues: %s", error)
657
+ # Don't raise - queue registration failure shouldn't stop worker startup
658
+
659
+ async def _register_deployment_workflow(
660
+ self,
661
+ workflow_id: str,
662
+ workflow_type: str,
663
+ trigger_on_event: bool = False,
664
+ scheduled: bool = False,
665
+ ):
666
+ """Register a workflow/agent in deployment_workflows table."""
667
+ headers = self._get_headers()
668
+ request_body = {
669
+ "workflow_id": workflow_id,
670
+ "workflow_type": workflow_type,
671
+ "trigger_on_event": trigger_on_event,
672
+ "scheduled": scheduled,
673
+ }
674
+ response = await self.client.post(
675
+ f"{self.api_url}/api/v1/workers/deployments/{self.deployment_id}/workflows",
676
+ json=request_body,
677
+ headers=headers,
678
+ )
679
+ response.raise_for_status()
680
+
681
+ async def _register_schedule(self, workflow: Workflow):
682
+ """Register a schedule for a workflow."""
683
+ schedule_config = workflow.schedule
684
+
685
+ # Parse schedule configuration
686
+ if isinstance(schedule_config, str):
687
+ # Simple cron string - use UTC timezone and "global" key
688
+ cron = schedule_config
689
+ timezone = "UTC"
690
+ key = "global"
691
+ elif isinstance(schedule_config, dict):
692
+ # Dict with cron and optional timezone and key
693
+ cron = schedule_config.get("cron")
694
+ timezone = schedule_config.get("timezone", "UTC")
695
+ key = schedule_config.get("key", "global") # Default to "global" if not provided
696
+ if not cron:
697
+ raise ValueError("Schedule dict must contain 'cron' key")
698
+ else:
699
+ raise ValueError(f"Invalid schedule type: {type(schedule_config)}")
700
+
701
+ headers = self._get_headers()
702
+ request_body = {
703
+ "workflow_id": workflow.id,
704
+ "cron": cron,
705
+ "timezone": timezone,
706
+ "key": key,
707
+ }
708
+
709
+ response = await self.client.post(
710
+ f"{self.api_url}/api/v1/schedules",
711
+ json=request_body,
712
+ headers=headers,
713
+ )
714
+ response.raise_for_status()
715
+ logger.info(
716
+ "Registered schedule for workflow: %s (cron: %s, timezone: %s, key: %s)",
717
+ workflow.id,
718
+ cron,
719
+ timezone,
720
+ key,
721
+ )
722
+
723
+ async def _poll_loop(self):
724
+ """Continuously poll for workflows (batch)."""
725
+ while self.running:
726
+ try:
727
+ if not self.worker_id:
728
+ await asyncio.sleep(1)
729
+ continue
730
+
731
+ # Calculate available slots
732
+ available_slots = self.max_concurrent_workflows - len(self.active_executions)
733
+ if available_slots <= 0:
734
+ # No available slots, wait a bit before polling again
735
+ await asyncio.sleep(0.1)
736
+ continue
737
+
738
+ headers = self._get_headers()
739
+ # Poll for multiple workflows (up to available slots)
740
+ response = await self.client.get(
741
+ f"{self.api_url}/api/v1/workers/{self.worker_id}/poll",
742
+ params={"max_workflows": available_slots},
743
+ headers=headers,
744
+ )
745
+
746
+ if response.status_code != 200:
747
+ await asyncio.sleep(1)
748
+ continue
749
+
750
+ workflows_data = response.json()
751
+
752
+ if workflows_data:
753
+ logger.debug(
754
+ "Received %d workflow(s) (requested %d, active: %d)",
755
+ len(workflows_data),
756
+ available_slots,
757
+ len(self.active_executions),
758
+ )
759
+ # Execute all workflows in background
760
+ for workflow_data in workflows_data:
761
+
762
+ async def execute_with_error_handling(exec_data):
763
+ import contextlib
764
+
765
+ with contextlib.suppress(Exception):
766
+ # Exceptions are already handled in _execute_workflow
767
+ # This just prevents "Task exception was never retrieved" warning
768
+ await self._execute_workflow_with_semaphore(exec_data)
769
+
770
+ asyncio.create_task(execute_with_error_handling(workflow_data))
771
+ # If no workflows, will continue polling (long poll timeout handled by httpx)
772
+
773
+ except asyncio.CancelledError:
774
+ break
775
+ except httpx.TimeoutException:
776
+ # Expected on long poll timeout
777
+ pass
778
+ except Exception as error:
779
+ logger.error("Poll error: %s", error)
780
+ await asyncio.sleep(1) # Wait before retrying on error
781
+
782
+ async def _execute_workflow_with_semaphore(self, workflow_data: dict[str, Any]):
783
+ """Execute a workflow with semaphore control for concurrency limiting."""
784
+ execution_id = workflow_data["execution_id"]
785
+
786
+ # Acquire semaphore (blocks if at max concurrency)
787
+ async with self.execution_semaphore:
788
+ self.active_executions.add(execution_id)
789
+ try:
790
+ await self._execute_workflow(workflow_data)
791
+ finally:
792
+ self.active_executions.discard(execution_id)
793
+
794
+ async def _execute_workflow(self, workflow_data: dict[str, Any]):
795
+ """Execute a workflow from the registry."""
796
+ execution_id = workflow_data["execution_id"]
797
+ run_timeout_seconds = workflow_data.get("run_timeout_seconds")
798
+
799
+ try:
800
+ workflow_id = workflow_data["workflow_id"]
801
+ # First check worker's local registry (explicitly registered workflows)
802
+ workflow = self.workflows_registry.get(workflow_id)
803
+
804
+ # If not found, check global registry (for system workflows and other workflows)
805
+ if not workflow:
806
+ workflow = _WORKFLOW_REGISTRY.get(workflow_id)
807
+
808
+ if not workflow:
809
+ raise ValueError(
810
+ f"Workflow {workflow_id} not found in registry or global workflow registry"
811
+ )
812
+
813
+ # Build context
814
+ context = {
815
+ "execution_id": execution_id,
816
+ "deployment_id": workflow_data.get("deployment_id"),
817
+ "parent_execution_id": workflow_data.get("parent_execution_id"),
818
+ "root_execution_id": workflow_data.get("root_execution_id"),
819
+ "retry_count": workflow_data.get("retry_count", 0),
820
+ "session_id": workflow_data.get("session_id"),
821
+ "user_id": workflow_data.get("user_id"),
822
+ "otel_traceparent": workflow_data.get("otel_traceparent"),
823
+ "otel_span_id": workflow_data.get("otel_span_id"),
824
+ "initial_state": workflow_data.get("initial_state"),
825
+ "run_timeout_seconds": run_timeout_seconds,
826
+ }
827
+
828
+ payload = workflow_data["payload"]
829
+ created_at_str = workflow_data.get("created_at")
830
+
831
+ # Parse created_at if provided
832
+ created_at = None
833
+ if created_at_str:
834
+ import contextlib
835
+
836
+ with contextlib.suppress(ValueError, AttributeError):
837
+ created_at = datetime.fromisoformat(created_at_str.replace("Z", "+00:00"))
838
+
839
+ context["created_at"] = created_at
840
+
841
+ # Create task for workflow execution and store it for cancellation
842
+ workflow_task = asyncio.create_task(workflow._execute(context, payload))
843
+ async with self.execution_tasks_lock:
844
+ self.execution_tasks[execution_id] = workflow_task
845
+
846
+ # Check for timeout
847
+ timeout_task = None
848
+ if run_timeout_seconds:
849
+
850
+ async def check_timeout():
851
+ """Background task to check for timeout."""
852
+ try:
853
+ await asyncio.sleep(run_timeout_seconds)
854
+ # Timeout reached - check if execution is still running
855
+ async with self.execution_tasks_lock:
856
+ task = self.execution_tasks.get(execution_id)
857
+ if task and not task.done():
858
+ # Timeout reached, cancel the execution
859
+ task.cancel()
860
+ logger.warning(
861
+ "Execution %s timed out after %d seconds",
862
+ execution_id,
863
+ run_timeout_seconds,
864
+ )
865
+ except asyncio.CancelledError:
866
+ # Execution was cancelled manually, ignore
867
+ pass
868
+
869
+ timeout_task = asyncio.create_task(check_timeout())
870
+
871
+ try:
872
+ # Execute workflow
873
+ result, final_state = await workflow_task
874
+ except asyncio.CancelledError:
875
+ # Execution was cancelled (either manually or due to timeout)
876
+ logger.info("Execution %s was cancelled", execution_id)
877
+ await self._handle_cancellation(execution_id, workflow_id, context)
878
+ raise
879
+ finally:
880
+ # Clean up
881
+ if timeout_task:
882
+ timeout_task.cancel()
883
+ import contextlib
884
+
885
+ with contextlib.suppress(asyncio.CancelledError):
886
+ await timeout_task
887
+
888
+ async with self.execution_tasks_lock:
889
+ self.execution_tasks.pop(execution_id, None)
890
+
891
+ # Prepare result for reporting:
892
+ # - If it's a Pydantic model, convert to dict via model_dump(mode="json") and
893
+ # store schema name
894
+ # - Otherwise, ensure it's JSON-serializable via json.dumps()
895
+ prepared_result = result
896
+ output_schema_name = None
897
+ try:
898
+ if isinstance(result, BaseModel):
899
+ prepared_result = result.model_dump(mode="json")
900
+ # Store full module path for Pydantic model reconstruction
901
+ output_schema_name = (
902
+ f"{result.__class__.__module__}.{result.__class__.__name__}"
903
+ )
904
+ else:
905
+ # Validate JSON serializability; json.dumps will raise on failure
906
+ json.dumps(result)
907
+ except (TypeError, ValueError) as e:
908
+ # Serialization failed; propagate so it is handled by the outer except block
909
+ raise TypeError(
910
+ f"Workflow result is not JSON serializable: "
911
+ f"{type(result).__name__}. Error: {str(e)}"
912
+ ) from e
913
+
914
+ # Report success with already validated/serialized-safe result and schema name
915
+ await self._report_success(
916
+ workflow_data["execution_id"], prepared_result, output_schema_name, final_state
917
+ )
918
+
919
+ except WaitException as e:
920
+ # WaitException is expected when a workflow waits for a sub-workflow
921
+ # The orchestrator will resume the execution when the sub-workflow completes
922
+ # Do not report this as a failure - it's the normal wait mechanism
923
+ logger.debug("Workflow paused for waiting: %s", e)
924
+ return
925
+
926
+ except Exception as error:
927
+ # Capture the full stack trace
928
+ error_message = str(error)
929
+ stack_trace = traceback.format_exc()
930
+ logger.error("Execution error: %s\nStack trace:\n%s", error_message, stack_trace)
931
+
932
+ # Extract final_state from execution context if workflow has state_schema
933
+ workflow_id = workflow_data.get("workflow_id")
934
+ workflow = self.workflows_registry.get(workflow_id) or _WORKFLOW_REGISTRY.get(
935
+ workflow_id
936
+ )
937
+
938
+ # Check if error is StepExecutionError - if so, mark as non-retryable
939
+ # Tools are not retryable by default. We feed the error back to the LLM to handle.
940
+ retryable = (
941
+ not isinstance(error, StepExecutionError) and workflow.workflow_type != "tool"
942
+ )
943
+ await self._report_failure(
944
+ workflow_data["execution_id"], error_message, stack_trace, retryable=retryable
945
+ )
946
+ raise
947
+
948
+ def _setup_worker_server(self):
949
+ """Setup FastAPI server for push mode."""
950
+ if not FASTAPI_AVAILABLE:
951
+ raise RuntimeError("FastAPI not available")
952
+
953
+ # Create callback function that executes workflow
954
+ async def on_work_received(workflow_data: dict[str, Any]):
955
+ """Callback to handle received work."""
956
+ await self._execute_workflow_with_semaphore(workflow_data)
957
+
958
+ # Create callback function that handles cancel requests
959
+ async def on_cancel_requested(execution_id: str) -> bool:
960
+ """Callback to handle cancel requests.
961
+
962
+ Returns:
963
+ True if execution was found and cancelled, False if not found or already completed
964
+ """
965
+ return await self._handle_cancel_request(execution_id)
966
+
967
+ # Initialize worker server
968
+ self.worker_server = WorkerServer(
969
+ worker_id=self.worker_id,
970
+ max_concurrent_workflows=self.max_concurrent_workflows,
971
+ on_work_received=on_work_received,
972
+ on_cancel_requested=on_cancel_requested,
973
+ local_mode=self.local_mode,
974
+ )
975
+
976
+ logger.info("Worker server initialized")
977
+
978
+ async def _report_success(
979
+ self,
980
+ execution_id: str,
981
+ result: Any,
982
+ output_schema_name: str | None = None,
983
+ final_state: dict[str, Any] | None = None,
984
+ ):
985
+ """Report successful workflow execution with retries and exponential backoff."""
986
+ max_retries = 5
987
+ base_delay = 1.0 # Start with 1 second
988
+
989
+ for attempt in range(max_retries):
990
+ try:
991
+ headers = self._get_headers()
992
+ payload = {
993
+ "result": result,
994
+ "worker_id": self.worker_id, # Include worker_id for validation
995
+ }
996
+ if output_schema_name:
997
+ payload["output_schema_name"] = output_schema_name
998
+ if final_state is not None:
999
+ payload["final_state"] = final_state
1000
+ response = await self.client.post(
1001
+ f"{self.api_url}/internal/executions/{execution_id}/complete",
1002
+ json=payload,
1003
+ headers=headers,
1004
+ )
1005
+
1006
+ # Handle 409 Conflict (execution reassigned)
1007
+ if response.status_code == 409:
1008
+ logger.debug(
1009
+ "Execution %s was reassigned [old worker %s], ignoring completion",
1010
+ execution_id,
1011
+ self.worker_id,
1012
+ )
1013
+ return # Don't retry on 409
1014
+
1015
+ response.raise_for_status()
1016
+ # Success - return immediately
1017
+ return
1018
+ except httpx.HTTPStatusError as e:
1019
+ if e.response.status_code == 409:
1020
+ # 409 Conflict - execution reassigned, don't retry
1021
+ logger.debug("Execution %s was reassigned, ignoring completion", execution_id)
1022
+ return
1023
+ if attempt < max_retries - 1:
1024
+ # Calculate exponential backoff: 1s, 2s, 4s, 8s, 16s
1025
+ delay = base_delay * (2**attempt)
1026
+ logger.warning(
1027
+ "Failed to report success (attempt %d/%d): %s. Retrying in %ds...",
1028
+ attempt + 1,
1029
+ max_retries,
1030
+ e,
1031
+ delay,
1032
+ )
1033
+ await asyncio.sleep(delay)
1034
+ else:
1035
+ # Final attempt failed - report as failure with error message
1036
+ error_msg = f"Failed to report success after {max_retries} attempts: {e}"
1037
+ logger.error("%s", error_msg)
1038
+ # Don't call _report_failure here to avoid infinite loop
1039
+ except Exception as error:
1040
+ if attempt < max_retries - 1:
1041
+ # Calculate exponential backoff: 1s, 2s, 4s, 8s, 16s
1042
+ delay = base_delay * (2**attempt)
1043
+ logger.warning(
1044
+ "Failed to report success (attempt %d/%d): %s. Retrying in %ds...",
1045
+ attempt + 1,
1046
+ max_retries,
1047
+ error,
1048
+ delay,
1049
+ )
1050
+ await asyncio.sleep(delay)
1051
+ else:
1052
+ # Final attempt failed - report as failure with error message
1053
+ error_msg = f"Failed to report success after {max_retries} attempts: {error}"
1054
+ logger.error("%s", error_msg)
1055
+ await self._report_failure(execution_id, error_msg, retryable=True)
1056
+
1057
+ async def _report_failure(
1058
+ self,
1059
+ execution_id: str,
1060
+ error_message: str,
1061
+ stack_trace: str | None = None,
1062
+ retryable: bool = True,
1063
+ final_state: dict[str, Any] | None = None,
1064
+ ):
1065
+ """Report failed workflow execution with retries and exponential backoff.
1066
+
1067
+ Args:
1068
+ execution_id: Execution ID to report failure for
1069
+ error_message: Error message
1070
+ stack_trace: Optional stack trace
1071
+ retryable: Whether the execution should be retried (default: True)
1072
+ """
1073
+ max_retries = 5
1074
+ base_delay = 1.0 # Start with 1 second
1075
+
1076
+ for attempt in range(max_retries):
1077
+ try:
1078
+ headers = self._get_headers()
1079
+ payload = {
1080
+ "error": error_message,
1081
+ "worker_id": self.worker_id, # Include worker_id for validation
1082
+ }
1083
+ if stack_trace:
1084
+ payload["stack"] = stack_trace
1085
+ if not retryable:
1086
+ payload["retryable"] = False
1087
+ if final_state is not None:
1088
+ payload["final_state"] = final_state
1089
+ response = await self.client.post(
1090
+ f"{self.api_url}/internal/executions/{execution_id}/fail",
1091
+ json=payload,
1092
+ headers=headers,
1093
+ )
1094
+
1095
+ # Handle 409 Conflict (execution reassigned)
1096
+ if response.status_code == 409:
1097
+ logger.debug("Execution %s was reassigned, ignoring failure", execution_id)
1098
+ return # Don't retry on 409
1099
+
1100
+ response.raise_for_status()
1101
+ # Success - return immediately
1102
+ return
1103
+ except httpx.HTTPStatusError as e:
1104
+ if e.response.status_code == 409:
1105
+ # 409 Conflict - execution reassigned, don't retry
1106
+ logger.debug("Execution %s was reassigned, ignoring failure", execution_id)
1107
+ return
1108
+ if attempt < max_retries - 1:
1109
+ # Calculate exponential backoff: 1s, 2s, 4s, 8s, 16s
1110
+ delay = base_delay * (2**attempt)
1111
+ logger.warning(
1112
+ "Failed to report failure (attempt %d/%d): %s. Retrying in %ds...",
1113
+ attempt + 1,
1114
+ max_retries,
1115
+ e,
1116
+ delay,
1117
+ )
1118
+ await asyncio.sleep(delay)
1119
+ else:
1120
+ # Final attempt failed
1121
+ logger.error("Failed to report failure after %d attempts: %s", max_retries, e)
1122
+ except Exception as error:
1123
+ if attempt < max_retries - 1:
1124
+ # Calculate exponential backoff: 1s, 2s, 4s, 8s, 16s
1125
+ delay = base_delay * (2**attempt)
1126
+ logger.warning(
1127
+ "Failed to report failure (attempt %d/%d): %s. Retrying in %ds...",
1128
+ attempt + 1,
1129
+ max_retries,
1130
+ error,
1131
+ delay,
1132
+ )
1133
+ await asyncio.sleep(delay)
1134
+ else:
1135
+ # Final attempt failed
1136
+ logger.error(
1137
+ "Failed to report failure after %d attempts: %s", max_retries, error
1138
+ )
1139
+
1140
+ async def _handle_cancel_request(self, execution_id: str) -> bool:
1141
+ """Handle cancellation request from orchestrator.
1142
+
1143
+ Returns:
1144
+ True if execution was found and cancelled, False if not found or already completed
1145
+ """
1146
+ logger.info("Handling cancellation request for execution %s", execution_id)
1147
+ async with self.execution_tasks_lock:
1148
+ task = self.execution_tasks.get(execution_id)
1149
+ if task and not task.done():
1150
+ logger.debug("Cancelling task %s", task)
1151
+ # Cancel the task
1152
+ task.cancel()
1153
+ logger.info("Cancellation requested for execution %s", execution_id)
1154
+ return True
1155
+ else:
1156
+ logger.debug("Execution %s not found or already completed", execution_id)
1157
+ return False
1158
+
1159
+ async def _handle_cancellation(
1160
+ self, execution_id: str, workflow_id: str, context: dict[str, Any]
1161
+ ):
1162
+ """Handle execution cancellation - send confirmation and emit event."""
1163
+ try:
1164
+ logger.info("Sending cancellation confirmation for execution %s", execution_id)
1165
+ # Send cancel confirmation to orchestrator
1166
+ await self._send_cancel_confirmation(execution_id)
1167
+
1168
+ # Emit cancellation event
1169
+ await self._emit_cancellation_event(execution_id, workflow_id, context)
1170
+ except Exception as e:
1171
+ logger.error("Error handling cancellation for %s: %s", execution_id, e)
1172
+
1173
+ async def _send_cancel_confirmation(self, execution_id: str):
1174
+ """Send cancellation confirmation to orchestrator."""
1175
+ max_retries = 5
1176
+ base_delay = 1.0
1177
+
1178
+ for attempt in range(max_retries):
1179
+ try:
1180
+ headers = self._get_headers()
1181
+ payload = {
1182
+ "worker_id": self.worker_id,
1183
+ }
1184
+ response = await self.client.post(
1185
+ f"{self.api_url}/internal/executions/{execution_id}/confirm-cancellation",
1186
+ json=payload,
1187
+ headers=headers,
1188
+ )
1189
+
1190
+ # Handle 409 Conflict (execution reassigned)
1191
+ if response.status_code == 409:
1192
+ logger.debug(
1193
+ "Execution %s was reassigned, ignoring cancellation confirmation",
1194
+ execution_id,
1195
+ )
1196
+ return
1197
+
1198
+ response.raise_for_status()
1199
+ logger.info("Sent cancellation confirmation for execution %s", execution_id)
1200
+ return
1201
+ except httpx.HTTPStatusError as e:
1202
+ if e.response.status_code == 409:
1203
+ logger.debug(
1204
+ "Execution %s was reassigned, ignoring cancellation confirmation",
1205
+ execution_id,
1206
+ )
1207
+ return
1208
+ if attempt < max_retries - 1:
1209
+ delay = base_delay * (2**attempt)
1210
+ logger.warning(
1211
+ "Failed to send cancellation confirmation "
1212
+ "(attempt %d/%d): %s. Retrying in %ds...",
1213
+ attempt + 1,
1214
+ max_retries,
1215
+ e,
1216
+ delay,
1217
+ )
1218
+ await asyncio.sleep(delay)
1219
+ else:
1220
+ logger.error(
1221
+ "Failed to send cancellation confirmation after %d attempts: %s",
1222
+ max_retries,
1223
+ e,
1224
+ )
1225
+ except Exception as error:
1226
+ if attempt < max_retries - 1:
1227
+ delay = base_delay * (2**attempt)
1228
+ logger.warning(
1229
+ "Failed to send cancellation confirmation "
1230
+ "(attempt %d/%d): %s. Retrying in %ds...",
1231
+ attempt + 1,
1232
+ max_retries,
1233
+ error,
1234
+ delay,
1235
+ )
1236
+ await asyncio.sleep(delay)
1237
+ else:
1238
+ logger.error(
1239
+ "Failed to send cancellation confirmation after %d attempts: %s",
1240
+ max_retries,
1241
+ error,
1242
+ )
1243
+
1244
+ async def _emit_cancellation_event(
1245
+ self, execution_id: str, workflow_id: str, context: dict[str, Any]
1246
+ ):
1247
+ """Emit cancellation event for the workflow."""
1248
+ try:
1249
+ from ..features.events import publish
1250
+
1251
+ # Topic format: workflow:{execution_id}
1252
+ topic = f"workflow:{context.get('root_execution_id') or execution_id}"
1253
+
1254
+ # Event type: {workflow_id}_cancel
1255
+ event_type = f"{context.get('workflow_type', 'workflow')}_cancel"
1256
+
1257
+ # Event data
1258
+ event_data = {
1259
+ "_metadata": {
1260
+ "execution_id": execution_id,
1261
+ "workflow_id": workflow_id,
1262
+ }
1263
+ }
1264
+
1265
+ # Publish event
1266
+ await publish(
1267
+ self.polos_client,
1268
+ topic=topic,
1269
+ event_type=event_type,
1270
+ data=event_data,
1271
+ execution_id=execution_id,
1272
+ root_execution_id=context.get("root_execution_id"),
1273
+ )
1274
+
1275
+ logger.debug("Emitted cancellation event for execution %s", execution_id)
1276
+ except Exception as e:
1277
+ logger.error("Failed to emit cancellation event for %s: %s", execution_id, e)
1278
+
1279
+ async def _heartbeat_loop(self):
1280
+ """Send periodic heartbeats."""
1281
+ while self.running:
1282
+ try:
1283
+ await asyncio.sleep(30)
1284
+
1285
+ if not self.worker_id:
1286
+ continue
1287
+
1288
+ headers = self._get_headers()
1289
+ response = await self.client.post(
1290
+ f"{self.api_url}/api/v1/workers/{self.worker_id}/heartbeat",
1291
+ headers=headers,
1292
+ )
1293
+ response.raise_for_status()
1294
+
1295
+ # Check if re-registration is required
1296
+ data = response.json()
1297
+ if data.get("re_register", False):
1298
+ logger.info("Orchestrator requested re-registration, re-registering...")
1299
+ await self._re_register()
1300
+ except asyncio.CancelledError:
1301
+ break
1302
+ except Exception as error:
1303
+ logger.warning("Heartbeat failed: %s", error)
1304
+
1305
+ async def shutdown(self):
1306
+ """Graceful shutdown."""
1307
+ if not self.running:
1308
+ return
1309
+
1310
+ logger.info("Shutting down gracefully...")
1311
+ self.running = False
1312
+
1313
+ # Shutdown worker server first (if in push mode)
1314
+ if self.mode == "push" and self.worker_server:
1315
+ try:
1316
+ await self.worker_server.shutdown()
1317
+ except Exception as e:
1318
+ logger.error("Error shutting down worker server: %s", e)
1319
+
1320
+ # Cancel tasks
1321
+ if self.mode == "push":
1322
+ if self.worker_server_task:
1323
+ self.worker_server_task.cancel()
1324
+ else:
1325
+ if self.poll_task:
1326
+ self.poll_task.cancel()
1327
+
1328
+ if self.heartbeat_task:
1329
+ self.heartbeat_task.cancel()
1330
+
1331
+ # Wait for cancellation
1332
+ if self.mode == "push":
1333
+ try:
1334
+ if self.worker_server_task:
1335
+ await self.worker_server_task
1336
+ except asyncio.CancelledError:
1337
+ pass
1338
+ else:
1339
+ try:
1340
+ if self.poll_task:
1341
+ await self.poll_task
1342
+ except asyncio.CancelledError:
1343
+ pass
1344
+
1345
+ try:
1346
+ if self.heartbeat_task:
1347
+ await self.heartbeat_task
1348
+ except asyncio.CancelledError:
1349
+ pass
1350
+
1351
+ # Close HTTP client
1352
+ if hasattr(self, "client") and self.client:
1353
+ try:
1354
+ await self.client.aclose()
1355
+ except Exception as e:
1356
+ logger.error("Error closing HTTP client: %s", e)
1357
+
1358
+ # Unregister this worker instance
1359
+ set_current_worker(None)
1360
+
1361
+ logger.info("Shutdown complete")
1362
+
1363
+ def _get_headers(self) -> dict[str, str]:
1364
+ """Get HTTP headers for API requests, including API key and project_id."""
1365
+ return self.polos_client._get_headers()