lite-agent 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lite-agent might be problematic. Click here for more details.
- lite_agent/__init__.py +8 -0
- lite_agent/agent.py +395 -20
- lite_agent/client.py +34 -0
- lite_agent/loggers.py +1 -1
- lite_agent/message_transfers.py +111 -0
- lite_agent/processors/__init__.py +1 -1
- lite_agent/processors/stream_chunk_processor.py +63 -42
- lite_agent/rich_helpers.py +503 -0
- lite_agent/runner.py +612 -31
- lite_agent/stream_handlers/__init__.py +5 -0
- lite_agent/stream_handlers/litellm.py +106 -0
- lite_agent/templates/handoffs_source_instructions.xml.j2 +10 -0
- lite_agent/templates/handoffs_target_instructions.xml.j2 +9 -0
- lite_agent/templates/wait_for_user_instructions.xml.j2 +6 -0
- lite_agent/types/__init__.py +75 -0
- lite_agent/types/chunks.py +89 -0
- lite_agent/types/messages.py +135 -0
- lite_agent/types/tool_calls.py +15 -0
- lite_agent-0.3.0.dist-info/METADATA +111 -0
- lite_agent-0.3.0.dist-info/RECORD +22 -0
- lite_agent/__main__.py +0 -110
- lite_agent/chunk_handler.py +0 -166
- lite_agent/types.py +0 -152
- lite_agent-0.1.0.dist-info/METADATA +0 -22
- lite_agent-0.1.0.dist-info/RECORD +0 -13
- {lite_agent-0.1.0.dist-info → lite_agent-0.3.0.dist-info}/WHEEL +0 -0
lite_agent/runner.py
CHANGED
|
@@ -1,51 +1,632 @@
|
|
|
1
|
-
|
|
2
|
-
from
|
|
1
|
+
import json
|
|
2
|
+
from collections.abc import AsyncGenerator, Sequence
|
|
3
|
+
from os import PathLike
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
3
6
|
|
|
4
|
-
from
|
|
5
|
-
from
|
|
6
|
-
from
|
|
7
|
+
from lite_agent.agent import Agent
|
|
8
|
+
from lite_agent.loggers import logger
|
|
9
|
+
from lite_agent.types import (
|
|
10
|
+
AgentAssistantMessage,
|
|
11
|
+
AgentChunk,
|
|
12
|
+
AgentChunkType,
|
|
13
|
+
AgentFunctionCallOutput,
|
|
14
|
+
AgentFunctionToolCallMessage,
|
|
15
|
+
AgentSystemMessage,
|
|
16
|
+
AgentUserMessage,
|
|
17
|
+
FlexibleRunnerMessage,
|
|
18
|
+
MessageDict,
|
|
19
|
+
RunnerMessage,
|
|
20
|
+
ToolCall,
|
|
21
|
+
ToolCallFunction,
|
|
22
|
+
UserInput,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from lite_agent.types import AssistantMessage
|
|
27
|
+
|
|
28
|
+
DEFAULT_INCLUDES: tuple[AgentChunkType, ...] = (
|
|
29
|
+
"completion_raw",
|
|
30
|
+
"usage",
|
|
31
|
+
"final_message",
|
|
32
|
+
"tool_call",
|
|
33
|
+
"tool_call_result",
|
|
34
|
+
"content_delta",
|
|
35
|
+
"tool_call_delta",
|
|
36
|
+
)
|
|
7
37
|
|
|
8
38
|
|
|
9
39
|
class Runner:
|
|
10
40
|
def __init__(self, agent: Agent) -> None:
|
|
11
41
|
self.agent = agent
|
|
12
|
-
self.messages:
|
|
42
|
+
self.messages: list[RunnerMessage] = []
|
|
43
|
+
|
|
44
|
+
def _normalize_includes(self, includes: Sequence[AgentChunkType] | None) -> Sequence[AgentChunkType]:
|
|
45
|
+
"""Normalize includes parameter to default if None."""
|
|
46
|
+
return includes if includes is not None else DEFAULT_INCLUDES
|
|
47
|
+
|
|
48
|
+
def _normalize_record_path(self, record_to: PathLike | str | None) -> Path | None:
|
|
49
|
+
"""Normalize record_to parameter to Path object if provided."""
|
|
50
|
+
return Path(record_to) if record_to else None
|
|
51
|
+
|
|
52
|
+
async def _handle_tool_calls(self, tool_calls: "Sequence[ToolCall] | None", includes: Sequence[AgentChunkType], context: "Any | None" = None) -> AsyncGenerator[AgentChunk, None]: # noqa: ANN401, C901, PLR0912
|
|
53
|
+
"""Handle tool calls and yield appropriate chunks."""
|
|
54
|
+
if not tool_calls:
|
|
55
|
+
return
|
|
56
|
+
|
|
57
|
+
# Check for transfer_to_agent calls first
|
|
58
|
+
transfer_calls = [tc for tc in tool_calls if tc.function.name == "transfer_to_agent"]
|
|
59
|
+
if transfer_calls:
|
|
60
|
+
# Handle all transfer calls but only execute the first one
|
|
61
|
+
for i, tool_call in enumerate(transfer_calls):
|
|
62
|
+
if i == 0:
|
|
63
|
+
# Execute the first transfer
|
|
64
|
+
await self._handle_agent_transfer(tool_call, includes)
|
|
65
|
+
else:
|
|
66
|
+
# Add response for additional transfer calls without executing them
|
|
67
|
+
self.messages.append(
|
|
68
|
+
AgentFunctionCallOutput(
|
|
69
|
+
type="function_call_output",
|
|
70
|
+
call_id=tool_call.id,
|
|
71
|
+
output="Transfer already executed by previous call",
|
|
72
|
+
),
|
|
73
|
+
)
|
|
74
|
+
return # Stop processing other tool calls after transfer
|
|
75
|
+
|
|
76
|
+
return_parent_calls = [tc for tc in tool_calls if tc.function.name == "transfer_to_parent"]
|
|
77
|
+
if return_parent_calls:
|
|
78
|
+
# Handle multiple transfer_to_parent calls (only execute the first one)
|
|
79
|
+
for i, tool_call in enumerate(return_parent_calls):
|
|
80
|
+
if i == 0:
|
|
81
|
+
# Execute the first transfer
|
|
82
|
+
await self._handle_parent_transfer(tool_call, includes)
|
|
83
|
+
else:
|
|
84
|
+
# Add response for additional transfer calls without executing them
|
|
85
|
+
self.messages.append(
|
|
86
|
+
AgentFunctionCallOutput(
|
|
87
|
+
type="function_call_output",
|
|
88
|
+
call_id=tool_call.id,
|
|
89
|
+
output="Transfer already executed by previous call",
|
|
90
|
+
),
|
|
91
|
+
)
|
|
92
|
+
return # Stop processing other tool calls after transfer
|
|
13
93
|
|
|
14
|
-
|
|
94
|
+
async for tool_call_chunk in self.agent.handle_tool_calls(tool_calls, context=context):
|
|
95
|
+
if tool_call_chunk.type == "tool_call" and tool_call_chunk.type in includes:
|
|
96
|
+
yield tool_call_chunk
|
|
97
|
+
if tool_call_chunk.type == "tool_call_result":
|
|
98
|
+
if tool_call_chunk.type in includes:
|
|
99
|
+
yield tool_call_chunk
|
|
100
|
+
# Create function call output in responses format
|
|
101
|
+
self.messages.append(
|
|
102
|
+
AgentFunctionCallOutput(
|
|
103
|
+
type="function_call_output",
|
|
104
|
+
call_id=tool_call_chunk.tool_call_id,
|
|
105
|
+
output=tool_call_chunk.content,
|
|
106
|
+
),
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
async def _collect_all_chunks(self, stream: AsyncGenerator[AgentChunk, None]) -> list[AgentChunk]:
|
|
110
|
+
"""Collect all chunks from an async generator into a list."""
|
|
111
|
+
return [chunk async for chunk in stream]
|
|
112
|
+
|
|
113
|
+
def run(
|
|
15
114
|
self,
|
|
16
|
-
user_input:
|
|
115
|
+
user_input: UserInput,
|
|
17
116
|
max_steps: int = 20,
|
|
18
|
-
includes:
|
|
117
|
+
includes: Sequence[AgentChunkType] | None = None,
|
|
118
|
+
context: "Any | None" = None, # noqa: ANN401
|
|
119
|
+
record_to: PathLike | str | None = None,
|
|
19
120
|
) -> AsyncGenerator[AgentChunk, None]:
|
|
20
121
|
"""Run the agent and return a RunResponse object that can be asynchronously iterated for each chunk."""
|
|
21
|
-
|
|
22
|
-
includes = ["final_message", "usage", "tool_call", "tool_call_result", "tool_call_delta", "content_delta"]
|
|
122
|
+
includes = self._normalize_includes(includes)
|
|
23
123
|
if isinstance(user_input, str):
|
|
24
|
-
self.messages.append(
|
|
124
|
+
self.messages.append(AgentUserMessage(role="user", content=user_input))
|
|
125
|
+
elif isinstance(user_input, (list, tuple)):
|
|
126
|
+
# Handle sequence of messages
|
|
127
|
+
for message in user_input:
|
|
128
|
+
self.append_message(message)
|
|
25
129
|
else:
|
|
26
|
-
|
|
130
|
+
# Handle single message (BaseModel, TypedDict, or dict)
|
|
131
|
+
# Type assertion needed due to the complex union type
|
|
132
|
+
self.append_message(user_input) # type: ignore[arg-type]
|
|
133
|
+
return self._run(max_steps, includes, self._normalize_record_path(record_to), context=context)
|
|
27
134
|
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
async def _run_aiter(self, max_steps: int, includes: list[Literal["usage", "final_message", "tool_call", "tool_call_result"]]) -> AsyncGenerator[AgentChunk, None]:
|
|
135
|
+
async def _run(self, max_steps: int, includes: Sequence[AgentChunkType], record_to: Path | None = None, context: "Any | None" = None) -> AsyncGenerator[AgentChunk, None]: # noqa: ANN401, C901
|
|
31
136
|
"""Run the agent and return a RunResponse object that can be asynchronously iterated for each chunk."""
|
|
137
|
+
logger.debug(f"Running agent with messages: {self.messages}")
|
|
32
138
|
steps = 0
|
|
33
139
|
finish_reason = None
|
|
34
|
-
|
|
35
|
-
|
|
140
|
+
|
|
141
|
+
# Determine completion condition based on agent configuration
|
|
142
|
+
completion_condition = getattr(self.agent, "completion_condition", "stop")
|
|
143
|
+
|
|
144
|
+
def is_finish() -> bool:
|
|
145
|
+
if completion_condition == "call":
|
|
146
|
+
function_calls = self._find_pending_function_calls()
|
|
147
|
+
return any(getattr(fc, "name", None) == "wait_for_user" for fc in function_calls)
|
|
148
|
+
return finish_reason == "stop"
|
|
149
|
+
|
|
150
|
+
while not is_finish() and steps < max_steps:
|
|
151
|
+
resp = await self.agent.completion(self.messages, record_to_file=record_to)
|
|
36
152
|
async for chunk in resp:
|
|
37
|
-
if chunk
|
|
38
|
-
message = chunk["message"]
|
|
39
|
-
self.messages.append(message.model_dump())
|
|
40
|
-
finish_reason = chunk["finish_reason"]
|
|
41
|
-
elif chunk["type"] == "tool_call_result":
|
|
42
|
-
self.messages.append(
|
|
43
|
-
AgentToolCallMessage(
|
|
44
|
-
role="tool",
|
|
45
|
-
tool_call_id=chunk["tool_call_id"],
|
|
46
|
-
content=chunk["content"],
|
|
47
|
-
),
|
|
48
|
-
)
|
|
49
|
-
if chunk["type"] in includes:
|
|
153
|
+
if chunk.type in includes:
|
|
50
154
|
yield chunk
|
|
155
|
+
|
|
156
|
+
if chunk.type == "final_message":
|
|
157
|
+
message = chunk.message
|
|
158
|
+
# Convert to responses format and add to messages
|
|
159
|
+
await self._convert_final_message_to_responses_format(message)
|
|
160
|
+
finish_reason = chunk.finish_reason
|
|
161
|
+
if finish_reason == "tool_calls":
|
|
162
|
+
# Find pending function calls in responses format
|
|
163
|
+
pending_function_calls = self._find_pending_function_calls()
|
|
164
|
+
if pending_function_calls:
|
|
165
|
+
# Convert to ToolCall format for existing handler
|
|
166
|
+
tool_calls = self._convert_function_calls_to_tool_calls(pending_function_calls)
|
|
167
|
+
require_confirm_tools = await self.agent.list_require_confirm_tools(tool_calls)
|
|
168
|
+
if require_confirm_tools:
|
|
169
|
+
return
|
|
170
|
+
async for tool_chunk in self._handle_tool_calls(tool_calls, includes, context=context):
|
|
171
|
+
yield tool_chunk
|
|
51
172
|
steps += 1
|
|
173
|
+
|
|
174
|
+
async def run_continue_until_complete(
|
|
175
|
+
self,
|
|
176
|
+
max_steps: int = 20,
|
|
177
|
+
includes: list[AgentChunkType] | None = None,
|
|
178
|
+
record_to: PathLike | str | None = None,
|
|
179
|
+
) -> list[AgentChunk]:
|
|
180
|
+
resp = self.run_continue_stream(max_steps, includes, record_to=record_to)
|
|
181
|
+
return await self._collect_all_chunks(resp)
|
|
182
|
+
|
|
183
|
+
def run_continue_stream(
|
|
184
|
+
self,
|
|
185
|
+
max_steps: int = 20,
|
|
186
|
+
includes: list[AgentChunkType] | None = None,
|
|
187
|
+
record_to: PathLike | str | None = None,
|
|
188
|
+
context: "Any | None" = None, # noqa: ANN401
|
|
189
|
+
) -> AsyncGenerator[AgentChunk, None]:
|
|
190
|
+
return self._run_continue_stream(max_steps, includes, record_to=record_to, context=context)
|
|
191
|
+
|
|
192
|
+
async def _run_continue_stream(
|
|
193
|
+
self,
|
|
194
|
+
max_steps: int = 20,
|
|
195
|
+
includes: Sequence[AgentChunkType] | None = None,
|
|
196
|
+
record_to: PathLike | str | None = None,
|
|
197
|
+
context: "Any | None" = None, # noqa: ANN401
|
|
198
|
+
) -> AsyncGenerator[AgentChunk, None]:
|
|
199
|
+
"""Continue running the agent and return a RunResponse object that can be asynchronously iterated for each chunk."""
|
|
200
|
+
includes = self._normalize_includes(includes)
|
|
201
|
+
|
|
202
|
+
# Find pending function calls in responses format
|
|
203
|
+
pending_function_calls = self._find_pending_function_calls()
|
|
204
|
+
if pending_function_calls:
|
|
205
|
+
# Convert to ToolCall format for existing handler
|
|
206
|
+
tool_calls = self._convert_function_calls_to_tool_calls(pending_function_calls)
|
|
207
|
+
async for tool_chunk in self._handle_tool_calls(tool_calls, includes, context=context):
|
|
208
|
+
yield tool_chunk
|
|
209
|
+
async for chunk in self._run(max_steps, includes, self._normalize_record_path(record_to)):
|
|
210
|
+
if chunk.type in includes:
|
|
211
|
+
yield chunk
|
|
212
|
+
else:
|
|
213
|
+
# Check if there are any messages and what the last message is
|
|
214
|
+
if not self.messages:
|
|
215
|
+
msg = "Cannot continue running without a valid last message from the assistant."
|
|
216
|
+
raise ValueError(msg)
|
|
217
|
+
|
|
218
|
+
last_message = self.messages[-1]
|
|
219
|
+
if not (isinstance(last_message, AgentAssistantMessage) or (hasattr(last_message, "role") and getattr(last_message, "role", None) == "assistant")):
|
|
220
|
+
msg = "Cannot continue running without a valid last message from the assistant."
|
|
221
|
+
raise ValueError(msg)
|
|
222
|
+
|
|
223
|
+
resp = self._run(max_steps=max_steps, includes=includes, record_to=self._normalize_record_path(record_to), context=context)
|
|
224
|
+
async for chunk in resp:
|
|
225
|
+
yield chunk
|
|
226
|
+
|
|
227
|
+
async def run_until_complete(
|
|
228
|
+
self,
|
|
229
|
+
user_input: UserInput,
|
|
230
|
+
max_steps: int = 20,
|
|
231
|
+
includes: list[AgentChunkType] | None = None,
|
|
232
|
+
record_to: PathLike | str | None = None,
|
|
233
|
+
) -> list[AgentChunk]:
|
|
234
|
+
"""Run the agent until it completes and return the final message."""
|
|
235
|
+
resp = self.run(user_input, max_steps, includes, record_to=record_to)
|
|
236
|
+
return await self._collect_all_chunks(resp)
|
|
237
|
+
|
|
238
|
+
async def _convert_final_message_to_responses_format(self, message: "AssistantMessage") -> None:
|
|
239
|
+
"""Convert a completions format final message to responses format messages."""
|
|
240
|
+
# The final message from the stream handler might still contain tool_calls
|
|
241
|
+
# We need to convert it to responses format
|
|
242
|
+
if hasattr(message, "tool_calls") and message.tool_calls:
|
|
243
|
+
if message.content:
|
|
244
|
+
# Add the assistant message without tool_calls
|
|
245
|
+
assistant_msg = AgentAssistantMessage(
|
|
246
|
+
role="assistant",
|
|
247
|
+
content=message.content,
|
|
248
|
+
)
|
|
249
|
+
self.messages.append(assistant_msg)
|
|
250
|
+
|
|
251
|
+
# Add function call messages
|
|
252
|
+
for tool_call in message.tool_calls:
|
|
253
|
+
function_call_msg = AgentFunctionToolCallMessage(
|
|
254
|
+
type="function_call",
|
|
255
|
+
function_call_id=tool_call.id,
|
|
256
|
+
name=tool_call.function.name,
|
|
257
|
+
arguments=tool_call.function.arguments or "",
|
|
258
|
+
content="",
|
|
259
|
+
)
|
|
260
|
+
self.messages.append(function_call_msg)
|
|
261
|
+
else:
|
|
262
|
+
# Regular assistant message without tool calls
|
|
263
|
+
assistant_msg = AgentAssistantMessage(
|
|
264
|
+
role="assistant",
|
|
265
|
+
content=message.content,
|
|
266
|
+
)
|
|
267
|
+
self.messages.append(assistant_msg)
|
|
268
|
+
|
|
269
|
+
def _find_pending_function_calls(self) -> list:
|
|
270
|
+
"""Find function call messages that don't have corresponding outputs yet."""
|
|
271
|
+
function_calls: list[AgentFunctionToolCallMessage] = []
|
|
272
|
+
function_call_ids = set()
|
|
273
|
+
|
|
274
|
+
# Collect all function call messages
|
|
275
|
+
for msg in reversed(self.messages):
|
|
276
|
+
if isinstance(msg, AgentFunctionToolCallMessage):
|
|
277
|
+
function_calls.append(msg)
|
|
278
|
+
function_call_ids.add(msg.function_call_id)
|
|
279
|
+
elif isinstance(msg, AgentFunctionCallOutput):
|
|
280
|
+
# Remove the corresponding function call from our list
|
|
281
|
+
function_call_ids.discard(msg.call_id)
|
|
282
|
+
elif isinstance(msg, AgentAssistantMessage):
|
|
283
|
+
# Stop when we hit the assistant message that initiated these calls
|
|
284
|
+
break
|
|
285
|
+
|
|
286
|
+
# Return only function calls that don't have outputs yet
|
|
287
|
+
return [fc for fc in function_calls if fc.function_call_id in function_call_ids]
|
|
288
|
+
|
|
289
|
+
def _convert_function_calls_to_tool_calls(self, function_calls: list[AgentFunctionToolCallMessage]) -> list[ToolCall]:
|
|
290
|
+
"""Convert function call messages to ToolCall objects for compatibility."""
|
|
291
|
+
|
|
292
|
+
tool_calls = []
|
|
293
|
+
for fc in function_calls:
|
|
294
|
+
tool_call = ToolCall(
|
|
295
|
+
id=fc.function_call_id,
|
|
296
|
+
type="function",
|
|
297
|
+
function=ToolCallFunction(
|
|
298
|
+
name=fc.name,
|
|
299
|
+
arguments=fc.arguments,
|
|
300
|
+
),
|
|
301
|
+
index=len(tool_calls),
|
|
302
|
+
)
|
|
303
|
+
tool_calls.append(tool_call)
|
|
304
|
+
return tool_calls
|
|
305
|
+
|
|
306
|
+
def set_chat_history(self, messages: Sequence[FlexibleRunnerMessage], root_agent: Agent | None = None) -> None:
|
|
307
|
+
"""Set the entire chat history and track the current agent based on function calls.
|
|
308
|
+
|
|
309
|
+
This method analyzes the message history to determine which agent should be active
|
|
310
|
+
based on transfer_to_agent and transfer_to_parent function calls.
|
|
311
|
+
|
|
312
|
+
Args:
|
|
313
|
+
messages: List of messages to set as the chat history
|
|
314
|
+
root_agent: The root agent to use if no transfers are found. If None, uses self.agent
|
|
315
|
+
"""
|
|
316
|
+
# Clear current messages
|
|
317
|
+
self.messages.clear()
|
|
318
|
+
|
|
319
|
+
# Set initial agent
|
|
320
|
+
current_agent = root_agent if root_agent is not None else self.agent
|
|
321
|
+
|
|
322
|
+
# Add each message and track agent transfers
|
|
323
|
+
for message in messages:
|
|
324
|
+
self.append_message(message)
|
|
325
|
+
current_agent = self._track_agent_transfer_in_message(message, current_agent)
|
|
326
|
+
|
|
327
|
+
# Set the current agent based on the tracked transfers
|
|
328
|
+
self.agent = current_agent
|
|
329
|
+
logger.info(f"Chat history set with {len(self.messages)} messages. Current agent: {self.agent.name}")
|
|
330
|
+
|
|
331
|
+
def get_messages_dict(self) -> list[dict[str, Any]]:
|
|
332
|
+
"""Get the messages in JSONL format."""
|
|
333
|
+
return [msg.model_dump(mode="json") for msg in self.messages]
|
|
334
|
+
|
|
335
|
+
def _track_agent_transfer_in_message(self, message: FlexibleRunnerMessage, current_agent: Agent) -> Agent:
|
|
336
|
+
"""Track agent transfers in a single message.
|
|
337
|
+
|
|
338
|
+
Args:
|
|
339
|
+
message: The message to analyze for transfers
|
|
340
|
+
current_agent: The currently active agent
|
|
341
|
+
|
|
342
|
+
Returns:
|
|
343
|
+
The agent that should be active after processing this message
|
|
344
|
+
"""
|
|
345
|
+
if isinstance(message, dict):
|
|
346
|
+
return self._track_transfer_from_dict_message(message, current_agent)
|
|
347
|
+
|
|
348
|
+
if isinstance(message, AgentFunctionToolCallMessage):
|
|
349
|
+
return self._track_transfer_from_function_call_message(message, current_agent)
|
|
350
|
+
|
|
351
|
+
return current_agent
|
|
352
|
+
|
|
353
|
+
def _track_transfer_from_dict_message(self, message: dict[str, Any] | MessageDict, current_agent: Agent) -> Agent:
|
|
354
|
+
"""Track transfers from dictionary-format messages."""
|
|
355
|
+
message_type = message.get("type")
|
|
356
|
+
if message_type != "function_call":
|
|
357
|
+
return current_agent
|
|
358
|
+
|
|
359
|
+
function_name = message.get("name", "")
|
|
360
|
+
if function_name == "transfer_to_agent":
|
|
361
|
+
return self._handle_transfer_to_agent_tracking(message.get("arguments", ""), current_agent)
|
|
362
|
+
|
|
363
|
+
if function_name == "transfer_to_parent":
|
|
364
|
+
return self._handle_transfer_to_parent_tracking(current_agent)
|
|
365
|
+
|
|
366
|
+
return current_agent
|
|
367
|
+
|
|
368
|
+
def _track_transfer_from_function_call_message(self, message: AgentFunctionToolCallMessage, current_agent: Agent) -> Agent:
|
|
369
|
+
"""Track transfers from AgentFunctionToolCallMessage objects."""
|
|
370
|
+
if message.name == "transfer_to_agent":
|
|
371
|
+
return self._handle_transfer_to_agent_tracking(message.arguments, current_agent)
|
|
372
|
+
|
|
373
|
+
if message.name == "transfer_to_parent":
|
|
374
|
+
return self._handle_transfer_to_parent_tracking(current_agent)
|
|
375
|
+
|
|
376
|
+
return current_agent
|
|
377
|
+
|
|
378
|
+
def _handle_transfer_to_agent_tracking(self, arguments: str | dict, current_agent: Agent) -> Agent:
|
|
379
|
+
"""Handle transfer_to_agent function call tracking."""
|
|
380
|
+
try:
|
|
381
|
+
args_dict = json.loads(arguments) if isinstance(arguments, str) else arguments
|
|
382
|
+
|
|
383
|
+
target_agent_name = args_dict.get("name")
|
|
384
|
+
if target_agent_name:
|
|
385
|
+
target_agent = self._find_agent_by_name(current_agent, target_agent_name)
|
|
386
|
+
if target_agent:
|
|
387
|
+
logger.debug(f"History tracking: Transferring from {current_agent.name} to {target_agent_name}")
|
|
388
|
+
return target_agent
|
|
389
|
+
|
|
390
|
+
logger.warning(f"Target agent '{target_agent_name}' not found in handoffs during history setup")
|
|
391
|
+
except (json.JSONDecodeError, KeyError, TypeError) as e:
|
|
392
|
+
logger.warning(f"Failed to parse transfer_to_agent arguments during history setup: {e}")
|
|
393
|
+
|
|
394
|
+
return current_agent
|
|
395
|
+
|
|
396
|
+
def _handle_transfer_to_parent_tracking(self, current_agent: Agent) -> Agent:
|
|
397
|
+
"""Handle transfer_to_parent function call tracking."""
|
|
398
|
+
if current_agent.parent:
|
|
399
|
+
logger.debug(f"History tracking: Transferring from {current_agent.name} back to parent {current_agent.parent.name}")
|
|
400
|
+
return current_agent.parent
|
|
401
|
+
|
|
402
|
+
logger.warning(f"Agent {current_agent.name} has no parent to transfer back to during history setup")
|
|
403
|
+
return current_agent
|
|
404
|
+
|
|
405
|
+
def _find_agent_by_name(self, root_agent: Agent, target_name: str) -> Agent | None:
|
|
406
|
+
"""Find an agent by name in the handoffs tree starting from root_agent.
|
|
407
|
+
|
|
408
|
+
Args:
|
|
409
|
+
root_agent: The root agent to start searching from
|
|
410
|
+
target_name: The name of the agent to find
|
|
411
|
+
|
|
412
|
+
Returns:
|
|
413
|
+
The agent if found, None otherwise
|
|
414
|
+
"""
|
|
415
|
+
# Check direct handoffs from current agent
|
|
416
|
+
if root_agent.handoffs:
|
|
417
|
+
for agent in root_agent.handoffs:
|
|
418
|
+
if agent.name == target_name:
|
|
419
|
+
return agent
|
|
420
|
+
|
|
421
|
+
# If not found in direct handoffs, check if we need to look in parent's handoffs
|
|
422
|
+
# This handles cases where agents can transfer to siblings
|
|
423
|
+
current = root_agent
|
|
424
|
+
while current.parent is not None:
|
|
425
|
+
current = current.parent
|
|
426
|
+
if current.handoffs:
|
|
427
|
+
for agent in current.handoffs:
|
|
428
|
+
if agent.name == target_name:
|
|
429
|
+
return agent
|
|
430
|
+
|
|
431
|
+
return None
|
|
432
|
+
|
|
433
|
+
def append_message(self, message: FlexibleRunnerMessage) -> None:
|
|
434
|
+
if isinstance(message, RunnerMessage):
|
|
435
|
+
self.messages.append(message)
|
|
436
|
+
elif isinstance(message, dict):
|
|
437
|
+
# Handle different message types
|
|
438
|
+
message_type = message.get("type")
|
|
439
|
+
role = message.get("role")
|
|
440
|
+
|
|
441
|
+
if message_type == "function_call":
|
|
442
|
+
# Function call message
|
|
443
|
+
self.messages.append(AgentFunctionToolCallMessage.model_validate(message))
|
|
444
|
+
elif message_type == "function_call_output":
|
|
445
|
+
# Function call output message
|
|
446
|
+
self.messages.append(AgentFunctionCallOutput.model_validate(message))
|
|
447
|
+
elif role == "assistant" and "tool_calls" in message:
|
|
448
|
+
# Legacy assistant message with tool_calls - convert to responses format
|
|
449
|
+
# Add assistant message without tool_calls
|
|
450
|
+
assistant_msg = AgentAssistantMessage(
|
|
451
|
+
role="assistant",
|
|
452
|
+
content=message.get("content", ""),
|
|
453
|
+
)
|
|
454
|
+
self.messages.append(assistant_msg)
|
|
455
|
+
|
|
456
|
+
# Convert tool_calls to function call messages
|
|
457
|
+
for tool_call in message.get("tool_calls", []):
|
|
458
|
+
function_call_msg = AgentFunctionToolCallMessage(
|
|
459
|
+
type="function_call",
|
|
460
|
+
function_call_id=tool_call["id"],
|
|
461
|
+
name=tool_call["function"]["name"],
|
|
462
|
+
arguments=tool_call["function"]["arguments"],
|
|
463
|
+
content="",
|
|
464
|
+
)
|
|
465
|
+
self.messages.append(function_call_msg)
|
|
466
|
+
elif role:
|
|
467
|
+
# Regular role-based message
|
|
468
|
+
role_to_message_class = {
|
|
469
|
+
"user": AgentUserMessage,
|
|
470
|
+
"assistant": AgentAssistantMessage,
|
|
471
|
+
"system": AgentSystemMessage,
|
|
472
|
+
}
|
|
473
|
+
|
|
474
|
+
message_class = role_to_message_class.get(role)
|
|
475
|
+
if message_class:
|
|
476
|
+
self.messages.append(message_class.model_validate(message))
|
|
477
|
+
else:
|
|
478
|
+
msg = f"Unsupported message role: {role}"
|
|
479
|
+
raise ValueError(msg)
|
|
480
|
+
else:
|
|
481
|
+
msg = "Message must have a 'role' or 'type' field."
|
|
482
|
+
raise ValueError(msg)
|
|
483
|
+
|
|
484
|
+
async def _handle_agent_transfer(self, tool_call: ToolCall, _includes: Sequence[AgentChunkType]) -> None:
|
|
485
|
+
"""Handle agent transfer when transfer_to_agent tool is called.
|
|
486
|
+
|
|
487
|
+
Args:
|
|
488
|
+
tool_call: The transfer_to_agent tool call
|
|
489
|
+
_includes: The types of chunks to include in output (unused)
|
|
490
|
+
"""
|
|
491
|
+
|
|
492
|
+
# Parse the arguments to get the target agent name
|
|
493
|
+
try:
|
|
494
|
+
arguments = json.loads(tool_call.function.arguments or "{}")
|
|
495
|
+
target_agent_name = arguments.get("name")
|
|
496
|
+
except (json.JSONDecodeError, KeyError):
|
|
497
|
+
logger.error("Failed to parse transfer_to_agent arguments: %s", tool_call.function.arguments)
|
|
498
|
+
# Add error result to messages
|
|
499
|
+
self.messages.append(
|
|
500
|
+
AgentFunctionCallOutput(
|
|
501
|
+
type="function_call_output",
|
|
502
|
+
call_id=tool_call.id,
|
|
503
|
+
output="Failed to parse transfer arguments",
|
|
504
|
+
),
|
|
505
|
+
)
|
|
506
|
+
return
|
|
507
|
+
|
|
508
|
+
if not target_agent_name:
|
|
509
|
+
logger.error("No target agent name provided in transfer_to_agent call")
|
|
510
|
+
# Add error result to messages
|
|
511
|
+
self.messages.append(
|
|
512
|
+
AgentFunctionCallOutput(
|
|
513
|
+
type="function_call_output",
|
|
514
|
+
call_id=tool_call.id,
|
|
515
|
+
output="No target agent name provided",
|
|
516
|
+
),
|
|
517
|
+
)
|
|
518
|
+
return
|
|
519
|
+
|
|
520
|
+
# Find the target agent in handoffs
|
|
521
|
+
if not self.agent.handoffs:
|
|
522
|
+
logger.error("Current agent has no handoffs configured")
|
|
523
|
+
# Add error result to messages
|
|
524
|
+
self.messages.append(
|
|
525
|
+
AgentFunctionCallOutput(
|
|
526
|
+
type="function_call_output",
|
|
527
|
+
call_id=tool_call.id,
|
|
528
|
+
output="Current agent has no handoffs configured",
|
|
529
|
+
),
|
|
530
|
+
)
|
|
531
|
+
return
|
|
532
|
+
|
|
533
|
+
target_agent = None
|
|
534
|
+
for agent in self.agent.handoffs:
|
|
535
|
+
if agent.name == target_agent_name:
|
|
536
|
+
target_agent = agent
|
|
537
|
+
break
|
|
538
|
+
|
|
539
|
+
if not target_agent:
|
|
540
|
+
logger.error("Target agent '%s' not found in handoffs", target_agent_name)
|
|
541
|
+
# Add error result to messages
|
|
542
|
+
self.messages.append(
|
|
543
|
+
AgentFunctionCallOutput(
|
|
544
|
+
type="function_call_output",
|
|
545
|
+
call_id=tool_call.id,
|
|
546
|
+
output=f"Target agent '{target_agent_name}' not found in handoffs",
|
|
547
|
+
),
|
|
548
|
+
)
|
|
549
|
+
return
|
|
550
|
+
|
|
551
|
+
# Execute the transfer tool call to get the result
|
|
552
|
+
try:
|
|
553
|
+
result = await self.agent.fc.call_function_async(
|
|
554
|
+
tool_call.function.name,
|
|
555
|
+
tool_call.function.arguments or "",
|
|
556
|
+
)
|
|
557
|
+
|
|
558
|
+
# Add the tool call result to messages
|
|
559
|
+
self.messages.append(
|
|
560
|
+
AgentFunctionCallOutput(
|
|
561
|
+
type="function_call_output",
|
|
562
|
+
call_id=tool_call.id,
|
|
563
|
+
output=str(result),
|
|
564
|
+
),
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
# Switch to the target agent
|
|
568
|
+
logger.info("Transferring conversation from %s to %s", self.agent.name, target_agent_name)
|
|
569
|
+
self.agent = target_agent
|
|
570
|
+
|
|
571
|
+
except Exception as e:
|
|
572
|
+
logger.exception("Failed to execute transfer_to_agent tool call")
|
|
573
|
+
# Add error result to messages
|
|
574
|
+
self.messages.append(
|
|
575
|
+
AgentFunctionCallOutput(
|
|
576
|
+
type="function_call_output",
|
|
577
|
+
call_id=tool_call.id,
|
|
578
|
+
output=f"Transfer failed: {e!s}",
|
|
579
|
+
),
|
|
580
|
+
)
|
|
581
|
+
|
|
582
|
+
async def _handle_parent_transfer(self, tool_call: ToolCall, _includes: Sequence[AgentChunkType]) -> None:
|
|
583
|
+
"""Handle parent transfer when transfer_to_parent tool is called.
|
|
584
|
+
|
|
585
|
+
Args:
|
|
586
|
+
tool_call: The transfer_to_parent tool call
|
|
587
|
+
_includes: The types of chunks to include in output (unused)
|
|
588
|
+
"""
|
|
589
|
+
|
|
590
|
+
# Check if current agent has a parent
|
|
591
|
+
if not self.agent.parent:
|
|
592
|
+
logger.error("Current agent has no parent to transfer back to.")
|
|
593
|
+
# Add error result to messages
|
|
594
|
+
self.messages.append(
|
|
595
|
+
AgentFunctionCallOutput(
|
|
596
|
+
type="function_call_output",
|
|
597
|
+
call_id=tool_call.id,
|
|
598
|
+
output="Current agent has no parent to transfer back to",
|
|
599
|
+
),
|
|
600
|
+
)
|
|
601
|
+
return
|
|
602
|
+
|
|
603
|
+
# Execute the transfer tool call to get the result
|
|
604
|
+
try:
|
|
605
|
+
result = await self.agent.fc.call_function_async(
|
|
606
|
+
tool_call.function.name,
|
|
607
|
+
tool_call.function.arguments or "",
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
# Add the tool call result to messages
|
|
611
|
+
self.messages.append(
|
|
612
|
+
AgentFunctionCallOutput(
|
|
613
|
+
type="function_call_output",
|
|
614
|
+
call_id=tool_call.id,
|
|
615
|
+
output=str(result),
|
|
616
|
+
),
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
# Switch to the parent agent
|
|
620
|
+
logger.info("Transferring conversation from %s back to parent %s", self.agent.name, self.agent.parent.name)
|
|
621
|
+
self.agent = self.agent.parent
|
|
622
|
+
|
|
623
|
+
except Exception as e:
|
|
624
|
+
logger.exception("Failed to execute transfer_to_parent tool call")
|
|
625
|
+
# Add error result to messages
|
|
626
|
+
self.messages.append(
|
|
627
|
+
AgentFunctionCallOutput(
|
|
628
|
+
type="function_call_output",
|
|
629
|
+
call_id=tool_call.id,
|
|
630
|
+
output=f"Transfer to parent failed: {e!s}",
|
|
631
|
+
),
|
|
632
|
+
)
|