pydantic-ai-slim 0.0.11__py3-none-any.whl → 0.0.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/_pydantic.py CHANGED
@@ -17,10 +17,10 @@ from pydantic.plugin._schema_validator import create_schema_validator
17
17
  from pydantic_core import SchemaValidator, core_schema
18
18
 
19
19
  from ._griffe import doc_descriptions
20
- from ._utils import ObjectJsonSchema, check_object_json_schema, is_model_like
20
+ from ._utils import check_object_json_schema, is_model_like
21
21
 
22
22
  if TYPE_CHECKING:
23
- pass
23
+ from .tools import ObjectJsonSchema
24
24
 
25
25
 
26
26
  __all__ = 'function_schema', 'LazyTypeAdapter'
@@ -168,11 +168,13 @@ def takes_ctx(function: Callable[..., Any]) -> bool:
168
168
  """
169
169
  sig = signature(function)
170
170
  try:
171
- _, first_param = next(iter(sig.parameters.items()))
171
+ first_param_name = next(iter(sig.parameters.keys()))
172
172
  except StopIteration:
173
173
  return False
174
174
  else:
175
- return first_param.annotation is not sig.empty and _is_call_ctx(first_param.annotation)
175
+ type_hints = _typing_extra.get_function_type_hints(function)
176
+ annotation = type_hints[first_param_name]
177
+ return annotation is not sig.empty and _is_call_ctx(annotation)
176
178
 
177
179
 
178
180
  def _build_schema(
pydantic_ai/_result.py CHANGED
@@ -14,7 +14,7 @@ from . import _utils, messages
14
14
  from .exceptions import ModelRetry
15
15
  from .messages import ModelStructuredResponse, ToolCall
16
16
  from .result import ResultData
17
- from .tools import AgentDeps, ResultValidatorFunc, RunContext
17
+ from .tools import AgentDeps, ResultValidatorFunc, RunContext, ToolDefinition
18
18
 
19
19
 
20
20
  @dataclass
@@ -94,10 +94,7 @@ class ResultSchema(Generic[ResultData]):
94
94
  allow_text_result = False
95
95
 
96
96
  def _build_tool(a: Any, tool_name_: str, multiple: bool) -> ResultTool[ResultData]:
97
- return cast(
98
- ResultTool[ResultData],
99
- ResultTool.build(a, tool_name_, description, multiple), # pyright: ignore[reportUnknownMemberType]
100
- )
97
+ return cast(ResultTool[ResultData], ResultTool(a, tool_name_, description, multiple))
101
98
 
102
99
  tools: dict[str, ResultTool[ResultData]] = {}
103
100
  if args := get_union_args(response_type):
@@ -121,38 +118,38 @@ class ResultSchema(Generic[ResultData]):
121
118
  """Return the names of the tools."""
122
119
  return list(self.tools.keys())
123
120
 
121
+ def tool_defs(self) -> list[ToolDefinition]:
122
+ """Get tool definitions to register with the model."""
123
+ return [t.tool_def for t in self.tools.values()]
124
+
124
125
 
125
126
  DEFAULT_DESCRIPTION = 'The final response which ends this conversation'
126
127
 
127
128
 
128
- @dataclass
129
+ @dataclass(init=False)
129
130
  class ResultTool(Generic[ResultData]):
130
- name: str
131
- description: str
131
+ tool_def: ToolDefinition
132
132
  type_adapter: TypeAdapter[Any]
133
- json_schema: _utils.ObjectJsonSchema
134
- outer_typed_dict_key: str | None
135
133
 
136
- @classmethod
137
- def build(cls, response_type: type[ResultData], name: str, description: str | None, multiple: bool) -> Self | None:
134
+ def __init__(self, response_type: type[ResultData], name: str, description: str | None, multiple: bool):
138
135
  """Build a ResultTool dataclass from a response type."""
139
136
  assert response_type is not str, 'ResultTool does not support str as a response type'
140
137
 
141
138
  if _utils.is_model_like(response_type):
142
- type_adapter = TypeAdapter(response_type)
139
+ self.type_adapter = TypeAdapter(response_type)
143
140
  outer_typed_dict_key: str | None = None
144
141
  # noinspection PyArgumentList
145
- json_schema = _utils.check_object_json_schema(type_adapter.json_schema())
142
+ parameters_json_schema = _utils.check_object_json_schema(self.type_adapter.json_schema())
146
143
  else:
147
144
  response_data_typed_dict = TypedDict('response_data_typed_dict', {'response': response_type}) # noqa
148
- type_adapter = TypeAdapter(response_data_typed_dict)
145
+ self.type_adapter = TypeAdapter(response_data_typed_dict)
149
146
  outer_typed_dict_key = 'response'
150
147
  # noinspection PyArgumentList
151
- json_schema = _utils.check_object_json_schema(type_adapter.json_schema())
148
+ parameters_json_schema = _utils.check_object_json_schema(self.type_adapter.json_schema())
152
149
  # including `response_data_typed_dict` as a title here doesn't add anything and could confuse the LLM
153
- json_schema.pop('title')
150
+ parameters_json_schema.pop('title')
154
151
 
155
- if json_schema_description := json_schema.pop('description', None):
152
+ if json_schema_description := parameters_json_schema.pop('description', None):
156
153
  if description is None:
157
154
  tool_description = json_schema_description
158
155
  else:
@@ -162,11 +159,10 @@ class ResultTool(Generic[ResultData]):
162
159
  if multiple:
163
160
  tool_description = f'{union_arg_name(response_type)}: {tool_description}'
164
161
 
165
- return cls(
162
+ self.tool_def = ToolDefinition(
166
163
  name=name,
167
164
  description=tool_description,
168
- type_adapter=type_adapter,
169
- json_schema=json_schema,
165
+ parameters_json_schema=parameters_json_schema,
170
166
  outer_typed_dict_key=outer_typed_dict_key,
171
167
  )
172
168
 
@@ -204,7 +200,7 @@ class ResultTool(Generic[ResultData]):
204
200
  else:
205
201
  raise
206
202
  else:
207
- if k := self.outer_typed_dict_key:
203
+ if k := self.tool_def.outer_typed_dict_key:
208
204
  result = result[k]
209
205
  return result
210
206
 
@@ -21,7 +21,7 @@ class SystemPromptRunner(Generic[AgentDeps]):
21
21
 
22
22
  async def run(self, deps: AgentDeps) -> str:
23
23
  if self._takes_ctx:
24
- args = (RunContext(deps, 0, None),)
24
+ args = (RunContext(deps, 0),)
25
25
  else:
26
26
  args = ()
27
27
 
pydantic_ai/_utils.py CHANGED
@@ -8,12 +8,15 @@ from dataclasses import dataclass, is_dataclass
8
8
  from datetime import datetime, timezone
9
9
  from functools import partial
10
10
  from types import GenericAlias
11
- from typing import Any, Callable, Generic, TypeVar, Union, cast, overload
11
+ from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast, overload
12
12
 
13
13
  from pydantic import BaseModel
14
14
  from pydantic.json_schema import JsonSchemaValue
15
15
  from typing_extensions import ParamSpec, TypeAlias, TypeGuard, is_typeddict
16
16
 
17
+ if TYPE_CHECKING:
18
+ from .tools import ObjectJsonSchema
19
+
17
20
  _P = ParamSpec('_P')
18
21
  _R = TypeVar('_R')
19
22
 
@@ -39,10 +42,6 @@ def is_model_like(type_: Any) -> bool:
39
42
  )
40
43
 
41
44
 
42
- # With PEP-728 this should be a TypedDict with `type: Literal['object']`, and `extra_items=Any`
43
- ObjectJsonSchema: TypeAlias = dict[str, Any]
44
-
45
-
46
45
  def check_object_json_schema(schema: JsonSchemaValue) -> ObjectJsonSchema:
47
46
  from .exceptions import UserError
48
47
 
@@ -127,6 +126,12 @@ class Either(Generic[Left, Right]):
127
126
  def whichever(self) -> Left | Right:
128
127
  return self._left.value if self._left is not None else self.right
129
128
 
129
+ def __repr__(self):
130
+ if left := self._left:
131
+ return f'Either(left={left.value!r})'
132
+ else:
133
+ return f'Either(right={self.right!r})'
134
+
130
135
 
131
136
  @asynccontextmanager
132
137
  async def group_by_temporal(
@@ -218,7 +223,7 @@ async def group_by_temporal(
218
223
 
219
224
  try:
220
225
  yield async_iter_groups()
221
- finally:
226
+ finally: # pragma: no cover
222
227
  # after iteration if a tasks still exists, cancel it, this will only happen if an error occurred
223
228
  if task:
224
229
  task.cancel('Cancelling due to error in iterator')
pydantic_ai/agent.py CHANGED
@@ -22,7 +22,17 @@ from . import (
22
22
  result,
23
23
  )
24
24
  from .result import ResultData
25
- from .tools import AgentDeps, RunContext, Tool, ToolFuncContext, ToolFuncEither, ToolFuncPlain, ToolParams
25
+ from .tools import (
26
+ AgentDeps,
27
+ RunContext,
28
+ Tool,
29
+ ToolDefinition,
30
+ ToolFuncContext,
31
+ ToolFuncEither,
32
+ ToolFuncPlain,
33
+ ToolParams,
34
+ ToolPrepareFunc,
35
+ )
26
36
 
27
37
  __all__ = ('Agent',)
28
38
 
@@ -136,7 +146,10 @@ class Agent(Generic[AgentDeps, ResultData]):
136
146
  self._function_tools = {}
137
147
  self._default_retries = retries
138
148
  for tool in tools:
139
- self._register_tool(Tool.infer(tool))
149
+ if isinstance(tool, Tool):
150
+ self._register_tool(tool)
151
+ else:
152
+ self._register_tool(Tool(tool))
140
153
  self._deps_type = deps_type
141
154
  self._system_prompt_functions = []
142
155
  self._max_result_retries = result_retries if result_retries is not None else retries
@@ -166,7 +179,7 @@ class Agent(Generic[AgentDeps, ResultData]):
166
179
  """
167
180
  if infer_name and self.name is None:
168
181
  self._infer_name(inspect.currentframe())
169
- model_used, custom_model, agent_model = await self._get_agent_model(model)
182
+ model_used, mode_selection = await self._get_model(model)
170
183
 
171
184
  deps = self._get_deps(deps)
172
185
 
@@ -174,7 +187,7 @@ class Agent(Generic[AgentDeps, ResultData]):
174
187
  '{agent_name} run {prompt=}',
175
188
  prompt=user_prompt,
176
189
  agent=self,
177
- custom_model=custom_model,
190
+ mode_selection=mode_selection,
178
191
  model_name=model_used.name(),
179
192
  agent_name=self.name or 'agent',
180
193
  ) as run_span:
@@ -182,14 +195,17 @@ class Agent(Generic[AgentDeps, ResultData]):
182
195
  self.last_run_messages = messages
183
196
 
184
197
  for tool in self._function_tools.values():
185
- tool.reset()
198
+ tool.current_retry = 0
186
199
 
187
200
  cost = result.Cost()
188
201
 
189
202
  run_step = 0
190
203
  while True:
191
204
  run_step += 1
192
- with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
205
+ with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
206
+ agent_model = await self._prepare_model(model_used, deps)
207
+
208
+ with _logfire.span('model request', run_step=run_step) as model_req_span:
193
209
  model_response, request_cost = await agent_model.request(messages)
194
210
  model_req_span.set_attribute('response', model_response)
195
211
  model_req_span.set_attribute('cost', request_cost)
@@ -198,12 +214,15 @@ class Agent(Generic[AgentDeps, ResultData]):
198
214
  messages.append(model_response)
199
215
  cost += request_cost
200
216
 
201
- with _logfire.span('handle model response') as handle_span:
202
- either = await self._handle_model_response(model_response, deps)
217
+ with _logfire.span('handle model response', run_step=run_step) as handle_span:
218
+ final_result, response_messages = await self._handle_model_response(model_response, deps)
203
219
 
204
- if isinstance(either, _MarkFinalResult):
205
- # we have a final result, end the conversation
206
- result_data = either.data
220
+ # Add all messages to the conversation
221
+ messages.extend(response_messages)
222
+
223
+ # Check if we got a final result
224
+ if final_result is not None:
225
+ result_data = final_result.data
207
226
  run_span.set_attribute('all_messages', messages)
208
227
  run_span.set_attribute('cost', cost)
209
228
  handle_span.set_attribute('result', result_data)
@@ -211,11 +230,9 @@ class Agent(Generic[AgentDeps, ResultData]):
211
230
  return result.RunResult(messages, new_message_index, result_data, cost)
212
231
  else:
213
232
  # continue the conversation
214
- tool_responses = either
215
- handle_span.set_attribute('tool_responses', tool_responses)
216
- response_msgs = ' '.join(m.role for m in tool_responses)
233
+ handle_span.set_attribute('tool_responses', response_messages)
234
+ response_msgs = ' '.join(r.role for r in response_messages)
217
235
  handle_span.message = f'handle model response -> {response_msgs}'
218
- messages.extend(tool_responses)
219
236
 
220
237
  def run_sync(
221
238
  self,
@@ -273,7 +290,7 @@ class Agent(Generic[AgentDeps, ResultData]):
273
290
  # f_back because `asynccontextmanager` adds one frame
274
291
  if frame := inspect.currentframe(): # pragma: no branch
275
292
  self._infer_name(frame.f_back)
276
- model_used, custom_model, agent_model = await self._get_agent_model(model)
293
+ model_used, mode_selection = await self._get_model(model)
277
294
 
278
295
  deps = self._get_deps(deps)
279
296
 
@@ -281,7 +298,7 @@ class Agent(Generic[AgentDeps, ResultData]):
281
298
  '{agent_name} run stream {prompt=}',
282
299
  prompt=user_prompt,
283
300
  agent=self,
284
- custom_model=custom_model,
301
+ mode_selection=mode_selection,
285
302
  model_name=model_used.name(),
286
303
  agent_name=self.name or 'agent',
287
304
  ) as run_span:
@@ -289,13 +306,17 @@ class Agent(Generic[AgentDeps, ResultData]):
289
306
  self.last_run_messages = messages
290
307
 
291
308
  for tool in self._function_tools.values():
292
- tool.reset()
309
+ tool.current_retry = 0
293
310
 
294
311
  cost = result.Cost()
295
312
 
296
313
  run_step = 0
297
314
  while True:
298
315
  run_step += 1
316
+
317
+ with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
318
+ agent_model = await self._prepare_model(model_used, deps)
319
+
299
320
  with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
300
321
  async with agent_model.request_stream(messages) as model_response:
301
322
  model_req_span.set_attribute('response_type', model_response.__class__.__name__)
@@ -304,10 +325,16 @@ class Agent(Generic[AgentDeps, ResultData]):
304
325
  model_req_span.__exit__(None, None, None)
305
326
 
306
327
  with _logfire.span('handle model response') as handle_span:
307
- either = await self._handle_streamed_model_response(model_response, deps)
328
+ final_result, response_messages = await self._handle_streamed_model_response(
329
+ model_response, deps
330
+ )
331
+
332
+ # Add all messages to the conversation
333
+ messages.extend(response_messages)
308
334
 
309
- if isinstance(either, _MarkFinalResult):
310
- result_stream = either.data
335
+ # Check if we got a final result
336
+ if final_result is not None:
337
+ result_stream = final_result.data
311
338
  run_span.set_attribute('all_messages', messages)
312
339
  handle_span.set_attribute('result_type', result_stream.__class__.__name__)
313
340
  handle_span.message = 'handle model response -> final result'
@@ -323,11 +350,10 @@ class Agent(Generic[AgentDeps, ResultData]):
323
350
  )
324
351
  return
325
352
  else:
326
- tool_responses = either
327
- handle_span.set_attribute('tool_responses', tool_responses)
328
- response_msgs = ' '.join(m.role for m in tool_responses)
353
+ # continue the conversation
354
+ handle_span.set_attribute('tool_responses', response_messages)
355
+ response_msgs = ' '.join(r.role for r in response_messages)
329
356
  handle_span.message = f'handle model response -> {response_msgs}'
330
- messages.extend(tool_responses)
331
357
  # the model_response should have been fully streamed by now, we can add it's cost
332
358
  cost += model_response.cost()
333
359
 
@@ -477,7 +503,11 @@ class Agent(Generic[AgentDeps, ResultData]):
477
503
 
478
504
  @overload
479
505
  def tool(
480
- self, /, *, retries: int | None = None
506
+ self,
507
+ /,
508
+ *,
509
+ retries: int | None = None,
510
+ prepare: ToolPrepareFunc[AgentDeps] | None = None,
481
511
  ) -> Callable[[ToolFuncContext[AgentDeps, ToolParams]], ToolFuncContext[AgentDeps, ToolParams]]: ...
482
512
 
483
513
  def tool(
@@ -486,9 +516,9 @@ class Agent(Generic[AgentDeps, ResultData]):
486
516
  /,
487
517
  *,
488
518
  retries: int | None = None,
519
+ prepare: ToolPrepareFunc[AgentDeps] | None = None,
489
520
  ) -> Any:
490
- """Decorator to register a tool function which takes
491
- [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
521
+ """Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
492
522
 
493
523
  Can decorate a sync or async functions.
494
524
 
@@ -521,20 +551,23 @@ class Agent(Generic[AgentDeps, ResultData]):
521
551
  func: The tool function to register.
522
552
  retries: The number of retries to allow for this tool, defaults to the agent's default retries,
523
553
  which defaults to 1.
524
- """ # noqa: D205
554
+ prepare: custom method to prepare the tool definition for each step, return `None` to omit this
555
+ tool from a given step. This is useful if you want to customise a tool at call time,
556
+ or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
557
+ """
525
558
  if func is None:
526
559
 
527
560
  def tool_decorator(
528
561
  func_: ToolFuncContext[AgentDeps, ToolParams],
529
562
  ) -> ToolFuncContext[AgentDeps, ToolParams]:
530
563
  # noinspection PyTypeChecker
531
- self._register_function(func_, True, retries)
564
+ self._register_function(func_, True, retries, prepare)
532
565
  return func_
533
566
 
534
567
  return tool_decorator
535
568
  else:
536
569
  # noinspection PyTypeChecker
537
- self._register_function(func, True, retries)
570
+ self._register_function(func, True, retries, prepare)
538
571
  return func
539
572
 
540
573
  @overload
@@ -542,10 +575,21 @@ class Agent(Generic[AgentDeps, ResultData]):
542
575
 
543
576
  @overload
544
577
  def tool_plain(
545
- self, /, *, retries: int | None = None
578
+ self,
579
+ /,
580
+ *,
581
+ retries: int | None = None,
582
+ prepare: ToolPrepareFunc[AgentDeps] | None = None,
546
583
  ) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...
547
584
 
548
- def tool_plain(self, func: ToolFuncPlain[ToolParams] | None = None, /, *, retries: int | None = None) -> Any:
585
+ def tool_plain(
586
+ self,
587
+ func: ToolFuncPlain[ToolParams] | None = None,
588
+ /,
589
+ *,
590
+ retries: int | None = None,
591
+ prepare: ToolPrepareFunc[AgentDeps] | None = None,
592
+ ) -> Any:
549
593
  """Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
550
594
 
551
595
  Can decorate a sync or async functions.
@@ -579,30 +623,38 @@ class Agent(Generic[AgentDeps, ResultData]):
579
623
  func: The tool function to register.
580
624
  retries: The number of retries to allow for this tool, defaults to the agent's default retries,
581
625
  which defaults to 1.
626
+ prepare: custom method to prepare the tool definition for each step, return `None` to omit this
627
+ tool from a given step. This is useful if you want to customise a tool at call time,
628
+ or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
582
629
  """
583
630
  if func is None:
584
631
 
585
632
  def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
586
633
  # noinspection PyTypeChecker
587
- self._register_function(func_, False, retries)
634
+ self._register_function(func_, False, retries, prepare)
588
635
  return func_
589
636
 
590
637
  return tool_decorator
591
638
  else:
592
- self._register_function(func, False, retries)
639
+ self._register_function(func, False, retries, prepare)
593
640
  return func
594
641
 
595
642
  def _register_function(
596
- self, func: ToolFuncEither[AgentDeps, ToolParams], takes_ctx: bool, retries: int | None
643
+ self,
644
+ func: ToolFuncEither[AgentDeps, ToolParams],
645
+ takes_ctx: bool,
646
+ retries: int | None,
647
+ prepare: ToolPrepareFunc[AgentDeps] | None,
597
648
  ) -> None:
598
649
  """Private utility to register a function as a tool."""
599
650
  retries_ = retries if retries is not None else self._default_retries
600
- tool = Tool(func, takes_ctx, max_retries=retries_)
651
+ tool = Tool(func, takes_ctx=takes_ctx, max_retries=retries_, prepare=prepare)
601
652
  self._register_tool(tool)
602
653
 
603
654
  def _register_tool(self, tool: Tool[AgentDeps]) -> None:
604
655
  """Private utility to register a tool instance."""
605
656
  if tool.max_retries is None:
657
+ # noinspection PyTypeChecker
606
658
  tool = dataclasses.replace(tool, max_retries=self._default_retries)
607
659
 
608
660
  if tool.name in self._function_tools:
@@ -613,16 +665,14 @@ class Agent(Generic[AgentDeps, ResultData]):
613
665
 
614
666
  self._function_tools[tool.name] = tool
615
667
 
616
- async def _get_agent_model(
617
- self, model: models.Model | models.KnownModelName | None
618
- ) -> tuple[models.Model, models.Model | None, models.AgentModel]:
668
+ async def _get_model(self, model: models.Model | models.KnownModelName | None) -> tuple[models.Model, str]:
619
669
  """Create a model configured for this agent.
620
670
 
621
671
  Args:
622
672
  model: model to use for this run, required if `model` was not set when creating the agent.
623
673
 
624
674
  Returns:
625
- a tuple of `(model used, custom_model if any, agent_model)`
675
+ a tuple of `(model used, how the model was selected)`
626
676
  """
627
677
  model_: models.Model
628
678
  if some_model := self._override_model:
@@ -633,19 +683,35 @@ class Agent(Generic[AgentDeps, ResultData]):
633
683
  '(Even when `override(model=...)` is customizing the model that will actually be called)'
634
684
  )
635
685
  model_ = some_model.value
636
- custom_model = None
686
+ mode_selection = 'override-model'
637
687
  elif model is not None:
638
- custom_model = model_ = models.infer_model(model)
688
+ model_ = models.infer_model(model)
689
+ mode_selection = 'custom'
639
690
  elif self.model is not None:
640
691
  # noinspection PyTypeChecker
641
692
  model_ = self.model = models.infer_model(self.model)
642
- custom_model = None
693
+ mode_selection = 'from-agent'
643
694
  else:
644
695
  raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')
645
696
 
646
- result_tools = list(self._result_schema.tools.values()) if self._result_schema else None
647
- agent_model = await model_.agent_model(self._function_tools, self._allow_text_result, result_tools)
648
- return model_, custom_model, agent_model
697
+ return model_, mode_selection
698
+
699
+ async def _prepare_model(self, model: models.Model, deps: AgentDeps) -> models.AgentModel:
700
+ """Create building tools and create an agent model."""
701
+ function_tools: list[ToolDefinition] = []
702
+
703
+ async def add_tool(tool: Tool[AgentDeps]) -> None:
704
+ ctx = RunContext(deps, tool.current_retry, tool.name)
705
+ if tool_def := await tool.prepare_tool_def(ctx):
706
+ function_tools.append(tool_def)
707
+
708
+ await asyncio.gather(*map(add_tool, self._function_tools.values()))
709
+
710
+ return await model.agent_model(
711
+ function_tools=function_tools,
712
+ allow_text_result=self._allow_text_result,
713
+ result_tools=self._result_schema.tool_defs() if self._result_schema is not None else [],
714
+ )
649
715
 
650
716
  async def _prepare_messages(
651
717
  self, deps: AgentDeps, user_prompt: str, message_history: list[_messages.Message] | None
@@ -665,11 +731,11 @@ class Agent(Generic[AgentDeps, ResultData]):
665
731
 
666
732
  async def _handle_model_response(
667
733
  self, model_response: _messages.ModelAnyResponse, deps: AgentDeps
668
- ) -> _MarkFinalResult[ResultData] | list[_messages.Message]:
734
+ ) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.Message]]:
669
735
  """Process a non-streamed response from the model.
670
736
 
671
737
  Returns:
672
- Return `Either` left: final result data, right: list of messages to send back to the model.
738
+ A tuple of `(final_result, messages)`. If `final_result` is not `None`, the conversation should end.
673
739
  """
674
740
  if model_response.role == 'model-text-response':
675
741
  # plain string response
@@ -679,15 +745,15 @@ class Agent(Generic[AgentDeps, ResultData]):
679
745
  result_data = await self._validate_result(result_data_input, deps, None)
680
746
  except _result.ToolRetryError as e:
681
747
  self._incr_result_retry()
682
- return [e.tool_retry]
748
+ return None, [e.tool_retry]
683
749
  else:
684
- return _MarkFinalResult(result_data)
750
+ return _MarkFinalResult(result_data), []
685
751
  else:
686
752
  self._incr_result_retry()
687
753
  response = _messages.RetryPrompt(
688
754
  content='Plain text responses are not permitted, please call one of the functions instead.',
689
755
  )
690
- return [response]
756
+ return None, [response]
691
757
  elif model_response.role == 'model-structured-response':
692
758
  if self._result_schema is not None:
693
759
  # if there's a result schema, and any of the calls match one of its tools, return the result
@@ -699,9 +765,15 @@ class Agent(Generic[AgentDeps, ResultData]):
699
765
  result_data = await self._validate_result(result_data, deps, call)
700
766
  except _result.ToolRetryError as e:
701
767
  self._incr_result_retry()
702
- return [e.tool_retry]
768
+ return None, [e.tool_retry]
703
769
  else:
704
- return _MarkFinalResult(result_data)
770
+ # Add a ToolReturn message for the schema tool call
771
+ tool_return = _messages.ToolReturn(
772
+ tool_name=call.tool_name,
773
+ content='Final result processed.',
774
+ tool_id=call.tool_id,
775
+ )
776
+ return _MarkFinalResult(result_data), [tool_return]
705
777
 
706
778
  if not model_response.calls:
707
779
  raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
@@ -716,26 +788,24 @@ class Agent(Generic[AgentDeps, ResultData]):
716
788
  messages.append(self._unknown_tool(call.tool_name))
717
789
 
718
790
  with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
719
- messages += await asyncio.gather(*tasks)
720
- return messages
791
+ task_results: Sequence[_messages.Message] = await asyncio.gather(*tasks)
792
+ messages.extend(task_results)
793
+ return None, messages
721
794
  else:
722
795
  assert_never(model_response)
723
796
 
724
797
  async def _handle_streamed_model_response(
725
798
  self, model_response: models.EitherStreamedResponse, deps: AgentDeps
726
- ) -> _MarkFinalResult[models.EitherStreamedResponse] | list[_messages.Message]:
799
+ ) -> tuple[_MarkFinalResult[models.EitherStreamedResponse] | None, list[_messages.Message]]:
727
800
  """Process a streamed response from the model.
728
801
 
729
- TODO: change the response type to `models.EitherStreamedResponse | list[_messages.Message]` once we drop 3.9
730
- (with 3.9 we get `TypeError: Subscripted generics cannot be used with class and instance checks`)
731
-
732
802
  Returns:
733
- Return `Either` left: final result data, right: list of messages to send back to the model.
803
+ A tuple of (final_result, messages). If final_result is not None, the conversation should end.
734
804
  """
735
805
  if isinstance(model_response, models.StreamTextResponse):
736
806
  # plain string response
737
807
  if self._allow_text_result:
738
- return _MarkFinalResult(model_response)
808
+ return _MarkFinalResult(model_response), []
739
809
  else:
740
810
  self._incr_result_retry()
741
811
  response = _messages.RetryPrompt(
@@ -745,7 +815,7 @@ class Agent(Generic[AgentDeps, ResultData]):
745
815
  async for _ in model_response:
746
816
  pass
747
817
 
748
- return [response]
818
+ return None, [response]
749
819
  else:
750
820
  assert isinstance(model_response, models.StreamStructuredResponse), f'Unexpected response: {model_response}'
751
821
  if self._result_schema is not None:
@@ -759,8 +829,14 @@ class Agent(Generic[AgentDeps, ResultData]):
759
829
  break
760
830
  structured_msg = model_response.get()
761
831
 
762
- if self._result_schema.find_tool(structured_msg):
763
- return _MarkFinalResult(model_response)
832
+ if match := self._result_schema.find_tool(structured_msg):
833
+ call, _ = match
834
+ tool_return = _messages.ToolReturn(
835
+ tool_name=call.tool_name,
836
+ content='Final result processed.',
837
+ tool_id=call.tool_id,
838
+ )
839
+ return _MarkFinalResult(model_response), [tool_return]
764
840
 
765
841
  # the model is calling a tool function, consume the response to get the next message
766
842
  async for _ in model_response:
@@ -779,8 +855,9 @@ class Agent(Generic[AgentDeps, ResultData]):
779
855
  messages.append(self._unknown_tool(call.tool_name))
780
856
 
781
857
  with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
782
- messages += await asyncio.gather(*tasks)
783
- return messages
858
+ task_results: Sequence[_messages.Message] = await asyncio.gather(*tasks)
859
+ messages.extend(task_results)
860
+ return None, messages
784
861
 
785
862
  async def _validate_result(
786
863
  self, result_data: ResultData, deps: AgentDeps, tool_call: _messages.ToolCall | None
@@ -852,6 +929,8 @@ class _MarkFinalResult(Generic[ResultData]):
852
929
  """Marker class to indicate that the result is the final result.
853
930
 
854
931
  This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultData` directly.
932
+
933
+ It also avoids problems in the case where the result type is itself `None`, but is set.
855
934
  """
856
935
 
857
936
  data: ResultData