pydantic-ai-slim 0.0.24__py3-none-any.whl → 0.0.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -3,7 +3,7 @@ from __future__ import annotations as _annotations
3
3
  import asyncio
4
4
  import dataclasses
5
5
  from abc import ABC
6
- from collections.abc import AsyncIterator, Iterator, Sequence
6
+ from collections.abc import AsyncIterator, Iterator
7
7
  from contextlib import asynccontextmanager, contextmanager
8
8
  from contextvars import ContextVar
9
9
  from dataclasses import field
@@ -32,6 +32,16 @@ from .tools import (
32
32
  ToolDefinition,
33
33
  )
34
34
 
35
+ __all__ = (
36
+ 'GraphAgentState',
37
+ 'GraphAgentDeps',
38
+ 'UserPromptNode',
39
+ 'ModelRequestNode',
40
+ 'HandleResponseNode',
41
+ 'build_run_context',
42
+ 'capture_run_messages',
43
+ )
44
+
35
45
  _logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
36
46
 
37
47
  # while waiting for https://github.com/pydantic/logfire/issues/745
@@ -56,21 +66,6 @@ DepsT = TypeVar('DepsT')
56
66
  ResultT = TypeVar('ResultT')
57
67
 
58
68
 
59
- @dataclasses.dataclass
60
- class MarkFinalResult(Generic[ResultDataT]):
61
- """Marker class to indicate that the result is the final result.
62
-
63
- This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultDataT` directly.
64
-
65
- It also avoids problems in the case where the result type is itself `None`, but is set.
66
- """
67
-
68
- data: ResultDataT
69
- """The final result data."""
70
- tool_name: str | None
71
- """Name of the final result tool, None if the result is a string."""
72
-
73
-
74
69
  @dataclasses.dataclass
75
70
  class GraphAgentState:
76
71
  """State kept across the execution of the agent graph."""
@@ -113,17 +108,22 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
113
108
 
114
109
 
115
110
  @dataclasses.dataclass
116
- class BaseUserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT], ABC):
111
+ class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT], ABC):
117
112
  user_prompt: str
118
113
 
119
114
  system_prompts: tuple[str, ...]
120
115
  system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
121
116
  system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]]
122
117
 
118
+ async def run(
119
+ self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
120
+ ) -> ModelRequestNode[DepsT, NodeRunEndT]:
121
+ return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))
122
+
123
123
  async def _get_first_message(
124
124
  self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
125
125
  ) -> _messages.ModelRequest:
126
- run_context = _build_run_context(ctx)
126
+ run_context = build_run_context(ctx)
127
127
  history, next_message = await self._prepare_messages(self.user_prompt, ctx.state.message_history, run_context)
128
128
  ctx.state.message_history = history
129
129
  run_context.messages = history
@@ -188,29 +188,13 @@ class BaseUserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], N
188
188
  return messages
189
189
 
190
190
 
191
- @dataclasses.dataclass
192
- class UserPromptNode(BaseUserPromptNode[DepsT, NodeRunEndT]):
193
- async def run(
194
- self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
195
- ) -> ModelRequestNode[DepsT, NodeRunEndT]:
196
- return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))
197
-
198
-
199
- @dataclasses.dataclass
200
- class StreamUserPromptNode(BaseUserPromptNode[DepsT, NodeRunEndT]):
201
- async def run(
202
- self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
203
- ) -> StreamModelRequestNode[DepsT, NodeRunEndT]:
204
- return StreamModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))
205
-
206
-
207
191
  async def _prepare_request_parameters(
208
192
  ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
209
193
  ) -> models.ModelRequestParameters:
210
194
  """Build tools and create an agent model."""
211
195
  function_tool_defs: list[ToolDefinition] = []
212
196
 
213
- run_context = _build_run_context(ctx)
197
+ run_context = build_run_context(ctx)
214
198
 
215
199
  async def add_tool(tool: Tool[DepsT]) -> None:
216
200
  ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name)
@@ -222,7 +206,7 @@ async def _prepare_request_parameters(
222
206
  result_schema = ctx.deps.result_schema
223
207
  return models.ModelRequestParameters(
224
208
  function_tools=function_tool_defs,
225
- allow_text_result=_allow_text_result(result_schema),
209
+ allow_text_result=allow_text_result(result_schema),
226
210
  result_tools=result_schema.tool_defs() if result_schema is not None else [],
227
211
  )
228
212
 
@@ -233,9 +217,70 @@ class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], Nod
233
217
 
234
218
  request: _messages.ModelRequest
235
219
 
220
+ _result: HandleResponseNode[DepsT, NodeRunEndT] | None = field(default=None, repr=False)
221
+ _did_stream: bool = field(default=False, repr=False)
222
+
236
223
  async def run(
237
224
  self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
238
225
  ) -> HandleResponseNode[DepsT, NodeRunEndT]:
226
+ if self._result is not None:
227
+ return self._result
228
+
229
+ if self._did_stream:
230
+ # `self._result` gets set when exiting the `stream` contextmanager, so hitting this
231
+ # means that the stream was started but not finished before `run()` was called
232
+ raise exceptions.AgentRunError('You must finish streaming before calling run()')
233
+
234
+ return await self._make_request(ctx)
235
+
236
+ @asynccontextmanager
237
+ async def _stream(
238
+ self,
239
+ ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
240
+ ) -> AsyncIterator[models.StreamedResponse]:
241
+ # TODO: Consider changing this to return something more similar to a `StreamedRunResult`, then make it public
242
+ assert not self._did_stream, 'stream() should only be called once per node'
243
+
244
+ model_settings, model_request_parameters = await self._prepare_request(ctx)
245
+ with _logfire.span('model request', run_step=ctx.state.run_step) as span:
246
+ async with ctx.deps.model.request_stream(
247
+ ctx.state.message_history, model_settings, model_request_parameters
248
+ ) as streamed_response:
249
+ self._did_stream = True
250
+ ctx.state.usage.incr(_usage.Usage(), requests=1)
251
+ yield streamed_response
252
+ # In case the user didn't manually consume the full stream, ensure it is fully consumed here,
253
+ # otherwise usage won't be properly counted:
254
+ async for _ in streamed_response:
255
+ pass
256
+ model_response = streamed_response.get()
257
+ request_usage = streamed_response.usage()
258
+ span.set_attribute('response', model_response)
259
+ span.set_attribute('usage', request_usage)
260
+
261
+ self._finish_handling(ctx, model_response, request_usage)
262
+ assert self._result is not None # this should be set by the previous line
263
+
264
+ async def _make_request(
265
+ self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
266
+ ) -> HandleResponseNode[DepsT, NodeRunEndT]:
267
+ if self._result is not None:
268
+ return self._result
269
+
270
+ model_settings, model_request_parameters = await self._prepare_request(ctx)
271
+ with _logfire.span('model request', run_step=ctx.state.run_step) as span:
272
+ model_response, request_usage = await ctx.deps.model.request(
273
+ ctx.state.message_history, model_settings, model_request_parameters
274
+ )
275
+ ctx.state.usage.incr(_usage.Usage(), requests=1)
276
+ span.set_attribute('response', model_response)
277
+ span.set_attribute('usage', request_usage)
278
+
279
+ return self._finish_handling(ctx, model_response, request_usage)
280
+
281
+ async def _prepare_request(
282
+ self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
283
+ ) -> tuple[ModelSettings | None, models.ModelRequestParameters]:
239
284
  ctx.state.message_history.append(self.request)
240
285
 
241
286
  # Check usage
@@ -245,71 +290,124 @@ class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], Nod
245
290
  # Increment run_step
246
291
  ctx.state.run_step += 1
247
292
 
293
+ model_settings = merge_model_settings(ctx.deps.model_settings, None)
248
294
  with _logfire.span('preparing model request params {run_step=}', run_step=ctx.state.run_step):
249
295
  model_request_parameters = await _prepare_request_parameters(ctx)
296
+ return model_settings, model_request_parameters
250
297
 
251
- # Actually make the model request
252
- model_settings = merge_model_settings(ctx.deps.model_settings, None)
253
- with _logfire.span('model request') as span:
254
- model_response, request_usage = await ctx.deps.model.request(
255
- ctx.state.message_history, model_settings, model_request_parameters
256
- )
257
- span.set_attribute('response', model_response)
258
- span.set_attribute('usage', request_usage)
259
-
298
+ def _finish_handling(
299
+ self,
300
+ ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
301
+ response: _messages.ModelResponse,
302
+ usage: _usage.Usage,
303
+ ) -> HandleResponseNode[DepsT, NodeRunEndT]:
260
304
  # Update usage
261
- ctx.state.usage.incr(request_usage, requests=1)
305
+ ctx.state.usage.incr(usage, requests=0)
262
306
  if ctx.deps.usage_limits:
263
307
  ctx.deps.usage_limits.check_tokens(ctx.state.usage)
264
308
 
265
309
  # Append the model response to state.message_history
266
- ctx.state.message_history.append(model_response)
267
- return HandleResponseNode(model_response)
310
+ ctx.state.message_history.append(response)
311
+
312
+ # Set the `_result` attribute since we can't use `return` in an async iterator
313
+ self._result = HandleResponseNode(response)
314
+
315
+ return self._result
268
316
 
269
317
 
270
318
  @dataclasses.dataclass
271
319
  class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]):
272
- """Process e response from a model, decide whether to end the run or make a new request."""
320
+ """Process a model response, and decide whether to end the run or make a new request."""
273
321
 
274
322
  model_response: _messages.ModelResponse
275
323
 
324
+ _events_iterator: AsyncIterator[_messages.HandleResponseEvent] | None = field(default=None, repr=False)
325
+ _next_node: ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]] | None = field(
326
+ default=None, repr=False
327
+ )
328
+ _tool_responses: list[_messages.ModelRequestPart] = field(default_factory=list, repr=False)
329
+
276
330
  async def run(
277
331
  self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
278
- ) -> Union[ModelRequestNode[DepsT, NodeRunEndT], FinalResultNode[DepsT, NodeRunEndT]]: # noqa UP007
332
+ ) -> Union[ModelRequestNode[DepsT, NodeRunEndT], End[result.FinalResult[NodeRunEndT]]]: # noqa UP007
333
+ async with self.stream(ctx):
334
+ pass
335
+
336
+ assert (next_node := self._next_node) is not None, 'the stream should set `self._next_node` before it ends'
337
+ return next_node
338
+
339
+ @asynccontextmanager
340
+ async def stream(
341
+ self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
342
+ ) -> AsyncIterator[AsyncIterator[_messages.HandleResponseEvent]]:
343
+ """Process the model response and yield events for the start and end of each function tool call."""
279
344
  with _logfire.span('handle model response', run_step=ctx.state.run_step) as handle_span:
280
- texts: list[str] = []
281
- tool_calls: list[_messages.ToolCallPart] = []
282
- for part in self.model_response.parts:
283
- if isinstance(part, _messages.TextPart):
284
- # ignore empty content for text parts, see #437
285
- if part.content:
286
- texts.append(part.content)
287
- elif isinstance(part, _messages.ToolCallPart):
288
- tool_calls.append(part)
345
+ stream = self._run_stream(ctx)
346
+ yield stream
347
+
348
+ # Run the stream to completion if it was not finished:
349
+ async for _event in stream:
350
+ pass
351
+
352
+ # Set the next node based on the final state of the stream
353
+ next_node = self._next_node
354
+ if isinstance(next_node, End):
355
+ handle_span.set_attribute('result', next_node.data)
356
+ handle_span.message = 'handle model response -> final result'
357
+ elif tool_responses := self._tool_responses:
358
+ # TODO: We could drop `self._tool_responses` if we drop this set_attribute
359
+ # I'm thinking it might be better to just create a span for the handling of each tool
360
+ # than to set an attribute here.
361
+ handle_span.set_attribute('tool_responses', tool_responses)
362
+ tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
363
+ handle_span.message = f'handle model response -> {tool_responses_str}'
364
+
365
+ async def _run_stream(
366
+ self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
367
+ ) -> AsyncIterator[_messages.HandleResponseEvent]:
368
+ if self._events_iterator is None:
369
+ # Ensure that the stream is only run once
370
+
371
+ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
372
+ texts: list[str] = []
373
+ tool_calls: list[_messages.ToolCallPart] = []
374
+ for part in self.model_response.parts:
375
+ if isinstance(part, _messages.TextPart):
376
+ # ignore empty content for text parts, see #437
377
+ if part.content:
378
+ texts.append(part.content)
379
+ elif isinstance(part, _messages.ToolCallPart):
380
+ tool_calls.append(part)
381
+ else:
382
+ assert_never(part)
383
+
384
+ # At the moment, we prioritize at least executing tool calls if they are present.
385
+ # In the future, we'd consider making this configurable at the agent or run level.
386
+ # This accounts for cases like anthropic returns that might contain a text response
387
+ # and a tool call response, where the text response just indicates the tool call will happen.
388
+ if tool_calls:
389
+ async for event in self._handle_tool_calls(ctx, tool_calls):
390
+ yield event
391
+ elif texts:
392
+ # No events are emitted during the handling of text responses, so we don't need to yield anything
393
+ self._next_node = await self._handle_text_response(ctx, texts)
289
394
  else:
290
- assert_never(part)
291
-
292
- # At the moment, we prioritize at least executing tool calls if they are present.
293
- # In the future, we'd consider making this configurable at the agent or run level.
294
- # This accounts for cases like anthropic returns that might contain a text response
295
- # and a tool call response, where the text response just indicates the tool call will happen.
296
- if tool_calls:
297
- return await self._handle_tool_calls_response(ctx, tool_calls, handle_span)
298
- elif texts:
299
- return await self._handle_text_response(ctx, texts, handle_span)
300
- else:
301
- raise exceptions.UnexpectedModelBehavior('Received empty model response')
395
+ raise exceptions.UnexpectedModelBehavior('Received empty model response')
302
396
 
303
- async def _handle_tool_calls_response(
397
+ self._events_iterator = _run_stream()
398
+
399
+ async for event in self._events_iterator:
400
+ yield event
401
+
402
+ async def _handle_tool_calls(
304
403
  self,
305
404
  ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
306
405
  tool_calls: list[_messages.ToolCallPart],
307
- handle_span: logfire_api.LogfireSpan,
308
- ):
406
+ ) -> AsyncIterator[_messages.HandleResponseEvent]:
309
407
  result_schema = ctx.deps.result_schema
310
408
 
311
409
  # first look for the result tool call
312
- final_result: MarkFinalResult[NodeRunEndT] | None = None
410
+ final_result: result.FinalResult[NodeRunEndT] | None = None
313
411
  parts: list[_messages.ModelRequestPart] = []
314
412
  if result_schema is not None:
315
413
  if match := result_schema.find_tool(tool_calls):
@@ -323,33 +421,51 @@ class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], N
323
421
  ctx.state.increment_retries(ctx.deps.max_result_retries)
324
422
  parts.append(e.tool_retry)
325
423
  else:
326
- final_result = MarkFinalResult(result_data, call.tool_name)
424
+ final_result = result.FinalResult(result_data, call.tool_name)
327
425
 
328
426
  # Then build the other request parts based on end strategy
329
- tool_responses = await _process_function_tools(tool_calls, final_result and final_result.tool_name, ctx)
427
+ tool_responses: list[_messages.ModelRequestPart] = self._tool_responses
428
+ async for event in process_function_tools(
429
+ tool_calls, final_result and final_result.tool_name, ctx, tool_responses
430
+ ):
431
+ yield event
330
432
 
331
433
  if final_result:
332
- handle_span.set_attribute('result', final_result.data)
333
- handle_span.message = 'handle model response -> final result'
334
- return FinalResultNode[DepsT, NodeRunEndT](final_result, tool_responses)
434
+ self._next_node = self._handle_final_result(ctx, final_result, tool_responses)
335
435
  else:
336
436
  if tool_responses:
337
- handle_span.set_attribute('tool_responses', tool_responses)
338
- tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
339
- handle_span.message = f'handle model response -> {tool_responses_str}'
340
437
  parts.extend(tool_responses)
341
- return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=parts))
438
+ self._next_node = ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=parts))
439
+
440
+ def _handle_final_result(
441
+ self,
442
+ ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
443
+ final_result: result.FinalResult[NodeRunEndT],
444
+ tool_responses: list[_messages.ModelRequestPart],
445
+ ) -> End[result.FinalResult[NodeRunEndT]]:
446
+ run_span = ctx.deps.run_span
447
+ usage = ctx.state.usage
448
+ messages = ctx.state.message_history
449
+
450
+ # For backwards compatibility, append a new ModelRequest using the tool returns and retries
451
+ if tool_responses:
452
+ messages.append(_messages.ModelRequest(parts=tool_responses))
453
+
454
+ run_span.set_attribute('usage', usage)
455
+ run_span.set_attribute('all_messages', messages)
456
+
457
+ # End the run with self.data
458
+ return End(final_result)
342
459
 
343
460
  async def _handle_text_response(
344
461
  self,
345
462
  ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
346
463
  texts: list[str],
347
- handle_span: logfire_api.LogfireSpan,
348
- ):
464
+ ) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
349
465
  result_schema = ctx.deps.result_schema
350
466
 
351
467
  text = '\n\n'.join(texts)
352
- if _allow_text_result(result_schema):
468
+ if allow_text_result(result_schema):
353
469
  result_data_input = cast(NodeRunEndT, text)
354
470
  try:
355
471
  result_data = await _validate_result(result_data_input, ctx, None)
@@ -357,9 +473,8 @@ class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], N
357
473
  ctx.state.increment_retries(ctx.deps.max_result_retries)
358
474
  return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
359
475
  else:
360
- handle_span.set_attribute('result', result_data)
361
- handle_span.message = 'handle model response -> final result'
362
- return FinalResultNode[DepsT, NodeRunEndT](MarkFinalResult(result_data, None))
476
+ # The following cast is safe because we know `str` is an allowed result type
477
+ return self._handle_final_result(ctx, result.FinalResult(result_data, tool_name=None), [])
363
478
  else:
364
479
  ctx.state.increment_retries(ctx.deps.max_result_retries)
365
480
  return ModelRequestNode[DepsT, NodeRunEndT](
@@ -373,166 +488,8 @@ class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], N
373
488
  )
374
489
 
375
490
 
376
- @dataclasses.dataclass
377
- class StreamModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]):
378
- """Make a request to the model using the last message in state.message_history (or a specified request)."""
379
-
380
- request: _messages.ModelRequest
381
- _result: StreamModelRequestNode[DepsT, NodeRunEndT] | End[result.StreamedRunResult[DepsT, NodeRunEndT]] | None = (
382
- field(default=None, repr=False)
383
- )
384
-
385
- async def run(
386
- self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
387
- ) -> Union[StreamModelRequestNode[DepsT, NodeRunEndT], End[result.StreamedRunResult[DepsT, NodeRunEndT]]]: # noqa UP007
388
- if self._result is not None:
389
- return self._result
390
-
391
- async with self.run_to_result(ctx) as final_node:
392
- return final_node
393
-
394
- @asynccontextmanager
395
- async def run_to_result(
396
- self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
397
- ) -> AsyncIterator[StreamModelRequestNode[DepsT, NodeRunEndT] | End[result.StreamedRunResult[DepsT, NodeRunEndT]]]:
398
- result_schema = ctx.deps.result_schema
399
-
400
- ctx.state.message_history.append(self.request)
401
-
402
- # Check usage
403
- if ctx.deps.usage_limits:
404
- ctx.deps.usage_limits.check_before_request(ctx.state.usage)
405
-
406
- # Increment run_step
407
- ctx.state.run_step += 1
408
-
409
- with _logfire.span('preparing model and tools {run_step=}', run_step=ctx.state.run_step):
410
- model_request_parameters = await _prepare_request_parameters(ctx)
411
-
412
- # Actually make the model request
413
- model_settings = merge_model_settings(ctx.deps.model_settings, None)
414
- with _logfire.span('model request {run_step=}', run_step=ctx.state.run_step) as model_req_span:
415
- async with ctx.deps.model.request_stream(
416
- ctx.state.message_history, model_settings, model_request_parameters
417
- ) as streamed_response:
418
- ctx.state.usage.requests += 1
419
- model_req_span.set_attribute('response_type', streamed_response.__class__.__name__)
420
- # We want to end the "model request" span here, but we can't exit the context manager
421
- # in the traditional way
422
- model_req_span.__exit__(None, None, None)
423
-
424
- with _logfire.span('handle model response') as handle_span:
425
- received_text = False
426
-
427
- async for maybe_part_event in streamed_response:
428
- if isinstance(maybe_part_event, _messages.PartStartEvent):
429
- new_part = maybe_part_event.part
430
- if isinstance(new_part, _messages.TextPart):
431
- received_text = True
432
- if _allow_text_result(result_schema):
433
- handle_span.message = 'handle model response -> final result'
434
- streamed_run_result = _build_streamed_run_result(streamed_response, None, ctx)
435
- self._result = End(streamed_run_result)
436
- yield self._result
437
- return
438
- elif isinstance(new_part, _messages.ToolCallPart):
439
- if result_schema is not None and (match := result_schema.find_tool([new_part])):
440
- call, _ = match
441
- handle_span.message = 'handle model response -> final result'
442
- streamed_run_result = _build_streamed_run_result(
443
- streamed_response, call.tool_name, ctx
444
- )
445
- self._result = End(streamed_run_result)
446
- yield self._result
447
- return
448
- else:
449
- assert_never(new_part)
450
-
451
- tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
452
- parts: list[_messages.ModelRequestPart] = []
453
- model_response = streamed_response.get()
454
- if not model_response.parts:
455
- raise exceptions.UnexpectedModelBehavior('Received empty model response')
456
- ctx.state.message_history.append(model_response)
457
-
458
- run_context = _build_run_context(ctx)
459
- for p in model_response.parts:
460
- if isinstance(p, _messages.ToolCallPart):
461
- if tool := ctx.deps.function_tools.get(p.tool_name):
462
- tasks.append(asyncio.create_task(tool.run(p, run_context), name=p.tool_name))
463
- else:
464
- parts.append(_unknown_tool(p.tool_name, ctx))
465
-
466
- if received_text and not tasks and not parts:
467
- # Can only get here if self._allow_text_result returns `False` for the provided result_schema
468
- ctx.state.increment_retries(ctx.deps.max_result_retries)
469
- self._result = StreamModelRequestNode[DepsT, NodeRunEndT](
470
- _messages.ModelRequest(
471
- parts=[
472
- _messages.RetryPromptPart(
473
- content='Plain text responses are not permitted, please call one of the functions instead.',
474
- )
475
- ]
476
- )
477
- )
478
- yield self._result
479
- return
480
-
481
- with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
482
- task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
483
- parts.extend(task_results)
484
-
485
- next_request = _messages.ModelRequest(parts=parts)
486
- if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
487
- try:
488
- ctx.state.increment_retries(ctx.deps.max_result_retries)
489
- except:
490
- # TODO: This is janky, so I think we should probably change it, but how?
491
- ctx.state.message_history.append(next_request)
492
- raise
493
-
494
- handle_span.set_attribute('tool_responses', parts)
495
- tool_responses_str = ' '.join(r.part_kind for r in parts)
496
- handle_span.message = f'handle model response -> {tool_responses_str}'
497
- # the model_response should have been fully streamed by now, we can add its usage
498
- streamed_response_usage = streamed_response.usage()
499
- run_context.usage.incr(streamed_response_usage)
500
- ctx.deps.usage_limits.check_tokens(run_context.usage)
501
- self._result = StreamModelRequestNode[DepsT, NodeRunEndT](next_request)
502
- yield self._result
503
- return
504
-
505
-
506
- @dataclasses.dataclass
507
- class FinalResultNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], MarkFinalResult[NodeRunEndT]]):
508
- """Produce the final result of the run."""
509
-
510
- data: MarkFinalResult[NodeRunEndT]
511
- """The final result data."""
512
- extra_parts: list[_messages.ModelRequestPart] = dataclasses.field(default_factory=list)
513
-
514
- async def run(
515
- self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
516
- ) -> End[MarkFinalResult[NodeRunEndT]]:
517
- run_span = ctx.deps.run_span
518
- usage = ctx.state.usage
519
- messages = ctx.state.message_history
520
-
521
- # TODO: For backwards compatibility, append a new ModelRequest using the tool returns and retries
522
- if self.extra_parts:
523
- messages.append(_messages.ModelRequest(parts=self.extra_parts))
524
-
525
- # TODO: Set this attribute somewhere
526
- # handle_span = self.handle_model_response_span
527
- # handle_span.set_attribute('final_data', self.data)
528
- run_span.set_attribute('usage', usage)
529
- run_span.set_attribute('all_messages', messages)
530
-
531
- # End the run with self.data
532
- return End(self.data)
533
-
534
-
535
- def _build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]:
491
+ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]:
492
+ """Build a `RunContext` object from the current agent graph run context."""
536
493
  return RunContext[DepsT](
537
494
  deps=ctx.deps.user_deps,
538
495
  model=ctx.deps.model,
@@ -543,76 +500,31 @@ def _build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[Deps
543
500
  )
544
501
 
545
502
 
546
- def _build_streamed_run_result(
547
- result_stream: models.StreamedResponse,
548
- result_tool_name: str | None,
549
- ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
550
- ) -> result.StreamedRunResult[DepsT, NodeRunEndT]:
551
- new_message_index = ctx.deps.new_message_index
552
- result_schema = ctx.deps.result_schema
553
- run_span = ctx.deps.run_span
554
- usage_limits = ctx.deps.usage_limits
555
- messages = ctx.state.message_history
556
- run_context = _build_run_context(ctx)
557
-
558
- async def on_complete():
559
- """Called when the stream has completed.
560
-
561
- The model response will have been added to messages by now
562
- by `StreamedRunResult._marked_completed`.
563
- """
564
- last_message = messages[-1]
565
- assert isinstance(last_message, _messages.ModelResponse)
566
- tool_calls = [part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)]
567
- parts = await _process_function_tools(
568
- tool_calls,
569
- result_tool_name,
570
- ctx,
571
- )
572
- # TODO: Should we do something here related to the retry count?
573
- # Maybe we should move the incrementing of the retry count to where we actually make a request?
574
- # if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
575
- # ctx.state.increment_retries(ctx.deps.max_result_retries)
576
- if parts:
577
- messages.append(_messages.ModelRequest(parts))
578
- run_span.set_attribute('all_messages', messages)
579
-
580
- return result.StreamedRunResult[DepsT, NodeRunEndT](
581
- messages,
582
- new_message_index,
583
- usage_limits,
584
- result_stream,
585
- result_schema,
586
- run_context,
587
- ctx.deps.result_validators,
588
- result_tool_name,
589
- on_complete,
590
- )
591
-
592
-
593
- async def _process_function_tools(
503
+ async def process_function_tools(
594
504
  tool_calls: list[_messages.ToolCallPart],
595
505
  result_tool_name: str | None,
596
506
  ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
597
- ) -> list[_messages.ModelRequestPart]:
598
- """Process function (non-result) tool calls in parallel.
507
+ output_parts: list[_messages.ModelRequestPart],
508
+ ) -> AsyncIterator[_messages.HandleResponseEvent]:
509
+ """Process function (i.e., non-result) tool calls in parallel.
599
510
 
600
511
  Also add stub return parts for any other tools that need it.
601
- """
602
- parts: list[_messages.ModelRequestPart] = []
603
- tasks: list[asyncio.Task[_messages.ToolReturnPart | _messages.RetryPromptPart]] = []
604
512
 
513
+ Because async iterators can't have return values, we use `output_parts` as an output argument.
514
+ """
605
515
  stub_function_tools = bool(result_tool_name) and ctx.deps.end_strategy == 'early'
606
516
  result_schema = ctx.deps.result_schema
607
517
 
608
518
  # we rely on the fact that if we found a result, it's the first result tool in the last
609
519
  found_used_result_tool = False
610
- run_context = _build_run_context(ctx)
520
+ run_context = build_run_context(ctx)
611
521
 
522
+ calls_to_run: list[tuple[Tool[DepsT], _messages.ToolCallPart]] = []
523
+ call_index_to_event_id: dict[int, str] = {}
612
524
  for call in tool_calls:
613
525
  if call.tool_name == result_tool_name and not found_used_result_tool:
614
526
  found_used_result_tool = True
615
- parts.append(
527
+ output_parts.append(
616
528
  _messages.ToolReturnPart(
617
529
  tool_name=call.tool_name,
618
530
  content='Final result processed.',
@@ -621,7 +533,7 @@ async def _process_function_tools(
621
533
  )
622
534
  elif tool := ctx.deps.function_tools.get(call.tool_name):
623
535
  if stub_function_tools:
624
- parts.append(
536
+ output_parts.append(
625
537
  _messages.ToolReturnPart(
626
538
  tool_name=call.tool_name,
627
539
  content='Tool not executed - a final result was already processed.',
@@ -629,33 +541,47 @@ async def _process_function_tools(
629
541
  )
630
542
  )
631
543
  else:
632
- tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
544
+ event = _messages.FunctionToolCallEvent(call)
545
+ yield event
546
+ call_index_to_event_id[len(calls_to_run)] = event.call_id
547
+ calls_to_run.append((tool, call))
633
548
  elif result_schema is not None and call.tool_name in result_schema.tools:
634
549
  # if tool_name is in _result_schema, it means we found a result tool but an error occurred in
635
550
  # validation, we don't add another part here
636
551
  if result_tool_name is not None:
637
- parts.append(
638
- _messages.ToolReturnPart(
639
- tool_name=call.tool_name,
640
- content='Result tool not used - a final result was already processed.',
641
- tool_call_id=call.tool_call_id,
642
- )
552
+ part = _messages.ToolReturnPart(
553
+ tool_name=call.tool_name,
554
+ content='Result tool not used - a final result was already processed.',
555
+ tool_call_id=call.tool_call_id,
643
556
  )
557
+ output_parts.append(part)
644
558
  else:
645
- parts.append(_unknown_tool(call.tool_name, ctx))
559
+ output_parts.append(_unknown_tool(call.tool_name, ctx))
560
+
561
+ if not calls_to_run:
562
+ return
646
563
 
647
564
  # Run all tool tasks in parallel
648
- if tasks:
649
- with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
650
- task_results: Sequence[_messages.ToolReturnPart | _messages.RetryPromptPart] = await asyncio.gather(*tasks)
651
- for result in task_results:
652
- if isinstance(result, _messages.ToolReturnPart):
653
- parts.append(result)
654
- elif isinstance(result, _messages.RetryPromptPart):
655
- parts.append(result)
565
+ results_by_index: dict[int, _messages.ModelRequestPart] = {}
566
+ with _logfire.span('running {tools=}', tools=[call.tool_name for _, call in calls_to_run]):
567
+ # TODO: Should we wrap each individual tool call in a dedicated span?
568
+ tasks = [asyncio.create_task(tool.run(call, run_context), name=call.tool_name) for tool, call in calls_to_run]
569
+ pending = tasks
570
+ while pending:
571
+ done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
572
+ for task in done:
573
+ index = tasks.index(task)
574
+ result = task.result()
575
+ yield _messages.FunctionToolResultEvent(result, call_id=call_index_to_event_id[index])
576
+ if isinstance(result, (_messages.ToolReturnPart, _messages.RetryPromptPart)):
577
+ results_by_index[index] = result
656
578
  else:
657
579
  assert_never(result)
658
- return parts
580
+
581
+ # We append the results at the end, rather than as they are received, to retain a consistent ordering
582
+ # This is mostly just to simplify testing
583
+ for k in sorted(results_by_index):
584
+ output_parts.append(results_by_index[k])
659
585
 
660
586
 
661
587
  def _unknown_tool(
@@ -681,12 +607,13 @@ async def _validate_result(
681
607
  tool_call: _messages.ToolCallPart | None,
682
608
  ) -> T:
683
609
  for validator in ctx.deps.result_validators:
684
- run_context = _build_run_context(ctx)
610
+ run_context = build_run_context(ctx)
685
611
  result_data = await validator.validate(result_data, tool_call, run_context)
686
612
  return result_data
687
613
 
688
614
 
689
- def _allow_text_result(result_schema: _result.ResultSchema[Any] | None) -> bool:
615
+ def allow_text_result(result_schema: _result.ResultSchema[Any] | None) -> bool:
616
+ """Check if the result schema allows text results."""
690
617
  return result_schema is None or result_schema.allow_text_result
691
618
 
692
619
 
@@ -740,35 +667,18 @@ def get_captured_run_messages() -> _RunMessages:
740
667
 
741
668
  def build_agent_graph(
742
669
  name: str | None, deps_type: type[DepsT], result_type: type[ResultT]
743
- ) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], MarkFinalResult[ResultT]]:
744
- # We'll define the known node classes:
670
+ ) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[ResultT]]:
671
+ """Build the execution [Graph][pydantic_graph.Graph] for a given agent."""
745
672
  nodes = (
746
673
  UserPromptNode[DepsT],
747
674
  ModelRequestNode[DepsT],
748
675
  HandleResponseNode[DepsT],
749
- FinalResultNode[DepsT, ResultT],
750
676
  )
751
- graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], MarkFinalResult[ResultT]](
677
+ graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[ResultT]](
752
678
  nodes=nodes,
753
679
  name=name or 'Agent',
754
680
  state_type=GraphAgentState,
755
- run_end_type=MarkFinalResult[result_type],
681
+ run_end_type=result.FinalResult[result_type],
756
682
  auto_instrument=False,
757
683
  )
758
684
  return graph
759
-
760
-
761
- def build_agent_stream_graph(
762
- name: str | None, deps_type: type[DepsT], result_type: type[ResultT] | None
763
- ) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.StreamedRunResult[DepsT, Any]]:
764
- nodes = [
765
- StreamUserPromptNode[DepsT, result.StreamedRunResult[DepsT, ResultT]],
766
- StreamModelRequestNode[DepsT, result.StreamedRunResult[DepsT, ResultT]],
767
- ]
768
- graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.StreamedRunResult[DepsT, Any]](
769
- nodes=nodes,
770
- name=name or 'Agent',
771
- state_type=GraphAgentState,
772
- run_end_type=result.StreamedRunResult[DepsT, result_type],
773
- )
774
- return graph