pydantic-ai-slim 0.4.3__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (45) hide show
  1. pydantic_ai/_agent_graph.py +220 -319
  2. pydantic_ai/_cli.py +9 -7
  3. pydantic_ai/_output.py +295 -331
  4. pydantic_ai/_parts_manager.py +2 -2
  5. pydantic_ai/_run_context.py +8 -14
  6. pydantic_ai/_tool_manager.py +190 -0
  7. pydantic_ai/_utils.py +18 -1
  8. pydantic_ai/ag_ui.py +675 -0
  9. pydantic_ai/agent.py +369 -156
  10. pydantic_ai/exceptions.py +12 -0
  11. pydantic_ai/ext/aci.py +12 -3
  12. pydantic_ai/ext/langchain.py +9 -1
  13. pydantic_ai/mcp.py +147 -84
  14. pydantic_ai/messages.py +13 -5
  15. pydantic_ai/models/__init__.py +30 -18
  16. pydantic_ai/models/anthropic.py +1 -1
  17. pydantic_ai/models/function.py +50 -24
  18. pydantic_ai/models/gemini.py +1 -9
  19. pydantic_ai/models/google.py +2 -11
  20. pydantic_ai/models/groq.py +1 -0
  21. pydantic_ai/models/mistral.py +1 -1
  22. pydantic_ai/models/openai.py +3 -3
  23. pydantic_ai/output.py +21 -7
  24. pydantic_ai/profiles/google.py +1 -1
  25. pydantic_ai/profiles/moonshotai.py +8 -0
  26. pydantic_ai/providers/grok.py +13 -1
  27. pydantic_ai/providers/groq.py +2 -0
  28. pydantic_ai/result.py +58 -45
  29. pydantic_ai/tools.py +26 -119
  30. pydantic_ai/toolsets/__init__.py +22 -0
  31. pydantic_ai/toolsets/abstract.py +155 -0
  32. pydantic_ai/toolsets/combined.py +88 -0
  33. pydantic_ai/toolsets/deferred.py +38 -0
  34. pydantic_ai/toolsets/filtered.py +24 -0
  35. pydantic_ai/toolsets/function.py +238 -0
  36. pydantic_ai/toolsets/prefixed.py +37 -0
  37. pydantic_ai/toolsets/prepared.py +36 -0
  38. pydantic_ai/toolsets/renamed.py +42 -0
  39. pydantic_ai/toolsets/wrapper.py +37 -0
  40. pydantic_ai/usage.py +14 -8
  41. {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/METADATA +10 -7
  42. {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/RECORD +45 -32
  43. {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/WHEEL +0 -0
  44. {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/entry_points.txt +0 -0
  45. {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/agent.py CHANGED
@@ -4,7 +4,8 @@ import dataclasses
4
4
  import inspect
5
5
  import json
6
6
  import warnings
7
- from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
7
+ from asyncio import Lock
8
+ from collections.abc import AsyncIterator, Awaitable, Iterator, Mapping, Sequence
8
9
  from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager
9
10
  from contextvars import ContextVar
10
11
  from copy import deepcopy
@@ -15,7 +16,6 @@ from opentelemetry.trace import NoOpTracer, use_span
15
16
  from pydantic.json_schema import GenerateJsonSchema
16
17
  from typing_extensions import Literal, Never, Self, TypeIs, TypeVar, deprecated
17
18
 
18
- from pydantic_ai.profiles import ModelProfile
19
19
  from pydantic_graph import End, Graph, GraphRun, GraphRunContext
20
20
  from pydantic_graph._utils import get_event_loop
21
21
 
@@ -31,8 +31,11 @@ from . import (
31
31
  usage as _usage,
32
32
  )
33
33
  from ._agent_graph import HistoryProcessor
34
+ from ._output import OutputToolset
35
+ from ._tool_manager import ToolManager
34
36
  from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model
35
37
  from .output import OutputDataT, OutputSpec
38
+ from .profiles import ModelProfile
36
39
  from .result import FinalResult, StreamedRunResult
37
40
  from .settings import ModelSettings, merge_model_settings
38
41
  from .tools import (
@@ -48,6 +51,11 @@ from .tools import (
48
51
  ToolPrepareFunc,
49
52
  ToolsPrepareFunc,
50
53
  )
54
+ from .toolsets import AbstractToolset
55
+ from .toolsets.combined import CombinedToolset
56
+ from .toolsets.function import FunctionToolset
57
+ from .toolsets.prepared import PreparedToolset
58
+ from .usage import Usage, UsageLimits
51
59
 
52
60
  # Re-exporting like this improves auto-import behavior in PyCharm
53
61
  capture_run_messages = _agent_graph.capture_run_messages
@@ -62,11 +70,12 @@ if TYPE_CHECKING:
62
70
  from fasta2a.schema import AgentProvider, Skill
63
71
  from fasta2a.storage import Storage
64
72
  from starlette.middleware import Middleware
65
- from starlette.routing import Route
73
+ from starlette.routing import BaseRoute, Route
66
74
  from starlette.types import ExceptionHandler, Lifespan
67
75
 
68
76
  from pydantic_ai.mcp import MCPServer
69
77
 
78
+ from .ag_ui import AGUIApp
70
79
 
71
80
  __all__ = (
72
81
  'Agent',
@@ -153,12 +162,17 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
153
162
  _system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(
154
163
  repr=False
155
164
  )
165
+ _function_toolset: FunctionToolset[AgentDepsT] = dataclasses.field(repr=False)
166
+ _output_toolset: OutputToolset[AgentDepsT] | None = dataclasses.field(repr=False)
167
+ _user_toolsets: Sequence[AbstractToolset[AgentDepsT]] = dataclasses.field(repr=False)
156
168
  _prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False)
157
- _function_tools: dict[str, Tool[AgentDepsT]] = dataclasses.field(repr=False)
158
- _mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
159
- _default_retries: int = dataclasses.field(repr=False)
169
+ _prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False)
160
170
  _max_result_retries: int = dataclasses.field(repr=False)
161
171
 
172
+ _enter_lock: Lock = dataclasses.field(repr=False)
173
+ _entered_count: int = dataclasses.field(repr=False)
174
+ _exit_stack: AsyncExitStack | None = dataclasses.field(repr=False)
175
+
162
176
  @overload
163
177
  def __init__(
164
178
  self,
@@ -177,7 +191,8 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
177
191
  output_retries: int | None = None,
178
192
  tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
179
193
  prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
180
- mcp_servers: Sequence[MCPServer] = (),
194
+ prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
195
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
181
196
  defer_model_check: bool = False,
182
197
  end_strategy: EndStrategy = 'early',
183
198
  instrument: InstrumentationSettings | bool | None = None,
@@ -186,7 +201,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
186
201
 
187
202
  @overload
188
203
  @deprecated(
189
- '`result_type`, `result_tool_name`, `result_tool_description` & `result_retries` are deprecated, use `output_type` instead. `result_retries` is deprecated, use `output_retries` instead.'
204
+ '`result_type`, `result_tool_name` & `result_tool_description` are deprecated, use `output_type` instead. `result_retries` is deprecated, use `output_retries` instead.'
190
205
  )
191
206
  def __init__(
192
207
  self,
@@ -207,6 +222,36 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
207
222
  result_retries: int | None = None,
208
223
  tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
209
224
  prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
225
+ prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
226
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
227
+ defer_model_check: bool = False,
228
+ end_strategy: EndStrategy = 'early',
229
+ instrument: InstrumentationSettings | bool | None = None,
230
+ history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None,
231
+ ) -> None: ...
232
+
233
+ @overload
234
+ @deprecated('`mcp_servers` is deprecated, use `toolsets` instead.')
235
+ def __init__(
236
+ self,
237
+ model: models.Model | models.KnownModelName | str | None = None,
238
+ *,
239
+ result_type: type[OutputDataT] = str,
240
+ instructions: str
241
+ | _system_prompt.SystemPromptFunc[AgentDepsT]
242
+ | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]]
243
+ | None = None,
244
+ system_prompt: str | Sequence[str] = (),
245
+ deps_type: type[AgentDepsT] = NoneType,
246
+ name: str | None = None,
247
+ model_settings: ModelSettings | None = None,
248
+ retries: int = 1,
249
+ result_tool_name: str = _output.DEFAULT_OUTPUT_TOOL_NAME,
250
+ result_tool_description: str | None = None,
251
+ result_retries: int | None = None,
252
+ tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
253
+ prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
254
+ prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
210
255
  mcp_servers: Sequence[MCPServer] = (),
211
256
  defer_model_check: bool = False,
212
257
  end_strategy: EndStrategy = 'early',
@@ -232,7 +277,8 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
232
277
  output_retries: int | None = None,
233
278
  tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
234
279
  prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
235
- mcp_servers: Sequence[MCPServer] = (),
280
+ prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
281
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
236
282
  defer_model_check: bool = False,
237
283
  end_strategy: EndStrategy = 'early',
238
284
  instrument: InstrumentationSettings | bool | None = None,
@@ -258,14 +304,16 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
258
304
  when the agent is first run.
259
305
  model_settings: Optional model request settings to use for this agent's runs, by default.
260
306
  retries: The default number of retries to allow before raising an error.
261
- output_retries: The maximum number of retries to allow for result validation, defaults to `retries`.
307
+ output_retries: The maximum number of retries to allow for output validation, defaults to `retries`.
262
308
  tools: Tools to register with the agent, you can also register tools via the decorators
263
309
  [`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain].
264
- prepare_tools: custom method to prepare the tool definition of all tools for each step.
310
+ prepare_tools: Custom function to prepare the tool definition of all tools for each step, except output tools.
265
311
  This is useful if you want to customize the definition of multiple tools or you want to register
266
312
  a subset of tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc]
267
- mcp_servers: MCP servers to register with the agent. You should register a [`MCPServer`][pydantic_ai.mcp.MCPServer]
268
- for each server you want the agent to connect to.
313
+ prepare_output_tools: Custom function to prepare the tool definition of all output tools for each step.
314
+ This is useful if you want to customize the definition of multiple output tools or you want to register
315
+ a subset of output tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc]
316
+ toolsets: Toolsets to register with the agent, including MCP servers.
269
317
  defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model,
270
318
  it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately,
271
319
  which checks for the necessary environment variables. Set this to `false`
@@ -329,10 +377,17 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
329
377
  )
330
378
  output_retries = result_retries
331
379
 
380
+ if mcp_servers := _deprecated_kwargs.pop('mcp_servers', None):
381
+ if toolsets is not None: # pragma: no cover
382
+ raise TypeError('`mcp_servers` and `toolsets` cannot be set at the same time.')
383
+ warnings.warn('`mcp_servers` is deprecated, use `toolsets` instead', DeprecationWarning)
384
+ toolsets = mcp_servers
385
+
386
+ _utils.validate_empty_kwargs(_deprecated_kwargs)
387
+
332
388
  default_output_mode = (
333
389
  self.model.profile.default_structured_output_mode if isinstance(self.model, models.Model) else None
334
390
  )
335
- _utils.validate_empty_kwargs(_deprecated_kwargs)
336
391
 
337
392
  self._output_schema = _output.OutputSchema[OutputDataT].build(
338
393
  output_type,
@@ -357,21 +412,28 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
357
412
  self._system_prompt_functions = []
358
413
  self._system_prompt_dynamic_functions = {}
359
414
 
360
- self._function_tools = {}
361
-
362
- self._default_retries = retries
363
415
  self._max_result_retries = output_retries if output_retries is not None else retries
364
- self._mcp_servers = mcp_servers
365
416
  self._prepare_tools = prepare_tools
417
+ self._prepare_output_tools = prepare_output_tools
418
+
419
+ self._output_toolset = self._output_schema.toolset
420
+ if self._output_toolset:
421
+ self._output_toolset.max_retries = self._max_result_retries
422
+
423
+ self._function_toolset = FunctionToolset(tools, max_retries=retries)
424
+ self._user_toolsets = toolsets or ()
425
+
366
426
  self.history_processors = history_processors or []
367
- for tool in tools:
368
- if isinstance(tool, Tool):
369
- self._register_tool(tool)
370
- else:
371
- self._register_tool(Tool(tool))
372
427
 
373
428
  self._override_deps: ContextVar[_utils.Option[AgentDepsT]] = ContextVar('_override_deps', default=None)
374
429
  self._override_model: ContextVar[_utils.Option[models.Model]] = ContextVar('_override_model', default=None)
430
+ self._override_toolsets: ContextVar[_utils.Option[Sequence[AbstractToolset[AgentDepsT]]]] = ContextVar(
431
+ '_override_toolsets', default=None
432
+ )
433
+
434
+ self._enter_lock = _utils.get_async_lock()
435
+ self._entered_count = 0
436
+ self._exit_stack = None
375
437
 
376
438
  @staticmethod
377
439
  def instrument_all(instrument: InstrumentationSettings | bool = True) -> None:
@@ -391,6 +453,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
391
453
  usage_limits: _usage.UsageLimits | None = None,
392
454
  usage: _usage.Usage | None = None,
393
455
  infer_name: bool = True,
456
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
394
457
  ) -> AgentRunResult[OutputDataT]: ...
395
458
 
396
459
  @overload
@@ -406,6 +469,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
406
469
  usage_limits: _usage.UsageLimits | None = None,
407
470
  usage: _usage.Usage | None = None,
408
471
  infer_name: bool = True,
472
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
409
473
  ) -> AgentRunResult[RunOutputDataT]: ...
410
474
 
411
475
  @overload
@@ -422,6 +486,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
422
486
  usage_limits: _usage.UsageLimits | None = None,
423
487
  usage: _usage.Usage | None = None,
424
488
  infer_name: bool = True,
489
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
425
490
  ) -> AgentRunResult[RunOutputDataT]: ...
426
491
 
427
492
  async def run(
@@ -436,6 +501,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
436
501
  usage_limits: _usage.UsageLimits | None = None,
437
502
  usage: _usage.Usage | None = None,
438
503
  infer_name: bool = True,
504
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
439
505
  **_deprecated_kwargs: Never,
440
506
  ) -> AgentRunResult[Any]:
441
507
  """Run the agent with a user prompt in async mode.
@@ -466,6 +532,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
466
532
  usage_limits: Optional limits on model request count or token usage.
467
533
  usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
468
534
  infer_name: Whether to try to infer the agent name from the call frame if it's not set.
535
+ toolsets: Optional additional toolsets for this run.
469
536
 
470
537
  Returns:
471
538
  The result of the run.
@@ -490,6 +557,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
490
557
  model_settings=model_settings,
491
558
  usage_limits=usage_limits,
492
559
  usage=usage,
560
+ toolsets=toolsets,
493
561
  ) as agent_run:
494
562
  async for _ in agent_run:
495
563
  pass
@@ -510,6 +578,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
510
578
  usage_limits: _usage.UsageLimits | None = None,
511
579
  usage: _usage.Usage | None = None,
512
580
  infer_name: bool = True,
581
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
513
582
  **_deprecated_kwargs: Never,
514
583
  ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ...
515
584
 
@@ -526,6 +595,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
526
595
  usage_limits: _usage.UsageLimits | None = None,
527
596
  usage: _usage.Usage | None = None,
528
597
  infer_name: bool = True,
598
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
529
599
  **_deprecated_kwargs: Never,
530
600
  ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ...
531
601
 
@@ -543,6 +613,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
543
613
  usage_limits: _usage.UsageLimits | None = None,
544
614
  usage: _usage.Usage | None = None,
545
615
  infer_name: bool = True,
616
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
546
617
  ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, Any]]: ...
547
618
 
548
619
  @asynccontextmanager
@@ -558,6 +629,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
558
629
  usage_limits: _usage.UsageLimits | None = None,
559
630
  usage: _usage.Usage | None = None,
560
631
  infer_name: bool = True,
632
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
561
633
  **_deprecated_kwargs: Never,
562
634
  ) -> AsyncIterator[AgentRun[AgentDepsT, Any]]:
563
635
  """A contextmanager which can be used to iterate over the agent graph's nodes as they are executed.
@@ -632,6 +704,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
632
704
  usage_limits: Optional limits on model request count or token usage.
633
705
  usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
634
706
  infer_name: Whether to try to infer the agent name from the call frame if it's not set.
707
+ toolsets: Optional additional toolsets for this run.
635
708
 
636
709
  Returns:
637
710
  The result of the run.
@@ -655,6 +728,18 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
655
728
 
656
729
  output_type_ = output_type or self.output_type
657
730
 
731
+ # We consider it a user error if a user tries to restrict the result type while having an output validator that
732
+ # may change the result type from the restricted type to something else. Therefore, we consider the following
733
+ # typecast reasonable, even though it is possible to violate it with otherwise-type-checked code.
734
+ output_validators = cast(list[_output.OutputValidator[AgentDepsT, RunOutputDataT]], self._output_validators)
735
+
736
+ output_toolset = self._output_toolset
737
+ if output_schema != self._output_schema or output_validators:
738
+ output_toolset = cast(OutputToolset[AgentDepsT], output_schema.toolset)
739
+ if output_toolset:
740
+ output_toolset.max_retries = self._max_result_retries
741
+ output_toolset.output_validators = output_validators
742
+
658
743
  # Build the graph
659
744
  graph: Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[Any]] = (
660
745
  _agent_graph.build_agent_graph(self.name, self._deps_type, output_type_)
@@ -669,22 +754,32 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
669
754
  run_step=0,
670
755
  )
671
756
 
672
- # We consider it a user error if a user tries to restrict the result type while having an output validator that
673
- # may change the result type from the restricted type to something else. Therefore, we consider the following
674
- # typecast reasonable, even though it is possible to violate it with otherwise-type-checked code.
675
- output_validators = cast(list[_output.OutputValidator[AgentDepsT, RunOutputDataT]], self._output_validators)
676
-
677
- # Merge model settings in order of precedence: run > agent > model
678
- merged_settings = merge_model_settings(model_used.settings, self.model_settings)
679
- model_settings = merge_model_settings(merged_settings, model_settings)
680
- usage_limits = usage_limits or _usage.UsageLimits()
681
-
682
757
  if isinstance(model_used, InstrumentedModel):
683
758
  instrumentation_settings = model_used.instrumentation_settings
684
759
  tracer = model_used.instrumentation_settings.tracer
685
760
  else:
686
761
  instrumentation_settings = None
687
762
  tracer = NoOpTracer()
763
+
764
+ run_context = RunContext[AgentDepsT](
765
+ deps=deps,
766
+ model=model_used,
767
+ usage=usage,
768
+ prompt=user_prompt,
769
+ messages=state.message_history,
770
+ tracer=tracer,
771
+ trace_include_content=instrumentation_settings is not None and instrumentation_settings.include_content,
772
+ run_step=state.run_step,
773
+ )
774
+
775
+ toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets)
776
+ # This will raise errors for any name conflicts
777
+ run_toolset = await ToolManager[AgentDepsT].build(toolset, run_context)
778
+
779
+ # Merge model settings in order of precedence: run > agent > model
780
+ merged_settings = merge_model_settings(model_used.settings, self.model_settings)
781
+ model_settings = merge_model_settings(merged_settings, model_settings)
782
+ usage_limits = usage_limits or _usage.UsageLimits()
688
783
  agent_name = self.name or 'agent'
689
784
  run_span = tracer.start_span(
690
785
  'agent run',
@@ -711,10 +806,6 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
711
806
  return None
712
807
  return '\n\n'.join(parts).strip()
713
808
 
714
- # Copy the function tools so that retry state is agent-run-specific
715
- # Note that the retry count is reset to 0 when this happens due to the `default=0` and `init=False`.
716
- run_function_tools = {k: dataclasses.replace(v) for k, v in self._function_tools.items()}
717
-
718
809
  graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT](
719
810
  user_deps=deps,
720
811
  prompt=user_prompt,
@@ -727,11 +818,8 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
727
818
  output_schema=output_schema,
728
819
  output_validators=output_validators,
729
820
  history_processors=self.history_processors,
730
- function_tools=run_function_tools,
731
- mcp_servers=self._mcp_servers,
732
- default_retries=self._default_retries,
821
+ tool_manager=run_toolset,
733
822
  tracer=tracer,
734
- prepare_tools=self._prepare_tools,
735
823
  get_instructions=get_instructions,
736
824
  instrumentation_settings=instrumentation_settings,
737
825
  )
@@ -801,6 +889,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
801
889
  usage_limits: _usage.UsageLimits | None = None,
802
890
  usage: _usage.Usage | None = None,
803
891
  infer_name: bool = True,
892
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
804
893
  ) -> AgentRunResult[OutputDataT]: ...
805
894
 
806
895
  @overload
@@ -816,6 +905,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
816
905
  usage_limits: _usage.UsageLimits | None = None,
817
906
  usage: _usage.Usage | None = None,
818
907
  infer_name: bool = True,
908
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
819
909
  ) -> AgentRunResult[RunOutputDataT]: ...
820
910
 
821
911
  @overload
@@ -832,6 +922,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
832
922
  usage_limits: _usage.UsageLimits | None = None,
833
923
  usage: _usage.Usage | None = None,
834
924
  infer_name: bool = True,
925
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
835
926
  ) -> AgentRunResult[RunOutputDataT]: ...
836
927
 
837
928
  def run_sync(
@@ -846,6 +937,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
846
937
  usage_limits: _usage.UsageLimits | None = None,
847
938
  usage: _usage.Usage | None = None,
848
939
  infer_name: bool = True,
940
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
849
941
  **_deprecated_kwargs: Never,
850
942
  ) -> AgentRunResult[Any]:
851
943
  """Synchronously run the agent with a user prompt.
@@ -875,6 +967,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
875
967
  usage_limits: Optional limits on model request count or token usage.
876
968
  usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
877
969
  infer_name: Whether to try to infer the agent name from the call frame if it's not set.
970
+ toolsets: Optional additional toolsets for this run.
878
971
 
879
972
  Returns:
880
973
  The result of the run.
@@ -901,6 +994,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
901
994
  usage_limits=usage_limits,
902
995
  usage=usage,
903
996
  infer_name=False,
997
+ toolsets=toolsets,
904
998
  )
905
999
  )
906
1000
 
@@ -916,6 +1010,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
916
1010
  usage_limits: _usage.UsageLimits | None = None,
917
1011
  usage: _usage.Usage | None = None,
918
1012
  infer_name: bool = True,
1013
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
919
1014
  ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, OutputDataT]]: ...
920
1015
 
921
1016
  @overload
@@ -931,6 +1026,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
931
1026
  usage_limits: _usage.UsageLimits | None = None,
932
1027
  usage: _usage.Usage | None = None,
933
1028
  infer_name: bool = True,
1029
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
934
1030
  ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ...
935
1031
 
936
1032
  @overload
@@ -947,6 +1043,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
947
1043
  usage_limits: _usage.UsageLimits | None = None,
948
1044
  usage: _usage.Usage | None = None,
949
1045
  infer_name: bool = True,
1046
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
950
1047
  ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ...
951
1048
 
952
1049
  @asynccontextmanager
@@ -962,6 +1059,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
962
1059
  usage_limits: _usage.UsageLimits | None = None,
963
1060
  usage: _usage.Usage | None = None,
964
1061
  infer_name: bool = True,
1062
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
965
1063
  **_deprecated_kwargs: Never,
966
1064
  ) -> AsyncIterator[result.StreamedRunResult[AgentDepsT, Any]]:
967
1065
  """Run the agent with a user prompt in async mode, returning a streamed response.
@@ -989,6 +1087,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
989
1087
  usage_limits: Optional limits on model request count or token usage.
990
1088
  usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
991
1089
  infer_name: Whether to try to infer the agent name from the call frame if it's not set.
1090
+ toolsets: Optional additional toolsets for this run.
992
1091
 
993
1092
  Returns:
994
1093
  The result of the run.
@@ -1019,6 +1118,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1019
1118
  usage_limits=usage_limits,
1020
1119
  usage=usage,
1021
1120
  infer_name=False,
1121
+ toolsets=toolsets,
1022
1122
  ) as agent_run:
1023
1123
  first_node = agent_run.next_node # start with the first node
1024
1124
  assert isinstance(first_node, _agent_graph.UserPromptNode) # the first node should be a user prompt node
@@ -1039,15 +1139,17 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1039
1139
  output_schema, _output.TextOutputSchema
1040
1140
  ):
1041
1141
  return FinalResult(s, None, None)
1042
- elif isinstance(new_part, _messages.ToolCallPart) and isinstance(
1043
- output_schema, _output.ToolOutputSchema
1044
- ): # pragma: no branch
1045
- for call, _ in output_schema.find_tool([new_part]):
1046
- return FinalResult(s, call.tool_name, call.tool_call_id)
1142
+ elif isinstance(new_part, _messages.ToolCallPart) and (
1143
+ tool_def := graph_ctx.deps.tool_manager.get_tool_def(new_part.tool_name)
1144
+ ):
1145
+ if tool_def.kind == 'output':
1146
+ return FinalResult(s, new_part.tool_name, new_part.tool_call_id)
1147
+ elif tool_def.kind == 'deferred':
1148
+ return FinalResult(s, None, None)
1047
1149
  return None
1048
1150
 
1049
- final_result_details = await stream_to_final(streamed_response)
1050
- if final_result_details is not None:
1151
+ final_result = await stream_to_final(streamed_response)
1152
+ if final_result is not None:
1051
1153
  if yielded:
1052
1154
  raise exceptions.AgentRunError('Agent run produced final results') # pragma: no cover
1053
1155
  yielded = True
@@ -1068,17 +1170,13 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1068
1170
 
1069
1171
  parts: list[_messages.ModelRequestPart] = []
1070
1172
  async for _event in _agent_graph.process_function_tools(
1173
+ graph_ctx.deps.tool_manager,
1071
1174
  tool_calls,
1072
- final_result_details.tool_name,
1073
- final_result_details.tool_call_id,
1175
+ final_result,
1074
1176
  graph_ctx,
1075
1177
  parts,
1076
1178
  ):
1077
1179
  pass
1078
- # TODO: Should we do something here related to the retry count?
1079
- # Maybe we should move the incrementing of the retry count to where we actually make a request?
1080
- # if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
1081
- # ctx.state.increment_retries(ctx.deps.max_result_retries)
1082
1180
  if parts:
1083
1181
  messages.append(_messages.ModelRequest(parts))
1084
1182
 
@@ -1089,10 +1187,10 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1089
1187
  streamed_response,
1090
1188
  graph_ctx.deps.output_schema,
1091
1189
  _agent_graph.build_run_context(graph_ctx),
1092
- _output.build_trace_context(graph_ctx),
1093
1190
  graph_ctx.deps.output_validators,
1094
- final_result_details.tool_name,
1191
+ final_result.tool_name,
1095
1192
  on_complete,
1193
+ graph_ctx.deps.tool_manager,
1096
1194
  )
1097
1195
  break
1098
1196
  next_node = await agent_run.next(node)
@@ -1111,8 +1209,9 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1111
1209
  *,
1112
1210
  deps: AgentDepsT | _utils.Unset = _utils.UNSET,
1113
1211
  model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET,
1212
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET,
1114
1213
  ) -> Iterator[None]:
1115
- """Context manager to temporarily override agent dependencies and model.
1214
+ """Context manager to temporarily override agent dependencies, model, or toolsets.
1116
1215
 
1117
1216
  This is particularly useful when testing.
1118
1217
  You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures).
@@ -1120,6 +1219,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1120
1219
  Args:
1121
1220
  deps: The dependencies to use instead of the dependencies passed to the agent run.
1122
1221
  model: The model to use instead of the model passed to the agent run.
1222
+ toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run.
1123
1223
  """
1124
1224
  if _utils.is_set(deps):
1125
1225
  deps_token = self._override_deps.set(_utils.Some(deps))
@@ -1131,6 +1231,11 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1131
1231
  else:
1132
1232
  model_token = None
1133
1233
 
1234
+ if _utils.is_set(toolsets):
1235
+ toolsets_token = self._override_toolsets.set(_utils.Some(toolsets))
1236
+ else:
1237
+ toolsets_token = None
1238
+
1134
1239
  try:
1135
1240
  yield
1136
1241
  finally:
@@ -1138,6 +1243,8 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1138
1243
  self._override_deps.reset(deps_token)
1139
1244
  if model_token is not None:
1140
1245
  self._override_model.reset(model_token)
1246
+ if toolsets_token is not None:
1247
+ self._override_toolsets.reset(toolsets_token)
1141
1248
 
1142
1249
  @overload
1143
1250
  def instructions(
@@ -1423,30 +1530,13 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1423
1530
  strict: Whether to enforce JSON schema compliance (only affects OpenAI).
1424
1531
  See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
1425
1532
  """
1426
- if func is None:
1427
-
1428
- def tool_decorator(
1429
- func_: ToolFuncContext[AgentDepsT, ToolParams],
1430
- ) -> ToolFuncContext[AgentDepsT, ToolParams]:
1431
- # noinspection PyTypeChecker
1432
- self._register_function(
1433
- func_,
1434
- True,
1435
- name,
1436
- retries,
1437
- prepare,
1438
- docstring_format,
1439
- require_parameter_descriptions,
1440
- schema_generator,
1441
- strict,
1442
- )
1443
- return func_
1444
1533
 
1445
- return tool_decorator
1446
- else:
1534
+ def tool_decorator(
1535
+ func_: ToolFuncContext[AgentDepsT, ToolParams],
1536
+ ) -> ToolFuncContext[AgentDepsT, ToolParams]:
1447
1537
  # noinspection PyTypeChecker
1448
- self._register_function(
1449
- func,
1538
+ self._function_toolset.add_function(
1539
+ func_,
1450
1540
  True,
1451
1541
  name,
1452
1542
  retries,
@@ -1456,7 +1546,9 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1456
1546
  schema_generator,
1457
1547
  strict,
1458
1548
  )
1459
- return func
1549
+ return func_
1550
+
1551
+ return tool_decorator if func is None else tool_decorator(func)
1460
1552
 
1461
1553
  @overload
1462
1554
  def tool_plain(self, func: ToolFuncPlain[ToolParams], /) -> ToolFuncPlain[ToolParams]: ...
@@ -1532,27 +1624,11 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1532
1624
  strict: Whether to enforce JSON schema compliance (only affects OpenAI).
1533
1625
  See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
1534
1626
  """
1535
- if func is None:
1536
1627
 
1537
- def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
1538
- # noinspection PyTypeChecker
1539
- self._register_function(
1540
- func_,
1541
- False,
1542
- name,
1543
- retries,
1544
- prepare,
1545
- docstring_format,
1546
- require_parameter_descriptions,
1547
- schema_generator,
1548
- strict,
1549
- )
1550
- return func_
1551
-
1552
- return tool_decorator
1553
- else:
1554
- self._register_function(
1555
- func,
1628
+ def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
1629
+ # noinspection PyTypeChecker
1630
+ self._function_toolset.add_function(
1631
+ func_,
1556
1632
  False,
1557
1633
  name,
1558
1634
  retries,
@@ -1562,48 +1638,9 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1562
1638
  schema_generator,
1563
1639
  strict,
1564
1640
  )
1565
- return func
1566
-
1567
- def _register_function(
1568
- self,
1569
- func: ToolFuncEither[AgentDepsT, ToolParams],
1570
- takes_ctx: bool,
1571
- name: str | None,
1572
- retries: int | None,
1573
- prepare: ToolPrepareFunc[AgentDepsT] | None,
1574
- docstring_format: DocstringFormat,
1575
- require_parameter_descriptions: bool,
1576
- schema_generator: type[GenerateJsonSchema],
1577
- strict: bool | None,
1578
- ) -> None:
1579
- """Private utility to register a function as a tool."""
1580
- retries_ = retries if retries is not None else self._default_retries
1581
- tool = Tool[AgentDepsT](
1582
- func,
1583
- takes_ctx=takes_ctx,
1584
- name=name,
1585
- max_retries=retries_,
1586
- prepare=prepare,
1587
- docstring_format=docstring_format,
1588
- require_parameter_descriptions=require_parameter_descriptions,
1589
- schema_generator=schema_generator,
1590
- strict=strict,
1591
- )
1592
- self._register_tool(tool)
1593
-
1594
- def _register_tool(self, tool: Tool[AgentDepsT]) -> None:
1595
- """Private utility to register a tool instance."""
1596
- if tool.max_retries is None:
1597
- # noinspection PyTypeChecker
1598
- tool = dataclasses.replace(tool, max_retries=self._default_retries)
1599
-
1600
- if tool.name in self._function_tools:
1601
- raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}')
1641
+ return func_
1602
1642
 
1603
- if tool.name in self._output_schema.tools:
1604
- raise exceptions.UserError(f'Tool name conflicts with output tool name: {tool.name!r}')
1605
-
1606
- self._function_tools[tool.name] = tool
1643
+ return tool_decorator if func is None else tool_decorator(func)
1607
1644
 
1608
1645
  def _get_model(self, model: models.Model | models.KnownModelName | str | None) -> models.Model:
1609
1646
  """Create a model configured for this agent.
@@ -1649,6 +1686,37 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1649
1686
  else:
1650
1687
  return deps
1651
1688
 
1689
+ def _get_toolset(
1690
+ self,
1691
+ output_toolset: AbstractToolset[AgentDepsT] | None | _utils.Unset = _utils.UNSET,
1692
+ additional_toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
1693
+ ) -> AbstractToolset[AgentDepsT]:
1694
+ """Get the complete toolset.
1695
+
1696
+ Args:
1697
+ output_toolset: The output toolset to use instead of the one built at agent construction time.
1698
+ additional_toolsets: Additional toolsets to add.
1699
+ """
1700
+ if some_user_toolsets := self._override_toolsets.get():
1701
+ user_toolsets = some_user_toolsets.value
1702
+ elif additional_toolsets is not None:
1703
+ user_toolsets = [*self._user_toolsets, *additional_toolsets]
1704
+ else:
1705
+ user_toolsets = self._user_toolsets
1706
+
1707
+ all_toolsets = [self._function_toolset, *user_toolsets]
1708
+
1709
+ if self._prepare_tools:
1710
+ all_toolsets = [PreparedToolset(CombinedToolset(all_toolsets), self._prepare_tools)]
1711
+
1712
+ output_toolset = output_toolset if _utils.is_set(output_toolset) else self._output_toolset
1713
+ if output_toolset is not None:
1714
+ if self._prepare_output_tools:
1715
+ output_toolset = PreparedToolset(output_toolset, self._prepare_output_tools)
1716
+ all_toolsets = [output_toolset, *all_toolsets]
1717
+
1718
+ return CombinedToolset(all_toolsets)
1719
+
1652
1720
  def _infer_name(self, function_frame: FrameType | None) -> None:
1653
1721
  """Infer the agent name from the call frame.
1654
1722
 
@@ -1734,28 +1802,167 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1734
1802
  """
1735
1803
  return isinstance(node, End)
1736
1804
 
1805
+ async def __aenter__(self) -> Self:
1806
+ """Enter the agent context.
1807
+
1808
+ This will start all [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] registered as `toolsets` so they are ready to be used.
1809
+
1810
+ This is a no-op if the agent has already been entered.
1811
+ """
1812
+ async with self._enter_lock:
1813
+ if self._entered_count == 0:
1814
+ self._exit_stack = AsyncExitStack()
1815
+ toolset = self._get_toolset()
1816
+ await self._exit_stack.enter_async_context(toolset)
1817
+ self._entered_count += 1
1818
+ return self
1819
+
1820
+ async def __aexit__(self, *args: Any) -> bool | None:
1821
+ async with self._enter_lock:
1822
+ self._entered_count -= 1
1823
+ if self._entered_count == 0 and self._exit_stack is not None:
1824
+ await self._exit_stack.aclose()
1825
+ self._exit_stack = None
1826
+
1827
+ def set_mcp_sampling_model(self, model: models.Model | models.KnownModelName | str | None = None) -> None:
1828
+ """Set the sampling model on all MCP servers registered with the agent.
1829
+
1830
+ If no sampling model is provided, the agent's model will be used.
1831
+ """
1832
+ try:
1833
+ sampling_model = models.infer_model(model) if model else self._get_model(None)
1834
+ except exceptions.UserError as e:
1835
+ raise exceptions.UserError('No sampling model provided and no model set on the agent.') from e
1836
+
1837
+ from .mcp import MCPServer
1838
+
1839
+ def _set_sampling_model(toolset: AbstractToolset[AgentDepsT]) -> None:
1840
+ if isinstance(toolset, MCPServer):
1841
+ toolset.sampling_model = sampling_model
1842
+
1843
+ self._get_toolset().apply(_set_sampling_model)
1844
+
1737
1845
  @asynccontextmanager
1846
+ @deprecated(
1847
+ '`run_mcp_servers` is deprecated, use `async with agent:` instead. If you need to set a sampling model on all MCP servers, use `agent.set_mcp_sampling_model()`.'
1848
+ )
1738
1849
  async def run_mcp_servers(
1739
1850
  self, model: models.Model | models.KnownModelName | str | None = None
1740
1851
  ) -> AsyncIterator[None]:
1741
1852
  """Run [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] so they can be used by the agent.
1742
1853
 
1854
+ Deprecated: use [`async with agent`][pydantic_ai.agent.Agent.__aenter__] instead.
1855
+ If you need to set a sampling model on all MCP servers, use [`agent.set_mcp_sampling_model()`][pydantic_ai.agent.Agent.set_mcp_sampling_model].
1856
+
1743
1857
  Returns: a context manager to start and shutdown the servers.
1744
1858
  """
1745
1859
  try:
1746
- sampling_model: models.Model | None = self._get_model(model)
1747
- except exceptions.UserError: # pragma: no cover
1748
- sampling_model = None
1860
+ self.set_mcp_sampling_model(model)
1861
+ except exceptions.UserError:
1862
+ if model is not None:
1863
+ raise
1749
1864
 
1750
- exit_stack = AsyncExitStack()
1751
- try:
1752
- for mcp_server in self._mcp_servers:
1753
- if sampling_model is not None: # pragma: no branch
1754
- mcp_server.sampling_model = sampling_model
1755
- await exit_stack.enter_async_context(mcp_server)
1865
+ async with self:
1756
1866
  yield
1757
- finally:
1758
- await exit_stack.aclose()
1867
+
1868
+ def to_ag_ui(
1869
+ self,
1870
+ *,
1871
+ # Agent.iter parameters
1872
+ output_type: OutputSpec[OutputDataT] | None = None,
1873
+ model: models.Model | models.KnownModelName | str | None = None,
1874
+ deps: AgentDepsT = None,
1875
+ model_settings: ModelSettings | None = None,
1876
+ usage_limits: UsageLimits | None = None,
1877
+ usage: Usage | None = None,
1878
+ infer_name: bool = True,
1879
+ toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
1880
+ # Starlette
1881
+ debug: bool = False,
1882
+ routes: Sequence[BaseRoute] | None = None,
1883
+ middleware: Sequence[Middleware] | None = None,
1884
+ exception_handlers: Mapping[Any, ExceptionHandler] | None = None,
1885
+ on_startup: Sequence[Callable[[], Any]] | None = None,
1886
+ on_shutdown: Sequence[Callable[[], Any]] | None = None,
1887
+ lifespan: Lifespan[AGUIApp[AgentDepsT, OutputDataT]] | None = None,
1888
+ ) -> AGUIApp[AgentDepsT, OutputDataT]:
1889
+ """Convert the agent to an AG-UI application.
1890
+
1891
+ This allows you to use the agent with a compatible AG-UI frontend.
1892
+
1893
+ Example:
1894
+ ```python
1895
+ from pydantic_ai import Agent
1896
+
1897
+ agent = Agent('openai:gpt-4o')
1898
+ app = agent.to_ag_ui()
1899
+ ```
1900
+
1901
+ The `app` is an ASGI application that can be used with any ASGI server.
1902
+
1903
+ To run the application, you can use the following command:
1904
+
1905
+ ```bash
1906
+ uvicorn app:app --host 0.0.0.0 --port 8000
1907
+ ```
1908
+
1909
+ See [AG-UI docs](../ag-ui.md) for more information.
1910
+
1911
+ Args:
1912
+ output_type: Custom output type to use for this run, `output_type` may only be used if the agent has
1913
+ no output validators since output validators would expect an argument that matches the agent's
1914
+ output type.
1915
+ model: Optional model to use for this run, required if `model` was not set when creating the agent.
1916
+ deps: Optional dependencies to use for this run.
1917
+ model_settings: Optional settings to use for this model's request.
1918
+ usage_limits: Optional limits on model request count or token usage.
1919
+ usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
1920
+ infer_name: Whether to try to infer the agent name from the call frame if it's not set.
1921
+ toolsets: Optional list of toolsets to use for this agent, defaults to the agent's toolset.
1922
+
1923
+ debug: Boolean indicating if debug tracebacks should be returned on errors.
1924
+ routes: A list of routes to serve incoming HTTP and WebSocket requests.
1925
+ middleware: A list of middleware to run for every request. A starlette application will always
1926
+ automatically include two middleware classes. `ServerErrorMiddleware` is added as the very
1927
+ outermost middleware, to handle any uncaught errors occurring anywhere in the entire stack.
1928
+ `ExceptionMiddleware` is added as the very innermost middleware, to deal with handled
1929
+ exception cases occurring in the routing or endpoints.
1930
+ exception_handlers: A mapping of either integer status codes, or exception class types onto
1931
+ callables which handle the exceptions. Exception handler callables should be of the form
1932
+ `handler(request, exc) -> response` and may be either standard functions, or async functions.
1933
+ on_startup: A list of callables to run on application startup. Startup handler callables do not
1934
+ take any arguments, and may be either standard functions, or async functions.
1935
+ on_shutdown: A list of callables to run on application shutdown. Shutdown handler callables do
1936
+ not take any arguments, and may be either standard functions, or async functions.
1937
+ lifespan: A lifespan context function, which can be used to perform startup and shutdown tasks.
1938
+ This is a newer style that replaces the `on_startup` and `on_shutdown` handlers. Use one or
1939
+ the other, not both.
1940
+
1941
+ Returns:
1942
+ An ASGI application for running Pydantic AI agents with AG-UI protocol support.
1943
+ """
1944
+ from .ag_ui import AGUIApp
1945
+
1946
+ return AGUIApp(
1947
+ agent=self,
1948
+ # Agent.iter parameters
1949
+ output_type=output_type,
1950
+ model=model,
1951
+ deps=deps,
1952
+ model_settings=model_settings,
1953
+ usage_limits=usage_limits,
1954
+ usage=usage,
1955
+ infer_name=infer_name,
1956
+ toolsets=toolsets,
1957
+ # Starlette
1958
+ debug=debug,
1959
+ routes=routes,
1960
+ middleware=middleware,
1961
+ exception_handlers=exception_handlers,
1962
+ on_startup=on_startup,
1963
+ on_shutdown=on_shutdown,
1964
+ lifespan=lifespan,
1965
+ )
1759
1966
 
1760
1967
  def to_a2a(
1761
1968
  self,
@@ -2112,12 +2319,18 @@ class AgentRunResult(Generic[OutputDataT]):
2112
2319
  """
2113
2320
  if not self._output_tool_name:
2114
2321
  raise ValueError('Cannot set output tool return content when the return type is `str`.')
2115
- messages = deepcopy(self._state.message_history)
2322
+
2323
+ messages = self._state.message_history
2116
2324
  last_message = messages[-1]
2117
- for part in last_message.parts:
2325
+ for idx, part in enumerate(last_message.parts):
2118
2326
  if isinstance(part, _messages.ToolReturnPart) and part.tool_name == self._output_tool_name:
2119
- part.content = return_content
2120
- return messages
2327
+ # Only do deepcopy when we have to modify
2328
+ copied_messages = list(messages)
2329
+ copied_last = deepcopy(last_message)
2330
+ copied_last.parts[idx].content = return_content # type: ignore[misc]
2331
+ copied_messages[-1] = copied_last
2332
+ return copied_messages
2333
+
2121
2334
  raise LookupError(f'No tool call found with tool name {self._output_tool_name!r}.')
2122
2335
 
2123
2336
  @overload