pydantic-ai-slim 0.4.3__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +220 -319
- pydantic_ai/_cli.py +9 -7
- pydantic_ai/_output.py +295 -331
- pydantic_ai/_parts_manager.py +2 -2
- pydantic_ai/_run_context.py +8 -14
- pydantic_ai/_tool_manager.py +190 -0
- pydantic_ai/_utils.py +18 -1
- pydantic_ai/ag_ui.py +675 -0
- pydantic_ai/agent.py +369 -156
- pydantic_ai/exceptions.py +12 -0
- pydantic_ai/ext/aci.py +12 -3
- pydantic_ai/ext/langchain.py +9 -1
- pydantic_ai/mcp.py +147 -84
- pydantic_ai/messages.py +13 -5
- pydantic_ai/models/__init__.py +30 -18
- pydantic_ai/models/anthropic.py +1 -1
- pydantic_ai/models/function.py +50 -24
- pydantic_ai/models/gemini.py +1 -9
- pydantic_ai/models/google.py +2 -11
- pydantic_ai/models/groq.py +1 -0
- pydantic_ai/models/mistral.py +1 -1
- pydantic_ai/models/openai.py +3 -3
- pydantic_ai/output.py +21 -7
- pydantic_ai/profiles/google.py +1 -1
- pydantic_ai/profiles/moonshotai.py +8 -0
- pydantic_ai/providers/grok.py +13 -1
- pydantic_ai/providers/groq.py +2 -0
- pydantic_ai/result.py +58 -45
- pydantic_ai/tools.py +26 -119
- pydantic_ai/toolsets/__init__.py +22 -0
- pydantic_ai/toolsets/abstract.py +155 -0
- pydantic_ai/toolsets/combined.py +88 -0
- pydantic_ai/toolsets/deferred.py +38 -0
- pydantic_ai/toolsets/filtered.py +24 -0
- pydantic_ai/toolsets/function.py +238 -0
- pydantic_ai/toolsets/prefixed.py +37 -0
- pydantic_ai/toolsets/prepared.py +36 -0
- pydantic_ai/toolsets/renamed.py +42 -0
- pydantic_ai/toolsets/wrapper.py +37 -0
- pydantic_ai/usage.py +14 -8
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/METADATA +10 -7
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/RECORD +45 -32
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.4.3.dist-info → pydantic_ai_slim-0.4.4.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/agent.py
CHANGED
|
@@ -4,7 +4,8 @@ import dataclasses
|
|
|
4
4
|
import inspect
|
|
5
5
|
import json
|
|
6
6
|
import warnings
|
|
7
|
-
from
|
|
7
|
+
from asyncio import Lock
|
|
8
|
+
from collections.abc import AsyncIterator, Awaitable, Iterator, Mapping, Sequence
|
|
8
9
|
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager
|
|
9
10
|
from contextvars import ContextVar
|
|
10
11
|
from copy import deepcopy
|
|
@@ -15,7 +16,6 @@ from opentelemetry.trace import NoOpTracer, use_span
|
|
|
15
16
|
from pydantic.json_schema import GenerateJsonSchema
|
|
16
17
|
from typing_extensions import Literal, Never, Self, TypeIs, TypeVar, deprecated
|
|
17
18
|
|
|
18
|
-
from pydantic_ai.profiles import ModelProfile
|
|
19
19
|
from pydantic_graph import End, Graph, GraphRun, GraphRunContext
|
|
20
20
|
from pydantic_graph._utils import get_event_loop
|
|
21
21
|
|
|
@@ -31,8 +31,11 @@ from . import (
|
|
|
31
31
|
usage as _usage,
|
|
32
32
|
)
|
|
33
33
|
from ._agent_graph import HistoryProcessor
|
|
34
|
+
from ._output import OutputToolset
|
|
35
|
+
from ._tool_manager import ToolManager
|
|
34
36
|
from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model
|
|
35
37
|
from .output import OutputDataT, OutputSpec
|
|
38
|
+
from .profiles import ModelProfile
|
|
36
39
|
from .result import FinalResult, StreamedRunResult
|
|
37
40
|
from .settings import ModelSettings, merge_model_settings
|
|
38
41
|
from .tools import (
|
|
@@ -48,6 +51,11 @@ from .tools import (
|
|
|
48
51
|
ToolPrepareFunc,
|
|
49
52
|
ToolsPrepareFunc,
|
|
50
53
|
)
|
|
54
|
+
from .toolsets import AbstractToolset
|
|
55
|
+
from .toolsets.combined import CombinedToolset
|
|
56
|
+
from .toolsets.function import FunctionToolset
|
|
57
|
+
from .toolsets.prepared import PreparedToolset
|
|
58
|
+
from .usage import Usage, UsageLimits
|
|
51
59
|
|
|
52
60
|
# Re-exporting like this improves auto-import behavior in PyCharm
|
|
53
61
|
capture_run_messages = _agent_graph.capture_run_messages
|
|
@@ -62,11 +70,12 @@ if TYPE_CHECKING:
|
|
|
62
70
|
from fasta2a.schema import AgentProvider, Skill
|
|
63
71
|
from fasta2a.storage import Storage
|
|
64
72
|
from starlette.middleware import Middleware
|
|
65
|
-
from starlette.routing import Route
|
|
73
|
+
from starlette.routing import BaseRoute, Route
|
|
66
74
|
from starlette.types import ExceptionHandler, Lifespan
|
|
67
75
|
|
|
68
76
|
from pydantic_ai.mcp import MCPServer
|
|
69
77
|
|
|
78
|
+
from .ag_ui import AGUIApp
|
|
70
79
|
|
|
71
80
|
__all__ = (
|
|
72
81
|
'Agent',
|
|
@@ -153,12 +162,17 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
153
162
|
_system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(
|
|
154
163
|
repr=False
|
|
155
164
|
)
|
|
165
|
+
_function_toolset: FunctionToolset[AgentDepsT] = dataclasses.field(repr=False)
|
|
166
|
+
_output_toolset: OutputToolset[AgentDepsT] | None = dataclasses.field(repr=False)
|
|
167
|
+
_user_toolsets: Sequence[AbstractToolset[AgentDepsT]] = dataclasses.field(repr=False)
|
|
156
168
|
_prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False)
|
|
157
|
-
|
|
158
|
-
_mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
|
|
159
|
-
_default_retries: int = dataclasses.field(repr=False)
|
|
169
|
+
_prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False)
|
|
160
170
|
_max_result_retries: int = dataclasses.field(repr=False)
|
|
161
171
|
|
|
172
|
+
_enter_lock: Lock = dataclasses.field(repr=False)
|
|
173
|
+
_entered_count: int = dataclasses.field(repr=False)
|
|
174
|
+
_exit_stack: AsyncExitStack | None = dataclasses.field(repr=False)
|
|
175
|
+
|
|
162
176
|
@overload
|
|
163
177
|
def __init__(
|
|
164
178
|
self,
|
|
@@ -177,7 +191,8 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
177
191
|
output_retries: int | None = None,
|
|
178
192
|
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
|
|
179
193
|
prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
|
|
180
|
-
|
|
194
|
+
prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
|
|
195
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
181
196
|
defer_model_check: bool = False,
|
|
182
197
|
end_strategy: EndStrategy = 'early',
|
|
183
198
|
instrument: InstrumentationSettings | bool | None = None,
|
|
@@ -186,7 +201,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
186
201
|
|
|
187
202
|
@overload
|
|
188
203
|
@deprecated(
|
|
189
|
-
'`result_type`, `result_tool_name
|
|
204
|
+
'`result_type`, `result_tool_name` & `result_tool_description` are deprecated, use `output_type` instead. `result_retries` is deprecated, use `output_retries` instead.'
|
|
190
205
|
)
|
|
191
206
|
def __init__(
|
|
192
207
|
self,
|
|
@@ -207,6 +222,36 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
207
222
|
result_retries: int | None = None,
|
|
208
223
|
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
|
|
209
224
|
prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
|
|
225
|
+
prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
|
|
226
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
227
|
+
defer_model_check: bool = False,
|
|
228
|
+
end_strategy: EndStrategy = 'early',
|
|
229
|
+
instrument: InstrumentationSettings | bool | None = None,
|
|
230
|
+
history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None,
|
|
231
|
+
) -> None: ...
|
|
232
|
+
|
|
233
|
+
@overload
|
|
234
|
+
@deprecated('`mcp_servers` is deprecated, use `toolsets` instead.')
|
|
235
|
+
def __init__(
|
|
236
|
+
self,
|
|
237
|
+
model: models.Model | models.KnownModelName | str | None = None,
|
|
238
|
+
*,
|
|
239
|
+
result_type: type[OutputDataT] = str,
|
|
240
|
+
instructions: str
|
|
241
|
+
| _system_prompt.SystemPromptFunc[AgentDepsT]
|
|
242
|
+
| Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]]
|
|
243
|
+
| None = None,
|
|
244
|
+
system_prompt: str | Sequence[str] = (),
|
|
245
|
+
deps_type: type[AgentDepsT] = NoneType,
|
|
246
|
+
name: str | None = None,
|
|
247
|
+
model_settings: ModelSettings | None = None,
|
|
248
|
+
retries: int = 1,
|
|
249
|
+
result_tool_name: str = _output.DEFAULT_OUTPUT_TOOL_NAME,
|
|
250
|
+
result_tool_description: str | None = None,
|
|
251
|
+
result_retries: int | None = None,
|
|
252
|
+
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
|
|
253
|
+
prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
|
|
254
|
+
prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
|
|
210
255
|
mcp_servers: Sequence[MCPServer] = (),
|
|
211
256
|
defer_model_check: bool = False,
|
|
212
257
|
end_strategy: EndStrategy = 'early',
|
|
@@ -232,7 +277,8 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
232
277
|
output_retries: int | None = None,
|
|
233
278
|
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
|
|
234
279
|
prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
|
|
235
|
-
|
|
280
|
+
prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
|
|
281
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
236
282
|
defer_model_check: bool = False,
|
|
237
283
|
end_strategy: EndStrategy = 'early',
|
|
238
284
|
instrument: InstrumentationSettings | bool | None = None,
|
|
@@ -258,14 +304,16 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
258
304
|
when the agent is first run.
|
|
259
305
|
model_settings: Optional model request settings to use for this agent's runs, by default.
|
|
260
306
|
retries: The default number of retries to allow before raising an error.
|
|
261
|
-
output_retries: The maximum number of retries to allow for
|
|
307
|
+
output_retries: The maximum number of retries to allow for output validation, defaults to `retries`.
|
|
262
308
|
tools: Tools to register with the agent, you can also register tools via the decorators
|
|
263
309
|
[`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain].
|
|
264
|
-
prepare_tools:
|
|
310
|
+
prepare_tools: Custom function to prepare the tool definition of all tools for each step, except output tools.
|
|
265
311
|
This is useful if you want to customize the definition of multiple tools or you want to register
|
|
266
312
|
a subset of tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc]
|
|
267
|
-
|
|
268
|
-
|
|
313
|
+
prepare_output_tools: Custom function to prepare the tool definition of all output tools for each step.
|
|
314
|
+
This is useful if you want to customize the definition of multiple output tools or you want to register
|
|
315
|
+
a subset of output tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc]
|
|
316
|
+
toolsets: Toolsets to register with the agent, including MCP servers.
|
|
269
317
|
defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model,
|
|
270
318
|
it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately,
|
|
271
319
|
which checks for the necessary environment variables. Set this to `false`
|
|
@@ -329,10 +377,17 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
329
377
|
)
|
|
330
378
|
output_retries = result_retries
|
|
331
379
|
|
|
380
|
+
if mcp_servers := _deprecated_kwargs.pop('mcp_servers', None):
|
|
381
|
+
if toolsets is not None: # pragma: no cover
|
|
382
|
+
raise TypeError('`mcp_servers` and `toolsets` cannot be set at the same time.')
|
|
383
|
+
warnings.warn('`mcp_servers` is deprecated, use `toolsets` instead', DeprecationWarning)
|
|
384
|
+
toolsets = mcp_servers
|
|
385
|
+
|
|
386
|
+
_utils.validate_empty_kwargs(_deprecated_kwargs)
|
|
387
|
+
|
|
332
388
|
default_output_mode = (
|
|
333
389
|
self.model.profile.default_structured_output_mode if isinstance(self.model, models.Model) else None
|
|
334
390
|
)
|
|
335
|
-
_utils.validate_empty_kwargs(_deprecated_kwargs)
|
|
336
391
|
|
|
337
392
|
self._output_schema = _output.OutputSchema[OutputDataT].build(
|
|
338
393
|
output_type,
|
|
@@ -357,21 +412,28 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
357
412
|
self._system_prompt_functions = []
|
|
358
413
|
self._system_prompt_dynamic_functions = {}
|
|
359
414
|
|
|
360
|
-
self._function_tools = {}
|
|
361
|
-
|
|
362
|
-
self._default_retries = retries
|
|
363
415
|
self._max_result_retries = output_retries if output_retries is not None else retries
|
|
364
|
-
self._mcp_servers = mcp_servers
|
|
365
416
|
self._prepare_tools = prepare_tools
|
|
417
|
+
self._prepare_output_tools = prepare_output_tools
|
|
418
|
+
|
|
419
|
+
self._output_toolset = self._output_schema.toolset
|
|
420
|
+
if self._output_toolset:
|
|
421
|
+
self._output_toolset.max_retries = self._max_result_retries
|
|
422
|
+
|
|
423
|
+
self._function_toolset = FunctionToolset(tools, max_retries=retries)
|
|
424
|
+
self._user_toolsets = toolsets or ()
|
|
425
|
+
|
|
366
426
|
self.history_processors = history_processors or []
|
|
367
|
-
for tool in tools:
|
|
368
|
-
if isinstance(tool, Tool):
|
|
369
|
-
self._register_tool(tool)
|
|
370
|
-
else:
|
|
371
|
-
self._register_tool(Tool(tool))
|
|
372
427
|
|
|
373
428
|
self._override_deps: ContextVar[_utils.Option[AgentDepsT]] = ContextVar('_override_deps', default=None)
|
|
374
429
|
self._override_model: ContextVar[_utils.Option[models.Model]] = ContextVar('_override_model', default=None)
|
|
430
|
+
self._override_toolsets: ContextVar[_utils.Option[Sequence[AbstractToolset[AgentDepsT]]]] = ContextVar(
|
|
431
|
+
'_override_toolsets', default=None
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
self._enter_lock = _utils.get_async_lock()
|
|
435
|
+
self._entered_count = 0
|
|
436
|
+
self._exit_stack = None
|
|
375
437
|
|
|
376
438
|
@staticmethod
|
|
377
439
|
def instrument_all(instrument: InstrumentationSettings | bool = True) -> None:
|
|
@@ -391,6 +453,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
391
453
|
usage_limits: _usage.UsageLimits | None = None,
|
|
392
454
|
usage: _usage.Usage | None = None,
|
|
393
455
|
infer_name: bool = True,
|
|
456
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
394
457
|
) -> AgentRunResult[OutputDataT]: ...
|
|
395
458
|
|
|
396
459
|
@overload
|
|
@@ -406,6 +469,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
406
469
|
usage_limits: _usage.UsageLimits | None = None,
|
|
407
470
|
usage: _usage.Usage | None = None,
|
|
408
471
|
infer_name: bool = True,
|
|
472
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
409
473
|
) -> AgentRunResult[RunOutputDataT]: ...
|
|
410
474
|
|
|
411
475
|
@overload
|
|
@@ -422,6 +486,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
422
486
|
usage_limits: _usage.UsageLimits | None = None,
|
|
423
487
|
usage: _usage.Usage | None = None,
|
|
424
488
|
infer_name: bool = True,
|
|
489
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
425
490
|
) -> AgentRunResult[RunOutputDataT]: ...
|
|
426
491
|
|
|
427
492
|
async def run(
|
|
@@ -436,6 +501,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
436
501
|
usage_limits: _usage.UsageLimits | None = None,
|
|
437
502
|
usage: _usage.Usage | None = None,
|
|
438
503
|
infer_name: bool = True,
|
|
504
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
439
505
|
**_deprecated_kwargs: Never,
|
|
440
506
|
) -> AgentRunResult[Any]:
|
|
441
507
|
"""Run the agent with a user prompt in async mode.
|
|
@@ -466,6 +532,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
466
532
|
usage_limits: Optional limits on model request count or token usage.
|
|
467
533
|
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
|
|
468
534
|
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
535
|
+
toolsets: Optional additional toolsets for this run.
|
|
469
536
|
|
|
470
537
|
Returns:
|
|
471
538
|
The result of the run.
|
|
@@ -490,6 +557,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
490
557
|
model_settings=model_settings,
|
|
491
558
|
usage_limits=usage_limits,
|
|
492
559
|
usage=usage,
|
|
560
|
+
toolsets=toolsets,
|
|
493
561
|
) as agent_run:
|
|
494
562
|
async for _ in agent_run:
|
|
495
563
|
pass
|
|
@@ -510,6 +578,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
510
578
|
usage_limits: _usage.UsageLimits | None = None,
|
|
511
579
|
usage: _usage.Usage | None = None,
|
|
512
580
|
infer_name: bool = True,
|
|
581
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
513
582
|
**_deprecated_kwargs: Never,
|
|
514
583
|
) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ...
|
|
515
584
|
|
|
@@ -526,6 +595,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
526
595
|
usage_limits: _usage.UsageLimits | None = None,
|
|
527
596
|
usage: _usage.Usage | None = None,
|
|
528
597
|
infer_name: bool = True,
|
|
598
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
529
599
|
**_deprecated_kwargs: Never,
|
|
530
600
|
) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ...
|
|
531
601
|
|
|
@@ -543,6 +613,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
543
613
|
usage_limits: _usage.UsageLimits | None = None,
|
|
544
614
|
usage: _usage.Usage | None = None,
|
|
545
615
|
infer_name: bool = True,
|
|
616
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
546
617
|
) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, Any]]: ...
|
|
547
618
|
|
|
548
619
|
@asynccontextmanager
|
|
@@ -558,6 +629,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
558
629
|
usage_limits: _usage.UsageLimits | None = None,
|
|
559
630
|
usage: _usage.Usage | None = None,
|
|
560
631
|
infer_name: bool = True,
|
|
632
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
561
633
|
**_deprecated_kwargs: Never,
|
|
562
634
|
) -> AsyncIterator[AgentRun[AgentDepsT, Any]]:
|
|
563
635
|
"""A contextmanager which can be used to iterate over the agent graph's nodes as they are executed.
|
|
@@ -632,6 +704,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
632
704
|
usage_limits: Optional limits on model request count or token usage.
|
|
633
705
|
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
|
|
634
706
|
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
707
|
+
toolsets: Optional additional toolsets for this run.
|
|
635
708
|
|
|
636
709
|
Returns:
|
|
637
710
|
The result of the run.
|
|
@@ -655,6 +728,18 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
655
728
|
|
|
656
729
|
output_type_ = output_type or self.output_type
|
|
657
730
|
|
|
731
|
+
# We consider it a user error if a user tries to restrict the result type while having an output validator that
|
|
732
|
+
# may change the result type from the restricted type to something else. Therefore, we consider the following
|
|
733
|
+
# typecast reasonable, even though it is possible to violate it with otherwise-type-checked code.
|
|
734
|
+
output_validators = cast(list[_output.OutputValidator[AgentDepsT, RunOutputDataT]], self._output_validators)
|
|
735
|
+
|
|
736
|
+
output_toolset = self._output_toolset
|
|
737
|
+
if output_schema != self._output_schema or output_validators:
|
|
738
|
+
output_toolset = cast(OutputToolset[AgentDepsT], output_schema.toolset)
|
|
739
|
+
if output_toolset:
|
|
740
|
+
output_toolset.max_retries = self._max_result_retries
|
|
741
|
+
output_toolset.output_validators = output_validators
|
|
742
|
+
|
|
658
743
|
# Build the graph
|
|
659
744
|
graph: Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[Any]] = (
|
|
660
745
|
_agent_graph.build_agent_graph(self.name, self._deps_type, output_type_)
|
|
@@ -669,22 +754,32 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
669
754
|
run_step=0,
|
|
670
755
|
)
|
|
671
756
|
|
|
672
|
-
# We consider it a user error if a user tries to restrict the result type while having an output validator that
|
|
673
|
-
# may change the result type from the restricted type to something else. Therefore, we consider the following
|
|
674
|
-
# typecast reasonable, even though it is possible to violate it with otherwise-type-checked code.
|
|
675
|
-
output_validators = cast(list[_output.OutputValidator[AgentDepsT, RunOutputDataT]], self._output_validators)
|
|
676
|
-
|
|
677
|
-
# Merge model settings in order of precedence: run > agent > model
|
|
678
|
-
merged_settings = merge_model_settings(model_used.settings, self.model_settings)
|
|
679
|
-
model_settings = merge_model_settings(merged_settings, model_settings)
|
|
680
|
-
usage_limits = usage_limits or _usage.UsageLimits()
|
|
681
|
-
|
|
682
757
|
if isinstance(model_used, InstrumentedModel):
|
|
683
758
|
instrumentation_settings = model_used.instrumentation_settings
|
|
684
759
|
tracer = model_used.instrumentation_settings.tracer
|
|
685
760
|
else:
|
|
686
761
|
instrumentation_settings = None
|
|
687
762
|
tracer = NoOpTracer()
|
|
763
|
+
|
|
764
|
+
run_context = RunContext[AgentDepsT](
|
|
765
|
+
deps=deps,
|
|
766
|
+
model=model_used,
|
|
767
|
+
usage=usage,
|
|
768
|
+
prompt=user_prompt,
|
|
769
|
+
messages=state.message_history,
|
|
770
|
+
tracer=tracer,
|
|
771
|
+
trace_include_content=instrumentation_settings is not None and instrumentation_settings.include_content,
|
|
772
|
+
run_step=state.run_step,
|
|
773
|
+
)
|
|
774
|
+
|
|
775
|
+
toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets)
|
|
776
|
+
# This will raise errors for any name conflicts
|
|
777
|
+
run_toolset = await ToolManager[AgentDepsT].build(toolset, run_context)
|
|
778
|
+
|
|
779
|
+
# Merge model settings in order of precedence: run > agent > model
|
|
780
|
+
merged_settings = merge_model_settings(model_used.settings, self.model_settings)
|
|
781
|
+
model_settings = merge_model_settings(merged_settings, model_settings)
|
|
782
|
+
usage_limits = usage_limits or _usage.UsageLimits()
|
|
688
783
|
agent_name = self.name or 'agent'
|
|
689
784
|
run_span = tracer.start_span(
|
|
690
785
|
'agent run',
|
|
@@ -711,10 +806,6 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
711
806
|
return None
|
|
712
807
|
return '\n\n'.join(parts).strip()
|
|
713
808
|
|
|
714
|
-
# Copy the function tools so that retry state is agent-run-specific
|
|
715
|
-
# Note that the retry count is reset to 0 when this happens due to the `default=0` and `init=False`.
|
|
716
|
-
run_function_tools = {k: dataclasses.replace(v) for k, v in self._function_tools.items()}
|
|
717
|
-
|
|
718
809
|
graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT](
|
|
719
810
|
user_deps=deps,
|
|
720
811
|
prompt=user_prompt,
|
|
@@ -727,11 +818,8 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
727
818
|
output_schema=output_schema,
|
|
728
819
|
output_validators=output_validators,
|
|
729
820
|
history_processors=self.history_processors,
|
|
730
|
-
|
|
731
|
-
mcp_servers=self._mcp_servers,
|
|
732
|
-
default_retries=self._default_retries,
|
|
821
|
+
tool_manager=run_toolset,
|
|
733
822
|
tracer=tracer,
|
|
734
|
-
prepare_tools=self._prepare_tools,
|
|
735
823
|
get_instructions=get_instructions,
|
|
736
824
|
instrumentation_settings=instrumentation_settings,
|
|
737
825
|
)
|
|
@@ -801,6 +889,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
801
889
|
usage_limits: _usage.UsageLimits | None = None,
|
|
802
890
|
usage: _usage.Usage | None = None,
|
|
803
891
|
infer_name: bool = True,
|
|
892
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
804
893
|
) -> AgentRunResult[OutputDataT]: ...
|
|
805
894
|
|
|
806
895
|
@overload
|
|
@@ -816,6 +905,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
816
905
|
usage_limits: _usage.UsageLimits | None = None,
|
|
817
906
|
usage: _usage.Usage | None = None,
|
|
818
907
|
infer_name: bool = True,
|
|
908
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
819
909
|
) -> AgentRunResult[RunOutputDataT]: ...
|
|
820
910
|
|
|
821
911
|
@overload
|
|
@@ -832,6 +922,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
832
922
|
usage_limits: _usage.UsageLimits | None = None,
|
|
833
923
|
usage: _usage.Usage | None = None,
|
|
834
924
|
infer_name: bool = True,
|
|
925
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
835
926
|
) -> AgentRunResult[RunOutputDataT]: ...
|
|
836
927
|
|
|
837
928
|
def run_sync(
|
|
@@ -846,6 +937,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
846
937
|
usage_limits: _usage.UsageLimits | None = None,
|
|
847
938
|
usage: _usage.Usage | None = None,
|
|
848
939
|
infer_name: bool = True,
|
|
940
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
849
941
|
**_deprecated_kwargs: Never,
|
|
850
942
|
) -> AgentRunResult[Any]:
|
|
851
943
|
"""Synchronously run the agent with a user prompt.
|
|
@@ -875,6 +967,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
875
967
|
usage_limits: Optional limits on model request count or token usage.
|
|
876
968
|
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
|
|
877
969
|
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
970
|
+
toolsets: Optional additional toolsets for this run.
|
|
878
971
|
|
|
879
972
|
Returns:
|
|
880
973
|
The result of the run.
|
|
@@ -901,6 +994,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
901
994
|
usage_limits=usage_limits,
|
|
902
995
|
usage=usage,
|
|
903
996
|
infer_name=False,
|
|
997
|
+
toolsets=toolsets,
|
|
904
998
|
)
|
|
905
999
|
)
|
|
906
1000
|
|
|
@@ -916,6 +1010,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
916
1010
|
usage_limits: _usage.UsageLimits | None = None,
|
|
917
1011
|
usage: _usage.Usage | None = None,
|
|
918
1012
|
infer_name: bool = True,
|
|
1013
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
919
1014
|
) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, OutputDataT]]: ...
|
|
920
1015
|
|
|
921
1016
|
@overload
|
|
@@ -931,6 +1026,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
931
1026
|
usage_limits: _usage.UsageLimits | None = None,
|
|
932
1027
|
usage: _usage.Usage | None = None,
|
|
933
1028
|
infer_name: bool = True,
|
|
1029
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
934
1030
|
) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ...
|
|
935
1031
|
|
|
936
1032
|
@overload
|
|
@@ -947,6 +1043,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
947
1043
|
usage_limits: _usage.UsageLimits | None = None,
|
|
948
1044
|
usage: _usage.Usage | None = None,
|
|
949
1045
|
infer_name: bool = True,
|
|
1046
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
950
1047
|
) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ...
|
|
951
1048
|
|
|
952
1049
|
@asynccontextmanager
|
|
@@ -962,6 +1059,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
962
1059
|
usage_limits: _usage.UsageLimits | None = None,
|
|
963
1060
|
usage: _usage.Usage | None = None,
|
|
964
1061
|
infer_name: bool = True,
|
|
1062
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
965
1063
|
**_deprecated_kwargs: Never,
|
|
966
1064
|
) -> AsyncIterator[result.StreamedRunResult[AgentDepsT, Any]]:
|
|
967
1065
|
"""Run the agent with a user prompt in async mode, returning a streamed response.
|
|
@@ -989,6 +1087,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
989
1087
|
usage_limits: Optional limits on model request count or token usage.
|
|
990
1088
|
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
|
|
991
1089
|
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
1090
|
+
toolsets: Optional additional toolsets for this run.
|
|
992
1091
|
|
|
993
1092
|
Returns:
|
|
994
1093
|
The result of the run.
|
|
@@ -1019,6 +1118,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1019
1118
|
usage_limits=usage_limits,
|
|
1020
1119
|
usage=usage,
|
|
1021
1120
|
infer_name=False,
|
|
1121
|
+
toolsets=toolsets,
|
|
1022
1122
|
) as agent_run:
|
|
1023
1123
|
first_node = agent_run.next_node # start with the first node
|
|
1024
1124
|
assert isinstance(first_node, _agent_graph.UserPromptNode) # the first node should be a user prompt node
|
|
@@ -1039,15 +1139,17 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1039
1139
|
output_schema, _output.TextOutputSchema
|
|
1040
1140
|
):
|
|
1041
1141
|
return FinalResult(s, None, None)
|
|
1042
|
-
elif isinstance(new_part, _messages.ToolCallPart) and
|
|
1043
|
-
|
|
1044
|
-
):
|
|
1045
|
-
|
|
1046
|
-
return FinalResult(s,
|
|
1142
|
+
elif isinstance(new_part, _messages.ToolCallPart) and (
|
|
1143
|
+
tool_def := graph_ctx.deps.tool_manager.get_tool_def(new_part.tool_name)
|
|
1144
|
+
):
|
|
1145
|
+
if tool_def.kind == 'output':
|
|
1146
|
+
return FinalResult(s, new_part.tool_name, new_part.tool_call_id)
|
|
1147
|
+
elif tool_def.kind == 'deferred':
|
|
1148
|
+
return FinalResult(s, None, None)
|
|
1047
1149
|
return None
|
|
1048
1150
|
|
|
1049
|
-
|
|
1050
|
-
if
|
|
1151
|
+
final_result = await stream_to_final(streamed_response)
|
|
1152
|
+
if final_result is not None:
|
|
1051
1153
|
if yielded:
|
|
1052
1154
|
raise exceptions.AgentRunError('Agent run produced final results') # pragma: no cover
|
|
1053
1155
|
yielded = True
|
|
@@ -1068,17 +1170,13 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1068
1170
|
|
|
1069
1171
|
parts: list[_messages.ModelRequestPart] = []
|
|
1070
1172
|
async for _event in _agent_graph.process_function_tools(
|
|
1173
|
+
graph_ctx.deps.tool_manager,
|
|
1071
1174
|
tool_calls,
|
|
1072
|
-
|
|
1073
|
-
final_result_details.tool_call_id,
|
|
1175
|
+
final_result,
|
|
1074
1176
|
graph_ctx,
|
|
1075
1177
|
parts,
|
|
1076
1178
|
):
|
|
1077
1179
|
pass
|
|
1078
|
-
# TODO: Should we do something here related to the retry count?
|
|
1079
|
-
# Maybe we should move the incrementing of the retry count to where we actually make a request?
|
|
1080
|
-
# if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
|
|
1081
|
-
# ctx.state.increment_retries(ctx.deps.max_result_retries)
|
|
1082
1180
|
if parts:
|
|
1083
1181
|
messages.append(_messages.ModelRequest(parts))
|
|
1084
1182
|
|
|
@@ -1089,10 +1187,10 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1089
1187
|
streamed_response,
|
|
1090
1188
|
graph_ctx.deps.output_schema,
|
|
1091
1189
|
_agent_graph.build_run_context(graph_ctx),
|
|
1092
|
-
_output.build_trace_context(graph_ctx),
|
|
1093
1190
|
graph_ctx.deps.output_validators,
|
|
1094
|
-
|
|
1191
|
+
final_result.tool_name,
|
|
1095
1192
|
on_complete,
|
|
1193
|
+
graph_ctx.deps.tool_manager,
|
|
1096
1194
|
)
|
|
1097
1195
|
break
|
|
1098
1196
|
next_node = await agent_run.next(node)
|
|
@@ -1111,8 +1209,9 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1111
1209
|
*,
|
|
1112
1210
|
deps: AgentDepsT | _utils.Unset = _utils.UNSET,
|
|
1113
1211
|
model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET,
|
|
1212
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET,
|
|
1114
1213
|
) -> Iterator[None]:
|
|
1115
|
-
"""Context manager to temporarily override agent dependencies
|
|
1214
|
+
"""Context manager to temporarily override agent dependencies, model, or toolsets.
|
|
1116
1215
|
|
|
1117
1216
|
This is particularly useful when testing.
|
|
1118
1217
|
You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures).
|
|
@@ -1120,6 +1219,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1120
1219
|
Args:
|
|
1121
1220
|
deps: The dependencies to use instead of the dependencies passed to the agent run.
|
|
1122
1221
|
model: The model to use instead of the model passed to the agent run.
|
|
1222
|
+
toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run.
|
|
1123
1223
|
"""
|
|
1124
1224
|
if _utils.is_set(deps):
|
|
1125
1225
|
deps_token = self._override_deps.set(_utils.Some(deps))
|
|
@@ -1131,6 +1231,11 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1131
1231
|
else:
|
|
1132
1232
|
model_token = None
|
|
1133
1233
|
|
|
1234
|
+
if _utils.is_set(toolsets):
|
|
1235
|
+
toolsets_token = self._override_toolsets.set(_utils.Some(toolsets))
|
|
1236
|
+
else:
|
|
1237
|
+
toolsets_token = None
|
|
1238
|
+
|
|
1134
1239
|
try:
|
|
1135
1240
|
yield
|
|
1136
1241
|
finally:
|
|
@@ -1138,6 +1243,8 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1138
1243
|
self._override_deps.reset(deps_token)
|
|
1139
1244
|
if model_token is not None:
|
|
1140
1245
|
self._override_model.reset(model_token)
|
|
1246
|
+
if toolsets_token is not None:
|
|
1247
|
+
self._override_toolsets.reset(toolsets_token)
|
|
1141
1248
|
|
|
1142
1249
|
@overload
|
|
1143
1250
|
def instructions(
|
|
@@ -1423,30 +1530,13 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1423
1530
|
strict: Whether to enforce JSON schema compliance (only affects OpenAI).
|
|
1424
1531
|
See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
|
|
1425
1532
|
"""
|
|
1426
|
-
if func is None:
|
|
1427
|
-
|
|
1428
|
-
def tool_decorator(
|
|
1429
|
-
func_: ToolFuncContext[AgentDepsT, ToolParams],
|
|
1430
|
-
) -> ToolFuncContext[AgentDepsT, ToolParams]:
|
|
1431
|
-
# noinspection PyTypeChecker
|
|
1432
|
-
self._register_function(
|
|
1433
|
-
func_,
|
|
1434
|
-
True,
|
|
1435
|
-
name,
|
|
1436
|
-
retries,
|
|
1437
|
-
prepare,
|
|
1438
|
-
docstring_format,
|
|
1439
|
-
require_parameter_descriptions,
|
|
1440
|
-
schema_generator,
|
|
1441
|
-
strict,
|
|
1442
|
-
)
|
|
1443
|
-
return func_
|
|
1444
1533
|
|
|
1445
|
-
|
|
1446
|
-
|
|
1534
|
+
def tool_decorator(
|
|
1535
|
+
func_: ToolFuncContext[AgentDepsT, ToolParams],
|
|
1536
|
+
) -> ToolFuncContext[AgentDepsT, ToolParams]:
|
|
1447
1537
|
# noinspection PyTypeChecker
|
|
1448
|
-
self.
|
|
1449
|
-
|
|
1538
|
+
self._function_toolset.add_function(
|
|
1539
|
+
func_,
|
|
1450
1540
|
True,
|
|
1451
1541
|
name,
|
|
1452
1542
|
retries,
|
|
@@ -1456,7 +1546,9 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1456
1546
|
schema_generator,
|
|
1457
1547
|
strict,
|
|
1458
1548
|
)
|
|
1459
|
-
return
|
|
1549
|
+
return func_
|
|
1550
|
+
|
|
1551
|
+
return tool_decorator if func is None else tool_decorator(func)
|
|
1460
1552
|
|
|
1461
1553
|
@overload
|
|
1462
1554
|
def tool_plain(self, func: ToolFuncPlain[ToolParams], /) -> ToolFuncPlain[ToolParams]: ...
|
|
@@ -1532,27 +1624,11 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1532
1624
|
strict: Whether to enforce JSON schema compliance (only affects OpenAI).
|
|
1533
1625
|
See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
|
|
1534
1626
|
"""
|
|
1535
|
-
if func is None:
|
|
1536
1627
|
|
|
1537
|
-
|
|
1538
|
-
|
|
1539
|
-
|
|
1540
|
-
|
|
1541
|
-
False,
|
|
1542
|
-
name,
|
|
1543
|
-
retries,
|
|
1544
|
-
prepare,
|
|
1545
|
-
docstring_format,
|
|
1546
|
-
require_parameter_descriptions,
|
|
1547
|
-
schema_generator,
|
|
1548
|
-
strict,
|
|
1549
|
-
)
|
|
1550
|
-
return func_
|
|
1551
|
-
|
|
1552
|
-
return tool_decorator
|
|
1553
|
-
else:
|
|
1554
|
-
self._register_function(
|
|
1555
|
-
func,
|
|
1628
|
+
def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
|
|
1629
|
+
# noinspection PyTypeChecker
|
|
1630
|
+
self._function_toolset.add_function(
|
|
1631
|
+
func_,
|
|
1556
1632
|
False,
|
|
1557
1633
|
name,
|
|
1558
1634
|
retries,
|
|
@@ -1562,48 +1638,9 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1562
1638
|
schema_generator,
|
|
1563
1639
|
strict,
|
|
1564
1640
|
)
|
|
1565
|
-
return
|
|
1566
|
-
|
|
1567
|
-
def _register_function(
|
|
1568
|
-
self,
|
|
1569
|
-
func: ToolFuncEither[AgentDepsT, ToolParams],
|
|
1570
|
-
takes_ctx: bool,
|
|
1571
|
-
name: str | None,
|
|
1572
|
-
retries: int | None,
|
|
1573
|
-
prepare: ToolPrepareFunc[AgentDepsT] | None,
|
|
1574
|
-
docstring_format: DocstringFormat,
|
|
1575
|
-
require_parameter_descriptions: bool,
|
|
1576
|
-
schema_generator: type[GenerateJsonSchema],
|
|
1577
|
-
strict: bool | None,
|
|
1578
|
-
) -> None:
|
|
1579
|
-
"""Private utility to register a function as a tool."""
|
|
1580
|
-
retries_ = retries if retries is not None else self._default_retries
|
|
1581
|
-
tool = Tool[AgentDepsT](
|
|
1582
|
-
func,
|
|
1583
|
-
takes_ctx=takes_ctx,
|
|
1584
|
-
name=name,
|
|
1585
|
-
max_retries=retries_,
|
|
1586
|
-
prepare=prepare,
|
|
1587
|
-
docstring_format=docstring_format,
|
|
1588
|
-
require_parameter_descriptions=require_parameter_descriptions,
|
|
1589
|
-
schema_generator=schema_generator,
|
|
1590
|
-
strict=strict,
|
|
1591
|
-
)
|
|
1592
|
-
self._register_tool(tool)
|
|
1593
|
-
|
|
1594
|
-
def _register_tool(self, tool: Tool[AgentDepsT]) -> None:
|
|
1595
|
-
"""Private utility to register a tool instance."""
|
|
1596
|
-
if tool.max_retries is None:
|
|
1597
|
-
# noinspection PyTypeChecker
|
|
1598
|
-
tool = dataclasses.replace(tool, max_retries=self._default_retries)
|
|
1599
|
-
|
|
1600
|
-
if tool.name in self._function_tools:
|
|
1601
|
-
raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}')
|
|
1641
|
+
return func_
|
|
1602
1642
|
|
|
1603
|
-
if
|
|
1604
|
-
raise exceptions.UserError(f'Tool name conflicts with output tool name: {tool.name!r}')
|
|
1605
|
-
|
|
1606
|
-
self._function_tools[tool.name] = tool
|
|
1643
|
+
return tool_decorator if func is None else tool_decorator(func)
|
|
1607
1644
|
|
|
1608
1645
|
def _get_model(self, model: models.Model | models.KnownModelName | str | None) -> models.Model:
|
|
1609
1646
|
"""Create a model configured for this agent.
|
|
@@ -1649,6 +1686,37 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1649
1686
|
else:
|
|
1650
1687
|
return deps
|
|
1651
1688
|
|
|
1689
|
+
def _get_toolset(
|
|
1690
|
+
self,
|
|
1691
|
+
output_toolset: AbstractToolset[AgentDepsT] | None | _utils.Unset = _utils.UNSET,
|
|
1692
|
+
additional_toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
1693
|
+
) -> AbstractToolset[AgentDepsT]:
|
|
1694
|
+
"""Get the complete toolset.
|
|
1695
|
+
|
|
1696
|
+
Args:
|
|
1697
|
+
output_toolset: The output toolset to use instead of the one built at agent construction time.
|
|
1698
|
+
additional_toolsets: Additional toolsets to add.
|
|
1699
|
+
"""
|
|
1700
|
+
if some_user_toolsets := self._override_toolsets.get():
|
|
1701
|
+
user_toolsets = some_user_toolsets.value
|
|
1702
|
+
elif additional_toolsets is not None:
|
|
1703
|
+
user_toolsets = [*self._user_toolsets, *additional_toolsets]
|
|
1704
|
+
else:
|
|
1705
|
+
user_toolsets = self._user_toolsets
|
|
1706
|
+
|
|
1707
|
+
all_toolsets = [self._function_toolset, *user_toolsets]
|
|
1708
|
+
|
|
1709
|
+
if self._prepare_tools:
|
|
1710
|
+
all_toolsets = [PreparedToolset(CombinedToolset(all_toolsets), self._prepare_tools)]
|
|
1711
|
+
|
|
1712
|
+
output_toolset = output_toolset if _utils.is_set(output_toolset) else self._output_toolset
|
|
1713
|
+
if output_toolset is not None:
|
|
1714
|
+
if self._prepare_output_tools:
|
|
1715
|
+
output_toolset = PreparedToolset(output_toolset, self._prepare_output_tools)
|
|
1716
|
+
all_toolsets = [output_toolset, *all_toolsets]
|
|
1717
|
+
|
|
1718
|
+
return CombinedToolset(all_toolsets)
|
|
1719
|
+
|
|
1652
1720
|
def _infer_name(self, function_frame: FrameType | None) -> None:
|
|
1653
1721
|
"""Infer the agent name from the call frame.
|
|
1654
1722
|
|
|
@@ -1734,28 +1802,167 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1734
1802
|
"""
|
|
1735
1803
|
return isinstance(node, End)
|
|
1736
1804
|
|
|
1805
|
+
async def __aenter__(self) -> Self:
|
|
1806
|
+
"""Enter the agent context.
|
|
1807
|
+
|
|
1808
|
+
This will start all [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] registered as `toolsets` so they are ready to be used.
|
|
1809
|
+
|
|
1810
|
+
This is a no-op if the agent has already been entered.
|
|
1811
|
+
"""
|
|
1812
|
+
async with self._enter_lock:
|
|
1813
|
+
if self._entered_count == 0:
|
|
1814
|
+
self._exit_stack = AsyncExitStack()
|
|
1815
|
+
toolset = self._get_toolset()
|
|
1816
|
+
await self._exit_stack.enter_async_context(toolset)
|
|
1817
|
+
self._entered_count += 1
|
|
1818
|
+
return self
|
|
1819
|
+
|
|
1820
|
+
async def __aexit__(self, *args: Any) -> bool | None:
|
|
1821
|
+
async with self._enter_lock:
|
|
1822
|
+
self._entered_count -= 1
|
|
1823
|
+
if self._entered_count == 0 and self._exit_stack is not None:
|
|
1824
|
+
await self._exit_stack.aclose()
|
|
1825
|
+
self._exit_stack = None
|
|
1826
|
+
|
|
1827
|
+
def set_mcp_sampling_model(self, model: models.Model | models.KnownModelName | str | None = None) -> None:
|
|
1828
|
+
"""Set the sampling model on all MCP servers registered with the agent.
|
|
1829
|
+
|
|
1830
|
+
If no sampling model is provided, the agent's model will be used.
|
|
1831
|
+
"""
|
|
1832
|
+
try:
|
|
1833
|
+
sampling_model = models.infer_model(model) if model else self._get_model(None)
|
|
1834
|
+
except exceptions.UserError as e:
|
|
1835
|
+
raise exceptions.UserError('No sampling model provided and no model set on the agent.') from e
|
|
1836
|
+
|
|
1837
|
+
from .mcp import MCPServer
|
|
1838
|
+
|
|
1839
|
+
def _set_sampling_model(toolset: AbstractToolset[AgentDepsT]) -> None:
|
|
1840
|
+
if isinstance(toolset, MCPServer):
|
|
1841
|
+
toolset.sampling_model = sampling_model
|
|
1842
|
+
|
|
1843
|
+
self._get_toolset().apply(_set_sampling_model)
|
|
1844
|
+
|
|
1737
1845
|
@asynccontextmanager
|
|
1846
|
+
@deprecated(
|
|
1847
|
+
'`run_mcp_servers` is deprecated, use `async with agent:` instead. If you need to set a sampling model on all MCP servers, use `agent.set_mcp_sampling_model()`.'
|
|
1848
|
+
)
|
|
1738
1849
|
async def run_mcp_servers(
|
|
1739
1850
|
self, model: models.Model | models.KnownModelName | str | None = None
|
|
1740
1851
|
) -> AsyncIterator[None]:
|
|
1741
1852
|
"""Run [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] so they can be used by the agent.
|
|
1742
1853
|
|
|
1854
|
+
Deprecated: use [`async with agent`][pydantic_ai.agent.Agent.__aenter__] instead.
|
|
1855
|
+
If you need to set a sampling model on all MCP servers, use [`agent.set_mcp_sampling_model()`][pydantic_ai.agent.Agent.set_mcp_sampling_model].
|
|
1856
|
+
|
|
1743
1857
|
Returns: a context manager to start and shutdown the servers.
|
|
1744
1858
|
"""
|
|
1745
1859
|
try:
|
|
1746
|
-
|
|
1747
|
-
except exceptions.UserError:
|
|
1748
|
-
|
|
1860
|
+
self.set_mcp_sampling_model(model)
|
|
1861
|
+
except exceptions.UserError:
|
|
1862
|
+
if model is not None:
|
|
1863
|
+
raise
|
|
1749
1864
|
|
|
1750
|
-
|
|
1751
|
-
try:
|
|
1752
|
-
for mcp_server in self._mcp_servers:
|
|
1753
|
-
if sampling_model is not None: # pragma: no branch
|
|
1754
|
-
mcp_server.sampling_model = sampling_model
|
|
1755
|
-
await exit_stack.enter_async_context(mcp_server)
|
|
1865
|
+
async with self:
|
|
1756
1866
|
yield
|
|
1757
|
-
|
|
1758
|
-
|
|
1867
|
+
|
|
1868
|
+
def to_ag_ui(
|
|
1869
|
+
self,
|
|
1870
|
+
*,
|
|
1871
|
+
# Agent.iter parameters
|
|
1872
|
+
output_type: OutputSpec[OutputDataT] | None = None,
|
|
1873
|
+
model: models.Model | models.KnownModelName | str | None = None,
|
|
1874
|
+
deps: AgentDepsT = None,
|
|
1875
|
+
model_settings: ModelSettings | None = None,
|
|
1876
|
+
usage_limits: UsageLimits | None = None,
|
|
1877
|
+
usage: Usage | None = None,
|
|
1878
|
+
infer_name: bool = True,
|
|
1879
|
+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
|
|
1880
|
+
# Starlette
|
|
1881
|
+
debug: bool = False,
|
|
1882
|
+
routes: Sequence[BaseRoute] | None = None,
|
|
1883
|
+
middleware: Sequence[Middleware] | None = None,
|
|
1884
|
+
exception_handlers: Mapping[Any, ExceptionHandler] | None = None,
|
|
1885
|
+
on_startup: Sequence[Callable[[], Any]] | None = None,
|
|
1886
|
+
on_shutdown: Sequence[Callable[[], Any]] | None = None,
|
|
1887
|
+
lifespan: Lifespan[AGUIApp[AgentDepsT, OutputDataT]] | None = None,
|
|
1888
|
+
) -> AGUIApp[AgentDepsT, OutputDataT]:
|
|
1889
|
+
"""Convert the agent to an AG-UI application.
|
|
1890
|
+
|
|
1891
|
+
This allows you to use the agent with a compatible AG-UI frontend.
|
|
1892
|
+
|
|
1893
|
+
Example:
|
|
1894
|
+
```python
|
|
1895
|
+
from pydantic_ai import Agent
|
|
1896
|
+
|
|
1897
|
+
agent = Agent('openai:gpt-4o')
|
|
1898
|
+
app = agent.to_ag_ui()
|
|
1899
|
+
```
|
|
1900
|
+
|
|
1901
|
+
The `app` is an ASGI application that can be used with any ASGI server.
|
|
1902
|
+
|
|
1903
|
+
To run the application, you can use the following command:
|
|
1904
|
+
|
|
1905
|
+
```bash
|
|
1906
|
+
uvicorn app:app --host 0.0.0.0 --port 8000
|
|
1907
|
+
```
|
|
1908
|
+
|
|
1909
|
+
See [AG-UI docs](../ag-ui.md) for more information.
|
|
1910
|
+
|
|
1911
|
+
Args:
|
|
1912
|
+
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has
|
|
1913
|
+
no output validators since output validators would expect an argument that matches the agent's
|
|
1914
|
+
output type.
|
|
1915
|
+
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
1916
|
+
deps: Optional dependencies to use for this run.
|
|
1917
|
+
model_settings: Optional settings to use for this model's request.
|
|
1918
|
+
usage_limits: Optional limits on model request count or token usage.
|
|
1919
|
+
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
|
|
1920
|
+
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
1921
|
+
toolsets: Optional list of toolsets to use for this agent, defaults to the agent's toolset.
|
|
1922
|
+
|
|
1923
|
+
debug: Boolean indicating if debug tracebacks should be returned on errors.
|
|
1924
|
+
routes: A list of routes to serve incoming HTTP and WebSocket requests.
|
|
1925
|
+
middleware: A list of middleware to run for every request. A starlette application will always
|
|
1926
|
+
automatically include two middleware classes. `ServerErrorMiddleware` is added as the very
|
|
1927
|
+
outermost middleware, to handle any uncaught errors occurring anywhere in the entire stack.
|
|
1928
|
+
`ExceptionMiddleware` is added as the very innermost middleware, to deal with handled
|
|
1929
|
+
exception cases occurring in the routing or endpoints.
|
|
1930
|
+
exception_handlers: A mapping of either integer status codes, or exception class types onto
|
|
1931
|
+
callables which handle the exceptions. Exception handler callables should be of the form
|
|
1932
|
+
`handler(request, exc) -> response` and may be either standard functions, or async functions.
|
|
1933
|
+
on_startup: A list of callables to run on application startup. Startup handler callables do not
|
|
1934
|
+
take any arguments, and may be either standard functions, or async functions.
|
|
1935
|
+
on_shutdown: A list of callables to run on application shutdown. Shutdown handler callables do
|
|
1936
|
+
not take any arguments, and may be either standard functions, or async functions.
|
|
1937
|
+
lifespan: A lifespan context function, which can be used to perform startup and shutdown tasks.
|
|
1938
|
+
This is a newer style that replaces the `on_startup` and `on_shutdown` handlers. Use one or
|
|
1939
|
+
the other, not both.
|
|
1940
|
+
|
|
1941
|
+
Returns:
|
|
1942
|
+
An ASGI application for running Pydantic AI agents with AG-UI protocol support.
|
|
1943
|
+
"""
|
|
1944
|
+
from .ag_ui import AGUIApp
|
|
1945
|
+
|
|
1946
|
+
return AGUIApp(
|
|
1947
|
+
agent=self,
|
|
1948
|
+
# Agent.iter parameters
|
|
1949
|
+
output_type=output_type,
|
|
1950
|
+
model=model,
|
|
1951
|
+
deps=deps,
|
|
1952
|
+
model_settings=model_settings,
|
|
1953
|
+
usage_limits=usage_limits,
|
|
1954
|
+
usage=usage,
|
|
1955
|
+
infer_name=infer_name,
|
|
1956
|
+
toolsets=toolsets,
|
|
1957
|
+
# Starlette
|
|
1958
|
+
debug=debug,
|
|
1959
|
+
routes=routes,
|
|
1960
|
+
middleware=middleware,
|
|
1961
|
+
exception_handlers=exception_handlers,
|
|
1962
|
+
on_startup=on_startup,
|
|
1963
|
+
on_shutdown=on_shutdown,
|
|
1964
|
+
lifespan=lifespan,
|
|
1965
|
+
)
|
|
1759
1966
|
|
|
1760
1967
|
def to_a2a(
|
|
1761
1968
|
self,
|
|
@@ -2112,12 +2319,18 @@ class AgentRunResult(Generic[OutputDataT]):
|
|
|
2112
2319
|
"""
|
|
2113
2320
|
if not self._output_tool_name:
|
|
2114
2321
|
raise ValueError('Cannot set output tool return content when the return type is `str`.')
|
|
2115
|
-
|
|
2322
|
+
|
|
2323
|
+
messages = self._state.message_history
|
|
2116
2324
|
last_message = messages[-1]
|
|
2117
|
-
for part in last_message.parts:
|
|
2325
|
+
for idx, part in enumerate(last_message.parts):
|
|
2118
2326
|
if isinstance(part, _messages.ToolReturnPart) and part.tool_name == self._output_tool_name:
|
|
2119
|
-
|
|
2120
|
-
|
|
2327
|
+
# Only do deepcopy when we have to modify
|
|
2328
|
+
copied_messages = list(messages)
|
|
2329
|
+
copied_last = deepcopy(last_message)
|
|
2330
|
+
copied_last.parts[idx].content = return_content # type: ignore[misc]
|
|
2331
|
+
copied_messages[-1] = copied_last
|
|
2332
|
+
return copied_messages
|
|
2333
|
+
|
|
2121
2334
|
raise LookupError(f'No tool call found with tool name {self._output_tool_name!r}.')
|
|
2122
2335
|
|
|
2123
2336
|
@overload
|