polos-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- polos/__init__.py +105 -0
- polos/agents/__init__.py +7 -0
- polos/agents/agent.py +746 -0
- polos/agents/conversation_history.py +121 -0
- polos/agents/stop_conditions.py +280 -0
- polos/agents/stream.py +635 -0
- polos/core/__init__.py +0 -0
- polos/core/context.py +143 -0
- polos/core/state.py +26 -0
- polos/core/step.py +1380 -0
- polos/core/workflow.py +1192 -0
- polos/features/__init__.py +0 -0
- polos/features/events.py +456 -0
- polos/features/schedules.py +110 -0
- polos/features/tracing.py +605 -0
- polos/features/wait.py +82 -0
- polos/llm/__init__.py +9 -0
- polos/llm/generate.py +152 -0
- polos/llm/providers/__init__.py +5 -0
- polos/llm/providers/anthropic.py +615 -0
- polos/llm/providers/azure.py +42 -0
- polos/llm/providers/base.py +196 -0
- polos/llm/providers/fireworks.py +41 -0
- polos/llm/providers/gemini.py +40 -0
- polos/llm/providers/groq.py +40 -0
- polos/llm/providers/openai.py +1021 -0
- polos/llm/providers/together.py +40 -0
- polos/llm/stream.py +183 -0
- polos/middleware/__init__.py +0 -0
- polos/middleware/guardrail.py +148 -0
- polos/middleware/guardrail_executor.py +253 -0
- polos/middleware/hook.py +164 -0
- polos/middleware/hook_executor.py +104 -0
- polos/runtime/__init__.py +0 -0
- polos/runtime/batch.py +87 -0
- polos/runtime/client.py +841 -0
- polos/runtime/queue.py +42 -0
- polos/runtime/worker.py +1365 -0
- polos/runtime/worker_server.py +249 -0
- polos/tools/__init__.py +0 -0
- polos/tools/tool.py +587 -0
- polos/types/__init__.py +23 -0
- polos/types/types.py +116 -0
- polos/utils/__init__.py +27 -0
- polos/utils/agent.py +27 -0
- polos/utils/client_context.py +41 -0
- polos/utils/config.py +12 -0
- polos/utils/output_schema.py +311 -0
- polos/utils/retry.py +47 -0
- polos/utils/serializer.py +167 -0
- polos/utils/tracing.py +27 -0
- polos/utils/worker_singleton.py +40 -0
- polos_sdk-0.1.0.dist-info/METADATA +650 -0
- polos_sdk-0.1.0.dist-info/RECORD +55 -0
- polos_sdk-0.1.0.dist-info/WHEEL +4 -0
polos/runtime/client.py
ADDED
|
@@ -0,0 +1,841 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import httpx
|
|
7
|
+
from dotenv import load_dotenv
|
|
8
|
+
from pydantic import BaseModel
|
|
9
|
+
|
|
10
|
+
from ..types.types import AgentResult, BatchWorkflowInput
|
|
11
|
+
from ..utils.client_context import get_client_or_raise
|
|
12
|
+
from ..utils.config import is_localhost_url
|
|
13
|
+
from ..utils.worker_singleton import get_worker_client
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
load_dotenv()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _validate_state_size(state: dict[str, Any] | BaseModel, max_size_mb: float = 1.0) -> None:
|
|
21
|
+
"""Validate that state JSON size doesn't exceed limit.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
state: State dictionary to validate
|
|
25
|
+
max_size_mb: Maximum size in MB (default: 1.0)
|
|
26
|
+
|
|
27
|
+
Raises:
|
|
28
|
+
ValueError: If state size exceeds limit
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
state_json = state.model_dump_json() if isinstance(state, BaseModel) else json.dumps(state)
|
|
32
|
+
size_bytes = len(state_json.encode("utf-8"))
|
|
33
|
+
max_bytes = int(max_size_mb * 1024 * 1024)
|
|
34
|
+
|
|
35
|
+
if size_bytes > max_bytes:
|
|
36
|
+
size_mb = size_bytes / (1024 * 1024)
|
|
37
|
+
raise ValueError(
|
|
38
|
+
f"Workflow state size ({size_mb:.2f}MB) exceeds maximum allowed size "
|
|
39
|
+
f"({max_size_mb}MB). Consider reducing state size or using external storage."
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class PolosClient:
|
|
44
|
+
"""Client for interacting with the Polos orchestrator API."""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
api_url: str | None = None,
|
|
49
|
+
api_key: str | None = None,
|
|
50
|
+
project_id: str | None = None,
|
|
51
|
+
):
|
|
52
|
+
"""Initialize Polos client.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
api_url: Orchestrator API URL (default: from POLOS_API_URL env var or http://localhost:8080)
|
|
56
|
+
api_key: API key for authentication (default: from POLOS_API_KEY env var)
|
|
57
|
+
project_id: Project ID for multi-tenancy (default: from POLOS_PROJECT_ID env var)
|
|
58
|
+
"""
|
|
59
|
+
self.api_url = api_url or os.getenv("POLOS_API_URL", "http://localhost:8080")
|
|
60
|
+
self.api_key = api_key or os.getenv("POLOS_API_KEY")
|
|
61
|
+
self.project_id = project_id or os.getenv("POLOS_PROJECT_ID")
|
|
62
|
+
|
|
63
|
+
# Validate required fields (with local mode support)
|
|
64
|
+
local_mode_requested = os.getenv("POLOS_LOCAL_MODE", "False").lower() == "true"
|
|
65
|
+
is_localhost = is_localhost_url(self.api_url)
|
|
66
|
+
local_mode = local_mode_requested and is_localhost
|
|
67
|
+
|
|
68
|
+
if local_mode_requested and not is_localhost:
|
|
69
|
+
logger.warning(
|
|
70
|
+
f"POLOS_LOCAL_MODE=True ignored because api_url ({self.api_url}) "
|
|
71
|
+
"is not localhost. Falling back to normal authentication."
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
if not local_mode and not self.api_key:
|
|
75
|
+
raise ValueError(
|
|
76
|
+
"api_key is required. Set it via PolosClient(api_key='...') "
|
|
77
|
+
"or POLOS_API_KEY environment variable. Or set "
|
|
78
|
+
"POLOS_LOCAL_MODE=True for local development "
|
|
79
|
+
"(only works with localhost URLs)."
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
if not self.project_id:
|
|
83
|
+
raise ValueError(
|
|
84
|
+
"project_id is required. Set it via PolosClient(project_id='...') "
|
|
85
|
+
"or POLOS_PROJECT_ID environment variable."
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
def _get_headers(self) -> dict[str, str]:
|
|
89
|
+
"""Get headers for API requests, including project_id and API key.
|
|
90
|
+
|
|
91
|
+
The API key is required for all orchestrator API calls, unless POLOS_LOCAL_MODE=True.
|
|
92
|
+
Local mode is only enabled when api_url is localhost.
|
|
93
|
+
"""
|
|
94
|
+
headers = {"Content-Type": "application/json"}
|
|
95
|
+
|
|
96
|
+
# Check for local mode (only enabled for localhost URLs)
|
|
97
|
+
local_mode_requested = os.getenv("POLOS_LOCAL_MODE", "False").lower() == "true"
|
|
98
|
+
is_localhost = is_localhost_url(self.api_url)
|
|
99
|
+
local_mode = local_mode_requested and is_localhost
|
|
100
|
+
|
|
101
|
+
# API key is required for all orchestrator API calls unless in local mode
|
|
102
|
+
if not local_mode:
|
|
103
|
+
if not self.api_key:
|
|
104
|
+
raise ValueError(
|
|
105
|
+
"api_key is required. Set it via PolosClient(api_key='...') "
|
|
106
|
+
"or POLOS_API_KEY environment variable. Or set "
|
|
107
|
+
"POLOS_LOCAL_MODE=True for local development "
|
|
108
|
+
"(only works with localhost URLs)."
|
|
109
|
+
)
|
|
110
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
111
|
+
|
|
112
|
+
# Add project_id header (required for multi-tenancy)
|
|
113
|
+
if not self.project_id:
|
|
114
|
+
raise ValueError(
|
|
115
|
+
"project_id is required. Set it via PolosClient(project_id='...') "
|
|
116
|
+
"or POLOS_PROJECT_ID environment variable."
|
|
117
|
+
)
|
|
118
|
+
headers["X-Project-ID"] = self.project_id
|
|
119
|
+
|
|
120
|
+
return headers
|
|
121
|
+
|
|
122
|
+
async def _get_http_client(self, timeout: httpx.Timeout | None = None) -> httpx.AsyncClient:
|
|
123
|
+
"""Get an HTTP client, reusing the worker's client if available.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
timeout: Optional timeout for the client (only used if creating new client)
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
An httpx.AsyncClient instance (either from worker or newly created)
|
|
130
|
+
|
|
131
|
+
Note:
|
|
132
|
+
If using worker's client, it's not closed after use (worker manages its lifecycle).
|
|
133
|
+
If creating a new client, it should be used with `async with` context manager.
|
|
134
|
+
"""
|
|
135
|
+
worker_client = get_worker_client()
|
|
136
|
+
if worker_client is not None:
|
|
137
|
+
return worker_client
|
|
138
|
+
else:
|
|
139
|
+
if timeout is not None:
|
|
140
|
+
return httpx.AsyncClient(timeout=timeout)
|
|
141
|
+
else:
|
|
142
|
+
return httpx.AsyncClient()
|
|
143
|
+
|
|
144
|
+
async def invoke(
|
|
145
|
+
self,
|
|
146
|
+
workflow_id: str,
|
|
147
|
+
payload: Any = None,
|
|
148
|
+
queue_name: str | None = None,
|
|
149
|
+
queue_concurrency_limit: int | None = None,
|
|
150
|
+
concurrency_key: str | None = None,
|
|
151
|
+
session_id: str | None = None,
|
|
152
|
+
user_id: str | None = None,
|
|
153
|
+
initial_state: dict[str, Any] | None = None,
|
|
154
|
+
run_timeout_seconds: int | None = None,
|
|
155
|
+
) -> "ExecutionHandle":
|
|
156
|
+
"""Invoke a workflow and return an execution handle.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
workflow_id: The workflow identifier
|
|
160
|
+
payload: The workflow payload
|
|
161
|
+
deployment_id: Optional deployment ID (if not provided, uses latest active)
|
|
162
|
+
queue_name: Optional queue name (if not provided, defaults to workflow_id)
|
|
163
|
+
queue_concurrency_limit: Optional concurrency limit for queue creation
|
|
164
|
+
concurrency_key: Optional concurrency key for per-tenant queuing
|
|
165
|
+
session_id: Optional session ID
|
|
166
|
+
user_id: Optional user ID
|
|
167
|
+
initial_state: Optional initial state dictionary (must be JSON-serializable, max 1MB)
|
|
168
|
+
run_timeout_seconds: Optional timeout in seconds
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
ExecutionHandle for the submitted workflow
|
|
172
|
+
"""
|
|
173
|
+
return await self._submit_workflow(
|
|
174
|
+
workflow_id=workflow_id,
|
|
175
|
+
deployment_id=None, # Use latest deployment
|
|
176
|
+
payload=payload,
|
|
177
|
+
queue_name=queue_name,
|
|
178
|
+
queue_concurrency_limit=queue_concurrency_limit,
|
|
179
|
+
concurrency_key=concurrency_key,
|
|
180
|
+
session_id=session_id,
|
|
181
|
+
user_id=user_id,
|
|
182
|
+
initial_state=initial_state,
|
|
183
|
+
run_timeout_seconds=run_timeout_seconds,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
async def batch_invoke(
|
|
187
|
+
self,
|
|
188
|
+
workflows: list[BatchWorkflowInput],
|
|
189
|
+
session_id: str | None = None,
|
|
190
|
+
user_id: str | None = None,
|
|
191
|
+
) -> list["ExecutionHandle"]:
|
|
192
|
+
"""Invoke multiple different workflows in a single batch and return handles immediately.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
workflows: List of BatchWorkflowInput objects with 'id' (workflow_id string)
|
|
196
|
+
and 'payload' (dict or Pydantic model)
|
|
197
|
+
session_id: Optional session ID
|
|
198
|
+
user_id: Optional user ID
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
List of ExecutionHandle objects for the submitted workflows
|
|
202
|
+
"""
|
|
203
|
+
from ..core.workflow import get_workflow
|
|
204
|
+
from ..utils.serializer import serialize
|
|
205
|
+
|
|
206
|
+
if not workflows:
|
|
207
|
+
return []
|
|
208
|
+
|
|
209
|
+
# Build workflow requests for batch submission
|
|
210
|
+
workflow_requests = []
|
|
211
|
+
for workflow_input in workflows:
|
|
212
|
+
workflow_id = workflow_input.id
|
|
213
|
+
payload = serialize(workflow_input.payload)
|
|
214
|
+
|
|
215
|
+
workflow_obj = get_workflow(workflow_id)
|
|
216
|
+
if not workflow_obj:
|
|
217
|
+
raise ValueError(f"Workflow '{workflow_id}' not found")
|
|
218
|
+
|
|
219
|
+
workflow_req = {
|
|
220
|
+
"workflow_id": workflow_id,
|
|
221
|
+
"payload": payload,
|
|
222
|
+
"initial_state": serialize(workflow_input.initial_state),
|
|
223
|
+
"run_timeout_seconds": workflow_input.run_timeout_seconds,
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
# Per-workflow properties (queue_name, concurrency_key, etc.)
|
|
227
|
+
if workflow_obj.queue_name is not None:
|
|
228
|
+
workflow_req["queue_name"] = workflow_obj.queue_name
|
|
229
|
+
|
|
230
|
+
if workflow_obj.queue_concurrency_limit is not None:
|
|
231
|
+
workflow_req["queue_concurrency_limit"] = workflow_obj.queue_concurrency_limit
|
|
232
|
+
|
|
233
|
+
workflow_requests.append(workflow_req)
|
|
234
|
+
|
|
235
|
+
# Submit all workflows in a single batch using the batch endpoint
|
|
236
|
+
handles = await self._submit_workflows(
|
|
237
|
+
workflows=workflow_requests,
|
|
238
|
+
deployment_id=None, # Use latest active deployment
|
|
239
|
+
parent_execution_id=None,
|
|
240
|
+
root_execution_id=None,
|
|
241
|
+
step_key=None, # Not invoked from a step, so no step_key
|
|
242
|
+
session_id=session_id,
|
|
243
|
+
user_id=user_id,
|
|
244
|
+
wait_for_subworkflow=False, # Fire-and-forget
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
return handles
|
|
248
|
+
|
|
249
|
+
async def resume(self, suspend_execution_id: str, suspend_step_key: str, data: Any) -> None:
|
|
250
|
+
"""Resume a suspended execution by publishing a resume event.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
suspend_execution_id: The execution ID that is suspended
|
|
254
|
+
suspend_step_key: The step key that was used in suspend()
|
|
255
|
+
data: Data to pass in the resume event (can be dict or Pydantic BaseModel)
|
|
256
|
+
"""
|
|
257
|
+
from ..features.events import EventData, batch_publish
|
|
258
|
+
from ..utils.serializer import serialize
|
|
259
|
+
|
|
260
|
+
# Serialize data
|
|
261
|
+
serialized_data = serialize(data)
|
|
262
|
+
|
|
263
|
+
topic = f"{suspend_step_key}/{suspend_execution_id}"
|
|
264
|
+
|
|
265
|
+
# Publish resume event
|
|
266
|
+
await batch_publish(
|
|
267
|
+
topic=topic,
|
|
268
|
+
events=[EventData(data=serialized_data, event_type="resume")],
|
|
269
|
+
client=self,
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
async def get_execution(self, execution_id: str) -> dict[str, Any]:
|
|
273
|
+
"""Get execution details from the orchestrator.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
execution_id: The execution ID to look up
|
|
277
|
+
|
|
278
|
+
Returns:
|
|
279
|
+
Dictionary with execution details including:
|
|
280
|
+
- id, workflow_id, status, payload, result, error
|
|
281
|
+
- created_at, started_at, completed_at
|
|
282
|
+
- deployment_id, parent_execution_id, root_execution_id
|
|
283
|
+
- retry_count, step_key etc.
|
|
284
|
+
"""
|
|
285
|
+
headers = self._get_headers()
|
|
286
|
+
|
|
287
|
+
# Try to reuse worker's HTTP client if available
|
|
288
|
+
worker_client = get_worker_client()
|
|
289
|
+
if worker_client is not None:
|
|
290
|
+
response = await worker_client.get(
|
|
291
|
+
f"{self.api_url}/api/v1/executions/{execution_id}",
|
|
292
|
+
headers=headers,
|
|
293
|
+
)
|
|
294
|
+
response.raise_for_status()
|
|
295
|
+
return response.json()
|
|
296
|
+
else:
|
|
297
|
+
async with httpx.AsyncClient() as client:
|
|
298
|
+
response = await client.get(
|
|
299
|
+
f"{self.api_url}/api/v1/executions/{execution_id}",
|
|
300
|
+
headers=headers,
|
|
301
|
+
)
|
|
302
|
+
response.raise_for_status()
|
|
303
|
+
return response.json()
|
|
304
|
+
|
|
305
|
+
async def cancel_execution(self, execution_id: str) -> bool:
|
|
306
|
+
"""Cancel an execution by its ID.
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
execution_id: The execution ID to cancel
|
|
310
|
+
|
|
311
|
+
Returns:
|
|
312
|
+
True if cancellation was successful, False otherwise
|
|
313
|
+
"""
|
|
314
|
+
headers = self._get_headers()
|
|
315
|
+
|
|
316
|
+
# Try to reuse worker's HTTP client if available
|
|
317
|
+
worker_client = get_worker_client()
|
|
318
|
+
if worker_client is not None:
|
|
319
|
+
client = worker_client
|
|
320
|
+
use_context_manager = False
|
|
321
|
+
else:
|
|
322
|
+
client = httpx.AsyncClient(timeout=httpx.Timeout(30.0))
|
|
323
|
+
use_context_manager = True
|
|
324
|
+
|
|
325
|
+
try:
|
|
326
|
+
response = await client.post(
|
|
327
|
+
f"{self.api_url}/api/v1/executions/{execution_id}/cancel",
|
|
328
|
+
headers=headers,
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
if response.status_code == 404:
|
|
332
|
+
# Execution not found
|
|
333
|
+
return False
|
|
334
|
+
|
|
335
|
+
response.raise_for_status()
|
|
336
|
+
return True
|
|
337
|
+
except httpx.HTTPStatusError as e:
|
|
338
|
+
logger.error("Failed to cancel execution %s: %s", execution_id, e)
|
|
339
|
+
return False
|
|
340
|
+
except Exception as e:
|
|
341
|
+
logger.error("Error cancelling execution %s: %s", execution_id, e)
|
|
342
|
+
return False
|
|
343
|
+
finally:
|
|
344
|
+
if use_context_manager:
|
|
345
|
+
await client.aclose()
|
|
346
|
+
|
|
347
|
+
async def _submit_workflow(
|
|
348
|
+
self,
|
|
349
|
+
workflow_id: str,
|
|
350
|
+
payload: Any,
|
|
351
|
+
deployment_id: str | None = None,
|
|
352
|
+
parent_execution_id: str | None = None,
|
|
353
|
+
root_execution_id: str | None = None,
|
|
354
|
+
step_key: str | None = None,
|
|
355
|
+
queue_name: str | None = None,
|
|
356
|
+
queue_concurrency_limit: int | None = None,
|
|
357
|
+
concurrency_key: str | None = None,
|
|
358
|
+
wait_for_subworkflow: bool = False,
|
|
359
|
+
batch_id: str | None = None,
|
|
360
|
+
session_id: str | None = None,
|
|
361
|
+
user_id: str | None = None,
|
|
362
|
+
otel_traceparent: str | None = None,
|
|
363
|
+
initial_state: dict[str, Any] | None = None,
|
|
364
|
+
run_timeout_seconds: int | None = None,
|
|
365
|
+
) -> "ExecutionHandle":
|
|
366
|
+
"""Submit a workflow and return an execution handle.
|
|
367
|
+
|
|
368
|
+
Internal method used by invoke() and step.invoke_and_wait().
|
|
369
|
+
"""
|
|
370
|
+
headers = self._get_headers()
|
|
371
|
+
|
|
372
|
+
# Validate initial_state size if provided
|
|
373
|
+
if initial_state is not None:
|
|
374
|
+
_validate_state_size(initial_state)
|
|
375
|
+
|
|
376
|
+
# Inherit session_id and user_id from parent if not provided
|
|
377
|
+
if not session_id or not user_id:
|
|
378
|
+
from ..core.workflow import _execution_context
|
|
379
|
+
|
|
380
|
+
exec_context = _execution_context.get()
|
|
381
|
+
if exec_context:
|
|
382
|
+
if not session_id:
|
|
383
|
+
session_id = exec_context.get("session_id")
|
|
384
|
+
if not user_id:
|
|
385
|
+
user_id = exec_context.get("user_id")
|
|
386
|
+
|
|
387
|
+
# Try to reuse worker's HTTP client if available
|
|
388
|
+
worker_client = get_worker_client()
|
|
389
|
+
if worker_client is not None:
|
|
390
|
+
client = worker_client
|
|
391
|
+
use_context_manager = False
|
|
392
|
+
else:
|
|
393
|
+
client = httpx.AsyncClient(timeout=httpx.Timeout(300.0))
|
|
394
|
+
use_context_manager = True
|
|
395
|
+
|
|
396
|
+
try:
|
|
397
|
+
# Submit workflow
|
|
398
|
+
request_json = {
|
|
399
|
+
"payload": payload,
|
|
400
|
+
}
|
|
401
|
+
if step_key:
|
|
402
|
+
request_json["step_key"] = step_key
|
|
403
|
+
if deployment_id:
|
|
404
|
+
request_json["deployment_id"] = deployment_id
|
|
405
|
+
if parent_execution_id:
|
|
406
|
+
request_json["parent_execution_id"] = parent_execution_id
|
|
407
|
+
if root_execution_id:
|
|
408
|
+
request_json["root_execution_id"] = root_execution_id
|
|
409
|
+
if queue_name:
|
|
410
|
+
request_json["queue_name"] = queue_name
|
|
411
|
+
if queue_concurrency_limit is not None:
|
|
412
|
+
request_json["queue_concurrency_limit"] = queue_concurrency_limit
|
|
413
|
+
if concurrency_key:
|
|
414
|
+
request_json["concurrency_key"] = concurrency_key
|
|
415
|
+
request_json["wait_for_subworkflow"] = wait_for_subworkflow
|
|
416
|
+
if batch_id:
|
|
417
|
+
request_json["batch_id"] = batch_id
|
|
418
|
+
if session_id:
|
|
419
|
+
request_json["session_id"] = session_id
|
|
420
|
+
if user_id:
|
|
421
|
+
request_json["user_id"] = user_id
|
|
422
|
+
if otel_traceparent:
|
|
423
|
+
request_json["otel_traceparent"] = otel_traceparent
|
|
424
|
+
if initial_state is not None:
|
|
425
|
+
request_json["initial_state"] = initial_state
|
|
426
|
+
if run_timeout_seconds is not None:
|
|
427
|
+
request_json["run_timeout_seconds"] = run_timeout_seconds
|
|
428
|
+
|
|
429
|
+
response = await client.post(
|
|
430
|
+
f"{self.api_url}/api/v1/workflows/{workflow_id}/run",
|
|
431
|
+
json=request_json,
|
|
432
|
+
headers=headers,
|
|
433
|
+
)
|
|
434
|
+
response.raise_for_status()
|
|
435
|
+
data = response.json()
|
|
436
|
+
execution_id_value = data["execution_id"]
|
|
437
|
+
created_at = data.get("created_at")
|
|
438
|
+
|
|
439
|
+
# Return handle immediately (fire and forget)
|
|
440
|
+
# Note: If called from within a workflow, the orchestrator has already
|
|
441
|
+
# set the parent to waiting
|
|
442
|
+
# invoke_and_wait() will raise WaitException to pause the parent if needed
|
|
443
|
+
return ExecutionHandle(
|
|
444
|
+
id=execution_id_value,
|
|
445
|
+
workflow_id=workflow_id,
|
|
446
|
+
created_at=created_at,
|
|
447
|
+
parent_execution_id=parent_execution_id,
|
|
448
|
+
root_execution_id=root_execution_id,
|
|
449
|
+
session_id=session_id,
|
|
450
|
+
user_id=user_id,
|
|
451
|
+
step_key=step_key,
|
|
452
|
+
)
|
|
453
|
+
finally:
|
|
454
|
+
if use_context_manager:
|
|
455
|
+
await client.aclose()
|
|
456
|
+
|
|
457
|
+
async def _submit_workflows(
|
|
458
|
+
self,
|
|
459
|
+
workflows: list[dict[str, Any]],
|
|
460
|
+
deployment_id: str | None = None,
|
|
461
|
+
parent_execution_id: str | None = None,
|
|
462
|
+
root_execution_id: str | None = None,
|
|
463
|
+
step_key: str | None = None,
|
|
464
|
+
session_id: str | None = None,
|
|
465
|
+
user_id: str | None = None,
|
|
466
|
+
wait_for_subworkflow: bool = False,
|
|
467
|
+
otel_traceparent: str | None = None,
|
|
468
|
+
) -> list["ExecutionHandle"]:
|
|
469
|
+
"""Submit multiple workflows in a batch and return execution handles.
|
|
470
|
+
|
|
471
|
+
Internal method used by batch_invoke() and step.batch_invoke().
|
|
472
|
+
"""
|
|
473
|
+
headers = self._get_headers()
|
|
474
|
+
|
|
475
|
+
# Inherit session_id and user_id from parent if not provided
|
|
476
|
+
if not session_id or not user_id:
|
|
477
|
+
from ..core.workflow import _execution_context
|
|
478
|
+
|
|
479
|
+
exec_context = _execution_context.get()
|
|
480
|
+
if exec_context:
|
|
481
|
+
if not session_id:
|
|
482
|
+
session_id = exec_context.get("session_id")
|
|
483
|
+
if not user_id:
|
|
484
|
+
user_id = exec_context.get("user_id")
|
|
485
|
+
|
|
486
|
+
# Try to reuse worker's HTTP client if available
|
|
487
|
+
worker_client = get_worker_client()
|
|
488
|
+
if worker_client is not None:
|
|
489
|
+
client = worker_client
|
|
490
|
+
use_context_manager = False
|
|
491
|
+
else:
|
|
492
|
+
client = httpx.AsyncClient(timeout=httpx.Timeout(300.0))
|
|
493
|
+
use_context_manager = True
|
|
494
|
+
|
|
495
|
+
try:
|
|
496
|
+
# Prepare batch request
|
|
497
|
+
request_json = {
|
|
498
|
+
"workflows": workflows,
|
|
499
|
+
}
|
|
500
|
+
|
|
501
|
+
# Common batch-level properties (shared by all workflows)
|
|
502
|
+
if step_key:
|
|
503
|
+
request_json["step_key"] = step_key
|
|
504
|
+
if deployment_id:
|
|
505
|
+
request_json["deployment_id"] = deployment_id
|
|
506
|
+
if parent_execution_id:
|
|
507
|
+
request_json["parent_execution_id"] = parent_execution_id
|
|
508
|
+
if root_execution_id:
|
|
509
|
+
request_json["root_execution_id"] = root_execution_id
|
|
510
|
+
if session_id:
|
|
511
|
+
request_json["session_id"] = session_id
|
|
512
|
+
if user_id:
|
|
513
|
+
request_json["user_id"] = user_id
|
|
514
|
+
request_json["wait_for_subworkflow"] = wait_for_subworkflow
|
|
515
|
+
if otel_traceparent:
|
|
516
|
+
request_json["otel_traceparent"] = otel_traceparent
|
|
517
|
+
|
|
518
|
+
response = await client.post(
|
|
519
|
+
f"{self.api_url}/api/v1/workflows/batch_run",
|
|
520
|
+
json=request_json,
|
|
521
|
+
headers=headers,
|
|
522
|
+
)
|
|
523
|
+
response.raise_for_status()
|
|
524
|
+
data = response.json()
|
|
525
|
+
|
|
526
|
+
# Build ExecutionHandle objects from response
|
|
527
|
+
# The API returns executions in the same order as the request
|
|
528
|
+
handles = []
|
|
529
|
+
executions = data.get("executions", [])
|
|
530
|
+
for i, execution_response in enumerate(executions):
|
|
531
|
+
execution_id = execution_response["execution_id"]
|
|
532
|
+
created_at = execution_response.get("created_at")
|
|
533
|
+
|
|
534
|
+
# Get workflow_id from the corresponding workflow request (API doesn't return it)
|
|
535
|
+
# Executions are returned in the same order as the request
|
|
536
|
+
workflow_id = workflows[i]["workflow_id"] if i < len(workflows) else None
|
|
537
|
+
|
|
538
|
+
handles.append(
|
|
539
|
+
ExecutionHandle(
|
|
540
|
+
id=execution_id,
|
|
541
|
+
workflow_id=workflow_id,
|
|
542
|
+
created_at=created_at,
|
|
543
|
+
parent_execution_id=parent_execution_id,
|
|
544
|
+
root_execution_id=root_execution_id,
|
|
545
|
+
session_id=session_id,
|
|
546
|
+
user_id=user_id,
|
|
547
|
+
step_key=step_key,
|
|
548
|
+
)
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
return handles
|
|
552
|
+
finally:
|
|
553
|
+
if use_context_manager:
|
|
554
|
+
await client.aclose()
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
class ExecutionHandle(BaseModel):
|
|
558
|
+
"""Handle for a workflow execution that allows monitoring and management."""
|
|
559
|
+
|
|
560
|
+
# Primary field name is 'id'
|
|
561
|
+
id: str
|
|
562
|
+
workflow_id: str | None = None
|
|
563
|
+
created_at: str | None = None
|
|
564
|
+
parent_execution_id: str | None = None
|
|
565
|
+
root_execution_id: str | None = None
|
|
566
|
+
session_id: str | None = None
|
|
567
|
+
user_id: str | None = None
|
|
568
|
+
step_key: str | None = None
|
|
569
|
+
|
|
570
|
+
def to_dict(self) -> dict[str, Any]:
|
|
571
|
+
"""Convert the execution handle to a dictionary."""
|
|
572
|
+
return self.model_dump(mode="json")
|
|
573
|
+
|
|
574
|
+
async def get(self, client: PolosClient) -> dict[str, Any]:
|
|
575
|
+
"""Get the current status of the execution."""
|
|
576
|
+
headers = client._get_headers()
|
|
577
|
+
|
|
578
|
+
# Try to reuse worker's HTTP client if available
|
|
579
|
+
worker_client = get_worker_client()
|
|
580
|
+
if worker_client is not None:
|
|
581
|
+
response = await worker_client.get(
|
|
582
|
+
f"{client.api_url}/api/v1/executions/{self.id}",
|
|
583
|
+
headers=headers,
|
|
584
|
+
)
|
|
585
|
+
response.raise_for_status()
|
|
586
|
+
else:
|
|
587
|
+
async with httpx.AsyncClient() as http_client:
|
|
588
|
+
response = await http_client.get(
|
|
589
|
+
f"{client.api_url}/api/v1/executions/{self.id}",
|
|
590
|
+
headers=headers,
|
|
591
|
+
)
|
|
592
|
+
response.raise_for_status()
|
|
593
|
+
|
|
594
|
+
execution = response.json()
|
|
595
|
+
self._cached_status = execution
|
|
596
|
+
result = await self._prepare_result(
|
|
597
|
+
execution.get("result"), execution.get("output_schema_name")
|
|
598
|
+
)
|
|
599
|
+
|
|
600
|
+
return {
|
|
601
|
+
"status": execution.get("status"),
|
|
602
|
+
"result": result,
|
|
603
|
+
"error": execution.get("error"),
|
|
604
|
+
"created_at": execution.get("created_at"),
|
|
605
|
+
"completed_at": execution.get("completed_at"),
|
|
606
|
+
"parent_execution_id": execution.get("parent_execution_id"),
|
|
607
|
+
"root_execution_id": execution.get("root_execution_id"),
|
|
608
|
+
"output_schema_name": execution.get("output_schema_name"),
|
|
609
|
+
"step_key": execution.get("step_key"),
|
|
610
|
+
}
|
|
611
|
+
|
|
612
|
+
async def cancel(self, client: PolosClient) -> bool:
|
|
613
|
+
"""Cancel the execution if it's still queued or running.
|
|
614
|
+
|
|
615
|
+
Returns:
|
|
616
|
+
True if cancellation was successful, False otherwise
|
|
617
|
+
"""
|
|
618
|
+
return await client.cancel_execution(self.id)
|
|
619
|
+
|
|
620
|
+
# Pydantic will generate __repr__ automatically, but we can customize it if needed
|
|
621
|
+
def __repr__(self) -> str:
|
|
622
|
+
# Use model_dump to get the dictionary representation.
|
|
623
|
+
fields = self.model_dump(exclude_none=True, mode="json")
|
|
624
|
+
field_str = ", ".join(f"{k}={v!r}" for k, v in fields.items())
|
|
625
|
+
return f"ExecutionHandle({field_str})"
|
|
626
|
+
|
|
627
|
+
async def _prepare_result(self, result: Any, output_schema_name: str | None = None) -> Any:
|
|
628
|
+
# Reconstruct Pydantic model if output_schema_name is present
|
|
629
|
+
prepared_result = result
|
|
630
|
+
|
|
631
|
+
from ..core.workflow import _WORKFLOW_REGISTRY
|
|
632
|
+
|
|
633
|
+
workflow = _WORKFLOW_REGISTRY.get(self.workflow_id)
|
|
634
|
+
|
|
635
|
+
if output_schema_name and result and isinstance(result, dict):
|
|
636
|
+
# First, check if the workflow has an output_schema stored (set during execution)
|
|
637
|
+
if workflow and hasattr(workflow, "output_schema") and workflow.output_schema:
|
|
638
|
+
# Use the stored output schema class from the workflow
|
|
639
|
+
try:
|
|
640
|
+
prepared_result = workflow.output_schema.model_validate(result)
|
|
641
|
+
except (ValueError, TypeError) as e:
|
|
642
|
+
logger.warning(
|
|
643
|
+
f"Failed to reconstruct Pydantic model using workflow.output_schema: {e}. "
|
|
644
|
+
f"Falling back to dynamic import."
|
|
645
|
+
)
|
|
646
|
+
|
|
647
|
+
# Fallback to dynamic import if workflow.output_schema is not available
|
|
648
|
+
try:
|
|
649
|
+
# Dynamically import the Pydantic class
|
|
650
|
+
module_path, class_name = output_schema_name.rsplit(".", 1)
|
|
651
|
+
module = __import__(module_path, fromlist=[class_name])
|
|
652
|
+
model_class = getattr(module, class_name)
|
|
653
|
+
|
|
654
|
+
# Validate that it's a Pydantic BaseModel
|
|
655
|
+
if issubclass(model_class, BaseModel):
|
|
656
|
+
prepared_result = model_class.model_validate(result)
|
|
657
|
+
except (ImportError, AttributeError, ValueError, TypeError) as e:
|
|
658
|
+
# If reconstruction fails, log warning but return dict
|
|
659
|
+
# This allows backward compatibility if the class is not available
|
|
660
|
+
logger.warning(
|
|
661
|
+
f"Failed to reconstruct Pydantic model '{output_schema_name}': {e}. "
|
|
662
|
+
f"Returning dict instead."
|
|
663
|
+
)
|
|
664
|
+
|
|
665
|
+
# Handle structured output for agents
|
|
666
|
+
from ..agents.agent import Agent
|
|
667
|
+
|
|
668
|
+
if workflow and isinstance(workflow, Agent) and isinstance(prepared_result, AgentResult):
|
|
669
|
+
# Convert result to structured output schema
|
|
670
|
+
if workflow.result_output_schema and prepared_result.result is not None:
|
|
671
|
+
prepared_result.result = workflow.result_output_schema.model_validate(
|
|
672
|
+
prepared_result.result
|
|
673
|
+
)
|
|
674
|
+
|
|
675
|
+
# Convert tool results to structured output schema
|
|
676
|
+
for tool_result in prepared_result.tool_results:
|
|
677
|
+
if tool_result.result_schema and tool_result.result is not None:
|
|
678
|
+
try:
|
|
679
|
+
# Dynamically import the Pydantic class
|
|
680
|
+
module_path, class_name = tool_result.result_schema.rsplit(".", 1)
|
|
681
|
+
module = __import__(module_path, fromlist=[class_name])
|
|
682
|
+
model_class = getattr(module, class_name)
|
|
683
|
+
|
|
684
|
+
# Validate that it's a Pydantic BaseModel
|
|
685
|
+
if issubclass(model_class, BaseModel):
|
|
686
|
+
tool_result.result = model_class.model_validate(tool_result.result)
|
|
687
|
+
except (ImportError, AttributeError, ValueError, TypeError) as e:
|
|
688
|
+
# If reconstruction fails, log warning but return dict
|
|
689
|
+
# This allows backward compatibility if the class is not available
|
|
690
|
+
logger.warning(
|
|
691
|
+
f"Failed to reconstruct Pydantic model "
|
|
692
|
+
f"'{tool_result.result_schema}': {e}. "
|
|
693
|
+
f"Returning dict instead."
|
|
694
|
+
)
|
|
695
|
+
|
|
696
|
+
return prepared_result
|
|
697
|
+
|
|
698
|
+
|
|
699
|
+
# Internal functions that get client from context
|
|
700
|
+
async def store_step_output(
|
|
701
|
+
execution_id: str,
|
|
702
|
+
step_key: str,
|
|
703
|
+
outputs: Any | None = None,
|
|
704
|
+
error: Any | None = None,
|
|
705
|
+
success: bool | None = True,
|
|
706
|
+
source_execution_id: str | None = None,
|
|
707
|
+
output_schema_name: str | None = None,
|
|
708
|
+
) -> None:
|
|
709
|
+
"""Store step output for recovery.
|
|
710
|
+
|
|
711
|
+
Args:
|
|
712
|
+
execution_id: Execution ID
|
|
713
|
+
step_key: Step key identifier (required, must be unique per execution)
|
|
714
|
+
outputs: Step outputs (optional)
|
|
715
|
+
error: Step error (optional)
|
|
716
|
+
success: Whether step succeeded (optional)
|
|
717
|
+
source_execution_id: Source execution ID (optional)
|
|
718
|
+
output_schema_name: Full module path of Pydantic class for deserialization (optional)
|
|
719
|
+
"""
|
|
720
|
+
client = get_client_or_raise()
|
|
721
|
+
headers = client._get_headers()
|
|
722
|
+
|
|
723
|
+
# Build request payload, only including fields that are not None
|
|
724
|
+
payload = {
|
|
725
|
+
"step_key": step_key,
|
|
726
|
+
}
|
|
727
|
+
|
|
728
|
+
if outputs is not None:
|
|
729
|
+
payload["outputs"] = outputs
|
|
730
|
+
if error is not None:
|
|
731
|
+
payload["error"] = error
|
|
732
|
+
if success is not None:
|
|
733
|
+
payload["success"] = success
|
|
734
|
+
if source_execution_id is not None:
|
|
735
|
+
payload["source_execution_id"] = source_execution_id
|
|
736
|
+
if output_schema_name is not None:
|
|
737
|
+
payload["output_schema_name"] = output_schema_name
|
|
738
|
+
|
|
739
|
+
# Try to reuse worker's HTTP client if available
|
|
740
|
+
worker_client = get_worker_client()
|
|
741
|
+
if worker_client is not None:
|
|
742
|
+
response = await worker_client.post(
|
|
743
|
+
f"{client.api_url}/internal/executions/{execution_id}/steps",
|
|
744
|
+
json=payload,
|
|
745
|
+
headers=headers,
|
|
746
|
+
)
|
|
747
|
+
response.raise_for_status()
|
|
748
|
+
else:
|
|
749
|
+
async with httpx.AsyncClient() as http_client:
|
|
750
|
+
response = await http_client.post(
|
|
751
|
+
f"{client.api_url}/internal/executions/{execution_id}/steps",
|
|
752
|
+
json=payload,
|
|
753
|
+
headers=headers,
|
|
754
|
+
)
|
|
755
|
+
response.raise_for_status()
|
|
756
|
+
|
|
757
|
+
|
|
758
|
+
async def get_step_output(execution_id: str, step_key: str) -> dict[str, Any] | None:
|
|
759
|
+
"""Get step output for recovery.
|
|
760
|
+
|
|
761
|
+
Args:
|
|
762
|
+
execution_id: Execution ID
|
|
763
|
+
step_key: Step key identifier (required)
|
|
764
|
+
|
|
765
|
+
Returns:
|
|
766
|
+
Step output dictionary or None if not found
|
|
767
|
+
"""
|
|
768
|
+
client = get_client_or_raise()
|
|
769
|
+
headers = client._get_headers()
|
|
770
|
+
|
|
771
|
+
# Try to reuse worker's HTTP client if available
|
|
772
|
+
worker_client = get_worker_client()
|
|
773
|
+
if worker_client is not None:
|
|
774
|
+
response = await worker_client.get(
|
|
775
|
+
f"{client.api_url}/internal/executions/{execution_id}/steps/{step_key}",
|
|
776
|
+
headers=headers,
|
|
777
|
+
)
|
|
778
|
+
if response.status_code == 404:
|
|
779
|
+
return None
|
|
780
|
+
response.raise_for_status()
|
|
781
|
+
return response.json()
|
|
782
|
+
else:
|
|
783
|
+
async with httpx.AsyncClient() as http_client:
|
|
784
|
+
response = await http_client.get(
|
|
785
|
+
f"{client.api_url}/internal/executions/{execution_id}/steps/{step_key}",
|
|
786
|
+
headers=headers,
|
|
787
|
+
)
|
|
788
|
+
if response.status_code == 404:
|
|
789
|
+
return None
|
|
790
|
+
response.raise_for_status()
|
|
791
|
+
return response.json()
|
|
792
|
+
|
|
793
|
+
|
|
794
|
+
async def get_all_step_outputs(execution_id: str) -> list:
|
|
795
|
+
"""Get all step outputs for an execution (for recovery)."""
|
|
796
|
+
client = get_client_or_raise()
|
|
797
|
+
headers = client._get_headers()
|
|
798
|
+
|
|
799
|
+
# Try to reuse worker's HTTP client if available
|
|
800
|
+
worker_client = get_worker_client()
|
|
801
|
+
if worker_client is not None:
|
|
802
|
+
response = await worker_client.get(
|
|
803
|
+
f"{client.api_url}/internal/executions/{execution_id}/steps",
|
|
804
|
+
headers=headers,
|
|
805
|
+
)
|
|
806
|
+
response.raise_for_status()
|
|
807
|
+
data = response.json()
|
|
808
|
+
return data.get("steps", [])
|
|
809
|
+
else:
|
|
810
|
+
async with httpx.AsyncClient() as http_client:
|
|
811
|
+
response = await http_client.get(
|
|
812
|
+
f"{client.api_url}/internal/executions/{execution_id}/steps",
|
|
813
|
+
headers=headers,
|
|
814
|
+
)
|
|
815
|
+
response.raise_for_status()
|
|
816
|
+
data = response.json()
|
|
817
|
+
return data.get("steps", [])
|
|
818
|
+
|
|
819
|
+
|
|
820
|
+
async def update_execution_otel_span_id(execution_id: str, otel_span_id: str | None) -> None:
|
|
821
|
+
"""Update execution's otel_span_id (used when workflow is paused via WaitException)."""
|
|
822
|
+
client = get_client_or_raise()
|
|
823
|
+
headers = client._get_headers()
|
|
824
|
+
|
|
825
|
+
# Try to reuse worker's HTTP client if available
|
|
826
|
+
worker_client = get_worker_client()
|
|
827
|
+
if worker_client is not None:
|
|
828
|
+
response = await worker_client.put(
|
|
829
|
+
f"{client.api_url}/internal/executions/{execution_id}/otel-span-id",
|
|
830
|
+
headers=headers,
|
|
831
|
+
json={"otel_span_id": otel_span_id},
|
|
832
|
+
)
|
|
833
|
+
response.raise_for_status()
|
|
834
|
+
else:
|
|
835
|
+
async with httpx.AsyncClient() as http_client:
|
|
836
|
+
response = await http_client.put(
|
|
837
|
+
f"{client.api_url}/internal/executions/{execution_id}/otel-span-id",
|
|
838
|
+
headers=headers,
|
|
839
|
+
json={"otel_span_id": otel_span_id},
|
|
840
|
+
)
|
|
841
|
+
response.raise_for_status()
|