polos-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- polos/__init__.py +105 -0
- polos/agents/__init__.py +7 -0
- polos/agents/agent.py +746 -0
- polos/agents/conversation_history.py +121 -0
- polos/agents/stop_conditions.py +280 -0
- polos/agents/stream.py +635 -0
- polos/core/__init__.py +0 -0
- polos/core/context.py +143 -0
- polos/core/state.py +26 -0
- polos/core/step.py +1380 -0
- polos/core/workflow.py +1192 -0
- polos/features/__init__.py +0 -0
- polos/features/events.py +456 -0
- polos/features/schedules.py +110 -0
- polos/features/tracing.py +605 -0
- polos/features/wait.py +82 -0
- polos/llm/__init__.py +9 -0
- polos/llm/generate.py +152 -0
- polos/llm/providers/__init__.py +5 -0
- polos/llm/providers/anthropic.py +615 -0
- polos/llm/providers/azure.py +42 -0
- polos/llm/providers/base.py +196 -0
- polos/llm/providers/fireworks.py +41 -0
- polos/llm/providers/gemini.py +40 -0
- polos/llm/providers/groq.py +40 -0
- polos/llm/providers/openai.py +1021 -0
- polos/llm/providers/together.py +40 -0
- polos/llm/stream.py +183 -0
- polos/middleware/__init__.py +0 -0
- polos/middleware/guardrail.py +148 -0
- polos/middleware/guardrail_executor.py +253 -0
- polos/middleware/hook.py +164 -0
- polos/middleware/hook_executor.py +104 -0
- polos/runtime/__init__.py +0 -0
- polos/runtime/batch.py +87 -0
- polos/runtime/client.py +841 -0
- polos/runtime/queue.py +42 -0
- polos/runtime/worker.py +1365 -0
- polos/runtime/worker_server.py +249 -0
- polos/tools/__init__.py +0 -0
- polos/tools/tool.py +587 -0
- polos/types/__init__.py +23 -0
- polos/types/types.py +116 -0
- polos/utils/__init__.py +27 -0
- polos/utils/agent.py +27 -0
- polos/utils/client_context.py +41 -0
- polos/utils/config.py +12 -0
- polos/utils/output_schema.py +311 -0
- polos/utils/retry.py +47 -0
- polos/utils/serializer.py +167 -0
- polos/utils/tracing.py +27 -0
- polos/utils/worker_singleton.py +40 -0
- polos_sdk-0.1.0.dist-info/METADATA +650 -0
- polos_sdk-0.1.0.dist-info/RECORD +55 -0
- polos_sdk-0.1.0.dist-info/WHEEL +4 -0
polos/agents/stream.py
ADDED
|
@@ -0,0 +1,635 @@
|
|
|
1
|
+
"""Agent stream function (called from agent._agent_execute)."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel
|
|
9
|
+
|
|
10
|
+
from ..core.context import AgentContext
|
|
11
|
+
from ..core.workflow import _WORKFLOW_REGISTRY
|
|
12
|
+
from ..llm import _llm_generate, _llm_stream
|
|
13
|
+
from ..middleware.hook import HookAction, HookContext
|
|
14
|
+
from ..middleware.hook_executor import execute_hooks
|
|
15
|
+
from ..types.types import (
|
|
16
|
+
AgentResult,
|
|
17
|
+
BatchStepResult,
|
|
18
|
+
BatchWorkflowInput,
|
|
19
|
+
Step,
|
|
20
|
+
ToolCall,
|
|
21
|
+
ToolResult,
|
|
22
|
+
Usage,
|
|
23
|
+
)
|
|
24
|
+
from ..utils.serializer import json_serialize, serialize
|
|
25
|
+
from .conversation_history import add_conversation_history, get_conversation_history
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
async def _agent_stream_function(ctx: AgentContext, payload: dict[str, Any]) -> dict[str, Any]:
|
|
31
|
+
"""
|
|
32
|
+
Agent stream function (called from agent._agent_execute).
|
|
33
|
+
|
|
34
|
+
This function orchestrates the agent conversation:
|
|
35
|
+
1. Executes LLM stream/generate and gets results
|
|
36
|
+
2. If tool calls are present, executes tools and publishes results
|
|
37
|
+
3. Makes successive LLM calls with tool results until no more tool calls
|
|
38
|
+
4. Publishes finish event
|
|
39
|
+
|
|
40
|
+
Payload:
|
|
41
|
+
{
|
|
42
|
+
"agent_run_id": str, # This is the execution_id
|
|
43
|
+
"name": str,
|
|
44
|
+
"agent_config": {
|
|
45
|
+
"provider": str,
|
|
46
|
+
"model": str,
|
|
47
|
+
"tools": List[Dict],
|
|
48
|
+
"system_prompt": Optional[str],
|
|
49
|
+
"max_output_tokens": Optional[int],
|
|
50
|
+
"temperature": Optional[float]
|
|
51
|
+
},
|
|
52
|
+
"input": Union[str, List[Dict]], # String or array of input items
|
|
53
|
+
"streaming": bool # Whether to stream or return final result
|
|
54
|
+
}
|
|
55
|
+
"""
|
|
56
|
+
agent_run_id = ctx.execution_id # Use execution_id from context
|
|
57
|
+
agent_config = payload["agent_config"]
|
|
58
|
+
streaming = payload.get("streaming", True) # Default to True for backward compatibility
|
|
59
|
+
tool_stop_action = False
|
|
60
|
+
input_data = payload.get("input")
|
|
61
|
+
|
|
62
|
+
result = {
|
|
63
|
+
"agent_run_id": agent_run_id,
|
|
64
|
+
"conversation_id": ctx.conversation_id,
|
|
65
|
+
"result": None,
|
|
66
|
+
"tool_results": [],
|
|
67
|
+
"total_steps": 0,
|
|
68
|
+
"usage": {
|
|
69
|
+
"input_tokens": 0,
|
|
70
|
+
"output_tokens": 0,
|
|
71
|
+
"total_tokens": 0,
|
|
72
|
+
},
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
# Get agent instance for hooks
|
|
76
|
+
agent = _WORKFLOW_REGISTRY.get(ctx.agent_id)
|
|
77
|
+
|
|
78
|
+
# Main agent streaming logic starts here
|
|
79
|
+
input_data = payload.get("input")
|
|
80
|
+
if input_data is None:
|
|
81
|
+
raise ValueError("Input is required in payload")
|
|
82
|
+
|
|
83
|
+
# Retrieve conversation history if enabled (cache it to avoid repeated DB reads)
|
|
84
|
+
cached_history_messages = None
|
|
85
|
+
if agent and agent.conversation_history > 0 and ctx.conversation_id:
|
|
86
|
+
history_records = await get_conversation_history(
|
|
87
|
+
conversation_id=ctx.conversation_id,
|
|
88
|
+
agent_id=ctx.agent_id,
|
|
89
|
+
deployment_id=ctx.deployment_id,
|
|
90
|
+
limit=agent.conversation_history,
|
|
91
|
+
)
|
|
92
|
+
# Convert history records to message format
|
|
93
|
+
# History records have: id, session_id, agent_id, role, content, created_at, agent_run_id
|
|
94
|
+
cached_history_messages = []
|
|
95
|
+
for record in history_records:
|
|
96
|
+
# content is already JSON-serializable (string or dict/list for structured content)
|
|
97
|
+
cached_history_messages.append(
|
|
98
|
+
{
|
|
99
|
+
"role": record.get("role"),
|
|
100
|
+
"content": record.get("content"),
|
|
101
|
+
}
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# Build conversation history (for successive LLM calls)
|
|
105
|
+
conversation_messages = []
|
|
106
|
+
|
|
107
|
+
# Prepend cached conversation history if available
|
|
108
|
+
if cached_history_messages:
|
|
109
|
+
conversation_messages.extend(cached_history_messages)
|
|
110
|
+
|
|
111
|
+
# Add current input to conversation
|
|
112
|
+
if isinstance(input_data, str):
|
|
113
|
+
conversation_messages.append({"role": "user", "content": input_data})
|
|
114
|
+
elif isinstance(input_data, list):
|
|
115
|
+
conversation_messages.extend(input_data)
|
|
116
|
+
|
|
117
|
+
# Loop: LLM call -> tool execution -> LLM call with results
|
|
118
|
+
agent_step = 1
|
|
119
|
+
final_input_tokens = 0
|
|
120
|
+
final_output_tokens = 0
|
|
121
|
+
final_total_tokens = 0
|
|
122
|
+
last_llm_result_content = None
|
|
123
|
+
all_tool_results = []
|
|
124
|
+
steps: list[Step] = []
|
|
125
|
+
end_steps = False
|
|
126
|
+
# Get stop conditions from agent object (all are callables)
|
|
127
|
+
stop_conditions = agent.stop_conditions if agent else []
|
|
128
|
+
tool_results = None
|
|
129
|
+
parsed_result = None
|
|
130
|
+
checked_structured_output = False
|
|
131
|
+
|
|
132
|
+
# Check for max_steps limit unless overridden by explicit max_steps stop condition
|
|
133
|
+
# Check if any stop condition is max_steps (configured callable from max_steps())
|
|
134
|
+
from .stop_conditions import max_steps as max_steps_fn
|
|
135
|
+
|
|
136
|
+
has_max_steps_condition = False
|
|
137
|
+
for sc in stop_conditions:
|
|
138
|
+
# Check if it's a configured callable from max_steps()
|
|
139
|
+
if hasattr(sc, "__stop_condition_fn__"):
|
|
140
|
+
# It's a configured callable - check if it's from max_steps
|
|
141
|
+
if sc.__stop_condition_fn__ is max_steps_fn:
|
|
142
|
+
has_max_steps_condition = True
|
|
143
|
+
break
|
|
144
|
+
# Also check if it's the max_steps function itself (unlikely but possible)
|
|
145
|
+
elif sc is max_steps_fn:
|
|
146
|
+
has_max_steps_condition = True
|
|
147
|
+
break
|
|
148
|
+
|
|
149
|
+
if has_max_steps_condition:
|
|
150
|
+
max_steps = None
|
|
151
|
+
else:
|
|
152
|
+
max_steps = int(os.environ.get("POLOS_AGENT_MAX_STEPS", "10")) # Configurable safety limit
|
|
153
|
+
|
|
154
|
+
while not end_steps and (max_steps is None or agent_step <= max_steps):
|
|
155
|
+
current_iteration_tool_results = []
|
|
156
|
+
# Execute on_agent_step_start hooks
|
|
157
|
+
if agent and agent.on_agent_step_start:
|
|
158
|
+
hook_context = HookContext(
|
|
159
|
+
workflow_id=ctx.agent_id,
|
|
160
|
+
session_id=ctx.session_id,
|
|
161
|
+
user_id=ctx.user_id,
|
|
162
|
+
agent_config=agent_config,
|
|
163
|
+
steps=steps.copy(),
|
|
164
|
+
current_payload={"step": agent_step, "messages": conversation_messages},
|
|
165
|
+
)
|
|
166
|
+
hook_result = await execute_hooks(
|
|
167
|
+
f"{agent_step}.hook.on_agent_step_start",
|
|
168
|
+
agent.on_agent_step_start,
|
|
169
|
+
hook_context,
|
|
170
|
+
ctx,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# Apply modifications
|
|
174
|
+
if hook_result.modified_payload and "messages" in hook_result.modified_payload:
|
|
175
|
+
conversation_messages = hook_result.modified_payload["messages"]
|
|
176
|
+
|
|
177
|
+
# Check hook action
|
|
178
|
+
if hook_result.action == HookAction.FAIL:
|
|
179
|
+
break
|
|
180
|
+
if hook_result.action == HookAction.FAIL:
|
|
181
|
+
from ..core.workflow import StepExecutionError
|
|
182
|
+
|
|
183
|
+
raise StepExecutionError(hook_result.error_message or "Hook execution failed")
|
|
184
|
+
|
|
185
|
+
# Get guardrails from agent
|
|
186
|
+
guardrails = agent.guardrails if agent else None
|
|
187
|
+
guardrail_max_retries = (
|
|
188
|
+
agent.guardrail_max_retries if agent else (agent_config.guardrail_max_retries or 2)
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
# Use _llm_generate if streaming=False OR guardrails are present
|
|
192
|
+
# Guardrails need the full response to validate, so we can't stream incrementally
|
|
193
|
+
# If streaming=False, we want the complete result, not incremental chunks
|
|
194
|
+
use_llm_generate = not streaming or guardrails
|
|
195
|
+
|
|
196
|
+
if use_llm_generate:
|
|
197
|
+
llm_result = await _llm_generate(
|
|
198
|
+
ctx,
|
|
199
|
+
{
|
|
200
|
+
"agent_run_id": agent_run_id,
|
|
201
|
+
"agent_config": agent_config,
|
|
202
|
+
"input": conversation_messages,
|
|
203
|
+
"agent_step": agent_step,
|
|
204
|
+
"guardrails": guardrails,
|
|
205
|
+
"guardrail_max_retries": guardrail_max_retries,
|
|
206
|
+
"tool_results": tool_results, # Tool results from previous iteration
|
|
207
|
+
},
|
|
208
|
+
)
|
|
209
|
+
else:
|
|
210
|
+
# No guardrails - use streaming
|
|
211
|
+
llm_result = await _llm_stream(
|
|
212
|
+
ctx,
|
|
213
|
+
{
|
|
214
|
+
"agent_run_id": agent_run_id,
|
|
215
|
+
"agent_config": agent_config,
|
|
216
|
+
"input": conversation_messages,
|
|
217
|
+
"agent_step": agent_step,
|
|
218
|
+
"tool_results": tool_results, # Tool results from previous iteration
|
|
219
|
+
},
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
tool_results = None # Reset tool results for next iteration
|
|
223
|
+
|
|
224
|
+
usage_dict = llm_result.get("usage")
|
|
225
|
+
if usage_dict:
|
|
226
|
+
final_input_tokens += usage_dict.get("input_tokens", 0)
|
|
227
|
+
final_output_tokens += usage_dict.get("output_tokens", 0)
|
|
228
|
+
final_total_tokens += usage_dict.get("total_tokens", 0)
|
|
229
|
+
|
|
230
|
+
last_llm_result_content = llm_result.get("content")
|
|
231
|
+
tool_calls = llm_result.get("tool_calls") or []
|
|
232
|
+
if not llm_result.get("raw_output"):
|
|
233
|
+
raise Exception(
|
|
234
|
+
f"LLM failed to generate output: agent_id={ctx.agent_id}, agent_step={agent_step}"
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
# Execute tools in batch and publish results
|
|
238
|
+
|
|
239
|
+
# Prepare batch workflow list
|
|
240
|
+
batch_workflows = []
|
|
241
|
+
tool_call_list = [] # List of tool calls for publishing results
|
|
242
|
+
# tool_results is preserved from previous iteration (or None on first iteration)
|
|
243
|
+
# Will be set to new results if tools are executed in this iteration
|
|
244
|
+
tool_results_list = [] # Initialize to empty list
|
|
245
|
+
tool_results_recorded_list = []
|
|
246
|
+
|
|
247
|
+
for idx, tool_call in enumerate(tool_calls):
|
|
248
|
+
# Tool call format: {"id": "...", "type": "function",
|
|
249
|
+
# "function": {"name": "...", "arguments": "..."}}
|
|
250
|
+
if (
|
|
251
|
+
isinstance(tool_call, dict)
|
|
252
|
+
and "function" in tool_call
|
|
253
|
+
and isinstance(tool_call["function"], dict)
|
|
254
|
+
):
|
|
255
|
+
tool_name = tool_call["function"].get("name")
|
|
256
|
+
tool_args_str = tool_call["function"].get("arguments", "{}")
|
|
257
|
+
tool_call_id = tool_call.get("id")
|
|
258
|
+
tool_call_call_id = tool_call.get("call_id")
|
|
259
|
+
else:
|
|
260
|
+
continue
|
|
261
|
+
|
|
262
|
+
if not tool_name:
|
|
263
|
+
continue
|
|
264
|
+
|
|
265
|
+
# Find the tool workflow in registry
|
|
266
|
+
tool_workflow = _WORKFLOW_REGISTRY.get(tool_name)
|
|
267
|
+
if not tool_workflow:
|
|
268
|
+
logger.warning("Tool '%s' not found in registry", tool_name)
|
|
269
|
+
continue
|
|
270
|
+
|
|
271
|
+
# Parse tool arguments
|
|
272
|
+
try:
|
|
273
|
+
tool_args = (
|
|
274
|
+
json.loads(tool_args_str) if isinstance(tool_args_str, str) else tool_args_str
|
|
275
|
+
)
|
|
276
|
+
except Exception:
|
|
277
|
+
tool_args = {}
|
|
278
|
+
|
|
279
|
+
# Execute on_tool_start hooks
|
|
280
|
+
if agent and agent.on_tool_start:
|
|
281
|
+
hook_context = HookContext(
|
|
282
|
+
workflow_id=ctx.agent_id,
|
|
283
|
+
session_id=ctx.session_id,
|
|
284
|
+
user_id=ctx.user_id,
|
|
285
|
+
agent_config=agent_config,
|
|
286
|
+
steps=steps.copy(),
|
|
287
|
+
current_tool=tool_name,
|
|
288
|
+
current_payload=tool_args,
|
|
289
|
+
)
|
|
290
|
+
hook_result = await execute_hooks(
|
|
291
|
+
f"{agent_step}.hook.on_tool_start.{idx}", agent.on_tool_start, hook_context, ctx
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
# Apply modifications
|
|
295
|
+
if hook_result.modified_payload:
|
|
296
|
+
tool_args.update(hook_result.modified_payload)
|
|
297
|
+
|
|
298
|
+
# Check hook action
|
|
299
|
+
if hook_result.action == HookAction.FAIL:
|
|
300
|
+
from ..core.workflow import StepExecutionError
|
|
301
|
+
|
|
302
|
+
raise StepExecutionError(hook_result.error_message or "Hook execution failed")
|
|
303
|
+
|
|
304
|
+
# Add to batch
|
|
305
|
+
batch_workflows.append(BatchWorkflowInput(id=tool_name, payload=tool_args))
|
|
306
|
+
tool_call_list.append(
|
|
307
|
+
{
|
|
308
|
+
"tool_call_id": tool_call_id,
|
|
309
|
+
"tool_call_call_id": tool_call_call_id,
|
|
310
|
+
"tool_name": tool_name,
|
|
311
|
+
"tool_call": tool_call,
|
|
312
|
+
}
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
# Execute all tools in batch
|
|
316
|
+
if tool_stop_action is False and len(batch_workflows) > 0:
|
|
317
|
+
tool_results_list: list[BatchStepResult] = await ctx.step.batch_invoke_and_wait(
|
|
318
|
+
f"execute_tools:step_{agent_step}", batch_workflows
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
# Publish results and build conversation
|
|
322
|
+
tool_results_recorded_list = []
|
|
323
|
+
for i, batch_tool_result in enumerate(tool_results_list):
|
|
324
|
+
tool_result = (
|
|
325
|
+
batch_tool_result.result
|
|
326
|
+
if batch_tool_result.success
|
|
327
|
+
else f"Error: {batch_tool_result.error}"
|
|
328
|
+
)
|
|
329
|
+
tool_spec = batch_workflows[i]
|
|
330
|
+
tool_name = tool_spec.id
|
|
331
|
+
tool_call_info = tool_call_list[i]
|
|
332
|
+
|
|
333
|
+
tool_call_id = tool_call_info.get("tool_call_id")
|
|
334
|
+
tool_call_call_id = tool_call_info.get("tool_call_call_id")
|
|
335
|
+
|
|
336
|
+
tool_result_schema = (
|
|
337
|
+
(f"{tool_result.__class__.__module__}.{tool_result.__class__.__name__}")
|
|
338
|
+
if batch_tool_result.success
|
|
339
|
+
else None
|
|
340
|
+
)
|
|
341
|
+
if tool_result_schema and tool_result_schema.startswith("builtins."):
|
|
342
|
+
tool_result_schema = None
|
|
343
|
+
|
|
344
|
+
# Execute on_tool_end hooks
|
|
345
|
+
if agent and agent.on_tool_end:
|
|
346
|
+
hook_context = HookContext(
|
|
347
|
+
workflow_id=ctx.agent_id,
|
|
348
|
+
session_id=ctx.session_id,
|
|
349
|
+
user_id=ctx.user_id,
|
|
350
|
+
agent_config=agent_config,
|
|
351
|
+
steps=steps.copy(),
|
|
352
|
+
current_tool=tool_name,
|
|
353
|
+
current_payload=tool_spec.payload,
|
|
354
|
+
current_output=tool_result,
|
|
355
|
+
)
|
|
356
|
+
hook_result = await execute_hooks(
|
|
357
|
+
f"{agent_step}.hook.on_tool_end.{idx}", agent.on_tool_end, hook_context, ctx
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
# Apply modifications
|
|
361
|
+
if hook_result.modified_output is not None:
|
|
362
|
+
tool_result = hook_result.modified_output
|
|
363
|
+
|
|
364
|
+
# Check hook action
|
|
365
|
+
if hook_result.action == HookAction.FAIL:
|
|
366
|
+
from ..core.workflow import StepExecutionError
|
|
367
|
+
|
|
368
|
+
raise StepExecutionError(
|
|
369
|
+
hook_result.error_message or "Hook execution failed"
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
# Serialize and add tool result to conversation
|
|
373
|
+
tool_output = serialize(tool_result)
|
|
374
|
+
tool_json_output = json_serialize(tool_result)
|
|
375
|
+
|
|
376
|
+
current_iteration_tool_results.append(
|
|
377
|
+
{
|
|
378
|
+
"type": "function_call_output",
|
|
379
|
+
"call_id": tool_call_call_id,
|
|
380
|
+
"output": tool_json_output,
|
|
381
|
+
}
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
tool_results_recorded_list.append(
|
|
385
|
+
{
|
|
386
|
+
"tool_name": tool_name,
|
|
387
|
+
"status": "completed",
|
|
388
|
+
"result": tool_output,
|
|
389
|
+
"result_schema": tool_result_schema,
|
|
390
|
+
"tool_call_id": tool_call_id,
|
|
391
|
+
"tool_call_call_id": tool_call_call_id,
|
|
392
|
+
}
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
all_tool_results.extend(tool_results_recorded_list)
|
|
396
|
+
|
|
397
|
+
# Set tool_results for next iteration
|
|
398
|
+
tool_results = current_iteration_tool_results
|
|
399
|
+
|
|
400
|
+
# Convert tool_calls to ToolCall objects
|
|
401
|
+
tool_calls_list = []
|
|
402
|
+
for tc in tool_calls:
|
|
403
|
+
if isinstance(tc, dict) and "function" in tc and isinstance(tc["function"], dict):
|
|
404
|
+
tool_calls_list.append(ToolCall.model_validate(tc))
|
|
405
|
+
|
|
406
|
+
# Convert tool_results to ToolResult objects
|
|
407
|
+
tool_results_list = []
|
|
408
|
+
for tr in tool_results_recorded_list:
|
|
409
|
+
if isinstance(tr, dict):
|
|
410
|
+
tool_results_list.append(ToolResult.model_validate(tr))
|
|
411
|
+
|
|
412
|
+
# Convert usage to Usage object
|
|
413
|
+
usage_obj = None
|
|
414
|
+
usage_dict = llm_result.get("usage")
|
|
415
|
+
if usage_dict:
|
|
416
|
+
usage_obj = Usage.model_validate(usage_dict)
|
|
417
|
+
|
|
418
|
+
steps.append(
|
|
419
|
+
Step(
|
|
420
|
+
step=agent_step,
|
|
421
|
+
content=last_llm_result_content,
|
|
422
|
+
tool_calls=tool_calls_list,
|
|
423
|
+
tool_results=tool_results_list,
|
|
424
|
+
usage=usage_obj,
|
|
425
|
+
raw_output=llm_result.get("raw_output"),
|
|
426
|
+
)
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
# Execute on_agent_step_end hooks
|
|
430
|
+
if agent and agent.on_agent_step_end:
|
|
431
|
+
hook_context = HookContext(
|
|
432
|
+
workflow_id=ctx.agent_id,
|
|
433
|
+
session_id=ctx.session_id,
|
|
434
|
+
user_id=ctx.user_id,
|
|
435
|
+
agent_config=agent_config,
|
|
436
|
+
steps=steps.copy(),
|
|
437
|
+
current_payload={"step": agent_step, "messages": conversation_messages},
|
|
438
|
+
current_output=steps[-1],
|
|
439
|
+
)
|
|
440
|
+
hook_result = await execute_hooks(
|
|
441
|
+
f"{agent_step}.hook.on_agent_step_end", agent.on_agent_step_end, hook_context, ctx
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
# Apply modifications
|
|
445
|
+
if hook_result.modified_output:
|
|
446
|
+
new_result = hook_result.modified_output
|
|
447
|
+
steps[-1] = new_result
|
|
448
|
+
|
|
449
|
+
# Check hook action
|
|
450
|
+
if hook_result.action == HookAction.FAIL:
|
|
451
|
+
from ..core.workflow import StepExecutionError
|
|
452
|
+
|
|
453
|
+
raise StepExecutionError(hook_result.error_message or "Hook execution failed")
|
|
454
|
+
|
|
455
|
+
# No tool results, we're done
|
|
456
|
+
if tool_results is None or len(tool_results) == 0 or tool_stop_action:
|
|
457
|
+
end_steps = True
|
|
458
|
+
|
|
459
|
+
# Evaluate stop conditions (if any)
|
|
460
|
+
if stop_conditions and not end_steps:
|
|
461
|
+
from .stop_conditions import StopConditionContext
|
|
462
|
+
|
|
463
|
+
# Create stop condition context
|
|
464
|
+
stop_ctx = StopConditionContext(
|
|
465
|
+
steps=steps.copy(),
|
|
466
|
+
agent_id=ctx.agent_id,
|
|
467
|
+
agent_run_id=agent_run_id,
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
for idx, stop_condition in enumerate(stop_conditions):
|
|
471
|
+
# Call stop condition using step.run() for durable execution
|
|
472
|
+
if hasattr(stop_condition, "__stop_condition_name__"):
|
|
473
|
+
func_name = stop_condition.__stop_condition_name__
|
|
474
|
+
else:
|
|
475
|
+
func_name = "unknown"
|
|
476
|
+
should_stop = await ctx.step.run(
|
|
477
|
+
f"{agent_step}.stop_condition.{func_name}.{idx}", stop_condition, stop_ctx
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
if should_stop:
|
|
481
|
+
# Stop condition met; break loop
|
|
482
|
+
end_steps = True
|
|
483
|
+
break
|
|
484
|
+
|
|
485
|
+
if end_steps:
|
|
486
|
+
# Parse structured output
|
|
487
|
+
parsed_result, parse_success = await _parse_structured_output(
|
|
488
|
+
last_llm_result_content, agent.result_output_schema if agent else None
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
if checked_structured_output and not parse_success:
|
|
492
|
+
# LLM failed to generate valid output again, raise an exception
|
|
493
|
+
raise Exception(
|
|
494
|
+
f"LLM failed to generate valid structured output: "
|
|
495
|
+
f"agent_id={ctx.agent_id}, agent_step={agent_step}"
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
checked_structured_output = True
|
|
499
|
+
|
|
500
|
+
# If parsing failed and output_schema is present, try to fix it with llm_generate
|
|
501
|
+
if not parse_success and agent and agent.result_output_schema:
|
|
502
|
+
# Reset end_steps and try to fix the output
|
|
503
|
+
end_steps = False
|
|
504
|
+
|
|
505
|
+
# Simply include the last incorrect output in the conversation messages
|
|
506
|
+
conversation_messages = llm_result.get("raw_output")
|
|
507
|
+
|
|
508
|
+
# Add a user message asking to fix the output
|
|
509
|
+
schema_json = json.dumps(agent.result_output_schema.model_json_schema(), indent=2)
|
|
510
|
+
fix_prompt = (
|
|
511
|
+
f"The previous response was not valid JSON matching the "
|
|
512
|
+
f"required schema. Please reformat your response to be valid "
|
|
513
|
+
f"JSON that strictly conforms to this schema:\n\n{schema_json}\n\n"
|
|
514
|
+
f"Please provide ONLY valid JSON that matches the schema, "
|
|
515
|
+
f"with no additional text or formatting."
|
|
516
|
+
)
|
|
517
|
+
conversation_messages.append({"role": "user", "content": fix_prompt})
|
|
518
|
+
|
|
519
|
+
if not end_steps:
|
|
520
|
+
# If it's a structured output correction step, we've already created
|
|
521
|
+
# the conversation messages
|
|
522
|
+
# So we don't need to add the raw output again
|
|
523
|
+
if checked_structured_output is False:
|
|
524
|
+
# conversation_messages.extend(llm_result.get("raw_output"))
|
|
525
|
+
conversation_messages = llm_result.get("raw_output")
|
|
526
|
+
|
|
527
|
+
# Increment agent_step for next LLM call
|
|
528
|
+
agent_step += 1
|
|
529
|
+
|
|
530
|
+
# Prepare result and update agent_run status to completed
|
|
531
|
+
result.update(
|
|
532
|
+
{
|
|
533
|
+
"agent_run_id": agent_run_id,
|
|
534
|
+
"conversation_id": ctx.conversation_id,
|
|
535
|
+
"result": last_llm_result_content,
|
|
536
|
+
"tool_results": all_tool_results,
|
|
537
|
+
"total_steps": agent_step,
|
|
538
|
+
"usage": {
|
|
539
|
+
"input_tokens": final_input_tokens,
|
|
540
|
+
"output_tokens": final_output_tokens,
|
|
541
|
+
"total_tokens": final_total_tokens,
|
|
542
|
+
},
|
|
543
|
+
}
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
if parsed_result and agent and agent.result_output_schema:
|
|
547
|
+
parsed_result_schema = (
|
|
548
|
+
f"{parsed_result.__class__.__module__}.{parsed_result.__class__.__name__}"
|
|
549
|
+
)
|
|
550
|
+
else:
|
|
551
|
+
parsed_result_schema = None
|
|
552
|
+
|
|
553
|
+
# Store conversation history if enabled
|
|
554
|
+
if agent and ctx.conversation_id:
|
|
555
|
+
# Store user message (input_data is already in a JSON-serializable format)
|
|
556
|
+
await ctx.step.run(
|
|
557
|
+
"add_conversation_history_user",
|
|
558
|
+
add_conversation_history,
|
|
559
|
+
ctx=ctx,
|
|
560
|
+
conversation_id=ctx.conversation_id,
|
|
561
|
+
agent_id=ctx.agent_id,
|
|
562
|
+
role="user",
|
|
563
|
+
content=input_data,
|
|
564
|
+
agent_run_id=str(agent_run_id),
|
|
565
|
+
conversation_history_limit=agent.conversation_history,
|
|
566
|
+
)
|
|
567
|
+
|
|
568
|
+
# Store assistant response
|
|
569
|
+
if last_llm_result_content:
|
|
570
|
+
await ctx.step.run(
|
|
571
|
+
"add_conversation_history_assistant",
|
|
572
|
+
add_conversation_history,
|
|
573
|
+
ctx=ctx,
|
|
574
|
+
conversation_id=ctx.conversation_id,
|
|
575
|
+
agent_id=ctx.agent_id,
|
|
576
|
+
role="assistant",
|
|
577
|
+
content=last_llm_result_content,
|
|
578
|
+
agent_run_id=str(agent_run_id),
|
|
579
|
+
conversation_history_limit=agent.conversation_history,
|
|
580
|
+
)
|
|
581
|
+
|
|
582
|
+
# Return typed AgentResult for SDK callers
|
|
583
|
+
raw_tool_results = result.get("tool_results", [])
|
|
584
|
+
typed_tool_results = [
|
|
585
|
+
ToolResult.model_validate(tr) for tr in raw_tool_results if isinstance(tr, dict)
|
|
586
|
+
]
|
|
587
|
+
usage_dict = result.get("usage", {}) or {}
|
|
588
|
+
usage_obj = Usage(
|
|
589
|
+
input_tokens=usage_dict.get("input_tokens", 0),
|
|
590
|
+
output_tokens=usage_dict.get("output_tokens", 0),
|
|
591
|
+
total_tokens=usage_dict.get("total_tokens", 0),
|
|
592
|
+
)
|
|
593
|
+
|
|
594
|
+
agent_result = AgentResult(
|
|
595
|
+
agent_run_id=str(result.get("agent_run_id")),
|
|
596
|
+
conversation_id=str(result.get("conversation_id")),
|
|
597
|
+
result=parsed_result,
|
|
598
|
+
result_schema=parsed_result_schema,
|
|
599
|
+
tool_results=typed_tool_results,
|
|
600
|
+
total_steps=int(result.get("total_steps", 0)),
|
|
601
|
+
usage=usage_obj,
|
|
602
|
+
)
|
|
603
|
+
return agent_result
|
|
604
|
+
|
|
605
|
+
|
|
606
|
+
async def _parse_structured_output(output: str, output_schema: type[BaseModel] | None = None):
|
|
607
|
+
"""
|
|
608
|
+
Parse structured output if output_schema is provided.
|
|
609
|
+
|
|
610
|
+
Returns:
|
|
611
|
+
Tuple of (parsed_output, success_flag) where success_flag is True if parsing succeeded,
|
|
612
|
+
or False if parsing failed and output_schema is present.
|
|
613
|
+
"""
|
|
614
|
+
parsed_output = output
|
|
615
|
+
success = True
|
|
616
|
+
|
|
617
|
+
if output_schema and output:
|
|
618
|
+
if isinstance(output, str):
|
|
619
|
+
try:
|
|
620
|
+
# Parse JSON dict into Pydantic model instance
|
|
621
|
+
parsed_output = output_schema.model_validate_json(output)
|
|
622
|
+
except Exception as e:
|
|
623
|
+
# If parsing fails and output_schema is present, return False
|
|
624
|
+
logger.warning("Failed to parse structured output: %s", e)
|
|
625
|
+
success = False
|
|
626
|
+
elif isinstance(output, dict):
|
|
627
|
+
try:
|
|
628
|
+
# Parse JSON dict into Pydantic model instance
|
|
629
|
+
parsed_output = output_schema.model_validate(output)
|
|
630
|
+
except Exception as e:
|
|
631
|
+
# If parsing fails and output_schema is present, return False
|
|
632
|
+
logger.warning("Failed to parse structured output: %s", e)
|
|
633
|
+
success = False
|
|
634
|
+
|
|
635
|
+
return parsed_output, success
|
polos/core/__init__.py
ADDED
|
File without changes
|