pydantic-ai-slim 1.7.0__py3-none-any.whl → 1.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_ai/__init__.py +2 -0
- pydantic_ai/_agent_graph.py +3 -0
- pydantic_ai/_cli.py +2 -2
- pydantic_ai/ag_ui.py +50 -696
- pydantic_ai/agent/abstract.py +17 -6
- pydantic_ai/direct.py +16 -4
- pydantic_ai/durable_exec/dbos/_agent.py +3 -0
- pydantic_ai/durable_exec/prefect/_agent.py +3 -0
- pydantic_ai/durable_exec/temporal/_agent.py +3 -0
- pydantic_ai/messages.py +39 -7
- pydantic_ai/models/__init__.py +42 -1
- pydantic_ai/models/groq.py +9 -1
- pydantic_ai/models/openai.py +2 -3
- pydantic_ai/result.py +19 -7
- pydantic_ai/ui/__init__.py +16 -0
- pydantic_ai/ui/_adapter.py +386 -0
- pydantic_ai/ui/_event_stream.py +591 -0
- pydantic_ai/ui/_messages_builder.py +28 -0
- pydantic_ai/ui/ag_ui/__init__.py +9 -0
- pydantic_ai/ui/ag_ui/_adapter.py +187 -0
- pydantic_ai/ui/ag_ui/_event_stream.py +227 -0
- pydantic_ai/ui/ag_ui/app.py +141 -0
- pydantic_ai/ui/vercel_ai/__init__.py +16 -0
- pydantic_ai/ui/vercel_ai/_adapter.py +199 -0
- pydantic_ai/ui/vercel_ai/_event_stream.py +187 -0
- pydantic_ai/ui/vercel_ai/_utils.py +16 -0
- pydantic_ai/ui/vercel_ai/request_types.py +275 -0
- pydantic_ai/ui/vercel_ai/response_types.py +230 -0
- {pydantic_ai_slim-1.7.0.dist-info → pydantic_ai_slim-1.9.0.dist-info}/METADATA +5 -3
- {pydantic_ai_slim-1.7.0.dist-info → pydantic_ai_slim-1.9.0.dist-info}/RECORD +33 -19
- {pydantic_ai_slim-1.7.0.dist-info → pydantic_ai_slim-1.9.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-1.7.0.dist-info → pydantic_ai_slim-1.9.0.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-1.7.0.dist-info → pydantic_ai_slim-1.9.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,591 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from collections.abc import AsyncIterator, Awaitable, Callable, Mapping
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar, cast
|
|
8
|
+
from uuid import uuid4
|
|
9
|
+
|
|
10
|
+
from pydantic_ai import _utils
|
|
11
|
+
|
|
12
|
+
from ..messages import (
|
|
13
|
+
AgentStreamEvent,
|
|
14
|
+
BuiltinToolCallEvent, # pyright: ignore[reportDeprecated]
|
|
15
|
+
BuiltinToolCallPart,
|
|
16
|
+
BuiltinToolResultEvent, # pyright: ignore[reportDeprecated]
|
|
17
|
+
BuiltinToolReturnPart,
|
|
18
|
+
FilePart,
|
|
19
|
+
FinalResultEvent,
|
|
20
|
+
FunctionToolCallEvent,
|
|
21
|
+
FunctionToolResultEvent,
|
|
22
|
+
PartDeltaEvent,
|
|
23
|
+
PartEndEvent,
|
|
24
|
+
PartStartEvent,
|
|
25
|
+
TextPart,
|
|
26
|
+
TextPartDelta,
|
|
27
|
+
ThinkingPart,
|
|
28
|
+
ThinkingPartDelta,
|
|
29
|
+
ToolCallPart,
|
|
30
|
+
ToolCallPartDelta,
|
|
31
|
+
ToolReturnPart,
|
|
32
|
+
)
|
|
33
|
+
from ..output import OutputDataT
|
|
34
|
+
from ..run import AgentRunResult, AgentRunResultEvent
|
|
35
|
+
from ..tools import AgentDepsT
|
|
36
|
+
|
|
37
|
+
if TYPE_CHECKING:
|
|
38
|
+
from starlette.responses import StreamingResponse
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
SSE_CONTENT_TYPE = 'text/event-stream'
|
|
42
|
+
"""Content type header value for Server-Sent Events (SSE)."""
|
|
43
|
+
|
|
44
|
+
EventT = TypeVar('EventT')
|
|
45
|
+
"""Type variable for protocol-specific event types."""
|
|
46
|
+
|
|
47
|
+
RunInputT = TypeVar('RunInputT')
|
|
48
|
+
"""Type variable for protocol-specific run input types."""
|
|
49
|
+
|
|
50
|
+
NativeEvent: TypeAlias = AgentStreamEvent | AgentRunResultEvent[Any]
|
|
51
|
+
"""Type alias for the native event type, which is either an `AgentStreamEvent` or an `AgentRunResultEvent`."""
|
|
52
|
+
|
|
53
|
+
OnCompleteFunc: TypeAlias = (
|
|
54
|
+
Callable[[AgentRunResult[Any]], None]
|
|
55
|
+
| Callable[[AgentRunResult[Any]], Awaitable[None]]
|
|
56
|
+
| Callable[[AgentRunResult[Any]], AsyncIterator[EventT]]
|
|
57
|
+
)
|
|
58
|
+
"""Callback function type that receives the `AgentRunResult` of the completed run. Can be sync, async, or an async generator of protocol-specific events."""
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@dataclass
|
|
62
|
+
class UIEventStream(ABC, Generic[RunInputT, EventT, AgentDepsT, OutputDataT]):
|
|
63
|
+
"""Base class for UI event stream transformers.
|
|
64
|
+
|
|
65
|
+
This class is responsible for transforming Pydantic AI events into protocol-specific events.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
run_input: RunInputT
|
|
69
|
+
|
|
70
|
+
accept: str | None = None
|
|
71
|
+
"""The `Accept` header value of the request, used to determine how to encode the protocol-specific events for the streaming response."""
|
|
72
|
+
|
|
73
|
+
message_id: str = field(default_factory=lambda: str(uuid4()))
|
|
74
|
+
"""The message ID to use for the next event."""
|
|
75
|
+
|
|
76
|
+
_turn: Literal['request', 'response'] | None = None
|
|
77
|
+
|
|
78
|
+
_result: AgentRunResult[OutputDataT] | None = None
|
|
79
|
+
_final_result_event: FinalResultEvent | None = None
|
|
80
|
+
|
|
81
|
+
def new_message_id(self) -> str:
|
|
82
|
+
"""Generate and store a new message ID."""
|
|
83
|
+
self.message_id = str(uuid4())
|
|
84
|
+
return self.message_id
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def response_headers(self) -> Mapping[str, str] | None:
|
|
88
|
+
"""Response headers to return to the frontend."""
|
|
89
|
+
return None
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def content_type(self) -> str:
|
|
93
|
+
"""Get the content type for the event stream, compatible with the `Accept` header value.
|
|
94
|
+
|
|
95
|
+
By default, this returns the Server-Sent Events content type (`text/event-stream`).
|
|
96
|
+
If a subclass supports other types as well, it should consider `self.accept` in [`encode_event()`][pydantic_ai.ui.UIEventStream.encode_event] and return the resulting content type.
|
|
97
|
+
"""
|
|
98
|
+
return SSE_CONTENT_TYPE
|
|
99
|
+
|
|
100
|
+
@abstractmethod
|
|
101
|
+
def encode_event(self, event: EventT) -> str:
|
|
102
|
+
"""Encode a protocol-specific event as a string."""
|
|
103
|
+
raise NotImplementedError
|
|
104
|
+
|
|
105
|
+
async def encode_stream(self, stream: AsyncIterator[EventT]) -> AsyncIterator[str]:
|
|
106
|
+
"""Encode a stream of protocol-specific events as strings according to the `Accept` header value."""
|
|
107
|
+
async for event in stream:
|
|
108
|
+
yield self.encode_event(event)
|
|
109
|
+
|
|
110
|
+
def streaming_response(self, stream: AsyncIterator[EventT]) -> StreamingResponse:
|
|
111
|
+
"""Generate a streaming response from a stream of protocol-specific events."""
|
|
112
|
+
try:
|
|
113
|
+
from starlette.responses import StreamingResponse
|
|
114
|
+
except ImportError as e: # pragma: no cover
|
|
115
|
+
raise ImportError(
|
|
116
|
+
'Please install the `starlette` package to use the `streaming_response()` method, '
|
|
117
|
+
'you can use the `ui` optional group — `pip install "pydantic-ai-slim[ui]"`'
|
|
118
|
+
) from e
|
|
119
|
+
|
|
120
|
+
return StreamingResponse(
|
|
121
|
+
self.encode_stream(stream),
|
|
122
|
+
headers=self.response_headers,
|
|
123
|
+
media_type=self.content_type,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
async def transform_stream( # noqa: C901
|
|
127
|
+
self, stream: AsyncIterator[NativeEvent], on_complete: OnCompleteFunc[EventT] | None = None
|
|
128
|
+
) -> AsyncIterator[EventT]:
|
|
129
|
+
"""Transform a stream of Pydantic AI events into protocol-specific events.
|
|
130
|
+
|
|
131
|
+
This method dispatches to specific hooks and `handle_*` methods that subclasses can override:
|
|
132
|
+
- [`before_stream()`][pydantic_ai.ui.UIEventStream.before_stream]
|
|
133
|
+
- [`after_stream()`][pydantic_ai.ui.UIEventStream.after_stream]
|
|
134
|
+
- [`on_error()`][pydantic_ai.ui.UIEventStream.on_error]
|
|
135
|
+
- [`before_request()`][pydantic_ai.ui.UIEventStream.before_request]
|
|
136
|
+
- [`after_request()`][pydantic_ai.ui.UIEventStream.after_request]
|
|
137
|
+
- [`before_response()`][pydantic_ai.ui.UIEventStream.before_response]
|
|
138
|
+
- [`after_response()`][pydantic_ai.ui.UIEventStream.after_response]
|
|
139
|
+
- [`handle_event()`][pydantic_ai.ui.UIEventStream.handle_event]
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
stream: The stream of Pydantic AI events to transform.
|
|
143
|
+
on_complete: Optional callback function called when the agent run completes successfully.
|
|
144
|
+
The callback receives the completed [`AgentRunResult`][pydantic_ai.agent.AgentRunResult] and can optionally yield additional protocol-specific events.
|
|
145
|
+
"""
|
|
146
|
+
async for e in self.before_stream():
|
|
147
|
+
yield e
|
|
148
|
+
|
|
149
|
+
try:
|
|
150
|
+
async for event in stream:
|
|
151
|
+
if isinstance(event, PartStartEvent):
|
|
152
|
+
async for e in self._turn_to('response'):
|
|
153
|
+
yield e
|
|
154
|
+
elif isinstance(event, FunctionToolCallEvent):
|
|
155
|
+
async for e in self._turn_to('request'):
|
|
156
|
+
yield e
|
|
157
|
+
elif isinstance(event, AgentRunResultEvent):
|
|
158
|
+
if (
|
|
159
|
+
self._final_result_event
|
|
160
|
+
and (tool_call_id := self._final_result_event.tool_call_id)
|
|
161
|
+
and (tool_name := self._final_result_event.tool_name)
|
|
162
|
+
):
|
|
163
|
+
async for e in self._turn_to('request'):
|
|
164
|
+
yield e
|
|
165
|
+
|
|
166
|
+
self._final_result_event = None
|
|
167
|
+
# Ensure the stream does not end on a dangling tool call without a result.
|
|
168
|
+
output_tool_result_event = FunctionToolResultEvent(
|
|
169
|
+
result=ToolReturnPart(
|
|
170
|
+
tool_call_id=tool_call_id,
|
|
171
|
+
tool_name=tool_name,
|
|
172
|
+
content='Final result processed.',
|
|
173
|
+
)
|
|
174
|
+
)
|
|
175
|
+
async for e in self.handle_function_tool_result(output_tool_result_event):
|
|
176
|
+
yield e
|
|
177
|
+
|
|
178
|
+
result = cast(AgentRunResult[OutputDataT], event.result)
|
|
179
|
+
self._result = result
|
|
180
|
+
|
|
181
|
+
async for e in self._turn_to(None):
|
|
182
|
+
yield e
|
|
183
|
+
|
|
184
|
+
if on_complete is not None:
|
|
185
|
+
if inspect.isasyncgenfunction(on_complete):
|
|
186
|
+
async for e in on_complete(result):
|
|
187
|
+
yield e
|
|
188
|
+
elif _utils.is_async_callable(on_complete):
|
|
189
|
+
await on_complete(result)
|
|
190
|
+
else:
|
|
191
|
+
await _utils.run_in_executor(on_complete, result)
|
|
192
|
+
elif isinstance(event, FinalResultEvent):
|
|
193
|
+
self._final_result_event = event
|
|
194
|
+
|
|
195
|
+
if isinstance(event, BuiltinToolCallEvent | BuiltinToolResultEvent): # pyright: ignore[reportDeprecated]
|
|
196
|
+
# These events were deprecated before this feature was introduced
|
|
197
|
+
continue
|
|
198
|
+
|
|
199
|
+
async for e in self.handle_event(event):
|
|
200
|
+
yield e
|
|
201
|
+
except Exception as e:
|
|
202
|
+
async for e in self.on_error(e):
|
|
203
|
+
yield e
|
|
204
|
+
finally:
|
|
205
|
+
async for e in self._turn_to(None):
|
|
206
|
+
yield e
|
|
207
|
+
|
|
208
|
+
async for e in self.after_stream():
|
|
209
|
+
yield e
|
|
210
|
+
|
|
211
|
+
async def _turn_to(self, to_turn: Literal['request', 'response'] | None) -> AsyncIterator[EventT]:
|
|
212
|
+
"""Fire hooks when turning from request to response or vice versa."""
|
|
213
|
+
if to_turn == self._turn:
|
|
214
|
+
return
|
|
215
|
+
|
|
216
|
+
if self._turn == 'request':
|
|
217
|
+
async for e in self.after_request():
|
|
218
|
+
yield e
|
|
219
|
+
elif self._turn == 'response':
|
|
220
|
+
async for e in self.after_response():
|
|
221
|
+
yield e
|
|
222
|
+
|
|
223
|
+
self._turn = to_turn
|
|
224
|
+
|
|
225
|
+
if to_turn == 'request':
|
|
226
|
+
async for e in self.before_request():
|
|
227
|
+
yield e
|
|
228
|
+
elif to_turn == 'response':
|
|
229
|
+
async for e in self.before_response():
|
|
230
|
+
yield e
|
|
231
|
+
|
|
232
|
+
async def handle_event(self, event: NativeEvent) -> AsyncIterator[EventT]:
|
|
233
|
+
"""Transform a Pydantic AI event into one or more protocol-specific events.
|
|
234
|
+
|
|
235
|
+
This method dispatches to specific `handle_*` methods based on event type:
|
|
236
|
+
|
|
237
|
+
- [`PartStartEvent`][pydantic_ai.messages.PartStartEvent] -> [`handle_part_start()`][pydantic_ai.ui.UIEventStream.handle_part_start]
|
|
238
|
+
- [`PartDeltaEvent`][pydantic_ai.messages.PartDeltaEvent] -> `handle_part_delta`
|
|
239
|
+
- [`PartEndEvent`][pydantic_ai.messages.PartEndEvent] -> `handle_part_end`
|
|
240
|
+
- [`FinalResultEvent`][pydantic_ai.messages.FinalResultEvent] -> `handle_final_result`
|
|
241
|
+
- [`FunctionToolCallEvent`][pydantic_ai.messages.FunctionToolCallEvent] -> `handle_function_tool_call`
|
|
242
|
+
- [`FunctionToolResultEvent`][pydantic_ai.messages.FunctionToolResultEvent] -> `handle_function_tool_result`
|
|
243
|
+
- [`AgentRunResultEvent`][pydantic_ai.run.AgentRunResultEvent] -> `handle_run_result`
|
|
244
|
+
|
|
245
|
+
Subclasses are encouraged to override the individual `handle_*` methods rather than this one.
|
|
246
|
+
If you need specific behavior for all events, make sure you call the super method.
|
|
247
|
+
"""
|
|
248
|
+
match event:
|
|
249
|
+
case PartStartEvent():
|
|
250
|
+
async for e in self.handle_part_start(event):
|
|
251
|
+
yield e
|
|
252
|
+
case PartDeltaEvent():
|
|
253
|
+
async for e in self.handle_part_delta(event):
|
|
254
|
+
yield e
|
|
255
|
+
case PartEndEvent():
|
|
256
|
+
async for e in self.handle_part_end(event):
|
|
257
|
+
yield e
|
|
258
|
+
case FinalResultEvent():
|
|
259
|
+
async for e in self.handle_final_result(event):
|
|
260
|
+
yield e
|
|
261
|
+
case FunctionToolCallEvent():
|
|
262
|
+
async for e in self.handle_function_tool_call(event):
|
|
263
|
+
yield e
|
|
264
|
+
case FunctionToolResultEvent():
|
|
265
|
+
async for e in self.handle_function_tool_result(event):
|
|
266
|
+
yield e
|
|
267
|
+
case AgentRunResultEvent():
|
|
268
|
+
async for e in self.handle_run_result(event):
|
|
269
|
+
yield e
|
|
270
|
+
case _:
|
|
271
|
+
pass
|
|
272
|
+
|
|
273
|
+
async def handle_part_start(self, event: PartStartEvent) -> AsyncIterator[EventT]:
|
|
274
|
+
"""Handle a `PartStartEvent`.
|
|
275
|
+
|
|
276
|
+
This method dispatches to specific `handle_*` methods based on part type:
|
|
277
|
+
|
|
278
|
+
- [`TextPart`][pydantic_ai.messages.TextPart] -> [`handle_text_start()`][pydantic_ai.ui.UIEventStream.handle_text_start]
|
|
279
|
+
- [`ThinkingPart`][pydantic_ai.messages.ThinkingPart] -> [`handle_thinking_start()`][pydantic_ai.ui.UIEventStream.handle_thinking_start]
|
|
280
|
+
- [`ToolCallPart`][pydantic_ai.messages.ToolCallPart] -> [`handle_tool_call_start()`][pydantic_ai.ui.UIEventStream.handle_tool_call_start]
|
|
281
|
+
- [`BuiltinToolCallPart`][pydantic_ai.messages.BuiltinToolCallPart] -> [`handle_builtin_tool_call_start()`][pydantic_ai.ui.UIEventStream.handle_builtin_tool_call_start]
|
|
282
|
+
- [`BuiltinToolReturnPart`][pydantic_ai.messages.BuiltinToolReturnPart] -> [`handle_builtin_tool_return()`][pydantic_ai.ui.UIEventStream.handle_builtin_tool_return]
|
|
283
|
+
- [`FilePart`][pydantic_ai.messages.FilePart] -> [`handle_file()`][pydantic_ai.ui.UIEventStream.handle_file]
|
|
284
|
+
|
|
285
|
+
Subclasses are encouraged to override the individual `handle_*` methods rather than this one.
|
|
286
|
+
If you need specific behavior for all part start events, make sure you call the super method.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
event: The part start event.
|
|
290
|
+
"""
|
|
291
|
+
part = event.part
|
|
292
|
+
previous_part_kind = event.previous_part_kind
|
|
293
|
+
match part:
|
|
294
|
+
case TextPart():
|
|
295
|
+
async for e in self.handle_text_start(part, follows_text=previous_part_kind == 'text'):
|
|
296
|
+
yield e
|
|
297
|
+
case ThinkingPart():
|
|
298
|
+
async for e in self.handle_thinking_start(part, follows_thinking=previous_part_kind == 'thinking'):
|
|
299
|
+
yield e
|
|
300
|
+
case ToolCallPart():
|
|
301
|
+
async for e in self.handle_tool_call_start(part):
|
|
302
|
+
yield e
|
|
303
|
+
case BuiltinToolCallPart():
|
|
304
|
+
async for e in self.handle_builtin_tool_call_start(part):
|
|
305
|
+
yield e
|
|
306
|
+
case BuiltinToolReturnPart():
|
|
307
|
+
async for e in self.handle_builtin_tool_return(part):
|
|
308
|
+
yield e
|
|
309
|
+
case FilePart(): # pragma: no branch
|
|
310
|
+
async for e in self.handle_file(part):
|
|
311
|
+
yield e
|
|
312
|
+
|
|
313
|
+
async def handle_part_delta(self, event: PartDeltaEvent) -> AsyncIterator[EventT]:
|
|
314
|
+
"""Handle a PartDeltaEvent.
|
|
315
|
+
|
|
316
|
+
This method dispatches to specific `handle_*_delta` methods based on part delta type:
|
|
317
|
+
|
|
318
|
+
- [`TextPartDelta`][pydantic_ai.messages.TextPartDelta] -> [`handle_text_delta()`][pydantic_ai.ui.UIEventStream.handle_text_delta]
|
|
319
|
+
- [`ThinkingPartDelta`][pydantic_ai.messages.ThinkingPartDelta] -> [`handle_thinking_delta()`][pydantic_ai.ui.UIEventStream.handle_thinking_delta]
|
|
320
|
+
- [`ToolCallPartDelta`][pydantic_ai.messages.ToolCallPartDelta] -> [`handle_tool_call_delta()`][pydantic_ai.ui.UIEventStream.handle_tool_call_delta]
|
|
321
|
+
|
|
322
|
+
Subclasses are encouraged to override the individual `handle_*_delta` methods rather than this one.
|
|
323
|
+
If you need specific behavior for all part delta events, make sure you call the super method.
|
|
324
|
+
|
|
325
|
+
Args:
|
|
326
|
+
event: The PartDeltaEvent.
|
|
327
|
+
"""
|
|
328
|
+
delta = event.delta
|
|
329
|
+
match delta:
|
|
330
|
+
case TextPartDelta():
|
|
331
|
+
async for e in self.handle_text_delta(delta):
|
|
332
|
+
yield e
|
|
333
|
+
case ThinkingPartDelta():
|
|
334
|
+
async for e in self.handle_thinking_delta(delta):
|
|
335
|
+
yield e
|
|
336
|
+
case ToolCallPartDelta(): # pragma: no branch
|
|
337
|
+
async for e in self.handle_tool_call_delta(delta):
|
|
338
|
+
yield e
|
|
339
|
+
|
|
340
|
+
async def handle_part_end(self, event: PartEndEvent) -> AsyncIterator[EventT]:
|
|
341
|
+
"""Handle a `PartEndEvent`.
|
|
342
|
+
|
|
343
|
+
This method dispatches to specific `handle_*_end` methods based on part type:
|
|
344
|
+
|
|
345
|
+
- [`TextPart`][pydantic_ai.messages.TextPart] -> [`handle_text_end()`][pydantic_ai.ui.UIEventStream.handle_text_end]
|
|
346
|
+
- [`ThinkingPart`][pydantic_ai.messages.ThinkingPart] -> [`handle_thinking_end()`][pydantic_ai.ui.UIEventStream.handle_thinking_end]
|
|
347
|
+
- [`ToolCallPart`][pydantic_ai.messages.ToolCallPart] -> [`handle_tool_call_end()`][pydantic_ai.ui.UIEventStream.handle_tool_call_end]
|
|
348
|
+
- [`BuiltinToolCallPart`][pydantic_ai.messages.BuiltinToolCallPart] -> [`handle_builtin_tool_call_end()`][pydantic_ai.ui.UIEventStream.handle_builtin_tool_call_end]
|
|
349
|
+
|
|
350
|
+
Subclasses are encouraged to override the individual `handle_*_end` methods rather than this one.
|
|
351
|
+
If you need specific behavior for all part end events, make sure you call the super method.
|
|
352
|
+
|
|
353
|
+
Args:
|
|
354
|
+
event: The part end event.
|
|
355
|
+
"""
|
|
356
|
+
part = event.part
|
|
357
|
+
next_part_kind = event.next_part_kind
|
|
358
|
+
match part:
|
|
359
|
+
case TextPart():
|
|
360
|
+
async for e in self.handle_text_end(part, followed_by_text=next_part_kind == 'text'):
|
|
361
|
+
yield e
|
|
362
|
+
case ThinkingPart():
|
|
363
|
+
async for e in self.handle_thinking_end(part, followed_by_thinking=next_part_kind == 'thinking'):
|
|
364
|
+
yield e
|
|
365
|
+
case ToolCallPart():
|
|
366
|
+
async for e in self.handle_tool_call_end(part):
|
|
367
|
+
yield e
|
|
368
|
+
case BuiltinToolCallPart():
|
|
369
|
+
async for e in self.handle_builtin_tool_call_end(part):
|
|
370
|
+
yield e
|
|
371
|
+
case BuiltinToolReturnPart() | FilePart(): # pragma: no cover
|
|
372
|
+
# These don't have deltas, so they don't need to be ended.
|
|
373
|
+
pass
|
|
374
|
+
|
|
375
|
+
async def before_stream(self) -> AsyncIterator[EventT]:
|
|
376
|
+
"""Yield events before agent streaming starts.
|
|
377
|
+
|
|
378
|
+
This hook is called before any agent events are processed.
|
|
379
|
+
Override this to inject custom events at the start of the stream.
|
|
380
|
+
"""
|
|
381
|
+
return # pragma: no cover
|
|
382
|
+
yield # Make this an async generator
|
|
383
|
+
|
|
384
|
+
async def after_stream(self) -> AsyncIterator[EventT]:
|
|
385
|
+
"""Yield events after agent streaming completes.
|
|
386
|
+
|
|
387
|
+
This hook is called after all agent events have been processed.
|
|
388
|
+
Override this to inject custom events at the end of the stream.
|
|
389
|
+
"""
|
|
390
|
+
return # pragma: no cover
|
|
391
|
+
yield # Make this an async generator
|
|
392
|
+
|
|
393
|
+
async def on_error(self, error: Exception) -> AsyncIterator[EventT]:
|
|
394
|
+
"""Handle errors that occur during streaming.
|
|
395
|
+
|
|
396
|
+
Args:
|
|
397
|
+
error: The error that occurred during streaming.
|
|
398
|
+
"""
|
|
399
|
+
return # pragma: no cover
|
|
400
|
+
yield # Make this an async generator
|
|
401
|
+
|
|
402
|
+
async def before_request(self) -> AsyncIterator[EventT]:
|
|
403
|
+
"""Yield events before a model request is processed.
|
|
404
|
+
|
|
405
|
+
Override this to inject custom events at the start of the request.
|
|
406
|
+
"""
|
|
407
|
+
return
|
|
408
|
+
yield # Make this an async generator
|
|
409
|
+
|
|
410
|
+
async def after_request(self) -> AsyncIterator[EventT]:
|
|
411
|
+
"""Yield events after a model request is processed.
|
|
412
|
+
|
|
413
|
+
Override this to inject custom events at the end of the request.
|
|
414
|
+
"""
|
|
415
|
+
return
|
|
416
|
+
yield # Make this an async generator
|
|
417
|
+
|
|
418
|
+
async def before_response(self) -> AsyncIterator[EventT]:
|
|
419
|
+
"""Yield events before a model response is processed.
|
|
420
|
+
|
|
421
|
+
Override this to inject custom events at the start of the response.
|
|
422
|
+
"""
|
|
423
|
+
return
|
|
424
|
+
yield # Make this an async generator
|
|
425
|
+
|
|
426
|
+
async def after_response(self) -> AsyncIterator[EventT]:
|
|
427
|
+
"""Yield events after a model response is processed.
|
|
428
|
+
|
|
429
|
+
Override this to inject custom events at the end of the response.
|
|
430
|
+
"""
|
|
431
|
+
return
|
|
432
|
+
yield # Make this an async generator
|
|
433
|
+
|
|
434
|
+
async def handle_text_start(self, part: TextPart, follows_text: bool = False) -> AsyncIterator[EventT]:
|
|
435
|
+
"""Handle the start of a `TextPart`.
|
|
436
|
+
|
|
437
|
+
Args:
|
|
438
|
+
part: The text part.
|
|
439
|
+
follows_text: Whether the part is directly preceded by another text part. In this case, you may want to yield a "text-delta" event instead of a "text-start" event.
|
|
440
|
+
"""
|
|
441
|
+
return # pragma: no cover
|
|
442
|
+
yield # Make this an async generator
|
|
443
|
+
|
|
444
|
+
async def handle_text_delta(self, delta: TextPartDelta) -> AsyncIterator[EventT]:
|
|
445
|
+
"""Handle a `TextPartDelta`.
|
|
446
|
+
|
|
447
|
+
Args:
|
|
448
|
+
delta: The text part delta.
|
|
449
|
+
"""
|
|
450
|
+
return # pragma: no cover
|
|
451
|
+
yield # Make this an async generator
|
|
452
|
+
|
|
453
|
+
async def handle_text_end(self, part: TextPart, followed_by_text: bool = False) -> AsyncIterator[EventT]:
|
|
454
|
+
"""Handle the end of a `TextPart`.
|
|
455
|
+
|
|
456
|
+
Args:
|
|
457
|
+
part: The text part.
|
|
458
|
+
followed_by_text: Whether the part is directly followed by another text part. In this case, you may not want to yield a "text-end" event yet.
|
|
459
|
+
"""
|
|
460
|
+
return # pragma: no cover
|
|
461
|
+
yield # Make this an async generator
|
|
462
|
+
|
|
463
|
+
async def handle_thinking_start(self, part: ThinkingPart, follows_thinking: bool = False) -> AsyncIterator[EventT]:
|
|
464
|
+
"""Handle the start of a `ThinkingPart`.
|
|
465
|
+
|
|
466
|
+
Args:
|
|
467
|
+
part: The thinking part.
|
|
468
|
+
follows_thinking: Whether the part is directly preceded by another thinking part. In this case, you may want to yield a "thinking-delta" event instead of a "thinking-start" event.
|
|
469
|
+
"""
|
|
470
|
+
return # pragma: no cover
|
|
471
|
+
yield # Make this an async generator
|
|
472
|
+
|
|
473
|
+
async def handle_thinking_delta(self, delta: ThinkingPartDelta) -> AsyncIterator[EventT]:
|
|
474
|
+
"""Handle a `ThinkingPartDelta`.
|
|
475
|
+
|
|
476
|
+
Args:
|
|
477
|
+
delta: The thinking part delta.
|
|
478
|
+
"""
|
|
479
|
+
return # pragma: no cover
|
|
480
|
+
yield # Make this an async generator
|
|
481
|
+
|
|
482
|
+
async def handle_thinking_end(
|
|
483
|
+
self, part: ThinkingPart, followed_by_thinking: bool = False
|
|
484
|
+
) -> AsyncIterator[EventT]:
|
|
485
|
+
"""Handle the end of a `ThinkingPart`.
|
|
486
|
+
|
|
487
|
+
Args:
|
|
488
|
+
part: The thinking part.
|
|
489
|
+
followed_by_thinking: Whether the part is directly followed by another thinking part. In this case, you may not want to yield a "thinking-end" event yet.
|
|
490
|
+
"""
|
|
491
|
+
return # pragma: no cover
|
|
492
|
+
yield # Make this an async generator
|
|
493
|
+
|
|
494
|
+
async def handle_tool_call_start(self, part: ToolCallPart) -> AsyncIterator[EventT]:
|
|
495
|
+
"""Handle the start of a `ToolCallPart`.
|
|
496
|
+
|
|
497
|
+
Args:
|
|
498
|
+
part: The tool call part.
|
|
499
|
+
"""
|
|
500
|
+
return # pragma: no cover
|
|
501
|
+
yield # Make this an async generator
|
|
502
|
+
|
|
503
|
+
async def handle_tool_call_delta(self, delta: ToolCallPartDelta) -> AsyncIterator[EventT]:
|
|
504
|
+
"""Handle a `ToolCallPartDelta`.
|
|
505
|
+
|
|
506
|
+
Args:
|
|
507
|
+
delta: The tool call part delta.
|
|
508
|
+
"""
|
|
509
|
+
return # pragma: no cover
|
|
510
|
+
yield # Make this an async generator
|
|
511
|
+
|
|
512
|
+
async def handle_tool_call_end(self, part: ToolCallPart) -> AsyncIterator[EventT]:
|
|
513
|
+
"""Handle the end of a `ToolCallPart`.
|
|
514
|
+
|
|
515
|
+
Args:
|
|
516
|
+
part: The tool call part.
|
|
517
|
+
"""
|
|
518
|
+
return # pragma: no cover
|
|
519
|
+
yield # Make this an async generator
|
|
520
|
+
|
|
521
|
+
async def handle_builtin_tool_call_start(self, part: BuiltinToolCallPart) -> AsyncIterator[EventT]:
|
|
522
|
+
"""Handle a `BuiltinToolCallPart` at start.
|
|
523
|
+
|
|
524
|
+
Args:
|
|
525
|
+
part: The builtin tool call part.
|
|
526
|
+
"""
|
|
527
|
+
return # pragma: no cover
|
|
528
|
+
yield # Make this an async generator
|
|
529
|
+
|
|
530
|
+
async def handle_builtin_tool_call_end(self, part: BuiltinToolCallPart) -> AsyncIterator[EventT]:
|
|
531
|
+
"""Handle the end of a `BuiltinToolCallPart`.
|
|
532
|
+
|
|
533
|
+
Args:
|
|
534
|
+
part: The builtin tool call part.
|
|
535
|
+
"""
|
|
536
|
+
return # pragma: no cover
|
|
537
|
+
yield # Make this an async generator
|
|
538
|
+
|
|
539
|
+
async def handle_builtin_tool_return(self, part: BuiltinToolReturnPart) -> AsyncIterator[EventT]:
|
|
540
|
+
"""Handle a `BuiltinToolReturnPart`.
|
|
541
|
+
|
|
542
|
+
Args:
|
|
543
|
+
part: The builtin tool return part.
|
|
544
|
+
"""
|
|
545
|
+
return # pragma: no cover
|
|
546
|
+
yield # Make this an async generator
|
|
547
|
+
|
|
548
|
+
async def handle_file(self, part: FilePart) -> AsyncIterator[EventT]:
|
|
549
|
+
"""Handle a `FilePart`.
|
|
550
|
+
|
|
551
|
+
Args:
|
|
552
|
+
part: The file part.
|
|
553
|
+
"""
|
|
554
|
+
return # pragma: no cover
|
|
555
|
+
yield # Make this an async generator
|
|
556
|
+
|
|
557
|
+
async def handle_final_result(self, event: FinalResultEvent) -> AsyncIterator[EventT]:
|
|
558
|
+
"""Handle a `FinalResultEvent`.
|
|
559
|
+
|
|
560
|
+
Args:
|
|
561
|
+
event: The final result event.
|
|
562
|
+
"""
|
|
563
|
+
return
|
|
564
|
+
yield # Make this an async generator
|
|
565
|
+
|
|
566
|
+
async def handle_function_tool_call(self, event: FunctionToolCallEvent) -> AsyncIterator[EventT]:
|
|
567
|
+
"""Handle a `FunctionToolCallEvent`.
|
|
568
|
+
|
|
569
|
+
Args:
|
|
570
|
+
event: The function tool call event.
|
|
571
|
+
"""
|
|
572
|
+
return
|
|
573
|
+
yield # Make this an async generator
|
|
574
|
+
|
|
575
|
+
async def handle_function_tool_result(self, event: FunctionToolResultEvent) -> AsyncIterator[EventT]:
|
|
576
|
+
"""Handle a `FunctionToolResultEvent`.
|
|
577
|
+
|
|
578
|
+
Args:
|
|
579
|
+
event: The function tool result event.
|
|
580
|
+
"""
|
|
581
|
+
return # pragma: no cover
|
|
582
|
+
yield # Make this an async generator
|
|
583
|
+
|
|
584
|
+
async def handle_run_result(self, event: AgentRunResultEvent) -> AsyncIterator[EventT]:
|
|
585
|
+
"""Handle an `AgentRunResultEvent`.
|
|
586
|
+
|
|
587
|
+
Args:
|
|
588
|
+
event: The agent run result event.
|
|
589
|
+
"""
|
|
590
|
+
return
|
|
591
|
+
yield # Make this an async generator
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import cast
|
|
3
|
+
|
|
4
|
+
from pydantic_ai._utils import get_union_args
|
|
5
|
+
from pydantic_ai.messages import ModelMessage, ModelRequest, ModelRequestPart, ModelResponse, ModelResponsePart
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class MessagesBuilder:
|
|
10
|
+
"""Helper class to build Pydantic AI messages from request/response parts."""
|
|
11
|
+
|
|
12
|
+
messages: list[ModelMessage] = field(default_factory=list)
|
|
13
|
+
|
|
14
|
+
def add(self, part: ModelRequestPart | ModelResponsePart) -> None:
|
|
15
|
+
"""Add a new part, creating a new request or response message if necessary."""
|
|
16
|
+
last_message = self.messages[-1] if self.messages else None
|
|
17
|
+
if isinstance(part, get_union_args(ModelRequestPart)):
|
|
18
|
+
part = cast(ModelRequestPart, part)
|
|
19
|
+
if isinstance(last_message, ModelRequest):
|
|
20
|
+
last_message.parts = [*last_message.parts, part]
|
|
21
|
+
else:
|
|
22
|
+
self.messages.append(ModelRequest(parts=[part]))
|
|
23
|
+
else:
|
|
24
|
+
part = cast(ModelResponsePart, part)
|
|
25
|
+
if isinstance(last_message, ModelResponse):
|
|
26
|
+
last_message.parts = [*last_message.parts, part]
|
|
27
|
+
else:
|
|
28
|
+
self.messages.append(ModelResponse(parts=[part]))
|