pydantic-ai-slim 0.4.10__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pydantic_ai/ag_ui.py CHANGED
@@ -8,11 +8,10 @@ from __future__ import annotations
8
8
 
9
9
  import json
10
10
  import uuid
11
- from collections.abc import Iterable, Mapping, Sequence
12
- from dataclasses import Field, dataclass, field, replace
11
+ from collections.abc import AsyncIterator, Iterable, Mapping, Sequence
12
+ from dataclasses import Field, dataclass, replace
13
13
  from http import HTTPStatus
14
14
  from typing import (
15
- TYPE_CHECKING,
16
15
  Any,
17
16
  Callable,
18
17
  ClassVar,
@@ -23,10 +22,36 @@ from typing import (
23
22
  runtime_checkable,
24
23
  )
25
24
 
26
- from pydantic_ai.exceptions import UserError
25
+ from pydantic import BaseModel, ValidationError
27
26
 
28
- if TYPE_CHECKING:
29
- pass
27
+ from ._agent_graph import CallToolsNode, ModelRequestNode
28
+ from .agent import Agent, AgentRun
29
+ from .exceptions import UserError
30
+ from .messages import (
31
+ AgentStreamEvent,
32
+ FunctionToolResultEvent,
33
+ ModelMessage,
34
+ ModelRequest,
35
+ ModelResponse,
36
+ PartDeltaEvent,
37
+ PartStartEvent,
38
+ SystemPromptPart,
39
+ TextPart,
40
+ TextPartDelta,
41
+ ThinkingPart,
42
+ ThinkingPartDelta,
43
+ ToolCallPart,
44
+ ToolCallPartDelta,
45
+ ToolReturnPart,
46
+ UserPromptPart,
47
+ )
48
+ from .models import KnownModelName, Model
49
+ from .output import DeferredToolCalls, OutputDataT, OutputSpec
50
+ from .settings import ModelSettings
51
+ from .tools import AgentDepsT, ToolDefinition
52
+ from .toolsets import AbstractToolset
53
+ from .toolsets.deferred import DeferredToolset
54
+ from .usage import Usage, UsageLimits
30
55
 
31
56
  try:
32
57
  from ag_ui.core import (
@@ -74,43 +99,13 @@ except ImportError as e: # pragma: no cover
74
99
  'you can use the `ag-ui` optional group — `pip install "pydantic-ai-slim[ag-ui]"`'
75
100
  ) from e
76
101
 
77
- from collections.abc import AsyncGenerator
78
-
79
- from pydantic import BaseModel, ValidationError
80
-
81
- from ._agent_graph import CallToolsNode, ModelRequestNode
82
- from .agent import Agent, AgentRun, RunOutputDataT
83
- from .messages import (
84
- AgentStreamEvent,
85
- FunctionToolResultEvent,
86
- ModelMessage,
87
- ModelRequest,
88
- ModelResponse,
89
- PartDeltaEvent,
90
- PartStartEvent,
91
- SystemPromptPart,
92
- TextPart,
93
- TextPartDelta,
94
- ThinkingPart,
95
- ThinkingPartDelta,
96
- ToolCallPart,
97
- ToolCallPartDelta,
98
- ToolReturnPart,
99
- UserPromptPart,
100
- )
101
- from .models import KnownModelName, Model
102
- from .output import DeferredToolCalls, OutputDataT, OutputSpec
103
- from .settings import ModelSettings
104
- from .tools import AgentDepsT, ToolDefinition
105
- from .toolsets import AbstractToolset
106
- from .toolsets.deferred import DeferredToolset
107
- from .usage import Usage, UsageLimits
108
-
109
102
  __all__ = [
110
103
  'SSE_CONTENT_TYPE',
111
104
  'StateDeps',
112
105
  'StateHandler',
113
106
  'AGUIApp',
107
+ 'handle_ag_ui_request',
108
+ 'run_ag_ui',
114
109
  ]
115
110
 
116
111
  SSE_CONTENT_TYPE: Final[str] = 'text/event-stream'
@@ -125,7 +120,7 @@ class AGUIApp(Generic[AgentDepsT, OutputDataT], Starlette):
125
120
  agent: Agent[AgentDepsT, OutputDataT],
126
121
  *,
127
122
  # Agent.iter parameters.
128
- output_type: OutputSpec[OutputDataT] | None = None,
123
+ output_type: OutputSpec[Any] | None = None,
129
124
  model: Model | KnownModelName | str | None = None,
130
125
  deps: AgentDepsT = None,
131
126
  model_settings: ModelSettings | None = None,
@@ -142,10 +137,16 @@ class AGUIApp(Generic[AgentDepsT, OutputDataT], Starlette):
142
137
  on_shutdown: Sequence[Callable[[], Any]] | None = None,
143
138
  lifespan: Lifespan[AGUIApp[AgentDepsT, OutputDataT]] | None = None,
144
139
  ) -> None:
145
- """Initialise the AG-UI application.
140
+ """An ASGI application that handles every AG-UI request by running the agent.
141
+
142
+ Note that the `deps` will be the same for each request, with the exception of the AG-UI state that's
143
+ injected into the `state` field of a `deps` object that implements the [`StateHandler`][pydantic_ai.ag_ui.StateHandler] protocol.
144
+ To provide different `deps` for each request (e.g. based on the authenticated user),
145
+ use [`pydantic_ai.ag_ui.run_ag_ui`][pydantic_ai.ag_ui.run_ag_ui] or
146
+ [`pydantic_ai.ag_ui.handle_ag_ui_request`][pydantic_ai.ag_ui.handle_ag_ui_request] instead.
146
147
 
147
148
  Args:
148
- agent: The Pydantic AI `Agent` to adapt.
149
+ agent: The agent to run.
149
150
 
150
151
  output_type: Custom output type to use for this run, `output_type` may only be used if the agent has
151
152
  no output validators since output validators would expect an argument that matches the agent's
@@ -156,7 +157,7 @@ class AGUIApp(Generic[AgentDepsT, OutputDataT], Starlette):
156
157
  usage_limits: Optional limits on model request count or token usage.
157
158
  usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
158
159
  infer_name: Whether to try to infer the agent name from the call frame if it's not set.
159
- toolsets: Optional list of toolsets to use for this agent, defaults to the agent's toolset.
160
+ toolsets: Optional additional toolsets for this run.
160
161
 
161
162
  debug: Boolean indicating if debug tracebacks should be returned on errors.
162
163
  routes: A list of routes to serve incoming HTTP and WebSocket requests.
@@ -185,320 +186,349 @@ class AGUIApp(Generic[AgentDepsT, OutputDataT], Starlette):
185
186
  on_shutdown=on_shutdown,
186
187
  lifespan=lifespan,
187
188
  )
188
- adapter = _Adapter(agent=agent)
189
189
 
190
- async def endpoint(request: Request) -> Response | StreamingResponse:
190
+ async def endpoint(request: Request) -> Response:
191
191
  """Endpoint to run the agent with the provided input data."""
192
- accept = request.headers.get('accept', SSE_CONTENT_TYPE)
193
- try:
194
- input_data = RunAgentInput.model_validate(await request.json())
195
- except ValidationError as e: # pragma: no cover
196
- return Response(
197
- content=json.dumps(e.json()),
198
- media_type='application/json',
199
- status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
200
- )
201
-
202
- return StreamingResponse(
203
- adapter.run(
204
- input_data,
205
- accept,
206
- output_type=output_type,
207
- model=model,
208
- deps=deps,
209
- model_settings=model_settings,
210
- usage_limits=usage_limits,
211
- usage=usage,
212
- infer_name=infer_name,
213
- toolsets=toolsets,
214
- ),
215
- media_type=SSE_CONTENT_TYPE,
192
+ return await handle_ag_ui_request(
193
+ agent,
194
+ request,
195
+ output_type=output_type,
196
+ model=model,
197
+ deps=deps,
198
+ model_settings=model_settings,
199
+ usage_limits=usage_limits,
200
+ usage=usage,
201
+ infer_name=infer_name,
202
+ toolsets=toolsets,
216
203
  )
217
204
 
218
205
  self.router.add_route('/', endpoint, methods=['POST'], name='run_agent')
219
206
 
220
207
 
221
- @dataclass(repr=False)
222
- class _Adapter(Generic[AgentDepsT, OutputDataT]):
223
- """An agent adapter providing AG-UI protocol support for Pydantic AI agents.
224
-
225
- This class manages the agent runs, tool calls, state storage and providing
226
- an adapter for running agents with Server-Sent Event (SSE) streaming
227
- responses using the AG-UI protocol.
208
+ async def handle_ag_ui_request(
209
+ agent: Agent[AgentDepsT, Any],
210
+ request: Request,
211
+ *,
212
+ output_type: OutputSpec[Any] | None = None,
213
+ model: Model | KnownModelName | str | None = None,
214
+ deps: AgentDepsT = None,
215
+ model_settings: ModelSettings | None = None,
216
+ usage_limits: UsageLimits | None = None,
217
+ usage: Usage | None = None,
218
+ infer_name: bool = True,
219
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
220
+ ) -> Response:
221
+ """Handle an AG-UI request by running the agent and returning a streaming response.
228
222
 
229
223
  Args:
230
- agent: The Pydantic AI `Agent` to adapt.
224
+ agent: The agent to run.
225
+ request: The Starlette request (e.g. from FastAPI) containing the AG-UI run input.
226
+
227
+ output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
228
+ output validators since output validators would expect an argument that matches the agent's output type.
229
+ model: Optional model to use for this run, required if `model` was not set when creating the agent.
230
+ deps: Optional dependencies to use for this run.
231
+ model_settings: Optional settings to use for this model's request.
232
+ usage_limits: Optional limits on model request count or token usage.
233
+ usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
234
+ infer_name: Whether to try to infer the agent name from the call frame if it's not set.
235
+ toolsets: Optional additional toolsets for this run.
236
+
237
+ Returns:
238
+ A streaming Starlette response with AG-UI protocol events.
231
239
  """
240
+ accept = request.headers.get('accept', SSE_CONTENT_TYPE)
241
+ try:
242
+ input_data = RunAgentInput.model_validate(await request.json())
243
+ except ValidationError as e: # pragma: no cover
244
+ return Response(
245
+ content=json.dumps(e.json()),
246
+ media_type='application/json',
247
+ status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
248
+ )
232
249
 
233
- agent: Agent[AgentDepsT, OutputDataT] = field(repr=False)
234
-
235
- async def run(
236
- self,
237
- run_input: RunAgentInput,
238
- accept: str = SSE_CONTENT_TYPE,
239
- *,
240
- output_type: OutputSpec[RunOutputDataT] | None = None,
241
- model: Model | KnownModelName | str | None = None,
242
- deps: AgentDepsT = None,
243
- model_settings: ModelSettings | None = None,
244
- usage_limits: UsageLimits | None = None,
245
- usage: Usage | None = None,
246
- infer_name: bool = True,
247
- toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
248
- ) -> AsyncGenerator[str, None]:
249
- """Run the agent with streaming response using AG-UI protocol events.
250
+ return StreamingResponse(
251
+ run_ag_ui(
252
+ agent,
253
+ input_data,
254
+ accept,
255
+ output_type=output_type,
256
+ model=model,
257
+ deps=deps,
258
+ model_settings=model_settings,
259
+ usage_limits=usage_limits,
260
+ usage=usage,
261
+ infer_name=infer_name,
262
+ toolsets=toolsets,
263
+ ),
264
+ media_type=accept,
265
+ )
250
266
 
251
- The first two arguments are specific to `Adapter` the rest map directly to the `Agent.iter` method.
252
267
 
253
- Args:
254
- run_input: The AG-UI run input containing thread_id, run_id, messages, etc.
255
- accept: The accept header value for the run.
268
+ async def run_ag_ui(
269
+ agent: Agent[AgentDepsT, Any],
270
+ run_input: RunAgentInput,
271
+ accept: str = SSE_CONTENT_TYPE,
272
+ *,
273
+ output_type: OutputSpec[Any] | None = None,
274
+ model: Model | KnownModelName | str | None = None,
275
+ deps: AgentDepsT = None,
276
+ model_settings: ModelSettings | None = None,
277
+ usage_limits: UsageLimits | None = None,
278
+ usage: Usage | None = None,
279
+ infer_name: bool = True,
280
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
281
+ ) -> AsyncIterator[str]:
282
+ """Run the agent with the AG-UI run input and stream AG-UI protocol events.
256
283
 
257
- output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
258
- output validators since output validators would expect an argument that matches the agent's output type.
259
- model: Optional model to use for this run, required if `model` was not set when creating the agent.
260
- deps: Optional dependencies to use for this run.
261
- model_settings: Optional settings to use for this model's request.
262
- usage_limits: Optional limits on model request count or token usage.
263
- usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
264
- infer_name: Whether to try to infer the agent name from the call frame if it's not set.
265
- toolsets: Optional list of toolsets to use for this agent, defaults to the agent's toolset.
284
+ Args:
285
+ agent: The agent to run.
286
+ run_input: The AG-UI run input containing thread_id, run_id, messages, etc.
287
+ accept: The accept header value for the run.
288
+
289
+ output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
290
+ output validators since output validators would expect an argument that matches the agent's output type.
291
+ model: Optional model to use for this run, required if `model` was not set when creating the agent.
292
+ deps: Optional dependencies to use for this run.
293
+ model_settings: Optional settings to use for this model's request.
294
+ usage_limits: Optional limits on model request count or token usage.
295
+ usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
296
+ infer_name: Whether to try to infer the agent name from the call frame if it's not set.
297
+ toolsets: Optional additional toolsets for this run.
298
+
299
+ Yields:
300
+ Streaming event chunks encoded as strings according to the accept header value.
301
+ """
302
+ encoder = EventEncoder(accept=accept)
303
+ if run_input.tools:
304
+ # AG-UI tools can't be prefixed as that would result in a mismatch between the tool names in the
305
+ # Pydantic AI events and actual AG-UI tool names, preventing the tool from being called. If any
306
+ # conflicts arise, the AG-UI tool should be renamed or a `PrefixedToolset` used for local toolsets.
307
+ toolset = DeferredToolset[AgentDepsT](
308
+ [
309
+ ToolDefinition(
310
+ name=tool.name,
311
+ description=tool.description,
312
+ parameters_json_schema=tool.parameters,
313
+ )
314
+ for tool in run_input.tools
315
+ ]
316
+ )
317
+ toolsets = [*toolsets, toolset] if toolsets else [toolset]
318
+
319
+ try:
320
+ yield encoder.encode(
321
+ RunStartedEvent(
322
+ thread_id=run_input.thread_id,
323
+ run_id=run_input.run_id,
324
+ ),
325
+ )
266
326
 
267
- Yields:
268
- Streaming SSE-formatted event chunks.
269
- """
270
- encoder = EventEncoder(accept=accept)
271
- if run_input.tools:
272
- # AG-UI tools can't be prefixed as that would result in a mismatch between the tool names in the
273
- # Pydantic AI events and actual AG-UI tool names, preventing the tool from being called. If any
274
- # conflicts arise, the AG-UI tool should be renamed or a `PrefixedToolset` used for local toolsets.
275
- toolset = DeferredToolset[AgentDepsT](
276
- [
277
- ToolDefinition(
278
- name=tool.name,
279
- description=tool.description,
280
- parameters_json_schema=tool.parameters,
281
- )
282
- for tool in run_input.tools
283
- ]
284
- )
285
- toolsets = [*toolsets, toolset] if toolsets else [toolset]
286
-
287
- try:
288
- yield encoder.encode(
289
- RunStartedEvent(
290
- thread_id=run_input.thread_id,
291
- run_id=run_input.run_id,
292
- ),
293
- )
327
+ if not run_input.messages:
328
+ raise _NoMessagesError
294
329
 
295
- if not run_input.messages:
296
- raise _NoMessagesError
297
-
298
- raw_state: dict[str, Any] = run_input.state or {}
299
- if isinstance(deps, StateHandler):
300
- if isinstance(deps.state, BaseModel):
301
- try:
302
- state = type(deps.state).model_validate(raw_state)
303
- except ValidationError as e: # pragma: no cover
304
- raise _InvalidStateError from e
305
- else:
306
- state = raw_state
307
-
308
- deps = replace(deps, state=state)
309
- elif raw_state:
310
- raise UserError(
311
- f'AG-UI state is provided but `deps` of type `{type(deps).__name__}` does not implement the `StateHandler` protocol: it needs to be a dataclass with a non-optional `state` field.'
312
- )
330
+ raw_state: dict[str, Any] = run_input.state or {}
331
+ if isinstance(deps, StateHandler):
332
+ if isinstance(deps.state, BaseModel):
333
+ try:
334
+ state = type(deps.state).model_validate(raw_state)
335
+ except ValidationError as e: # pragma: no cover
336
+ raise _InvalidStateError from e
313
337
  else:
314
- # `deps` not being a `StateHandler` is OK if there is no state.
315
- pass
316
-
317
- messages = _messages_from_ag_ui(run_input.messages)
338
+ state = raw_state
318
339
 
319
- async with self.agent.iter(
320
- user_prompt=None,
321
- output_type=[output_type or self.agent.output_type, DeferredToolCalls],
322
- message_history=messages,
323
- model=model,
324
- deps=deps,
325
- model_settings=model_settings,
326
- usage_limits=usage_limits,
327
- usage=usage,
328
- infer_name=infer_name,
329
- toolsets=toolsets,
330
- ) as run:
331
- async for event in self._agent_stream(run):
332
- yield encoder.encode(event)
333
- except _RunError as e:
334
- yield encoder.encode(
335
- RunErrorEvent(message=e.message, code=e.code),
336
- )
337
- except Exception as e:
338
- yield encoder.encode(
339
- RunErrorEvent(message=str(e)),
340
+ deps = replace(deps, state=state)
341
+ elif raw_state:
342
+ raise UserError(
343
+ f'AG-UI state is provided but `deps` of type `{type(deps).__name__}` does not implement the `StateHandler` protocol: it needs to be a dataclass with a non-optional `state` field.'
340
344
  )
341
- raise e
342
345
  else:
343
- yield encoder.encode(
344
- RunFinishedEvent(
345
- thread_id=run_input.thread_id,
346
- run_id=run_input.run_id,
347
- ),
348
- )
346
+ # `deps` not being a `StateHandler` is OK if there is no state.
347
+ pass
349
348
 
350
- async def _agent_stream(
351
- self,
352
- run: AgentRun[AgentDepsT, Any],
353
- ) -> AsyncGenerator[BaseEvent, None]:
354
- """Run the agent streaming responses using AG-UI protocol events.
349
+ messages = _messages_from_ag_ui(run_input.messages)
350
+
351
+ async with agent.iter(
352
+ user_prompt=None,
353
+ output_type=[output_type or agent.output_type, DeferredToolCalls],
354
+ message_history=messages,
355
+ model=model,
356
+ deps=deps,
357
+ model_settings=model_settings,
358
+ usage_limits=usage_limits,
359
+ usage=usage,
360
+ infer_name=infer_name,
361
+ toolsets=toolsets,
362
+ ) as run:
363
+ async for event in _agent_stream(run):
364
+ yield encoder.encode(event)
365
+ except _RunError as e:
366
+ yield encoder.encode(
367
+ RunErrorEvent(message=e.message, code=e.code),
368
+ )
369
+ except Exception as e:
370
+ yield encoder.encode(
371
+ RunErrorEvent(message=str(e)),
372
+ )
373
+ raise e
374
+ else:
375
+ yield encoder.encode(
376
+ RunFinishedEvent(
377
+ thread_id=run_input.thread_id,
378
+ run_id=run_input.run_id,
379
+ ),
380
+ )
355
381
 
356
- Args:
357
- run: The agent run to process.
358
382
 
359
- Yields:
360
- AG-UI Server-Sent Events (SSE).
361
- """
362
- async for node in run:
363
- stream_ctx = _RequestStreamContext()
364
- if isinstance(node, ModelRequestNode):
365
- async with node.stream(run.ctx) as request_stream:
366
- async for agent_event in request_stream:
367
- async for msg in self._handle_model_request_event(stream_ctx, agent_event):
383
+ async def _agent_stream(run: AgentRun[AgentDepsT, Any]) -> AsyncIterator[BaseEvent]:
384
+ """Run the agent streaming responses using AG-UI protocol events.
385
+
386
+ Args:
387
+ run: The agent run to process.
388
+
389
+ Yields:
390
+ AG-UI Server-Sent Events (SSE).
391
+ """
392
+ async for node in run:
393
+ stream_ctx = _RequestStreamContext()
394
+ if isinstance(node, ModelRequestNode):
395
+ async with node.stream(run.ctx) as request_stream:
396
+ async for agent_event in request_stream:
397
+ async for msg in _handle_model_request_event(stream_ctx, agent_event):
398
+ yield msg
399
+
400
+ if stream_ctx.part_end: # pragma: no branch
401
+ yield stream_ctx.part_end
402
+ stream_ctx.part_end = None
403
+ elif isinstance(node, CallToolsNode):
404
+ async with node.stream(run.ctx) as handle_stream:
405
+ async for event in handle_stream:
406
+ if isinstance(event, FunctionToolResultEvent):
407
+ async for msg in _handle_tool_result_event(stream_ctx, event):
368
408
  yield msg
369
409
 
370
- if stream_ctx.part_end: # pragma: no branch
371
- yield stream_ctx.part_end
372
- stream_ctx.part_end = None
373
- elif isinstance(node, CallToolsNode):
374
- async with node.stream(run.ctx) as handle_stream:
375
- async for event in handle_stream:
376
- if isinstance(event, FunctionToolResultEvent):
377
- async for msg in self._handle_tool_result_event(stream_ctx, event):
378
- yield msg
379
-
380
- async def _handle_model_request_event(
381
- self,
382
- stream_ctx: _RequestStreamContext,
383
- agent_event: AgentStreamEvent,
384
- ) -> AsyncGenerator[BaseEvent, None]:
385
- """Handle an agent event and yield AG-UI protocol events.
386
410
 
387
- Args:
388
- stream_ctx: The request stream context to manage state.
389
- agent_event: The agent event to process.
411
+ async def _handle_model_request_event(
412
+ stream_ctx: _RequestStreamContext,
413
+ agent_event: AgentStreamEvent,
414
+ ) -> AsyncIterator[BaseEvent]:
415
+ """Handle an agent event and yield AG-UI protocol events.
390
416
 
391
- Yields:
392
- AG-UI Server-Sent Events (SSE) based on the agent event.
393
- """
394
- if isinstance(agent_event, PartStartEvent):
395
- if stream_ctx.part_end:
396
- # End the previous part.
397
- yield stream_ctx.part_end
398
- stream_ctx.part_end = None
399
-
400
- part = agent_event.part
401
- if isinstance(part, TextPart):
402
- message_id = stream_ctx.new_message_id()
403
- yield TextMessageStartEvent(
404
- message_id=message_id,
405
- )
406
- if part.content: # pragma: no branch
407
- yield TextMessageContentEvent(
408
- message_id=message_id,
409
- delta=part.content,
410
- )
411
- stream_ctx.part_end = TextMessageEndEvent(
417
+ Args:
418
+ stream_ctx: The request stream context to manage state.
419
+ agent_event: The agent event to process.
420
+
421
+ Yields:
422
+ AG-UI Server-Sent Events (SSE) based on the agent event.
423
+ """
424
+ if isinstance(agent_event, PartStartEvent):
425
+ if stream_ctx.part_end:
426
+ # End the previous part.
427
+ yield stream_ctx.part_end
428
+ stream_ctx.part_end = None
429
+
430
+ part = agent_event.part
431
+ if isinstance(part, TextPart):
432
+ message_id = stream_ctx.new_message_id()
433
+ yield TextMessageStartEvent(
434
+ message_id=message_id,
435
+ )
436
+ if part.content: # pragma: no branch
437
+ yield TextMessageContentEvent(
412
438
  message_id=message_id,
439
+ delta=part.content,
413
440
  )
414
- elif isinstance(part, ToolCallPart): # pragma: no branch
415
- message_id = stream_ctx.message_id or stream_ctx.new_message_id()
416
- yield ToolCallStartEvent(
417
- tool_call_id=part.tool_call_id,
418
- tool_call_name=part.tool_name,
419
- parent_message_id=message_id,
420
- )
421
- if part.args:
422
- yield ToolCallArgsEvent(
423
- tool_call_id=part.tool_call_id,
424
- delta=part.args if isinstance(part.args, str) else json.dumps(part.args),
425
- )
426
- stream_ctx.part_end = ToolCallEndEvent(
441
+ stream_ctx.part_end = TextMessageEndEvent(
442
+ message_id=message_id,
443
+ )
444
+ elif isinstance(part, ToolCallPart): # pragma: no branch
445
+ message_id = stream_ctx.message_id or stream_ctx.new_message_id()
446
+ yield ToolCallStartEvent(
447
+ tool_call_id=part.tool_call_id,
448
+ tool_call_name=part.tool_name,
449
+ parent_message_id=message_id,
450
+ )
451
+ if part.args:
452
+ yield ToolCallArgsEvent(
427
453
  tool_call_id=part.tool_call_id,
454
+ delta=part.args if isinstance(part.args, str) else json.dumps(part.args),
428
455
  )
456
+ stream_ctx.part_end = ToolCallEndEvent(
457
+ tool_call_id=part.tool_call_id,
458
+ )
429
459
 
430
- elif isinstance(part, ThinkingPart): # pragma: no branch
431
- yield ThinkingTextMessageStartEvent(
432
- type=EventType.THINKING_TEXT_MESSAGE_START,
433
- )
434
- # Always send the content even if it's empty, as it may be
435
- # used to indicate the start of thinking.
460
+ elif isinstance(part, ThinkingPart): # pragma: no branch
461
+ yield ThinkingTextMessageStartEvent(
462
+ type=EventType.THINKING_TEXT_MESSAGE_START,
463
+ )
464
+ # Always send the content even if it's empty, as it may be
465
+ # used to indicate the start of thinking.
466
+ yield ThinkingTextMessageContentEvent(
467
+ type=EventType.THINKING_TEXT_MESSAGE_CONTENT,
468
+ delta=part.content,
469
+ )
470
+ stream_ctx.part_end = ThinkingTextMessageEndEvent(
471
+ type=EventType.THINKING_TEXT_MESSAGE_END,
472
+ )
473
+
474
+ elif isinstance(agent_event, PartDeltaEvent):
475
+ delta = agent_event.delta
476
+ if isinstance(delta, TextPartDelta):
477
+ yield TextMessageContentEvent(
478
+ message_id=stream_ctx.message_id,
479
+ delta=delta.content_delta,
480
+ )
481
+ elif isinstance(delta, ToolCallPartDelta): # pragma: no branch
482
+ assert delta.tool_call_id, '`ToolCallPartDelta.tool_call_id` must be set'
483
+ yield ToolCallArgsEvent(
484
+ tool_call_id=delta.tool_call_id,
485
+ delta=delta.args_delta if isinstance(delta.args_delta, str) else json.dumps(delta.args_delta),
486
+ )
487
+ elif isinstance(delta, ThinkingPartDelta): # pragma: no branch
488
+ if delta.content_delta: # pragma: no branch
436
489
  yield ThinkingTextMessageContentEvent(
437
490
  type=EventType.THINKING_TEXT_MESSAGE_CONTENT,
438
- delta=part.content,
439
- )
440
- stream_ctx.part_end = ThinkingTextMessageEndEvent(
441
- type=EventType.THINKING_TEXT_MESSAGE_END,
442
- )
443
-
444
- elif isinstance(agent_event, PartDeltaEvent):
445
- delta = agent_event.delta
446
- if isinstance(delta, TextPartDelta):
447
- yield TextMessageContentEvent(
448
- message_id=stream_ctx.message_id,
449
491
  delta=delta.content_delta,
450
492
  )
451
- elif isinstance(delta, ToolCallPartDelta): # pragma: no branch
452
- assert delta.tool_call_id, '`ToolCallPartDelta.tool_call_id` must be set'
453
- yield ToolCallArgsEvent(
454
- tool_call_id=delta.tool_call_id,
455
- delta=delta.args_delta if isinstance(delta.args_delta, str) else json.dumps(delta.args_delta),
456
- )
457
- elif isinstance(delta, ThinkingPartDelta): # pragma: no branch
458
- if delta.content_delta: # pragma: no branch
459
- yield ThinkingTextMessageContentEvent(
460
- type=EventType.THINKING_TEXT_MESSAGE_CONTENT,
461
- delta=delta.content_delta,
462
- )
463
493
 
464
- async def _handle_tool_result_event(
465
- self,
466
- stream_ctx: _RequestStreamContext,
467
- event: FunctionToolResultEvent,
468
- ) -> AsyncGenerator[BaseEvent, None]:
469
- """Convert a tool call result to AG-UI events.
470
494
 
471
- Args:
472
- stream_ctx: The request stream context to manage state.
473
- event: The tool call result event to process.
495
+ async def _handle_tool_result_event(
496
+ stream_ctx: _RequestStreamContext,
497
+ event: FunctionToolResultEvent,
498
+ ) -> AsyncIterator[BaseEvent]:
499
+ """Convert a tool call result to AG-UI events.
474
500
 
475
- Yields:
476
- AG-UI Server-Sent Events (SSE).
477
- """
478
- result = event.result
479
- if not isinstance(result, ToolReturnPart):
480
- return
481
-
482
- message_id = stream_ctx.new_message_id()
483
- yield ToolCallResultEvent(
484
- message_id=message_id,
485
- type=EventType.TOOL_CALL_RESULT,
486
- role='tool',
487
- tool_call_id=result.tool_call_id,
488
- content=result.model_response_str(),
489
- )
501
+ Args:
502
+ stream_ctx: The request stream context to manage state.
503
+ event: The tool call result event to process.
490
504
 
491
- # Now check for AG-UI events returned by the tool calls.
492
- content = result.content
493
- if isinstance(content, BaseEvent):
494
- yield content
495
- elif isinstance(content, (str, bytes)): # pragma: no branch
496
- # Avoid iterable check for strings and bytes.
497
- pass
498
- elif isinstance(content, Iterable): # pragma: no branch
499
- for item in content: # type: ignore[reportUnknownMemberType]
500
- if isinstance(item, BaseEvent): # pragma: no branch
501
- yield item
505
+ Yields:
506
+ AG-UI Server-Sent Events (SSE).
507
+ """
508
+ result = event.result
509
+ if not isinstance(result, ToolReturnPart):
510
+ return
511
+
512
+ message_id = stream_ctx.new_message_id()
513
+ yield ToolCallResultEvent(
514
+ message_id=message_id,
515
+ type=EventType.TOOL_CALL_RESULT,
516
+ role='tool',
517
+ tool_call_id=result.tool_call_id,
518
+ content=result.model_response_str(),
519
+ )
520
+
521
+ # Now check for AG-UI events returned by the tool calls.
522
+ content = result.content
523
+ if isinstance(content, BaseEvent):
524
+ yield content
525
+ elif isinstance(content, (str, bytes)): # pragma: no branch
526
+ # Avoid iterable check for strings and bytes.
527
+ pass
528
+ elif isinstance(content, Iterable): # pragma: no branch
529
+ for item in content: # type: ignore[reportUnknownMemberType]
530
+ if isinstance(item, BaseEvent): # pragma: no branch
531
+ yield item
502
532
 
503
533
 
504
534
  def _messages_from_ag_ui(messages: list[Message]) -> list[ModelMessage]: