pydantic-ai-slim 0.0.24__py3-none-any.whl → 0.0.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -32,6 +32,16 @@ from .tools import (
32
32
  ToolDefinition,
33
33
  )
34
34
 
35
+ __all__ = (
36
+ 'GraphAgentState',
37
+ 'GraphAgentDeps',
38
+ 'UserPromptNode',
39
+ 'ModelRequestNode',
40
+ 'HandleResponseNode',
41
+ 'build_run_context',
42
+ 'capture_run_messages',
43
+ )
44
+
35
45
  _logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
36
46
 
37
47
  # while waiting for https://github.com/pydantic/logfire/issues/745
@@ -56,21 +66,6 @@ DepsT = TypeVar('DepsT')
56
66
  ResultT = TypeVar('ResultT')
57
67
 
58
68
 
59
- @dataclasses.dataclass
60
- class MarkFinalResult(Generic[ResultDataT]):
61
- """Marker class to indicate that the result is the final result.
62
-
63
- This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultDataT` directly.
64
-
65
- It also avoids problems in the case where the result type is itself `None`, but is set.
66
- """
67
-
68
- data: ResultDataT
69
- """The final result data."""
70
- tool_name: str | None
71
- """Name of the final result tool, None if the result is a string."""
72
-
73
-
74
69
  @dataclasses.dataclass
75
70
  class GraphAgentState:
76
71
  """State kept across the execution of the agent graph."""
@@ -94,7 +89,7 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
94
89
 
95
90
  user_deps: DepsT
96
91
 
97
- prompt: str
92
+ prompt: str | Sequence[_messages.UserContent]
98
93
  new_message_index: int
99
94
 
100
95
  model: models.Model
@@ -113,17 +108,22 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
113
108
 
114
109
 
115
110
  @dataclasses.dataclass
116
- class BaseUserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT], ABC):
117
- user_prompt: str
111
+ class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]], ABC):
112
+ user_prompt: str | Sequence[_messages.UserContent]
118
113
 
119
114
  system_prompts: tuple[str, ...]
120
115
  system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
121
116
  system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]]
122
117
 
118
+ async def run(
119
+ self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
120
+ ) -> ModelRequestNode[DepsT, NodeRunEndT]:
121
+ return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))
122
+
123
123
  async def _get_first_message(
124
- self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
124
+ self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
125
125
  ) -> _messages.ModelRequest:
126
- run_context = _build_run_context(ctx)
126
+ run_context = build_run_context(ctx)
127
127
  history, next_message = await self._prepare_messages(self.user_prompt, ctx.state.message_history, run_context)
128
128
  ctx.state.message_history = history
129
129
  run_context.messages = history
@@ -135,7 +135,10 @@ class BaseUserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], N
135
135
  return next_message
136
136
 
137
137
  async def _prepare_messages(
138
- self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[DepsT]
138
+ self,
139
+ user_prompt: str | Sequence[_messages.UserContent],
140
+ message_history: list[_messages.ModelMessage] | None,
141
+ run_context: RunContext[DepsT],
139
142
  ) -> tuple[list[_messages.ModelMessage], _messages.ModelRequest]:
140
143
  try:
141
144
  ctx_messages = get_captured_run_messages()
@@ -188,29 +191,13 @@ class BaseUserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], N
188
191
  return messages
189
192
 
190
193
 
191
- @dataclasses.dataclass
192
- class UserPromptNode(BaseUserPromptNode[DepsT, NodeRunEndT]):
193
- async def run(
194
- self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
195
- ) -> ModelRequestNode[DepsT, NodeRunEndT]:
196
- return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))
197
-
198
-
199
- @dataclasses.dataclass
200
- class StreamUserPromptNode(BaseUserPromptNode[DepsT, NodeRunEndT]):
201
- async def run(
202
- self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
203
- ) -> StreamModelRequestNode[DepsT, NodeRunEndT]:
204
- return StreamModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))
205
-
206
-
207
194
  async def _prepare_request_parameters(
208
195
  ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
209
196
  ) -> models.ModelRequestParameters:
210
197
  """Build tools and create an agent model."""
211
198
  function_tool_defs: list[ToolDefinition] = []
212
199
 
213
- run_context = _build_run_context(ctx)
200
+ run_context = build_run_context(ctx)
214
201
 
215
202
  async def add_tool(tool: Tool[DepsT]) -> None:
216
203
  ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name)
@@ -222,20 +209,81 @@ async def _prepare_request_parameters(
222
209
  result_schema = ctx.deps.result_schema
223
210
  return models.ModelRequestParameters(
224
211
  function_tools=function_tool_defs,
225
- allow_text_result=_allow_text_result(result_schema),
212
+ allow_text_result=allow_text_result(result_schema),
226
213
  result_tools=result_schema.tool_defs() if result_schema is not None else [],
227
214
  )
228
215
 
229
216
 
230
217
  @dataclasses.dataclass
231
- class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]):
218
+ class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
232
219
  """Make a request to the model using the last message in state.message_history."""
233
220
 
234
221
  request: _messages.ModelRequest
235
222
 
223
+ _result: HandleResponseNode[DepsT, NodeRunEndT] | None = field(default=None, repr=False)
224
+ _did_stream: bool = field(default=False, repr=False)
225
+
236
226
  async def run(
237
227
  self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
238
228
  ) -> HandleResponseNode[DepsT, NodeRunEndT]:
229
+ if self._result is not None:
230
+ return self._result
231
+
232
+ if self._did_stream:
233
+ # `self._result` gets set when exiting the `stream` contextmanager, so hitting this
234
+ # means that the stream was started but not finished before `run()` was called
235
+ raise exceptions.AgentRunError('You must finish streaming before calling run()')
236
+
237
+ return await self._make_request(ctx)
238
+
239
+ @asynccontextmanager
240
+ async def _stream(
241
+ self,
242
+ ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
243
+ ) -> AsyncIterator[models.StreamedResponse]:
244
+ # TODO: Consider changing this to return something more similar to a `StreamedRunResult`, then make it public
245
+ assert not self._did_stream, 'stream() should only be called once per node'
246
+
247
+ model_settings, model_request_parameters = await self._prepare_request(ctx)
248
+ with _logfire.span('model request', run_step=ctx.state.run_step) as span:
249
+ async with ctx.deps.model.request_stream(
250
+ ctx.state.message_history, model_settings, model_request_parameters
251
+ ) as streamed_response:
252
+ self._did_stream = True
253
+ ctx.state.usage.incr(_usage.Usage(), requests=1)
254
+ yield streamed_response
255
+ # In case the user didn't manually consume the full stream, ensure it is fully consumed here,
256
+ # otherwise usage won't be properly counted:
257
+ async for _ in streamed_response:
258
+ pass
259
+ model_response = streamed_response.get()
260
+ request_usage = streamed_response.usage()
261
+ span.set_attribute('response', model_response)
262
+ span.set_attribute('usage', request_usage)
263
+
264
+ self._finish_handling(ctx, model_response, request_usage)
265
+ assert self._result is not None # this should be set by the previous line
266
+
267
+ async def _make_request(
268
+ self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
269
+ ) -> HandleResponseNode[DepsT, NodeRunEndT]:
270
+ if self._result is not None:
271
+ return self._result
272
+
273
+ model_settings, model_request_parameters = await self._prepare_request(ctx)
274
+ with _logfire.span('model request', run_step=ctx.state.run_step) as span:
275
+ model_response, request_usage = await ctx.deps.model.request(
276
+ ctx.state.message_history, model_settings, model_request_parameters
277
+ )
278
+ ctx.state.usage.incr(_usage.Usage(), requests=1)
279
+ span.set_attribute('response', model_response)
280
+ span.set_attribute('usage', request_usage)
281
+
282
+ return self._finish_handling(ctx, model_response, request_usage)
283
+
284
+ async def _prepare_request(
285
+ self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
286
+ ) -> tuple[ModelSettings | None, models.ModelRequestParameters]:
239
287
  ctx.state.message_history.append(self.request)
240
288
 
241
289
  # Check usage
@@ -245,71 +293,124 @@ class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], Nod
245
293
  # Increment run_step
246
294
  ctx.state.run_step += 1
247
295
 
296
+ model_settings = merge_model_settings(ctx.deps.model_settings, None)
248
297
  with _logfire.span('preparing model request params {run_step=}', run_step=ctx.state.run_step):
249
298
  model_request_parameters = await _prepare_request_parameters(ctx)
299
+ return model_settings, model_request_parameters
250
300
 
251
- # Actually make the model request
252
- model_settings = merge_model_settings(ctx.deps.model_settings, None)
253
- with _logfire.span('model request') as span:
254
- model_response, request_usage = await ctx.deps.model.request(
255
- ctx.state.message_history, model_settings, model_request_parameters
256
- )
257
- span.set_attribute('response', model_response)
258
- span.set_attribute('usage', request_usage)
259
-
301
+ def _finish_handling(
302
+ self,
303
+ ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
304
+ response: _messages.ModelResponse,
305
+ usage: _usage.Usage,
306
+ ) -> HandleResponseNode[DepsT, NodeRunEndT]:
260
307
  # Update usage
261
- ctx.state.usage.incr(request_usage, requests=1)
308
+ ctx.state.usage.incr(usage, requests=0)
262
309
  if ctx.deps.usage_limits:
263
310
  ctx.deps.usage_limits.check_tokens(ctx.state.usage)
264
311
 
265
312
  # Append the model response to state.message_history
266
- ctx.state.message_history.append(model_response)
267
- return HandleResponseNode(model_response)
313
+ ctx.state.message_history.append(response)
314
+
315
+ # Set the `_result` attribute since we can't use `return` in an async iterator
316
+ self._result = HandleResponseNode(response)
317
+
318
+ return self._result
268
319
 
269
320
 
270
321
  @dataclasses.dataclass
271
- class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]):
272
- """Process e response from a model, decide whether to end the run or make a new request."""
322
+ class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
323
+ """Process a model response, and decide whether to end the run or make a new request."""
273
324
 
274
325
  model_response: _messages.ModelResponse
275
326
 
327
+ _events_iterator: AsyncIterator[_messages.HandleResponseEvent] | None = field(default=None, repr=False)
328
+ _next_node: ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]] | None = field(
329
+ default=None, repr=False
330
+ )
331
+ _tool_responses: list[_messages.ModelRequestPart] = field(default_factory=list, repr=False)
332
+
276
333
  async def run(
277
334
  self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
278
- ) -> Union[ModelRequestNode[DepsT, NodeRunEndT], FinalResultNode[DepsT, NodeRunEndT]]: # noqa UP007
335
+ ) -> Union[ModelRequestNode[DepsT, NodeRunEndT], End[result.FinalResult[NodeRunEndT]]]: # noqa UP007
336
+ async with self.stream(ctx):
337
+ pass
338
+
339
+ assert (next_node := self._next_node) is not None, 'the stream should set `self._next_node` before it ends'
340
+ return next_node
341
+
342
+ @asynccontextmanager
343
+ async def stream(
344
+ self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
345
+ ) -> AsyncIterator[AsyncIterator[_messages.HandleResponseEvent]]:
346
+ """Process the model response and yield events for the start and end of each function tool call."""
279
347
  with _logfire.span('handle model response', run_step=ctx.state.run_step) as handle_span:
280
- texts: list[str] = []
281
- tool_calls: list[_messages.ToolCallPart] = []
282
- for part in self.model_response.parts:
283
- if isinstance(part, _messages.TextPart):
284
- # ignore empty content for text parts, see #437
285
- if part.content:
286
- texts.append(part.content)
287
- elif isinstance(part, _messages.ToolCallPart):
288
- tool_calls.append(part)
348
+ stream = self._run_stream(ctx)
349
+ yield stream
350
+
351
+ # Run the stream to completion if it was not finished:
352
+ async for _event in stream:
353
+ pass
354
+
355
+ # Set the next node based on the final state of the stream
356
+ next_node = self._next_node
357
+ if isinstance(next_node, End):
358
+ handle_span.set_attribute('result', next_node.data)
359
+ handle_span.message = 'handle model response -> final result'
360
+ elif tool_responses := self._tool_responses:
361
+ # TODO: We could drop `self._tool_responses` if we drop this set_attribute
362
+ # I'm thinking it might be better to just create a span for the handling of each tool
363
+ # than to set an attribute here.
364
+ handle_span.set_attribute('tool_responses', tool_responses)
365
+ tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
366
+ handle_span.message = f'handle model response -> {tool_responses_str}'
367
+
368
+ async def _run_stream(
369
+ self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
370
+ ) -> AsyncIterator[_messages.HandleResponseEvent]:
371
+ if self._events_iterator is None:
372
+ # Ensure that the stream is only run once
373
+
374
+ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
375
+ texts: list[str] = []
376
+ tool_calls: list[_messages.ToolCallPart] = []
377
+ for part in self.model_response.parts:
378
+ if isinstance(part, _messages.TextPart):
379
+ # ignore empty content for text parts, see #437
380
+ if part.content:
381
+ texts.append(part.content)
382
+ elif isinstance(part, _messages.ToolCallPart):
383
+ tool_calls.append(part)
384
+ else:
385
+ assert_never(part)
386
+
387
+ # At the moment, we prioritize at least executing tool calls if they are present.
388
+ # In the future, we'd consider making this configurable at the agent or run level.
389
+ # This accounts for cases like anthropic returns that might contain a text response
390
+ # and a tool call response, where the text response just indicates the tool call will happen.
391
+ if tool_calls:
392
+ async for event in self._handle_tool_calls(ctx, tool_calls):
393
+ yield event
394
+ elif texts:
395
+ # No events are emitted during the handling of text responses, so we don't need to yield anything
396
+ self._next_node = await self._handle_text_response(ctx, texts)
289
397
  else:
290
- assert_never(part)
291
-
292
- # At the moment, we prioritize at least executing tool calls if they are present.
293
- # In the future, we'd consider making this configurable at the agent or run level.
294
- # This accounts for cases like anthropic returns that might contain a text response
295
- # and a tool call response, where the text response just indicates the tool call will happen.
296
- if tool_calls:
297
- return await self._handle_tool_calls_response(ctx, tool_calls, handle_span)
298
- elif texts:
299
- return await self._handle_text_response(ctx, texts, handle_span)
300
- else:
301
- raise exceptions.UnexpectedModelBehavior('Received empty model response')
398
+ raise exceptions.UnexpectedModelBehavior('Received empty model response')
302
399
 
303
- async def _handle_tool_calls_response(
400
+ self._events_iterator = _run_stream()
401
+
402
+ async for event in self._events_iterator:
403
+ yield event
404
+
405
+ async def _handle_tool_calls(
304
406
  self,
305
407
  ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
306
408
  tool_calls: list[_messages.ToolCallPart],
307
- handle_span: logfire_api.LogfireSpan,
308
- ):
409
+ ) -> AsyncIterator[_messages.HandleResponseEvent]:
309
410
  result_schema = ctx.deps.result_schema
310
411
 
311
412
  # first look for the result tool call
312
- final_result: MarkFinalResult[NodeRunEndT] | None = None
413
+ final_result: result.FinalResult[NodeRunEndT] | None = None
313
414
  parts: list[_messages.ModelRequestPart] = []
314
415
  if result_schema is not None:
315
416
  if match := result_schema.find_tool(tool_calls):
@@ -323,33 +424,51 @@ class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], N
323
424
  ctx.state.increment_retries(ctx.deps.max_result_retries)
324
425
  parts.append(e.tool_retry)
325
426
  else:
326
- final_result = MarkFinalResult(result_data, call.tool_name)
427
+ final_result = result.FinalResult(result_data, call.tool_name)
327
428
 
328
429
  # Then build the other request parts based on end strategy
329
- tool_responses = await _process_function_tools(tool_calls, final_result and final_result.tool_name, ctx)
430
+ tool_responses: list[_messages.ModelRequestPart] = self._tool_responses
431
+ async for event in process_function_tools(
432
+ tool_calls, final_result and final_result.tool_name, ctx, tool_responses
433
+ ):
434
+ yield event
330
435
 
331
436
  if final_result:
332
- handle_span.set_attribute('result', final_result.data)
333
- handle_span.message = 'handle model response -> final result'
334
- return FinalResultNode[DepsT, NodeRunEndT](final_result, tool_responses)
437
+ self._next_node = self._handle_final_result(ctx, final_result, tool_responses)
335
438
  else:
336
439
  if tool_responses:
337
- handle_span.set_attribute('tool_responses', tool_responses)
338
- tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
339
- handle_span.message = f'handle model response -> {tool_responses_str}'
340
440
  parts.extend(tool_responses)
341
- return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=parts))
441
+ self._next_node = ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=parts))
442
+
443
+ def _handle_final_result(
444
+ self,
445
+ ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
446
+ final_result: result.FinalResult[NodeRunEndT],
447
+ tool_responses: list[_messages.ModelRequestPart],
448
+ ) -> End[result.FinalResult[NodeRunEndT]]:
449
+ run_span = ctx.deps.run_span
450
+ usage = ctx.state.usage
451
+ messages = ctx.state.message_history
452
+
453
+ # For backwards compatibility, append a new ModelRequest using the tool returns and retries
454
+ if tool_responses:
455
+ messages.append(_messages.ModelRequest(parts=tool_responses))
456
+
457
+ run_span.set_attribute('usage', usage)
458
+ run_span.set_attribute('all_messages', messages)
459
+
460
+ # End the run with self.data
461
+ return End(final_result)
342
462
 
343
463
  async def _handle_text_response(
344
464
  self,
345
465
  ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
346
466
  texts: list[str],
347
- handle_span: logfire_api.LogfireSpan,
348
- ):
467
+ ) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
349
468
  result_schema = ctx.deps.result_schema
350
469
 
351
470
  text = '\n\n'.join(texts)
352
- if _allow_text_result(result_schema):
471
+ if allow_text_result(result_schema):
353
472
  result_data_input = cast(NodeRunEndT, text)
354
473
  try:
355
474
  result_data = await _validate_result(result_data_input, ctx, None)
@@ -357,9 +476,8 @@ class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], N
357
476
  ctx.state.increment_retries(ctx.deps.max_result_retries)
358
477
  return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
359
478
  else:
360
- handle_span.set_attribute('result', result_data)
361
- handle_span.message = 'handle model response -> final result'
362
- return FinalResultNode[DepsT, NodeRunEndT](MarkFinalResult(result_data, None))
479
+ # The following cast is safe because we know `str` is an allowed result type
480
+ return self._handle_final_result(ctx, result.FinalResult(result_data, tool_name=None), [])
363
481
  else:
364
482
  ctx.state.increment_retries(ctx.deps.max_result_retries)
365
483
  return ModelRequestNode[DepsT, NodeRunEndT](
@@ -373,166 +491,8 @@ class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], N
373
491
  )
374
492
 
375
493
 
376
- @dataclasses.dataclass
377
- class StreamModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]):
378
- """Make a request to the model using the last message in state.message_history (or a specified request)."""
379
-
380
- request: _messages.ModelRequest
381
- _result: StreamModelRequestNode[DepsT, NodeRunEndT] | End[result.StreamedRunResult[DepsT, NodeRunEndT]] | None = (
382
- field(default=None, repr=False)
383
- )
384
-
385
- async def run(
386
- self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
387
- ) -> Union[StreamModelRequestNode[DepsT, NodeRunEndT], End[result.StreamedRunResult[DepsT, NodeRunEndT]]]: # noqa UP007
388
- if self._result is not None:
389
- return self._result
390
-
391
- async with self.run_to_result(ctx) as final_node:
392
- return final_node
393
-
394
- @asynccontextmanager
395
- async def run_to_result(
396
- self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
397
- ) -> AsyncIterator[StreamModelRequestNode[DepsT, NodeRunEndT] | End[result.StreamedRunResult[DepsT, NodeRunEndT]]]:
398
- result_schema = ctx.deps.result_schema
399
-
400
- ctx.state.message_history.append(self.request)
401
-
402
- # Check usage
403
- if ctx.deps.usage_limits:
404
- ctx.deps.usage_limits.check_before_request(ctx.state.usage)
405
-
406
- # Increment run_step
407
- ctx.state.run_step += 1
408
-
409
- with _logfire.span('preparing model and tools {run_step=}', run_step=ctx.state.run_step):
410
- model_request_parameters = await _prepare_request_parameters(ctx)
411
-
412
- # Actually make the model request
413
- model_settings = merge_model_settings(ctx.deps.model_settings, None)
414
- with _logfire.span('model request {run_step=}', run_step=ctx.state.run_step) as model_req_span:
415
- async with ctx.deps.model.request_stream(
416
- ctx.state.message_history, model_settings, model_request_parameters
417
- ) as streamed_response:
418
- ctx.state.usage.requests += 1
419
- model_req_span.set_attribute('response_type', streamed_response.__class__.__name__)
420
- # We want to end the "model request" span here, but we can't exit the context manager
421
- # in the traditional way
422
- model_req_span.__exit__(None, None, None)
423
-
424
- with _logfire.span('handle model response') as handle_span:
425
- received_text = False
426
-
427
- async for maybe_part_event in streamed_response:
428
- if isinstance(maybe_part_event, _messages.PartStartEvent):
429
- new_part = maybe_part_event.part
430
- if isinstance(new_part, _messages.TextPart):
431
- received_text = True
432
- if _allow_text_result(result_schema):
433
- handle_span.message = 'handle model response -> final result'
434
- streamed_run_result = _build_streamed_run_result(streamed_response, None, ctx)
435
- self._result = End(streamed_run_result)
436
- yield self._result
437
- return
438
- elif isinstance(new_part, _messages.ToolCallPart):
439
- if result_schema is not None and (match := result_schema.find_tool([new_part])):
440
- call, _ = match
441
- handle_span.message = 'handle model response -> final result'
442
- streamed_run_result = _build_streamed_run_result(
443
- streamed_response, call.tool_name, ctx
444
- )
445
- self._result = End(streamed_run_result)
446
- yield self._result
447
- return
448
- else:
449
- assert_never(new_part)
450
-
451
- tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
452
- parts: list[_messages.ModelRequestPart] = []
453
- model_response = streamed_response.get()
454
- if not model_response.parts:
455
- raise exceptions.UnexpectedModelBehavior('Received empty model response')
456
- ctx.state.message_history.append(model_response)
457
-
458
- run_context = _build_run_context(ctx)
459
- for p in model_response.parts:
460
- if isinstance(p, _messages.ToolCallPart):
461
- if tool := ctx.deps.function_tools.get(p.tool_name):
462
- tasks.append(asyncio.create_task(tool.run(p, run_context), name=p.tool_name))
463
- else:
464
- parts.append(_unknown_tool(p.tool_name, ctx))
465
-
466
- if received_text and not tasks and not parts:
467
- # Can only get here if self._allow_text_result returns `False` for the provided result_schema
468
- ctx.state.increment_retries(ctx.deps.max_result_retries)
469
- self._result = StreamModelRequestNode[DepsT, NodeRunEndT](
470
- _messages.ModelRequest(
471
- parts=[
472
- _messages.RetryPromptPart(
473
- content='Plain text responses are not permitted, please call one of the functions instead.',
474
- )
475
- ]
476
- )
477
- )
478
- yield self._result
479
- return
480
-
481
- with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
482
- task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
483
- parts.extend(task_results)
484
-
485
- next_request = _messages.ModelRequest(parts=parts)
486
- if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
487
- try:
488
- ctx.state.increment_retries(ctx.deps.max_result_retries)
489
- except:
490
- # TODO: This is janky, so I think we should probably change it, but how?
491
- ctx.state.message_history.append(next_request)
492
- raise
493
-
494
- handle_span.set_attribute('tool_responses', parts)
495
- tool_responses_str = ' '.join(r.part_kind for r in parts)
496
- handle_span.message = f'handle model response -> {tool_responses_str}'
497
- # the model_response should have been fully streamed by now, we can add its usage
498
- streamed_response_usage = streamed_response.usage()
499
- run_context.usage.incr(streamed_response_usage)
500
- ctx.deps.usage_limits.check_tokens(run_context.usage)
501
- self._result = StreamModelRequestNode[DepsT, NodeRunEndT](next_request)
502
- yield self._result
503
- return
504
-
505
-
506
- @dataclasses.dataclass
507
- class FinalResultNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], MarkFinalResult[NodeRunEndT]]):
508
- """Produce the final result of the run."""
509
-
510
- data: MarkFinalResult[NodeRunEndT]
511
- """The final result data."""
512
- extra_parts: list[_messages.ModelRequestPart] = dataclasses.field(default_factory=list)
513
-
514
- async def run(
515
- self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
516
- ) -> End[MarkFinalResult[NodeRunEndT]]:
517
- run_span = ctx.deps.run_span
518
- usage = ctx.state.usage
519
- messages = ctx.state.message_history
520
-
521
- # TODO: For backwards compatibility, append a new ModelRequest using the tool returns and retries
522
- if self.extra_parts:
523
- messages.append(_messages.ModelRequest(parts=self.extra_parts))
524
-
525
- # TODO: Set this attribute somewhere
526
- # handle_span = self.handle_model_response_span
527
- # handle_span.set_attribute('final_data', self.data)
528
- run_span.set_attribute('usage', usage)
529
- run_span.set_attribute('all_messages', messages)
530
-
531
- # End the run with self.data
532
- return End(self.data)
533
-
534
-
535
- def _build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]:
494
+ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]:
495
+ """Build a `RunContext` object from the current agent graph run context."""
536
496
  return RunContext[DepsT](
537
497
  deps=ctx.deps.user_deps,
538
498
  model=ctx.deps.model,
@@ -543,76 +503,31 @@ def _build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[Deps
543
503
  )
544
504
 
545
505
 
546
- def _build_streamed_run_result(
547
- result_stream: models.StreamedResponse,
548
- result_tool_name: str | None,
549
- ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
550
- ) -> result.StreamedRunResult[DepsT, NodeRunEndT]:
551
- new_message_index = ctx.deps.new_message_index
552
- result_schema = ctx.deps.result_schema
553
- run_span = ctx.deps.run_span
554
- usage_limits = ctx.deps.usage_limits
555
- messages = ctx.state.message_history
556
- run_context = _build_run_context(ctx)
557
-
558
- async def on_complete():
559
- """Called when the stream has completed.
560
-
561
- The model response will have been added to messages by now
562
- by `StreamedRunResult._marked_completed`.
563
- """
564
- last_message = messages[-1]
565
- assert isinstance(last_message, _messages.ModelResponse)
566
- tool_calls = [part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)]
567
- parts = await _process_function_tools(
568
- tool_calls,
569
- result_tool_name,
570
- ctx,
571
- )
572
- # TODO: Should we do something here related to the retry count?
573
- # Maybe we should move the incrementing of the retry count to where we actually make a request?
574
- # if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
575
- # ctx.state.increment_retries(ctx.deps.max_result_retries)
576
- if parts:
577
- messages.append(_messages.ModelRequest(parts))
578
- run_span.set_attribute('all_messages', messages)
579
-
580
- return result.StreamedRunResult[DepsT, NodeRunEndT](
581
- messages,
582
- new_message_index,
583
- usage_limits,
584
- result_stream,
585
- result_schema,
586
- run_context,
587
- ctx.deps.result_validators,
588
- result_tool_name,
589
- on_complete,
590
- )
591
-
592
-
593
- async def _process_function_tools(
506
+ async def process_function_tools(
594
507
  tool_calls: list[_messages.ToolCallPart],
595
508
  result_tool_name: str | None,
596
509
  ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
597
- ) -> list[_messages.ModelRequestPart]:
598
- """Process function (non-result) tool calls in parallel.
510
+ output_parts: list[_messages.ModelRequestPart],
511
+ ) -> AsyncIterator[_messages.HandleResponseEvent]:
512
+ """Process function (i.e., non-result) tool calls in parallel.
599
513
 
600
514
  Also add stub return parts for any other tools that need it.
601
- """
602
- parts: list[_messages.ModelRequestPart] = []
603
- tasks: list[asyncio.Task[_messages.ToolReturnPart | _messages.RetryPromptPart]] = []
604
515
 
516
+ Because async iterators can't have return values, we use `output_parts` as an output argument.
517
+ """
605
518
  stub_function_tools = bool(result_tool_name) and ctx.deps.end_strategy == 'early'
606
519
  result_schema = ctx.deps.result_schema
607
520
 
608
521
  # we rely on the fact that if we found a result, it's the first result tool in the last
609
522
  found_used_result_tool = False
610
- run_context = _build_run_context(ctx)
523
+ run_context = build_run_context(ctx)
611
524
 
525
+ calls_to_run: list[tuple[Tool[DepsT], _messages.ToolCallPart]] = []
526
+ call_index_to_event_id: dict[int, str] = {}
612
527
  for call in tool_calls:
613
528
  if call.tool_name == result_tool_name and not found_used_result_tool:
614
529
  found_used_result_tool = True
615
- parts.append(
530
+ output_parts.append(
616
531
  _messages.ToolReturnPart(
617
532
  tool_name=call.tool_name,
618
533
  content='Final result processed.',
@@ -621,7 +536,7 @@ async def _process_function_tools(
621
536
  )
622
537
  elif tool := ctx.deps.function_tools.get(call.tool_name):
623
538
  if stub_function_tools:
624
- parts.append(
539
+ output_parts.append(
625
540
  _messages.ToolReturnPart(
626
541
  tool_name=call.tool_name,
627
542
  content='Tool not executed - a final result was already processed.',
@@ -629,33 +544,47 @@ async def _process_function_tools(
629
544
  )
630
545
  )
631
546
  else:
632
- tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
547
+ event = _messages.FunctionToolCallEvent(call)
548
+ yield event
549
+ call_index_to_event_id[len(calls_to_run)] = event.call_id
550
+ calls_to_run.append((tool, call))
633
551
  elif result_schema is not None and call.tool_name in result_schema.tools:
634
552
  # if tool_name is in _result_schema, it means we found a result tool but an error occurred in
635
553
  # validation, we don't add another part here
636
554
  if result_tool_name is not None:
637
- parts.append(
638
- _messages.ToolReturnPart(
639
- tool_name=call.tool_name,
640
- content='Result tool not used - a final result was already processed.',
641
- tool_call_id=call.tool_call_id,
642
- )
555
+ part = _messages.ToolReturnPart(
556
+ tool_name=call.tool_name,
557
+ content='Result tool not used - a final result was already processed.',
558
+ tool_call_id=call.tool_call_id,
643
559
  )
560
+ output_parts.append(part)
644
561
  else:
645
- parts.append(_unknown_tool(call.tool_name, ctx))
562
+ output_parts.append(_unknown_tool(call.tool_name, ctx))
563
+
564
+ if not calls_to_run:
565
+ return
646
566
 
647
567
  # Run all tool tasks in parallel
648
- if tasks:
649
- with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
650
- task_results: Sequence[_messages.ToolReturnPart | _messages.RetryPromptPart] = await asyncio.gather(*tasks)
651
- for result in task_results:
652
- if isinstance(result, _messages.ToolReturnPart):
653
- parts.append(result)
654
- elif isinstance(result, _messages.RetryPromptPart):
655
- parts.append(result)
568
+ results_by_index: dict[int, _messages.ModelRequestPart] = {}
569
+ with _logfire.span('running {tools=}', tools=[call.tool_name for _, call in calls_to_run]):
570
+ # TODO: Should we wrap each individual tool call in a dedicated span?
571
+ tasks = [asyncio.create_task(tool.run(call, run_context), name=call.tool_name) for tool, call in calls_to_run]
572
+ pending = tasks
573
+ while pending:
574
+ done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
575
+ for task in done:
576
+ index = tasks.index(task)
577
+ result = task.result()
578
+ yield _messages.FunctionToolResultEvent(result, call_id=call_index_to_event_id[index])
579
+ if isinstance(result, (_messages.ToolReturnPart, _messages.RetryPromptPart)):
580
+ results_by_index[index] = result
656
581
  else:
657
582
  assert_never(result)
658
- return parts
583
+
584
+ # We append the results at the end, rather than as they are received, to retain a consistent ordering
585
+ # This is mostly just to simplify testing
586
+ for k in sorted(results_by_index):
587
+ output_parts.append(results_by_index[k])
659
588
 
660
589
 
661
590
  def _unknown_tool(
@@ -681,12 +610,13 @@ async def _validate_result(
681
610
  tool_call: _messages.ToolCallPart | None,
682
611
  ) -> T:
683
612
  for validator in ctx.deps.result_validators:
684
- run_context = _build_run_context(ctx)
613
+ run_context = build_run_context(ctx)
685
614
  result_data = await validator.validate(result_data, tool_call, run_context)
686
615
  return result_data
687
616
 
688
617
 
689
- def _allow_text_result(result_schema: _result.ResultSchema[Any] | None) -> bool:
618
+ def allow_text_result(result_schema: _result.ResultSchema[Any] | None) -> bool:
619
+ """Check if the result schema allows text results."""
690
620
  return result_schema is None or result_schema.allow_text_result
691
621
 
692
622
 
@@ -740,35 +670,18 @@ def get_captured_run_messages() -> _RunMessages:
740
670
 
741
671
  def build_agent_graph(
742
672
  name: str | None, deps_type: type[DepsT], result_type: type[ResultT]
743
- ) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], MarkFinalResult[ResultT]]:
744
- # We'll define the known node classes:
673
+ ) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[ResultT]], result.FinalResult[ResultT]]:
674
+ """Build the execution [Graph][pydantic_graph.Graph] for a given agent."""
745
675
  nodes = (
746
676
  UserPromptNode[DepsT],
747
677
  ModelRequestNode[DepsT],
748
678
  HandleResponseNode[DepsT],
749
- FinalResultNode[DepsT, ResultT],
750
679
  )
751
- graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], MarkFinalResult[ResultT]](
680
+ graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[ResultT]](
752
681
  nodes=nodes,
753
682
  name=name or 'Agent',
754
683
  state_type=GraphAgentState,
755
- run_end_type=MarkFinalResult[result_type],
684
+ run_end_type=result.FinalResult[result_type],
756
685
  auto_instrument=False,
757
686
  )
758
687
  return graph
759
-
760
-
761
- def build_agent_stream_graph(
762
- name: str | None, deps_type: type[DepsT], result_type: type[ResultT] | None
763
- ) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.StreamedRunResult[DepsT, Any]]:
764
- nodes = [
765
- StreamUserPromptNode[DepsT, result.StreamedRunResult[DepsT, ResultT]],
766
- StreamModelRequestNode[DepsT, result.StreamedRunResult[DepsT, ResultT]],
767
- ]
768
- graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.StreamedRunResult[DepsT, Any]](
769
- nodes=nodes,
770
- name=name or 'Agent',
771
- state_type=GraphAgentState,
772
- run_end_type=result.StreamedRunResult[DepsT, result_type],
773
- )
774
- return graph