polos-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- polos/__init__.py +105 -0
- polos/agents/__init__.py +7 -0
- polos/agents/agent.py +746 -0
- polos/agents/conversation_history.py +121 -0
- polos/agents/stop_conditions.py +280 -0
- polos/agents/stream.py +635 -0
- polos/core/__init__.py +0 -0
- polos/core/context.py +143 -0
- polos/core/state.py +26 -0
- polos/core/step.py +1380 -0
- polos/core/workflow.py +1192 -0
- polos/features/__init__.py +0 -0
- polos/features/events.py +456 -0
- polos/features/schedules.py +110 -0
- polos/features/tracing.py +605 -0
- polos/features/wait.py +82 -0
- polos/llm/__init__.py +9 -0
- polos/llm/generate.py +152 -0
- polos/llm/providers/__init__.py +5 -0
- polos/llm/providers/anthropic.py +615 -0
- polos/llm/providers/azure.py +42 -0
- polos/llm/providers/base.py +196 -0
- polos/llm/providers/fireworks.py +41 -0
- polos/llm/providers/gemini.py +40 -0
- polos/llm/providers/groq.py +40 -0
- polos/llm/providers/openai.py +1021 -0
- polos/llm/providers/together.py +40 -0
- polos/llm/stream.py +183 -0
- polos/middleware/__init__.py +0 -0
- polos/middleware/guardrail.py +148 -0
- polos/middleware/guardrail_executor.py +253 -0
- polos/middleware/hook.py +164 -0
- polos/middleware/hook_executor.py +104 -0
- polos/runtime/__init__.py +0 -0
- polos/runtime/batch.py +87 -0
- polos/runtime/client.py +841 -0
- polos/runtime/queue.py +42 -0
- polos/runtime/worker.py +1365 -0
- polos/runtime/worker_server.py +249 -0
- polos/tools/__init__.py +0 -0
- polos/tools/tool.py +587 -0
- polos/types/__init__.py +23 -0
- polos/types/types.py +116 -0
- polos/utils/__init__.py +27 -0
- polos/utils/agent.py +27 -0
- polos/utils/client_context.py +41 -0
- polos/utils/config.py +12 -0
- polos/utils/output_schema.py +311 -0
- polos/utils/retry.py +47 -0
- polos/utils/serializer.py +167 -0
- polos/utils/tracing.py +27 -0
- polos/utils/worker_singleton.py +40 -0
- polos_sdk-0.1.0.dist-info/METADATA +650 -0
- polos_sdk-0.1.0.dist-info/RECORD +55 -0
- polos_sdk-0.1.0.dist-info/WHEEL +4 -0
polos/runtime/worker.py
ADDED
|
@@ -0,0 +1,1365 @@
|
|
|
1
|
+
"""Worker class for executing Polos workflows, agents, and tools."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
import signal
|
|
8
|
+
import traceback
|
|
9
|
+
from datetime import datetime
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
import httpx
|
|
13
|
+
from dotenv import load_dotenv
|
|
14
|
+
from pydantic import BaseModel
|
|
15
|
+
|
|
16
|
+
from ..agents.agent import Agent
|
|
17
|
+
from ..core.workflow import _WORKFLOW_REGISTRY, StepExecutionError, Workflow
|
|
18
|
+
from ..features.wait import WaitException
|
|
19
|
+
from ..tools.tool import Tool
|
|
20
|
+
from ..utils.config import is_localhost_url
|
|
21
|
+
from ..utils.worker_singleton import set_current_worker
|
|
22
|
+
from .client import PolosClient
|
|
23
|
+
|
|
24
|
+
# FastAPI imports for push mode
|
|
25
|
+
try:
|
|
26
|
+
from .worker_server import WorkerServer
|
|
27
|
+
|
|
28
|
+
FASTAPI_AVAILABLE = True
|
|
29
|
+
except ImportError:
|
|
30
|
+
FASTAPI_AVAILABLE = False
|
|
31
|
+
WorkerServer = None
|
|
32
|
+
|
|
33
|
+
load_dotenv()
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Worker:
|
|
38
|
+
"""
|
|
39
|
+
Polos worker that executes workflows and agents.
|
|
40
|
+
|
|
41
|
+
The worker:
|
|
42
|
+
1. Registers with the orchestrator (creates/replaces deployment)
|
|
43
|
+
2. Registers agent and tool definitions
|
|
44
|
+
3. Registers all workflows/agents in deployment_workflows table
|
|
45
|
+
4. Polls orchestrator for workflows and executes them
|
|
46
|
+
|
|
47
|
+
Usage:
|
|
48
|
+
from polos import Worker, Agent, Tool, PolosClient
|
|
49
|
+
|
|
50
|
+
client = PolosClient(api_url="http://localhost:8080")
|
|
51
|
+
|
|
52
|
+
# Define your workflows
|
|
53
|
+
research_agent = Agent(...)
|
|
54
|
+
analysis_agent = Agent(...)
|
|
55
|
+
|
|
56
|
+
# Create worker
|
|
57
|
+
worker = Worker(
|
|
58
|
+
client=client,
|
|
59
|
+
deployment_id=os.getenv("WORKER_DEPLOYMENT_ID"),
|
|
60
|
+
agents=[research_agent, analysis_agent],
|
|
61
|
+
tools=[search_web],
|
|
62
|
+
workflows=[step_condition_workflow],
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# Run worker (blocks until shutdown)
|
|
66
|
+
await worker.run()
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
client: PolosClient,
|
|
72
|
+
deployment_id: str,
|
|
73
|
+
agents: list[Agent] | None = None,
|
|
74
|
+
tools: list[Tool] | None = None,
|
|
75
|
+
workflows: list[Workflow] | None = None,
|
|
76
|
+
max_concurrent_workflows: int | None = None,
|
|
77
|
+
mode: str = "push", # "push" or "pull"
|
|
78
|
+
worker_server_url: str | None = None, # Required if mode="push"
|
|
79
|
+
):
|
|
80
|
+
"""
|
|
81
|
+
Initialize worker.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
client: PolosClient instance (required)
|
|
85
|
+
deployment_id: Required deployment ID (unique identifier for the deployment)
|
|
86
|
+
agents: List of Agent instances to register
|
|
87
|
+
tools: List of Tool instances to register
|
|
88
|
+
workflows: List of Workflow instances to register
|
|
89
|
+
max_concurrent_workflows: Maximum number of workflows to execute in parallel.
|
|
90
|
+
If not provided, reads from POLOS_MAX_CONCURRENT_WORKFLOWS env var
|
|
91
|
+
(default: 100)
|
|
92
|
+
mode: Worker mode - "push" (default) or "pull".
|
|
93
|
+
Push mode uses FastAPI server to receive work.
|
|
94
|
+
worker_server_url: Full URL for worker server endpoint
|
|
95
|
+
(e.g., "https://worker.example.com").
|
|
96
|
+
If not provided and mode="push", will be auto-generated from
|
|
97
|
+
POLOS_WORKER_SERVER_URL env var or default to "http://localhost:8000"
|
|
98
|
+
|
|
99
|
+
Raises:
|
|
100
|
+
ValueError: If deployment_id is not provided,
|
|
101
|
+
or if mode="push" but FastAPI unavailable
|
|
102
|
+
"""
|
|
103
|
+
self.polos_client = client
|
|
104
|
+
self.deployment_id = deployment_id or os.getenv("POLOS_DEPLOYMENT_ID")
|
|
105
|
+
if not self.deployment_id:
|
|
106
|
+
raise ValueError(
|
|
107
|
+
"deployment_id is required for Worker initialization. "
|
|
108
|
+
"Set it via parameter or POLOS_DEPLOYMENT_ID env var."
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# Use client's configuration
|
|
112
|
+
self.project_id = client.project_id
|
|
113
|
+
self.api_url = client.api_url
|
|
114
|
+
|
|
115
|
+
# Check if local_mode can be enabled (only allowed for localhost addresses)
|
|
116
|
+
local_mode_requested = os.getenv("POLOS_LOCAL_MODE", "False").lower() == "true"
|
|
117
|
+
is_localhost = is_localhost_url(self.api_url)
|
|
118
|
+
self.local_mode = local_mode_requested and is_localhost
|
|
119
|
+
|
|
120
|
+
if local_mode_requested and not is_localhost:
|
|
121
|
+
logger.warning(
|
|
122
|
+
"POLOS_LOCAL_MODE=True ignored because api_url (%s) is not localhost.",
|
|
123
|
+
self.api_url,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
self.api_key = client.api_key
|
|
127
|
+
if not self.local_mode and not self.api_key:
|
|
128
|
+
raise ValueError(
|
|
129
|
+
"api_key is required for Worker initialization. "
|
|
130
|
+
"Set it via PolosClient(api_key='...') or POLOS_API_KEY environment variable. "
|
|
131
|
+
"Or set POLOS_LOCAL_MODE=True for local development "
|
|
132
|
+
"(only works with localhost URLs)."
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# Worker mode configuration
|
|
136
|
+
self.mode = mode.lower()
|
|
137
|
+
if self.mode not in ("push", "pull"):
|
|
138
|
+
raise ValueError(f"mode must be 'push' or 'pull', got '{mode}'")
|
|
139
|
+
|
|
140
|
+
if self.mode == "pull":
|
|
141
|
+
raise ValueError("[Worker] Pull mode not supported yet. Use push mode instead.")
|
|
142
|
+
|
|
143
|
+
if self.mode == "push":
|
|
144
|
+
if not FASTAPI_AVAILABLE:
|
|
145
|
+
raise ValueError(
|
|
146
|
+
"FastAPI and uvicorn are required for push mode. "
|
|
147
|
+
"Install with: pip install fastapi uvicorn"
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
# Determine push endpoint URL
|
|
151
|
+
if worker_server_url:
|
|
152
|
+
self.worker_server_url = worker_server_url
|
|
153
|
+
else:
|
|
154
|
+
env_url = os.getenv("POLOS_WORKER_SERVER_URL")
|
|
155
|
+
if env_url:
|
|
156
|
+
self.worker_server_url = env_url
|
|
157
|
+
else:
|
|
158
|
+
self.worker_server_url = "http://localhost:8000"
|
|
159
|
+
|
|
160
|
+
self.worker_server: WorkerServer | None = None
|
|
161
|
+
else:
|
|
162
|
+
self.worker_server_url = None
|
|
163
|
+
self.worker_server = None
|
|
164
|
+
|
|
165
|
+
# Get max_concurrent_workflows from parameter, env var, or default
|
|
166
|
+
if max_concurrent_workflows is not None:
|
|
167
|
+
self.max_concurrent_workflows = max_concurrent_workflows
|
|
168
|
+
else:
|
|
169
|
+
env_value = os.getenv("POLOS_MAX_CONCURRENT_WORKFLOWS")
|
|
170
|
+
if env_value:
|
|
171
|
+
try:
|
|
172
|
+
self.max_concurrent_workflows = int(env_value)
|
|
173
|
+
except ValueError:
|
|
174
|
+
logger.warning(
|
|
175
|
+
"Invalid POLOS_MAX_CONCURRENT_WORKFLOWS value '%s', using default 100",
|
|
176
|
+
env_value,
|
|
177
|
+
)
|
|
178
|
+
self.max_concurrent_workflows = 100
|
|
179
|
+
else:
|
|
180
|
+
self.max_concurrent_workflows = 100
|
|
181
|
+
|
|
182
|
+
self.execution_semaphore = asyncio.Semaphore(self.max_concurrent_workflows)
|
|
183
|
+
self.active_executions: set = set()
|
|
184
|
+
# Store tasks for each execution (for manual cancellation)
|
|
185
|
+
self.execution_tasks: dict[str, asyncio.Task] = {}
|
|
186
|
+
self.execution_tasks_lock = asyncio.Lock()
|
|
187
|
+
|
|
188
|
+
# Build workflow registry
|
|
189
|
+
self.workflows_registry: dict[str, Workflow] = {}
|
|
190
|
+
self.agents: list[Agent] = [a for a in agents if isinstance(a, Agent)] or []
|
|
191
|
+
self.tools: list[Tool] = [t for t in tools if isinstance(t, Tool)] or []
|
|
192
|
+
self.agent_ids: list[str] = []
|
|
193
|
+
self.tool_ids: list[str] = []
|
|
194
|
+
self.workflow_ids: list[str] = []
|
|
195
|
+
|
|
196
|
+
# Process workflows list - convert stop conditions to Workflow instances
|
|
197
|
+
processed_workflows: list[Workflow] = []
|
|
198
|
+
for workflow in workflows or []:
|
|
199
|
+
# Regular Workflow instance
|
|
200
|
+
if not isinstance(workflow, Workflow):
|
|
201
|
+
logger.warning("Skipping non-Workflow object in workflows list: %s", workflow)
|
|
202
|
+
continue
|
|
203
|
+
|
|
204
|
+
processed_workflows.append(workflow)
|
|
205
|
+
self.workflows_registry[workflow.id] = workflow
|
|
206
|
+
self.workflow_ids.append(workflow.id)
|
|
207
|
+
|
|
208
|
+
# Store processed workflows (all are Workflow instances now)
|
|
209
|
+
self.workflows: list[Workflow] = processed_workflows
|
|
210
|
+
|
|
211
|
+
# Register all agents and tools in local registry
|
|
212
|
+
for agent in self.agents:
|
|
213
|
+
if isinstance(agent, Agent):
|
|
214
|
+
self.workflows_registry[agent.id] = agent
|
|
215
|
+
self.agent_ids.append(agent.id)
|
|
216
|
+
|
|
217
|
+
for tool in self.tools:
|
|
218
|
+
if isinstance(tool, Tool):
|
|
219
|
+
self.workflows_registry[tool.id] = tool
|
|
220
|
+
self.tool_ids.append(tool.id)
|
|
221
|
+
|
|
222
|
+
# Worker state
|
|
223
|
+
self.worker_id: str | None = None
|
|
224
|
+
self.running = False
|
|
225
|
+
self.poll_task: asyncio.Task | None = None
|
|
226
|
+
self.heartbeat_task: asyncio.Task | None = None
|
|
227
|
+
self.worker_server_task: asyncio.Task | None = None
|
|
228
|
+
|
|
229
|
+
# Reusable HTTP client for polling operations
|
|
230
|
+
self.client: httpx.AsyncClient | None = None
|
|
231
|
+
|
|
232
|
+
async def run(self):
|
|
233
|
+
"""Run the worker (blocks until shutdown)."""
|
|
234
|
+
logger.info("Starting worker...")
|
|
235
|
+
logger.info("Deployment ID: %s", self.deployment_id)
|
|
236
|
+
logger.info("Orchestrator: %s", self.api_url)
|
|
237
|
+
|
|
238
|
+
self.client = httpx.AsyncClient(timeout=httpx.Timeout(35.0, connect=5.0))
|
|
239
|
+
|
|
240
|
+
# Register with orchestrator
|
|
241
|
+
await self._register()
|
|
242
|
+
|
|
243
|
+
# Register deployment
|
|
244
|
+
await self._register_deployment()
|
|
245
|
+
|
|
246
|
+
# Register agents, tools, and workflows
|
|
247
|
+
await self._register_agents()
|
|
248
|
+
await self._register_tools()
|
|
249
|
+
await self._register_workflows()
|
|
250
|
+
|
|
251
|
+
# Register/update queues used by workflows, and agents
|
|
252
|
+
await self._register_queues()
|
|
253
|
+
|
|
254
|
+
# Mark worker as online after all registrations are complete
|
|
255
|
+
await self._mark_online()
|
|
256
|
+
|
|
257
|
+
self.running = True
|
|
258
|
+
|
|
259
|
+
# Register this worker instance so client.py can reuse its HTTP client
|
|
260
|
+
# and so features can access the client
|
|
261
|
+
set_current_worker(self)
|
|
262
|
+
|
|
263
|
+
# Setup signal handlers
|
|
264
|
+
def signal_handler(sig):
|
|
265
|
+
"""Handle shutdown signals."""
|
|
266
|
+
logger.info("Received signal %s, shutting down...", sig)
|
|
267
|
+
# Create shutdown task
|
|
268
|
+
asyncio.create_task(self.shutdown())
|
|
269
|
+
|
|
270
|
+
loop = asyncio.get_event_loop()
|
|
271
|
+
for sig in (signal.SIGINT, signal.SIGTERM):
|
|
272
|
+
loop.add_signal_handler(sig, lambda s=sig: signal_handler(s))
|
|
273
|
+
|
|
274
|
+
# Start tasks based on mode
|
|
275
|
+
if self.mode == "push":
|
|
276
|
+
# Initialize push server
|
|
277
|
+
self._setup_worker_server()
|
|
278
|
+
|
|
279
|
+
# Start FastAPI server for push mode
|
|
280
|
+
self.worker_server_task = asyncio.create_task(self.worker_server.run())
|
|
281
|
+
self.heartbeat_task = asyncio.create_task(self._heartbeat_loop())
|
|
282
|
+
try:
|
|
283
|
+
await asyncio.gather(
|
|
284
|
+
self.worker_server_task, self.heartbeat_task, return_exceptions=True
|
|
285
|
+
)
|
|
286
|
+
except Exception as e:
|
|
287
|
+
logger.error("Error in worker tasks: %s", e)
|
|
288
|
+
else:
|
|
289
|
+
# Start polling and heartbeat tasks for pull mode
|
|
290
|
+
self.poll_task = asyncio.create_task(self._poll_loop())
|
|
291
|
+
self.heartbeat_task = asyncio.create_task(self._heartbeat_loop())
|
|
292
|
+
try:
|
|
293
|
+
await asyncio.gather(self.poll_task, self.heartbeat_task, return_exceptions=True)
|
|
294
|
+
except Exception as e:
|
|
295
|
+
logger.error("Error in worker tasks: %s", e)
|
|
296
|
+
|
|
297
|
+
async def _register(self):
|
|
298
|
+
"""Register worker with orchestrator."""
|
|
299
|
+
try:
|
|
300
|
+
headers = self._get_headers()
|
|
301
|
+
registration_data = {
|
|
302
|
+
"deployment_id": self.deployment_id,
|
|
303
|
+
"project_id": self.project_id,
|
|
304
|
+
"mode": self.mode,
|
|
305
|
+
"capabilities": {
|
|
306
|
+
"runtime": "python",
|
|
307
|
+
"agent_ids": self.agent_ids,
|
|
308
|
+
"tool_ids": self.tool_ids,
|
|
309
|
+
"workflow_ids": self.workflow_ids,
|
|
310
|
+
},
|
|
311
|
+
"max_concurrent_executions": self.max_concurrent_workflows,
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
# Add worker_server_url if in push mode
|
|
315
|
+
if self.mode == "push":
|
|
316
|
+
registration_data["push_endpoint_url"] = self.worker_server_url
|
|
317
|
+
|
|
318
|
+
response = await self.client.post(
|
|
319
|
+
f"{self.api_url}/api/v1/workers/register",
|
|
320
|
+
json=registration_data,
|
|
321
|
+
headers=headers,
|
|
322
|
+
)
|
|
323
|
+
response.raise_for_status()
|
|
324
|
+
data = response.json()
|
|
325
|
+
self.worker_id = data["worker_id"]
|
|
326
|
+
|
|
327
|
+
logger.info("Registered: %s (mode: %s)", self.worker_id, self.mode)
|
|
328
|
+
except Exception as error:
|
|
329
|
+
logger.error("Registration failed: %s", error)
|
|
330
|
+
raise
|
|
331
|
+
|
|
332
|
+
async def _mark_online(self):
|
|
333
|
+
"""Mark worker as online after completing all registrations."""
|
|
334
|
+
try:
|
|
335
|
+
headers = self._get_headers()
|
|
336
|
+
response = await self.client.post(
|
|
337
|
+
f"{self.api_url}/api/v1/workers/{self.worker_id}/online",
|
|
338
|
+
headers=headers,
|
|
339
|
+
)
|
|
340
|
+
response.raise_for_status()
|
|
341
|
+
logger.info("Marked as online: %s", self.worker_id)
|
|
342
|
+
except Exception as error:
|
|
343
|
+
logger.warning("Failed to mark worker as online: %s", error)
|
|
344
|
+
# Don't raise - allow worker to continue, heartbeat will update status
|
|
345
|
+
|
|
346
|
+
async def _re_register(self):
|
|
347
|
+
"""Re-register worker, deployment, agents, tools, workflows, and queues."""
|
|
348
|
+
try:
|
|
349
|
+
# Register with orchestrator (gets new worker_id)
|
|
350
|
+
await self._register()
|
|
351
|
+
|
|
352
|
+
# Register deployment
|
|
353
|
+
await self._register_deployment()
|
|
354
|
+
|
|
355
|
+
# Register agents, tools, and workflows
|
|
356
|
+
await self._register_agents()
|
|
357
|
+
await self._register_tools()
|
|
358
|
+
await self._register_workflows()
|
|
359
|
+
|
|
360
|
+
# Register/update queues used by workflows, and agents
|
|
361
|
+
await self._register_queues()
|
|
362
|
+
|
|
363
|
+
# Mark worker as online after all registrations are complete
|
|
364
|
+
await self._mark_online()
|
|
365
|
+
|
|
366
|
+
# Update push server with new worker_id if in push mode
|
|
367
|
+
if self.mode == "push" and self.worker_server:
|
|
368
|
+
self.worker_server.update_worker_id(self.worker_id)
|
|
369
|
+
logger.debug("Updated push server with new worker_id: %s", self.worker_id)
|
|
370
|
+
|
|
371
|
+
logger.info("Re-registration complete: %s", self.worker_id)
|
|
372
|
+
except Exception as error:
|
|
373
|
+
logger.error("Re-registration failed: %s", error)
|
|
374
|
+
# Don't raise - allow heartbeat to continue and retry later
|
|
375
|
+
|
|
376
|
+
async def _register_deployment(self):
|
|
377
|
+
"""Create or replace deployment in orchestrator."""
|
|
378
|
+
try:
|
|
379
|
+
headers = self._get_headers()
|
|
380
|
+
response = await self.client.post(
|
|
381
|
+
f"{self.api_url}/api/v1/workers/deployments",
|
|
382
|
+
json={
|
|
383
|
+
"deployment_id": self.deployment_id,
|
|
384
|
+
},
|
|
385
|
+
headers=headers,
|
|
386
|
+
)
|
|
387
|
+
response.raise_for_status()
|
|
388
|
+
|
|
389
|
+
logger.info("Deployment registered: %s", self.deployment_id)
|
|
390
|
+
except Exception as error:
|
|
391
|
+
logger.error("Deployment registration failed: %s", error)
|
|
392
|
+
raise
|
|
393
|
+
|
|
394
|
+
async def _register_agents(self):
|
|
395
|
+
"""Register agent definitions and add to deployment_workflows."""
|
|
396
|
+
for agent in self.agents:
|
|
397
|
+
try:
|
|
398
|
+
# Register agent definition
|
|
399
|
+
headers = self._get_headers()
|
|
400
|
+
# Get tool definitions for agent
|
|
401
|
+
tools_json = None
|
|
402
|
+
if agent.tools:
|
|
403
|
+
tools_list = []
|
|
404
|
+
for tool in agent.tools:
|
|
405
|
+
if isinstance(tool, Tool):
|
|
406
|
+
tools_list.append(tool.to_llm_tool_definition())
|
|
407
|
+
elif isinstance(tool, dict):
|
|
408
|
+
tools_list.append(tool)
|
|
409
|
+
if tools_list:
|
|
410
|
+
import json
|
|
411
|
+
|
|
412
|
+
tools_json = json.loads(json.dumps(tools_list))
|
|
413
|
+
|
|
414
|
+
# Build metadata with stop condition function names and guardrail info
|
|
415
|
+
metadata = {}
|
|
416
|
+
if agent.provider_base_url:
|
|
417
|
+
metadata["provider_base_url"] = agent.provider_base_url
|
|
418
|
+
|
|
419
|
+
# Add stop condition function names
|
|
420
|
+
if agent.stop_conditions:
|
|
421
|
+
stop_condition_names = []
|
|
422
|
+
for sc in agent.stop_conditions:
|
|
423
|
+
# Get the function name from the configured callable
|
|
424
|
+
# Check for __stop_condition_name__ attribute first (set by decorator)
|
|
425
|
+
if hasattr(sc, "__stop_condition_name__"):
|
|
426
|
+
stop_condition_names.append(sc.__stop_condition_name__)
|
|
427
|
+
# Fall back to __name__ attribute
|
|
428
|
+
elif hasattr(sc, "__name__"):
|
|
429
|
+
stop_condition_names.append(sc.__name__)
|
|
430
|
+
else:
|
|
431
|
+
# Last resort: use string representation
|
|
432
|
+
stop_condition_names.append(str(sc))
|
|
433
|
+
|
|
434
|
+
if stop_condition_names:
|
|
435
|
+
metadata["stop_conditions"] = stop_condition_names
|
|
436
|
+
|
|
437
|
+
# Add guardrail function names and strings
|
|
438
|
+
if agent.guardrails:
|
|
439
|
+
guardrail_info = []
|
|
440
|
+
for gr in agent.guardrails:
|
|
441
|
+
if callable(gr):
|
|
442
|
+
# Get function name for callable guardrails
|
|
443
|
+
if hasattr(gr, "__name__") and gr.__name__ != "<lambda>":
|
|
444
|
+
guardrail_info.append({"type": "function", "name": gr.__name__})
|
|
445
|
+
else:
|
|
446
|
+
guardrail_info.append({"type": "function", "name": str(gr)})
|
|
447
|
+
elif isinstance(gr, str):
|
|
448
|
+
# Include string guardrails (truncate if too long for readability)
|
|
449
|
+
truncated = gr[:200] + "..." if len(gr) > 200 else gr
|
|
450
|
+
guardrail_info.append({"type": "string", "content": truncated})
|
|
451
|
+
|
|
452
|
+
if guardrail_info:
|
|
453
|
+
metadata["guardrails"] = guardrail_info
|
|
454
|
+
|
|
455
|
+
# Set to None if empty to avoid sending empty dict
|
|
456
|
+
if not metadata:
|
|
457
|
+
metadata = None
|
|
458
|
+
|
|
459
|
+
response = await self.client.post(
|
|
460
|
+
f"{self.api_url}/api/v1/agents/register",
|
|
461
|
+
json={
|
|
462
|
+
"id": agent.id,
|
|
463
|
+
"deployment_id": self.deployment_id,
|
|
464
|
+
"provider": agent.provider,
|
|
465
|
+
"model": agent.model,
|
|
466
|
+
"system_prompt": agent.system_prompt,
|
|
467
|
+
"tools": tools_json,
|
|
468
|
+
"temperature": agent.temperature,
|
|
469
|
+
"max_output_tokens": agent.max_output_tokens,
|
|
470
|
+
"metadata": metadata,
|
|
471
|
+
},
|
|
472
|
+
headers=headers,
|
|
473
|
+
)
|
|
474
|
+
response.raise_for_status()
|
|
475
|
+
|
|
476
|
+
# Register in deployment_workflows
|
|
477
|
+
await self._register_deployment_workflow(agent.id, "agent")
|
|
478
|
+
|
|
479
|
+
logger.debug("Registered agent: %s", agent.id)
|
|
480
|
+
except Exception as error:
|
|
481
|
+
logger.error("Failed to register agent %s: %s", agent.id, error)
|
|
482
|
+
raise
|
|
483
|
+
|
|
484
|
+
async def _register_tools(self):
|
|
485
|
+
"""Register tool definitions."""
|
|
486
|
+
for tool in self.tools:
|
|
487
|
+
try:
|
|
488
|
+
tool_type = tool.get_tool_type()
|
|
489
|
+
metadata = tool.get_tool_metadata()
|
|
490
|
+
|
|
491
|
+
# Register tool definition
|
|
492
|
+
headers = self._get_headers()
|
|
493
|
+
response = await self.client.post(
|
|
494
|
+
f"{self.api_url}/api/v1/tools/register",
|
|
495
|
+
json={
|
|
496
|
+
"id": tool.id,
|
|
497
|
+
"deployment_id": self.deployment_id,
|
|
498
|
+
"tool_type": tool_type,
|
|
499
|
+
"description": tool._tool_description,
|
|
500
|
+
"parameters": tool._tool_parameters,
|
|
501
|
+
"metadata": metadata,
|
|
502
|
+
},
|
|
503
|
+
headers=headers,
|
|
504
|
+
)
|
|
505
|
+
response.raise_for_status()
|
|
506
|
+
|
|
507
|
+
# Register in deployment_workflows
|
|
508
|
+
await self._register_deployment_workflow(tool.id, "tool")
|
|
509
|
+
|
|
510
|
+
logger.debug("Registered tool: %s (type: %s)", tool.id, tool_type)
|
|
511
|
+
except Exception as error:
|
|
512
|
+
logger.error("Failed to register tool %s: %s", tool.id, error)
|
|
513
|
+
raise
|
|
514
|
+
|
|
515
|
+
async def _register_workflows(self):
|
|
516
|
+
"""Register workflows in deployment_workflows, event triggers, and schedules."""
|
|
517
|
+
# self.workflows now only contains Workflow instances (stop conditions already converted)
|
|
518
|
+
for workflow in self.workflows:
|
|
519
|
+
try:
|
|
520
|
+
# Check if workflow is event-triggered (trigger_on_event is the topic string if set)
|
|
521
|
+
is_event_triggered = (
|
|
522
|
+
hasattr(workflow, "trigger_on_event") and workflow.trigger_on_event is not None
|
|
523
|
+
)
|
|
524
|
+
event_topic = workflow.trigger_on_event if is_event_triggered else None
|
|
525
|
+
|
|
526
|
+
# Check if workflow is schedulable (schedule=True or has cron string/dict)
|
|
527
|
+
is_schedulable = getattr(workflow, "is_schedulable", False)
|
|
528
|
+
|
|
529
|
+
# Register in deployment_workflows with boolean flags
|
|
530
|
+
await self._register_deployment_workflow(
|
|
531
|
+
workflow.id, "workflow", is_event_triggered, is_schedulable
|
|
532
|
+
)
|
|
533
|
+
|
|
534
|
+
# Register event trigger if workflow has trigger_on_event
|
|
535
|
+
if is_event_triggered and event_topic:
|
|
536
|
+
queue_name = workflow.queue_name or workflow.id
|
|
537
|
+
await self._register_event_trigger(
|
|
538
|
+
workflow.id,
|
|
539
|
+
event_topic,
|
|
540
|
+
getattr(workflow, "batch_size", 1),
|
|
541
|
+
getattr(workflow, "batch_timeout_seconds", None),
|
|
542
|
+
queue_name,
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
# Register schedule if workflow has a cron schedule (not just schedule=True)
|
|
546
|
+
schedule_config = getattr(workflow, "schedule", None)
|
|
547
|
+
if schedule_config and schedule_config is not True and schedule_config is not False:
|
|
548
|
+
# schedule is a cron string or dict - register it
|
|
549
|
+
await self._register_schedule(workflow)
|
|
550
|
+
|
|
551
|
+
logger.debug("Registered workflow: %s", workflow.id)
|
|
552
|
+
except Exception as error:
|
|
553
|
+
logger.error("Failed to register workflow %s: %s", workflow.id, error)
|
|
554
|
+
raise
|
|
555
|
+
|
|
556
|
+
async def _register_event_trigger(
|
|
557
|
+
self,
|
|
558
|
+
workflow_id: str,
|
|
559
|
+
event_topic: str,
|
|
560
|
+
batch_size: int,
|
|
561
|
+
batch_timeout_seconds: int | None,
|
|
562
|
+
queue_name: str,
|
|
563
|
+
):
|
|
564
|
+
"""Register an event trigger for a workflow."""
|
|
565
|
+
headers = self._get_headers()
|
|
566
|
+
response = await self.client.post(
|
|
567
|
+
f"{self.api_url}/api/v1/event-triggers/register",
|
|
568
|
+
json={
|
|
569
|
+
"workflow_id": workflow_id,
|
|
570
|
+
"deployment_id": self.deployment_id,
|
|
571
|
+
"event_topic": event_topic,
|
|
572
|
+
"batch_size": batch_size,
|
|
573
|
+
"batch_timeout_seconds": batch_timeout_seconds,
|
|
574
|
+
"queue_name": queue_name,
|
|
575
|
+
},
|
|
576
|
+
headers=headers,
|
|
577
|
+
)
|
|
578
|
+
response.raise_for_status()
|
|
579
|
+
|
|
580
|
+
async def _register_queues(self):
|
|
581
|
+
"""Register/update queues used by workflows, agents, and tools."""
|
|
582
|
+
# Collect all unique queues from workflows, agents, and tools
|
|
583
|
+
queues: dict[str, int | None] = {}
|
|
584
|
+
|
|
585
|
+
# Collect from workflows
|
|
586
|
+
for workflow in self.workflows:
|
|
587
|
+
# Skip scheduled workflows - they get their own queues registered separately
|
|
588
|
+
is_schedulable = getattr(workflow, "is_schedulable", False)
|
|
589
|
+
if is_schedulable:
|
|
590
|
+
continue
|
|
591
|
+
|
|
592
|
+
# If queue_name is None, workflow will use workflow.id as queue name at runtime
|
|
593
|
+
queue_name = getattr(workflow, "queue_name", None) or workflow.id
|
|
594
|
+
|
|
595
|
+
# Scheduled workflows always have concurrency=1
|
|
596
|
+
if getattr(workflow, "is_schedulable", False):
|
|
597
|
+
queue_limit = 1
|
|
598
|
+
else:
|
|
599
|
+
queue_limit = getattr(workflow, "queue_concurrency_limit", None)
|
|
600
|
+
|
|
601
|
+
# If queue already exists, use the more restrictive limit if both are set
|
|
602
|
+
if queue_name in queues:
|
|
603
|
+
if queue_limit is not None and queues[queue_name] is not None:
|
|
604
|
+
queues[queue_name] = min(queues[queue_name], queue_limit)
|
|
605
|
+
elif queue_limit is not None:
|
|
606
|
+
queues[queue_name] = queue_limit
|
|
607
|
+
else:
|
|
608
|
+
queues[queue_name] = queue_limit
|
|
609
|
+
|
|
610
|
+
# Collect from agents
|
|
611
|
+
for agent in self.agents:
|
|
612
|
+
# If queue_name is None, agent will use agent.id as queue name at runtime
|
|
613
|
+
queue_name = getattr(agent, "queue_name", None) or agent.id
|
|
614
|
+
queue_limit = getattr(agent, "queue_concurrency_limit", None)
|
|
615
|
+
if queue_name in queues:
|
|
616
|
+
if queue_limit is not None and queues[queue_name] is not None:
|
|
617
|
+
queues[queue_name] = min(queues[queue_name], queue_limit)
|
|
618
|
+
elif queue_limit is not None:
|
|
619
|
+
queues[queue_name] = queue_limit
|
|
620
|
+
else:
|
|
621
|
+
queues[queue_name] = queue_limit
|
|
622
|
+
|
|
623
|
+
# Collect from tools
|
|
624
|
+
for tool in self.tools:
|
|
625
|
+
# If queue_name is None, agent will use agent.id as queue name at runtime
|
|
626
|
+
queue_name = getattr(tool, "queue_name", None) or tool.id
|
|
627
|
+
queue_limit = getattr(tool, "queue_concurrency_limit", None)
|
|
628
|
+
if queue_name in queues:
|
|
629
|
+
if queue_limit is not None and queues[queue_name] is not None:
|
|
630
|
+
queues[queue_name] = min(queues[queue_name], queue_limit)
|
|
631
|
+
elif queue_limit is not None:
|
|
632
|
+
queues[queue_name] = queue_limit
|
|
633
|
+
else:
|
|
634
|
+
queues[queue_name] = queue_limit
|
|
635
|
+
|
|
636
|
+
# Batch register/update all queues
|
|
637
|
+
if queues:
|
|
638
|
+
try:
|
|
639
|
+
headers = self._get_headers()
|
|
640
|
+
# Convert dict to list of queue info dicts
|
|
641
|
+
queues_list = [
|
|
642
|
+
{"name": name, "concurrency_limit": limit} for name, limit in queues.items()
|
|
643
|
+
]
|
|
644
|
+
response = await self.client.post(
|
|
645
|
+
f"{self.api_url}/api/v1/workers/queues",
|
|
646
|
+
json={"deployment_id": self.deployment_id, "queues": queues_list},
|
|
647
|
+
headers=headers,
|
|
648
|
+
)
|
|
649
|
+
response.raise_for_status()
|
|
650
|
+
logger.info(
|
|
651
|
+
"Registered/updated %d queue(s) for deployment %s",
|
|
652
|
+
len(queues),
|
|
653
|
+
self.deployment_id,
|
|
654
|
+
)
|
|
655
|
+
except Exception as error:
|
|
656
|
+
logger.error("Failed to register queues: %s", error)
|
|
657
|
+
# Don't raise - queue registration failure shouldn't stop worker startup
|
|
658
|
+
|
|
659
|
+
async def _register_deployment_workflow(
|
|
660
|
+
self,
|
|
661
|
+
workflow_id: str,
|
|
662
|
+
workflow_type: str,
|
|
663
|
+
trigger_on_event: bool = False,
|
|
664
|
+
scheduled: bool = False,
|
|
665
|
+
):
|
|
666
|
+
"""Register a workflow/agent in deployment_workflows table."""
|
|
667
|
+
headers = self._get_headers()
|
|
668
|
+
request_body = {
|
|
669
|
+
"workflow_id": workflow_id,
|
|
670
|
+
"workflow_type": workflow_type,
|
|
671
|
+
"trigger_on_event": trigger_on_event,
|
|
672
|
+
"scheduled": scheduled,
|
|
673
|
+
}
|
|
674
|
+
response = await self.client.post(
|
|
675
|
+
f"{self.api_url}/api/v1/workers/deployments/{self.deployment_id}/workflows",
|
|
676
|
+
json=request_body,
|
|
677
|
+
headers=headers,
|
|
678
|
+
)
|
|
679
|
+
response.raise_for_status()
|
|
680
|
+
|
|
681
|
+
async def _register_schedule(self, workflow: Workflow):
|
|
682
|
+
"""Register a schedule for a workflow."""
|
|
683
|
+
schedule_config = workflow.schedule
|
|
684
|
+
|
|
685
|
+
# Parse schedule configuration
|
|
686
|
+
if isinstance(schedule_config, str):
|
|
687
|
+
# Simple cron string - use UTC timezone and "global" key
|
|
688
|
+
cron = schedule_config
|
|
689
|
+
timezone = "UTC"
|
|
690
|
+
key = "global"
|
|
691
|
+
elif isinstance(schedule_config, dict):
|
|
692
|
+
# Dict with cron and optional timezone and key
|
|
693
|
+
cron = schedule_config.get("cron")
|
|
694
|
+
timezone = schedule_config.get("timezone", "UTC")
|
|
695
|
+
key = schedule_config.get("key", "global") # Default to "global" if not provided
|
|
696
|
+
if not cron:
|
|
697
|
+
raise ValueError("Schedule dict must contain 'cron' key")
|
|
698
|
+
else:
|
|
699
|
+
raise ValueError(f"Invalid schedule type: {type(schedule_config)}")
|
|
700
|
+
|
|
701
|
+
headers = self._get_headers()
|
|
702
|
+
request_body = {
|
|
703
|
+
"workflow_id": workflow.id,
|
|
704
|
+
"cron": cron,
|
|
705
|
+
"timezone": timezone,
|
|
706
|
+
"key": key,
|
|
707
|
+
}
|
|
708
|
+
|
|
709
|
+
response = await self.client.post(
|
|
710
|
+
f"{self.api_url}/api/v1/schedules",
|
|
711
|
+
json=request_body,
|
|
712
|
+
headers=headers,
|
|
713
|
+
)
|
|
714
|
+
response.raise_for_status()
|
|
715
|
+
logger.info(
|
|
716
|
+
"Registered schedule for workflow: %s (cron: %s, timezone: %s, key: %s)",
|
|
717
|
+
workflow.id,
|
|
718
|
+
cron,
|
|
719
|
+
timezone,
|
|
720
|
+
key,
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
async def _poll_loop(self):
|
|
724
|
+
"""Continuously poll for workflows (batch)."""
|
|
725
|
+
while self.running:
|
|
726
|
+
try:
|
|
727
|
+
if not self.worker_id:
|
|
728
|
+
await asyncio.sleep(1)
|
|
729
|
+
continue
|
|
730
|
+
|
|
731
|
+
# Calculate available slots
|
|
732
|
+
available_slots = self.max_concurrent_workflows - len(self.active_executions)
|
|
733
|
+
if available_slots <= 0:
|
|
734
|
+
# No available slots, wait a bit before polling again
|
|
735
|
+
await asyncio.sleep(0.1)
|
|
736
|
+
continue
|
|
737
|
+
|
|
738
|
+
headers = self._get_headers()
|
|
739
|
+
# Poll for multiple workflows (up to available slots)
|
|
740
|
+
response = await self.client.get(
|
|
741
|
+
f"{self.api_url}/api/v1/workers/{self.worker_id}/poll",
|
|
742
|
+
params={"max_workflows": available_slots},
|
|
743
|
+
headers=headers,
|
|
744
|
+
)
|
|
745
|
+
|
|
746
|
+
if response.status_code != 200:
|
|
747
|
+
await asyncio.sleep(1)
|
|
748
|
+
continue
|
|
749
|
+
|
|
750
|
+
workflows_data = response.json()
|
|
751
|
+
|
|
752
|
+
if workflows_data:
|
|
753
|
+
logger.debug(
|
|
754
|
+
"Received %d workflow(s) (requested %d, active: %d)",
|
|
755
|
+
len(workflows_data),
|
|
756
|
+
available_slots,
|
|
757
|
+
len(self.active_executions),
|
|
758
|
+
)
|
|
759
|
+
# Execute all workflows in background
|
|
760
|
+
for workflow_data in workflows_data:
|
|
761
|
+
|
|
762
|
+
async def execute_with_error_handling(exec_data):
|
|
763
|
+
import contextlib
|
|
764
|
+
|
|
765
|
+
with contextlib.suppress(Exception):
|
|
766
|
+
# Exceptions are already handled in _execute_workflow
|
|
767
|
+
# This just prevents "Task exception was never retrieved" warning
|
|
768
|
+
await self._execute_workflow_with_semaphore(exec_data)
|
|
769
|
+
|
|
770
|
+
asyncio.create_task(execute_with_error_handling(workflow_data))
|
|
771
|
+
# If no workflows, will continue polling (long poll timeout handled by httpx)
|
|
772
|
+
|
|
773
|
+
except asyncio.CancelledError:
|
|
774
|
+
break
|
|
775
|
+
except httpx.TimeoutException:
|
|
776
|
+
# Expected on long poll timeout
|
|
777
|
+
pass
|
|
778
|
+
except Exception as error:
|
|
779
|
+
logger.error("Poll error: %s", error)
|
|
780
|
+
await asyncio.sleep(1) # Wait before retrying on error
|
|
781
|
+
|
|
782
|
+
async def _execute_workflow_with_semaphore(self, workflow_data: dict[str, Any]):
|
|
783
|
+
"""Execute a workflow with semaphore control for concurrency limiting."""
|
|
784
|
+
execution_id = workflow_data["execution_id"]
|
|
785
|
+
|
|
786
|
+
# Acquire semaphore (blocks if at max concurrency)
|
|
787
|
+
async with self.execution_semaphore:
|
|
788
|
+
self.active_executions.add(execution_id)
|
|
789
|
+
try:
|
|
790
|
+
await self._execute_workflow(workflow_data)
|
|
791
|
+
finally:
|
|
792
|
+
self.active_executions.discard(execution_id)
|
|
793
|
+
|
|
794
|
+
async def _execute_workflow(self, workflow_data: dict[str, Any]):
|
|
795
|
+
"""Execute a workflow from the registry."""
|
|
796
|
+
execution_id = workflow_data["execution_id"]
|
|
797
|
+
run_timeout_seconds = workflow_data.get("run_timeout_seconds")
|
|
798
|
+
|
|
799
|
+
try:
|
|
800
|
+
workflow_id = workflow_data["workflow_id"]
|
|
801
|
+
# First check worker's local registry (explicitly registered workflows)
|
|
802
|
+
workflow = self.workflows_registry.get(workflow_id)
|
|
803
|
+
|
|
804
|
+
# If not found, check global registry (for system workflows and other workflows)
|
|
805
|
+
if not workflow:
|
|
806
|
+
workflow = _WORKFLOW_REGISTRY.get(workflow_id)
|
|
807
|
+
|
|
808
|
+
if not workflow:
|
|
809
|
+
raise ValueError(
|
|
810
|
+
f"Workflow {workflow_id} not found in registry or global workflow registry"
|
|
811
|
+
)
|
|
812
|
+
|
|
813
|
+
# Build context
|
|
814
|
+
context = {
|
|
815
|
+
"execution_id": execution_id,
|
|
816
|
+
"deployment_id": workflow_data.get("deployment_id"),
|
|
817
|
+
"parent_execution_id": workflow_data.get("parent_execution_id"),
|
|
818
|
+
"root_execution_id": workflow_data.get("root_execution_id"),
|
|
819
|
+
"retry_count": workflow_data.get("retry_count", 0),
|
|
820
|
+
"session_id": workflow_data.get("session_id"),
|
|
821
|
+
"user_id": workflow_data.get("user_id"),
|
|
822
|
+
"otel_traceparent": workflow_data.get("otel_traceparent"),
|
|
823
|
+
"otel_span_id": workflow_data.get("otel_span_id"),
|
|
824
|
+
"initial_state": workflow_data.get("initial_state"),
|
|
825
|
+
"run_timeout_seconds": run_timeout_seconds,
|
|
826
|
+
}
|
|
827
|
+
|
|
828
|
+
payload = workflow_data["payload"]
|
|
829
|
+
created_at_str = workflow_data.get("created_at")
|
|
830
|
+
|
|
831
|
+
# Parse created_at if provided
|
|
832
|
+
created_at = None
|
|
833
|
+
if created_at_str:
|
|
834
|
+
import contextlib
|
|
835
|
+
|
|
836
|
+
with contextlib.suppress(ValueError, AttributeError):
|
|
837
|
+
created_at = datetime.fromisoformat(created_at_str.replace("Z", "+00:00"))
|
|
838
|
+
|
|
839
|
+
context["created_at"] = created_at
|
|
840
|
+
|
|
841
|
+
# Create task for workflow execution and store it for cancellation
|
|
842
|
+
workflow_task = asyncio.create_task(workflow._execute(context, payload))
|
|
843
|
+
async with self.execution_tasks_lock:
|
|
844
|
+
self.execution_tasks[execution_id] = workflow_task
|
|
845
|
+
|
|
846
|
+
# Check for timeout
|
|
847
|
+
timeout_task = None
|
|
848
|
+
if run_timeout_seconds:
|
|
849
|
+
|
|
850
|
+
async def check_timeout():
|
|
851
|
+
"""Background task to check for timeout."""
|
|
852
|
+
try:
|
|
853
|
+
await asyncio.sleep(run_timeout_seconds)
|
|
854
|
+
# Timeout reached - check if execution is still running
|
|
855
|
+
async with self.execution_tasks_lock:
|
|
856
|
+
task = self.execution_tasks.get(execution_id)
|
|
857
|
+
if task and not task.done():
|
|
858
|
+
# Timeout reached, cancel the execution
|
|
859
|
+
task.cancel()
|
|
860
|
+
logger.warning(
|
|
861
|
+
"Execution %s timed out after %d seconds",
|
|
862
|
+
execution_id,
|
|
863
|
+
run_timeout_seconds,
|
|
864
|
+
)
|
|
865
|
+
except asyncio.CancelledError:
|
|
866
|
+
# Execution was cancelled manually, ignore
|
|
867
|
+
pass
|
|
868
|
+
|
|
869
|
+
timeout_task = asyncio.create_task(check_timeout())
|
|
870
|
+
|
|
871
|
+
try:
|
|
872
|
+
# Execute workflow
|
|
873
|
+
result, final_state = await workflow_task
|
|
874
|
+
except asyncio.CancelledError:
|
|
875
|
+
# Execution was cancelled (either manually or due to timeout)
|
|
876
|
+
logger.info("Execution %s was cancelled", execution_id)
|
|
877
|
+
await self._handle_cancellation(execution_id, workflow_id, context)
|
|
878
|
+
raise
|
|
879
|
+
finally:
|
|
880
|
+
# Clean up
|
|
881
|
+
if timeout_task:
|
|
882
|
+
timeout_task.cancel()
|
|
883
|
+
import contextlib
|
|
884
|
+
|
|
885
|
+
with contextlib.suppress(asyncio.CancelledError):
|
|
886
|
+
await timeout_task
|
|
887
|
+
|
|
888
|
+
async with self.execution_tasks_lock:
|
|
889
|
+
self.execution_tasks.pop(execution_id, None)
|
|
890
|
+
|
|
891
|
+
# Prepare result for reporting:
|
|
892
|
+
# - If it's a Pydantic model, convert to dict via model_dump(mode="json") and
|
|
893
|
+
# store schema name
|
|
894
|
+
# - Otherwise, ensure it's JSON-serializable via json.dumps()
|
|
895
|
+
prepared_result = result
|
|
896
|
+
output_schema_name = None
|
|
897
|
+
try:
|
|
898
|
+
if isinstance(result, BaseModel):
|
|
899
|
+
prepared_result = result.model_dump(mode="json")
|
|
900
|
+
# Store full module path for Pydantic model reconstruction
|
|
901
|
+
output_schema_name = (
|
|
902
|
+
f"{result.__class__.__module__}.{result.__class__.__name__}"
|
|
903
|
+
)
|
|
904
|
+
else:
|
|
905
|
+
# Validate JSON serializability; json.dumps will raise on failure
|
|
906
|
+
json.dumps(result)
|
|
907
|
+
except (TypeError, ValueError) as e:
|
|
908
|
+
# Serialization failed; propagate so it is handled by the outer except block
|
|
909
|
+
raise TypeError(
|
|
910
|
+
f"Workflow result is not JSON serializable: "
|
|
911
|
+
f"{type(result).__name__}. Error: {str(e)}"
|
|
912
|
+
) from e
|
|
913
|
+
|
|
914
|
+
# Report success with already validated/serialized-safe result and schema name
|
|
915
|
+
await self._report_success(
|
|
916
|
+
workflow_data["execution_id"], prepared_result, output_schema_name, final_state
|
|
917
|
+
)
|
|
918
|
+
|
|
919
|
+
except WaitException as e:
|
|
920
|
+
# WaitException is expected when a workflow waits for a sub-workflow
|
|
921
|
+
# The orchestrator will resume the execution when the sub-workflow completes
|
|
922
|
+
# Do not report this as a failure - it's the normal wait mechanism
|
|
923
|
+
logger.debug("Workflow paused for waiting: %s", e)
|
|
924
|
+
return
|
|
925
|
+
|
|
926
|
+
except Exception as error:
|
|
927
|
+
# Capture the full stack trace
|
|
928
|
+
error_message = str(error)
|
|
929
|
+
stack_trace = traceback.format_exc()
|
|
930
|
+
logger.error("Execution error: %s\nStack trace:\n%s", error_message, stack_trace)
|
|
931
|
+
|
|
932
|
+
# Extract final_state from execution context if workflow has state_schema
|
|
933
|
+
workflow_id = workflow_data.get("workflow_id")
|
|
934
|
+
workflow = self.workflows_registry.get(workflow_id) or _WORKFLOW_REGISTRY.get(
|
|
935
|
+
workflow_id
|
|
936
|
+
)
|
|
937
|
+
|
|
938
|
+
# Check if error is StepExecutionError - if so, mark as non-retryable
|
|
939
|
+
# Tools are not retryable by default. We feed the error back to the LLM to handle.
|
|
940
|
+
retryable = (
|
|
941
|
+
not isinstance(error, StepExecutionError) and workflow.workflow_type != "tool"
|
|
942
|
+
)
|
|
943
|
+
await self._report_failure(
|
|
944
|
+
workflow_data["execution_id"], error_message, stack_trace, retryable=retryable
|
|
945
|
+
)
|
|
946
|
+
raise
|
|
947
|
+
|
|
948
|
+
def _setup_worker_server(self):
|
|
949
|
+
"""Setup FastAPI server for push mode."""
|
|
950
|
+
if not FASTAPI_AVAILABLE:
|
|
951
|
+
raise RuntimeError("FastAPI not available")
|
|
952
|
+
|
|
953
|
+
# Create callback function that executes workflow
|
|
954
|
+
async def on_work_received(workflow_data: dict[str, Any]):
|
|
955
|
+
"""Callback to handle received work."""
|
|
956
|
+
await self._execute_workflow_with_semaphore(workflow_data)
|
|
957
|
+
|
|
958
|
+
# Create callback function that handles cancel requests
|
|
959
|
+
async def on_cancel_requested(execution_id: str) -> bool:
|
|
960
|
+
"""Callback to handle cancel requests.
|
|
961
|
+
|
|
962
|
+
Returns:
|
|
963
|
+
True if execution was found and cancelled, False if not found or already completed
|
|
964
|
+
"""
|
|
965
|
+
return await self._handle_cancel_request(execution_id)
|
|
966
|
+
|
|
967
|
+
# Initialize worker server
|
|
968
|
+
self.worker_server = WorkerServer(
|
|
969
|
+
worker_id=self.worker_id,
|
|
970
|
+
max_concurrent_workflows=self.max_concurrent_workflows,
|
|
971
|
+
on_work_received=on_work_received,
|
|
972
|
+
on_cancel_requested=on_cancel_requested,
|
|
973
|
+
local_mode=self.local_mode,
|
|
974
|
+
)
|
|
975
|
+
|
|
976
|
+
logger.info("Worker server initialized")
|
|
977
|
+
|
|
978
|
+
async def _report_success(
|
|
979
|
+
self,
|
|
980
|
+
execution_id: str,
|
|
981
|
+
result: Any,
|
|
982
|
+
output_schema_name: str | None = None,
|
|
983
|
+
final_state: dict[str, Any] | None = None,
|
|
984
|
+
):
|
|
985
|
+
"""Report successful workflow execution with retries and exponential backoff."""
|
|
986
|
+
max_retries = 5
|
|
987
|
+
base_delay = 1.0 # Start with 1 second
|
|
988
|
+
|
|
989
|
+
for attempt in range(max_retries):
|
|
990
|
+
try:
|
|
991
|
+
headers = self._get_headers()
|
|
992
|
+
payload = {
|
|
993
|
+
"result": result,
|
|
994
|
+
"worker_id": self.worker_id, # Include worker_id for validation
|
|
995
|
+
}
|
|
996
|
+
if output_schema_name:
|
|
997
|
+
payload["output_schema_name"] = output_schema_name
|
|
998
|
+
if final_state is not None:
|
|
999
|
+
payload["final_state"] = final_state
|
|
1000
|
+
response = await self.client.post(
|
|
1001
|
+
f"{self.api_url}/internal/executions/{execution_id}/complete",
|
|
1002
|
+
json=payload,
|
|
1003
|
+
headers=headers,
|
|
1004
|
+
)
|
|
1005
|
+
|
|
1006
|
+
# Handle 409 Conflict (execution reassigned)
|
|
1007
|
+
if response.status_code == 409:
|
|
1008
|
+
logger.debug(
|
|
1009
|
+
"Execution %s was reassigned [old worker %s], ignoring completion",
|
|
1010
|
+
execution_id,
|
|
1011
|
+
self.worker_id,
|
|
1012
|
+
)
|
|
1013
|
+
return # Don't retry on 409
|
|
1014
|
+
|
|
1015
|
+
response.raise_for_status()
|
|
1016
|
+
# Success - return immediately
|
|
1017
|
+
return
|
|
1018
|
+
except httpx.HTTPStatusError as e:
|
|
1019
|
+
if e.response.status_code == 409:
|
|
1020
|
+
# 409 Conflict - execution reassigned, don't retry
|
|
1021
|
+
logger.debug("Execution %s was reassigned, ignoring completion", execution_id)
|
|
1022
|
+
return
|
|
1023
|
+
if attempt < max_retries - 1:
|
|
1024
|
+
# Calculate exponential backoff: 1s, 2s, 4s, 8s, 16s
|
|
1025
|
+
delay = base_delay * (2**attempt)
|
|
1026
|
+
logger.warning(
|
|
1027
|
+
"Failed to report success (attempt %d/%d): %s. Retrying in %ds...",
|
|
1028
|
+
attempt + 1,
|
|
1029
|
+
max_retries,
|
|
1030
|
+
e,
|
|
1031
|
+
delay,
|
|
1032
|
+
)
|
|
1033
|
+
await asyncio.sleep(delay)
|
|
1034
|
+
else:
|
|
1035
|
+
# Final attempt failed - report as failure with error message
|
|
1036
|
+
error_msg = f"Failed to report success after {max_retries} attempts: {e}"
|
|
1037
|
+
logger.error("%s", error_msg)
|
|
1038
|
+
# Don't call _report_failure here to avoid infinite loop
|
|
1039
|
+
except Exception as error:
|
|
1040
|
+
if attempt < max_retries - 1:
|
|
1041
|
+
# Calculate exponential backoff: 1s, 2s, 4s, 8s, 16s
|
|
1042
|
+
delay = base_delay * (2**attempt)
|
|
1043
|
+
logger.warning(
|
|
1044
|
+
"Failed to report success (attempt %d/%d): %s. Retrying in %ds...",
|
|
1045
|
+
attempt + 1,
|
|
1046
|
+
max_retries,
|
|
1047
|
+
error,
|
|
1048
|
+
delay,
|
|
1049
|
+
)
|
|
1050
|
+
await asyncio.sleep(delay)
|
|
1051
|
+
else:
|
|
1052
|
+
# Final attempt failed - report as failure with error message
|
|
1053
|
+
error_msg = f"Failed to report success after {max_retries} attempts: {error}"
|
|
1054
|
+
logger.error("%s", error_msg)
|
|
1055
|
+
await self._report_failure(execution_id, error_msg, retryable=True)
|
|
1056
|
+
|
|
1057
|
+
async def _report_failure(
|
|
1058
|
+
self,
|
|
1059
|
+
execution_id: str,
|
|
1060
|
+
error_message: str,
|
|
1061
|
+
stack_trace: str | None = None,
|
|
1062
|
+
retryable: bool = True,
|
|
1063
|
+
final_state: dict[str, Any] | None = None,
|
|
1064
|
+
):
|
|
1065
|
+
"""Report failed workflow execution with retries and exponential backoff.
|
|
1066
|
+
|
|
1067
|
+
Args:
|
|
1068
|
+
execution_id: Execution ID to report failure for
|
|
1069
|
+
error_message: Error message
|
|
1070
|
+
stack_trace: Optional stack trace
|
|
1071
|
+
retryable: Whether the execution should be retried (default: True)
|
|
1072
|
+
"""
|
|
1073
|
+
max_retries = 5
|
|
1074
|
+
base_delay = 1.0 # Start with 1 second
|
|
1075
|
+
|
|
1076
|
+
for attempt in range(max_retries):
|
|
1077
|
+
try:
|
|
1078
|
+
headers = self._get_headers()
|
|
1079
|
+
payload = {
|
|
1080
|
+
"error": error_message,
|
|
1081
|
+
"worker_id": self.worker_id, # Include worker_id for validation
|
|
1082
|
+
}
|
|
1083
|
+
if stack_trace:
|
|
1084
|
+
payload["stack"] = stack_trace
|
|
1085
|
+
if not retryable:
|
|
1086
|
+
payload["retryable"] = False
|
|
1087
|
+
if final_state is not None:
|
|
1088
|
+
payload["final_state"] = final_state
|
|
1089
|
+
response = await self.client.post(
|
|
1090
|
+
f"{self.api_url}/internal/executions/{execution_id}/fail",
|
|
1091
|
+
json=payload,
|
|
1092
|
+
headers=headers,
|
|
1093
|
+
)
|
|
1094
|
+
|
|
1095
|
+
# Handle 409 Conflict (execution reassigned)
|
|
1096
|
+
if response.status_code == 409:
|
|
1097
|
+
logger.debug("Execution %s was reassigned, ignoring failure", execution_id)
|
|
1098
|
+
return # Don't retry on 409
|
|
1099
|
+
|
|
1100
|
+
response.raise_for_status()
|
|
1101
|
+
# Success - return immediately
|
|
1102
|
+
return
|
|
1103
|
+
except httpx.HTTPStatusError as e:
|
|
1104
|
+
if e.response.status_code == 409:
|
|
1105
|
+
# 409 Conflict - execution reassigned, don't retry
|
|
1106
|
+
logger.debug("Execution %s was reassigned, ignoring failure", execution_id)
|
|
1107
|
+
return
|
|
1108
|
+
if attempt < max_retries - 1:
|
|
1109
|
+
# Calculate exponential backoff: 1s, 2s, 4s, 8s, 16s
|
|
1110
|
+
delay = base_delay * (2**attempt)
|
|
1111
|
+
logger.warning(
|
|
1112
|
+
"Failed to report failure (attempt %d/%d): %s. Retrying in %ds...",
|
|
1113
|
+
attempt + 1,
|
|
1114
|
+
max_retries,
|
|
1115
|
+
e,
|
|
1116
|
+
delay,
|
|
1117
|
+
)
|
|
1118
|
+
await asyncio.sleep(delay)
|
|
1119
|
+
else:
|
|
1120
|
+
# Final attempt failed
|
|
1121
|
+
logger.error("Failed to report failure after %d attempts: %s", max_retries, e)
|
|
1122
|
+
except Exception as error:
|
|
1123
|
+
if attempt < max_retries - 1:
|
|
1124
|
+
# Calculate exponential backoff: 1s, 2s, 4s, 8s, 16s
|
|
1125
|
+
delay = base_delay * (2**attempt)
|
|
1126
|
+
logger.warning(
|
|
1127
|
+
"Failed to report failure (attempt %d/%d): %s. Retrying in %ds...",
|
|
1128
|
+
attempt + 1,
|
|
1129
|
+
max_retries,
|
|
1130
|
+
error,
|
|
1131
|
+
delay,
|
|
1132
|
+
)
|
|
1133
|
+
await asyncio.sleep(delay)
|
|
1134
|
+
else:
|
|
1135
|
+
# Final attempt failed
|
|
1136
|
+
logger.error(
|
|
1137
|
+
"Failed to report failure after %d attempts: %s", max_retries, error
|
|
1138
|
+
)
|
|
1139
|
+
|
|
1140
|
+
async def _handle_cancel_request(self, execution_id: str) -> bool:
|
|
1141
|
+
"""Handle cancellation request from orchestrator.
|
|
1142
|
+
|
|
1143
|
+
Returns:
|
|
1144
|
+
True if execution was found and cancelled, False if not found or already completed
|
|
1145
|
+
"""
|
|
1146
|
+
logger.info("Handling cancellation request for execution %s", execution_id)
|
|
1147
|
+
async with self.execution_tasks_lock:
|
|
1148
|
+
task = self.execution_tasks.get(execution_id)
|
|
1149
|
+
if task and not task.done():
|
|
1150
|
+
logger.debug("Cancelling task %s", task)
|
|
1151
|
+
# Cancel the task
|
|
1152
|
+
task.cancel()
|
|
1153
|
+
logger.info("Cancellation requested for execution %s", execution_id)
|
|
1154
|
+
return True
|
|
1155
|
+
else:
|
|
1156
|
+
logger.debug("Execution %s not found or already completed", execution_id)
|
|
1157
|
+
return False
|
|
1158
|
+
|
|
1159
|
+
async def _handle_cancellation(
|
|
1160
|
+
self, execution_id: str, workflow_id: str, context: dict[str, Any]
|
|
1161
|
+
):
|
|
1162
|
+
"""Handle execution cancellation - send confirmation and emit event."""
|
|
1163
|
+
try:
|
|
1164
|
+
logger.info("Sending cancellation confirmation for execution %s", execution_id)
|
|
1165
|
+
# Send cancel confirmation to orchestrator
|
|
1166
|
+
await self._send_cancel_confirmation(execution_id)
|
|
1167
|
+
|
|
1168
|
+
# Emit cancellation event
|
|
1169
|
+
await self._emit_cancellation_event(execution_id, workflow_id, context)
|
|
1170
|
+
except Exception as e:
|
|
1171
|
+
logger.error("Error handling cancellation for %s: %s", execution_id, e)
|
|
1172
|
+
|
|
1173
|
+
async def _send_cancel_confirmation(self, execution_id: str):
|
|
1174
|
+
"""Send cancellation confirmation to orchestrator."""
|
|
1175
|
+
max_retries = 5
|
|
1176
|
+
base_delay = 1.0
|
|
1177
|
+
|
|
1178
|
+
for attempt in range(max_retries):
|
|
1179
|
+
try:
|
|
1180
|
+
headers = self._get_headers()
|
|
1181
|
+
payload = {
|
|
1182
|
+
"worker_id": self.worker_id,
|
|
1183
|
+
}
|
|
1184
|
+
response = await self.client.post(
|
|
1185
|
+
f"{self.api_url}/internal/executions/{execution_id}/confirm-cancellation",
|
|
1186
|
+
json=payload,
|
|
1187
|
+
headers=headers,
|
|
1188
|
+
)
|
|
1189
|
+
|
|
1190
|
+
# Handle 409 Conflict (execution reassigned)
|
|
1191
|
+
if response.status_code == 409:
|
|
1192
|
+
logger.debug(
|
|
1193
|
+
"Execution %s was reassigned, ignoring cancellation confirmation",
|
|
1194
|
+
execution_id,
|
|
1195
|
+
)
|
|
1196
|
+
return
|
|
1197
|
+
|
|
1198
|
+
response.raise_for_status()
|
|
1199
|
+
logger.info("Sent cancellation confirmation for execution %s", execution_id)
|
|
1200
|
+
return
|
|
1201
|
+
except httpx.HTTPStatusError as e:
|
|
1202
|
+
if e.response.status_code == 409:
|
|
1203
|
+
logger.debug(
|
|
1204
|
+
"Execution %s was reassigned, ignoring cancellation confirmation",
|
|
1205
|
+
execution_id,
|
|
1206
|
+
)
|
|
1207
|
+
return
|
|
1208
|
+
if attempt < max_retries - 1:
|
|
1209
|
+
delay = base_delay * (2**attempt)
|
|
1210
|
+
logger.warning(
|
|
1211
|
+
"Failed to send cancellation confirmation "
|
|
1212
|
+
"(attempt %d/%d): %s. Retrying in %ds...",
|
|
1213
|
+
attempt + 1,
|
|
1214
|
+
max_retries,
|
|
1215
|
+
e,
|
|
1216
|
+
delay,
|
|
1217
|
+
)
|
|
1218
|
+
await asyncio.sleep(delay)
|
|
1219
|
+
else:
|
|
1220
|
+
logger.error(
|
|
1221
|
+
"Failed to send cancellation confirmation after %d attempts: %s",
|
|
1222
|
+
max_retries,
|
|
1223
|
+
e,
|
|
1224
|
+
)
|
|
1225
|
+
except Exception as error:
|
|
1226
|
+
if attempt < max_retries - 1:
|
|
1227
|
+
delay = base_delay * (2**attempt)
|
|
1228
|
+
logger.warning(
|
|
1229
|
+
"Failed to send cancellation confirmation "
|
|
1230
|
+
"(attempt %d/%d): %s. Retrying in %ds...",
|
|
1231
|
+
attempt + 1,
|
|
1232
|
+
max_retries,
|
|
1233
|
+
error,
|
|
1234
|
+
delay,
|
|
1235
|
+
)
|
|
1236
|
+
await asyncio.sleep(delay)
|
|
1237
|
+
else:
|
|
1238
|
+
logger.error(
|
|
1239
|
+
"Failed to send cancellation confirmation after %d attempts: %s",
|
|
1240
|
+
max_retries,
|
|
1241
|
+
error,
|
|
1242
|
+
)
|
|
1243
|
+
|
|
1244
|
+
async def _emit_cancellation_event(
|
|
1245
|
+
self, execution_id: str, workflow_id: str, context: dict[str, Any]
|
|
1246
|
+
):
|
|
1247
|
+
"""Emit cancellation event for the workflow."""
|
|
1248
|
+
try:
|
|
1249
|
+
from ..features.events import publish
|
|
1250
|
+
|
|
1251
|
+
# Topic format: workflow:{execution_id}
|
|
1252
|
+
topic = f"workflow:{context.get('root_execution_id') or execution_id}"
|
|
1253
|
+
|
|
1254
|
+
# Event type: {workflow_id}_cancel
|
|
1255
|
+
event_type = f"{context.get('workflow_type', 'workflow')}_cancel"
|
|
1256
|
+
|
|
1257
|
+
# Event data
|
|
1258
|
+
event_data = {
|
|
1259
|
+
"_metadata": {
|
|
1260
|
+
"execution_id": execution_id,
|
|
1261
|
+
"workflow_id": workflow_id,
|
|
1262
|
+
}
|
|
1263
|
+
}
|
|
1264
|
+
|
|
1265
|
+
# Publish event
|
|
1266
|
+
await publish(
|
|
1267
|
+
self.polos_client,
|
|
1268
|
+
topic=topic,
|
|
1269
|
+
event_type=event_type,
|
|
1270
|
+
data=event_data,
|
|
1271
|
+
execution_id=execution_id,
|
|
1272
|
+
root_execution_id=context.get("root_execution_id"),
|
|
1273
|
+
)
|
|
1274
|
+
|
|
1275
|
+
logger.debug("Emitted cancellation event for execution %s", execution_id)
|
|
1276
|
+
except Exception as e:
|
|
1277
|
+
logger.error("Failed to emit cancellation event for %s: %s", execution_id, e)
|
|
1278
|
+
|
|
1279
|
+
async def _heartbeat_loop(self):
|
|
1280
|
+
"""Send periodic heartbeats."""
|
|
1281
|
+
while self.running:
|
|
1282
|
+
try:
|
|
1283
|
+
await asyncio.sleep(30)
|
|
1284
|
+
|
|
1285
|
+
if not self.worker_id:
|
|
1286
|
+
continue
|
|
1287
|
+
|
|
1288
|
+
headers = self._get_headers()
|
|
1289
|
+
response = await self.client.post(
|
|
1290
|
+
f"{self.api_url}/api/v1/workers/{self.worker_id}/heartbeat",
|
|
1291
|
+
headers=headers,
|
|
1292
|
+
)
|
|
1293
|
+
response.raise_for_status()
|
|
1294
|
+
|
|
1295
|
+
# Check if re-registration is required
|
|
1296
|
+
data = response.json()
|
|
1297
|
+
if data.get("re_register", False):
|
|
1298
|
+
logger.info("Orchestrator requested re-registration, re-registering...")
|
|
1299
|
+
await self._re_register()
|
|
1300
|
+
except asyncio.CancelledError:
|
|
1301
|
+
break
|
|
1302
|
+
except Exception as error:
|
|
1303
|
+
logger.warning("Heartbeat failed: %s", error)
|
|
1304
|
+
|
|
1305
|
+
async def shutdown(self):
|
|
1306
|
+
"""Graceful shutdown."""
|
|
1307
|
+
if not self.running:
|
|
1308
|
+
return
|
|
1309
|
+
|
|
1310
|
+
logger.info("Shutting down gracefully...")
|
|
1311
|
+
self.running = False
|
|
1312
|
+
|
|
1313
|
+
# Shutdown worker server first (if in push mode)
|
|
1314
|
+
if self.mode == "push" and self.worker_server:
|
|
1315
|
+
try:
|
|
1316
|
+
await self.worker_server.shutdown()
|
|
1317
|
+
except Exception as e:
|
|
1318
|
+
logger.error("Error shutting down worker server: %s", e)
|
|
1319
|
+
|
|
1320
|
+
# Cancel tasks
|
|
1321
|
+
if self.mode == "push":
|
|
1322
|
+
if self.worker_server_task:
|
|
1323
|
+
self.worker_server_task.cancel()
|
|
1324
|
+
else:
|
|
1325
|
+
if self.poll_task:
|
|
1326
|
+
self.poll_task.cancel()
|
|
1327
|
+
|
|
1328
|
+
if self.heartbeat_task:
|
|
1329
|
+
self.heartbeat_task.cancel()
|
|
1330
|
+
|
|
1331
|
+
# Wait for cancellation
|
|
1332
|
+
if self.mode == "push":
|
|
1333
|
+
try:
|
|
1334
|
+
if self.worker_server_task:
|
|
1335
|
+
await self.worker_server_task
|
|
1336
|
+
except asyncio.CancelledError:
|
|
1337
|
+
pass
|
|
1338
|
+
else:
|
|
1339
|
+
try:
|
|
1340
|
+
if self.poll_task:
|
|
1341
|
+
await self.poll_task
|
|
1342
|
+
except asyncio.CancelledError:
|
|
1343
|
+
pass
|
|
1344
|
+
|
|
1345
|
+
try:
|
|
1346
|
+
if self.heartbeat_task:
|
|
1347
|
+
await self.heartbeat_task
|
|
1348
|
+
except asyncio.CancelledError:
|
|
1349
|
+
pass
|
|
1350
|
+
|
|
1351
|
+
# Close HTTP client
|
|
1352
|
+
if hasattr(self, "client") and self.client:
|
|
1353
|
+
try:
|
|
1354
|
+
await self.client.aclose()
|
|
1355
|
+
except Exception as e:
|
|
1356
|
+
logger.error("Error closing HTTP client: %s", e)
|
|
1357
|
+
|
|
1358
|
+
# Unregister this worker instance
|
|
1359
|
+
set_current_worker(None)
|
|
1360
|
+
|
|
1361
|
+
logger.info("Shutdown complete")
|
|
1362
|
+
|
|
1363
|
+
def _get_headers(self) -> dict[str, str]:
|
|
1364
|
+
"""Get HTTP headers for API requests, including API key and project_id."""
|
|
1365
|
+
return self.polos_client._get_headers()
|