pydantic-ai-slim 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -0,0 +1,56 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import dataclasses
4
+ from collections.abc import Sequence
5
+ from dataclasses import field
6
+ from typing import TYPE_CHECKING, Generic
7
+
8
+ from typing_extensions import TypeVar
9
+
10
+ from . import _utils, messages as _messages
11
+
12
+ if TYPE_CHECKING:
13
+ from .models import Model
14
+ from .result import Usage
15
+
16
+ AgentDepsT = TypeVar('AgentDepsT', default=None, contravariant=True)
17
+ """Type variable for agent dependencies."""
18
+
19
+
20
+ @dataclasses.dataclass(repr=False)
21
+ class RunContext(Generic[AgentDepsT]):
22
+ """Information about the current call."""
23
+
24
+ deps: AgentDepsT
25
+ """Dependencies for the agent."""
26
+ model: Model
27
+ """The model used in this run."""
28
+ usage: Usage
29
+ """LLM usage associated with the run."""
30
+ prompt: str | Sequence[_messages.UserContent] | None
31
+ """The original user prompt passed to the run."""
32
+ messages: list[_messages.ModelMessage] = field(default_factory=list)
33
+ """Messages exchanged in the conversation so far."""
34
+ tool_call_id: str | None = None
35
+ """The ID of the tool call."""
36
+ tool_name: str | None = None
37
+ """Name of the tool being called."""
38
+ retry: int = 0
39
+ """Number of retries so far."""
40
+ run_step: int = 0
41
+ """The current step in the run."""
42
+
43
+ def replace_with(
44
+ self,
45
+ retry: int | None = None,
46
+ tool_name: str | None | _utils.Unset = _utils.UNSET,
47
+ ) -> RunContext[AgentDepsT]:
48
+ # Create a new `RunContext` a new `retry` value and `tool_name`.
49
+ kwargs = {}
50
+ if retry is not None:
51
+ kwargs['retry'] = retry
52
+ if tool_name is not _utils.UNSET: # pragma: no branch
53
+ kwargs['tool_name'] = tool_name
54
+ return dataclasses.replace(self, **kwargs)
55
+
56
+ __repr__ = _utils.dataclasses_no_defaults_repr
@@ -6,7 +6,8 @@ from dataclasses import dataclass, field
6
6
  from typing import Any, Callable, Generic, cast
7
7
 
8
8
  from . import _utils
9
- from .tools import AgentDepsT, RunContext, SystemPromptFunc
9
+ from ._run_context import AgentDepsT, RunContext
10
+ from .tools import SystemPromptFunc
10
11
 
11
12
 
12
13
  @dataclass
pydantic_ai/_utils.py CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations as _annotations
3
3
  import asyncio
4
4
  import functools
5
5
  import inspect
6
+ import re
6
7
  import time
7
8
  import uuid
8
9
  from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterator
@@ -16,7 +17,17 @@ from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, overlo
16
17
  from anyio.to_thread import run_sync
17
18
  from pydantic import BaseModel, TypeAdapter
18
19
  from pydantic.json_schema import JsonSchemaValue
19
- from typing_extensions import ParamSpec, TypeAlias, TypeGuard, TypeIs, is_typeddict
20
+ from typing_extensions import (
21
+ ParamSpec,
22
+ TypeAlias,
23
+ TypeGuard,
24
+ TypeIs,
25
+ get_args,
26
+ get_origin,
27
+ is_typeddict,
28
+ )
29
+ from typing_inspection import typing_objects
30
+ from typing_inspection.introspection import is_union_origin
20
31
 
21
32
  from pydantic_graph._utils import AbstractSpan
22
33
 
@@ -327,3 +338,102 @@ def is_async_callable(obj: Any) -> Any:
327
338
  obj = obj.func
328
339
 
329
340
  return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__)) # type: ignore
341
+
342
+
343
+ def _update_mapped_json_schema_refs(s: dict[str, Any], name_mapping: dict[str, str]) -> None:
344
+ """Update $refs in a schema to use the new names from name_mapping."""
345
+ if '$ref' in s:
346
+ ref = s['$ref']
347
+ if ref.startswith('#/$defs/'): # pragma: no branch
348
+ original_name = ref[8:] # Remove '#/$defs/'
349
+ new_name = name_mapping.get(original_name, original_name)
350
+ s['$ref'] = f'#/$defs/{new_name}'
351
+
352
+ # Recursively update refs in properties
353
+ if 'properties' in s:
354
+ props: dict[str, dict[str, Any]] = s['properties']
355
+ for prop in props.values():
356
+ _update_mapped_json_schema_refs(prop, name_mapping)
357
+
358
+ # Handle arrays
359
+ if 'items' in s and isinstance(s['items'], dict):
360
+ items: dict[str, Any] = s['items']
361
+ _update_mapped_json_schema_refs(items, name_mapping)
362
+ if 'prefixItems' in s:
363
+ prefix_items: list[dict[str, Any]] = s['prefixItems']
364
+ for item in prefix_items:
365
+ _update_mapped_json_schema_refs(item, name_mapping)
366
+
367
+ # Handle unions
368
+ for union_type in ['anyOf', 'oneOf']:
369
+ if union_type in s:
370
+ union_items: list[dict[str, Any]] = s[union_type]
371
+ for item in union_items:
372
+ _update_mapped_json_schema_refs(item, name_mapping)
373
+
374
+
375
+ def merge_json_schema_defs(schemas: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], dict[str, dict[str, Any]]]:
376
+ """Merges the `$defs` from different JSON schemas into a single deduplicated `$defs`, handling name collisions of `$defs` that are not the same, and rewrites `$ref`s to point to the new `$defs`.
377
+
378
+ Returns a tuple of the rewritten schemas and a dictionary of the new `$defs`.
379
+ """
380
+ all_defs: dict[str, dict[str, Any]] = {}
381
+ rewritten_schemas: list[dict[str, Any]] = []
382
+
383
+ for schema in schemas:
384
+ if '$defs' not in schema:
385
+ rewritten_schemas.append(schema)
386
+ continue
387
+
388
+ schema = schema.copy()
389
+ defs = schema.pop('$defs', None)
390
+ schema_name_mapping: dict[str, str] = {}
391
+
392
+ # Process definitions and build mapping
393
+ for name, def_schema in defs.items():
394
+ if name not in all_defs:
395
+ all_defs[name] = def_schema
396
+ schema_name_mapping[name] = name
397
+ elif def_schema != all_defs[name]:
398
+ new_name = name
399
+ if title := schema.get('title'):
400
+ new_name = f'{title}_{name}'
401
+
402
+ i = 1
403
+ original_new_name = new_name
404
+ new_name = f'{new_name}_{i}'
405
+ while new_name in all_defs:
406
+ i += 1
407
+ new_name = f'{original_new_name}_{i}'
408
+
409
+ all_defs[new_name] = def_schema
410
+ schema_name_mapping[name] = new_name
411
+
412
+ _update_mapped_json_schema_refs(schema, schema_name_mapping)
413
+ rewritten_schemas.append(schema)
414
+
415
+ return rewritten_schemas, all_defs
416
+
417
+
418
+ def strip_markdown_fences(text: str) -> str:
419
+ if text.startswith('{'):
420
+ return text
421
+
422
+ regex = r'```(?:\w+)?\n(\{.*\})\n```'
423
+ match = re.search(regex, text, re.DOTALL)
424
+ if match:
425
+ return match.group(1)
426
+
427
+ return text
428
+
429
+
430
+ def get_union_args(tp: Any) -> tuple[Any, ...]:
431
+ """Extract the arguments of a Union type if `tp` is a union, otherwise return an empty tuple."""
432
+ if typing_objects.is_typealiastype(tp):
433
+ tp = tp.__value__
434
+
435
+ origin = get_origin(tp)
436
+ if is_union_origin(origin):
437
+ return get_args(tp)
438
+ else:
439
+ return ()
pydantic_ai/agent.py CHANGED
@@ -14,6 +14,7 @@ from opentelemetry.trace import NoOpTracer, use_span
14
14
  from pydantic.json_schema import GenerateJsonSchema
15
15
  from typing_extensions import Literal, Never, Self, TypeIs, TypeVar, deprecated
16
16
 
17
+ from pydantic_ai.profiles import ModelProfile
17
18
  from pydantic_graph import End, Graph, GraphRun, GraphRunContext
18
19
  from pydantic_graph._utils import get_event_loop
19
20
 
@@ -30,7 +31,8 @@ from . import (
30
31
  )
31
32
  from ._agent_graph import HistoryProcessor
32
33
  from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model
33
- from .result import FinalResult, OutputDataT, StreamedRunResult
34
+ from .output import OutputDataT, OutputSpec
35
+ from .result import FinalResult, StreamedRunResult
34
36
  from .settings import ModelSettings, merge_model_settings
35
37
  from .tools import (
36
38
  AgentDepsT,
@@ -91,7 +93,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
91
93
  """Class for defining "agents" - a way to have a specific type of "conversation" with an LLM.
92
94
 
93
95
  Agents are generic in the dependency type they take [`AgentDepsT`][pydantic_ai.tools.AgentDepsT]
94
- and the result data type they return, [`OutputDataT`][pydantic_ai.result.OutputDataT].
96
+ and the output type they return, [`OutputDataT`][pydantic_ai.output.OutputDataT].
95
97
 
96
98
  By default, if neither generic parameter is customised, agents have type `Agent[None, str]`.
97
99
 
@@ -128,7 +130,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
128
130
  be merged with this value, with the runtime argument taking priority.
129
131
  """
130
132
 
131
- output_type: _output.OutputType[OutputDataT]
133
+ output_type: OutputSpec[OutputDataT]
132
134
  """
133
135
  The type of data output by agent runs, used to validate the data returned by the model, defaults to `str`.
134
136
  """
@@ -141,7 +143,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
141
143
  _deps_type: type[AgentDepsT] = dataclasses.field(repr=False)
142
144
  _deprecated_result_tool_name: str | None = dataclasses.field(repr=False)
143
145
  _deprecated_result_tool_description: str | None = dataclasses.field(repr=False)
144
- _output_schema: _output.OutputSchema[OutputDataT] | None = dataclasses.field(repr=False)
146
+ _output_schema: _output.BaseOutputSchema[OutputDataT] = dataclasses.field(repr=False)
145
147
  _output_validators: list[_output.OutputValidator[AgentDepsT, OutputDataT]] = dataclasses.field(repr=False)
146
148
  _instructions: str | None = dataclasses.field(repr=False)
147
149
  _instructions_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False)
@@ -163,7 +165,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
163
165
  self,
164
166
  model: models.Model | models.KnownModelName | str | None = None,
165
167
  *,
166
- output_type: _output.OutputType[OutputDataT] = str,
168
+ output_type: OutputSpec[OutputDataT] = str,
167
169
  instructions: str
168
170
  | _system_prompt.SystemPromptFunc[AgentDepsT]
169
171
  | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]]
@@ -324,8 +326,14 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
324
326
  warnings.warn('`result_retries` is deprecated, use `max_result_retries` instead', DeprecationWarning)
325
327
  output_retries = result_retries
326
328
 
329
+ default_output_mode = (
330
+ self.model.profile.default_structured_output_mode if isinstance(self.model, models.Model) else None
331
+ )
327
332
  self._output_schema = _output.OutputSchema[OutputDataT].build(
328
- output_type, self._deprecated_result_tool_name, self._deprecated_result_tool_description
333
+ output_type,
334
+ default_mode=default_output_mode,
335
+ name=self._deprecated_result_tool_name,
336
+ description=self._deprecated_result_tool_description,
329
337
  )
330
338
  self._output_validators = []
331
339
 
@@ -382,7 +390,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
382
390
  self,
383
391
  user_prompt: str | Sequence[_messages.UserContent] | None = None,
384
392
  *,
385
- output_type: _output.OutputType[RunOutputDataT],
393
+ output_type: OutputSpec[RunOutputDataT],
386
394
  message_history: list[_messages.ModelMessage] | None = None,
387
395
  model: models.Model | models.KnownModelName | str | None = None,
388
396
  deps: AgentDepsT = None,
@@ -412,7 +420,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
412
420
  self,
413
421
  user_prompt: str | Sequence[_messages.UserContent] | None = None,
414
422
  *,
415
- output_type: _output.OutputType[RunOutputDataT] | None = None,
423
+ output_type: OutputSpec[RunOutputDataT] | None = None,
416
424
  message_history: list[_messages.ModelMessage] | None = None,
417
425
  model: models.Model | models.KnownModelName | str | None = None,
418
426
  deps: AgentDepsT = None,
@@ -500,7 +508,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
500
508
  self,
501
509
  user_prompt: str | Sequence[_messages.UserContent] | None,
502
510
  *,
503
- output_type: _output.OutputType[RunOutputDataT],
511
+ output_type: OutputSpec[RunOutputDataT],
504
512
  message_history: list[_messages.ModelMessage] | None = None,
505
513
  model: models.Model | models.KnownModelName | str | None = None,
506
514
  deps: AgentDepsT = None,
@@ -532,7 +540,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
532
540
  self,
533
541
  user_prompt: str | Sequence[_messages.UserContent] | None = None,
534
542
  *,
535
- output_type: _output.OutputType[RunOutputDataT] | None = None,
543
+ output_type: OutputSpec[RunOutputDataT] | None = None,
536
544
  message_history: list[_messages.ModelMessage] | None = None,
537
545
  model: models.Model | models.KnownModelName | str | None = None,
538
546
  deps: AgentDepsT = None,
@@ -631,7 +639,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
631
639
 
632
640
  deps = self._get_deps(deps)
633
641
  new_message_index = len(message_history) if message_history else 0
634
- output_schema = self._prepare_output_schema(output_type)
642
+ output_schema = self._prepare_output_schema(output_type, model_used.profile)
635
643
 
636
644
  output_type_ = output_type or self.output_type
637
645
 
@@ -674,14 +682,20 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
674
682
  )
675
683
 
676
684
  async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
677
- if self._instructions is None and not self._instructions_functions:
678
- return None
685
+ parts = [
686
+ self._instructions,
687
+ *[await func.run(run_context) for func in self._instructions_functions],
688
+ ]
679
689
 
680
- instructions = [self._instructions] if self._instructions else []
681
- for instructions_runner in self._instructions_functions:
682
- instructions.append(await instructions_runner.run(run_context))
683
- concatenated_instructions = '\n'.join(instruction for instruction in instructions if instruction)
684
- return concatenated_instructions.strip() if concatenated_instructions else None
690
+ model_profile = model_used.profile
691
+ if isinstance(output_schema, _output.PromptedOutputSchema):
692
+ instructions = output_schema.instructions(model_profile.prompted_output_template)
693
+ parts.append(instructions)
694
+
695
+ parts = [p for p in parts if p]
696
+ if not parts:
697
+ return None
698
+ return '\n\n'.join(parts).strip()
685
699
 
686
700
  # Copy the function tools so that retry state is agent-run-specific
687
701
  # Note that the retry count is reset to 0 when this happens due to the `default=0` and `init=False`.
@@ -779,7 +793,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
779
793
  self,
780
794
  user_prompt: str | Sequence[_messages.UserContent] | None = None,
781
795
  *,
782
- output_type: _output.OutputType[RunOutputDataT] | None = None,
796
+ output_type: OutputSpec[RunOutputDataT] | None = None,
783
797
  message_history: list[_messages.ModelMessage] | None = None,
784
798
  model: models.Model | models.KnownModelName | str | None = None,
785
799
  deps: AgentDepsT = None,
@@ -809,7 +823,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
809
823
  self,
810
824
  user_prompt: str | Sequence[_messages.UserContent] | None = None,
811
825
  *,
812
- output_type: _output.OutputType[RunOutputDataT] | None = None,
826
+ output_type: OutputSpec[RunOutputDataT] | None = None,
813
827
  message_history: list[_messages.ModelMessage] | None = None,
814
828
  model: models.Model | models.KnownModelName | str | None = None,
815
829
  deps: AgentDepsT = None,
@@ -892,7 +906,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
892
906
  self,
893
907
  user_prompt: str | Sequence[_messages.UserContent],
894
908
  *,
895
- output_type: _output.OutputType[RunOutputDataT],
909
+ output_type: OutputSpec[RunOutputDataT],
896
910
  message_history: list[_messages.ModelMessage] | None = None,
897
911
  model: models.Model | models.KnownModelName | str | None = None,
898
912
  deps: AgentDepsT = None,
@@ -923,7 +937,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
923
937
  self,
924
938
  user_prompt: str | Sequence[_messages.UserContent] | None = None,
925
939
  *,
926
- output_type: _output.OutputType[RunOutputDataT] | None = None,
940
+ output_type: OutputSpec[RunOutputDataT] | None = None,
927
941
  message_history: list[_messages.ModelMessage] | None = None,
928
942
  model: models.Model | models.KnownModelName | str | None = None,
929
943
  deps: AgentDepsT = None,
@@ -1002,10 +1016,13 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1002
1016
  async for maybe_part_event in streamed_response:
1003
1017
  if isinstance(maybe_part_event, _messages.PartStartEvent):
1004
1018
  new_part = maybe_part_event.part
1005
- if isinstance(new_part, _messages.TextPart):
1006
- if _output.allow_text_output(output_schema):
1007
- return FinalResult(s, None, None)
1008
- elif isinstance(new_part, _messages.ToolCallPart) and output_schema:
1019
+ if isinstance(new_part, _messages.TextPart) and isinstance(
1020
+ output_schema, _output.TextOutputSchema
1021
+ ):
1022
+ return FinalResult(s, None, None)
1023
+ elif isinstance(new_part, _messages.ToolCallPart) and isinstance(
1024
+ output_schema, _output.ToolOutputSchema
1025
+ ): # pragma: no branch
1009
1026
  for call, _ in output_schema.find_tool([new_part]):
1010
1027
  return FinalResult(s, call.tool_name, call.tool_call_id)
1011
1028
  return None
@@ -1561,8 +1578,8 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1561
1578
  if tool.name in self._function_tools:
1562
1579
  raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}')
1563
1580
 
1564
- if self._output_schema and tool.name in self._output_schema.tools:
1565
- raise exceptions.UserError(f'Tool name conflicts with result schema name: {tool.name!r}')
1581
+ if tool.name in self._output_schema.tools:
1582
+ raise exceptions.UserError(f'Tool name conflicts with output tool name: {tool.name!r}')
1566
1583
 
1567
1584
  self._function_tools[tool.name] = tool
1568
1585
 
@@ -1637,18 +1654,23 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1637
1654
  raise AttributeError('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.')
1638
1655
 
1639
1656
  def _prepare_output_schema(
1640
- self, output_type: _output.OutputType[RunOutputDataT] | None
1641
- ) -> _output.OutputSchema[RunOutputDataT] | None:
1657
+ self, output_type: OutputSpec[RunOutputDataT] | None, model_profile: ModelProfile
1658
+ ) -> _output.OutputSchema[RunOutputDataT]:
1642
1659
  if output_type is not None:
1643
1660
  if self._output_validators:
1644
1661
  raise exceptions.UserError('Cannot set a custom run `output_type` when the agent has output validators')
1645
- return _output.OutputSchema[RunOutputDataT].build(
1662
+ schema = _output.OutputSchema[RunOutputDataT].build(
1646
1663
  output_type,
1647
- self._deprecated_result_tool_name,
1648
- self._deprecated_result_tool_description,
1664
+ name=self._deprecated_result_tool_name,
1665
+ description=self._deprecated_result_tool_description,
1666
+ default_mode=model_profile.default_structured_output_mode,
1649
1667
  )
1650
1668
  else:
1651
- return self._output_schema # pyright: ignore[reportReturnType]
1669
+ schema = self._output_schema.with_default_mode(model_profile.default_structured_output_mode)
1670
+
1671
+ schema.raise_if_unsupported(model_profile)
1672
+
1673
+ return schema # pyright: ignore[reportReturnType]
1652
1674
 
1653
1675
  @staticmethod
1654
1676
  def is_model_request_node(
@@ -1691,14 +1713,23 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1691
1713
  return isinstance(node, End)
1692
1714
 
1693
1715
  @asynccontextmanager
1694
- async def run_mcp_servers(self) -> AsyncIterator[None]:
1716
+ async def run_mcp_servers(
1717
+ self, model: models.Model | models.KnownModelName | str | None = None
1718
+ ) -> AsyncIterator[None]:
1695
1719
  """Run [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] so they can be used by the agent.
1696
1720
 
1697
1721
  Returns: a context manager to start and shutdown the servers.
1698
1722
  """
1723
+ try:
1724
+ sampling_model: models.Model | None = self._get_model(model)
1725
+ except exceptions.UserError: # pragma: no cover
1726
+ sampling_model = None
1727
+
1699
1728
  exit_stack = AsyncExitStack()
1700
1729
  try:
1701
1730
  for mcp_server in self._mcp_servers:
1731
+ if sampling_model is not None: # pragma: no branch
1732
+ mcp_server.sampling_model = sampling_model
1702
1733
  await exit_stack.enter_async_context(mcp_server)
1703
1734
  yield
1704
1735
  finally: