polos-sdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. polos/__init__.py +105 -0
  2. polos/agents/__init__.py +7 -0
  3. polos/agents/agent.py +746 -0
  4. polos/agents/conversation_history.py +121 -0
  5. polos/agents/stop_conditions.py +280 -0
  6. polos/agents/stream.py +635 -0
  7. polos/core/__init__.py +0 -0
  8. polos/core/context.py +143 -0
  9. polos/core/state.py +26 -0
  10. polos/core/step.py +1380 -0
  11. polos/core/workflow.py +1192 -0
  12. polos/features/__init__.py +0 -0
  13. polos/features/events.py +456 -0
  14. polos/features/schedules.py +110 -0
  15. polos/features/tracing.py +605 -0
  16. polos/features/wait.py +82 -0
  17. polos/llm/__init__.py +9 -0
  18. polos/llm/generate.py +152 -0
  19. polos/llm/providers/__init__.py +5 -0
  20. polos/llm/providers/anthropic.py +615 -0
  21. polos/llm/providers/azure.py +42 -0
  22. polos/llm/providers/base.py +196 -0
  23. polos/llm/providers/fireworks.py +41 -0
  24. polos/llm/providers/gemini.py +40 -0
  25. polos/llm/providers/groq.py +40 -0
  26. polos/llm/providers/openai.py +1021 -0
  27. polos/llm/providers/together.py +40 -0
  28. polos/llm/stream.py +183 -0
  29. polos/middleware/__init__.py +0 -0
  30. polos/middleware/guardrail.py +148 -0
  31. polos/middleware/guardrail_executor.py +253 -0
  32. polos/middleware/hook.py +164 -0
  33. polos/middleware/hook_executor.py +104 -0
  34. polos/runtime/__init__.py +0 -0
  35. polos/runtime/batch.py +87 -0
  36. polos/runtime/client.py +841 -0
  37. polos/runtime/queue.py +42 -0
  38. polos/runtime/worker.py +1365 -0
  39. polos/runtime/worker_server.py +249 -0
  40. polos/tools/__init__.py +0 -0
  41. polos/tools/tool.py +587 -0
  42. polos/types/__init__.py +23 -0
  43. polos/types/types.py +116 -0
  44. polos/utils/__init__.py +27 -0
  45. polos/utils/agent.py +27 -0
  46. polos/utils/client_context.py +41 -0
  47. polos/utils/config.py +12 -0
  48. polos/utils/output_schema.py +311 -0
  49. polos/utils/retry.py +47 -0
  50. polos/utils/serializer.py +167 -0
  51. polos/utils/tracing.py +27 -0
  52. polos/utils/worker_singleton.py +40 -0
  53. polos_sdk-0.1.0.dist-info/METADATA +650 -0
  54. polos_sdk-0.1.0.dist-info/RECORD +55 -0
  55. polos_sdk-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,841 @@
1
+ import json
2
+ import logging
3
+ import os
4
+ from typing import Any
5
+
6
+ import httpx
7
+ from dotenv import load_dotenv
8
+ from pydantic import BaseModel
9
+
10
+ from ..types.types import AgentResult, BatchWorkflowInput
11
+ from ..utils.client_context import get_client_or_raise
12
+ from ..utils.config import is_localhost_url
13
+ from ..utils.worker_singleton import get_worker_client
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ load_dotenv()
18
+
19
+
20
+ def _validate_state_size(state: dict[str, Any] | BaseModel, max_size_mb: float = 1.0) -> None:
21
+ """Validate that state JSON size doesn't exceed limit.
22
+
23
+ Args:
24
+ state: State dictionary to validate
25
+ max_size_mb: Maximum size in MB (default: 1.0)
26
+
27
+ Raises:
28
+ ValueError: If state size exceeds limit
29
+ """
30
+
31
+ state_json = state.model_dump_json() if isinstance(state, BaseModel) else json.dumps(state)
32
+ size_bytes = len(state_json.encode("utf-8"))
33
+ max_bytes = int(max_size_mb * 1024 * 1024)
34
+
35
+ if size_bytes > max_bytes:
36
+ size_mb = size_bytes / (1024 * 1024)
37
+ raise ValueError(
38
+ f"Workflow state size ({size_mb:.2f}MB) exceeds maximum allowed size "
39
+ f"({max_size_mb}MB). Consider reducing state size or using external storage."
40
+ )
41
+
42
+
43
+ class PolosClient:
44
+ """Client for interacting with the Polos orchestrator API."""
45
+
46
+ def __init__(
47
+ self,
48
+ api_url: str | None = None,
49
+ api_key: str | None = None,
50
+ project_id: str | None = None,
51
+ ):
52
+ """Initialize Polos client.
53
+
54
+ Args:
55
+ api_url: Orchestrator API URL (default: from POLOS_API_URL env var or http://localhost:8080)
56
+ api_key: API key for authentication (default: from POLOS_API_KEY env var)
57
+ project_id: Project ID for multi-tenancy (default: from POLOS_PROJECT_ID env var)
58
+ """
59
+ self.api_url = api_url or os.getenv("POLOS_API_URL", "http://localhost:8080")
60
+ self.api_key = api_key or os.getenv("POLOS_API_KEY")
61
+ self.project_id = project_id or os.getenv("POLOS_PROJECT_ID")
62
+
63
+ # Validate required fields (with local mode support)
64
+ local_mode_requested = os.getenv("POLOS_LOCAL_MODE", "False").lower() == "true"
65
+ is_localhost = is_localhost_url(self.api_url)
66
+ local_mode = local_mode_requested and is_localhost
67
+
68
+ if local_mode_requested and not is_localhost:
69
+ logger.warning(
70
+ f"POLOS_LOCAL_MODE=True ignored because api_url ({self.api_url}) "
71
+ "is not localhost. Falling back to normal authentication."
72
+ )
73
+
74
+ if not local_mode and not self.api_key:
75
+ raise ValueError(
76
+ "api_key is required. Set it via PolosClient(api_key='...') "
77
+ "or POLOS_API_KEY environment variable. Or set "
78
+ "POLOS_LOCAL_MODE=True for local development "
79
+ "(only works with localhost URLs)."
80
+ )
81
+
82
+ if not self.project_id:
83
+ raise ValueError(
84
+ "project_id is required. Set it via PolosClient(project_id='...') "
85
+ "or POLOS_PROJECT_ID environment variable."
86
+ )
87
+
88
+ def _get_headers(self) -> dict[str, str]:
89
+ """Get headers for API requests, including project_id and API key.
90
+
91
+ The API key is required for all orchestrator API calls, unless POLOS_LOCAL_MODE=True.
92
+ Local mode is only enabled when api_url is localhost.
93
+ """
94
+ headers = {"Content-Type": "application/json"}
95
+
96
+ # Check for local mode (only enabled for localhost URLs)
97
+ local_mode_requested = os.getenv("POLOS_LOCAL_MODE", "False").lower() == "true"
98
+ is_localhost = is_localhost_url(self.api_url)
99
+ local_mode = local_mode_requested and is_localhost
100
+
101
+ # API key is required for all orchestrator API calls unless in local mode
102
+ if not local_mode:
103
+ if not self.api_key:
104
+ raise ValueError(
105
+ "api_key is required. Set it via PolosClient(api_key='...') "
106
+ "or POLOS_API_KEY environment variable. Or set "
107
+ "POLOS_LOCAL_MODE=True for local development "
108
+ "(only works with localhost URLs)."
109
+ )
110
+ headers["Authorization"] = f"Bearer {self.api_key}"
111
+
112
+ # Add project_id header (required for multi-tenancy)
113
+ if not self.project_id:
114
+ raise ValueError(
115
+ "project_id is required. Set it via PolosClient(project_id='...') "
116
+ "or POLOS_PROJECT_ID environment variable."
117
+ )
118
+ headers["X-Project-ID"] = self.project_id
119
+
120
+ return headers
121
+
122
+ async def _get_http_client(self, timeout: httpx.Timeout | None = None) -> httpx.AsyncClient:
123
+ """Get an HTTP client, reusing the worker's client if available.
124
+
125
+ Args:
126
+ timeout: Optional timeout for the client (only used if creating new client)
127
+
128
+ Returns:
129
+ An httpx.AsyncClient instance (either from worker or newly created)
130
+
131
+ Note:
132
+ If using worker's client, it's not closed after use (worker manages its lifecycle).
133
+ If creating a new client, it should be used with `async with` context manager.
134
+ """
135
+ worker_client = get_worker_client()
136
+ if worker_client is not None:
137
+ return worker_client
138
+ else:
139
+ if timeout is not None:
140
+ return httpx.AsyncClient(timeout=timeout)
141
+ else:
142
+ return httpx.AsyncClient()
143
+
144
+ async def invoke(
145
+ self,
146
+ workflow_id: str,
147
+ payload: Any = None,
148
+ queue_name: str | None = None,
149
+ queue_concurrency_limit: int | None = None,
150
+ concurrency_key: str | None = None,
151
+ session_id: str | None = None,
152
+ user_id: str | None = None,
153
+ initial_state: dict[str, Any] | None = None,
154
+ run_timeout_seconds: int | None = None,
155
+ ) -> "ExecutionHandle":
156
+ """Invoke a workflow and return an execution handle.
157
+
158
+ Args:
159
+ workflow_id: The workflow identifier
160
+ payload: The workflow payload
161
+ deployment_id: Optional deployment ID (if not provided, uses latest active)
162
+ queue_name: Optional queue name (if not provided, defaults to workflow_id)
163
+ queue_concurrency_limit: Optional concurrency limit for queue creation
164
+ concurrency_key: Optional concurrency key for per-tenant queuing
165
+ session_id: Optional session ID
166
+ user_id: Optional user ID
167
+ initial_state: Optional initial state dictionary (must be JSON-serializable, max 1MB)
168
+ run_timeout_seconds: Optional timeout in seconds
169
+
170
+ Returns:
171
+ ExecutionHandle for the submitted workflow
172
+ """
173
+ return await self._submit_workflow(
174
+ workflow_id=workflow_id,
175
+ deployment_id=None, # Use latest deployment
176
+ payload=payload,
177
+ queue_name=queue_name,
178
+ queue_concurrency_limit=queue_concurrency_limit,
179
+ concurrency_key=concurrency_key,
180
+ session_id=session_id,
181
+ user_id=user_id,
182
+ initial_state=initial_state,
183
+ run_timeout_seconds=run_timeout_seconds,
184
+ )
185
+
186
+ async def batch_invoke(
187
+ self,
188
+ workflows: list[BatchWorkflowInput],
189
+ session_id: str | None = None,
190
+ user_id: str | None = None,
191
+ ) -> list["ExecutionHandle"]:
192
+ """Invoke multiple different workflows in a single batch and return handles immediately.
193
+
194
+ Args:
195
+ workflows: List of BatchWorkflowInput objects with 'id' (workflow_id string)
196
+ and 'payload' (dict or Pydantic model)
197
+ session_id: Optional session ID
198
+ user_id: Optional user ID
199
+
200
+ Returns:
201
+ List of ExecutionHandle objects for the submitted workflows
202
+ """
203
+ from ..core.workflow import get_workflow
204
+ from ..utils.serializer import serialize
205
+
206
+ if not workflows:
207
+ return []
208
+
209
+ # Build workflow requests for batch submission
210
+ workflow_requests = []
211
+ for workflow_input in workflows:
212
+ workflow_id = workflow_input.id
213
+ payload = serialize(workflow_input.payload)
214
+
215
+ workflow_obj = get_workflow(workflow_id)
216
+ if not workflow_obj:
217
+ raise ValueError(f"Workflow '{workflow_id}' not found")
218
+
219
+ workflow_req = {
220
+ "workflow_id": workflow_id,
221
+ "payload": payload,
222
+ "initial_state": serialize(workflow_input.initial_state),
223
+ "run_timeout_seconds": workflow_input.run_timeout_seconds,
224
+ }
225
+
226
+ # Per-workflow properties (queue_name, concurrency_key, etc.)
227
+ if workflow_obj.queue_name is not None:
228
+ workflow_req["queue_name"] = workflow_obj.queue_name
229
+
230
+ if workflow_obj.queue_concurrency_limit is not None:
231
+ workflow_req["queue_concurrency_limit"] = workflow_obj.queue_concurrency_limit
232
+
233
+ workflow_requests.append(workflow_req)
234
+
235
+ # Submit all workflows in a single batch using the batch endpoint
236
+ handles = await self._submit_workflows(
237
+ workflows=workflow_requests,
238
+ deployment_id=None, # Use latest active deployment
239
+ parent_execution_id=None,
240
+ root_execution_id=None,
241
+ step_key=None, # Not invoked from a step, so no step_key
242
+ session_id=session_id,
243
+ user_id=user_id,
244
+ wait_for_subworkflow=False, # Fire-and-forget
245
+ )
246
+
247
+ return handles
248
+
249
+ async def resume(self, suspend_execution_id: str, suspend_step_key: str, data: Any) -> None:
250
+ """Resume a suspended execution by publishing a resume event.
251
+
252
+ Args:
253
+ suspend_execution_id: The execution ID that is suspended
254
+ suspend_step_key: The step key that was used in suspend()
255
+ data: Data to pass in the resume event (can be dict or Pydantic BaseModel)
256
+ """
257
+ from ..features.events import EventData, batch_publish
258
+ from ..utils.serializer import serialize
259
+
260
+ # Serialize data
261
+ serialized_data = serialize(data)
262
+
263
+ topic = f"{suspend_step_key}/{suspend_execution_id}"
264
+
265
+ # Publish resume event
266
+ await batch_publish(
267
+ topic=topic,
268
+ events=[EventData(data=serialized_data, event_type="resume")],
269
+ client=self,
270
+ )
271
+
272
+ async def get_execution(self, execution_id: str) -> dict[str, Any]:
273
+ """Get execution details from the orchestrator.
274
+
275
+ Args:
276
+ execution_id: The execution ID to look up
277
+
278
+ Returns:
279
+ Dictionary with execution details including:
280
+ - id, workflow_id, status, payload, result, error
281
+ - created_at, started_at, completed_at
282
+ - deployment_id, parent_execution_id, root_execution_id
283
+ - retry_count, step_key etc.
284
+ """
285
+ headers = self._get_headers()
286
+
287
+ # Try to reuse worker's HTTP client if available
288
+ worker_client = get_worker_client()
289
+ if worker_client is not None:
290
+ response = await worker_client.get(
291
+ f"{self.api_url}/api/v1/executions/{execution_id}",
292
+ headers=headers,
293
+ )
294
+ response.raise_for_status()
295
+ return response.json()
296
+ else:
297
+ async with httpx.AsyncClient() as client:
298
+ response = await client.get(
299
+ f"{self.api_url}/api/v1/executions/{execution_id}",
300
+ headers=headers,
301
+ )
302
+ response.raise_for_status()
303
+ return response.json()
304
+
305
+ async def cancel_execution(self, execution_id: str) -> bool:
306
+ """Cancel an execution by its ID.
307
+
308
+ Args:
309
+ execution_id: The execution ID to cancel
310
+
311
+ Returns:
312
+ True if cancellation was successful, False otherwise
313
+ """
314
+ headers = self._get_headers()
315
+
316
+ # Try to reuse worker's HTTP client if available
317
+ worker_client = get_worker_client()
318
+ if worker_client is not None:
319
+ client = worker_client
320
+ use_context_manager = False
321
+ else:
322
+ client = httpx.AsyncClient(timeout=httpx.Timeout(30.0))
323
+ use_context_manager = True
324
+
325
+ try:
326
+ response = await client.post(
327
+ f"{self.api_url}/api/v1/executions/{execution_id}/cancel",
328
+ headers=headers,
329
+ )
330
+
331
+ if response.status_code == 404:
332
+ # Execution not found
333
+ return False
334
+
335
+ response.raise_for_status()
336
+ return True
337
+ except httpx.HTTPStatusError as e:
338
+ logger.error("Failed to cancel execution %s: %s", execution_id, e)
339
+ return False
340
+ except Exception as e:
341
+ logger.error("Error cancelling execution %s: %s", execution_id, e)
342
+ return False
343
+ finally:
344
+ if use_context_manager:
345
+ await client.aclose()
346
+
347
+ async def _submit_workflow(
348
+ self,
349
+ workflow_id: str,
350
+ payload: Any,
351
+ deployment_id: str | None = None,
352
+ parent_execution_id: str | None = None,
353
+ root_execution_id: str | None = None,
354
+ step_key: str | None = None,
355
+ queue_name: str | None = None,
356
+ queue_concurrency_limit: int | None = None,
357
+ concurrency_key: str | None = None,
358
+ wait_for_subworkflow: bool = False,
359
+ batch_id: str | None = None,
360
+ session_id: str | None = None,
361
+ user_id: str | None = None,
362
+ otel_traceparent: str | None = None,
363
+ initial_state: dict[str, Any] | None = None,
364
+ run_timeout_seconds: int | None = None,
365
+ ) -> "ExecutionHandle":
366
+ """Submit a workflow and return an execution handle.
367
+
368
+ Internal method used by invoke() and step.invoke_and_wait().
369
+ """
370
+ headers = self._get_headers()
371
+
372
+ # Validate initial_state size if provided
373
+ if initial_state is not None:
374
+ _validate_state_size(initial_state)
375
+
376
+ # Inherit session_id and user_id from parent if not provided
377
+ if not session_id or not user_id:
378
+ from ..core.workflow import _execution_context
379
+
380
+ exec_context = _execution_context.get()
381
+ if exec_context:
382
+ if not session_id:
383
+ session_id = exec_context.get("session_id")
384
+ if not user_id:
385
+ user_id = exec_context.get("user_id")
386
+
387
+ # Try to reuse worker's HTTP client if available
388
+ worker_client = get_worker_client()
389
+ if worker_client is not None:
390
+ client = worker_client
391
+ use_context_manager = False
392
+ else:
393
+ client = httpx.AsyncClient(timeout=httpx.Timeout(300.0))
394
+ use_context_manager = True
395
+
396
+ try:
397
+ # Submit workflow
398
+ request_json = {
399
+ "payload": payload,
400
+ }
401
+ if step_key:
402
+ request_json["step_key"] = step_key
403
+ if deployment_id:
404
+ request_json["deployment_id"] = deployment_id
405
+ if parent_execution_id:
406
+ request_json["parent_execution_id"] = parent_execution_id
407
+ if root_execution_id:
408
+ request_json["root_execution_id"] = root_execution_id
409
+ if queue_name:
410
+ request_json["queue_name"] = queue_name
411
+ if queue_concurrency_limit is not None:
412
+ request_json["queue_concurrency_limit"] = queue_concurrency_limit
413
+ if concurrency_key:
414
+ request_json["concurrency_key"] = concurrency_key
415
+ request_json["wait_for_subworkflow"] = wait_for_subworkflow
416
+ if batch_id:
417
+ request_json["batch_id"] = batch_id
418
+ if session_id:
419
+ request_json["session_id"] = session_id
420
+ if user_id:
421
+ request_json["user_id"] = user_id
422
+ if otel_traceparent:
423
+ request_json["otel_traceparent"] = otel_traceparent
424
+ if initial_state is not None:
425
+ request_json["initial_state"] = initial_state
426
+ if run_timeout_seconds is not None:
427
+ request_json["run_timeout_seconds"] = run_timeout_seconds
428
+
429
+ response = await client.post(
430
+ f"{self.api_url}/api/v1/workflows/{workflow_id}/run",
431
+ json=request_json,
432
+ headers=headers,
433
+ )
434
+ response.raise_for_status()
435
+ data = response.json()
436
+ execution_id_value = data["execution_id"]
437
+ created_at = data.get("created_at")
438
+
439
+ # Return handle immediately (fire and forget)
440
+ # Note: If called from within a workflow, the orchestrator has already
441
+ # set the parent to waiting
442
+ # invoke_and_wait() will raise WaitException to pause the parent if needed
443
+ return ExecutionHandle(
444
+ id=execution_id_value,
445
+ workflow_id=workflow_id,
446
+ created_at=created_at,
447
+ parent_execution_id=parent_execution_id,
448
+ root_execution_id=root_execution_id,
449
+ session_id=session_id,
450
+ user_id=user_id,
451
+ step_key=step_key,
452
+ )
453
+ finally:
454
+ if use_context_manager:
455
+ await client.aclose()
456
+
457
+ async def _submit_workflows(
458
+ self,
459
+ workflows: list[dict[str, Any]],
460
+ deployment_id: str | None = None,
461
+ parent_execution_id: str | None = None,
462
+ root_execution_id: str | None = None,
463
+ step_key: str | None = None,
464
+ session_id: str | None = None,
465
+ user_id: str | None = None,
466
+ wait_for_subworkflow: bool = False,
467
+ otel_traceparent: str | None = None,
468
+ ) -> list["ExecutionHandle"]:
469
+ """Submit multiple workflows in a batch and return execution handles.
470
+
471
+ Internal method used by batch_invoke() and step.batch_invoke().
472
+ """
473
+ headers = self._get_headers()
474
+
475
+ # Inherit session_id and user_id from parent if not provided
476
+ if not session_id or not user_id:
477
+ from ..core.workflow import _execution_context
478
+
479
+ exec_context = _execution_context.get()
480
+ if exec_context:
481
+ if not session_id:
482
+ session_id = exec_context.get("session_id")
483
+ if not user_id:
484
+ user_id = exec_context.get("user_id")
485
+
486
+ # Try to reuse worker's HTTP client if available
487
+ worker_client = get_worker_client()
488
+ if worker_client is not None:
489
+ client = worker_client
490
+ use_context_manager = False
491
+ else:
492
+ client = httpx.AsyncClient(timeout=httpx.Timeout(300.0))
493
+ use_context_manager = True
494
+
495
+ try:
496
+ # Prepare batch request
497
+ request_json = {
498
+ "workflows": workflows,
499
+ }
500
+
501
+ # Common batch-level properties (shared by all workflows)
502
+ if step_key:
503
+ request_json["step_key"] = step_key
504
+ if deployment_id:
505
+ request_json["deployment_id"] = deployment_id
506
+ if parent_execution_id:
507
+ request_json["parent_execution_id"] = parent_execution_id
508
+ if root_execution_id:
509
+ request_json["root_execution_id"] = root_execution_id
510
+ if session_id:
511
+ request_json["session_id"] = session_id
512
+ if user_id:
513
+ request_json["user_id"] = user_id
514
+ request_json["wait_for_subworkflow"] = wait_for_subworkflow
515
+ if otel_traceparent:
516
+ request_json["otel_traceparent"] = otel_traceparent
517
+
518
+ response = await client.post(
519
+ f"{self.api_url}/api/v1/workflows/batch_run",
520
+ json=request_json,
521
+ headers=headers,
522
+ )
523
+ response.raise_for_status()
524
+ data = response.json()
525
+
526
+ # Build ExecutionHandle objects from response
527
+ # The API returns executions in the same order as the request
528
+ handles = []
529
+ executions = data.get("executions", [])
530
+ for i, execution_response in enumerate(executions):
531
+ execution_id = execution_response["execution_id"]
532
+ created_at = execution_response.get("created_at")
533
+
534
+ # Get workflow_id from the corresponding workflow request (API doesn't return it)
535
+ # Executions are returned in the same order as the request
536
+ workflow_id = workflows[i]["workflow_id"] if i < len(workflows) else None
537
+
538
+ handles.append(
539
+ ExecutionHandle(
540
+ id=execution_id,
541
+ workflow_id=workflow_id,
542
+ created_at=created_at,
543
+ parent_execution_id=parent_execution_id,
544
+ root_execution_id=root_execution_id,
545
+ session_id=session_id,
546
+ user_id=user_id,
547
+ step_key=step_key,
548
+ )
549
+ )
550
+
551
+ return handles
552
+ finally:
553
+ if use_context_manager:
554
+ await client.aclose()
555
+
556
+
557
+ class ExecutionHandle(BaseModel):
558
+ """Handle for a workflow execution that allows monitoring and management."""
559
+
560
+ # Primary field name is 'id'
561
+ id: str
562
+ workflow_id: str | None = None
563
+ created_at: str | None = None
564
+ parent_execution_id: str | None = None
565
+ root_execution_id: str | None = None
566
+ session_id: str | None = None
567
+ user_id: str | None = None
568
+ step_key: str | None = None
569
+
570
+ def to_dict(self) -> dict[str, Any]:
571
+ """Convert the execution handle to a dictionary."""
572
+ return self.model_dump(mode="json")
573
+
574
+ async def get(self, client: PolosClient) -> dict[str, Any]:
575
+ """Get the current status of the execution."""
576
+ headers = client._get_headers()
577
+
578
+ # Try to reuse worker's HTTP client if available
579
+ worker_client = get_worker_client()
580
+ if worker_client is not None:
581
+ response = await worker_client.get(
582
+ f"{client.api_url}/api/v1/executions/{self.id}",
583
+ headers=headers,
584
+ )
585
+ response.raise_for_status()
586
+ else:
587
+ async with httpx.AsyncClient() as http_client:
588
+ response = await http_client.get(
589
+ f"{client.api_url}/api/v1/executions/{self.id}",
590
+ headers=headers,
591
+ )
592
+ response.raise_for_status()
593
+
594
+ execution = response.json()
595
+ self._cached_status = execution
596
+ result = await self._prepare_result(
597
+ execution.get("result"), execution.get("output_schema_name")
598
+ )
599
+
600
+ return {
601
+ "status": execution.get("status"),
602
+ "result": result,
603
+ "error": execution.get("error"),
604
+ "created_at": execution.get("created_at"),
605
+ "completed_at": execution.get("completed_at"),
606
+ "parent_execution_id": execution.get("parent_execution_id"),
607
+ "root_execution_id": execution.get("root_execution_id"),
608
+ "output_schema_name": execution.get("output_schema_name"),
609
+ "step_key": execution.get("step_key"),
610
+ }
611
+
612
+ async def cancel(self, client: PolosClient) -> bool:
613
+ """Cancel the execution if it's still queued or running.
614
+
615
+ Returns:
616
+ True if cancellation was successful, False otherwise
617
+ """
618
+ return await client.cancel_execution(self.id)
619
+
620
+ # Pydantic will generate __repr__ automatically, but we can customize it if needed
621
+ def __repr__(self) -> str:
622
+ # Use model_dump to get the dictionary representation.
623
+ fields = self.model_dump(exclude_none=True, mode="json")
624
+ field_str = ", ".join(f"{k}={v!r}" for k, v in fields.items())
625
+ return f"ExecutionHandle({field_str})"
626
+
627
+ async def _prepare_result(self, result: Any, output_schema_name: str | None = None) -> Any:
628
+ # Reconstruct Pydantic model if output_schema_name is present
629
+ prepared_result = result
630
+
631
+ from ..core.workflow import _WORKFLOW_REGISTRY
632
+
633
+ workflow = _WORKFLOW_REGISTRY.get(self.workflow_id)
634
+
635
+ if output_schema_name and result and isinstance(result, dict):
636
+ # First, check if the workflow has an output_schema stored (set during execution)
637
+ if workflow and hasattr(workflow, "output_schema") and workflow.output_schema:
638
+ # Use the stored output schema class from the workflow
639
+ try:
640
+ prepared_result = workflow.output_schema.model_validate(result)
641
+ except (ValueError, TypeError) as e:
642
+ logger.warning(
643
+ f"Failed to reconstruct Pydantic model using workflow.output_schema: {e}. "
644
+ f"Falling back to dynamic import."
645
+ )
646
+
647
+ # Fallback to dynamic import if workflow.output_schema is not available
648
+ try:
649
+ # Dynamically import the Pydantic class
650
+ module_path, class_name = output_schema_name.rsplit(".", 1)
651
+ module = __import__(module_path, fromlist=[class_name])
652
+ model_class = getattr(module, class_name)
653
+
654
+ # Validate that it's a Pydantic BaseModel
655
+ if issubclass(model_class, BaseModel):
656
+ prepared_result = model_class.model_validate(result)
657
+ except (ImportError, AttributeError, ValueError, TypeError) as e:
658
+ # If reconstruction fails, log warning but return dict
659
+ # This allows backward compatibility if the class is not available
660
+ logger.warning(
661
+ f"Failed to reconstruct Pydantic model '{output_schema_name}': {e}. "
662
+ f"Returning dict instead."
663
+ )
664
+
665
+ # Handle structured output for agents
666
+ from ..agents.agent import Agent
667
+
668
+ if workflow and isinstance(workflow, Agent) and isinstance(prepared_result, AgentResult):
669
+ # Convert result to structured output schema
670
+ if workflow.result_output_schema and prepared_result.result is not None:
671
+ prepared_result.result = workflow.result_output_schema.model_validate(
672
+ prepared_result.result
673
+ )
674
+
675
+ # Convert tool results to structured output schema
676
+ for tool_result in prepared_result.tool_results:
677
+ if tool_result.result_schema and tool_result.result is not None:
678
+ try:
679
+ # Dynamically import the Pydantic class
680
+ module_path, class_name = tool_result.result_schema.rsplit(".", 1)
681
+ module = __import__(module_path, fromlist=[class_name])
682
+ model_class = getattr(module, class_name)
683
+
684
+ # Validate that it's a Pydantic BaseModel
685
+ if issubclass(model_class, BaseModel):
686
+ tool_result.result = model_class.model_validate(tool_result.result)
687
+ except (ImportError, AttributeError, ValueError, TypeError) as e:
688
+ # If reconstruction fails, log warning but return dict
689
+ # This allows backward compatibility if the class is not available
690
+ logger.warning(
691
+ f"Failed to reconstruct Pydantic model "
692
+ f"'{tool_result.result_schema}': {e}. "
693
+ f"Returning dict instead."
694
+ )
695
+
696
+ return prepared_result
697
+
698
+
699
+ # Internal functions that get client from context
700
+ async def store_step_output(
701
+ execution_id: str,
702
+ step_key: str,
703
+ outputs: Any | None = None,
704
+ error: Any | None = None,
705
+ success: bool | None = True,
706
+ source_execution_id: str | None = None,
707
+ output_schema_name: str | None = None,
708
+ ) -> None:
709
+ """Store step output for recovery.
710
+
711
+ Args:
712
+ execution_id: Execution ID
713
+ step_key: Step key identifier (required, must be unique per execution)
714
+ outputs: Step outputs (optional)
715
+ error: Step error (optional)
716
+ success: Whether step succeeded (optional)
717
+ source_execution_id: Source execution ID (optional)
718
+ output_schema_name: Full module path of Pydantic class for deserialization (optional)
719
+ """
720
+ client = get_client_or_raise()
721
+ headers = client._get_headers()
722
+
723
+ # Build request payload, only including fields that are not None
724
+ payload = {
725
+ "step_key": step_key,
726
+ }
727
+
728
+ if outputs is not None:
729
+ payload["outputs"] = outputs
730
+ if error is not None:
731
+ payload["error"] = error
732
+ if success is not None:
733
+ payload["success"] = success
734
+ if source_execution_id is not None:
735
+ payload["source_execution_id"] = source_execution_id
736
+ if output_schema_name is not None:
737
+ payload["output_schema_name"] = output_schema_name
738
+
739
+ # Try to reuse worker's HTTP client if available
740
+ worker_client = get_worker_client()
741
+ if worker_client is not None:
742
+ response = await worker_client.post(
743
+ f"{client.api_url}/internal/executions/{execution_id}/steps",
744
+ json=payload,
745
+ headers=headers,
746
+ )
747
+ response.raise_for_status()
748
+ else:
749
+ async with httpx.AsyncClient() as http_client:
750
+ response = await http_client.post(
751
+ f"{client.api_url}/internal/executions/{execution_id}/steps",
752
+ json=payload,
753
+ headers=headers,
754
+ )
755
+ response.raise_for_status()
756
+
757
+
758
+ async def get_step_output(execution_id: str, step_key: str) -> dict[str, Any] | None:
759
+ """Get step output for recovery.
760
+
761
+ Args:
762
+ execution_id: Execution ID
763
+ step_key: Step key identifier (required)
764
+
765
+ Returns:
766
+ Step output dictionary or None if not found
767
+ """
768
+ client = get_client_or_raise()
769
+ headers = client._get_headers()
770
+
771
+ # Try to reuse worker's HTTP client if available
772
+ worker_client = get_worker_client()
773
+ if worker_client is not None:
774
+ response = await worker_client.get(
775
+ f"{client.api_url}/internal/executions/{execution_id}/steps/{step_key}",
776
+ headers=headers,
777
+ )
778
+ if response.status_code == 404:
779
+ return None
780
+ response.raise_for_status()
781
+ return response.json()
782
+ else:
783
+ async with httpx.AsyncClient() as http_client:
784
+ response = await http_client.get(
785
+ f"{client.api_url}/internal/executions/{execution_id}/steps/{step_key}",
786
+ headers=headers,
787
+ )
788
+ if response.status_code == 404:
789
+ return None
790
+ response.raise_for_status()
791
+ return response.json()
792
+
793
+
794
+ async def get_all_step_outputs(execution_id: str) -> list:
795
+ """Get all step outputs for an execution (for recovery)."""
796
+ client = get_client_or_raise()
797
+ headers = client._get_headers()
798
+
799
+ # Try to reuse worker's HTTP client if available
800
+ worker_client = get_worker_client()
801
+ if worker_client is not None:
802
+ response = await worker_client.get(
803
+ f"{client.api_url}/internal/executions/{execution_id}/steps",
804
+ headers=headers,
805
+ )
806
+ response.raise_for_status()
807
+ data = response.json()
808
+ return data.get("steps", [])
809
+ else:
810
+ async with httpx.AsyncClient() as http_client:
811
+ response = await http_client.get(
812
+ f"{client.api_url}/internal/executions/{execution_id}/steps",
813
+ headers=headers,
814
+ )
815
+ response.raise_for_status()
816
+ data = response.json()
817
+ return data.get("steps", [])
818
+
819
+
820
+ async def update_execution_otel_span_id(execution_id: str, otel_span_id: str | None) -> None:
821
+ """Update execution's otel_span_id (used when workflow is paused via WaitException)."""
822
+ client = get_client_or_raise()
823
+ headers = client._get_headers()
824
+
825
+ # Try to reuse worker's HTTP client if available
826
+ worker_client = get_worker_client()
827
+ if worker_client is not None:
828
+ response = await worker_client.put(
829
+ f"{client.api_url}/internal/executions/{execution_id}/otel-span-id",
830
+ headers=headers,
831
+ json={"otel_span_id": otel_span_id},
832
+ )
833
+ response.raise_for_status()
834
+ else:
835
+ async with httpx.AsyncClient() as http_client:
836
+ response = await http_client.put(
837
+ f"{client.api_url}/internal/executions/{execution_id}/otel-span-id",
838
+ headers=headers,
839
+ json={"otel_span_id": otel_span_id},
840
+ )
841
+ response.raise_for_status()