agent-runtime-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent_runtime/__init__.py +84 -0
- agent_runtime/builder.py +317 -0
- agent_runtime/config/__init__.py +29 -0
- agent_runtime/config/definitions.py +144 -0
- agent_runtime/config/policies.py +63 -0
- agent_runtime/config/storage.py +117 -0
- agent_runtime/context.py +10 -0
- agent_runtime/definitions.py +33 -0
- agent_runtime/discovery.py +16 -0
- agent_runtime/exceptions.py +74 -0
- agent_runtime/mcp/__init__.py +28 -0
- agent_runtime/mcp/discovery.py +146 -0
- agent_runtime/mcp/metadata.py +68 -0
- agent_runtime/mcp/utils.py +52 -0
- agent_runtime/model_registry.py +40 -0
- agent_runtime/plugins/__init__.py +4 -0
- agent_runtime/plugins/base.py +90 -0
- agent_runtime/plugins/default.py +19 -0
- agent_runtime/plugins/instructions.py +38 -0
- agent_runtime/plugins/loader.py +59 -0
- agent_runtime/policies.py +15 -0
- agent_runtime/runtime.py +110 -0
- agent_runtime/runtime_engine/__init__.py +22 -0
- agent_runtime/runtime_engine/a2a_bridge.py +190 -0
- agent_runtime/runtime_engine/a2a_task_io.py +165 -0
- agent_runtime/runtime_engine/agent_build.py +315 -0
- agent_runtime/runtime_engine/context.py +469 -0
- agent_runtime/runtime_engine/loading.py +170 -0
- agent_runtime/runtime_engine/observability.py +154 -0
- agent_runtime/runtime_engine/policy_registry.py +98 -0
- agent_runtime/runtime_engine/protocol_tools.py +94 -0
- agent_runtime/runtime_engine/task_flow.py +897 -0
- agent_runtime/runtime_engine/tool_flow.py +332 -0
- agent_runtime/sdk_agent.py +548 -0
- agent_runtime/server/__init__.py +15 -0
- agent_runtime/server/app_factory.py +37 -0
- agent_runtime/server/bootstrap.py +48 -0
- agent_runtime/server/endpoint_utils.py +37 -0
- agent_runtime/server/management.py +107 -0
- agent_runtime/smol/__init__.py +4 -0
- agent_runtime/smol/agents.py +431 -0
- agent_runtime/smol/llm_models.py +212 -0
- agent_runtime/smol/memory.py +111 -0
- agent_runtime/smol/models.py +69 -0
- agent_runtime/standalone.py +57 -0
- agent_runtime/storage.py +5 -0
- agent_runtime/tools.py +5 -0
- agent_runtime_sdk-0.1.0.dist-info/METADATA +125 -0
- agent_runtime_sdk-0.1.0.dist-info/RECORD +51 -0
- agent_runtime_sdk-0.1.0.dist-info/WHEEL +5 -0
- agent_runtime_sdk-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from fastapi import APIRouter, HTTPException
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
8
|
+
from ..config.definitions import ToolPolicy
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from ..runtime import ManagedAgentRuntime
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def create_management_router(runtime: ManagedAgentRuntime) -> APIRouter:
|
|
15
|
+
"""Create a FastAPI router for runtime tool/policy management."""
|
|
16
|
+
|
|
17
|
+
router = APIRouter(prefix="/management", tags=["management"])
|
|
18
|
+
|
|
19
|
+
def _tool_payload(tool, policy: ToolPolicy | None = None):
|
|
20
|
+
return {
|
|
21
|
+
"name": tool.name,
|
|
22
|
+
"source_mcp": tool.source_mcp,
|
|
23
|
+
"description": tool.description,
|
|
24
|
+
"inputs": tool.inputs,
|
|
25
|
+
"policy": (policy or ToolPolicy()).model_dump(),
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
@router.get("/tools")
|
|
29
|
+
def list_tools():
|
|
30
|
+
tools = runtime.list_tools()
|
|
31
|
+
policies = runtime.list_tool_policies()
|
|
32
|
+
return [
|
|
33
|
+
_tool_payload(t, policies.get(t.name))
|
|
34
|
+
for t in tools
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
@router.get("/mcps")
|
|
38
|
+
def list_mcps():
|
|
39
|
+
return [
|
|
40
|
+
{
|
|
41
|
+
"name": mcp.name,
|
|
42
|
+
"enabled": mcp.enabled,
|
|
43
|
+
"url": mcp.url,
|
|
44
|
+
"transport": mcp.transport,
|
|
45
|
+
"tool_count": len(runtime.list_tools(mcp.name)),
|
|
46
|
+
"policy_count": len(runtime.list_tool_policies(mcp.name)),
|
|
47
|
+
}
|
|
48
|
+
for mcp in runtime.definition.mcps
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
@router.get("/mcps/{mcp_name}/tools")
|
|
52
|
+
def list_mcp_tools(mcp_name: str):
|
|
53
|
+
tools = runtime.list_tools(mcp_name)
|
|
54
|
+
if not any(mcp.name == mcp_name for mcp in runtime.definition.mcps):
|
|
55
|
+
raise HTTPException(404, f"mcp '{mcp_name}' not found")
|
|
56
|
+
policies = runtime.list_tool_policies(mcp_name)
|
|
57
|
+
return [_tool_payload(tool, policies.get(tool.name)) for tool in tools]
|
|
58
|
+
|
|
59
|
+
@router.get("/mcps/{mcp_name}/tools/{tool_name}")
|
|
60
|
+
def get_mcp_tool(mcp_name: str, tool_name: str):
|
|
61
|
+
tool = runtime.get_tool(tool_name, mcp_name)
|
|
62
|
+
if not tool:
|
|
63
|
+
raise HTTPException(404, f"tool '{tool_name}' not found in mcp '{mcp_name}'")
|
|
64
|
+
policy = runtime.get_tool_policy(tool_name, mcp_name)
|
|
65
|
+
return _tool_payload(tool, policy)
|
|
66
|
+
|
|
67
|
+
@router.put("/mcps/{mcp_name}/tools/{tool_name}/policy")
|
|
68
|
+
def update_tool_policy(mcp_name: str, tool_name: str, body: ToolPolicy):
|
|
69
|
+
try:
|
|
70
|
+
runtime.set_tool_policy(tool_name, body, mcp_name=mcp_name)
|
|
71
|
+
except ValueError as exc:
|
|
72
|
+
raise HTTPException(404, str(exc)) from exc
|
|
73
|
+
return {"ok": True, "mcp_name": mcp_name, "tool_name": tool_name, "policy": body.model_dump()}
|
|
74
|
+
|
|
75
|
+
@router.delete("/mcps/{mcp_name}/tools/{tool_name}/policy")
|
|
76
|
+
def delete_tool_policy(mcp_name: str, tool_name: str):
|
|
77
|
+
removed = runtime.remove_tool_policy(tool_name, mcp_name=mcp_name)
|
|
78
|
+
if not removed:
|
|
79
|
+
raise HTTPException(404, f"no policy for '{tool_name}' in mcp '{mcp_name}'")
|
|
80
|
+
return {"ok": True}
|
|
81
|
+
|
|
82
|
+
@router.post("/reload")
|
|
83
|
+
def reload_tools():
|
|
84
|
+
runtime.reload(discover=True, skip_plugin_load=True)
|
|
85
|
+
return {"tool_count": len(runtime.discovered_tools)}
|
|
86
|
+
|
|
87
|
+
# ── Instructions ──
|
|
88
|
+
|
|
89
|
+
class InstructionsBody(BaseModel):
|
|
90
|
+
instructions: str
|
|
91
|
+
|
|
92
|
+
@router.get("/instructions")
|
|
93
|
+
def get_instructions():
|
|
94
|
+
return {"instructions": runtime.get_instructions()}
|
|
95
|
+
|
|
96
|
+
@router.put("/instructions")
|
|
97
|
+
def set_instructions(body: InstructionsBody):
|
|
98
|
+
runtime.set_instructions(body.instructions)
|
|
99
|
+
return {"ok": True, "instructions": body.instructions}
|
|
100
|
+
|
|
101
|
+
@router.get("/config")
|
|
102
|
+
def get_config():
|
|
103
|
+
dump = runtime.definition.model_dump()
|
|
104
|
+
dump.get("runtime", {}).get("model", {}).pop("api_key_env", None)
|
|
105
|
+
return dump
|
|
106
|
+
|
|
107
|
+
return router
|
|
@@ -0,0 +1,431 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
import time
|
|
6
|
+
from collections.abc import Callable, Generator
|
|
7
|
+
from typing import TYPE_CHECKING, Any
|
|
8
|
+
|
|
9
|
+
from rich.live import Live
|
|
10
|
+
from rich.markdown import Markdown
|
|
11
|
+
from rich.text import Text
|
|
12
|
+
from smolagents import (
|
|
13
|
+
ActionOutput,
|
|
14
|
+
AgentError,
|
|
15
|
+
AgentGenerationError,
|
|
16
|
+
AgentParsingError,
|
|
17
|
+
AgentToolExecutionError,
|
|
18
|
+
ChatMessage,
|
|
19
|
+
ChatMessageStreamDelta,
|
|
20
|
+
FinalAnswerStep,
|
|
21
|
+
LogLevel,
|
|
22
|
+
PlanningStep,
|
|
23
|
+
Timing,
|
|
24
|
+
ToolCall,
|
|
25
|
+
ToolCallingAgent,
|
|
26
|
+
ToolOutput,
|
|
27
|
+
YELLOW_HEX,
|
|
28
|
+
agglomerate_stream_deltas,
|
|
29
|
+
handle_agent_output_types,
|
|
30
|
+
parse_json_if_needed,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
from ..runtime_engine.context import (
|
|
34
|
+
ActiveToolCall,
|
|
35
|
+
current_task_id,
|
|
36
|
+
current_task_pool,
|
|
37
|
+
)
|
|
38
|
+
from .memory import ActionStep, ToolCallbackRegistry
|
|
39
|
+
|
|
40
|
+
if TYPE_CHECKING:
|
|
41
|
+
import PIL.Image
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class ToolCallingCheckAgent(ToolCallingAgent):
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
tool_callbacks: dict[str, Callable | list[Callable]] | None = None,
|
|
48
|
+
after_tool_hook: Callable[[ToolOutput], None] | None = None,
|
|
49
|
+
**kwargs,
|
|
50
|
+
):
|
|
51
|
+
if tool_callbacks:
|
|
52
|
+
self.interrupt_before_tool_name = list(tool_callbacks.keys())
|
|
53
|
+
self.after_tool_hook = after_tool_hook
|
|
54
|
+
self._setup_tool_callbacks(tool_callbacks)
|
|
55
|
+
super().__init__(**kwargs)
|
|
56
|
+
|
|
57
|
+
def _setup_tool_callbacks(self, tool_callbacks):
|
|
58
|
+
self.tool_callbacks = ToolCallbackRegistry()
|
|
59
|
+
if not tool_callbacks:
|
|
60
|
+
return
|
|
61
|
+
if not isinstance(tool_callbacks, dict):
|
|
62
|
+
raise ValueError("tool_callbacks must be a dict")
|
|
63
|
+
for tool_name, callbacks in tool_callbacks.items():
|
|
64
|
+
callback_list = callbacks if isinstance(callbacks, list) else [callbacks]
|
|
65
|
+
for callback in callback_list:
|
|
66
|
+
self.tool_callbacks.register(tool_name, callback)
|
|
67
|
+
|
|
68
|
+
@staticmethod
|
|
69
|
+
def _current_task_context():
|
|
70
|
+
try:
|
|
71
|
+
task_id = current_task_id.get()
|
|
72
|
+
pool = current_task_pool.get()
|
|
73
|
+
except LookupError:
|
|
74
|
+
return None, None
|
|
75
|
+
return task_id, pool.get(task_id)
|
|
76
|
+
|
|
77
|
+
def execute_tool_call(self, tool_name: str, arguments: dict[str, str] | str) -> Any:
|
|
78
|
+
task_id, task_info = self._current_task_context()
|
|
79
|
+
if task_info is not None and task_info.control.stop_requested:
|
|
80
|
+
raise task_info.control.stop_error or RuntimeError(
|
|
81
|
+
"task cancelled before tool execution"
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
available_tools = {**self.tools, **self.managed_agents}
|
|
85
|
+
tool = available_tools.get(tool_name)
|
|
86
|
+
active_tool: ActiveToolCall | None = None
|
|
87
|
+
if task_info is not None and tool is not None:
|
|
88
|
+
active_tool = ActiveToolCall(
|
|
89
|
+
tool_name=tool_name,
|
|
90
|
+
mcp_name=getattr(tool, "_runtime_mcp_name", None),
|
|
91
|
+
cancel_hook=getattr(tool, "_runtime_cancel_hook", None),
|
|
92
|
+
)
|
|
93
|
+
task_info.set_active_tool(active_tool)
|
|
94
|
+
|
|
95
|
+
try:
|
|
96
|
+
return super().execute_tool_call(tool_name, arguments)
|
|
97
|
+
finally:
|
|
98
|
+
if task_info is not None and task_info.active_tool is active_tool:
|
|
99
|
+
task_info.clear_active_tool()
|
|
100
|
+
|
|
101
|
+
def _run_stream(
|
|
102
|
+
self,
|
|
103
|
+
task: str,
|
|
104
|
+
max_steps: int,
|
|
105
|
+
images: list["PIL.Image.Image"] | None = None,
|
|
106
|
+
) -> Generator[
|
|
107
|
+
ActionStep | PlanningStep | FinalAnswerStep | ChatMessageStreamDelta
|
|
108
|
+
]:
|
|
109
|
+
self.step_number = 1
|
|
110
|
+
returned_final_answer = False
|
|
111
|
+
while not returned_final_answer and self.step_number <= max_steps:
|
|
112
|
+
if self.interrupt_switch:
|
|
113
|
+
raise AgentError("Agent interrupted.", self.logger)
|
|
114
|
+
|
|
115
|
+
if self.planning_interval is not None and (
|
|
116
|
+
self.step_number == 1
|
|
117
|
+
or (self.step_number - 1) % self.planning_interval == 0
|
|
118
|
+
):
|
|
119
|
+
planning_start_time = time.time()
|
|
120
|
+
planning_step = None
|
|
121
|
+
for element in self._generate_planning_step(
|
|
122
|
+
task,
|
|
123
|
+
is_first_step=len(self.memory.steps) == 1,
|
|
124
|
+
step=self.step_number,
|
|
125
|
+
):
|
|
126
|
+
yield element
|
|
127
|
+
planning_step = element
|
|
128
|
+
assert isinstance(planning_step, PlanningStep)
|
|
129
|
+
planning_end_time = time.time()
|
|
130
|
+
planning_step.timing = Timing(
|
|
131
|
+
start_time=planning_start_time, end_time=planning_end_time
|
|
132
|
+
)
|
|
133
|
+
self._finalize_step(planning_step)
|
|
134
|
+
self.memory.steps.append(planning_step)
|
|
135
|
+
|
|
136
|
+
action_step_start_time = time.time()
|
|
137
|
+
action_step = ActionStep(
|
|
138
|
+
step_number=self.step_number,
|
|
139
|
+
timing=Timing(start_time=action_step_start_time),
|
|
140
|
+
observations_images=images,
|
|
141
|
+
)
|
|
142
|
+
self.logger.log_rule(f"Step {self.step_number}", level=LogLevel.INFO)
|
|
143
|
+
try:
|
|
144
|
+
for output in self._step_stream(action_step):
|
|
145
|
+
yield output
|
|
146
|
+
if isinstance(output, ActionOutput) and output.is_final_answer:
|
|
147
|
+
final_answer = output.output
|
|
148
|
+
self.logger.log(
|
|
149
|
+
Text(
|
|
150
|
+
f"Final answer: {final_answer}",
|
|
151
|
+
style=f"bold {YELLOW_HEX}",
|
|
152
|
+
),
|
|
153
|
+
level=LogLevel.INFO,
|
|
154
|
+
)
|
|
155
|
+
if self.final_answer_checks:
|
|
156
|
+
self._validate_final_answer(final_answer)
|
|
157
|
+
returned_final_answer = True
|
|
158
|
+
action_step.is_final_answer = True
|
|
159
|
+
except AgentGenerationError as exc:
|
|
160
|
+
raise exc
|
|
161
|
+
except AgentError as exc:
|
|
162
|
+
action_step.error = exc
|
|
163
|
+
finally:
|
|
164
|
+
self._finalize_step(action_step)
|
|
165
|
+
self._emit_step_artifacts(action_step)
|
|
166
|
+
self.memory.steps.append(action_step)
|
|
167
|
+
yield action_step
|
|
168
|
+
self.step_number += 1
|
|
169
|
+
|
|
170
|
+
if not returned_final_answer and self.step_number == max_steps + 1:
|
|
171
|
+
final_answer = self._handle_max_steps_reached(task)
|
|
172
|
+
yield action_step
|
|
173
|
+
|
|
174
|
+
yield FinalAnswerStep(handle_agent_output_types(final_answer))
|
|
175
|
+
|
|
176
|
+
def _emit_step_artifacts(self, action_step: ActionStep) -> None:
|
|
177
|
+
try:
|
|
178
|
+
task_id = current_task_id.get()
|
|
179
|
+
pool = current_task_pool.get()
|
|
180
|
+
except LookupError:
|
|
181
|
+
return
|
|
182
|
+
|
|
183
|
+
task_info = pool.get(task_id)
|
|
184
|
+
if not task_info:
|
|
185
|
+
return
|
|
186
|
+
|
|
187
|
+
updater = task_info.updater
|
|
188
|
+
loop = task_info.loop
|
|
189
|
+
if not updater or not loop or loop.is_closed():
|
|
190
|
+
return
|
|
191
|
+
|
|
192
|
+
from a2a.types import Part, TaskState
|
|
193
|
+
|
|
194
|
+
def _emit_artifact(
|
|
195
|
+
parts: list[Part], metadata: dict[str, Any], name: str
|
|
196
|
+
) -> bool:
|
|
197
|
+
try:
|
|
198
|
+
future = asyncio.run_coroutine_threadsafe(
|
|
199
|
+
updater.add_artifact(parts=parts, name=name, metadata=metadata),
|
|
200
|
+
loop,
|
|
201
|
+
)
|
|
202
|
+
future.result(timeout=5)
|
|
203
|
+
return True
|
|
204
|
+
except Exception:
|
|
205
|
+
return False
|
|
206
|
+
|
|
207
|
+
step_no = action_step.step_number
|
|
208
|
+
tool_observations: dict[str, dict[str, Any]] = (
|
|
209
|
+
getattr(action_step, "tool_observations", {}) or {}
|
|
210
|
+
)
|
|
211
|
+
tool_calls = action_step.tool_calls or []
|
|
212
|
+
has_final_answer = False
|
|
213
|
+
emitted_observation_ids: set[str] = set()
|
|
214
|
+
|
|
215
|
+
for tool_call in tool_calls:
|
|
216
|
+
tool_name = ""
|
|
217
|
+
tool_args = None
|
|
218
|
+
tool_id = None
|
|
219
|
+
|
|
220
|
+
if hasattr(tool_call, "function"):
|
|
221
|
+
func = tool_call.function
|
|
222
|
+
tool_name = getattr(func, "name", "") or tool_name
|
|
223
|
+
tool_args = getattr(func, "arguments", None)
|
|
224
|
+
tool_id = getattr(tool_call, "id", None)
|
|
225
|
+
else:
|
|
226
|
+
tool_name = getattr(tool_call, "name", "") or tool_name
|
|
227
|
+
tool_args = getattr(tool_call, "arguments", None)
|
|
228
|
+
tool_id = getattr(tool_call, "id", None)
|
|
229
|
+
|
|
230
|
+
if not tool_name and hasattr(tool_call, "dict"):
|
|
231
|
+
tool_call_dict = tool_call.dict()
|
|
232
|
+
tool_name = (tool_call_dict.get("function", {}) or {}).get(
|
|
233
|
+
"name", ""
|
|
234
|
+
) or tool_name
|
|
235
|
+
tool_args = (tool_call_dict.get("function", {}) or {}).get(
|
|
236
|
+
"arguments", tool_args
|
|
237
|
+
)
|
|
238
|
+
tool_id = tool_call_dict.get("id", tool_id)
|
|
239
|
+
|
|
240
|
+
if tool_name == "final_answer":
|
|
241
|
+
has_final_answer = True
|
|
242
|
+
continue
|
|
243
|
+
if not tool_name:
|
|
244
|
+
continue
|
|
245
|
+
|
|
246
|
+
normalized_args = tool_args
|
|
247
|
+
if isinstance(normalized_args, str):
|
|
248
|
+
try:
|
|
249
|
+
normalized_args = json.loads(normalized_args)
|
|
250
|
+
except Exception:
|
|
251
|
+
pass
|
|
252
|
+
|
|
253
|
+
_emit_artifact(
|
|
254
|
+
parts=[
|
|
255
|
+
Part.model_validate(
|
|
256
|
+
{
|
|
257
|
+
"kind": "data",
|
|
258
|
+
"data": {
|
|
259
|
+
"tool": {
|
|
260
|
+
"name": tool_name,
|
|
261
|
+
"args": normalized_args,
|
|
262
|
+
"id": tool_id,
|
|
263
|
+
}
|
|
264
|
+
},
|
|
265
|
+
}
|
|
266
|
+
)
|
|
267
|
+
],
|
|
268
|
+
metadata={
|
|
269
|
+
"event": "tool_call",
|
|
270
|
+
"state": TaskState.working.value,
|
|
271
|
+
"step": step_no,
|
|
272
|
+
},
|
|
273
|
+
name="tool_call",
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
obs_payload = tool_observations.get(tool_id or "")
|
|
277
|
+
if obs_payload:
|
|
278
|
+
obs_text = str(obs_payload.get("observation") or "").strip()
|
|
279
|
+
if obs_text:
|
|
280
|
+
_emit_artifact(
|
|
281
|
+
parts=[
|
|
282
|
+
Part.model_validate(
|
|
283
|
+
{
|
|
284
|
+
"kind": "data",
|
|
285
|
+
"data": {
|
|
286
|
+
"tool": {"name": tool_name, "id": tool_id},
|
|
287
|
+
"observation": obs_text,
|
|
288
|
+
},
|
|
289
|
+
}
|
|
290
|
+
)
|
|
291
|
+
],
|
|
292
|
+
metadata={
|
|
293
|
+
"event": "observation",
|
|
294
|
+
"state": TaskState.working.value,
|
|
295
|
+
"step": step_no,
|
|
296
|
+
},
|
|
297
|
+
name="observation",
|
|
298
|
+
)
|
|
299
|
+
if tool_id:
|
|
300
|
+
emitted_observation_ids.add(tool_id)
|
|
301
|
+
|
|
302
|
+
if not has_final_answer:
|
|
303
|
+
for obs_id, obs_payload in tool_observations.items():
|
|
304
|
+
if obs_id in emitted_observation_ids:
|
|
305
|
+
continue
|
|
306
|
+
obs_text = str(obs_payload.get("observation") or "").strip()
|
|
307
|
+
if not obs_text:
|
|
308
|
+
continue
|
|
309
|
+
_emit_artifact(
|
|
310
|
+
parts=[
|
|
311
|
+
Part.model_validate(
|
|
312
|
+
{
|
|
313
|
+
"kind": "data",
|
|
314
|
+
"data": {
|
|
315
|
+
"tool": {
|
|
316
|
+
"name": obs_payload.get("tool_name"),
|
|
317
|
+
"id": obs_id,
|
|
318
|
+
},
|
|
319
|
+
"observation": obs_text,
|
|
320
|
+
},
|
|
321
|
+
}
|
|
322
|
+
)
|
|
323
|
+
],
|
|
324
|
+
metadata={
|
|
325
|
+
"event": "observation",
|
|
326
|
+
"state": TaskState.working.value,
|
|
327
|
+
"step": step_no,
|
|
328
|
+
},
|
|
329
|
+
name="observation",
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
def _step_stream(
|
|
333
|
+
self,
|
|
334
|
+
memory_step: ActionStep,
|
|
335
|
+
) -> Generator[ChatMessageStreamDelta | ToolCall | ToolOutput | ActionOutput]:
|
|
336
|
+
input_messages = self.write_memory_to_messages().copy()
|
|
337
|
+
memory_step.model_input_messages = input_messages
|
|
338
|
+
try:
|
|
339
|
+
if self.stream_outputs and hasattr(self.model, "generate_stream"):
|
|
340
|
+
output_stream = self.model.generate_stream(
|
|
341
|
+
input_messages,
|
|
342
|
+
stop_sequences=["Observation:", "Calling tools:"],
|
|
343
|
+
tools_to_call_from=self.tools_and_managed_agents,
|
|
344
|
+
)
|
|
345
|
+
chat_message_stream_deltas: list[ChatMessageStreamDelta] = []
|
|
346
|
+
with Live(
|
|
347
|
+
"", console=self.logger.console, vertical_overflow="visible"
|
|
348
|
+
) as live:
|
|
349
|
+
for event in output_stream:
|
|
350
|
+
chat_message_stream_deltas.append(event)
|
|
351
|
+
live.update(
|
|
352
|
+
Markdown(
|
|
353
|
+
agglomerate_stream_deltas(
|
|
354
|
+
chat_message_stream_deltas
|
|
355
|
+
).render_as_markdown()
|
|
356
|
+
)
|
|
357
|
+
)
|
|
358
|
+
yield event
|
|
359
|
+
chat_message = agglomerate_stream_deltas(chat_message_stream_deltas)
|
|
360
|
+
else:
|
|
361
|
+
chat_message = self.model.generate(
|
|
362
|
+
input_messages,
|
|
363
|
+
stop_sequences=["Observation:", "Calling tools:"],
|
|
364
|
+
tools_to_call_from=self.tools_and_managed_agents,
|
|
365
|
+
)
|
|
366
|
+
log_content = (
|
|
367
|
+
str(chat_message.raw)
|
|
368
|
+
if chat_message.content is None and chat_message.raw is not None
|
|
369
|
+
else str(chat_message.content) or ""
|
|
370
|
+
)
|
|
371
|
+
self.logger.log_markdown(
|
|
372
|
+
content=log_content,
|
|
373
|
+
title="Output message of the LLM:",
|
|
374
|
+
level=LogLevel.DEBUG,
|
|
375
|
+
)
|
|
376
|
+
memory_step.model_output_message = chat_message
|
|
377
|
+
memory_step.model_output = chat_message.content
|
|
378
|
+
memory_step.token_usage = chat_message.token_usage
|
|
379
|
+
except Exception as exc:
|
|
380
|
+
raise AgentGenerationError(
|
|
381
|
+
f"Error while generating output:\n{exc}", self.logger
|
|
382
|
+
) from exc
|
|
383
|
+
|
|
384
|
+
if chat_message.tool_calls is None or len(chat_message.tool_calls) == 0:
|
|
385
|
+
try:
|
|
386
|
+
chat_message = self.model.parse_tool_calls(chat_message)
|
|
387
|
+
except Exception as exc:
|
|
388
|
+
raise AgentParsingError(
|
|
389
|
+
f"Error while parsing tool call from model output: {exc}",
|
|
390
|
+
self.logger,
|
|
391
|
+
)
|
|
392
|
+
else:
|
|
393
|
+
for tool_call in chat_message.tool_calls:
|
|
394
|
+
tool_call.function.arguments = parse_json_if_needed(
|
|
395
|
+
tool_call.function.arguments
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
final_answer, got_final_answer = None, False
|
|
399
|
+
tool_observations: dict[str, dict[str, Any]] = {}
|
|
400
|
+
self.tool_callbacks.callback(chat_message, agent=self)
|
|
401
|
+
|
|
402
|
+
for output in self.process_tool_calls(chat_message, memory_step):
|
|
403
|
+
yield output
|
|
404
|
+
if isinstance(output, ToolOutput):
|
|
405
|
+
if output.id:
|
|
406
|
+
tool_name = (
|
|
407
|
+
getattr(output.tool_call, "name", None)
|
|
408
|
+
if output.tool_call is not None
|
|
409
|
+
else None
|
|
410
|
+
)
|
|
411
|
+
tool_observations[output.id] = {
|
|
412
|
+
"tool_name": tool_name,
|
|
413
|
+
"observation": output.observation,
|
|
414
|
+
}
|
|
415
|
+
if self.after_tool_hook is not None:
|
|
416
|
+
self.after_tool_hook(output)
|
|
417
|
+
if output.is_final_answer:
|
|
418
|
+
if got_final_answer:
|
|
419
|
+
raise AgentToolExecutionError(
|
|
420
|
+
"You returned multiple final answers. Please return only one single final answer!",
|
|
421
|
+
self.logger,
|
|
422
|
+
)
|
|
423
|
+
final_answer = output.output
|
|
424
|
+
got_final_answer = True
|
|
425
|
+
if (
|
|
426
|
+
isinstance(final_answer, str)
|
|
427
|
+
and final_answer in self.state.keys()
|
|
428
|
+
):
|
|
429
|
+
final_answer = self.state[final_answer]
|
|
430
|
+
memory_step.tool_observations = tool_observations
|
|
431
|
+
yield ActionOutput(output=final_answer, is_final_answer=got_final_answer)
|