pydantic-ai-slim 0.4.3__py3-none-any.whl → 0.4.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_a2a.py +3 -3
- pydantic_ai/_agent_graph.py +220 -319
- pydantic_ai/_cli.py +9 -7
- pydantic_ai/_output.py +295 -331
- pydantic_ai/_parts_manager.py +2 -2
- pydantic_ai/_run_context.py +8 -14
- pydantic_ai/_tool_manager.py +190 -0
- pydantic_ai/_utils.py +18 -1
- pydantic_ai/ag_ui.py +675 -0
- pydantic_ai/agent.py +378 -164
- pydantic_ai/exceptions.py +12 -0
- pydantic_ai/ext/aci.py +12 -3
- pydantic_ai/ext/langchain.py +9 -1
- pydantic_ai/format_prompt.py +3 -6
- pydantic_ai/mcp.py +147 -84
- pydantic_ai/messages.py +13 -5
- pydantic_ai/models/__init__.py +30 -18
- pydantic_ai/models/anthropic.py +1 -1
- pydantic_ai/models/function.py +50 -24
- pydantic_ai/models/gemini.py +1 -18
- pydantic_ai/models/google.py +2 -11
- pydantic_ai/models/groq.py +1 -0
- pydantic_ai/models/instrumented.py +6 -1
- pydantic_ai/models/mistral.py +1 -1
- pydantic_ai/models/openai.py +16 -4
- pydantic_ai/output.py +21 -7
- pydantic_ai/profiles/google.py +1 -1
- pydantic_ai/profiles/moonshotai.py +8 -0
- pydantic_ai/providers/grok.py +13 -1
- pydantic_ai/providers/groq.py +2 -0
- pydantic_ai/result.py +58 -45
- pydantic_ai/tools.py +26 -119
- pydantic_ai/toolsets/__init__.py +22 -0
- pydantic_ai/toolsets/abstract.py +155 -0
- pydantic_ai/toolsets/combined.py +88 -0
- pydantic_ai/toolsets/deferred.py +38 -0
- pydantic_ai/toolsets/filtered.py +24 -0
- pydantic_ai/toolsets/function.py +238 -0
- pydantic_ai/toolsets/prefixed.py +37 -0
- pydantic_ai/toolsets/prepared.py +36 -0
- pydantic_ai/toolsets/renamed.py +42 -0
- pydantic_ai/toolsets/wrapper.py +37 -0
- pydantic_ai/usage.py +14 -8
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.5.dist-info}/METADATA +10 -7
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.5.dist-info}/RECORD +48 -35
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.5.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.5.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.5.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/agent.py
CHANGED
|
@@ -4,7 +4,8 @@ import dataclasses
|
|
|
4
4
|
import inspect
|
|
5
5
|
import json
|
|
6
6
|
import warnings
|
|
7
|
-
from
|
|
7
|
+
from asyncio import Lock
|
|
8
|
+
from collections.abc import AsyncIterator, Awaitable, Iterator, Mapping, Sequence
|
|
8
9
|
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager
|
|
9
10
|
from contextvars import ContextVar
|
|
10
11
|
from copy import deepcopy
|
|
@@ -15,7 +16,6 @@ from opentelemetry.trace import NoOpTracer, use_span
|
|
|
15
16
|
from pydantic.json_schema import GenerateJsonSchema
|
|
16
17
|
from typing_extensions import Literal, Never, Self, TypeIs, TypeVar, deprecated
|
|
17
18
|
|
|
18
|
-
from pydantic_ai.profiles import ModelProfile
|
|
19
19
|
from pydantic_graph import End, Graph, GraphRun, GraphRunContext
|
|
20
20
|
from pydantic_graph._utils import get_event_loop
|
|
21
21
|
|
|
@@ -31,8 +31,11 @@ from . import (
|
|
|
31
31
|
usage as _usage,
|
|
32
32
|
)
|
|
33
33
|
from ._agent_graph import HistoryProcessor
|
|
34
|
+
from ._output import OutputToolset
|
|
35
|
+
from ._tool_manager import ToolManager
|
|
34
36
|
from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model
|
|
35
37
|
from .output import OutputDataT, OutputSpec
|
|
38
|
+
from .profiles import ModelProfile
|
|
36
39
|
from .result import FinalResult, StreamedRunResult
|
|
37
40
|
from .settings import ModelSettings, merge_model_settings
|
|
38
41
|
from .tools import (
|
|
@@ -48,6 +51,11 @@ from .tools import (
|
|
|
48
51
|
ToolPrepareFunc,
|
|
49
52
|
ToolsPrepareFunc,
|
|
50
53
|
)
|
|
54
|
+
from .toolsets import AbstractToolset
|
|
55
|
+
from .toolsets.combined import CombinedToolset
|
|
56
|
+
from .toolsets.function import FunctionToolset
|
|
57
|
+
from .toolsets.prepared import PreparedToolset
|
|
58
|
+
from .usage import Usage, UsageLimits
|
|
51
59
|
|
|
52
60
|
# Re-exporting like this improves auto-import behavior in PyCharm
|
|
53
61
|
capture_run_messages = _agent_graph.capture_run_messages
|
|
@@ -62,11 +70,12 @@ if TYPE_CHECKING:
|
|
|
62
70
|
from fasta2a.schema import AgentProvider, Skill
|
|
63
71
|
from fasta2a.storage import Storage
|
|
64
72
|
from starlette.middleware import Middleware
|
|
65
|
-
from starlette.routing import Route
|
|
73
|
+
from starlette.routing import BaseRoute, Route
|
|
66
74
|
from starlette.types import ExceptionHandler, Lifespan
|
|
67
75
|
|
|
68
76
|
from pydantic_ai.mcp import MCPServer
|
|
69
77
|
|
|
78
|
+
from .ag_ui import AGUIApp
|
|
70
79
|
|
|
71
80
|
__all__ = (
|
|
72
81
|
'Agent',
|
|
@@ -153,12 +162,17 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
153
162
|
_system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(
|
|
154
163
|
repr=False
|
|
155
164
|
)
|
|
165
|
+
_function_toolset: FunctionToolset[AgentDepsT] = dataclasses.field(repr=False)
|
|
166
|
+
_output_toolset: OutputToolset[AgentDepsT] | None = dataclasses.field(repr=False)
|
|
167
|
+
_user_toolsets: Sequence[AbstractToolset[AgentDepsT]] = dataclasses.field(repr=False)
|
|
156
168
|
_prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False)
|
|
157
|
-
|
|
158
|
-
_mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
|
|
159
|
-
_default_retries: int = dataclasses.field(repr=False)
|
|
169
|
+
_prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False)
|
|
160
170
|
_max_result_retries: int = dataclasses.field(repr=False)
|
|
161
171
|
|
|
172
|
+
_enter_lock: Lock = dataclasses.field(repr=False)
|
|
173
|
+
_entered_count: int = dataclasses.field(repr=False)
|
|
174
|
+
_exit_stack: AsyncExitStack | None = dataclasses.field(repr=False)
|
|
175
|
+
|
|
162
176
|
@overload
|
|
163
177
|
def __init__(
|
|
164
178
|
self,
|
|
@@ -177,7 +191,8 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
177
191
|
output_retries: int | None = None,
|
|
178
192
|
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
|
|
179
193
|
prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
|
|
180
|
-
|
|
194
|
+
prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
|
|
195
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
181
196
|
defer_model_check: bool = False,
|
|
182
197
|
end_strategy: EndStrategy = 'early',
|
|
183
198
|
instrument: InstrumentationSettings | bool | None = None,
|
|
@@ -186,7 +201,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
186
201
|
|
|
187
202
|
@overload
|
|
188
203
|
@deprecated(
|
|
189
|
-
'`result_type`, `result_tool_name
|
|
204
|
+
'`result_type`, `result_tool_name` & `result_tool_description` are deprecated, use `output_type` instead. `result_retries` is deprecated, use `output_retries` instead.'
|
|
190
205
|
)
|
|
191
206
|
def __init__(
|
|
192
207
|
self,
|
|
@@ -207,6 +222,36 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
207
222
|
result_retries: int | None = None,
|
|
208
223
|
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
|
|
209
224
|
prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
|
|
225
|
+
prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
|
|
226
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
227
|
+
defer_model_check: bool = False,
|
|
228
|
+
end_strategy: EndStrategy = 'early',
|
|
229
|
+
instrument: InstrumentationSettings | bool | None = None,
|
|
230
|
+
history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None,
|
|
231
|
+
) -> None: ...
|
|
232
|
+
|
|
233
|
+
@overload
|
|
234
|
+
@deprecated('`mcp_servers` is deprecated, use `toolsets` instead.')
|
|
235
|
+
def __init__(
|
|
236
|
+
self,
|
|
237
|
+
model: models.Model | models.KnownModelName | str | None = None,
|
|
238
|
+
*,
|
|
239
|
+
result_type: type[OutputDataT] = str,
|
|
240
|
+
instructions: str
|
|
241
|
+
| _system_prompt.SystemPromptFunc[AgentDepsT]
|
|
242
|
+
| Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]]
|
|
243
|
+
| None = None,
|
|
244
|
+
system_prompt: str | Sequence[str] = (),
|
|
245
|
+
deps_type: type[AgentDepsT] = NoneType,
|
|
246
|
+
name: str | None = None,
|
|
247
|
+
model_settings: ModelSettings | None = None,
|
|
248
|
+
retries: int = 1,
|
|
249
|
+
result_tool_name: str = _output.DEFAULT_OUTPUT_TOOL_NAME,
|
|
250
|
+
result_tool_description: str | None = None,
|
|
251
|
+
result_retries: int | None = None,
|
|
252
|
+
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
|
|
253
|
+
prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
|
|
254
|
+
prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
|
|
210
255
|
mcp_servers: Sequence[MCPServer] = (),
|
|
211
256
|
defer_model_check: bool = False,
|
|
212
257
|
end_strategy: EndStrategy = 'early',
|
|
@@ -232,7 +277,8 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
232
277
|
output_retries: int | None = None,
|
|
233
278
|
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
|
|
234
279
|
prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
|
|
235
|
-
|
|
280
|
+
prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
|
|
281
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
236
282
|
defer_model_check: bool = False,
|
|
237
283
|
end_strategy: EndStrategy = 'early',
|
|
238
284
|
instrument: InstrumentationSettings | bool | None = None,
|
|
@@ -258,14 +304,16 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
258
304
|
when the agent is first run.
|
|
259
305
|
model_settings: Optional model request settings to use for this agent's runs, by default.
|
|
260
306
|
retries: The default number of retries to allow before raising an error.
|
|
261
|
-
output_retries: The maximum number of retries to allow for
|
|
307
|
+
output_retries: The maximum number of retries to allow for output validation, defaults to `retries`.
|
|
262
308
|
tools: Tools to register with the agent, you can also register tools via the decorators
|
|
263
309
|
[`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain].
|
|
264
|
-
prepare_tools:
|
|
310
|
+
prepare_tools: Custom function to prepare the tool definition of all tools for each step, except output tools.
|
|
265
311
|
This is useful if you want to customize the definition of multiple tools or you want to register
|
|
266
312
|
a subset of tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc]
|
|
267
|
-
|
|
268
|
-
|
|
313
|
+
prepare_output_tools: Custom function to prepare the tool definition of all output tools for each step.
|
|
314
|
+
This is useful if you want to customize the definition of multiple output tools or you want to register
|
|
315
|
+
a subset of output tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc]
|
|
316
|
+
toolsets: Toolsets to register with the agent, including MCP servers.
|
|
269
317
|
defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model,
|
|
270
318
|
it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately,
|
|
271
319
|
which checks for the necessary environment variables. Set this to `false`
|
|
@@ -329,10 +377,17 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
329
377
|
)
|
|
330
378
|
output_retries = result_retries
|
|
331
379
|
|
|
380
|
+
if mcp_servers := _deprecated_kwargs.pop('mcp_servers', None):
|
|
381
|
+
if toolsets is not None: # pragma: no cover
|
|
382
|
+
raise TypeError('`mcp_servers` and `toolsets` cannot be set at the same time.')
|
|
383
|
+
warnings.warn('`mcp_servers` is deprecated, use `toolsets` instead', DeprecationWarning)
|
|
384
|
+
toolsets = mcp_servers
|
|
385
|
+
|
|
386
|
+
_utils.validate_empty_kwargs(_deprecated_kwargs)
|
|
387
|
+
|
|
332
388
|
default_output_mode = (
|
|
333
389
|
self.model.profile.default_structured_output_mode if isinstance(self.model, models.Model) else None
|
|
334
390
|
)
|
|
335
|
-
_utils.validate_empty_kwargs(_deprecated_kwargs)
|
|
336
391
|
|
|
337
392
|
self._output_schema = _output.OutputSchema[OutputDataT].build(
|
|
338
393
|
output_type,
|
|
@@ -357,21 +412,28 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
357
412
|
self._system_prompt_functions = []
|
|
358
413
|
self._system_prompt_dynamic_functions = {}
|
|
359
414
|
|
|
360
|
-
self._function_tools = {}
|
|
361
|
-
|
|
362
|
-
self._default_retries = retries
|
|
363
415
|
self._max_result_retries = output_retries if output_retries is not None else retries
|
|
364
|
-
self._mcp_servers = mcp_servers
|
|
365
416
|
self._prepare_tools = prepare_tools
|
|
417
|
+
self._prepare_output_tools = prepare_output_tools
|
|
418
|
+
|
|
419
|
+
self._output_toolset = self._output_schema.toolset
|
|
420
|
+
if self._output_toolset:
|
|
421
|
+
self._output_toolset.max_retries = self._max_result_retries
|
|
422
|
+
|
|
423
|
+
self._function_toolset = FunctionToolset(tools, max_retries=retries)
|
|
424
|
+
self._user_toolsets = toolsets or ()
|
|
425
|
+
|
|
366
426
|
self.history_processors = history_processors or []
|
|
367
|
-
for tool in tools:
|
|
368
|
-
if isinstance(tool, Tool):
|
|
369
|
-
self._register_tool(tool)
|
|
370
|
-
else:
|
|
371
|
-
self._register_tool(Tool(tool))
|
|
372
427
|
|
|
373
428
|
self._override_deps: ContextVar[_utils.Option[AgentDepsT]] = ContextVar('_override_deps', default=None)
|
|
374
429
|
self._override_model: ContextVar[_utils.Option[models.Model]] = ContextVar('_override_model', default=None)
|
|
430
|
+
self._override_toolsets: ContextVar[_utils.Option[Sequence[AbstractToolset[AgentDepsT]]]] = ContextVar(
|
|
431
|
+
'_override_toolsets', default=None
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
self._enter_lock = _utils.get_async_lock()
|
|
435
|
+
self._entered_count = 0
|
|
436
|
+
self._exit_stack = None
|
|
375
437
|
|
|
376
438
|
@staticmethod
|
|
377
439
|
def instrument_all(instrument: InstrumentationSettings | bool = True) -> None:
|
|
@@ -391,6 +453,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
391
453
|
usage_limits: _usage.UsageLimits | None = None,
|
|
392
454
|
usage: _usage.Usage | None = None,
|
|
393
455
|
infer_name: bool = True,
|
|
456
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
394
457
|
) -> AgentRunResult[OutputDataT]: ...
|
|
395
458
|
|
|
396
459
|
@overload
|
|
@@ -406,6 +469,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
406
469
|
usage_limits: _usage.UsageLimits | None = None,
|
|
407
470
|
usage: _usage.Usage | None = None,
|
|
408
471
|
infer_name: bool = True,
|
|
472
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
409
473
|
) -> AgentRunResult[RunOutputDataT]: ...
|
|
410
474
|
|
|
411
475
|
@overload
|
|
@@ -422,6 +486,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
422
486
|
usage_limits: _usage.UsageLimits | None = None,
|
|
423
487
|
usage: _usage.Usage | None = None,
|
|
424
488
|
infer_name: bool = True,
|
|
489
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
425
490
|
) -> AgentRunResult[RunOutputDataT]: ...
|
|
426
491
|
|
|
427
492
|
async def run(
|
|
@@ -436,6 +501,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
436
501
|
usage_limits: _usage.UsageLimits | None = None,
|
|
437
502
|
usage: _usage.Usage | None = None,
|
|
438
503
|
infer_name: bool = True,
|
|
504
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
439
505
|
**_deprecated_kwargs: Never,
|
|
440
506
|
) -> AgentRunResult[Any]:
|
|
441
507
|
"""Run the agent with a user prompt in async mode.
|
|
@@ -466,6 +532,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
466
532
|
usage_limits: Optional limits on model request count or token usage.
|
|
467
533
|
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
|
|
468
534
|
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
535
|
+
toolsets: Optional additional toolsets for this run.
|
|
469
536
|
|
|
470
537
|
Returns:
|
|
471
538
|
The result of the run.
|
|
@@ -490,6 +557,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
490
557
|
model_settings=model_settings,
|
|
491
558
|
usage_limits=usage_limits,
|
|
492
559
|
usage=usage,
|
|
560
|
+
toolsets=toolsets,
|
|
493
561
|
) as agent_run:
|
|
494
562
|
async for _ in agent_run:
|
|
495
563
|
pass
|
|
@@ -510,6 +578,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
510
578
|
usage_limits: _usage.UsageLimits | None = None,
|
|
511
579
|
usage: _usage.Usage | None = None,
|
|
512
580
|
infer_name: bool = True,
|
|
581
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
513
582
|
**_deprecated_kwargs: Never,
|
|
514
583
|
) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ...
|
|
515
584
|
|
|
@@ -526,6 +595,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
526
595
|
usage_limits: _usage.UsageLimits | None = None,
|
|
527
596
|
usage: _usage.Usage | None = None,
|
|
528
597
|
infer_name: bool = True,
|
|
598
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
529
599
|
**_deprecated_kwargs: Never,
|
|
530
600
|
) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ...
|
|
531
601
|
|
|
@@ -543,6 +613,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
543
613
|
usage_limits: _usage.UsageLimits | None = None,
|
|
544
614
|
usage: _usage.Usage | None = None,
|
|
545
615
|
infer_name: bool = True,
|
|
616
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
546
617
|
) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, Any]]: ...
|
|
547
618
|
|
|
548
619
|
@asynccontextmanager
|
|
@@ -558,6 +629,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
558
629
|
usage_limits: _usage.UsageLimits | None = None,
|
|
559
630
|
usage: _usage.Usage | None = None,
|
|
560
631
|
infer_name: bool = True,
|
|
632
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
561
633
|
**_deprecated_kwargs: Never,
|
|
562
634
|
) -> AsyncIterator[AgentRun[AgentDepsT, Any]]:
|
|
563
635
|
"""A contextmanager which can be used to iterate over the agent graph's nodes as they are executed.
|
|
@@ -632,6 +704,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
632
704
|
usage_limits: Optional limits on model request count or token usage.
|
|
633
705
|
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
|
|
634
706
|
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
707
|
+
toolsets: Optional additional toolsets for this run.
|
|
635
708
|
|
|
636
709
|
Returns:
|
|
637
710
|
The result of the run.
|
|
@@ -655,6 +728,18 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
655
728
|
|
|
656
729
|
output_type_ = output_type or self.output_type
|
|
657
730
|
|
|
731
|
+
# We consider it a user error if a user tries to restrict the result type while having an output validator that
|
|
732
|
+
# may change the result type from the restricted type to something else. Therefore, we consider the following
|
|
733
|
+
# typecast reasonable, even though it is possible to violate it with otherwise-type-checked code.
|
|
734
|
+
output_validators = cast(list[_output.OutputValidator[AgentDepsT, RunOutputDataT]], self._output_validators)
|
|
735
|
+
|
|
736
|
+
output_toolset = self._output_toolset
|
|
737
|
+
if output_schema != self._output_schema or output_validators:
|
|
738
|
+
output_toolset = cast(OutputToolset[AgentDepsT], output_schema.toolset)
|
|
739
|
+
if output_toolset:
|
|
740
|
+
output_toolset.max_retries = self._max_result_retries
|
|
741
|
+
output_toolset.output_validators = output_validators
|
|
742
|
+
|
|
658
743
|
# Build the graph
|
|
659
744
|
graph: Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[Any]] = (
|
|
660
745
|
_agent_graph.build_agent_graph(self.name, self._deps_type, output_type_)
|
|
@@ -669,22 +754,32 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
669
754
|
run_step=0,
|
|
670
755
|
)
|
|
671
756
|
|
|
672
|
-
# We consider it a user error if a user tries to restrict the result type while having an output validator that
|
|
673
|
-
# may change the result type from the restricted type to something else. Therefore, we consider the following
|
|
674
|
-
# typecast reasonable, even though it is possible to violate it with otherwise-type-checked code.
|
|
675
|
-
output_validators = cast(list[_output.OutputValidator[AgentDepsT, RunOutputDataT]], self._output_validators)
|
|
676
|
-
|
|
677
|
-
# Merge model settings in order of precedence: run > agent > model
|
|
678
|
-
merged_settings = merge_model_settings(model_used.settings, self.model_settings)
|
|
679
|
-
model_settings = merge_model_settings(merged_settings, model_settings)
|
|
680
|
-
usage_limits = usage_limits or _usage.UsageLimits()
|
|
681
|
-
|
|
682
757
|
if isinstance(model_used, InstrumentedModel):
|
|
683
758
|
instrumentation_settings = model_used.instrumentation_settings
|
|
684
759
|
tracer = model_used.instrumentation_settings.tracer
|
|
685
760
|
else:
|
|
686
761
|
instrumentation_settings = None
|
|
687
762
|
tracer = NoOpTracer()
|
|
763
|
+
|
|
764
|
+
run_context = RunContext[AgentDepsT](
|
|
765
|
+
deps=deps,
|
|
766
|
+
model=model_used,
|
|
767
|
+
usage=usage,
|
|
768
|
+
prompt=user_prompt,
|
|
769
|
+
messages=state.message_history,
|
|
770
|
+
tracer=tracer,
|
|
771
|
+
trace_include_content=instrumentation_settings is not None and instrumentation_settings.include_content,
|
|
772
|
+
run_step=state.run_step,
|
|
773
|
+
)
|
|
774
|
+
|
|
775
|
+
toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets)
|
|
776
|
+
# This will raise errors for any name conflicts
|
|
777
|
+
run_toolset = await ToolManager[AgentDepsT].build(toolset, run_context)
|
|
778
|
+
|
|
779
|
+
# Merge model settings in order of precedence: run > agent > model
|
|
780
|
+
merged_settings = merge_model_settings(model_used.settings, self.model_settings)
|
|
781
|
+
model_settings = merge_model_settings(merged_settings, model_settings)
|
|
782
|
+
usage_limits = usage_limits or _usage.UsageLimits()
|
|
688
783
|
agent_name = self.name or 'agent'
|
|
689
784
|
run_span = tracer.start_span(
|
|
690
785
|
'agent run',
|
|
@@ -711,10 +806,6 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
711
806
|
return None
|
|
712
807
|
return '\n\n'.join(parts).strip()
|
|
713
808
|
|
|
714
|
-
# Copy the function tools so that retry state is agent-run-specific
|
|
715
|
-
# Note that the retry count is reset to 0 when this happens due to the `default=0` and `init=False`.
|
|
716
|
-
run_function_tools = {k: dataclasses.replace(v) for k, v in self._function_tools.items()}
|
|
717
|
-
|
|
718
809
|
graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT](
|
|
719
810
|
user_deps=deps,
|
|
720
811
|
prompt=user_prompt,
|
|
@@ -727,11 +818,8 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
727
818
|
output_schema=output_schema,
|
|
728
819
|
output_validators=output_validators,
|
|
729
820
|
history_processors=self.history_processors,
|
|
730
|
-
|
|
731
|
-
mcp_servers=self._mcp_servers,
|
|
732
|
-
default_retries=self._default_retries,
|
|
821
|
+
tool_manager=run_toolset,
|
|
733
822
|
tracer=tracer,
|
|
734
|
-
prepare_tools=self._prepare_tools,
|
|
735
823
|
get_instructions=get_instructions,
|
|
736
824
|
instrumentation_settings=instrumentation_settings,
|
|
737
825
|
)
|
|
@@ -755,14 +843,15 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
755
843
|
agent_run = AgentRun(graph_run)
|
|
756
844
|
yield agent_run
|
|
757
845
|
if (final_result := agent_run.result) is not None and run_span.is_recording():
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
846
|
+
if instrumentation_settings and instrumentation_settings.include_content:
|
|
847
|
+
run_span.set_attribute(
|
|
848
|
+
'final_result',
|
|
849
|
+
(
|
|
850
|
+
final_result.output
|
|
851
|
+
if isinstance(final_result.output, str)
|
|
852
|
+
else json.dumps(InstrumentedModel.serialize_any(final_result.output))
|
|
853
|
+
),
|
|
854
|
+
)
|
|
766
855
|
finally:
|
|
767
856
|
try:
|
|
768
857
|
if instrumentation_settings and run_span.is_recording():
|
|
@@ -801,6 +890,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
801
890
|
usage_limits: _usage.UsageLimits | None = None,
|
|
802
891
|
usage: _usage.Usage | None = None,
|
|
803
892
|
infer_name: bool = True,
|
|
893
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
804
894
|
) -> AgentRunResult[OutputDataT]: ...
|
|
805
895
|
|
|
806
896
|
@overload
|
|
@@ -816,6 +906,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
816
906
|
usage_limits: _usage.UsageLimits | None = None,
|
|
817
907
|
usage: _usage.Usage | None = None,
|
|
818
908
|
infer_name: bool = True,
|
|
909
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
819
910
|
) -> AgentRunResult[RunOutputDataT]: ...
|
|
820
911
|
|
|
821
912
|
@overload
|
|
@@ -832,6 +923,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
832
923
|
usage_limits: _usage.UsageLimits | None = None,
|
|
833
924
|
usage: _usage.Usage | None = None,
|
|
834
925
|
infer_name: bool = True,
|
|
926
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
835
927
|
) -> AgentRunResult[RunOutputDataT]: ...
|
|
836
928
|
|
|
837
929
|
def run_sync(
|
|
@@ -846,6 +938,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
846
938
|
usage_limits: _usage.UsageLimits | None = None,
|
|
847
939
|
usage: _usage.Usage | None = None,
|
|
848
940
|
infer_name: bool = True,
|
|
941
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
849
942
|
**_deprecated_kwargs: Never,
|
|
850
943
|
) -> AgentRunResult[Any]:
|
|
851
944
|
"""Synchronously run the agent with a user prompt.
|
|
@@ -875,6 +968,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
875
968
|
usage_limits: Optional limits on model request count or token usage.
|
|
876
969
|
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
|
|
877
970
|
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
971
|
+
toolsets: Optional additional toolsets for this run.
|
|
878
972
|
|
|
879
973
|
Returns:
|
|
880
974
|
The result of the run.
|
|
@@ -901,6 +995,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
901
995
|
usage_limits=usage_limits,
|
|
902
996
|
usage=usage,
|
|
903
997
|
infer_name=False,
|
|
998
|
+
toolsets=toolsets,
|
|
904
999
|
)
|
|
905
1000
|
)
|
|
906
1001
|
|
|
@@ -916,6 +1011,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
916
1011
|
usage_limits: _usage.UsageLimits | None = None,
|
|
917
1012
|
usage: _usage.Usage | None = None,
|
|
918
1013
|
infer_name: bool = True,
|
|
1014
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
919
1015
|
) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, OutputDataT]]: ...
|
|
920
1016
|
|
|
921
1017
|
@overload
|
|
@@ -931,6 +1027,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
931
1027
|
usage_limits: _usage.UsageLimits | None = None,
|
|
932
1028
|
usage: _usage.Usage | None = None,
|
|
933
1029
|
infer_name: bool = True,
|
|
1030
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
934
1031
|
) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ...
|
|
935
1032
|
|
|
936
1033
|
@overload
|
|
@@ -947,6 +1044,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
947
1044
|
usage_limits: _usage.UsageLimits | None = None,
|
|
948
1045
|
usage: _usage.Usage | None = None,
|
|
949
1046
|
infer_name: bool = True,
|
|
1047
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
950
1048
|
) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ...
|
|
951
1049
|
|
|
952
1050
|
@asynccontextmanager
|
|
@@ -962,6 +1060,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
962
1060
|
usage_limits: _usage.UsageLimits | None = None,
|
|
963
1061
|
usage: _usage.Usage | None = None,
|
|
964
1062
|
infer_name: bool = True,
|
|
1063
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
965
1064
|
**_deprecated_kwargs: Never,
|
|
966
1065
|
) -> AsyncIterator[result.StreamedRunResult[AgentDepsT, Any]]:
|
|
967
1066
|
"""Run the agent with a user prompt in async mode, returning a streamed response.
|
|
@@ -989,6 +1088,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
989
1088
|
usage_limits: Optional limits on model request count or token usage.
|
|
990
1089
|
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
|
|
991
1090
|
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
1091
|
+
toolsets: Optional additional toolsets for this run.
|
|
992
1092
|
|
|
993
1093
|
Returns:
|
|
994
1094
|
The result of the run.
|
|
@@ -1019,6 +1119,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1019
1119
|
usage_limits=usage_limits,
|
|
1020
1120
|
usage=usage,
|
|
1021
1121
|
infer_name=False,
|
|
1122
|
+
toolsets=toolsets,
|
|
1022
1123
|
) as agent_run:
|
|
1023
1124
|
first_node = agent_run.next_node # start with the first node
|
|
1024
1125
|
assert isinstance(first_node, _agent_graph.UserPromptNode) # the first node should be a user prompt node
|
|
@@ -1039,15 +1140,17 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1039
1140
|
output_schema, _output.TextOutputSchema
|
|
1040
1141
|
):
|
|
1041
1142
|
return FinalResult(s, None, None)
|
|
1042
|
-
elif isinstance(new_part, _messages.ToolCallPart) and
|
|
1043
|
-
|
|
1044
|
-
):
|
|
1045
|
-
|
|
1046
|
-
return FinalResult(s,
|
|
1143
|
+
elif isinstance(new_part, _messages.ToolCallPart) and (
|
|
1144
|
+
tool_def := graph_ctx.deps.tool_manager.get_tool_def(new_part.tool_name)
|
|
1145
|
+
):
|
|
1146
|
+
if tool_def.kind == 'output':
|
|
1147
|
+
return FinalResult(s, new_part.tool_name, new_part.tool_call_id)
|
|
1148
|
+
elif tool_def.kind == 'deferred':
|
|
1149
|
+
return FinalResult(s, None, None)
|
|
1047
1150
|
return None
|
|
1048
1151
|
|
|
1049
|
-
|
|
1050
|
-
if
|
|
1152
|
+
final_result = await stream_to_final(streamed_response)
|
|
1153
|
+
if final_result is not None:
|
|
1051
1154
|
if yielded:
|
|
1052
1155
|
raise exceptions.AgentRunError('Agent run produced final results') # pragma: no cover
|
|
1053
1156
|
yielded = True
|
|
@@ -1068,17 +1171,13 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1068
1171
|
|
|
1069
1172
|
parts: list[_messages.ModelRequestPart] = []
|
|
1070
1173
|
async for _event in _agent_graph.process_function_tools(
|
|
1174
|
+
graph_ctx.deps.tool_manager,
|
|
1071
1175
|
tool_calls,
|
|
1072
|
-
|
|
1073
|
-
final_result_details.tool_call_id,
|
|
1176
|
+
final_result,
|
|
1074
1177
|
graph_ctx,
|
|
1075
1178
|
parts,
|
|
1076
1179
|
):
|
|
1077
1180
|
pass
|
|
1078
|
-
# TODO: Should we do something here related to the retry count?
|
|
1079
|
-
# Maybe we should move the incrementing of the retry count to where we actually make a request?
|
|
1080
|
-
# if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
|
|
1081
|
-
# ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
1082
1181
|
if parts:
|
|
1083
1182
|
messages.append(_messages.ModelRequest(parts))
|
|
1084
1183
|
|
|
@@ -1089,10 +1188,10 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1089
1188
|
streamed_response,
|
|
1090
1189
|
graph_ctx.deps.output_schema,
|
|
1091
1190
|
_agent_graph.build_run_context(graph_ctx),
|
|
1092
|
-
_output.build_trace_context(graph_ctx),
|
|
1093
1191
|
graph_ctx.deps.output_validators,
|
|
1094
|
-
|
|
1192
|
+
final_result.tool_name,
|
|
1095
1193
|
on_complete,
|
|
1194
|
+
graph_ctx.deps.tool_manager,
|
|
1096
1195
|
)
|
|
1097
1196
|
break
|
|
1098
1197
|
next_node = await agent_run.next(node)
|
|
@@ -1111,8 +1210,9 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1111
1210
|
*,
|
|
1112
1211
|
deps: AgentDepsT | _utils.Unset = _utils.UNSET,
|
|
1113
1212
|
model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET,
|
|
1213
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET,
|
|
1114
1214
|
) -> Iterator[None]:
|
|
1115
|
-
"""Context manager to temporarily override agent dependencies
|
|
1215
|
+
"""Context manager to temporarily override agent dependencies, model, or toolsets.
|
|
1116
1216
|
|
|
1117
1217
|
This is particularly useful when testing.
|
|
1118
1218
|
You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures).
|
|
@@ -1120,6 +1220,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1120
1220
|
Args:
|
|
1121
1221
|
deps: The dependencies to use instead of the dependencies passed to the agent run.
|
|
1122
1222
|
model: The model to use instead of the model passed to the agent run.
|
|
1223
|
+
toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run.
|
|
1123
1224
|
"""
|
|
1124
1225
|
if _utils.is_set(deps):
|
|
1125
1226
|
deps_token = self._override_deps.set(_utils.Some(deps))
|
|
@@ -1131,6 +1232,11 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1131
1232
|
else:
|
|
1132
1233
|
model_token = None
|
|
1133
1234
|
|
|
1235
|
+
if _utils.is_set(toolsets):
|
|
1236
|
+
toolsets_token = self._override_toolsets.set(_utils.Some(toolsets))
|
|
1237
|
+
else:
|
|
1238
|
+
toolsets_token = None
|
|
1239
|
+
|
|
1134
1240
|
try:
|
|
1135
1241
|
yield
|
|
1136
1242
|
finally:
|
|
@@ -1138,6 +1244,8 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1138
1244
|
self._override_deps.reset(deps_token)
|
|
1139
1245
|
if model_token is not None:
|
|
1140
1246
|
self._override_model.reset(model_token)
|
|
1247
|
+
if toolsets_token is not None:
|
|
1248
|
+
self._override_toolsets.reset(toolsets_token)
|
|
1141
1249
|
|
|
1142
1250
|
@overload
|
|
1143
1251
|
def instructions(
|
|
@@ -1423,30 +1531,13 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1423
1531
|
strict: Whether to enforce JSON schema compliance (only affects OpenAI).
|
|
1424
1532
|
See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
|
|
1425
1533
|
"""
|
|
1426
|
-
if func is None:
|
|
1427
|
-
|
|
1428
|
-
def tool_decorator(
|
|
1429
|
-
func_: ToolFuncContext[AgentDepsT, ToolParams],
|
|
1430
|
-
) -> ToolFuncContext[AgentDepsT, ToolParams]:
|
|
1431
|
-
# noinspection PyTypeChecker
|
|
1432
|
-
self._register_function(
|
|
1433
|
-
func_,
|
|
1434
|
-
True,
|
|
1435
|
-
name,
|
|
1436
|
-
retries,
|
|
1437
|
-
prepare,
|
|
1438
|
-
docstring_format,
|
|
1439
|
-
require_parameter_descriptions,
|
|
1440
|
-
schema_generator,
|
|
1441
|
-
strict,
|
|
1442
|
-
)
|
|
1443
|
-
return func_
|
|
1444
1534
|
|
|
1445
|
-
|
|
1446
|
-
|
|
1535
|
+
def tool_decorator(
|
|
1536
|
+
func_: ToolFuncContext[AgentDepsT, ToolParams],
|
|
1537
|
+
) -> ToolFuncContext[AgentDepsT, ToolParams]:
|
|
1447
1538
|
# noinspection PyTypeChecker
|
|
1448
|
-
self.
|
|
1449
|
-
|
|
1539
|
+
self._function_toolset.add_function(
|
|
1540
|
+
func_,
|
|
1450
1541
|
True,
|
|
1451
1542
|
name,
|
|
1452
1543
|
retries,
|
|
@@ -1456,7 +1547,9 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1456
1547
|
schema_generator,
|
|
1457
1548
|
strict,
|
|
1458
1549
|
)
|
|
1459
|
-
return
|
|
1550
|
+
return func_
|
|
1551
|
+
|
|
1552
|
+
return tool_decorator if func is None else tool_decorator(func)
|
|
1460
1553
|
|
|
1461
1554
|
@overload
|
|
1462
1555
|
def tool_plain(self, func: ToolFuncPlain[ToolParams], /) -> ToolFuncPlain[ToolParams]: ...
|
|
@@ -1532,27 +1625,11 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1532
1625
|
strict: Whether to enforce JSON schema compliance (only affects OpenAI).
|
|
1533
1626
|
See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
|
|
1534
1627
|
"""
|
|
1535
|
-
if func is None:
|
|
1536
1628
|
|
|
1537
|
-
|
|
1538
|
-
|
|
1539
|
-
|
|
1540
|
-
|
|
1541
|
-
False,
|
|
1542
|
-
name,
|
|
1543
|
-
retries,
|
|
1544
|
-
prepare,
|
|
1545
|
-
docstring_format,
|
|
1546
|
-
require_parameter_descriptions,
|
|
1547
|
-
schema_generator,
|
|
1548
|
-
strict,
|
|
1549
|
-
)
|
|
1550
|
-
return func_
|
|
1551
|
-
|
|
1552
|
-
return tool_decorator
|
|
1553
|
-
else:
|
|
1554
|
-
self._register_function(
|
|
1555
|
-
func,
|
|
1629
|
+
def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
|
|
1630
|
+
# noinspection PyTypeChecker
|
|
1631
|
+
self._function_toolset.add_function(
|
|
1632
|
+
func_,
|
|
1556
1633
|
False,
|
|
1557
1634
|
name,
|
|
1558
1635
|
retries,
|
|
@@ -1562,48 +1639,9 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1562
1639
|
schema_generator,
|
|
1563
1640
|
strict,
|
|
1564
1641
|
)
|
|
1565
|
-
return
|
|
1566
|
-
|
|
1567
|
-
def _register_function(
|
|
1568
|
-
self,
|
|
1569
|
-
func: ToolFuncEither[AgentDepsT, ToolParams],
|
|
1570
|
-
takes_ctx: bool,
|
|
1571
|
-
name: str | None,
|
|
1572
|
-
retries: int | None,
|
|
1573
|
-
prepare: ToolPrepareFunc[AgentDepsT] | None,
|
|
1574
|
-
docstring_format: DocstringFormat,
|
|
1575
|
-
require_parameter_descriptions: bool,
|
|
1576
|
-
schema_generator: type[GenerateJsonSchema],
|
|
1577
|
-
strict: bool | None,
|
|
1578
|
-
) -> None:
|
|
1579
|
-
"""Private utility to register a function as a tool."""
|
|
1580
|
-
retries_ = retries if retries is not None else self._default_retries
|
|
1581
|
-
tool = Tool[AgentDepsT](
|
|
1582
|
-
func,
|
|
1583
|
-
takes_ctx=takes_ctx,
|
|
1584
|
-
name=name,
|
|
1585
|
-
max_retries=retries_,
|
|
1586
|
-
prepare=prepare,
|
|
1587
|
-
docstring_format=docstring_format,
|
|
1588
|
-
require_parameter_descriptions=require_parameter_descriptions,
|
|
1589
|
-
schema_generator=schema_generator,
|
|
1590
|
-
strict=strict,
|
|
1591
|
-
)
|
|
1592
|
-
self._register_tool(tool)
|
|
1593
|
-
|
|
1594
|
-
def _register_tool(self, tool: Tool[AgentDepsT]) -> None:
|
|
1595
|
-
"""Private utility to register a tool instance."""
|
|
1596
|
-
if tool.max_retries is None:
|
|
1597
|
-
# noinspection PyTypeChecker
|
|
1598
|
-
tool = dataclasses.replace(tool, max_retries=self._default_retries)
|
|
1599
|
-
|
|
1600
|
-
if tool.name in self._function_tools:
|
|
1601
|
-
raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}')
|
|
1602
|
-
|
|
1603
|
-
if tool.name in self._output_schema.tools:
|
|
1604
|
-
raise exceptions.UserError(f'Tool name conflicts with output tool name: {tool.name!r}')
|
|
1642
|
+
return func_
|
|
1605
1643
|
|
|
1606
|
-
|
|
1644
|
+
return tool_decorator if func is None else tool_decorator(func)
|
|
1607
1645
|
|
|
1608
1646
|
def _get_model(self, model: models.Model | models.KnownModelName | str | None) -> models.Model:
|
|
1609
1647
|
"""Create a model configured for this agent.
|
|
@@ -1649,6 +1687,37 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1649
1687
|
else:
|
|
1650
1688
|
return deps
|
|
1651
1689
|
|
|
1690
|
+
def _get_toolset(
|
|
1691
|
+
self,
|
|
1692
|
+
output_toolset: AbstractToolset[AgentDepsT] | None | _utils.Unset = _utils.UNSET,
|
|
1693
|
+
additional_toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
1694
|
+
) -> AbstractToolset[AgentDepsT]:
|
|
1695
|
+
"""Get the complete toolset.
|
|
1696
|
+
|
|
1697
|
+
Args:
|
|
1698
|
+
output_toolset: The output toolset to use instead of the one built at agent construction time.
|
|
1699
|
+
additional_toolsets: Additional toolsets to add.
|
|
1700
|
+
"""
|
|
1701
|
+
if some_user_toolsets := self._override_toolsets.get():
|
|
1702
|
+
user_toolsets = some_user_toolsets.value
|
|
1703
|
+
elif additional_toolsets is not None:
|
|
1704
|
+
user_toolsets = [*self._user_toolsets, *additional_toolsets]
|
|
1705
|
+
else:
|
|
1706
|
+
user_toolsets = self._user_toolsets
|
|
1707
|
+
|
|
1708
|
+
all_toolsets = [self._function_toolset, *user_toolsets]
|
|
1709
|
+
|
|
1710
|
+
if self._prepare_tools:
|
|
1711
|
+
all_toolsets = [PreparedToolset(CombinedToolset(all_toolsets), self._prepare_tools)]
|
|
1712
|
+
|
|
1713
|
+
output_toolset = output_toolset if _utils.is_set(output_toolset) else self._output_toolset
|
|
1714
|
+
if output_toolset is not None:
|
|
1715
|
+
if self._prepare_output_tools:
|
|
1716
|
+
output_toolset = PreparedToolset(output_toolset, self._prepare_output_tools)
|
|
1717
|
+
all_toolsets = [output_toolset, *all_toolsets]
|
|
1718
|
+
|
|
1719
|
+
return CombinedToolset(all_toolsets)
|
|
1720
|
+
|
|
1652
1721
|
def _infer_name(self, function_frame: FrameType | None) -> None:
|
|
1653
1722
|
"""Infer the agent name from the call frame.
|
|
1654
1723
|
|
|
@@ -1734,28 +1803,167 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1734
1803
|
"""
|
|
1735
1804
|
return isinstance(node, End)
|
|
1736
1805
|
|
|
1806
|
+
async def __aenter__(self) -> Self:
|
|
1807
|
+
"""Enter the agent context.
|
|
1808
|
+
|
|
1809
|
+
This will start all [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] registered as `toolsets` so they are ready to be used.
|
|
1810
|
+
|
|
1811
|
+
This is a no-op if the agent has already been entered.
|
|
1812
|
+
"""
|
|
1813
|
+
async with self._enter_lock:
|
|
1814
|
+
if self._entered_count == 0:
|
|
1815
|
+
self._exit_stack = AsyncExitStack()
|
|
1816
|
+
toolset = self._get_toolset()
|
|
1817
|
+
await self._exit_stack.enter_async_context(toolset)
|
|
1818
|
+
self._entered_count += 1
|
|
1819
|
+
return self
|
|
1820
|
+
|
|
1821
|
+
async def __aexit__(self, *args: Any) -> bool | None:
|
|
1822
|
+
async with self._enter_lock:
|
|
1823
|
+
self._entered_count -= 1
|
|
1824
|
+
if self._entered_count == 0 and self._exit_stack is not None:
|
|
1825
|
+
await self._exit_stack.aclose()
|
|
1826
|
+
self._exit_stack = None
|
|
1827
|
+
|
|
1828
|
+
def set_mcp_sampling_model(self, model: models.Model | models.KnownModelName | str | None = None) -> None:
|
|
1829
|
+
"""Set the sampling model on all MCP servers registered with the agent.
|
|
1830
|
+
|
|
1831
|
+
If no sampling model is provided, the agent's model will be used.
|
|
1832
|
+
"""
|
|
1833
|
+
try:
|
|
1834
|
+
sampling_model = models.infer_model(model) if model else self._get_model(None)
|
|
1835
|
+
except exceptions.UserError as e:
|
|
1836
|
+
raise exceptions.UserError('No sampling model provided and no model set on the agent.') from e
|
|
1837
|
+
|
|
1838
|
+
from .mcp import MCPServer
|
|
1839
|
+
|
|
1840
|
+
def _set_sampling_model(toolset: AbstractToolset[AgentDepsT]) -> None:
|
|
1841
|
+
if isinstance(toolset, MCPServer):
|
|
1842
|
+
toolset.sampling_model = sampling_model
|
|
1843
|
+
|
|
1844
|
+
self._get_toolset().apply(_set_sampling_model)
|
|
1845
|
+
|
|
1737
1846
|
@asynccontextmanager
|
|
1847
|
+
@deprecated(
|
|
1848
|
+
'`run_mcp_servers` is deprecated, use `async with agent:` instead. If you need to set a sampling model on all MCP servers, use `agent.set_mcp_sampling_model()`.'
|
|
1849
|
+
)
|
|
1738
1850
|
async def run_mcp_servers(
|
|
1739
1851
|
self, model: models.Model | models.KnownModelName | str | None = None
|
|
1740
1852
|
) -> AsyncIterator[None]:
|
|
1741
1853
|
"""Run [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] so they can be used by the agent.
|
|
1742
1854
|
|
|
1855
|
+
Deprecated: use [`async with agent`][pydantic_ai.agent.Agent.__aenter__] instead.
|
|
1856
|
+
If you need to set a sampling model on all MCP servers, use [`agent.set_mcp_sampling_model()`][pydantic_ai.agent.Agent.set_mcp_sampling_model].
|
|
1857
|
+
|
|
1743
1858
|
Returns: a context manager to start and shutdown the servers.
|
|
1744
1859
|
"""
|
|
1745
1860
|
try:
|
|
1746
|
-
|
|
1747
|
-
except exceptions.UserError:
|
|
1748
|
-
|
|
1861
|
+
self.set_mcp_sampling_model(model)
|
|
1862
|
+
except exceptions.UserError:
|
|
1863
|
+
if model is not None:
|
|
1864
|
+
raise
|
|
1749
1865
|
|
|
1750
|
-
|
|
1751
|
-
try:
|
|
1752
|
-
for mcp_server in self._mcp_servers:
|
|
1753
|
-
if sampling_model is not None: # pragma: no branch
|
|
1754
|
-
mcp_server.sampling_model = sampling_model
|
|
1755
|
-
await exit_stack.enter_async_context(mcp_server)
|
|
1866
|
+
async with self:
|
|
1756
1867
|
yield
|
|
1757
|
-
|
|
1758
|
-
|
|
1868
|
+
|
|
1869
|
+
def to_ag_ui(
|
|
1870
|
+
self,
|
|
1871
|
+
*,
|
|
1872
|
+
# Agent.iter parameters
|
|
1873
|
+
output_type: OutputSpec[OutputDataT] | None = None,
|
|
1874
|
+
model: models.Model | models.KnownModelName | str | None = None,
|
|
1875
|
+
deps: AgentDepsT = None,
|
|
1876
|
+
model_settings: ModelSettings | None = None,
|
|
1877
|
+
usage_limits: UsageLimits | None = None,
|
|
1878
|
+
usage: Usage | None = None,
|
|
1879
|
+
infer_name: bool = True,
|
|
1880
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
1881
|
+
# Starlette
|
|
1882
|
+
debug: bool = False,
|
|
1883
|
+
routes: Sequence[BaseRoute] | None = None,
|
|
1884
|
+
middleware: Sequence[Middleware] | None = None,
|
|
1885
|
+
exception_handlers: Mapping[Any, ExceptionHandler] | None = None,
|
|
1886
|
+
on_startup: Sequence[Callable[[], Any]] | None = None,
|
|
1887
|
+
on_shutdown: Sequence[Callable[[], Any]] | None = None,
|
|
1888
|
+
lifespan: Lifespan[AGUIApp[AgentDepsT, OutputDataT]] | None = None,
|
|
1889
|
+
) -> AGUIApp[AgentDepsT, OutputDataT]:
|
|
1890
|
+
"""Convert the agent to an AG-UI application.
|
|
1891
|
+
|
|
1892
|
+
This allows you to use the agent with a compatible AG-UI frontend.
|
|
1893
|
+
|
|
1894
|
+
Example:
|
|
1895
|
+
```python
|
|
1896
|
+
from pydantic_ai import Agent
|
|
1897
|
+
|
|
1898
|
+
agent = Agent('openai:gpt-4o')
|
|
1899
|
+
app = agent.to_ag_ui()
|
|
1900
|
+
```
|
|
1901
|
+
|
|
1902
|
+
The `app` is an ASGI application that can be used with any ASGI server.
|
|
1903
|
+
|
|
1904
|
+
To run the application, you can use the following command:
|
|
1905
|
+
|
|
1906
|
+
```bash
|
|
1907
|
+
uvicorn app:app --host 0.0.0.0 --port 8000
|
|
1908
|
+
```
|
|
1909
|
+
|
|
1910
|
+
See [AG-UI docs](../ag-ui.md) for more information.
|
|
1911
|
+
|
|
1912
|
+
Args:
|
|
1913
|
+
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has
|
|
1914
|
+
no output validators since output validators would expect an argument that matches the agent's
|
|
1915
|
+
output type.
|
|
1916
|
+
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
1917
|
+
deps: Optional dependencies to use for this run.
|
|
1918
|
+
model_settings: Optional settings to use for this model's request.
|
|
1919
|
+
usage_limits: Optional limits on model request count or token usage.
|
|
1920
|
+
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
|
|
1921
|
+
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
1922
|
+
toolsets: Optional list of toolsets to use for this agent, defaults to the agent's toolset.
|
|
1923
|
+
|
|
1924
|
+
debug: Boolean indicating if debug tracebacks should be returned on errors.
|
|
1925
|
+
routes: A list of routes to serve incoming HTTP and WebSocket requests.
|
|
1926
|
+
middleware: A list of middleware to run for every request. A starlette application will always
|
|
1927
|
+
automatically include two middleware classes. `ServerErrorMiddleware` is added as the very
|
|
1928
|
+
outermost middleware, to handle any uncaught errors occurring anywhere in the entire stack.
|
|
1929
|
+
`ExceptionMiddleware` is added as the very innermost middleware, to deal with handled
|
|
1930
|
+
exception cases occurring in the routing or endpoints.
|
|
1931
|
+
exception_handlers: A mapping of either integer status codes, or exception class types onto
|
|
1932
|
+
callables which handle the exceptions. Exception handler callables should be of the form
|
|
1933
|
+
`handler(request, exc) -> response` and may be either standard functions, or async functions.
|
|
1934
|
+
on_startup: A list of callables to run on application startup. Startup handler callables do not
|
|
1935
|
+
take any arguments, and may be either standard functions, or async functions.
|
|
1936
|
+
on_shutdown: A list of callables to run on application shutdown. Shutdown handler callables do
|
|
1937
|
+
not take any arguments, and may be either standard functions, or async functions.
|
|
1938
|
+
lifespan: A lifespan context function, which can be used to perform startup and shutdown tasks.
|
|
1939
|
+
This is a newer style that replaces the `on_startup` and `on_shutdown` handlers. Use one or
|
|
1940
|
+
the other, not both.
|
|
1941
|
+
|
|
1942
|
+
Returns:
|
|
1943
|
+
An ASGI application for running Pydantic AI agents with AG-UI protocol support.
|
|
1944
|
+
"""
|
|
1945
|
+
from .ag_ui import AGUIApp
|
|
1946
|
+
|
|
1947
|
+
return AGUIApp(
|
|
1948
|
+
agent=self,
|
|
1949
|
+
# Agent.iter parameters
|
|
1950
|
+
output_type=output_type,
|
|
1951
|
+
model=model,
|
|
1952
|
+
deps=deps,
|
|
1953
|
+
model_settings=model_settings,
|
|
1954
|
+
usage_limits=usage_limits,
|
|
1955
|
+
usage=usage,
|
|
1956
|
+
infer_name=infer_name,
|
|
1957
|
+
toolsets=toolsets,
|
|
1958
|
+
# Starlette
|
|
1959
|
+
debug=debug,
|
|
1960
|
+
routes=routes,
|
|
1961
|
+
middleware=middleware,
|
|
1962
|
+
exception_handlers=exception_handlers,
|
|
1963
|
+
on_startup=on_startup,
|
|
1964
|
+
on_shutdown=on_shutdown,
|
|
1965
|
+
lifespan=lifespan,
|
|
1966
|
+
)
|
|
1759
1967
|
|
|
1760
1968
|
def to_a2a(
|
|
1761
1969
|
self,
|
|
@@ -2112,12 +2320,18 @@ class AgentRunResult(Generic[OutputDataT]):
|
|
|
2112
2320
|
"""
|
|
2113
2321
|
if not self._output_tool_name:
|
|
2114
2322
|
raise ValueError('Cannot set output tool return content when the return type is `str`.')
|
|
2115
|
-
|
|
2323
|
+
|
|
2324
|
+
messages = self._state.message_history
|
|
2116
2325
|
last_message = messages[-1]
|
|
2117
|
-
for part in last_message.parts:
|
|
2326
|
+
for idx, part in enumerate(last_message.parts):
|
|
2118
2327
|
if isinstance(part, _messages.ToolReturnPart) and part.tool_name == self._output_tool_name:
|
|
2119
|
-
|
|
2120
|
-
|
|
2328
|
+
# Only do deepcopy when we have to modify
|
|
2329
|
+
copied_messages = list(messages)
|
|
2330
|
+
copied_last = deepcopy(last_message)
|
|
2331
|
+
copied_last.parts[idx].content = return_content # type: ignore[misc]
|
|
2332
|
+
copied_messages[-1] = copied_last
|
|
2333
|
+
return copied_messages
|
|
2334
|
+
|
|
2121
2335
|
raise LookupError(f'No tool call found with tool name {self._output_tool_name!r}.')
|
|
2122
2336
|
|
|
2123
2337
|
@overload
|