lite-agent 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lite-agent might be problematic. Click here for more details.

lite_agent/runner.py CHANGED
@@ -1,51 +1,487 @@
1
- from collections.abc import AsyncGenerator
2
- from typing import Literal
1
+ import json
2
+ from collections.abc import AsyncGenerator, Sequence
3
+ from os import PathLike
4
+ from pathlib import Path
5
+ from typing import TYPE_CHECKING, Any
3
6
 
4
- from open_agents.agent import Agent
5
- from open_agents.chunk_handler import AgentChunk
6
- from open_agents.types import AgentToolCallMessage, RunnerMessages
7
+ from lite_agent.agent import Agent
8
+ from lite_agent.loggers import logger
9
+ from lite_agent.types import (
10
+ AgentAssistantMessage,
11
+ AgentChunk,
12
+ AgentChunkType,
13
+ AgentFunctionCallOutput,
14
+ AgentFunctionToolCallMessage,
15
+ AgentSystemMessage,
16
+ AgentUserMessage,
17
+ RunnerMessage,
18
+ RunnerMessages,
19
+ ToolCall,
20
+ ToolCallFunction,
21
+ )
22
+
23
+ if TYPE_CHECKING:
24
+ from lite_agent.types import AssistantMessage
25
+
26
+ DEFAULT_INCLUDES: tuple[AgentChunkType, ...] = (
27
+ "completion_raw",
28
+ "usage",
29
+ "final_message",
30
+ "tool_call",
31
+ "tool_call_result",
32
+ "content_delta",
33
+ "tool_call_delta",
34
+ )
7
35
 
8
36
 
9
37
  class Runner:
10
38
  def __init__(self, agent: Agent) -> None:
11
39
  self.agent = agent
12
- self.messages: RunnerMessages = []
40
+ self.messages: list[RunnerMessage] = []
41
+
42
+ def _normalize_includes(self, includes: Sequence[AgentChunkType] | None) -> Sequence[AgentChunkType]:
43
+ """Normalize includes parameter to default if None."""
44
+ return includes if includes is not None else DEFAULT_INCLUDES
45
+
46
+ def _normalize_record_path(self, record_to: PathLike | str | None) -> Path | None:
47
+ """Normalize record_to parameter to Path object if provided."""
48
+ return Path(record_to) if record_to else None
49
+
50
+ async def _handle_tool_calls(self, tool_calls: "Sequence[ToolCall] | None", includes: Sequence[AgentChunkType], context: "Any | None" = None) -> AsyncGenerator[AgentChunk, None]: # noqa: ANN401, C901, PLR0912
51
+ """Handle tool calls and yield appropriate chunks."""
52
+ if not tool_calls:
53
+ return
54
+
55
+ # Check for transfer_to_agent calls first
56
+ transfer_calls = [tc for tc in tool_calls if tc.function.name == "transfer_to_agent"]
57
+ if transfer_calls:
58
+ # Handle all transfer calls but only execute the first one
59
+ for i, tool_call in enumerate(transfer_calls):
60
+ if i == 0:
61
+ # Execute the first transfer
62
+ await self._handle_agent_transfer(tool_call, includes)
63
+ else:
64
+ # Add response for additional transfer calls without executing them
65
+ self.messages.append(
66
+ AgentFunctionCallOutput(
67
+ type="function_call_output",
68
+ call_id=tool_call.id,
69
+ output="Transfer already executed by previous call",
70
+ ),
71
+ )
72
+ return # Stop processing other tool calls after transfer
73
+ return_parent_calls = [tc for tc in tool_calls if tc.function.name == "transfer_to_parent"]
74
+ if return_parent_calls:
75
+ # Handle multiple transfer_to_parent calls (only execute the first one)
76
+ for i, tool_call in enumerate(return_parent_calls):
77
+ if i == 0:
78
+ # Execute the first transfer
79
+ await self._handle_parent_transfer(tool_call, includes)
80
+ else:
81
+ # Add response for additional transfer calls without executing them
82
+ self.messages.append(
83
+ AgentFunctionCallOutput(
84
+ type="function_call_output",
85
+ call_id=tool_call.id,
86
+ output="Transfer already executed by previous call",
87
+ ),
88
+ )
89
+ return # Stop processing other tool calls after transfer
90
+ async for tool_call_chunk in self.agent.handle_tool_calls(tool_calls, context=context):
91
+ if tool_call_chunk.type == "tool_call" and tool_call_chunk.type in includes:
92
+ yield tool_call_chunk
93
+ if tool_call_chunk.type == "tool_call_result":
94
+ if tool_call_chunk.type in includes:
95
+ yield tool_call_chunk
96
+ # Create function call output in responses format
97
+ self.messages.append(
98
+ AgentFunctionCallOutput(
99
+ type="function_call_output",
100
+ call_id=tool_call_chunk.tool_call_id,
101
+ output=tool_call_chunk.content,
102
+ ),
103
+ )
13
104
 
14
- def run_stream(
105
+ async def _collect_all_chunks(self, stream: AsyncGenerator[AgentChunk, None]) -> list[AgentChunk]:
106
+ """Collect all chunks from an async generator into a list."""
107
+ return [chunk async for chunk in stream]
108
+
109
+ def run(
15
110
  self,
16
111
  user_input: RunnerMessages | str,
17
112
  max_steps: int = 20,
18
- includes: list[Literal["usage", "final_message", "tool_call", "tool_call_result"]] | None = None,
113
+ includes: Sequence[AgentChunkType] | None = None,
114
+ context: "Any | None" = None, # noqa: ANN401
115
+ record_to: PathLike | str | None = None,
19
116
  ) -> AsyncGenerator[AgentChunk, None]:
20
117
  """Run the agent and return a RunResponse object that can be asynchronously iterated for each chunk."""
21
- if includes is None:
22
- includes = ["final_message", "usage", "tool_call", "tool_call_result", "tool_call_delta", "content_delta"]
118
+ includes = self._normalize_includes(includes)
23
119
  if isinstance(user_input, str):
24
- self.messages.append({"role": "user", "content": user_input})
120
+ self.messages.append(AgentUserMessage(role="user", content=user_input))
25
121
  else:
26
- self.messages = user_input
27
-
28
- return self._run_aiter(max_steps, includes)
122
+ for message in user_input:
123
+ self.append_message(message)
124
+ return self._run(max_steps, includes, self._normalize_record_path(record_to), context=context)
29
125
 
30
- async def _run_aiter(self, max_steps: int, includes: list[Literal["usage", "final_message", "tool_call", "tool_call_result"]]) -> AsyncGenerator[AgentChunk, None]:
126
+ async def _run(self, max_steps: int, includes: Sequence[AgentChunkType], record_to: Path | None = None, context: "Any | None" = None) -> AsyncGenerator[AgentChunk, None]: # noqa: ANN401
31
127
  """Run the agent and return a RunResponse object that can be asynchronously iterated for each chunk."""
128
+ logger.debug(f"Running agent with messages: {self.messages}")
32
129
  steps = 0
33
130
  finish_reason = None
131
+
34
132
  while finish_reason != "stop" and steps < max_steps:
35
- resp = await self.agent.stream_async(self.messages)
133
+ resp = await self.agent.completion(self.messages, record_to_file=record_to)
36
134
  async for chunk in resp:
37
- if chunk["type"] == "final_message":
38
- message = chunk["message"]
39
- self.messages.append(message.model_dump())
40
- finish_reason = chunk["finish_reason"]
41
- elif chunk["type"] == "tool_call_result":
42
- self.messages.append(
43
- AgentToolCallMessage(
44
- role="tool",
45
- tool_call_id=chunk["tool_call_id"],
46
- content=chunk["content"],
47
- ),
48
- )
49
- if chunk["type"] in includes:
135
+ if chunk.type in includes:
50
136
  yield chunk
137
+
138
+ if chunk.type == "final_message":
139
+ message = chunk.message
140
+ # Convert to responses format and add to messages
141
+ await self._convert_final_message_to_responses_format(message)
142
+ finish_reason = chunk.finish_reason
143
+ if finish_reason == "tool_calls":
144
+ # Find pending function calls in responses format
145
+ pending_function_calls = self._find_pending_function_calls()
146
+ if pending_function_calls:
147
+ # Convert to ToolCall format for existing handler
148
+ tool_calls = self._convert_function_calls_to_tool_calls(pending_function_calls)
149
+ require_confirm_tools = await self.agent.list_require_confirm_tools(tool_calls)
150
+ if require_confirm_tools:
151
+ return
152
+ async for tool_chunk in self._handle_tool_calls(tool_calls, includes, context=context):
153
+ yield tool_chunk
51
154
  steps += 1
155
+
156
+ async def run_continue_until_complete(
157
+ self,
158
+ max_steps: int = 20,
159
+ includes: list[AgentChunkType] | None = None,
160
+ record_to: PathLike | str | None = None,
161
+ ) -> list[AgentChunk]:
162
+ resp = self.run_continue_stream(max_steps, includes, record_to=record_to)
163
+ return await self._collect_all_chunks(resp)
164
+
165
+ def run_continue_stream(
166
+ self,
167
+ max_steps: int = 20,
168
+ includes: list[AgentChunkType] | None = None,
169
+ record_to: PathLike | str | None = None,
170
+ context: "Any | None" = None, # noqa: ANN401
171
+ ) -> AsyncGenerator[AgentChunk, None]:
172
+ return self._run_continue_stream(max_steps, includes, record_to=record_to, context=context)
173
+
174
+ async def _run_continue_stream(
175
+ self,
176
+ max_steps: int = 20,
177
+ includes: Sequence[AgentChunkType] | None = None,
178
+ record_to: PathLike | str | None = None,
179
+ context: "Any | None" = None, # noqa: ANN401
180
+ ) -> AsyncGenerator[AgentChunk, None]:
181
+ """Continue running the agent and return a RunResponse object that can be asynchronously iterated for each chunk."""
182
+ includes = self._normalize_includes(includes)
183
+
184
+ # Find pending function calls in responses format
185
+ pending_function_calls = self._find_pending_function_calls()
186
+ if pending_function_calls:
187
+ # Convert to ToolCall format for existing handler
188
+ tool_calls = self._convert_function_calls_to_tool_calls(pending_function_calls)
189
+ async for tool_chunk in self._handle_tool_calls(tool_calls, includes, context=context):
190
+ yield tool_chunk
191
+ async for chunk in self._run(max_steps, includes, self._normalize_record_path(record_to)):
192
+ if chunk.type in includes:
193
+ yield chunk
194
+ else:
195
+ # Check if there are any messages and what the last message is
196
+ if not self.messages:
197
+ msg = "Cannot continue running without a valid last message from the assistant."
198
+ raise ValueError(msg)
199
+
200
+ last_message = self.messages[-1]
201
+ if not (isinstance(last_message, AgentAssistantMessage) or (hasattr(last_message, "role") and getattr(last_message, "role", None) == "assistant")):
202
+ msg = "Cannot continue running without a valid last message from the assistant."
203
+ raise ValueError(msg)
204
+
205
+ # If we have an assistant message but no pending function calls,
206
+ # that means there's nothing to continue
207
+ msg = "Cannot continue running without pending function calls."
208
+ raise ValueError(msg)
209
+
210
+ async def run_until_complete(
211
+ self,
212
+ user_input: RunnerMessages | str,
213
+ max_steps: int = 20,
214
+ includes: list[AgentChunkType] | None = None,
215
+ record_to: PathLike | str | None = None,
216
+ ) -> list[AgentChunk]:
217
+ """Run the agent until it completes and return the final message."""
218
+ resp = self.run(user_input, max_steps, includes, record_to=record_to)
219
+ return await self._collect_all_chunks(resp)
220
+
221
+ async def _convert_final_message_to_responses_format(self, message: "AssistantMessage") -> None:
222
+ """Convert a completions format final message to responses format messages."""
223
+ # The final message from the stream handler might still contain tool_calls
224
+ # We need to convert it to responses format
225
+ if hasattr(message, "tool_calls") and message.tool_calls:
226
+ # Add the assistant message without tool_calls
227
+ assistant_msg = AgentAssistantMessage(
228
+ role="assistant",
229
+ content=message.content,
230
+ )
231
+ self.messages.append(assistant_msg)
232
+
233
+ # Add function call messages
234
+ for tool_call in message.tool_calls:
235
+ function_call_msg = AgentFunctionToolCallMessage(
236
+ type="function_call",
237
+ function_call_id=tool_call.id,
238
+ name=tool_call.function.name,
239
+ arguments=tool_call.function.arguments or "",
240
+ content="",
241
+ )
242
+ self.messages.append(function_call_msg)
243
+ else:
244
+ # Regular assistant message without tool calls
245
+ assistant_msg = AgentAssistantMessage(
246
+ role="assistant",
247
+ content=message.content,
248
+ )
249
+ self.messages.append(assistant_msg)
250
+
251
+ def _find_pending_function_calls(self) -> list:
252
+ """Find function call messages that don't have corresponding outputs yet."""
253
+ function_calls: list[AgentFunctionToolCallMessage] = []
254
+ function_call_ids = set()
255
+
256
+ # Collect all function call messages
257
+ for msg in reversed(self.messages):
258
+ if isinstance(msg, AgentFunctionToolCallMessage):
259
+ function_calls.append(msg)
260
+ function_call_ids.add(msg.function_call_id)
261
+ elif isinstance(msg, AgentFunctionCallOutput):
262
+ # Remove the corresponding function call from our list
263
+ function_call_ids.discard(msg.call_id)
264
+ elif isinstance(msg, AgentAssistantMessage):
265
+ # Stop when we hit the assistant message that initiated these calls
266
+ break
267
+
268
+ # Return only function calls that don't have outputs yet
269
+ return [fc for fc in function_calls if fc.function_call_id in function_call_ids]
270
+
271
+ def _convert_function_calls_to_tool_calls(self, function_calls: list[AgentFunctionToolCallMessage]) -> list[ToolCall]:
272
+ """Convert function call messages to ToolCall objects for compatibility."""
273
+
274
+ tool_calls = []
275
+ for fc in function_calls:
276
+ tool_call = ToolCall(
277
+ id=fc.function_call_id,
278
+ type="function",
279
+ function=ToolCallFunction(
280
+ name=fc.name,
281
+ arguments=fc.arguments,
282
+ ),
283
+ index=len(tool_calls),
284
+ )
285
+ tool_calls.append(tool_call)
286
+ return tool_calls
287
+
288
+ def append_message(self, message: RunnerMessage | dict) -> None:
289
+ if isinstance(message, RunnerMessage):
290
+ self.messages.append(message)
291
+ elif isinstance(message, dict):
292
+ # Handle different message types
293
+ message_type = message.get("type")
294
+ role = message.get("role")
295
+
296
+ if message_type == "function_call":
297
+ # Function call message
298
+ self.messages.append(AgentFunctionToolCallMessage.model_validate(message))
299
+ elif message_type == "function_call_output":
300
+ # Function call output message
301
+ self.messages.append(AgentFunctionCallOutput.model_validate(message))
302
+ elif role == "assistant" and "tool_calls" in message:
303
+ # Legacy assistant message with tool_calls - convert to responses format
304
+ # Add assistant message without tool_calls
305
+ assistant_msg = AgentAssistantMessage(
306
+ role="assistant",
307
+ content=message.get("content", ""),
308
+ )
309
+ self.messages.append(assistant_msg)
310
+
311
+ # Convert tool_calls to function call messages
312
+ for tool_call in message.get("tool_calls", []):
313
+ function_call_msg = AgentFunctionToolCallMessage(
314
+ type="function_call",
315
+ function_call_id=tool_call["id"],
316
+ name=tool_call["function"]["name"],
317
+ arguments=tool_call["function"]["arguments"],
318
+ content="",
319
+ )
320
+ self.messages.append(function_call_msg)
321
+ elif role:
322
+ # Regular role-based message
323
+ role_to_message_class = {
324
+ "user": AgentUserMessage,
325
+ "assistant": AgentAssistantMessage,
326
+ "system": AgentSystemMessage,
327
+ }
328
+
329
+ message_class = role_to_message_class.get(role)
330
+ if message_class:
331
+ self.messages.append(message_class.model_validate(message))
332
+ else:
333
+ msg = f"Unsupported message role: {role}"
334
+ raise ValueError(msg)
335
+ else:
336
+ msg = "Message must have a 'role' or 'type' field."
337
+ raise ValueError(msg)
338
+
339
+ async def _handle_agent_transfer(self, tool_call: ToolCall, _includes: Sequence[AgentChunkType]) -> None:
340
+ """Handle agent transfer when transfer_to_agent tool is called.
341
+
342
+ Args:
343
+ tool_call: The transfer_to_agent tool call
344
+ _includes: The types of chunks to include in output (unused)
345
+ """
346
+
347
+ # Parse the arguments to get the target agent name
348
+ try:
349
+ arguments = json.loads(tool_call.function.arguments or "{}")
350
+ target_agent_name = arguments.get("name")
351
+ except (json.JSONDecodeError, KeyError):
352
+ logger.error("Failed to parse transfer_to_agent arguments: %s", tool_call.function.arguments)
353
+ # Add error result to messages
354
+ self.messages.append(
355
+ AgentFunctionCallOutput(
356
+ type="function_call_output",
357
+ call_id=tool_call.id,
358
+ output="Failed to parse transfer arguments",
359
+ ),
360
+ )
361
+ return
362
+
363
+ if not target_agent_name:
364
+ logger.error("No target agent name provided in transfer_to_agent call")
365
+ # Add error result to messages
366
+ self.messages.append(
367
+ AgentFunctionCallOutput(
368
+ type="function_call_output",
369
+ call_id=tool_call.id,
370
+ output="No target agent name provided",
371
+ ),
372
+ )
373
+ return
374
+
375
+ # Find the target agent in handoffs
376
+ if not self.agent.handoffs:
377
+ logger.error("Current agent has no handoffs configured")
378
+ # Add error result to messages
379
+ self.messages.append(
380
+ AgentFunctionCallOutput(
381
+ type="function_call_output",
382
+ call_id=tool_call.id,
383
+ output="Current agent has no handoffs configured",
384
+ ),
385
+ )
386
+ return
387
+
388
+ target_agent = None
389
+ for agent in self.agent.handoffs:
390
+ if agent.name == target_agent_name:
391
+ target_agent = agent
392
+ break
393
+
394
+ if not target_agent:
395
+ logger.error("Target agent '%s' not found in handoffs", target_agent_name)
396
+ # Add error result to messages
397
+ self.messages.append(
398
+ AgentFunctionCallOutput(
399
+ type="function_call_output",
400
+ call_id=tool_call.id,
401
+ output=f"Target agent '{target_agent_name}' not found in handoffs",
402
+ ),
403
+ )
404
+ return
405
+
406
+ # Execute the transfer tool call to get the result
407
+ try:
408
+ result = await self.agent.fc.call_function_async(
409
+ tool_call.function.name,
410
+ tool_call.function.arguments or "",
411
+ )
412
+
413
+ # Add the tool call result to messages
414
+ self.messages.append(
415
+ AgentFunctionCallOutput(
416
+ type="function_call_output",
417
+ call_id=tool_call.id,
418
+ output=str(result),
419
+ ),
420
+ )
421
+
422
+ # Switch to the target agent
423
+ logger.info("Transferring conversation from %s to %s", self.agent.name, target_agent_name)
424
+ self.agent = target_agent
425
+
426
+ except Exception as e:
427
+ logger.exception("Failed to execute transfer_to_agent tool call")
428
+ # Add error result to messages
429
+ self.messages.append(
430
+ AgentFunctionCallOutput(
431
+ type="function_call_output",
432
+ call_id=tool_call.id,
433
+ output=f"Transfer failed: {e!s}",
434
+ ),
435
+ )
436
+
437
+ async def _handle_parent_transfer(self, tool_call: ToolCall, _includes: Sequence[AgentChunkType]) -> None:
438
+ """Handle parent transfer when transfer_to_parent tool is called.
439
+
440
+ Args:
441
+ tool_call: The transfer_to_parent tool call
442
+ _includes: The types of chunks to include in output (unused)
443
+ """
444
+
445
+ # Check if current agent has a parent
446
+ if not self.agent.parent:
447
+ logger.error("Current agent has no parent to transfer back to.")
448
+ # Add error result to messages
449
+ self.messages.append(
450
+ AgentFunctionCallOutput(
451
+ type="function_call_output",
452
+ call_id=tool_call.id,
453
+ output="Current agent has no parent to transfer back to",
454
+ ),
455
+ )
456
+ return
457
+
458
+ # Execute the transfer tool call to get the result
459
+ try:
460
+ result = await self.agent.fc.call_function_async(
461
+ tool_call.function.name,
462
+ tool_call.function.arguments or "",
463
+ )
464
+
465
+ # Add the tool call result to messages
466
+ self.messages.append(
467
+ AgentFunctionCallOutput(
468
+ type="function_call_output",
469
+ call_id=tool_call.id,
470
+ output=str(result),
471
+ ),
472
+ )
473
+
474
+ # Switch to the parent agent
475
+ logger.info("Transferring conversation from %s back to parent %s", self.agent.name, self.agent.parent.name)
476
+ self.agent = self.agent.parent
477
+
478
+ except Exception as e:
479
+ logger.exception("Failed to execute transfer_to_parent tool call")
480
+ # Add error result to messages
481
+ self.messages.append(
482
+ AgentFunctionCallOutput(
483
+ type="function_call_output",
484
+ call_id=tool_call.id,
485
+ output=f"Transfer to parent failed: {e!s}",
486
+ ),
487
+ )
@@ -0,0 +1,5 @@
1
+ from lite_agent.stream_handlers.litellm import litellm_stream_handler
2
+
3
+ __all__ = [
4
+ "litellm_stream_handler",
5
+ ]
@@ -0,0 +1,106 @@
1
+ from collections.abc import AsyncGenerator
2
+ from pathlib import Path
3
+
4
+ import aiofiles
5
+ import litellm
6
+ from aiofiles.threadpool.text import AsyncTextIOWrapper
7
+ from litellm.types.utils import Delta, ModelResponseStream, StreamingChoices
8
+
9
+ from lite_agent.loggers import logger
10
+ from lite_agent.processors import StreamChunkProcessor
11
+ from lite_agent.types import AgentChunk, CompletionRawChunk, ContentDeltaChunk, FinalMessageChunk, ToolCallDeltaChunk, UsageChunk
12
+
13
+
14
+ def ensure_record_file(record_to: Path | None) -> Path | None:
15
+ if not record_to:
16
+ return None
17
+ if not record_to.parent.exists():
18
+ logger.warning('Record directory "%s" does not exist, creating it.', record_to.parent)
19
+ record_to.parent.mkdir(parents=True, exist_ok=True)
20
+ return record_to
21
+
22
+
23
+ async def process_chunk(
24
+ processor: StreamChunkProcessor,
25
+ chunk: ModelResponseStream,
26
+ record_file: AsyncTextIOWrapper | None = None,
27
+ ) -> AsyncGenerator[AgentChunk, None]:
28
+ if record_file:
29
+ await record_file.write(chunk.model_dump_json() + "\n")
30
+ await record_file.flush()
31
+ yield CompletionRawChunk(type="completion_raw", raw=chunk)
32
+ usage_chunk = await handle_usage_chunk(processor, chunk)
33
+ if usage_chunk:
34
+ yield usage_chunk
35
+ return
36
+ if not chunk.choices:
37
+ return
38
+ choice = chunk.choices[0]
39
+ delta = choice.delta
40
+ for result in await handle_content_and_tool_calls(processor, chunk, choice, delta):
41
+ yield result
42
+ if choice.finish_reason:
43
+ current_message = processor.current_message
44
+ yield FinalMessageChunk(type="final_message", message=current_message, finish_reason=choice.finish_reason)
45
+
46
+
47
+ async def handle_usage_chunk(processor: StreamChunkProcessor, chunk: ModelResponseStream) -> UsageChunk | None:
48
+ usage = processor.handle_usage_info(chunk)
49
+ if usage:
50
+ return UsageChunk(type="usage", usage=usage)
51
+ return None
52
+
53
+
54
+ async def handle_content_and_tool_calls(
55
+ processor: StreamChunkProcessor,
56
+ chunk: ModelResponseStream,
57
+ choice: StreamingChoices,
58
+ delta: Delta,
59
+ ) -> list[AgentChunk]:
60
+ results: list[AgentChunk] = []
61
+ if not processor.is_initialized:
62
+ processor.initialize_message(chunk, choice)
63
+ if delta.content:
64
+ results.append(ContentDeltaChunk(type="content_delta", delta=delta.content))
65
+ processor.update_content(delta.content)
66
+ if delta.tool_calls is not None:
67
+ processor.update_tool_calls(delta.tool_calls)
68
+ if delta.tool_calls and processor.current_message.tool_calls:
69
+ results.extend(
70
+ [
71
+ ToolCallDeltaChunk(
72
+ type="tool_call_delta",
73
+ tool_call_id=processor.current_message.tool_calls[-1].id,
74
+ name=processor.current_message.tool_calls[-1].function.name,
75
+ arguments_delta=tool_call.function.arguments or "",
76
+ )
77
+ for tool_call in delta.tool_calls
78
+ if tool_call.function.arguments
79
+ ],
80
+ )
81
+ return results
82
+
83
+
84
+ async def litellm_stream_handler(
85
+ resp: litellm.CustomStreamWrapper,
86
+ record_to: Path | None = None,
87
+ ) -> AsyncGenerator[AgentChunk, None]:
88
+ """
89
+ Optimized chunk handler
90
+ """
91
+ processor = StreamChunkProcessor()
92
+ record_file: AsyncTextIOWrapper | None = None
93
+ record_path = ensure_record_file(record_to)
94
+ if record_path:
95
+ record_file = await aiofiles.open(record_path, "w", encoding="utf-8")
96
+ try:
97
+ async for chunk in resp: # type: ignore
98
+ if not isinstance(chunk, ModelResponseStream):
99
+ logger.warning("unexpected chunk type: %s", type(chunk))
100
+ logger.warning("chunk content: %s", chunk)
101
+ continue
102
+ async for result in process_chunk(processor, chunk, record_file):
103
+ yield result
104
+ finally:
105
+ if record_file:
106
+ await record_file.close()
@@ -0,0 +1,55 @@
1
+ # Export all types from submodules
2
+ from .chunks import (
3
+ AgentChunk,
4
+ AgentChunkType,
5
+ CompletionRawChunk,
6
+ ContentDeltaChunk,
7
+ FinalMessageChunk,
8
+ ToolCallChunk,
9
+ ToolCallDeltaChunk,
10
+ ToolCallResultChunk,
11
+ UsageChunk,
12
+ )
13
+ from .messages import (
14
+ AgentAssistantMessage,
15
+ AgentFunctionCallOutput,
16
+ AgentFunctionToolCallMessage,
17
+ AgentMessage,
18
+ AgentSystemMessage,
19
+ AgentUserMessage,
20
+ AssistantMessage,
21
+ Message,
22
+ RunnerMessage,
23
+ RunnerMessages,
24
+ UserMessageContentItemImageURL,
25
+ UserMessageContentItemImageURLImageURL,
26
+ UserMessageContentItemText,
27
+ )
28
+ from .tool_calls import ToolCall, ToolCallFunction
29
+
30
+ __all__ = [
31
+ "AgentAssistantMessage",
32
+ "AgentChunk",
33
+ "AgentChunkType",
34
+ "AgentFunctionCallOutput",
35
+ "AgentFunctionToolCallMessage",
36
+ "AgentMessage",
37
+ "AgentSystemMessage",
38
+ "AgentUserMessage",
39
+ "AssistantMessage",
40
+ "CompletionRawChunk",
41
+ "ContentDeltaChunk",
42
+ "FinalMessageChunk",
43
+ "Message",
44
+ "RunnerMessage",
45
+ "RunnerMessages",
46
+ "ToolCall",
47
+ "ToolCallChunk",
48
+ "ToolCallDeltaChunk",
49
+ "ToolCallFunction",
50
+ "ToolCallResultChunk",
51
+ "UsageChunk",
52
+ "UserMessageContentItemImageURL",
53
+ "UserMessageContentItemImageURLImageURL",
54
+ "UserMessageContentItemText",
55
+ ]