pydantic-ai-slim 0.4.3__py3-none-any.whl → 0.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (48) hide show
  1. pydantic_ai/_a2a.py +3 -3
  2. pydantic_ai/_agent_graph.py +220 -319
  3. pydantic_ai/_cli.py +9 -7
  4. pydantic_ai/_output.py +295 -331
  5. pydantic_ai/_parts_manager.py +2 -2
  6. pydantic_ai/_run_context.py +8 -14
  7. pydantic_ai/_tool_manager.py +190 -0
  8. pydantic_ai/_utils.py +18 -1
  9. pydantic_ai/ag_ui.py +675 -0
  10. pydantic_ai/agent.py +378 -164
  11. pydantic_ai/exceptions.py +12 -0
  12. pydantic_ai/ext/aci.py +12 -3
  13. pydantic_ai/ext/langchain.py +9 -1
  14. pydantic_ai/format_prompt.py +3 -6
  15. pydantic_ai/mcp.py +147 -84
  16. pydantic_ai/messages.py +13 -5
  17. pydantic_ai/models/__init__.py +30 -18
  18. pydantic_ai/models/anthropic.py +1 -1
  19. pydantic_ai/models/function.py +50 -24
  20. pydantic_ai/models/gemini.py +1 -18
  21. pydantic_ai/models/google.py +2 -11
  22. pydantic_ai/models/groq.py +1 -0
  23. pydantic_ai/models/instrumented.py +6 -1
  24. pydantic_ai/models/mistral.py +1 -1
  25. pydantic_ai/models/openai.py +16 -4
  26. pydantic_ai/output.py +21 -7
  27. pydantic_ai/profiles/google.py +1 -1
  28. pydantic_ai/profiles/moonshotai.py +8 -0
  29. pydantic_ai/providers/grok.py +13 -1
  30. pydantic_ai/providers/groq.py +2 -0
  31. pydantic_ai/result.py +58 -45
  32. pydantic_ai/tools.py +26 -119
  33. pydantic_ai/toolsets/__init__.py +22 -0
  34. pydantic_ai/toolsets/abstract.py +155 -0
  35. pydantic_ai/toolsets/combined.py +88 -0
  36. pydantic_ai/toolsets/deferred.py +38 -0
  37. pydantic_ai/toolsets/filtered.py +24 -0
  38. pydantic_ai/toolsets/function.py +238 -0
  39. pydantic_ai/toolsets/prefixed.py +37 -0
  40. pydantic_ai/toolsets/prepared.py +36 -0
  41. pydantic_ai/toolsets/renamed.py +42 -0
  42. pydantic_ai/toolsets/wrapper.py +37 -0
  43. pydantic_ai/usage.py +14 -8
  44. {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.5.dist-info}/METADATA +10 -7
  45. {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.5.dist-info}/RECORD +48 -35
  46. {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.5.dist-info}/WHEEL +0 -0
  47. {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.5.dist-info}/entry_points.txt +0 -0
  48. {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.5.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/agent.py CHANGED
@@ -4,7 +4,8 @@ import dataclasses
4
4
  import inspect
5
5
  import json
6
6
  import warnings
7
- from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
7
+ from asyncio import Lock
8
+ from collections.abc import AsyncIterator, Awaitable, Iterator, Mapping, Sequence
8
9
  from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager
9
10
  from contextvars import ContextVar
10
11
  from copy import deepcopy
@@ -15,7 +16,6 @@ from opentelemetry.trace import NoOpTracer, use_span
15
16
  from pydantic.json_schema import GenerateJsonSchema
16
17
  from typing_extensions import Literal, Never, Self, TypeIs, TypeVar, deprecated
17
18
 
18
- from pydantic_ai.profiles import ModelProfile
19
19
  from pydantic_graph import End, Graph, GraphRun, GraphRunContext
20
20
  from pydantic_graph._utils import get_event_loop
21
21
 
@@ -31,8 +31,11 @@ from . import (
31
31
  usage as _usage,
32
32
  )
33
33
  from ._agent_graph import HistoryProcessor
34
+ from ._output import OutputToolset
35
+ from ._tool_manager import ToolManager
34
36
  from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model
35
37
  from .output import OutputDataT, OutputSpec
38
+ from .profiles import ModelProfile
36
39
  from .result import FinalResult, StreamedRunResult
37
40
  from .settings import ModelSettings, merge_model_settings
38
41
  from .tools import (
@@ -48,6 +51,11 @@ from .tools import (
48
51
  ToolPrepareFunc,
49
52
  ToolsPrepareFunc,
50
53
  )
54
+ from .toolsets import AbstractToolset
55
+ from .toolsets.combined import CombinedToolset
56
+ from .toolsets.function import FunctionToolset
57
+ from .toolsets.prepared import PreparedToolset
58
+ from .usage import Usage, UsageLimits
51
59
 
52
60
  # Re-exporting like this improves auto-import behavior in PyCharm
53
61
  capture_run_messages = _agent_graph.capture_run_messages
@@ -62,11 +70,12 @@ if TYPE_CHECKING:
62
70
  from fasta2a.schema import AgentProvider, Skill
63
71
  from fasta2a.storage import Storage
64
72
  from starlette.middleware import Middleware
65
- from starlette.routing import Route
73
+ from starlette.routing import BaseRoute, Route
66
74
  from starlette.types import ExceptionHandler, Lifespan
67
75
 
68
76
  from pydantic_ai.mcp import MCPServer
69
77
 
78
+ from .ag_ui import AGUIApp
70
79
 
71
80
  __all__ = (
72
81
  'Agent',
@@ -153,12 +162,17 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
153
162
  _system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(
154
163
  repr=False
155
164
  )
165
+ _function_toolset: FunctionToolset[AgentDepsT] = dataclasses.field(repr=False)
166
+ _output_toolset: OutputToolset[AgentDepsT] | None = dataclasses.field(repr=False)
167
+ _user_toolsets: Sequence[AbstractToolset[AgentDepsT]] = dataclasses.field(repr=False)
156
168
  _prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False)
157
- _function_tools: dict[str, Tool[AgentDepsT]] = dataclasses.field(repr=False)
158
- _mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
159
- _default_retries: int = dataclasses.field(repr=False)
169
+ _prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False)
160
170
  _max_result_retries: int = dataclasses.field(repr=False)
161
171
 
172
+ _enter_lock: Lock = dataclasses.field(repr=False)
173
+ _entered_count: int = dataclasses.field(repr=False)
174
+ _exit_stack: AsyncExitStack | None = dataclasses.field(repr=False)
175
+
162
176
  @overload
163
177
  def __init__(
164
178
  self,
@@ -177,7 +191,8 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
177
191
  output_retries: int | None = None,
178
192
  tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
179
193
  prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
180
- mcp_servers: Sequence[MCPServer] = (),
194
+ prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
195
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
181
196
  defer_model_check: bool = False,
182
197
  end_strategy: EndStrategy = 'early',
183
198
  instrument: InstrumentationSettings | bool | None = None,
@@ -186,7 +201,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
186
201
 
187
202
  @overload
188
203
  @deprecated(
189
- '`result_type`, `result_tool_name`, `result_tool_description` & `result_retries` are deprecated, use `output_type` instead. `result_retries` is deprecated, use `output_retries` instead.'
204
+ '`result_type`, `result_tool_name` & `result_tool_description` are deprecated, use `output_type` instead. `result_retries` is deprecated, use `output_retries` instead.'
190
205
  )
191
206
  def __init__(
192
207
  self,
@@ -207,6 +222,36 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
207
222
  result_retries: int | None = None,
208
223
  tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
209
224
  prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
225
+ prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
226
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
227
+ defer_model_check: bool = False,
228
+ end_strategy: EndStrategy = 'early',
229
+ instrument: InstrumentationSettings | bool | None = None,
230
+ history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None,
231
+ ) -> None: ...
232
+
233
+ @overload
234
+ @deprecated('`mcp_servers` is deprecated, use `toolsets` instead.')
235
+ def __init__(
236
+ self,
237
+ model: models.Model | models.KnownModelName | str | None = None,
238
+ *,
239
+ result_type: type[OutputDataT] = str,
240
+ instructions: str
241
+ | _system_prompt.SystemPromptFunc[AgentDepsT]
242
+ | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]]
243
+ | None = None,
244
+ system_prompt: str | Sequence[str] = (),
245
+ deps_type: type[AgentDepsT] = NoneType,
246
+ name: str | None = None,
247
+ model_settings: ModelSettings | None = None,
248
+ retries: int = 1,
249
+ result_tool_name: str = _output.DEFAULT_OUTPUT_TOOL_NAME,
250
+ result_tool_description: str | None = None,
251
+ result_retries: int | None = None,
252
+ tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
253
+ prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
254
+ prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
210
255
  mcp_servers: Sequence[MCPServer] = (),
211
256
  defer_model_check: bool = False,
212
257
  end_strategy: EndStrategy = 'early',
@@ -232,7 +277,8 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
232
277
  output_retries: int | None = None,
233
278
  tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
234
279
  prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
235
- mcp_servers: Sequence[MCPServer] = (),
280
+ prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
281
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
236
282
  defer_model_check: bool = False,
237
283
  end_strategy: EndStrategy = 'early',
238
284
  instrument: InstrumentationSettings | bool | None = None,
@@ -258,14 +304,16 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
258
304
  when the agent is first run.
259
305
  model_settings: Optional model request settings to use for this agent's runs, by default.
260
306
  retries: The default number of retries to allow before raising an error.
261
- output_retries: The maximum number of retries to allow for result validation, defaults to `retries`.
307
+ output_retries: The maximum number of retries to allow for output validation, defaults to `retries`.
262
308
  tools: Tools to register with the agent, you can also register tools via the decorators
263
309
  [`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain].
264
- prepare_tools: custom method to prepare the tool definition of all tools for each step.
310
+ prepare_tools: Custom function to prepare the tool definition of all tools for each step, except output tools.
265
311
  This is useful if you want to customize the definition of multiple tools or you want to register
266
312
  a subset of tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc]
267
- mcp_servers: MCP servers to register with the agent. You should register a [`MCPServer`][pydantic_ai.mcp.MCPServer]
268
- for each server you want the agent to connect to.
313
+ prepare_output_tools: Custom function to prepare the tool definition of all output tools for each step.
314
+ This is useful if you want to customize the definition of multiple output tools or you want to register
315
+ a subset of output tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc]
316
+ toolsets: Toolsets to register with the agent, including MCP servers.
269
317
  defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model,
270
318
  it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately,
271
319
  which checks for the necessary environment variables. Set this to `false`
@@ -329,10 +377,17 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
329
377
  )
330
378
  output_retries = result_retries
331
379
 
380
+ if mcp_servers := _deprecated_kwargs.pop('mcp_servers', None):
381
+ if toolsets is not None: # pragma: no cover
382
+ raise TypeError('`mcp_servers` and `toolsets` cannot be set at the same time.')
383
+ warnings.warn('`mcp_servers` is deprecated, use `toolsets` instead', DeprecationWarning)
384
+ toolsets = mcp_servers
385
+
386
+ _utils.validate_empty_kwargs(_deprecated_kwargs)
387
+
332
388
  default_output_mode = (
333
389
  self.model.profile.default_structured_output_mode if isinstance(self.model, models.Model) else None
334
390
  )
335
- _utils.validate_empty_kwargs(_deprecated_kwargs)
336
391
 
337
392
  self._output_schema = _output.OutputSchema[OutputDataT].build(
338
393
  output_type,
@@ -357,21 +412,28 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
357
412
  self._system_prompt_functions = []
358
413
  self._system_prompt_dynamic_functions = {}
359
414
 
360
- self._function_tools = {}
361
-
362
- self._default_retries = retries
363
415
  self._max_result_retries = output_retries if output_retries is not None else retries
364
- self._mcp_servers = mcp_servers
365
416
  self._prepare_tools = prepare_tools
417
+ self._prepare_output_tools = prepare_output_tools
418
+
419
+ self._output_toolset = self._output_schema.toolset
420
+ if self._output_toolset:
421
+ self._output_toolset.max_retries = self._max_result_retries
422
+
423
+ self._function_toolset = FunctionToolset(tools, max_retries=retries)
424
+ self._user_toolsets = toolsets or ()
425
+
366
426
  self.history_processors = history_processors or []
367
- for tool in tools:
368
- if isinstance(tool, Tool):
369
- self._register_tool(tool)
370
- else:
371
- self._register_tool(Tool(tool))
372
427
 
373
428
  self._override_deps: ContextVar[_utils.Option[AgentDepsT]] = ContextVar('_override_deps', default=None)
374
429
  self._override_model: ContextVar[_utils.Option[models.Model]] = ContextVar('_override_model', default=None)
430
+ self._override_toolsets: ContextVar[_utils.Option[Sequence[AbstractToolset[AgentDepsT]]]] = ContextVar(
431
+ '_override_toolsets', default=None
432
+ )
433
+
434
+ self._enter_lock = _utils.get_async_lock()
435
+ self._entered_count = 0
436
+ self._exit_stack = None
375
437
 
376
438
  @staticmethod
377
439
  def instrument_all(instrument: InstrumentationSettings | bool = True) -> None:
@@ -391,6 +453,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
391
453
  usage_limits: _usage.UsageLimits | None = None,
392
454
  usage: _usage.Usage | None = None,
393
455
  infer_name: bool = True,
456
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
394
457
  ) -> AgentRunResult[OutputDataT]: ...
395
458
 
396
459
  @overload
@@ -406,6 +469,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
406
469
  usage_limits: _usage.UsageLimits | None = None,
407
470
  usage: _usage.Usage | None = None,
408
471
  infer_name: bool = True,
472
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
409
473
  ) -> AgentRunResult[RunOutputDataT]: ...
410
474
 
411
475
  @overload
@@ -422,6 +486,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
422
486
  usage_limits: _usage.UsageLimits | None = None,
423
487
  usage: _usage.Usage | None = None,
424
488
  infer_name: bool = True,
489
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
425
490
  ) -> AgentRunResult[RunOutputDataT]: ...
426
491
 
427
492
  async def run(
@@ -436,6 +501,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
436
501
  usage_limits: _usage.UsageLimits | None = None,
437
502
  usage: _usage.Usage | None = None,
438
503
  infer_name: bool = True,
504
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
439
505
  **_deprecated_kwargs: Never,
440
506
  ) -> AgentRunResult[Any]:
441
507
  """Run the agent with a user prompt in async mode.
@@ -466,6 +532,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
466
532
  usage_limits: Optional limits on model request count or token usage.
467
533
  usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
468
534
  infer_name: Whether to try to infer the agent name from the call frame if it's not set.
535
+ toolsets: Optional additional toolsets for this run.
469
536
 
470
537
  Returns:
471
538
  The result of the run.
@@ -490,6 +557,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
490
557
  model_settings=model_settings,
491
558
  usage_limits=usage_limits,
492
559
  usage=usage,
560
+ toolsets=toolsets,
493
561
  ) as agent_run:
494
562
  async for _ in agent_run:
495
563
  pass
@@ -510,6 +578,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
510
578
  usage_limits: _usage.UsageLimits | None = None,
511
579
  usage: _usage.Usage | None = None,
512
580
  infer_name: bool = True,
581
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
513
582
  **_deprecated_kwargs: Never,
514
583
  ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ...
515
584
 
@@ -526,6 +595,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
526
595
  usage_limits: _usage.UsageLimits | None = None,
527
596
  usage: _usage.Usage | None = None,
528
597
  infer_name: bool = True,
598
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
529
599
  **_deprecated_kwargs: Never,
530
600
  ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ...
531
601
 
@@ -543,6 +613,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
543
613
  usage_limits: _usage.UsageLimits | None = None,
544
614
  usage: _usage.Usage | None = None,
545
615
  infer_name: bool = True,
616
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
546
617
  ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, Any]]: ...
547
618
 
548
619
  @asynccontextmanager
@@ -558,6 +629,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
558
629
  usage_limits: _usage.UsageLimits | None = None,
559
630
  usage: _usage.Usage | None = None,
560
631
  infer_name: bool = True,
632
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
561
633
  **_deprecated_kwargs: Never,
562
634
  ) -> AsyncIterator[AgentRun[AgentDepsT, Any]]:
563
635
  """A contextmanager which can be used to iterate over the agent graph's nodes as they are executed.
@@ -632,6 +704,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
632
704
  usage_limits: Optional limits on model request count or token usage.
633
705
  usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
634
706
  infer_name: Whether to try to infer the agent name from the call frame if it's not set.
707
+ toolsets: Optional additional toolsets for this run.
635
708
 
636
709
  Returns:
637
710
  The result of the run.
@@ -655,6 +728,18 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
655
728
 
656
729
  output_type_ = output_type or self.output_type
657
730
 
731
+ # We consider it a user error if a user tries to restrict the result type while having an output validator that
732
+ # may change the result type from the restricted type to something else. Therefore, we consider the following
733
+ # typecast reasonable, even though it is possible to violate it with otherwise-type-checked code.
734
+ output_validators = cast(list[_output.OutputValidator[AgentDepsT, RunOutputDataT]], self._output_validators)
735
+
736
+ output_toolset = self._output_toolset
737
+ if output_schema != self._output_schema or output_validators:
738
+ output_toolset = cast(OutputToolset[AgentDepsT], output_schema.toolset)
739
+ if output_toolset:
740
+ output_toolset.max_retries = self._max_result_retries
741
+ output_toolset.output_validators = output_validators
742
+
658
743
  # Build the graph
659
744
  graph: Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[Any]] = (
660
745
  _agent_graph.build_agent_graph(self.name, self._deps_type, output_type_)
@@ -669,22 +754,32 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
669
754
  run_step=0,
670
755
  )
671
756
 
672
- # We consider it a user error if a user tries to restrict the result type while having an output validator that
673
- # may change the result type from the restricted type to something else. Therefore, we consider the following
674
- # typecast reasonable, even though it is possible to violate it with otherwise-type-checked code.
675
- output_validators = cast(list[_output.OutputValidator[AgentDepsT, RunOutputDataT]], self._output_validators)
676
-
677
- # Merge model settings in order of precedence: run > agent > model
678
- merged_settings = merge_model_settings(model_used.settings, self.model_settings)
679
- model_settings = merge_model_settings(merged_settings, model_settings)
680
- usage_limits = usage_limits or _usage.UsageLimits()
681
-
682
757
  if isinstance(model_used, InstrumentedModel):
683
758
  instrumentation_settings = model_used.instrumentation_settings
684
759
  tracer = model_used.instrumentation_settings.tracer
685
760
  else:
686
761
  instrumentation_settings = None
687
762
  tracer = NoOpTracer()
763
+
764
+ run_context = RunContext[AgentDepsT](
765
+ deps=deps,
766
+ model=model_used,
767
+ usage=usage,
768
+ prompt=user_prompt,
769
+ messages=state.message_history,
770
+ tracer=tracer,
771
+ trace_include_content=instrumentation_settings is not None and instrumentation_settings.include_content,
772
+ run_step=state.run_step,
773
+ )
774
+
775
+ toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets)
776
+ # This will raise errors for any name conflicts
777
+ run_toolset = await ToolManager[AgentDepsT].build(toolset, run_context)
778
+
779
+ # Merge model settings in order of precedence: run > agent > model
780
+ merged_settings = merge_model_settings(model_used.settings, self.model_settings)
781
+ model_settings = merge_model_settings(merged_settings, model_settings)
782
+ usage_limits = usage_limits or _usage.UsageLimits()
688
783
  agent_name = self.name or 'agent'
689
784
  run_span = tracer.start_span(
690
785
  'agent run',
@@ -711,10 +806,6 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
711
806
  return None
712
807
  return '\n\n'.join(parts).strip()
713
808
 
714
- # Copy the function tools so that retry state is agent-run-specific
715
- # Note that the retry count is reset to 0 when this happens due to the `default=0` and `init=False`.
716
- run_function_tools = {k: dataclasses.replace(v) for k, v in self._function_tools.items()}
717
-
718
809
  graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT](
719
810
  user_deps=deps,
720
811
  prompt=user_prompt,
@@ -727,11 +818,8 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
727
818
  output_schema=output_schema,
728
819
  output_validators=output_validators,
729
820
  history_processors=self.history_processors,
730
- function_tools=run_function_tools,
731
- mcp_servers=self._mcp_servers,
732
- default_retries=self._default_retries,
821
+ tool_manager=run_toolset,
733
822
  tracer=tracer,
734
- prepare_tools=self._prepare_tools,
735
823
  get_instructions=get_instructions,
736
824
  instrumentation_settings=instrumentation_settings,
737
825
  )
@@ -755,14 +843,15 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
755
843
  agent_run = AgentRun(graph_run)
756
844
  yield agent_run
757
845
  if (final_result := agent_run.result) is not None and run_span.is_recording():
758
- run_span.set_attribute(
759
- 'final_result',
760
- (
761
- final_result.output
762
- if isinstance(final_result.output, str)
763
- else json.dumps(InstrumentedModel.serialize_any(final_result.output))
764
- ),
765
- )
846
+ if instrumentation_settings and instrumentation_settings.include_content:
847
+ run_span.set_attribute(
848
+ 'final_result',
849
+ (
850
+ final_result.output
851
+ if isinstance(final_result.output, str)
852
+ else json.dumps(InstrumentedModel.serialize_any(final_result.output))
853
+ ),
854
+ )
766
855
  finally:
767
856
  try:
768
857
  if instrumentation_settings and run_span.is_recording():
@@ -801,6 +890,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
801
890
  usage_limits: _usage.UsageLimits | None = None,
802
891
  usage: _usage.Usage | None = None,
803
892
  infer_name: bool = True,
893
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
804
894
  ) -> AgentRunResult[OutputDataT]: ...
805
895
 
806
896
  @overload
@@ -816,6 +906,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
816
906
  usage_limits: _usage.UsageLimits | None = None,
817
907
  usage: _usage.Usage | None = None,
818
908
  infer_name: bool = True,
909
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
819
910
  ) -> AgentRunResult[RunOutputDataT]: ...
820
911
 
821
912
  @overload
@@ -832,6 +923,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
832
923
  usage_limits: _usage.UsageLimits | None = None,
833
924
  usage: _usage.Usage | None = None,
834
925
  infer_name: bool = True,
926
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
835
927
  ) -> AgentRunResult[RunOutputDataT]: ...
836
928
 
837
929
  def run_sync(
@@ -846,6 +938,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
846
938
  usage_limits: _usage.UsageLimits | None = None,
847
939
  usage: _usage.Usage | None = None,
848
940
  infer_name: bool = True,
941
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
849
942
  **_deprecated_kwargs: Never,
850
943
  ) -> AgentRunResult[Any]:
851
944
  """Synchronously run the agent with a user prompt.
@@ -875,6 +968,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
875
968
  usage_limits: Optional limits on model request count or token usage.
876
969
  usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
877
970
  infer_name: Whether to try to infer the agent name from the call frame if it's not set.
971
+ toolsets: Optional additional toolsets for this run.
878
972
 
879
973
  Returns:
880
974
  The result of the run.
@@ -901,6 +995,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
901
995
  usage_limits=usage_limits,
902
996
  usage=usage,
903
997
  infer_name=False,
998
+ toolsets=toolsets,
904
999
  )
905
1000
  )
906
1001
 
@@ -916,6 +1011,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
916
1011
  usage_limits: _usage.UsageLimits | None = None,
917
1012
  usage: _usage.Usage | None = None,
918
1013
  infer_name: bool = True,
1014
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
919
1015
  ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, OutputDataT]]: ...
920
1016
 
921
1017
  @overload
@@ -931,6 +1027,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
931
1027
  usage_limits: _usage.UsageLimits | None = None,
932
1028
  usage: _usage.Usage | None = None,
933
1029
  infer_name: bool = True,
1030
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
934
1031
  ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ...
935
1032
 
936
1033
  @overload
@@ -947,6 +1044,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
947
1044
  usage_limits: _usage.UsageLimits | None = None,
948
1045
  usage: _usage.Usage | None = None,
949
1046
  infer_name: bool = True,
1047
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
950
1048
  ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ...
951
1049
 
952
1050
  @asynccontextmanager
@@ -962,6 +1060,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
962
1060
  usage_limits: _usage.UsageLimits | None = None,
963
1061
  usage: _usage.Usage | None = None,
964
1062
  infer_name: bool = True,
1063
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
965
1064
  **_deprecated_kwargs: Never,
966
1065
  ) -> AsyncIterator[result.StreamedRunResult[AgentDepsT, Any]]:
967
1066
  """Run the agent with a user prompt in async mode, returning a streamed response.
@@ -989,6 +1088,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
989
1088
  usage_limits: Optional limits on model request count or token usage.
990
1089
  usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
991
1090
  infer_name: Whether to try to infer the agent name from the call frame if it's not set.
1091
+ toolsets: Optional additional toolsets for this run.
992
1092
 
993
1093
  Returns:
994
1094
  The result of the run.
@@ -1019,6 +1119,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1019
1119
  usage_limits=usage_limits,
1020
1120
  usage=usage,
1021
1121
  infer_name=False,
1122
+ toolsets=toolsets,
1022
1123
  ) as agent_run:
1023
1124
  first_node = agent_run.next_node # start with the first node
1024
1125
  assert isinstance(first_node, _agent_graph.UserPromptNode) # the first node should be a user prompt node
@@ -1039,15 +1140,17 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1039
1140
  output_schema, _output.TextOutputSchema
1040
1141
  ):
1041
1142
  return FinalResult(s, None, None)
1042
- elif isinstance(new_part, _messages.ToolCallPart) and isinstance(
1043
- output_schema, _output.ToolOutputSchema
1044
- ): # pragma: no branch
1045
- for call, _ in output_schema.find_tool([new_part]):
1046
- return FinalResult(s, call.tool_name, call.tool_call_id)
1143
+ elif isinstance(new_part, _messages.ToolCallPart) and (
1144
+ tool_def := graph_ctx.deps.tool_manager.get_tool_def(new_part.tool_name)
1145
+ ):
1146
+ if tool_def.kind == 'output':
1147
+ return FinalResult(s, new_part.tool_name, new_part.tool_call_id)
1148
+ elif tool_def.kind == 'deferred':
1149
+ return FinalResult(s, None, None)
1047
1150
  return None
1048
1151
 
1049
- final_result_details = await stream_to_final(streamed_response)
1050
- if final_result_details is not None:
1152
+ final_result = await stream_to_final(streamed_response)
1153
+ if final_result is not None:
1051
1154
  if yielded:
1052
1155
  raise exceptions.AgentRunError('Agent run produced final results') # pragma: no cover
1053
1156
  yielded = True
@@ -1068,17 +1171,13 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1068
1171
 
1069
1172
  parts: list[_messages.ModelRequestPart] = []
1070
1173
  async for _event in _agent_graph.process_function_tools(
1174
+ graph_ctx.deps.tool_manager,
1071
1175
  tool_calls,
1072
- final_result_details.tool_name,
1073
- final_result_details.tool_call_id,
1176
+ final_result,
1074
1177
  graph_ctx,
1075
1178
  parts,
1076
1179
  ):
1077
1180
  pass
1078
- # TODO: Should we do something here related to the retry count?
1079
- # Maybe we should move the incrementing of the retry count to where we actually make a request?
1080
- # if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
1081
- # ctx.state.increment_retries(ctx.deps.max_result_retries)
1082
1181
  if parts:
1083
1182
  messages.append(_messages.ModelRequest(parts))
1084
1183
 
@@ -1089,10 +1188,10 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1089
1188
  streamed_response,
1090
1189
  graph_ctx.deps.output_schema,
1091
1190
  _agent_graph.build_run_context(graph_ctx),
1092
- _output.build_trace_context(graph_ctx),
1093
1191
  graph_ctx.deps.output_validators,
1094
- final_result_details.tool_name,
1192
+ final_result.tool_name,
1095
1193
  on_complete,
1194
+ graph_ctx.deps.tool_manager,
1096
1195
  )
1097
1196
  break
1098
1197
  next_node = await agent_run.next(node)
@@ -1111,8 +1210,9 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1111
1210
  *,
1112
1211
  deps: AgentDepsT | _utils.Unset = _utils.UNSET,
1113
1212
  model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET,
1213
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET,
1114
1214
  ) -> Iterator[None]:
1115
- """Context manager to temporarily override agent dependencies and model.
1215
+ """Context manager to temporarily override agent dependencies, model, or toolsets.
1116
1216
 
1117
1217
  This is particularly useful when testing.
1118
1218
  You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures).
@@ -1120,6 +1220,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1120
1220
  Args:
1121
1221
  deps: The dependencies to use instead of the dependencies passed to the agent run.
1122
1222
  model: The model to use instead of the model passed to the agent run.
1223
+ toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run.
1123
1224
  """
1124
1225
  if _utils.is_set(deps):
1125
1226
  deps_token = self._override_deps.set(_utils.Some(deps))
@@ -1131,6 +1232,11 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1131
1232
  else:
1132
1233
  model_token = None
1133
1234
 
1235
+ if _utils.is_set(toolsets):
1236
+ toolsets_token = self._override_toolsets.set(_utils.Some(toolsets))
1237
+ else:
1238
+ toolsets_token = None
1239
+
1134
1240
  try:
1135
1241
  yield
1136
1242
  finally:
@@ -1138,6 +1244,8 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1138
1244
  self._override_deps.reset(deps_token)
1139
1245
  if model_token is not None:
1140
1246
  self._override_model.reset(model_token)
1247
+ if toolsets_token is not None:
1248
+ self._override_toolsets.reset(toolsets_token)
1141
1249
 
1142
1250
  @overload
1143
1251
  def instructions(
@@ -1423,30 +1531,13 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1423
1531
  strict: Whether to enforce JSON schema compliance (only affects OpenAI).
1424
1532
  See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
1425
1533
  """
1426
- if func is None:
1427
-
1428
- def tool_decorator(
1429
- func_: ToolFuncContext[AgentDepsT, ToolParams],
1430
- ) -> ToolFuncContext[AgentDepsT, ToolParams]:
1431
- # noinspection PyTypeChecker
1432
- self._register_function(
1433
- func_,
1434
- True,
1435
- name,
1436
- retries,
1437
- prepare,
1438
- docstring_format,
1439
- require_parameter_descriptions,
1440
- schema_generator,
1441
- strict,
1442
- )
1443
- return func_
1444
1534
 
1445
- return tool_decorator
1446
- else:
1535
+ def tool_decorator(
1536
+ func_: ToolFuncContext[AgentDepsT, ToolParams],
1537
+ ) -> ToolFuncContext[AgentDepsT, ToolParams]:
1447
1538
  # noinspection PyTypeChecker
1448
- self._register_function(
1449
- func,
1539
+ self._function_toolset.add_function(
1540
+ func_,
1450
1541
  True,
1451
1542
  name,
1452
1543
  retries,
@@ -1456,7 +1547,9 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1456
1547
  schema_generator,
1457
1548
  strict,
1458
1549
  )
1459
- return func
1550
+ return func_
1551
+
1552
+ return tool_decorator if func is None else tool_decorator(func)
1460
1553
 
1461
1554
  @overload
1462
1555
  def tool_plain(self, func: ToolFuncPlain[ToolParams], /) -> ToolFuncPlain[ToolParams]: ...
@@ -1532,27 +1625,11 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1532
1625
  strict: Whether to enforce JSON schema compliance (only affects OpenAI).
1533
1626
  See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
1534
1627
  """
1535
- if func is None:
1536
1628
 
1537
- def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
1538
- # noinspection PyTypeChecker
1539
- self._register_function(
1540
- func_,
1541
- False,
1542
- name,
1543
- retries,
1544
- prepare,
1545
- docstring_format,
1546
- require_parameter_descriptions,
1547
- schema_generator,
1548
- strict,
1549
- )
1550
- return func_
1551
-
1552
- return tool_decorator
1553
- else:
1554
- self._register_function(
1555
- func,
1629
+ def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
1630
+ # noinspection PyTypeChecker
1631
+ self._function_toolset.add_function(
1632
+ func_,
1556
1633
  False,
1557
1634
  name,
1558
1635
  retries,
@@ -1562,48 +1639,9 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1562
1639
  schema_generator,
1563
1640
  strict,
1564
1641
  )
1565
- return func
1566
-
1567
- def _register_function(
1568
- self,
1569
- func: ToolFuncEither[AgentDepsT, ToolParams],
1570
- takes_ctx: bool,
1571
- name: str | None,
1572
- retries: int | None,
1573
- prepare: ToolPrepareFunc[AgentDepsT] | None,
1574
- docstring_format: DocstringFormat,
1575
- require_parameter_descriptions: bool,
1576
- schema_generator: type[GenerateJsonSchema],
1577
- strict: bool | None,
1578
- ) -> None:
1579
- """Private utility to register a function as a tool."""
1580
- retries_ = retries if retries is not None else self._default_retries
1581
- tool = Tool[AgentDepsT](
1582
- func,
1583
- takes_ctx=takes_ctx,
1584
- name=name,
1585
- max_retries=retries_,
1586
- prepare=prepare,
1587
- docstring_format=docstring_format,
1588
- require_parameter_descriptions=require_parameter_descriptions,
1589
- schema_generator=schema_generator,
1590
- strict=strict,
1591
- )
1592
- self._register_tool(tool)
1593
-
1594
- def _register_tool(self, tool: Tool[AgentDepsT]) -> None:
1595
- """Private utility to register a tool instance."""
1596
- if tool.max_retries is None:
1597
- # noinspection PyTypeChecker
1598
- tool = dataclasses.replace(tool, max_retries=self._default_retries)
1599
-
1600
- if tool.name in self._function_tools:
1601
- raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}')
1602
-
1603
- if tool.name in self._output_schema.tools:
1604
- raise exceptions.UserError(f'Tool name conflicts with output tool name: {tool.name!r}')
1642
+ return func_
1605
1643
 
1606
- self._function_tools[tool.name] = tool
1644
+ return tool_decorator if func is None else tool_decorator(func)
1607
1645
 
1608
1646
  def _get_model(self, model: models.Model | models.KnownModelName | str | None) -> models.Model:
1609
1647
  """Create a model configured for this agent.
@@ -1649,6 +1687,37 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1649
1687
  else:
1650
1688
  return deps
1651
1689
 
1690
+ def _get_toolset(
1691
+ self,
1692
+ output_toolset: AbstractToolset[AgentDepsT] | None | _utils.Unset = _utils.UNSET,
1693
+ additional_toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
1694
+ ) -> AbstractToolset[AgentDepsT]:
1695
+ """Get the complete toolset.
1696
+
1697
+ Args:
1698
+ output_toolset: The output toolset to use instead of the one built at agent construction time.
1699
+ additional_toolsets: Additional toolsets to add.
1700
+ """
1701
+ if some_user_toolsets := self._override_toolsets.get():
1702
+ user_toolsets = some_user_toolsets.value
1703
+ elif additional_toolsets is not None:
1704
+ user_toolsets = [*self._user_toolsets, *additional_toolsets]
1705
+ else:
1706
+ user_toolsets = self._user_toolsets
1707
+
1708
+ all_toolsets = [self._function_toolset, *user_toolsets]
1709
+
1710
+ if self._prepare_tools:
1711
+ all_toolsets = [PreparedToolset(CombinedToolset(all_toolsets), self._prepare_tools)]
1712
+
1713
+ output_toolset = output_toolset if _utils.is_set(output_toolset) else self._output_toolset
1714
+ if output_toolset is not None:
1715
+ if self._prepare_output_tools:
1716
+ output_toolset = PreparedToolset(output_toolset, self._prepare_output_tools)
1717
+ all_toolsets = [output_toolset, *all_toolsets]
1718
+
1719
+ return CombinedToolset(all_toolsets)
1720
+
1652
1721
  def _infer_name(self, function_frame: FrameType | None) -> None:
1653
1722
  """Infer the agent name from the call frame.
1654
1723
 
@@ -1734,28 +1803,167 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1734
1803
  """
1735
1804
  return isinstance(node, End)
1736
1805
 
1806
+ async def __aenter__(self) -> Self:
1807
+ """Enter the agent context.
1808
+
1809
+ This will start all [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] registered as `toolsets` so they are ready to be used.
1810
+
1811
+ This is a no-op if the agent has already been entered.
1812
+ """
1813
+ async with self._enter_lock:
1814
+ if self._entered_count == 0:
1815
+ self._exit_stack = AsyncExitStack()
1816
+ toolset = self._get_toolset()
1817
+ await self._exit_stack.enter_async_context(toolset)
1818
+ self._entered_count += 1
1819
+ return self
1820
+
1821
+ async def __aexit__(self, *args: Any) -> bool | None:
1822
+ async with self._enter_lock:
1823
+ self._entered_count -= 1
1824
+ if self._entered_count == 0 and self._exit_stack is not None:
1825
+ await self._exit_stack.aclose()
1826
+ self._exit_stack = None
1827
+
1828
+ def set_mcp_sampling_model(self, model: models.Model | models.KnownModelName | str | None = None) -> None:
1829
+ """Set the sampling model on all MCP servers registered with the agent.
1830
+
1831
+ If no sampling model is provided, the agent's model will be used.
1832
+ """
1833
+ try:
1834
+ sampling_model = models.infer_model(model) if model else self._get_model(None)
1835
+ except exceptions.UserError as e:
1836
+ raise exceptions.UserError('No sampling model provided and no model set on the agent.') from e
1837
+
1838
+ from .mcp import MCPServer
1839
+
1840
+ def _set_sampling_model(toolset: AbstractToolset[AgentDepsT]) -> None:
1841
+ if isinstance(toolset, MCPServer):
1842
+ toolset.sampling_model = sampling_model
1843
+
1844
+ self._get_toolset().apply(_set_sampling_model)
1845
+
1737
1846
  @asynccontextmanager
1847
+ @deprecated(
1848
+ '`run_mcp_servers` is deprecated, use `async with agent:` instead. If you need to set a sampling model on all MCP servers, use `agent.set_mcp_sampling_model()`.'
1849
+ )
1738
1850
  async def run_mcp_servers(
1739
1851
  self, model: models.Model | models.KnownModelName | str | None = None
1740
1852
  ) -> AsyncIterator[None]:
1741
1853
  """Run [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] so they can be used by the agent.
1742
1854
 
1855
+ Deprecated: use [`async with agent`][pydantic_ai.agent.Agent.__aenter__] instead.
1856
+ If you need to set a sampling model on all MCP servers, use [`agent.set_mcp_sampling_model()`][pydantic_ai.agent.Agent.set_mcp_sampling_model].
1857
+
1743
1858
  Returns: a context manager to start and shutdown the servers.
1744
1859
  """
1745
1860
  try:
1746
- sampling_model: models.Model | None = self._get_model(model)
1747
- except exceptions.UserError: # pragma: no cover
1748
- sampling_model = None
1861
+ self.set_mcp_sampling_model(model)
1862
+ except exceptions.UserError:
1863
+ if model is not None:
1864
+ raise
1749
1865
 
1750
- exit_stack = AsyncExitStack()
1751
- try:
1752
- for mcp_server in self._mcp_servers:
1753
- if sampling_model is not None: # pragma: no branch
1754
- mcp_server.sampling_model = sampling_model
1755
- await exit_stack.enter_async_context(mcp_server)
1866
+ async with self:
1756
1867
  yield
1757
- finally:
1758
- await exit_stack.aclose()
1868
+
1869
+ def to_ag_ui(
1870
+ self,
1871
+ *,
1872
+ # Agent.iter parameters
1873
+ output_type: OutputSpec[OutputDataT] | None = None,
1874
+ model: models.Model | models.KnownModelName | str | None = None,
1875
+ deps: AgentDepsT = None,
1876
+ model_settings: ModelSettings | None = None,
1877
+ usage_limits: UsageLimits | None = None,
1878
+ usage: Usage | None = None,
1879
+ infer_name: bool = True,
1880
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
1881
+ # Starlette
1882
+ debug: bool = False,
1883
+ routes: Sequence[BaseRoute] | None = None,
1884
+ middleware: Sequence[Middleware] | None = None,
1885
+ exception_handlers: Mapping[Any, ExceptionHandler] | None = None,
1886
+ on_startup: Sequence[Callable[[], Any]] | None = None,
1887
+ on_shutdown: Sequence[Callable[[], Any]] | None = None,
1888
+ lifespan: Lifespan[AGUIApp[AgentDepsT, OutputDataT]] | None = None,
1889
+ ) -> AGUIApp[AgentDepsT, OutputDataT]:
1890
+ """Convert the agent to an AG-UI application.
1891
+
1892
+ This allows you to use the agent with a compatible AG-UI frontend.
1893
+
1894
+ Example:
1895
+ ```python
1896
+ from pydantic_ai import Agent
1897
+
1898
+ agent = Agent('openai:gpt-4o')
1899
+ app = agent.to_ag_ui()
1900
+ ```
1901
+
1902
+ The `app` is an ASGI application that can be used with any ASGI server.
1903
+
1904
+ To run the application, you can use the following command:
1905
+
1906
+ ```bash
1907
+ uvicorn app:app --host 0.0.0.0 --port 8000
1908
+ ```
1909
+
1910
+ See [AG-UI docs](../ag-ui.md) for more information.
1911
+
1912
+ Args:
1913
+ output_type: Custom output type to use for this run, `output_type` may only be used if the agent has
1914
+ no output validators since output validators would expect an argument that matches the agent's
1915
+ output type.
1916
+ model: Optional model to use for this run, required if `model` was not set when creating the agent.
1917
+ deps: Optional dependencies to use for this run.
1918
+ model_settings: Optional settings to use for this model's request.
1919
+ usage_limits: Optional limits on model request count or token usage.
1920
+ usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
1921
+ infer_name: Whether to try to infer the agent name from the call frame if it's not set.
1922
+ toolsets: Optional list of toolsets to use for this agent, defaults to the agent's toolset.
1923
+
1924
+ debug: Boolean indicating if debug tracebacks should be returned on errors.
1925
+ routes: A list of routes to serve incoming HTTP and WebSocket requests.
1926
+ middleware: A list of middleware to run for every request. A starlette application will always
1927
+ automatically include two middleware classes. `ServerErrorMiddleware` is added as the very
1928
+ outermost middleware, to handle any uncaught errors occurring anywhere in the entire stack.
1929
+ `ExceptionMiddleware` is added as the very innermost middleware, to deal with handled
1930
+ exception cases occurring in the routing or endpoints.
1931
+ exception_handlers: A mapping of either integer status codes, or exception class types onto
1932
+ callables which handle the exceptions. Exception handler callables should be of the form
1933
+ `handler(request, exc) -> response` and may be either standard functions, or async functions.
1934
+ on_startup: A list of callables to run on application startup. Startup handler callables do not
1935
+ take any arguments, and may be either standard functions, or async functions.
1936
+ on_shutdown: A list of callables to run on application shutdown. Shutdown handler callables do
1937
+ not take any arguments, and may be either standard functions, or async functions.
1938
+ lifespan: A lifespan context function, which can be used to perform startup and shutdown tasks.
1939
+ This is a newer style that replaces the `on_startup` and `on_shutdown` handlers. Use one or
1940
+ the other, not both.
1941
+
1942
+ Returns:
1943
+ An ASGI application for running Pydantic AI agents with AG-UI protocol support.
1944
+ """
1945
+ from .ag_ui import AGUIApp
1946
+
1947
+ return AGUIApp(
1948
+ agent=self,
1949
+ # Agent.iter parameters
1950
+ output_type=output_type,
1951
+ model=model,
1952
+ deps=deps,
1953
+ model_settings=model_settings,
1954
+ usage_limits=usage_limits,
1955
+ usage=usage,
1956
+ infer_name=infer_name,
1957
+ toolsets=toolsets,
1958
+ # Starlette
1959
+ debug=debug,
1960
+ routes=routes,
1961
+ middleware=middleware,
1962
+ exception_handlers=exception_handlers,
1963
+ on_startup=on_startup,
1964
+ on_shutdown=on_shutdown,
1965
+ lifespan=lifespan,
1966
+ )
1759
1967
 
1760
1968
  def to_a2a(
1761
1969
  self,
@@ -2112,12 +2320,18 @@ class AgentRunResult(Generic[OutputDataT]):
2112
2320
  """
2113
2321
  if not self._output_tool_name:
2114
2322
  raise ValueError('Cannot set output tool return content when the return type is `str`.')
2115
- messages = deepcopy(self._state.message_history)
2323
+
2324
+ messages = self._state.message_history
2116
2325
  last_message = messages[-1]
2117
- for part in last_message.parts:
2326
+ for idx, part in enumerate(last_message.parts):
2118
2327
  if isinstance(part, _messages.ToolReturnPart) and part.tool_name == self._output_tool_name:
2119
- part.content = return_content
2120
- return messages
2328
+ # Only do deepcopy when we have to modify
2329
+ copied_messages = list(messages)
2330
+ copied_last = deepcopy(last_message)
2331
+ copied_last.parts[idx].content = return_content # type: ignore[misc]
2332
+ copied_messages[-1] = copied_last
2333
+ return copied_messages
2334
+
2121
2335
  raise LookupError(f'No tool call found with tool name {self._output_tool_name!r}.')
2122
2336
 
2123
2337
  @overload