lite-agent 0.2.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lite-agent might be problematic. Click here for more details.
- lite_agent/__init__.py +2 -1
- lite_agent/agent.py +249 -58
- lite_agent/chat_display.py +779 -0
- lite_agent/client.py +69 -0
- lite_agent/message_transfers.py +9 -1
- lite_agent/processors/__init__.py +3 -2
- lite_agent/processors/completion_event_processor.py +306 -0
- lite_agent/processors/response_event_processor.py +205 -0
- lite_agent/runner.py +553 -225
- lite_agent/stream_handlers/__init__.py +3 -2
- lite_agent/stream_handlers/litellm.py +37 -68
- lite_agent/templates/handoffs_source_instructions.xml.j2 +10 -0
- lite_agent/templates/handoffs_target_instructions.xml.j2 +9 -0
- lite_agent/templates/wait_for_user_instructions.xml.j2 +6 -0
- lite_agent/types/__init__.py +97 -23
- lite_agent/types/events.py +119 -0
- lite_agent/types/messages.py +308 -33
- {lite_agent-0.2.0.dist-info → lite_agent-0.4.0.dist-info}/METADATA +2 -2
- lite_agent-0.4.0.dist-info/RECORD +23 -0
- lite_agent/processors/stream_chunk_processor.py +0 -106
- lite_agent/types/chunks.py +0 -89
- lite_agent-0.2.0.dist-info/RECORD +0 -17
- {lite_agent-0.2.0.dist-info → lite_agent-0.4.0.dist-info}/WHEEL +0 -0
lite_agent/client.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import os
|
|
3
|
+
from collections.abc import AsyncGenerator
|
|
4
|
+
from typing import Any, Literal
|
|
5
|
+
|
|
6
|
+
import litellm
|
|
7
|
+
from litellm.types.llms.openai import ResponsesAPIStreamingResponse
|
|
8
|
+
from openai.types.chat import ChatCompletionToolParam
|
|
9
|
+
from openai.types.responses import FunctionToolParam
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BaseLLMClient(abc.ABC):
|
|
13
|
+
"""Base class for LLM clients."""
|
|
14
|
+
|
|
15
|
+
def __init__(self, *, model: str, api_key: str | None = None, api_base: str | None = None, api_version: str | None = None):
|
|
16
|
+
self.model = model
|
|
17
|
+
self.api_key = api_key
|
|
18
|
+
self.api_base = api_base
|
|
19
|
+
self.api_version = api_version
|
|
20
|
+
|
|
21
|
+
@abc.abstractmethod
|
|
22
|
+
async def completion(self, messages: list[Any], tools: list[ChatCompletionToolParam] | None = None, tool_choice: str = "auto") -> Any: # noqa: ANN401
|
|
23
|
+
"""Perform a completion request to the LLM."""
|
|
24
|
+
|
|
25
|
+
@abc.abstractmethod
|
|
26
|
+
async def responses(
|
|
27
|
+
self,
|
|
28
|
+
messages: list[dict[str, Any]], # Changed from ResponseInputParam
|
|
29
|
+
tools: list[FunctionToolParam] | None = None,
|
|
30
|
+
tool_choice: Literal["none", "auto", "required"] = "auto",
|
|
31
|
+
) -> AsyncGenerator[ResponsesAPIStreamingResponse, None]:
|
|
32
|
+
"""Perform a response request to the LLM."""
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class LiteLLMClient(BaseLLMClient):
|
|
36
|
+
async def completion(self, messages: list[Any], tools: list[ChatCompletionToolParam] | None = None, tool_choice: str = "auto") -> Any: # noqa: ANN401
|
|
37
|
+
"""Perform a completion request to the Litellm API."""
|
|
38
|
+
return await litellm.acompletion(
|
|
39
|
+
model=self.model,
|
|
40
|
+
messages=messages,
|
|
41
|
+
tools=tools,
|
|
42
|
+
tool_choice=tool_choice,
|
|
43
|
+
api_version=self.api_version,
|
|
44
|
+
api_key=self.api_key,
|
|
45
|
+
api_base=self.api_base,
|
|
46
|
+
stream=True,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
async def responses(
|
|
50
|
+
self,
|
|
51
|
+
messages: list[dict[str, Any]], # Changed from ResponseInputParam
|
|
52
|
+
tools: list[FunctionToolParam] | None = None,
|
|
53
|
+
tool_choice: Literal["none", "auto", "required"] = "auto",
|
|
54
|
+
) -> AsyncGenerator[ResponsesAPIStreamingResponse, None]:
|
|
55
|
+
"""Perform a response request to the Litellm API."""
|
|
56
|
+
|
|
57
|
+
os.environ["DISABLE_AIOHTTP_TRANSPORT"] = "True"
|
|
58
|
+
|
|
59
|
+
return await litellm.aresponses(
|
|
60
|
+
model=self.model,
|
|
61
|
+
input=messages, # type: ignore[arg-type]
|
|
62
|
+
tools=tools,
|
|
63
|
+
tool_choice=tool_choice,
|
|
64
|
+
api_version=self.api_version,
|
|
65
|
+
api_key=self.api_key,
|
|
66
|
+
api_base=self.api_base,
|
|
67
|
+
stream=True,
|
|
68
|
+
store=False,
|
|
69
|
+
)
|
lite_agent/message_transfers.py
CHANGED
|
@@ -64,7 +64,15 @@ def _process_message_to_xml(message: dict | object) -> list[str]:
|
|
|
64
64
|
# Handle Pydantic model format messages
|
|
65
65
|
role = getattr(message, "role", "unknown")
|
|
66
66
|
content = getattr(message, "content", "")
|
|
67
|
-
|
|
67
|
+
|
|
68
|
+
# Handle new message format where content is a list
|
|
69
|
+
if isinstance(content, list):
|
|
70
|
+
# Extract text from content items
|
|
71
|
+
text_parts = [item.text for item in content if (hasattr(item, "type") and item.type == "text") or hasattr(item, "text")]
|
|
72
|
+
content_text = " ".join(text_parts)
|
|
73
|
+
if content_text:
|
|
74
|
+
xml_lines.append(f" <message role='{role}'>{content_text}</message>")
|
|
75
|
+
elif isinstance(content, str):
|
|
68
76
|
xml_lines.append(f" <message role='{role}'>{content}</message>")
|
|
69
77
|
elif hasattr(message, "type"):
|
|
70
78
|
# Handle function call messages
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
-
from lite_agent.processors.
|
|
1
|
+
from lite_agent.processors.completion_event_processor import CompletionEventProcessor
|
|
2
|
+
from lite_agent.processors.response_event_processor import ResponseEventProcessor
|
|
2
3
|
|
|
3
|
-
__all__ = ["
|
|
4
|
+
__all__ = ["CompletionEventProcessor", "ResponseEventProcessor"]
|
|
@@ -0,0 +1,306 @@
|
|
|
1
|
+
from collections.abc import AsyncGenerator
|
|
2
|
+
from datetime import datetime, timezone
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
import litellm
|
|
6
|
+
from aiofiles.threadpool.text import AsyncTextIOWrapper
|
|
7
|
+
from litellm.types.utils import ChatCompletionDeltaToolCall, ModelResponseStream, StreamingChoices
|
|
8
|
+
|
|
9
|
+
from lite_agent.loggers import logger
|
|
10
|
+
from lite_agent.types import (
|
|
11
|
+
AgentChunk,
|
|
12
|
+
AssistantMessage,
|
|
13
|
+
AssistantMessageEvent,
|
|
14
|
+
AssistantMessageMeta,
|
|
15
|
+
AssistantTextContent,
|
|
16
|
+
CompletionRawEvent,
|
|
17
|
+
ContentDeltaEvent,
|
|
18
|
+
EventUsage,
|
|
19
|
+
FunctionCallDeltaEvent,
|
|
20
|
+
FunctionCallEvent,
|
|
21
|
+
MessageUsage,
|
|
22
|
+
NewAssistantMessage,
|
|
23
|
+
Timing,
|
|
24
|
+
TimingEvent,
|
|
25
|
+
ToolCall,
|
|
26
|
+
ToolCallFunction,
|
|
27
|
+
UsageEvent,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class CompletionEventProcessor:
|
|
32
|
+
"""Processor for handling completion event"""
|
|
33
|
+
|
|
34
|
+
def __init__(self) -> None:
|
|
35
|
+
self._current_message: AssistantMessage | None = None
|
|
36
|
+
self.processing_chunk: Literal["content", "tool_calls"] | None = None
|
|
37
|
+
self.processing_function: str | None = None
|
|
38
|
+
self.last_processed_chunk: ModelResponseStream | None = None
|
|
39
|
+
self.yielded_content = False
|
|
40
|
+
self.yielded_function = set()
|
|
41
|
+
self._start_time: datetime | None = None
|
|
42
|
+
self._first_output_time: datetime | None = None
|
|
43
|
+
self._output_complete_time: datetime | None = None
|
|
44
|
+
self._usage_time: datetime | None = None
|
|
45
|
+
self._usage_data: dict[str, int] = {}
|
|
46
|
+
|
|
47
|
+
async def process_chunk(
|
|
48
|
+
self,
|
|
49
|
+
chunk: ModelResponseStream,
|
|
50
|
+
record_file: AsyncTextIOWrapper | None = None,
|
|
51
|
+
) -> AsyncGenerator[AgentChunk, None]:
|
|
52
|
+
# Mark start time on first chunk
|
|
53
|
+
if self._start_time is None:
|
|
54
|
+
self._start_time = datetime.now(timezone.utc)
|
|
55
|
+
|
|
56
|
+
if record_file:
|
|
57
|
+
await record_file.write(chunk.model_dump_json() + "\n")
|
|
58
|
+
await record_file.flush()
|
|
59
|
+
yield CompletionRawEvent(raw=chunk)
|
|
60
|
+
usage_chunks = self.handle_usage_chunk(chunk)
|
|
61
|
+
if usage_chunks:
|
|
62
|
+
for usage_chunk in usage_chunks:
|
|
63
|
+
yield usage_chunk
|
|
64
|
+
return
|
|
65
|
+
if not chunk.choices:
|
|
66
|
+
return
|
|
67
|
+
|
|
68
|
+
choice = chunk.choices[0]
|
|
69
|
+
delta = choice.delta
|
|
70
|
+
if delta.tool_calls:
|
|
71
|
+
if not self.yielded_content:
|
|
72
|
+
self.yielded_content = True
|
|
73
|
+
end_time = datetime.now(timezone.utc)
|
|
74
|
+
latency_ms = None
|
|
75
|
+
output_time_ms = None
|
|
76
|
+
# latency_ms: 从开始准备输出到 LLM 输出第一个字符的时间差
|
|
77
|
+
if self._start_time and self._first_output_time:
|
|
78
|
+
latency_ms = int((self._first_output_time - self._start_time).total_seconds() * 1000)
|
|
79
|
+
# output_time_ms: 从输出第一个字符到输出完成的时间差
|
|
80
|
+
if self._first_output_time and self._output_complete_time:
|
|
81
|
+
output_time_ms = int((self._output_complete_time - self._first_output_time).total_seconds() * 1000)
|
|
82
|
+
|
|
83
|
+
usage = MessageUsage(
|
|
84
|
+
input_tokens=self._usage_data.get("input_tokens"),
|
|
85
|
+
output_tokens=self._usage_data.get("output_tokens"),
|
|
86
|
+
)
|
|
87
|
+
meta = AssistantMessageMeta(
|
|
88
|
+
sent_at=end_time,
|
|
89
|
+
latency_ms=latency_ms,
|
|
90
|
+
total_time_ms=output_time_ms,
|
|
91
|
+
usage=usage,
|
|
92
|
+
)
|
|
93
|
+
# Include accumulated text content in the message
|
|
94
|
+
content = []
|
|
95
|
+
if self._current_message and self._current_message.content:
|
|
96
|
+
content.append(AssistantTextContent(text=self._current_message.content))
|
|
97
|
+
|
|
98
|
+
yield AssistantMessageEvent(
|
|
99
|
+
message=NewAssistantMessage(
|
|
100
|
+
content=content,
|
|
101
|
+
meta=meta,
|
|
102
|
+
),
|
|
103
|
+
)
|
|
104
|
+
first_tool_call = delta.tool_calls[0]
|
|
105
|
+
tool_name = first_tool_call.function.name if first_tool_call.function else ""
|
|
106
|
+
if tool_name:
|
|
107
|
+
self.processing_function = tool_name
|
|
108
|
+
delta = choice.delta
|
|
109
|
+
if (
|
|
110
|
+
self._current_message
|
|
111
|
+
and self._current_message.tool_calls
|
|
112
|
+
and self.processing_function != self._current_message.tool_calls[-1].function.name
|
|
113
|
+
and self._current_message.tool_calls[-1].function.name not in self.yielded_function
|
|
114
|
+
):
|
|
115
|
+
tool_call = self._current_message.tool_calls[-1]
|
|
116
|
+
yield FunctionCallEvent(
|
|
117
|
+
call_id=tool_call.id,
|
|
118
|
+
name=tool_call.function.name,
|
|
119
|
+
arguments=tool_call.function.arguments or "",
|
|
120
|
+
)
|
|
121
|
+
self.yielded_function.add(tool_call.function.name)
|
|
122
|
+
if not self.is_initialized:
|
|
123
|
+
self.initialize_message(chunk, choice)
|
|
124
|
+
if delta.content and self._current_message:
|
|
125
|
+
# Mark first output time if not already set
|
|
126
|
+
if self._first_output_time is None:
|
|
127
|
+
self._first_output_time = datetime.now(timezone.utc)
|
|
128
|
+
self._current_message.content += delta.content
|
|
129
|
+
yield ContentDeltaEvent(delta=delta.content)
|
|
130
|
+
if delta.tool_calls is not None:
|
|
131
|
+
self.update_tool_calls(delta.tool_calls)
|
|
132
|
+
if delta.tool_calls and self.current_message.tool_calls:
|
|
133
|
+
tool_call = delta.tool_calls[0]
|
|
134
|
+
message_tool_call = self.current_message.tool_calls[-1]
|
|
135
|
+
yield FunctionCallDeltaEvent(
|
|
136
|
+
tool_call_id=message_tool_call.id,
|
|
137
|
+
name=message_tool_call.function.name,
|
|
138
|
+
arguments_delta=tool_call.function.arguments or "",
|
|
139
|
+
)
|
|
140
|
+
if choice.finish_reason:
|
|
141
|
+
# Mark output complete time when finish_reason appears
|
|
142
|
+
if self._output_complete_time is None:
|
|
143
|
+
self._output_complete_time = datetime.now(timezone.utc)
|
|
144
|
+
|
|
145
|
+
if self.current_message.tool_calls:
|
|
146
|
+
tool_call = self.current_message.tool_calls[-1]
|
|
147
|
+
yield FunctionCallEvent(
|
|
148
|
+
call_id=tool_call.id,
|
|
149
|
+
name=tool_call.function.name,
|
|
150
|
+
arguments=tool_call.function.arguments or "",
|
|
151
|
+
)
|
|
152
|
+
if not self.yielded_content:
|
|
153
|
+
self.yielded_content = True
|
|
154
|
+
end_time = datetime.now(timezone.utc)
|
|
155
|
+
latency_ms = None
|
|
156
|
+
output_time_ms = None
|
|
157
|
+
# latency_ms: 从开始准备输出到 LLM 输出第一个字符的时间差
|
|
158
|
+
if self._start_time and self._first_output_time:
|
|
159
|
+
latency_ms = int((self._first_output_time - self._start_time).total_seconds() * 1000)
|
|
160
|
+
# output_time_ms: 从输出第一个字符到输出完成的时间差
|
|
161
|
+
if self._first_output_time and self._output_complete_time:
|
|
162
|
+
output_time_ms = int((self._output_complete_time - self._first_output_time).total_seconds() * 1000)
|
|
163
|
+
|
|
164
|
+
usage = MessageUsage(
|
|
165
|
+
input_tokens=self._usage_data.get("input_tokens"),
|
|
166
|
+
output_tokens=self._usage_data.get("output_tokens"),
|
|
167
|
+
)
|
|
168
|
+
meta = AssistantMessageMeta(
|
|
169
|
+
sent_at=end_time,
|
|
170
|
+
latency_ms=latency_ms,
|
|
171
|
+
total_time_ms=output_time_ms,
|
|
172
|
+
usage=usage,
|
|
173
|
+
)
|
|
174
|
+
# Include accumulated text content in the message
|
|
175
|
+
content = []
|
|
176
|
+
if self._current_message and self._current_message.content:
|
|
177
|
+
content.append(AssistantTextContent(text=self._current_message.content))
|
|
178
|
+
|
|
179
|
+
yield AssistantMessageEvent(
|
|
180
|
+
message=NewAssistantMessage(
|
|
181
|
+
content=content,
|
|
182
|
+
meta=meta,
|
|
183
|
+
),
|
|
184
|
+
)
|
|
185
|
+
self.last_processed_chunk = chunk
|
|
186
|
+
|
|
187
|
+
def handle_usage_chunk(self, chunk: ModelResponseStream) -> list[AgentChunk]:
|
|
188
|
+
usage = getattr(chunk, "usage", None)
|
|
189
|
+
if usage:
|
|
190
|
+
# Mark usage time
|
|
191
|
+
self._usage_time = datetime.now(timezone.utc)
|
|
192
|
+
# Store usage data for meta information
|
|
193
|
+
self._usage_data["input_tokens"] = usage["prompt_tokens"]
|
|
194
|
+
self._usage_data["output_tokens"] = usage["completion_tokens"]
|
|
195
|
+
|
|
196
|
+
results = []
|
|
197
|
+
|
|
198
|
+
# First yield usage event
|
|
199
|
+
results.append(UsageEvent(usage=EventUsage(input_tokens=usage["prompt_tokens"], output_tokens=usage["completion_tokens"])))
|
|
200
|
+
|
|
201
|
+
# Then yield timing event if we have timing data
|
|
202
|
+
if self._start_time and self._first_output_time and self._output_complete_time:
|
|
203
|
+
latency_ms = int((self._first_output_time - self._start_time).total_seconds() * 1000)
|
|
204
|
+
output_time_ms = int((self._output_complete_time - self._first_output_time).total_seconds() * 1000)
|
|
205
|
+
|
|
206
|
+
results.append(
|
|
207
|
+
TimingEvent(
|
|
208
|
+
timing=Timing(
|
|
209
|
+
latency_ms=latency_ms,
|
|
210
|
+
output_time_ms=output_time_ms,
|
|
211
|
+
),
|
|
212
|
+
),
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
return results
|
|
216
|
+
return []
|
|
217
|
+
|
|
218
|
+
def initialize_message(self, chunk: ModelResponseStream, choice: StreamingChoices) -> None:
|
|
219
|
+
"""Initialize the message object"""
|
|
220
|
+
delta = choice.delta
|
|
221
|
+
if delta.role != "assistant":
|
|
222
|
+
logger.warning("Skipping chunk with role: %s", delta.role)
|
|
223
|
+
return
|
|
224
|
+
self._current_message = AssistantMessage(
|
|
225
|
+
id=chunk.id,
|
|
226
|
+
index=choice.index,
|
|
227
|
+
role=delta.role,
|
|
228
|
+
content="",
|
|
229
|
+
)
|
|
230
|
+
logger.debug('Initialized new message: "%s"', self._current_message.id)
|
|
231
|
+
|
|
232
|
+
def update_content(self, content: str) -> None:
|
|
233
|
+
"""Update message content"""
|
|
234
|
+
if self._current_message and content:
|
|
235
|
+
self._current_message.content += content
|
|
236
|
+
|
|
237
|
+
def _initialize_tool_calls(self, tool_calls: list[litellm.ChatCompletionMessageToolCall]) -> None:
|
|
238
|
+
"""Initialize tool calls"""
|
|
239
|
+
if not self._current_message:
|
|
240
|
+
return
|
|
241
|
+
|
|
242
|
+
self._current_message.tool_calls = []
|
|
243
|
+
for call in tool_calls:
|
|
244
|
+
logger.debug("Create new tool call: %s", call.id)
|
|
245
|
+
|
|
246
|
+
def _update_tool_calls(self, tool_calls: list[litellm.ChatCompletionMessageToolCall]) -> None:
|
|
247
|
+
"""Update existing tool calls"""
|
|
248
|
+
if not self._current_message:
|
|
249
|
+
return
|
|
250
|
+
if not hasattr(self._current_message, "tool_calls"):
|
|
251
|
+
self._current_message.tool_calls = []
|
|
252
|
+
if not self._current_message.tool_calls:
|
|
253
|
+
return
|
|
254
|
+
if not tool_calls:
|
|
255
|
+
return
|
|
256
|
+
for current_call, new_call in zip(self._current_message.tool_calls, tool_calls, strict=False):
|
|
257
|
+
if new_call.function.arguments and current_call.function.arguments:
|
|
258
|
+
current_call.function.arguments += new_call.function.arguments
|
|
259
|
+
if new_call.type and new_call.type == "function":
|
|
260
|
+
current_call.type = new_call.type
|
|
261
|
+
elif new_call.type:
|
|
262
|
+
logger.warning("Unexpected tool call type: %s", new_call.type)
|
|
263
|
+
|
|
264
|
+
def update_tool_calls(self, tool_calls: list[ChatCompletionDeltaToolCall]) -> None:
|
|
265
|
+
"""Handle tool call updates"""
|
|
266
|
+
if not tool_calls:
|
|
267
|
+
return
|
|
268
|
+
for call in tool_calls:
|
|
269
|
+
if call.id:
|
|
270
|
+
if call.type == "function":
|
|
271
|
+
new_tool_call = ToolCall(
|
|
272
|
+
id=call.id,
|
|
273
|
+
type=call.type,
|
|
274
|
+
function=ToolCallFunction(
|
|
275
|
+
name=call.function.name or "",
|
|
276
|
+
arguments=call.function.arguments,
|
|
277
|
+
),
|
|
278
|
+
index=call.index,
|
|
279
|
+
)
|
|
280
|
+
if self._current_message is not None:
|
|
281
|
+
if self._current_message.tool_calls is None:
|
|
282
|
+
self._current_message.tool_calls = []
|
|
283
|
+
self._current_message.tool_calls.append(new_tool_call)
|
|
284
|
+
else:
|
|
285
|
+
logger.warning("Unexpected tool call type: %s", call.type)
|
|
286
|
+
elif self._current_message is not None and self._current_message.tool_calls is not None and call.index is not None and 0 <= call.index < len(self._current_message.tool_calls):
|
|
287
|
+
existing_call = self._current_message.tool_calls[call.index]
|
|
288
|
+
if call.function.arguments:
|
|
289
|
+
if existing_call.function.arguments is None:
|
|
290
|
+
existing_call.function.arguments = ""
|
|
291
|
+
existing_call.function.arguments += call.function.arguments
|
|
292
|
+
else:
|
|
293
|
+
logger.warning("Cannot update tool call: current_message or tool_calls is None, or invalid index.")
|
|
294
|
+
|
|
295
|
+
@property
|
|
296
|
+
def is_initialized(self) -> bool:
|
|
297
|
+
"""Check if the current message is initialized"""
|
|
298
|
+
return self._current_message is not None
|
|
299
|
+
|
|
300
|
+
@property
|
|
301
|
+
def current_message(self) -> AssistantMessage:
|
|
302
|
+
"""Get the current message being processed"""
|
|
303
|
+
if not self._current_message:
|
|
304
|
+
msg = "No current message initialized. Call initialize_message first."
|
|
305
|
+
raise ValueError(msg)
|
|
306
|
+
return self._current_message
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
from collections.abc import AsyncGenerator
|
|
2
|
+
from datetime import datetime, timezone
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from aiofiles.threadpool.text import AsyncTextIOWrapper
|
|
6
|
+
from litellm.types.llms.openai import (
|
|
7
|
+
ContentPartAddedEvent,
|
|
8
|
+
FunctionCallArgumentsDeltaEvent,
|
|
9
|
+
FunctionCallArgumentsDoneEvent,
|
|
10
|
+
OutputItemAddedEvent,
|
|
11
|
+
OutputItemDoneEvent,
|
|
12
|
+
OutputTextDeltaEvent,
|
|
13
|
+
ResponseCompletedEvent,
|
|
14
|
+
ResponsesAPIStreamEvents,
|
|
15
|
+
ResponsesAPIStreamingResponse,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
from lite_agent.types import (
|
|
19
|
+
AgentChunk,
|
|
20
|
+
AssistantMessageEvent,
|
|
21
|
+
AssistantMessageMeta,
|
|
22
|
+
ContentDeltaEvent,
|
|
23
|
+
EventUsage,
|
|
24
|
+
FunctionCallEvent,
|
|
25
|
+
NewAssistantMessage,
|
|
26
|
+
ResponseRawEvent,
|
|
27
|
+
Timing,
|
|
28
|
+
TimingEvent,
|
|
29
|
+
UsageEvent,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class ResponseEventProcessor:
|
|
34
|
+
"""Processor for handling response events"""
|
|
35
|
+
|
|
36
|
+
def __init__(self) -> None:
|
|
37
|
+
self._messages: list[dict[str, Any]] = []
|
|
38
|
+
self._start_time: datetime | None = None
|
|
39
|
+
self._first_output_time: datetime | None = None
|
|
40
|
+
self._output_complete_time: datetime | None = None
|
|
41
|
+
self._usage_time: datetime | None = None
|
|
42
|
+
self._usage_data: dict[str, Any] = {}
|
|
43
|
+
|
|
44
|
+
async def process_chunk(
|
|
45
|
+
self,
|
|
46
|
+
chunk: ResponsesAPIStreamingResponse,
|
|
47
|
+
record_file: AsyncTextIOWrapper | None = None,
|
|
48
|
+
) -> AsyncGenerator[AgentChunk, None]:
|
|
49
|
+
# Mark start time on first chunk
|
|
50
|
+
if self._start_time is None:
|
|
51
|
+
self._start_time = datetime.now(timezone.utc)
|
|
52
|
+
|
|
53
|
+
if record_file:
|
|
54
|
+
await record_file.write(chunk.model_dump_json() + "\n")
|
|
55
|
+
await record_file.flush()
|
|
56
|
+
|
|
57
|
+
yield ResponseRawEvent(raw=chunk)
|
|
58
|
+
|
|
59
|
+
events = self.handle_event(chunk)
|
|
60
|
+
for event in events:
|
|
61
|
+
yield event
|
|
62
|
+
|
|
63
|
+
def handle_event(self, event: ResponsesAPIStreamingResponse) -> list[AgentChunk]: # noqa: PLR0911
|
|
64
|
+
"""Handle individual response events"""
|
|
65
|
+
if event.type in (
|
|
66
|
+
ResponsesAPIStreamEvents.RESPONSE_CREATED,
|
|
67
|
+
ResponsesAPIStreamEvents.RESPONSE_IN_PROGRESS,
|
|
68
|
+
ResponsesAPIStreamEvents.OUTPUT_TEXT_DONE,
|
|
69
|
+
ResponsesAPIStreamEvents.CONTENT_PART_DONE,
|
|
70
|
+
):
|
|
71
|
+
return []
|
|
72
|
+
|
|
73
|
+
if isinstance(event, OutputItemAddedEvent):
|
|
74
|
+
self._messages.append(event.item) # type: ignore
|
|
75
|
+
return []
|
|
76
|
+
|
|
77
|
+
if isinstance(event, ContentPartAddedEvent):
|
|
78
|
+
latest_message = self._messages[-1] if self._messages else None
|
|
79
|
+
if latest_message and isinstance(latest_message.get("content"), list):
|
|
80
|
+
latest_message["content"].append(event.part)
|
|
81
|
+
return []
|
|
82
|
+
|
|
83
|
+
if isinstance(event, OutputTextDeltaEvent):
|
|
84
|
+
# Mark first output time if not already set
|
|
85
|
+
if self._first_output_time is None:
|
|
86
|
+
self._first_output_time = datetime.now(timezone.utc)
|
|
87
|
+
|
|
88
|
+
latest_message = self._messages[-1] if self._messages else None
|
|
89
|
+
if latest_message and isinstance(latest_message.get("content"), list):
|
|
90
|
+
latest_content = latest_message["content"][-1]
|
|
91
|
+
if "text" in latest_content:
|
|
92
|
+
latest_content["text"] += event.delta
|
|
93
|
+
return [ContentDeltaEvent(delta=event.delta)]
|
|
94
|
+
return []
|
|
95
|
+
|
|
96
|
+
if isinstance(event, OutputItemDoneEvent):
|
|
97
|
+
item = event.item
|
|
98
|
+
if item.get("type") == "function_call":
|
|
99
|
+
return [
|
|
100
|
+
FunctionCallEvent(
|
|
101
|
+
call_id=item["call_id"],
|
|
102
|
+
name=item["name"],
|
|
103
|
+
arguments=item["arguments"],
|
|
104
|
+
),
|
|
105
|
+
]
|
|
106
|
+
if item.get("type") == "message":
|
|
107
|
+
# Mark output complete time when message is done
|
|
108
|
+
if self._output_complete_time is None:
|
|
109
|
+
self._output_complete_time = datetime.now(timezone.utc)
|
|
110
|
+
|
|
111
|
+
content = item.get("content", [])
|
|
112
|
+
if content and isinstance(content, list) and len(content) > 0:
|
|
113
|
+
end_time = datetime.now(timezone.utc)
|
|
114
|
+
latency_ms = None
|
|
115
|
+
output_time_ms = None
|
|
116
|
+
# latency_ms: 从开始准备输出到 LLM 输出第一个字符的时间差
|
|
117
|
+
if self._start_time and self._first_output_time:
|
|
118
|
+
latency_ms = int((self._first_output_time - self._start_time).total_seconds() * 1000)
|
|
119
|
+
# output_time_ms: 从输出第一个字符到输出完成的时间差
|
|
120
|
+
if self._first_output_time and self._output_complete_time:
|
|
121
|
+
output_time_ms = int((self._output_complete_time - self._first_output_time).total_seconds() * 1000)
|
|
122
|
+
|
|
123
|
+
meta = AssistantMessageMeta(
|
|
124
|
+
sent_at=end_time,
|
|
125
|
+
latency_ms=latency_ms,
|
|
126
|
+
output_time_ms=output_time_ms,
|
|
127
|
+
input_tokens=self._usage_data.get("input_tokens"),
|
|
128
|
+
output_tokens=self._usage_data.get("output_tokens"),
|
|
129
|
+
)
|
|
130
|
+
return [
|
|
131
|
+
AssistantMessageEvent(
|
|
132
|
+
message=NewAssistantMessage(content=[], meta=meta),
|
|
133
|
+
),
|
|
134
|
+
]
|
|
135
|
+
|
|
136
|
+
elif isinstance(event, FunctionCallArgumentsDeltaEvent):
|
|
137
|
+
if self._messages:
|
|
138
|
+
latest_message = self._messages[-1]
|
|
139
|
+
if latest_message.get("type") == "function_call":
|
|
140
|
+
if "arguments" not in latest_message:
|
|
141
|
+
latest_message["arguments"] = ""
|
|
142
|
+
latest_message["arguments"] += event.delta
|
|
143
|
+
return []
|
|
144
|
+
|
|
145
|
+
elif isinstance(event, FunctionCallArgumentsDoneEvent):
|
|
146
|
+
if self._messages:
|
|
147
|
+
latest_message = self._messages[-1]
|
|
148
|
+
if latest_message.get("type") == "function_call":
|
|
149
|
+
latest_message["arguments"] = event.arguments
|
|
150
|
+
return []
|
|
151
|
+
|
|
152
|
+
elif isinstance(event, ResponseCompletedEvent):
|
|
153
|
+
usage = event.response.usage
|
|
154
|
+
if usage:
|
|
155
|
+
# Mark usage time
|
|
156
|
+
self._usage_time = datetime.now(timezone.utc)
|
|
157
|
+
# Store usage data for meta information
|
|
158
|
+
self._usage_data["input_tokens"] = usage.input_tokens
|
|
159
|
+
self._usage_data["output_tokens"] = usage.output_tokens
|
|
160
|
+
# Also store usage time for later calculation
|
|
161
|
+
self._usage_data["usage_time"] = self._usage_time
|
|
162
|
+
|
|
163
|
+
results = []
|
|
164
|
+
|
|
165
|
+
# First yield usage event
|
|
166
|
+
results.append(
|
|
167
|
+
UsageEvent(
|
|
168
|
+
usage=EventUsage(
|
|
169
|
+
input_tokens=usage.input_tokens,
|
|
170
|
+
output_tokens=usage.output_tokens,
|
|
171
|
+
),
|
|
172
|
+
),
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# Then yield timing event if we have timing data
|
|
176
|
+
if self._start_time and self._first_output_time and self._output_complete_time:
|
|
177
|
+
latency_ms = int((self._first_output_time - self._start_time).total_seconds() * 1000)
|
|
178
|
+
output_time_ms = int((self._output_complete_time - self._first_output_time).total_seconds() * 1000)
|
|
179
|
+
|
|
180
|
+
results.append(
|
|
181
|
+
TimingEvent(
|
|
182
|
+
timing=Timing(
|
|
183
|
+
latency_ms=latency_ms,
|
|
184
|
+
output_time_ms=output_time_ms,
|
|
185
|
+
),
|
|
186
|
+
),
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
return results
|
|
190
|
+
|
|
191
|
+
return []
|
|
192
|
+
|
|
193
|
+
@property
|
|
194
|
+
def messages(self) -> list[dict[str, Any]]:
|
|
195
|
+
"""Get the accumulated messages"""
|
|
196
|
+
return self._messages
|
|
197
|
+
|
|
198
|
+
def reset(self) -> None:
|
|
199
|
+
"""Reset the processor state"""
|
|
200
|
+
self._messages = []
|
|
201
|
+
self._start_time = None
|
|
202
|
+
self._first_output_time = None
|
|
203
|
+
self._output_complete_time = None
|
|
204
|
+
self._usage_time = None
|
|
205
|
+
self._usage_data = {}
|