lite-agent 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lite-agent might be problematic. Click here for more details.
- lite_agent/__init__.py +7 -0
- lite_agent/agent.py +310 -16
- lite_agent/loggers.py +1 -1
- lite_agent/message_transfers.py +111 -0
- lite_agent/processors/__init__.py +1 -1
- lite_agent/processors/stream_chunk_processor.py +63 -42
- lite_agent/runner.py +465 -29
- lite_agent/stream_handlers/__init__.py +5 -0
- lite_agent/stream_handlers/litellm.py +106 -0
- lite_agent/types/__init__.py +55 -0
- lite_agent/types/chunks.py +89 -0
- lite_agent/types/messages.py +68 -0
- lite_agent/types/tool_calls.py +15 -0
- lite_agent-0.2.0.dist-info/METADATA +111 -0
- lite_agent-0.2.0.dist-info/RECORD +17 -0
- lite_agent/__main__.py +0 -110
- lite_agent/chunk_handler.py +0 -166
- lite_agent/types.py +0 -152
- lite_agent-0.1.0.dist-info/METADATA +0 -22
- lite_agent-0.1.0.dist-info/RECORD +0 -13
- {lite_agent-0.1.0.dist-info → lite_agent-0.2.0.dist-info}/WHEEL +0 -0
lite_agent/runner.py
CHANGED
|
@@ -1,51 +1,487 @@
|
|
|
1
|
-
|
|
2
|
-
from
|
|
1
|
+
import json
|
|
2
|
+
from collections.abc import AsyncGenerator, Sequence
|
|
3
|
+
from os import PathLike
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
3
6
|
|
|
4
|
-
from
|
|
5
|
-
from
|
|
6
|
-
from
|
|
7
|
+
from lite_agent.agent import Agent
|
|
8
|
+
from lite_agent.loggers import logger
|
|
9
|
+
from lite_agent.types import (
|
|
10
|
+
AgentAssistantMessage,
|
|
11
|
+
AgentChunk,
|
|
12
|
+
AgentChunkType,
|
|
13
|
+
AgentFunctionCallOutput,
|
|
14
|
+
AgentFunctionToolCallMessage,
|
|
15
|
+
AgentSystemMessage,
|
|
16
|
+
AgentUserMessage,
|
|
17
|
+
RunnerMessage,
|
|
18
|
+
RunnerMessages,
|
|
19
|
+
ToolCall,
|
|
20
|
+
ToolCallFunction,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
from lite_agent.types import AssistantMessage
|
|
25
|
+
|
|
26
|
+
DEFAULT_INCLUDES: tuple[AgentChunkType, ...] = (
|
|
27
|
+
"completion_raw",
|
|
28
|
+
"usage",
|
|
29
|
+
"final_message",
|
|
30
|
+
"tool_call",
|
|
31
|
+
"tool_call_result",
|
|
32
|
+
"content_delta",
|
|
33
|
+
"tool_call_delta",
|
|
34
|
+
)
|
|
7
35
|
|
|
8
36
|
|
|
9
37
|
class Runner:
|
|
10
38
|
def __init__(self, agent: Agent) -> None:
|
|
11
39
|
self.agent = agent
|
|
12
|
-
self.messages:
|
|
40
|
+
self.messages: list[RunnerMessage] = []
|
|
41
|
+
|
|
42
|
+
def _normalize_includes(self, includes: Sequence[AgentChunkType] | None) -> Sequence[AgentChunkType]:
|
|
43
|
+
"""Normalize includes parameter to default if None."""
|
|
44
|
+
return includes if includes is not None else DEFAULT_INCLUDES
|
|
45
|
+
|
|
46
|
+
def _normalize_record_path(self, record_to: PathLike | str | None) -> Path | None:
|
|
47
|
+
"""Normalize record_to parameter to Path object if provided."""
|
|
48
|
+
return Path(record_to) if record_to else None
|
|
49
|
+
|
|
50
|
+
async def _handle_tool_calls(self, tool_calls: "Sequence[ToolCall] | None", includes: Sequence[AgentChunkType], context: "Any | None" = None) -> AsyncGenerator[AgentChunk, None]: # noqa: ANN401, C901, PLR0912
|
|
51
|
+
"""Handle tool calls and yield appropriate chunks."""
|
|
52
|
+
if not tool_calls:
|
|
53
|
+
return
|
|
54
|
+
|
|
55
|
+
# Check for transfer_to_agent calls first
|
|
56
|
+
transfer_calls = [tc for tc in tool_calls if tc.function.name == "transfer_to_agent"]
|
|
57
|
+
if transfer_calls:
|
|
58
|
+
# Handle all transfer calls but only execute the first one
|
|
59
|
+
for i, tool_call in enumerate(transfer_calls):
|
|
60
|
+
if i == 0:
|
|
61
|
+
# Execute the first transfer
|
|
62
|
+
await self._handle_agent_transfer(tool_call, includes)
|
|
63
|
+
else:
|
|
64
|
+
# Add response for additional transfer calls without executing them
|
|
65
|
+
self.messages.append(
|
|
66
|
+
AgentFunctionCallOutput(
|
|
67
|
+
type="function_call_output",
|
|
68
|
+
call_id=tool_call.id,
|
|
69
|
+
output="Transfer already executed by previous call",
|
|
70
|
+
),
|
|
71
|
+
)
|
|
72
|
+
return # Stop processing other tool calls after transfer
|
|
73
|
+
return_parent_calls = [tc for tc in tool_calls if tc.function.name == "transfer_to_parent"]
|
|
74
|
+
if return_parent_calls:
|
|
75
|
+
# Handle multiple transfer_to_parent calls (only execute the first one)
|
|
76
|
+
for i, tool_call in enumerate(return_parent_calls):
|
|
77
|
+
if i == 0:
|
|
78
|
+
# Execute the first transfer
|
|
79
|
+
await self._handle_parent_transfer(tool_call, includes)
|
|
80
|
+
else:
|
|
81
|
+
# Add response for additional transfer calls without executing them
|
|
82
|
+
self.messages.append(
|
|
83
|
+
AgentFunctionCallOutput(
|
|
84
|
+
type="function_call_output",
|
|
85
|
+
call_id=tool_call.id,
|
|
86
|
+
output="Transfer already executed by previous call",
|
|
87
|
+
),
|
|
88
|
+
)
|
|
89
|
+
return # Stop processing other tool calls after transfer
|
|
90
|
+
async for tool_call_chunk in self.agent.handle_tool_calls(tool_calls, context=context):
|
|
91
|
+
if tool_call_chunk.type == "tool_call" and tool_call_chunk.type in includes:
|
|
92
|
+
yield tool_call_chunk
|
|
93
|
+
if tool_call_chunk.type == "tool_call_result":
|
|
94
|
+
if tool_call_chunk.type in includes:
|
|
95
|
+
yield tool_call_chunk
|
|
96
|
+
# Create function call output in responses format
|
|
97
|
+
self.messages.append(
|
|
98
|
+
AgentFunctionCallOutput(
|
|
99
|
+
type="function_call_output",
|
|
100
|
+
call_id=tool_call_chunk.tool_call_id,
|
|
101
|
+
output=tool_call_chunk.content,
|
|
102
|
+
),
|
|
103
|
+
)
|
|
13
104
|
|
|
14
|
-
def
|
|
105
|
+
async def _collect_all_chunks(self, stream: AsyncGenerator[AgentChunk, None]) -> list[AgentChunk]:
|
|
106
|
+
"""Collect all chunks from an async generator into a list."""
|
|
107
|
+
return [chunk async for chunk in stream]
|
|
108
|
+
|
|
109
|
+
def run(
|
|
15
110
|
self,
|
|
16
111
|
user_input: RunnerMessages | str,
|
|
17
112
|
max_steps: int = 20,
|
|
18
|
-
includes:
|
|
113
|
+
includes: Sequence[AgentChunkType] | None = None,
|
|
114
|
+
context: "Any | None" = None, # noqa: ANN401
|
|
115
|
+
record_to: PathLike | str | None = None,
|
|
19
116
|
) -> AsyncGenerator[AgentChunk, None]:
|
|
20
117
|
"""Run the agent and return a RunResponse object that can be asynchronously iterated for each chunk."""
|
|
21
|
-
|
|
22
|
-
includes = ["final_message", "usage", "tool_call", "tool_call_result", "tool_call_delta", "content_delta"]
|
|
118
|
+
includes = self._normalize_includes(includes)
|
|
23
119
|
if isinstance(user_input, str):
|
|
24
|
-
self.messages.append(
|
|
120
|
+
self.messages.append(AgentUserMessage(role="user", content=user_input))
|
|
25
121
|
else:
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
return self.
|
|
122
|
+
for message in user_input:
|
|
123
|
+
self.append_message(message)
|
|
124
|
+
return self._run(max_steps, includes, self._normalize_record_path(record_to), context=context)
|
|
29
125
|
|
|
30
|
-
async def
|
|
126
|
+
async def _run(self, max_steps: int, includes: Sequence[AgentChunkType], record_to: Path | None = None, context: "Any | None" = None) -> AsyncGenerator[AgentChunk, None]: # noqa: ANN401
|
|
31
127
|
"""Run the agent and return a RunResponse object that can be asynchronously iterated for each chunk."""
|
|
128
|
+
logger.debug(f"Running agent with messages: {self.messages}")
|
|
32
129
|
steps = 0
|
|
33
130
|
finish_reason = None
|
|
131
|
+
|
|
34
132
|
while finish_reason != "stop" and steps < max_steps:
|
|
35
|
-
resp = await self.agent.
|
|
133
|
+
resp = await self.agent.completion(self.messages, record_to_file=record_to)
|
|
36
134
|
async for chunk in resp:
|
|
37
|
-
if chunk
|
|
38
|
-
message = chunk["message"]
|
|
39
|
-
self.messages.append(message.model_dump())
|
|
40
|
-
finish_reason = chunk["finish_reason"]
|
|
41
|
-
elif chunk["type"] == "tool_call_result":
|
|
42
|
-
self.messages.append(
|
|
43
|
-
AgentToolCallMessage(
|
|
44
|
-
role="tool",
|
|
45
|
-
tool_call_id=chunk["tool_call_id"],
|
|
46
|
-
content=chunk["content"],
|
|
47
|
-
),
|
|
48
|
-
)
|
|
49
|
-
if chunk["type"] in includes:
|
|
135
|
+
if chunk.type in includes:
|
|
50
136
|
yield chunk
|
|
137
|
+
|
|
138
|
+
if chunk.type == "final_message":
|
|
139
|
+
message = chunk.message
|
|
140
|
+
# Convert to responses format and add to messages
|
|
141
|
+
await self._convert_final_message_to_responses_format(message)
|
|
142
|
+
finish_reason = chunk.finish_reason
|
|
143
|
+
if finish_reason == "tool_calls":
|
|
144
|
+
# Find pending function calls in responses format
|
|
145
|
+
pending_function_calls = self._find_pending_function_calls()
|
|
146
|
+
if pending_function_calls:
|
|
147
|
+
# Convert to ToolCall format for existing handler
|
|
148
|
+
tool_calls = self._convert_function_calls_to_tool_calls(pending_function_calls)
|
|
149
|
+
require_confirm_tools = await self.agent.list_require_confirm_tools(tool_calls)
|
|
150
|
+
if require_confirm_tools:
|
|
151
|
+
return
|
|
152
|
+
async for tool_chunk in self._handle_tool_calls(tool_calls, includes, context=context):
|
|
153
|
+
yield tool_chunk
|
|
51
154
|
steps += 1
|
|
155
|
+
|
|
156
|
+
async def run_continue_until_complete(
|
|
157
|
+
self,
|
|
158
|
+
max_steps: int = 20,
|
|
159
|
+
includes: list[AgentChunkType] | None = None,
|
|
160
|
+
record_to: PathLike | str | None = None,
|
|
161
|
+
) -> list[AgentChunk]:
|
|
162
|
+
resp = self.run_continue_stream(max_steps, includes, record_to=record_to)
|
|
163
|
+
return await self._collect_all_chunks(resp)
|
|
164
|
+
|
|
165
|
+
def run_continue_stream(
|
|
166
|
+
self,
|
|
167
|
+
max_steps: int = 20,
|
|
168
|
+
includes: list[AgentChunkType] | None = None,
|
|
169
|
+
record_to: PathLike | str | None = None,
|
|
170
|
+
context: "Any | None" = None, # noqa: ANN401
|
|
171
|
+
) -> AsyncGenerator[AgentChunk, None]:
|
|
172
|
+
return self._run_continue_stream(max_steps, includes, record_to=record_to, context=context)
|
|
173
|
+
|
|
174
|
+
async def _run_continue_stream(
|
|
175
|
+
self,
|
|
176
|
+
max_steps: int = 20,
|
|
177
|
+
includes: Sequence[AgentChunkType] | None = None,
|
|
178
|
+
record_to: PathLike | str | None = None,
|
|
179
|
+
context: "Any | None" = None, # noqa: ANN401
|
|
180
|
+
) -> AsyncGenerator[AgentChunk, None]:
|
|
181
|
+
"""Continue running the agent and return a RunResponse object that can be asynchronously iterated for each chunk."""
|
|
182
|
+
includes = self._normalize_includes(includes)
|
|
183
|
+
|
|
184
|
+
# Find pending function calls in responses format
|
|
185
|
+
pending_function_calls = self._find_pending_function_calls()
|
|
186
|
+
if pending_function_calls:
|
|
187
|
+
# Convert to ToolCall format for existing handler
|
|
188
|
+
tool_calls = self._convert_function_calls_to_tool_calls(pending_function_calls)
|
|
189
|
+
async for tool_chunk in self._handle_tool_calls(tool_calls, includes, context=context):
|
|
190
|
+
yield tool_chunk
|
|
191
|
+
async for chunk in self._run(max_steps, includes, self._normalize_record_path(record_to)):
|
|
192
|
+
if chunk.type in includes:
|
|
193
|
+
yield chunk
|
|
194
|
+
else:
|
|
195
|
+
# Check if there are any messages and what the last message is
|
|
196
|
+
if not self.messages:
|
|
197
|
+
msg = "Cannot continue running without a valid last message from the assistant."
|
|
198
|
+
raise ValueError(msg)
|
|
199
|
+
|
|
200
|
+
last_message = self.messages[-1]
|
|
201
|
+
if not (isinstance(last_message, AgentAssistantMessage) or (hasattr(last_message, "role") and getattr(last_message, "role", None) == "assistant")):
|
|
202
|
+
msg = "Cannot continue running without a valid last message from the assistant."
|
|
203
|
+
raise ValueError(msg)
|
|
204
|
+
|
|
205
|
+
# If we have an assistant message but no pending function calls,
|
|
206
|
+
# that means there's nothing to continue
|
|
207
|
+
msg = "Cannot continue running without pending function calls."
|
|
208
|
+
raise ValueError(msg)
|
|
209
|
+
|
|
210
|
+
async def run_until_complete(
|
|
211
|
+
self,
|
|
212
|
+
user_input: RunnerMessages | str,
|
|
213
|
+
max_steps: int = 20,
|
|
214
|
+
includes: list[AgentChunkType] | None = None,
|
|
215
|
+
record_to: PathLike | str | None = None,
|
|
216
|
+
) -> list[AgentChunk]:
|
|
217
|
+
"""Run the agent until it completes and return the final message."""
|
|
218
|
+
resp = self.run(user_input, max_steps, includes, record_to=record_to)
|
|
219
|
+
return await self._collect_all_chunks(resp)
|
|
220
|
+
|
|
221
|
+
async def _convert_final_message_to_responses_format(self, message: "AssistantMessage") -> None:
|
|
222
|
+
"""Convert a completions format final message to responses format messages."""
|
|
223
|
+
# The final message from the stream handler might still contain tool_calls
|
|
224
|
+
# We need to convert it to responses format
|
|
225
|
+
if hasattr(message, "tool_calls") and message.tool_calls:
|
|
226
|
+
# Add the assistant message without tool_calls
|
|
227
|
+
assistant_msg = AgentAssistantMessage(
|
|
228
|
+
role="assistant",
|
|
229
|
+
content=message.content,
|
|
230
|
+
)
|
|
231
|
+
self.messages.append(assistant_msg)
|
|
232
|
+
|
|
233
|
+
# Add function call messages
|
|
234
|
+
for tool_call in message.tool_calls:
|
|
235
|
+
function_call_msg = AgentFunctionToolCallMessage(
|
|
236
|
+
type="function_call",
|
|
237
|
+
function_call_id=tool_call.id,
|
|
238
|
+
name=tool_call.function.name,
|
|
239
|
+
arguments=tool_call.function.arguments or "",
|
|
240
|
+
content="",
|
|
241
|
+
)
|
|
242
|
+
self.messages.append(function_call_msg)
|
|
243
|
+
else:
|
|
244
|
+
# Regular assistant message without tool calls
|
|
245
|
+
assistant_msg = AgentAssistantMessage(
|
|
246
|
+
role="assistant",
|
|
247
|
+
content=message.content,
|
|
248
|
+
)
|
|
249
|
+
self.messages.append(assistant_msg)
|
|
250
|
+
|
|
251
|
+
def _find_pending_function_calls(self) -> list:
|
|
252
|
+
"""Find function call messages that don't have corresponding outputs yet."""
|
|
253
|
+
function_calls: list[AgentFunctionToolCallMessage] = []
|
|
254
|
+
function_call_ids = set()
|
|
255
|
+
|
|
256
|
+
# Collect all function call messages
|
|
257
|
+
for msg in reversed(self.messages):
|
|
258
|
+
if isinstance(msg, AgentFunctionToolCallMessage):
|
|
259
|
+
function_calls.append(msg)
|
|
260
|
+
function_call_ids.add(msg.function_call_id)
|
|
261
|
+
elif isinstance(msg, AgentFunctionCallOutput):
|
|
262
|
+
# Remove the corresponding function call from our list
|
|
263
|
+
function_call_ids.discard(msg.call_id)
|
|
264
|
+
elif isinstance(msg, AgentAssistantMessage):
|
|
265
|
+
# Stop when we hit the assistant message that initiated these calls
|
|
266
|
+
break
|
|
267
|
+
|
|
268
|
+
# Return only function calls that don't have outputs yet
|
|
269
|
+
return [fc for fc in function_calls if fc.function_call_id in function_call_ids]
|
|
270
|
+
|
|
271
|
+
def _convert_function_calls_to_tool_calls(self, function_calls: list[AgentFunctionToolCallMessage]) -> list[ToolCall]:
|
|
272
|
+
"""Convert function call messages to ToolCall objects for compatibility."""
|
|
273
|
+
|
|
274
|
+
tool_calls = []
|
|
275
|
+
for fc in function_calls:
|
|
276
|
+
tool_call = ToolCall(
|
|
277
|
+
id=fc.function_call_id,
|
|
278
|
+
type="function",
|
|
279
|
+
function=ToolCallFunction(
|
|
280
|
+
name=fc.name,
|
|
281
|
+
arguments=fc.arguments,
|
|
282
|
+
),
|
|
283
|
+
index=len(tool_calls),
|
|
284
|
+
)
|
|
285
|
+
tool_calls.append(tool_call)
|
|
286
|
+
return tool_calls
|
|
287
|
+
|
|
288
|
+
def append_message(self, message: RunnerMessage | dict) -> None:
|
|
289
|
+
if isinstance(message, RunnerMessage):
|
|
290
|
+
self.messages.append(message)
|
|
291
|
+
elif isinstance(message, dict):
|
|
292
|
+
# Handle different message types
|
|
293
|
+
message_type = message.get("type")
|
|
294
|
+
role = message.get("role")
|
|
295
|
+
|
|
296
|
+
if message_type == "function_call":
|
|
297
|
+
# Function call message
|
|
298
|
+
self.messages.append(AgentFunctionToolCallMessage.model_validate(message))
|
|
299
|
+
elif message_type == "function_call_output":
|
|
300
|
+
# Function call output message
|
|
301
|
+
self.messages.append(AgentFunctionCallOutput.model_validate(message))
|
|
302
|
+
elif role == "assistant" and "tool_calls" in message:
|
|
303
|
+
# Legacy assistant message with tool_calls - convert to responses format
|
|
304
|
+
# Add assistant message without tool_calls
|
|
305
|
+
assistant_msg = AgentAssistantMessage(
|
|
306
|
+
role="assistant",
|
|
307
|
+
content=message.get("content", ""),
|
|
308
|
+
)
|
|
309
|
+
self.messages.append(assistant_msg)
|
|
310
|
+
|
|
311
|
+
# Convert tool_calls to function call messages
|
|
312
|
+
for tool_call in message.get("tool_calls", []):
|
|
313
|
+
function_call_msg = AgentFunctionToolCallMessage(
|
|
314
|
+
type="function_call",
|
|
315
|
+
function_call_id=tool_call["id"],
|
|
316
|
+
name=tool_call["function"]["name"],
|
|
317
|
+
arguments=tool_call["function"]["arguments"],
|
|
318
|
+
content="",
|
|
319
|
+
)
|
|
320
|
+
self.messages.append(function_call_msg)
|
|
321
|
+
elif role:
|
|
322
|
+
# Regular role-based message
|
|
323
|
+
role_to_message_class = {
|
|
324
|
+
"user": AgentUserMessage,
|
|
325
|
+
"assistant": AgentAssistantMessage,
|
|
326
|
+
"system": AgentSystemMessage,
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
message_class = role_to_message_class.get(role)
|
|
330
|
+
if message_class:
|
|
331
|
+
self.messages.append(message_class.model_validate(message))
|
|
332
|
+
else:
|
|
333
|
+
msg = f"Unsupported message role: {role}"
|
|
334
|
+
raise ValueError(msg)
|
|
335
|
+
else:
|
|
336
|
+
msg = "Message must have a 'role' or 'type' field."
|
|
337
|
+
raise ValueError(msg)
|
|
338
|
+
|
|
339
|
+
async def _handle_agent_transfer(self, tool_call: ToolCall, _includes: Sequence[AgentChunkType]) -> None:
|
|
340
|
+
"""Handle agent transfer when transfer_to_agent tool is called.
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
tool_call: The transfer_to_agent tool call
|
|
344
|
+
_includes: The types of chunks to include in output (unused)
|
|
345
|
+
"""
|
|
346
|
+
|
|
347
|
+
# Parse the arguments to get the target agent name
|
|
348
|
+
try:
|
|
349
|
+
arguments = json.loads(tool_call.function.arguments or "{}")
|
|
350
|
+
target_agent_name = arguments.get("name")
|
|
351
|
+
except (json.JSONDecodeError, KeyError):
|
|
352
|
+
logger.error("Failed to parse transfer_to_agent arguments: %s", tool_call.function.arguments)
|
|
353
|
+
# Add error result to messages
|
|
354
|
+
self.messages.append(
|
|
355
|
+
AgentFunctionCallOutput(
|
|
356
|
+
type="function_call_output",
|
|
357
|
+
call_id=tool_call.id,
|
|
358
|
+
output="Failed to parse transfer arguments",
|
|
359
|
+
),
|
|
360
|
+
)
|
|
361
|
+
return
|
|
362
|
+
|
|
363
|
+
if not target_agent_name:
|
|
364
|
+
logger.error("No target agent name provided in transfer_to_agent call")
|
|
365
|
+
# Add error result to messages
|
|
366
|
+
self.messages.append(
|
|
367
|
+
AgentFunctionCallOutput(
|
|
368
|
+
type="function_call_output",
|
|
369
|
+
call_id=tool_call.id,
|
|
370
|
+
output="No target agent name provided",
|
|
371
|
+
),
|
|
372
|
+
)
|
|
373
|
+
return
|
|
374
|
+
|
|
375
|
+
# Find the target agent in handoffs
|
|
376
|
+
if not self.agent.handoffs:
|
|
377
|
+
logger.error("Current agent has no handoffs configured")
|
|
378
|
+
# Add error result to messages
|
|
379
|
+
self.messages.append(
|
|
380
|
+
AgentFunctionCallOutput(
|
|
381
|
+
type="function_call_output",
|
|
382
|
+
call_id=tool_call.id,
|
|
383
|
+
output="Current agent has no handoffs configured",
|
|
384
|
+
),
|
|
385
|
+
)
|
|
386
|
+
return
|
|
387
|
+
|
|
388
|
+
target_agent = None
|
|
389
|
+
for agent in self.agent.handoffs:
|
|
390
|
+
if agent.name == target_agent_name:
|
|
391
|
+
target_agent = agent
|
|
392
|
+
break
|
|
393
|
+
|
|
394
|
+
if not target_agent:
|
|
395
|
+
logger.error("Target agent '%s' not found in handoffs", target_agent_name)
|
|
396
|
+
# Add error result to messages
|
|
397
|
+
self.messages.append(
|
|
398
|
+
AgentFunctionCallOutput(
|
|
399
|
+
type="function_call_output",
|
|
400
|
+
call_id=tool_call.id,
|
|
401
|
+
output=f"Target agent '{target_agent_name}' not found in handoffs",
|
|
402
|
+
),
|
|
403
|
+
)
|
|
404
|
+
return
|
|
405
|
+
|
|
406
|
+
# Execute the transfer tool call to get the result
|
|
407
|
+
try:
|
|
408
|
+
result = await self.agent.fc.call_function_async(
|
|
409
|
+
tool_call.function.name,
|
|
410
|
+
tool_call.function.arguments or "",
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
# Add the tool call result to messages
|
|
414
|
+
self.messages.append(
|
|
415
|
+
AgentFunctionCallOutput(
|
|
416
|
+
type="function_call_output",
|
|
417
|
+
call_id=tool_call.id,
|
|
418
|
+
output=str(result),
|
|
419
|
+
),
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
# Switch to the target agent
|
|
423
|
+
logger.info("Transferring conversation from %s to %s", self.agent.name, target_agent_name)
|
|
424
|
+
self.agent = target_agent
|
|
425
|
+
|
|
426
|
+
except Exception as e:
|
|
427
|
+
logger.exception("Failed to execute transfer_to_agent tool call")
|
|
428
|
+
# Add error result to messages
|
|
429
|
+
self.messages.append(
|
|
430
|
+
AgentFunctionCallOutput(
|
|
431
|
+
type="function_call_output",
|
|
432
|
+
call_id=tool_call.id,
|
|
433
|
+
output=f"Transfer failed: {e!s}",
|
|
434
|
+
),
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
async def _handle_parent_transfer(self, tool_call: ToolCall, _includes: Sequence[AgentChunkType]) -> None:
|
|
438
|
+
"""Handle parent transfer when transfer_to_parent tool is called.
|
|
439
|
+
|
|
440
|
+
Args:
|
|
441
|
+
tool_call: The transfer_to_parent tool call
|
|
442
|
+
_includes: The types of chunks to include in output (unused)
|
|
443
|
+
"""
|
|
444
|
+
|
|
445
|
+
# Check if current agent has a parent
|
|
446
|
+
if not self.agent.parent:
|
|
447
|
+
logger.error("Current agent has no parent to transfer back to.")
|
|
448
|
+
# Add error result to messages
|
|
449
|
+
self.messages.append(
|
|
450
|
+
AgentFunctionCallOutput(
|
|
451
|
+
type="function_call_output",
|
|
452
|
+
call_id=tool_call.id,
|
|
453
|
+
output="Current agent has no parent to transfer back to",
|
|
454
|
+
),
|
|
455
|
+
)
|
|
456
|
+
return
|
|
457
|
+
|
|
458
|
+
# Execute the transfer tool call to get the result
|
|
459
|
+
try:
|
|
460
|
+
result = await self.agent.fc.call_function_async(
|
|
461
|
+
tool_call.function.name,
|
|
462
|
+
tool_call.function.arguments or "",
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
# Add the tool call result to messages
|
|
466
|
+
self.messages.append(
|
|
467
|
+
AgentFunctionCallOutput(
|
|
468
|
+
type="function_call_output",
|
|
469
|
+
call_id=tool_call.id,
|
|
470
|
+
output=str(result),
|
|
471
|
+
),
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
# Switch to the parent agent
|
|
475
|
+
logger.info("Transferring conversation from %s back to parent %s", self.agent.name, self.agent.parent.name)
|
|
476
|
+
self.agent = self.agent.parent
|
|
477
|
+
|
|
478
|
+
except Exception as e:
|
|
479
|
+
logger.exception("Failed to execute transfer_to_parent tool call")
|
|
480
|
+
# Add error result to messages
|
|
481
|
+
self.messages.append(
|
|
482
|
+
AgentFunctionCallOutput(
|
|
483
|
+
type="function_call_output",
|
|
484
|
+
call_id=tool_call.id,
|
|
485
|
+
output=f"Transfer to parent failed: {e!s}",
|
|
486
|
+
),
|
|
487
|
+
)
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
from collections.abc import AsyncGenerator
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
import aiofiles
|
|
5
|
+
import litellm
|
|
6
|
+
from aiofiles.threadpool.text import AsyncTextIOWrapper
|
|
7
|
+
from litellm.types.utils import Delta, ModelResponseStream, StreamingChoices
|
|
8
|
+
|
|
9
|
+
from lite_agent.loggers import logger
|
|
10
|
+
from lite_agent.processors import StreamChunkProcessor
|
|
11
|
+
from lite_agent.types import AgentChunk, CompletionRawChunk, ContentDeltaChunk, FinalMessageChunk, ToolCallDeltaChunk, UsageChunk
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def ensure_record_file(record_to: Path | None) -> Path | None:
|
|
15
|
+
if not record_to:
|
|
16
|
+
return None
|
|
17
|
+
if not record_to.parent.exists():
|
|
18
|
+
logger.warning('Record directory "%s" does not exist, creating it.', record_to.parent)
|
|
19
|
+
record_to.parent.mkdir(parents=True, exist_ok=True)
|
|
20
|
+
return record_to
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
async def process_chunk(
|
|
24
|
+
processor: StreamChunkProcessor,
|
|
25
|
+
chunk: ModelResponseStream,
|
|
26
|
+
record_file: AsyncTextIOWrapper | None = None,
|
|
27
|
+
) -> AsyncGenerator[AgentChunk, None]:
|
|
28
|
+
if record_file:
|
|
29
|
+
await record_file.write(chunk.model_dump_json() + "\n")
|
|
30
|
+
await record_file.flush()
|
|
31
|
+
yield CompletionRawChunk(type="completion_raw", raw=chunk)
|
|
32
|
+
usage_chunk = await handle_usage_chunk(processor, chunk)
|
|
33
|
+
if usage_chunk:
|
|
34
|
+
yield usage_chunk
|
|
35
|
+
return
|
|
36
|
+
if not chunk.choices:
|
|
37
|
+
return
|
|
38
|
+
choice = chunk.choices[0]
|
|
39
|
+
delta = choice.delta
|
|
40
|
+
for result in await handle_content_and_tool_calls(processor, chunk, choice, delta):
|
|
41
|
+
yield result
|
|
42
|
+
if choice.finish_reason:
|
|
43
|
+
current_message = processor.current_message
|
|
44
|
+
yield FinalMessageChunk(type="final_message", message=current_message, finish_reason=choice.finish_reason)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
async def handle_usage_chunk(processor: StreamChunkProcessor, chunk: ModelResponseStream) -> UsageChunk | None:
|
|
48
|
+
usage = processor.handle_usage_info(chunk)
|
|
49
|
+
if usage:
|
|
50
|
+
return UsageChunk(type="usage", usage=usage)
|
|
51
|
+
return None
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
async def handle_content_and_tool_calls(
|
|
55
|
+
processor: StreamChunkProcessor,
|
|
56
|
+
chunk: ModelResponseStream,
|
|
57
|
+
choice: StreamingChoices,
|
|
58
|
+
delta: Delta,
|
|
59
|
+
) -> list[AgentChunk]:
|
|
60
|
+
results: list[AgentChunk] = []
|
|
61
|
+
if not processor.is_initialized:
|
|
62
|
+
processor.initialize_message(chunk, choice)
|
|
63
|
+
if delta.content:
|
|
64
|
+
results.append(ContentDeltaChunk(type="content_delta", delta=delta.content))
|
|
65
|
+
processor.update_content(delta.content)
|
|
66
|
+
if delta.tool_calls is not None:
|
|
67
|
+
processor.update_tool_calls(delta.tool_calls)
|
|
68
|
+
if delta.tool_calls and processor.current_message.tool_calls:
|
|
69
|
+
results.extend(
|
|
70
|
+
[
|
|
71
|
+
ToolCallDeltaChunk(
|
|
72
|
+
type="tool_call_delta",
|
|
73
|
+
tool_call_id=processor.current_message.tool_calls[-1].id,
|
|
74
|
+
name=processor.current_message.tool_calls[-1].function.name,
|
|
75
|
+
arguments_delta=tool_call.function.arguments or "",
|
|
76
|
+
)
|
|
77
|
+
for tool_call in delta.tool_calls
|
|
78
|
+
if tool_call.function.arguments
|
|
79
|
+
],
|
|
80
|
+
)
|
|
81
|
+
return results
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
async def litellm_stream_handler(
|
|
85
|
+
resp: litellm.CustomStreamWrapper,
|
|
86
|
+
record_to: Path | None = None,
|
|
87
|
+
) -> AsyncGenerator[AgentChunk, None]:
|
|
88
|
+
"""
|
|
89
|
+
Optimized chunk handler
|
|
90
|
+
"""
|
|
91
|
+
processor = StreamChunkProcessor()
|
|
92
|
+
record_file: AsyncTextIOWrapper | None = None
|
|
93
|
+
record_path = ensure_record_file(record_to)
|
|
94
|
+
if record_path:
|
|
95
|
+
record_file = await aiofiles.open(record_path, "w", encoding="utf-8")
|
|
96
|
+
try:
|
|
97
|
+
async for chunk in resp: # type: ignore
|
|
98
|
+
if not isinstance(chunk, ModelResponseStream):
|
|
99
|
+
logger.warning("unexpected chunk type: %s", type(chunk))
|
|
100
|
+
logger.warning("chunk content: %s", chunk)
|
|
101
|
+
continue
|
|
102
|
+
async for result in process_chunk(processor, chunk, record_file):
|
|
103
|
+
yield result
|
|
104
|
+
finally:
|
|
105
|
+
if record_file:
|
|
106
|
+
await record_file.close()
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
# Export all types from submodules
|
|
2
|
+
from .chunks import (
|
|
3
|
+
AgentChunk,
|
|
4
|
+
AgentChunkType,
|
|
5
|
+
CompletionRawChunk,
|
|
6
|
+
ContentDeltaChunk,
|
|
7
|
+
FinalMessageChunk,
|
|
8
|
+
ToolCallChunk,
|
|
9
|
+
ToolCallDeltaChunk,
|
|
10
|
+
ToolCallResultChunk,
|
|
11
|
+
UsageChunk,
|
|
12
|
+
)
|
|
13
|
+
from .messages import (
|
|
14
|
+
AgentAssistantMessage,
|
|
15
|
+
AgentFunctionCallOutput,
|
|
16
|
+
AgentFunctionToolCallMessage,
|
|
17
|
+
AgentMessage,
|
|
18
|
+
AgentSystemMessage,
|
|
19
|
+
AgentUserMessage,
|
|
20
|
+
AssistantMessage,
|
|
21
|
+
Message,
|
|
22
|
+
RunnerMessage,
|
|
23
|
+
RunnerMessages,
|
|
24
|
+
UserMessageContentItemImageURL,
|
|
25
|
+
UserMessageContentItemImageURLImageURL,
|
|
26
|
+
UserMessageContentItemText,
|
|
27
|
+
)
|
|
28
|
+
from .tool_calls import ToolCall, ToolCallFunction
|
|
29
|
+
|
|
30
|
+
__all__ = [
|
|
31
|
+
"AgentAssistantMessage",
|
|
32
|
+
"AgentChunk",
|
|
33
|
+
"AgentChunkType",
|
|
34
|
+
"AgentFunctionCallOutput",
|
|
35
|
+
"AgentFunctionToolCallMessage",
|
|
36
|
+
"AgentMessage",
|
|
37
|
+
"AgentSystemMessage",
|
|
38
|
+
"AgentUserMessage",
|
|
39
|
+
"AssistantMessage",
|
|
40
|
+
"CompletionRawChunk",
|
|
41
|
+
"ContentDeltaChunk",
|
|
42
|
+
"FinalMessageChunk",
|
|
43
|
+
"Message",
|
|
44
|
+
"RunnerMessage",
|
|
45
|
+
"RunnerMessages",
|
|
46
|
+
"ToolCall",
|
|
47
|
+
"ToolCallChunk",
|
|
48
|
+
"ToolCallDeltaChunk",
|
|
49
|
+
"ToolCallFunction",
|
|
50
|
+
"ToolCallResultChunk",
|
|
51
|
+
"UsageChunk",
|
|
52
|
+
"UserMessageContentItemImageURL",
|
|
53
|
+
"UserMessageContentItemImageURLImageURL",
|
|
54
|
+
"UserMessageContentItemText",
|
|
55
|
+
]
|