pydantic-ai-slim 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +5 -2
- pydantic_ai/_agent_graph.py +33 -15
- pydantic_ai/_cli.py +7 -3
- pydantic_ai/_function_schema.py +1 -4
- pydantic_ai/_mcp.py +123 -0
- pydantic_ai/_output.py +654 -159
- pydantic_ai/_run_context.py +56 -0
- pydantic_ai/_system_prompt.py +2 -1
- pydantic_ai/_utils.py +111 -1
- pydantic_ai/agent.py +66 -35
- pydantic_ai/mcp.py +144 -115
- pydantic_ai/models/__init__.py +21 -2
- pydantic_ai/models/function.py +21 -3
- pydantic_ai/models/gemini.py +27 -4
- pydantic_ai/models/google.py +29 -4
- pydantic_ai/models/mcp_sampling.py +95 -0
- pydantic_ai/models/mistral.py +5 -1
- pydantic_ai/models/openai.py +70 -9
- pydantic_ai/models/test.py +1 -1
- pydantic_ai/models/wrapper.py +6 -0
- pydantic_ai/output.py +288 -0
- pydantic_ai/profiles/__init__.py +21 -0
- pydantic_ai/profiles/_json_schema.py +1 -1
- pydantic_ai/profiles/google.py +6 -2
- pydantic_ai/profiles/openai.py +5 -0
- pydantic_ai/result.py +52 -26
- pydantic_ai/settings.py +1 -0
- pydantic_ai/tools.py +2 -47
- {pydantic_ai_slim-0.3.1.dist-info → pydantic_ai_slim-0.3.3.dist-info}/METADATA +4 -4
- {pydantic_ai_slim-0.3.1.dist-info → pydantic_ai_slim-0.3.3.dist-info}/RECORD +33 -29
- {pydantic_ai_slim-0.3.1.dist-info → pydantic_ai_slim-0.3.3.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.3.1.dist-info → pydantic_ai_slim-0.3.3.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.3.1.dist-info → pydantic_ai_slim-0.3.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
from collections.abc import Sequence
|
|
5
|
+
from dataclasses import field
|
|
6
|
+
from typing import TYPE_CHECKING, Generic
|
|
7
|
+
|
|
8
|
+
from typing_extensions import TypeVar
|
|
9
|
+
|
|
10
|
+
from . import _utils, messages as _messages
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from .models import Model
|
|
14
|
+
from .result import Usage
|
|
15
|
+
|
|
16
|
+
AgentDepsT = TypeVar('AgentDepsT', default=None, contravariant=True)
|
|
17
|
+
"""Type variable for agent dependencies."""
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclasses.dataclass(repr=False)
|
|
21
|
+
class RunContext(Generic[AgentDepsT]):
|
|
22
|
+
"""Information about the current call."""
|
|
23
|
+
|
|
24
|
+
deps: AgentDepsT
|
|
25
|
+
"""Dependencies for the agent."""
|
|
26
|
+
model: Model
|
|
27
|
+
"""The model used in this run."""
|
|
28
|
+
usage: Usage
|
|
29
|
+
"""LLM usage associated with the run."""
|
|
30
|
+
prompt: str | Sequence[_messages.UserContent] | None
|
|
31
|
+
"""The original user prompt passed to the run."""
|
|
32
|
+
messages: list[_messages.ModelMessage] = field(default_factory=list)
|
|
33
|
+
"""Messages exchanged in the conversation so far."""
|
|
34
|
+
tool_call_id: str | None = None
|
|
35
|
+
"""The ID of the tool call."""
|
|
36
|
+
tool_name: str | None = None
|
|
37
|
+
"""Name of the tool being called."""
|
|
38
|
+
retry: int = 0
|
|
39
|
+
"""Number of retries so far."""
|
|
40
|
+
run_step: int = 0
|
|
41
|
+
"""The current step in the run."""
|
|
42
|
+
|
|
43
|
+
def replace_with(
|
|
44
|
+
self,
|
|
45
|
+
retry: int | None = None,
|
|
46
|
+
tool_name: str | None | _utils.Unset = _utils.UNSET,
|
|
47
|
+
) -> RunContext[AgentDepsT]:
|
|
48
|
+
# Create a new `RunContext` a new `retry` value and `tool_name`.
|
|
49
|
+
kwargs = {}
|
|
50
|
+
if retry is not None:
|
|
51
|
+
kwargs['retry'] = retry
|
|
52
|
+
if tool_name is not _utils.UNSET: # pragma: no branch
|
|
53
|
+
kwargs['tool_name'] = tool_name
|
|
54
|
+
return dataclasses.replace(self, **kwargs)
|
|
55
|
+
|
|
56
|
+
__repr__ = _utils.dataclasses_no_defaults_repr
|
pydantic_ai/_system_prompt.py
CHANGED
|
@@ -6,7 +6,8 @@ from dataclasses import dataclass, field
|
|
|
6
6
|
from typing import Any, Callable, Generic, cast
|
|
7
7
|
|
|
8
8
|
from . import _utils
|
|
9
|
-
from .
|
|
9
|
+
from ._run_context import AgentDepsT, RunContext
|
|
10
|
+
from .tools import SystemPromptFunc
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
@dataclass
|
pydantic_ai/_utils.py
CHANGED
|
@@ -3,6 +3,7 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import functools
|
|
5
5
|
import inspect
|
|
6
|
+
import re
|
|
6
7
|
import time
|
|
7
8
|
import uuid
|
|
8
9
|
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterator
|
|
@@ -16,7 +17,17 @@ from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, overlo
|
|
|
16
17
|
from anyio.to_thread import run_sync
|
|
17
18
|
from pydantic import BaseModel, TypeAdapter
|
|
18
19
|
from pydantic.json_schema import JsonSchemaValue
|
|
19
|
-
from typing_extensions import
|
|
20
|
+
from typing_extensions import (
|
|
21
|
+
ParamSpec,
|
|
22
|
+
TypeAlias,
|
|
23
|
+
TypeGuard,
|
|
24
|
+
TypeIs,
|
|
25
|
+
get_args,
|
|
26
|
+
get_origin,
|
|
27
|
+
is_typeddict,
|
|
28
|
+
)
|
|
29
|
+
from typing_inspection import typing_objects
|
|
30
|
+
from typing_inspection.introspection import is_union_origin
|
|
20
31
|
|
|
21
32
|
from pydantic_graph._utils import AbstractSpan
|
|
22
33
|
|
|
@@ -327,3 +338,102 @@ def is_async_callable(obj: Any) -> Any:
|
|
|
327
338
|
obj = obj.func
|
|
328
339
|
|
|
329
340
|
return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__)) # type: ignore
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def _update_mapped_json_schema_refs(s: dict[str, Any], name_mapping: dict[str, str]) -> None:
|
|
344
|
+
"""Update $refs in a schema to use the new names from name_mapping."""
|
|
345
|
+
if '$ref' in s:
|
|
346
|
+
ref = s['$ref']
|
|
347
|
+
if ref.startswith('#/$defs/'): # pragma: no branch
|
|
348
|
+
original_name = ref[8:] # Remove '#/$defs/'
|
|
349
|
+
new_name = name_mapping.get(original_name, original_name)
|
|
350
|
+
s['$ref'] = f'#/$defs/{new_name}'
|
|
351
|
+
|
|
352
|
+
# Recursively update refs in properties
|
|
353
|
+
if 'properties' in s:
|
|
354
|
+
props: dict[str, dict[str, Any]] = s['properties']
|
|
355
|
+
for prop in props.values():
|
|
356
|
+
_update_mapped_json_schema_refs(prop, name_mapping)
|
|
357
|
+
|
|
358
|
+
# Handle arrays
|
|
359
|
+
if 'items' in s and isinstance(s['items'], dict):
|
|
360
|
+
items: dict[str, Any] = s['items']
|
|
361
|
+
_update_mapped_json_schema_refs(items, name_mapping)
|
|
362
|
+
if 'prefixItems' in s:
|
|
363
|
+
prefix_items: list[dict[str, Any]] = s['prefixItems']
|
|
364
|
+
for item in prefix_items:
|
|
365
|
+
_update_mapped_json_schema_refs(item, name_mapping)
|
|
366
|
+
|
|
367
|
+
# Handle unions
|
|
368
|
+
for union_type in ['anyOf', 'oneOf']:
|
|
369
|
+
if union_type in s:
|
|
370
|
+
union_items: list[dict[str, Any]] = s[union_type]
|
|
371
|
+
for item in union_items:
|
|
372
|
+
_update_mapped_json_schema_refs(item, name_mapping)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def merge_json_schema_defs(schemas: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], dict[str, dict[str, Any]]]:
|
|
376
|
+
"""Merges the `$defs` from different JSON schemas into a single deduplicated `$defs`, handling name collisions of `$defs` that are not the same, and rewrites `$ref`s to point to the new `$defs`.
|
|
377
|
+
|
|
378
|
+
Returns a tuple of the rewritten schemas and a dictionary of the new `$defs`.
|
|
379
|
+
"""
|
|
380
|
+
all_defs: dict[str, dict[str, Any]] = {}
|
|
381
|
+
rewritten_schemas: list[dict[str, Any]] = []
|
|
382
|
+
|
|
383
|
+
for schema in schemas:
|
|
384
|
+
if '$defs' not in schema:
|
|
385
|
+
rewritten_schemas.append(schema)
|
|
386
|
+
continue
|
|
387
|
+
|
|
388
|
+
schema = schema.copy()
|
|
389
|
+
defs = schema.pop('$defs', None)
|
|
390
|
+
schema_name_mapping: dict[str, str] = {}
|
|
391
|
+
|
|
392
|
+
# Process definitions and build mapping
|
|
393
|
+
for name, def_schema in defs.items():
|
|
394
|
+
if name not in all_defs:
|
|
395
|
+
all_defs[name] = def_schema
|
|
396
|
+
schema_name_mapping[name] = name
|
|
397
|
+
elif def_schema != all_defs[name]:
|
|
398
|
+
new_name = name
|
|
399
|
+
if title := schema.get('title'):
|
|
400
|
+
new_name = f'{title}_{name}'
|
|
401
|
+
|
|
402
|
+
i = 1
|
|
403
|
+
original_new_name = new_name
|
|
404
|
+
new_name = f'{new_name}_{i}'
|
|
405
|
+
while new_name in all_defs:
|
|
406
|
+
i += 1
|
|
407
|
+
new_name = f'{original_new_name}_{i}'
|
|
408
|
+
|
|
409
|
+
all_defs[new_name] = def_schema
|
|
410
|
+
schema_name_mapping[name] = new_name
|
|
411
|
+
|
|
412
|
+
_update_mapped_json_schema_refs(schema, schema_name_mapping)
|
|
413
|
+
rewritten_schemas.append(schema)
|
|
414
|
+
|
|
415
|
+
return rewritten_schemas, all_defs
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
def strip_markdown_fences(text: str) -> str:
|
|
419
|
+
if text.startswith('{'):
|
|
420
|
+
return text
|
|
421
|
+
|
|
422
|
+
regex = r'```(?:\w+)?\n(\{.*\})\n```'
|
|
423
|
+
match = re.search(regex, text, re.DOTALL)
|
|
424
|
+
if match:
|
|
425
|
+
return match.group(1)
|
|
426
|
+
|
|
427
|
+
return text
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
def get_union_args(tp: Any) -> tuple[Any, ...]:
|
|
431
|
+
"""Extract the arguments of a Union type if `tp` is a union, otherwise return an empty tuple."""
|
|
432
|
+
if typing_objects.is_typealiastype(tp):
|
|
433
|
+
tp = tp.__value__
|
|
434
|
+
|
|
435
|
+
origin = get_origin(tp)
|
|
436
|
+
if is_union_origin(origin):
|
|
437
|
+
return get_args(tp)
|
|
438
|
+
else:
|
|
439
|
+
return ()
|
pydantic_ai/agent.py
CHANGED
|
@@ -14,6 +14,7 @@ from opentelemetry.trace import NoOpTracer, use_span
|
|
|
14
14
|
from pydantic.json_schema import GenerateJsonSchema
|
|
15
15
|
from typing_extensions import Literal, Never, Self, TypeIs, TypeVar, deprecated
|
|
16
16
|
|
|
17
|
+
from pydantic_ai.profiles import ModelProfile
|
|
17
18
|
from pydantic_graph import End, Graph, GraphRun, GraphRunContext
|
|
18
19
|
from pydantic_graph._utils import get_event_loop
|
|
19
20
|
|
|
@@ -30,7 +31,8 @@ from . import (
|
|
|
30
31
|
)
|
|
31
32
|
from ._agent_graph import HistoryProcessor
|
|
32
33
|
from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model
|
|
33
|
-
from .
|
|
34
|
+
from .output import OutputDataT, OutputSpec
|
|
35
|
+
from .result import FinalResult, StreamedRunResult
|
|
34
36
|
from .settings import ModelSettings, merge_model_settings
|
|
35
37
|
from .tools import (
|
|
36
38
|
AgentDepsT,
|
|
@@ -91,7 +93,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
91
93
|
"""Class for defining "agents" - a way to have a specific type of "conversation" with an LLM.
|
|
92
94
|
|
|
93
95
|
Agents are generic in the dependency type they take [`AgentDepsT`][pydantic_ai.tools.AgentDepsT]
|
|
94
|
-
and the
|
|
96
|
+
and the output type they return, [`OutputDataT`][pydantic_ai.output.OutputDataT].
|
|
95
97
|
|
|
96
98
|
By default, if neither generic parameter is customised, agents have type `Agent[None, str]`.
|
|
97
99
|
|
|
@@ -128,7 +130,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
128
130
|
be merged with this value, with the runtime argument taking priority.
|
|
129
131
|
"""
|
|
130
132
|
|
|
131
|
-
output_type:
|
|
133
|
+
output_type: OutputSpec[OutputDataT]
|
|
132
134
|
"""
|
|
133
135
|
The type of data output by agent runs, used to validate the data returned by the model, defaults to `str`.
|
|
134
136
|
"""
|
|
@@ -141,7 +143,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
141
143
|
_deps_type: type[AgentDepsT] = dataclasses.field(repr=False)
|
|
142
144
|
_deprecated_result_tool_name: str | None = dataclasses.field(repr=False)
|
|
143
145
|
_deprecated_result_tool_description: str | None = dataclasses.field(repr=False)
|
|
144
|
-
_output_schema: _output.
|
|
146
|
+
_output_schema: _output.BaseOutputSchema[OutputDataT] = dataclasses.field(repr=False)
|
|
145
147
|
_output_validators: list[_output.OutputValidator[AgentDepsT, OutputDataT]] = dataclasses.field(repr=False)
|
|
146
148
|
_instructions: str | None = dataclasses.field(repr=False)
|
|
147
149
|
_instructions_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False)
|
|
@@ -163,7 +165,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
163
165
|
self,
|
|
164
166
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
165
167
|
*,
|
|
166
|
-
output_type:
|
|
168
|
+
output_type: OutputSpec[OutputDataT] = str,
|
|
167
169
|
instructions: str
|
|
168
170
|
| _system_prompt.SystemPromptFunc[AgentDepsT]
|
|
169
171
|
| Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]]
|
|
@@ -324,8 +326,14 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
324
326
|
warnings.warn('`result_retries` is deprecated, use `max_result_retries` instead', DeprecationWarning)
|
|
325
327
|
output_retries = result_retries
|
|
326
328
|
|
|
329
|
+
default_output_mode = (
|
|
330
|
+
self.model.profile.default_structured_output_mode if isinstance(self.model, models.Model) else None
|
|
331
|
+
)
|
|
327
332
|
self._output_schema = _output.OutputSchema[OutputDataT].build(
|
|
328
|
-
output_type,
|
|
333
|
+
output_type,
|
|
334
|
+
default_mode=default_output_mode,
|
|
335
|
+
name=self._deprecated_result_tool_name,
|
|
336
|
+
description=self._deprecated_result_tool_description,
|
|
329
337
|
)
|
|
330
338
|
self._output_validators = []
|
|
331
339
|
|
|
@@ -382,7 +390,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
382
390
|
self,
|
|
383
391
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
384
392
|
*,
|
|
385
|
-
output_type:
|
|
393
|
+
output_type: OutputSpec[RunOutputDataT],
|
|
386
394
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
387
395
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
388
396
|
deps: AgentDepsT = None,
|
|
@@ -412,7 +420,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
412
420
|
self,
|
|
413
421
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
414
422
|
*,
|
|
415
|
-
output_type:
|
|
423
|
+
output_type: OutputSpec[RunOutputDataT] | None = None,
|
|
416
424
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
417
425
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
418
426
|
deps: AgentDepsT = None,
|
|
@@ -500,7 +508,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
500
508
|
self,
|
|
501
509
|
user_prompt: str | Sequence[_messages.UserContent] | None,
|
|
502
510
|
*,
|
|
503
|
-
output_type:
|
|
511
|
+
output_type: OutputSpec[RunOutputDataT],
|
|
504
512
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
505
513
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
506
514
|
deps: AgentDepsT = None,
|
|
@@ -532,7 +540,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
532
540
|
self,
|
|
533
541
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
534
542
|
*,
|
|
535
|
-
output_type:
|
|
543
|
+
output_type: OutputSpec[RunOutputDataT] | None = None,
|
|
536
544
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
537
545
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
538
546
|
deps: AgentDepsT = None,
|
|
@@ -631,7 +639,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
631
639
|
|
|
632
640
|
deps = self._get_deps(deps)
|
|
633
641
|
new_message_index = len(message_history) if message_history else 0
|
|
634
|
-
output_schema = self._prepare_output_schema(output_type)
|
|
642
|
+
output_schema = self._prepare_output_schema(output_type, model_used.profile)
|
|
635
643
|
|
|
636
644
|
output_type_ = output_type or self.output_type
|
|
637
645
|
|
|
@@ -674,14 +682,20 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
674
682
|
)
|
|
675
683
|
|
|
676
684
|
async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
|
|
677
|
-
|
|
678
|
-
|
|
685
|
+
parts = [
|
|
686
|
+
self._instructions,
|
|
687
|
+
*[await func.run(run_context) for func in self._instructions_functions],
|
|
688
|
+
]
|
|
679
689
|
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
instructions.
|
|
683
|
-
|
|
684
|
-
|
|
690
|
+
model_profile = model_used.profile
|
|
691
|
+
if isinstance(output_schema, _output.PromptedOutputSchema):
|
|
692
|
+
instructions = output_schema.instructions(model_profile.prompted_output_template)
|
|
693
|
+
parts.append(instructions)
|
|
694
|
+
|
|
695
|
+
parts = [p for p in parts if p]
|
|
696
|
+
if not parts:
|
|
697
|
+
return None
|
|
698
|
+
return '\n\n'.join(parts).strip()
|
|
685
699
|
|
|
686
700
|
# Copy the function tools so that retry state is agent-run-specific
|
|
687
701
|
# Note that the retry count is reset to 0 when this happens due to the `default=0` and `init=False`.
|
|
@@ -779,7 +793,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
779
793
|
self,
|
|
780
794
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
781
795
|
*,
|
|
782
|
-
output_type:
|
|
796
|
+
output_type: OutputSpec[RunOutputDataT] | None = None,
|
|
783
797
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
784
798
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
785
799
|
deps: AgentDepsT = None,
|
|
@@ -809,7 +823,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
809
823
|
self,
|
|
810
824
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
811
825
|
*,
|
|
812
|
-
output_type:
|
|
826
|
+
output_type: OutputSpec[RunOutputDataT] | None = None,
|
|
813
827
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
814
828
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
815
829
|
deps: AgentDepsT = None,
|
|
@@ -892,7 +906,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
892
906
|
self,
|
|
893
907
|
user_prompt: str | Sequence[_messages.UserContent],
|
|
894
908
|
*,
|
|
895
|
-
output_type:
|
|
909
|
+
output_type: OutputSpec[RunOutputDataT],
|
|
896
910
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
897
911
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
898
912
|
deps: AgentDepsT = None,
|
|
@@ -923,7 +937,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
923
937
|
self,
|
|
924
938
|
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
925
939
|
*,
|
|
926
|
-
output_type:
|
|
940
|
+
output_type: OutputSpec[RunOutputDataT] | None = None,
|
|
927
941
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
928
942
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
929
943
|
deps: AgentDepsT = None,
|
|
@@ -1002,10 +1016,13 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1002
1016
|
async for maybe_part_event in streamed_response:
|
|
1003
1017
|
if isinstance(maybe_part_event, _messages.PartStartEvent):
|
|
1004
1018
|
new_part = maybe_part_event.part
|
|
1005
|
-
if isinstance(new_part, _messages.TextPart)
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
|
|
1019
|
+
if isinstance(new_part, _messages.TextPart) and isinstance(
|
|
1020
|
+
output_schema, _output.TextOutputSchema
|
|
1021
|
+
):
|
|
1022
|
+
return FinalResult(s, None, None)
|
|
1023
|
+
elif isinstance(new_part, _messages.ToolCallPart) and isinstance(
|
|
1024
|
+
output_schema, _output.ToolOutputSchema
|
|
1025
|
+
): # pragma: no branch
|
|
1009
1026
|
for call, _ in output_schema.find_tool([new_part]):
|
|
1010
1027
|
return FinalResult(s, call.tool_name, call.tool_call_id)
|
|
1011
1028
|
return None
|
|
@@ -1561,8 +1578,8 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1561
1578
|
if tool.name in self._function_tools:
|
|
1562
1579
|
raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}')
|
|
1563
1580
|
|
|
1564
|
-
if
|
|
1565
|
-
raise exceptions.UserError(f'Tool name conflicts with
|
|
1581
|
+
if tool.name in self._output_schema.tools:
|
|
1582
|
+
raise exceptions.UserError(f'Tool name conflicts with output tool name: {tool.name!r}')
|
|
1566
1583
|
|
|
1567
1584
|
self._function_tools[tool.name] = tool
|
|
1568
1585
|
|
|
@@ -1637,18 +1654,23 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1637
1654
|
raise AttributeError('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.')
|
|
1638
1655
|
|
|
1639
1656
|
def _prepare_output_schema(
|
|
1640
|
-
self, output_type:
|
|
1641
|
-
) -> _output.OutputSchema[RunOutputDataT]
|
|
1657
|
+
self, output_type: OutputSpec[RunOutputDataT] | None, model_profile: ModelProfile
|
|
1658
|
+
) -> _output.OutputSchema[RunOutputDataT]:
|
|
1642
1659
|
if output_type is not None:
|
|
1643
1660
|
if self._output_validators:
|
|
1644
1661
|
raise exceptions.UserError('Cannot set a custom run `output_type` when the agent has output validators')
|
|
1645
|
-
|
|
1662
|
+
schema = _output.OutputSchema[RunOutputDataT].build(
|
|
1646
1663
|
output_type,
|
|
1647
|
-
self._deprecated_result_tool_name,
|
|
1648
|
-
self._deprecated_result_tool_description,
|
|
1664
|
+
name=self._deprecated_result_tool_name,
|
|
1665
|
+
description=self._deprecated_result_tool_description,
|
|
1666
|
+
default_mode=model_profile.default_structured_output_mode,
|
|
1649
1667
|
)
|
|
1650
1668
|
else:
|
|
1651
|
-
|
|
1669
|
+
schema = self._output_schema.with_default_mode(model_profile.default_structured_output_mode)
|
|
1670
|
+
|
|
1671
|
+
schema.raise_if_unsupported(model_profile)
|
|
1672
|
+
|
|
1673
|
+
return schema # pyright: ignore[reportReturnType]
|
|
1652
1674
|
|
|
1653
1675
|
@staticmethod
|
|
1654
1676
|
def is_model_request_node(
|
|
@@ -1691,14 +1713,23 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1691
1713
|
return isinstance(node, End)
|
|
1692
1714
|
|
|
1693
1715
|
@asynccontextmanager
|
|
1694
|
-
async def run_mcp_servers(
|
|
1716
|
+
async def run_mcp_servers(
|
|
1717
|
+
self, model: models.Model | models.KnownModelName | str | None = None
|
|
1718
|
+
) -> AsyncIterator[None]:
|
|
1695
1719
|
"""Run [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] so they can be used by the agent.
|
|
1696
1720
|
|
|
1697
1721
|
Returns: a context manager to start and shutdown the servers.
|
|
1698
1722
|
"""
|
|
1723
|
+
try:
|
|
1724
|
+
sampling_model: models.Model | None = self._get_model(model)
|
|
1725
|
+
except exceptions.UserError: # pragma: no cover
|
|
1726
|
+
sampling_model = None
|
|
1727
|
+
|
|
1699
1728
|
exit_stack = AsyncExitStack()
|
|
1700
1729
|
try:
|
|
1701
1730
|
for mcp_server in self._mcp_servers:
|
|
1731
|
+
if sampling_model is not None: # pragma: no branch
|
|
1732
|
+
mcp_server.sampling_model = sampling_model
|
|
1702
1733
|
await exit_stack.enter_async_context(mcp_server)
|
|
1703
1734
|
yield
|
|
1704
1735
|
finally:
|