pydantic-ai-slim 0.3.2__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -0,0 +1,56 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import dataclasses
4
+ from collections.abc import Sequence
5
+ from dataclasses import field
6
+ from typing import TYPE_CHECKING, Generic
7
+
8
+ from typing_extensions import TypeVar
9
+
10
+ from . import _utils, messages as _messages
11
+
12
+ if TYPE_CHECKING:
13
+ from .models import Model
14
+ from .result import Usage
15
+
16
+ AgentDepsT = TypeVar('AgentDepsT', default=None, contravariant=True)
17
+ """Type variable for agent dependencies."""
18
+
19
+
20
+ @dataclasses.dataclass(repr=False)
21
+ class RunContext(Generic[AgentDepsT]):
22
+ """Information about the current call."""
23
+
24
+ deps: AgentDepsT
25
+ """Dependencies for the agent."""
26
+ model: Model
27
+ """The model used in this run."""
28
+ usage: Usage
29
+ """LLM usage associated with the run."""
30
+ prompt: str | Sequence[_messages.UserContent] | None
31
+ """The original user prompt passed to the run."""
32
+ messages: list[_messages.ModelMessage] = field(default_factory=list)
33
+ """Messages exchanged in the conversation so far."""
34
+ tool_call_id: str | None = None
35
+ """The ID of the tool call."""
36
+ tool_name: str | None = None
37
+ """Name of the tool being called."""
38
+ retry: int = 0
39
+ """Number of retries so far."""
40
+ run_step: int = 0
41
+ """The current step in the run."""
42
+
43
+ def replace_with(
44
+ self,
45
+ retry: int | None = None,
46
+ tool_name: str | None | _utils.Unset = _utils.UNSET,
47
+ ) -> RunContext[AgentDepsT]:
48
+ # Create a new `RunContext` a new `retry` value and `tool_name`.
49
+ kwargs = {}
50
+ if retry is not None:
51
+ kwargs['retry'] = retry
52
+ if tool_name is not _utils.UNSET: # pragma: no branch
53
+ kwargs['tool_name'] = tool_name
54
+ return dataclasses.replace(self, **kwargs)
55
+
56
+ __repr__ = _utils.dataclasses_no_defaults_repr
@@ -6,7 +6,8 @@ from dataclasses import dataclass, field
6
6
  from typing import Any, Callable, Generic, cast
7
7
 
8
8
  from . import _utils
9
- from .tools import AgentDepsT, RunContext, SystemPromptFunc
9
+ from ._run_context import AgentDepsT, RunContext
10
+ from .tools import SystemPromptFunc
10
11
 
11
12
 
12
13
  @dataclass
pydantic_ai/_utils.py CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations as _annotations
3
3
  import asyncio
4
4
  import functools
5
5
  import inspect
6
+ import re
6
7
  import time
7
8
  import uuid
8
9
  from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterator
@@ -16,7 +17,17 @@ from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, overlo
16
17
  from anyio.to_thread import run_sync
17
18
  from pydantic import BaseModel, TypeAdapter
18
19
  from pydantic.json_schema import JsonSchemaValue
19
- from typing_extensions import ParamSpec, TypeAlias, TypeGuard, TypeIs, is_typeddict
20
+ from typing_extensions import (
21
+ ParamSpec,
22
+ TypeAlias,
23
+ TypeGuard,
24
+ TypeIs,
25
+ get_args,
26
+ get_origin,
27
+ is_typeddict,
28
+ )
29
+ from typing_inspection import typing_objects
30
+ from typing_inspection.introspection import is_union_origin
20
31
 
21
32
  from pydantic_graph._utils import AbstractSpan
22
33
 
@@ -327,3 +338,102 @@ def is_async_callable(obj: Any) -> Any:
327
338
  obj = obj.func
328
339
 
329
340
  return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__)) # type: ignore
341
+
342
+
343
+ def _update_mapped_json_schema_refs(s: dict[str, Any], name_mapping: dict[str, str]) -> None:
344
+ """Update $refs in a schema to use the new names from name_mapping."""
345
+ if '$ref' in s:
346
+ ref = s['$ref']
347
+ if ref.startswith('#/$defs/'): # pragma: no branch
348
+ original_name = ref[8:] # Remove '#/$defs/'
349
+ new_name = name_mapping.get(original_name, original_name)
350
+ s['$ref'] = f'#/$defs/{new_name}'
351
+
352
+ # Recursively update refs in properties
353
+ if 'properties' in s:
354
+ props: dict[str, dict[str, Any]] = s['properties']
355
+ for prop in props.values():
356
+ _update_mapped_json_schema_refs(prop, name_mapping)
357
+
358
+ # Handle arrays
359
+ if 'items' in s and isinstance(s['items'], dict):
360
+ items: dict[str, Any] = s['items']
361
+ _update_mapped_json_schema_refs(items, name_mapping)
362
+ if 'prefixItems' in s:
363
+ prefix_items: list[dict[str, Any]] = s['prefixItems']
364
+ for item in prefix_items:
365
+ _update_mapped_json_schema_refs(item, name_mapping)
366
+
367
+ # Handle unions
368
+ for union_type in ['anyOf', 'oneOf']:
369
+ if union_type in s:
370
+ union_items: list[dict[str, Any]] = s[union_type]
371
+ for item in union_items:
372
+ _update_mapped_json_schema_refs(item, name_mapping)
373
+
374
+
375
+ def merge_json_schema_defs(schemas: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], dict[str, dict[str, Any]]]:
376
+ """Merges the `$defs` from different JSON schemas into a single deduplicated `$defs`, handling name collisions of `$defs` that are not the same, and rewrites `$ref`s to point to the new `$defs`.
377
+
378
+ Returns a tuple of the rewritten schemas and a dictionary of the new `$defs`.
379
+ """
380
+ all_defs: dict[str, dict[str, Any]] = {}
381
+ rewritten_schemas: list[dict[str, Any]] = []
382
+
383
+ for schema in schemas:
384
+ if '$defs' not in schema:
385
+ rewritten_schemas.append(schema)
386
+ continue
387
+
388
+ schema = schema.copy()
389
+ defs = schema.pop('$defs', None)
390
+ schema_name_mapping: dict[str, str] = {}
391
+
392
+ # Process definitions and build mapping
393
+ for name, def_schema in defs.items():
394
+ if name not in all_defs:
395
+ all_defs[name] = def_schema
396
+ schema_name_mapping[name] = name
397
+ elif def_schema != all_defs[name]:
398
+ new_name = name
399
+ if title := schema.get('title'):
400
+ new_name = f'{title}_{name}'
401
+
402
+ i = 1
403
+ original_new_name = new_name
404
+ new_name = f'{new_name}_{i}'
405
+ while new_name in all_defs:
406
+ i += 1
407
+ new_name = f'{original_new_name}_{i}'
408
+
409
+ all_defs[new_name] = def_schema
410
+ schema_name_mapping[name] = new_name
411
+
412
+ _update_mapped_json_schema_refs(schema, schema_name_mapping)
413
+ rewritten_schemas.append(schema)
414
+
415
+ return rewritten_schemas, all_defs
416
+
417
+
418
+ def strip_markdown_fences(text: str) -> str:
419
+ if text.startswith('{'):
420
+ return text
421
+
422
+ regex = r'```(?:\w+)?\n(\{.*\})\n```'
423
+ match = re.search(regex, text, re.DOTALL)
424
+ if match:
425
+ return match.group(1)
426
+
427
+ return text
428
+
429
+
430
+ def get_union_args(tp: Any) -> tuple[Any, ...]:
431
+ """Extract the arguments of a Union type if `tp` is a union, otherwise return an empty tuple."""
432
+ if typing_objects.is_typealiastype(tp):
433
+ tp = tp.__value__
434
+
435
+ origin = get_origin(tp)
436
+ if is_union_origin(origin):
437
+ return get_args(tp)
438
+ else:
439
+ return ()
pydantic_ai/agent.py CHANGED
@@ -14,6 +14,7 @@ from opentelemetry.trace import NoOpTracer, use_span
14
14
  from pydantic.json_schema import GenerateJsonSchema
15
15
  from typing_extensions import Literal, Never, Self, TypeIs, TypeVar, deprecated
16
16
 
17
+ from pydantic_ai.profiles import ModelProfile
17
18
  from pydantic_graph import End, Graph, GraphRun, GraphRunContext
18
19
  from pydantic_graph._utils import get_event_loop
19
20
 
@@ -30,7 +31,8 @@ from . import (
30
31
  )
31
32
  from ._agent_graph import HistoryProcessor
32
33
  from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model
33
- from .result import FinalResult, OutputDataT, StreamedRunResult
34
+ from .output import OutputDataT, OutputSpec
35
+ from .result import FinalResult, StreamedRunResult
34
36
  from .settings import ModelSettings, merge_model_settings
35
37
  from .tools import (
36
38
  AgentDepsT,
@@ -91,7 +93,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
91
93
  """Class for defining "agents" - a way to have a specific type of "conversation" with an LLM.
92
94
 
93
95
  Agents are generic in the dependency type they take [`AgentDepsT`][pydantic_ai.tools.AgentDepsT]
94
- and the result data type they return, [`OutputDataT`][pydantic_ai.result.OutputDataT].
96
+ and the output type they return, [`OutputDataT`][pydantic_ai.output.OutputDataT].
95
97
 
96
98
  By default, if neither generic parameter is customised, agents have type `Agent[None, str]`.
97
99
 
@@ -128,7 +130,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
128
130
  be merged with this value, with the runtime argument taking priority.
129
131
  """
130
132
 
131
- output_type: _output.OutputType[OutputDataT]
133
+ output_type: OutputSpec[OutputDataT]
132
134
  """
133
135
  The type of data output by agent runs, used to validate the data returned by the model, defaults to `str`.
134
136
  """
@@ -141,7 +143,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
141
143
  _deps_type: type[AgentDepsT] = dataclasses.field(repr=False)
142
144
  _deprecated_result_tool_name: str | None = dataclasses.field(repr=False)
143
145
  _deprecated_result_tool_description: str | None = dataclasses.field(repr=False)
144
- _output_schema: _output.OutputSchema[OutputDataT] | None = dataclasses.field(repr=False)
146
+ _output_schema: _output.BaseOutputSchema[OutputDataT] = dataclasses.field(repr=False)
145
147
  _output_validators: list[_output.OutputValidator[AgentDepsT, OutputDataT]] = dataclasses.field(repr=False)
146
148
  _instructions: str | None = dataclasses.field(repr=False)
147
149
  _instructions_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False)
@@ -163,7 +165,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
163
165
  self,
164
166
  model: models.Model | models.KnownModelName | str | None = None,
165
167
  *,
166
- output_type: _output.OutputType[OutputDataT] = str,
168
+ output_type: OutputSpec[OutputDataT] = str,
167
169
  instructions: str
168
170
  | _system_prompt.SystemPromptFunc[AgentDepsT]
169
171
  | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]]
@@ -324,8 +326,14 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
324
326
  warnings.warn('`result_retries` is deprecated, use `max_result_retries` instead', DeprecationWarning)
325
327
  output_retries = result_retries
326
328
 
329
+ default_output_mode = (
330
+ self.model.profile.default_structured_output_mode if isinstance(self.model, models.Model) else None
331
+ )
327
332
  self._output_schema = _output.OutputSchema[OutputDataT].build(
328
- output_type, self._deprecated_result_tool_name, self._deprecated_result_tool_description
333
+ output_type,
334
+ default_mode=default_output_mode,
335
+ name=self._deprecated_result_tool_name,
336
+ description=self._deprecated_result_tool_description,
329
337
  )
330
338
  self._output_validators = []
331
339
 
@@ -382,7 +390,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
382
390
  self,
383
391
  user_prompt: str | Sequence[_messages.UserContent] | None = None,
384
392
  *,
385
- output_type: _output.OutputType[RunOutputDataT],
393
+ output_type: OutputSpec[RunOutputDataT],
386
394
  message_history: list[_messages.ModelMessage] | None = None,
387
395
  model: models.Model | models.KnownModelName | str | None = None,
388
396
  deps: AgentDepsT = None,
@@ -412,7 +420,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
412
420
  self,
413
421
  user_prompt: str | Sequence[_messages.UserContent] | None = None,
414
422
  *,
415
- output_type: _output.OutputType[RunOutputDataT] | None = None,
423
+ output_type: OutputSpec[RunOutputDataT] | None = None,
416
424
  message_history: list[_messages.ModelMessage] | None = None,
417
425
  model: models.Model | models.KnownModelName | str | None = None,
418
426
  deps: AgentDepsT = None,
@@ -500,7 +508,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
500
508
  self,
501
509
  user_prompt: str | Sequence[_messages.UserContent] | None,
502
510
  *,
503
- output_type: _output.OutputType[RunOutputDataT],
511
+ output_type: OutputSpec[RunOutputDataT],
504
512
  message_history: list[_messages.ModelMessage] | None = None,
505
513
  model: models.Model | models.KnownModelName | str | None = None,
506
514
  deps: AgentDepsT = None,
@@ -532,7 +540,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
532
540
  self,
533
541
  user_prompt: str | Sequence[_messages.UserContent] | None = None,
534
542
  *,
535
- output_type: _output.OutputType[RunOutputDataT] | None = None,
543
+ output_type: OutputSpec[RunOutputDataT] | None = None,
536
544
  message_history: list[_messages.ModelMessage] | None = None,
537
545
  model: models.Model | models.KnownModelName | str | None = None,
538
546
  deps: AgentDepsT = None,
@@ -631,7 +639,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
631
639
 
632
640
  deps = self._get_deps(deps)
633
641
  new_message_index = len(message_history) if message_history else 0
634
- output_schema = self._prepare_output_schema(output_type)
642
+ output_schema = self._prepare_output_schema(output_type, model_used.profile)
635
643
 
636
644
  output_type_ = output_type or self.output_type
637
645
 
@@ -674,14 +682,20 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
674
682
  )
675
683
 
676
684
  async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
677
- if self._instructions is None and not self._instructions_functions:
678
- return None
685
+ parts = [
686
+ self._instructions,
687
+ *[await func.run(run_context) for func in self._instructions_functions],
688
+ ]
679
689
 
680
- instructions = [self._instructions] if self._instructions else []
681
- for instructions_runner in self._instructions_functions:
682
- instructions.append(await instructions_runner.run(run_context))
683
- concatenated_instructions = '\n'.join(instruction for instruction in instructions if instruction)
684
- return concatenated_instructions.strip() if concatenated_instructions else None
690
+ model_profile = model_used.profile
691
+ if isinstance(output_schema, _output.PromptedOutputSchema):
692
+ instructions = output_schema.instructions(model_profile.prompted_output_template)
693
+ parts.append(instructions)
694
+
695
+ parts = [p for p in parts if p]
696
+ if not parts:
697
+ return None
698
+ return '\n\n'.join(parts).strip()
685
699
 
686
700
  # Copy the function tools so that retry state is agent-run-specific
687
701
  # Note that the retry count is reset to 0 when this happens due to the `default=0` and `init=False`.
@@ -705,6 +719,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
705
719
  tracer=tracer,
706
720
  prepare_tools=self._prepare_tools,
707
721
  get_instructions=get_instructions,
722
+ instrumentation_settings=instrumentation_settings,
708
723
  )
709
724
  start_node = _agent_graph.UserPromptNode[AgentDepsT](
710
725
  user_prompt=user_prompt,
@@ -779,7 +794,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
779
794
  self,
780
795
  user_prompt: str | Sequence[_messages.UserContent] | None = None,
781
796
  *,
782
- output_type: _output.OutputType[RunOutputDataT] | None = None,
797
+ output_type: OutputSpec[RunOutputDataT] | None = None,
783
798
  message_history: list[_messages.ModelMessage] | None = None,
784
799
  model: models.Model | models.KnownModelName | str | None = None,
785
800
  deps: AgentDepsT = None,
@@ -809,7 +824,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
809
824
  self,
810
825
  user_prompt: str | Sequence[_messages.UserContent] | None = None,
811
826
  *,
812
- output_type: _output.OutputType[RunOutputDataT] | None = None,
827
+ output_type: OutputSpec[RunOutputDataT] | None = None,
813
828
  message_history: list[_messages.ModelMessage] | None = None,
814
829
  model: models.Model | models.KnownModelName | str | None = None,
815
830
  deps: AgentDepsT = None,
@@ -892,7 +907,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
892
907
  self,
893
908
  user_prompt: str | Sequence[_messages.UserContent],
894
909
  *,
895
- output_type: _output.OutputType[RunOutputDataT],
910
+ output_type: OutputSpec[RunOutputDataT],
896
911
  message_history: list[_messages.ModelMessage] | None = None,
897
912
  model: models.Model | models.KnownModelName | str | None = None,
898
913
  deps: AgentDepsT = None,
@@ -923,7 +938,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
923
938
  self,
924
939
  user_prompt: str | Sequence[_messages.UserContent] | None = None,
925
940
  *,
926
- output_type: _output.OutputType[RunOutputDataT] | None = None,
941
+ output_type: OutputSpec[RunOutputDataT] | None = None,
927
942
  message_history: list[_messages.ModelMessage] | None = None,
928
943
  model: models.Model | models.KnownModelName | str | None = None,
929
944
  deps: AgentDepsT = None,
@@ -1002,10 +1017,13 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1002
1017
  async for maybe_part_event in streamed_response:
1003
1018
  if isinstance(maybe_part_event, _messages.PartStartEvent):
1004
1019
  new_part = maybe_part_event.part
1005
- if isinstance(new_part, _messages.TextPart):
1006
- if _output.allow_text_output(output_schema):
1007
- return FinalResult(s, None, None)
1008
- elif isinstance(new_part, _messages.ToolCallPart) and output_schema:
1020
+ if isinstance(new_part, _messages.TextPart) and isinstance(
1021
+ output_schema, _output.TextOutputSchema
1022
+ ):
1023
+ return FinalResult(s, None, None)
1024
+ elif isinstance(new_part, _messages.ToolCallPart) and isinstance(
1025
+ output_schema, _output.ToolOutputSchema
1026
+ ): # pragma: no branch
1009
1027
  for call, _ in output_schema.find_tool([new_part]):
1010
1028
  return FinalResult(s, call.tool_name, call.tool_call_id)
1011
1029
  return None
@@ -1561,8 +1579,8 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1561
1579
  if tool.name in self._function_tools:
1562
1580
  raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}')
1563
1581
 
1564
- if self._output_schema and tool.name in self._output_schema.tools:
1565
- raise exceptions.UserError(f'Tool name conflicts with result schema name: {tool.name!r}')
1582
+ if tool.name in self._output_schema.tools:
1583
+ raise exceptions.UserError(f'Tool name conflicts with output tool name: {tool.name!r}')
1566
1584
 
1567
1585
  self._function_tools[tool.name] = tool
1568
1586
 
@@ -1637,18 +1655,23 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1637
1655
  raise AttributeError('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.')
1638
1656
 
1639
1657
  def _prepare_output_schema(
1640
- self, output_type: _output.OutputType[RunOutputDataT] | None
1641
- ) -> _output.OutputSchema[RunOutputDataT] | None:
1658
+ self, output_type: OutputSpec[RunOutputDataT] | None, model_profile: ModelProfile
1659
+ ) -> _output.OutputSchema[RunOutputDataT]:
1642
1660
  if output_type is not None:
1643
1661
  if self._output_validators:
1644
1662
  raise exceptions.UserError('Cannot set a custom run `output_type` when the agent has output validators')
1645
- return _output.OutputSchema[RunOutputDataT].build(
1663
+ schema = _output.OutputSchema[RunOutputDataT].build(
1646
1664
  output_type,
1647
- self._deprecated_result_tool_name,
1648
- self._deprecated_result_tool_description,
1665
+ name=self._deprecated_result_tool_name,
1666
+ description=self._deprecated_result_tool_description,
1667
+ default_mode=model_profile.default_structured_output_mode,
1649
1668
  )
1650
1669
  else:
1651
- return self._output_schema # pyright: ignore[reportReturnType]
1670
+ schema = self._output_schema.with_default_mode(model_profile.default_structured_output_mode)
1671
+
1672
+ schema.raise_if_unsupported(model_profile)
1673
+
1674
+ return schema # pyright: ignore[reportReturnType]
1652
1675
 
1653
1676
  @staticmethod
1654
1677
  def is_model_request_node(
pydantic_ai/messages.py CHANGED
@@ -76,8 +76,11 @@ class SystemPromptPart:
76
76
  part_kind: Literal['system-prompt'] = 'system-prompt'
77
77
  """Part type identifier, this is available on all parts as a discriminator."""
78
78
 
79
- def otel_event(self, _settings: InstrumentationSettings) -> Event:
80
- return Event('gen_ai.system.message', body={'content': self.content, 'role': 'system'})
79
+ def otel_event(self, settings: InstrumentationSettings) -> Event:
80
+ return Event(
81
+ 'gen_ai.system.message',
82
+ body={'role': 'system', **({'content': self.content} if settings.include_content else {})},
83
+ )
81
84
 
82
85
  __repr__ = _utils.dataclasses_no_defaults_repr
83
86
 
@@ -362,12 +365,12 @@ class UserPromptPart:
362
365
  content = []
363
366
  for part in self.content:
364
367
  if isinstance(part, str):
365
- content.append(part)
368
+ content.append(part if settings.include_content else {'kind': 'text'})
366
369
  elif isinstance(part, (ImageUrl, AudioUrl, DocumentUrl, VideoUrl)):
367
- content.append({'kind': part.kind, 'url': part.url})
370
+ content.append({'kind': part.kind, **({'url': part.url} if settings.include_content else {})})
368
371
  elif isinstance(part, BinaryContent):
369
372
  converted_part = {'kind': part.kind, 'media_type': part.media_type}
370
- if settings.include_binary_content:
373
+ if settings.include_content and settings.include_binary_content:
371
374
  converted_part['binary_content'] = base64.b64encode(part.data).decode()
372
375
  content.append(converted_part)
373
376
  else:
@@ -414,10 +417,15 @@ class ToolReturnPart:
414
417
  else:
415
418
  return {'return_value': tool_return_ta.dump_python(self.content, mode='json')}
416
419
 
417
- def otel_event(self, _settings: InstrumentationSettings) -> Event:
420
+ def otel_event(self, settings: InstrumentationSettings) -> Event:
418
421
  return Event(
419
422
  'gen_ai.tool.message',
420
- body={'content': self.content, 'role': 'tool', 'id': self.tool_call_id, 'name': self.tool_name},
423
+ body={
424
+ **({'content': self.content} if settings.include_content else {}),
425
+ 'role': 'tool',
426
+ 'id': self.tool_call_id,
427
+ 'name': self.tool_name,
428
+ },
421
429
  )
422
430
 
423
431
  __repr__ = _utils.dataclasses_no_defaults_repr
@@ -473,14 +481,14 @@ class RetryPromptPart:
473
481
  description = f'{len(self.content)} validation errors: {json_errors.decode()}'
474
482
  return f'{description}\n\nFix the errors and try again.'
475
483
 
476
- def otel_event(self, _settings: InstrumentationSettings) -> Event:
484
+ def otel_event(self, settings: InstrumentationSettings) -> Event:
477
485
  if self.tool_name is None:
478
486
  return Event('gen_ai.user.message', body={'content': self.model_response(), 'role': 'user'})
479
487
  else:
480
488
  return Event(
481
489
  'gen_ai.tool.message',
482
490
  body={
483
- 'content': self.model_response(),
491
+ **({'content': self.model_response()} if settings.include_content else {}),
484
492
  'role': 'tool',
485
493
  'id': self.tool_call_id,
486
494
  'name': self.tool_name,
@@ -657,7 +665,7 @@ class ModelResponse:
657
665
  vendor_id: str | None = None
658
666
  """Vendor ID as specified by the model provider. This can be used to track the specific request to the model."""
659
667
 
660
- def otel_events(self) -> list[Event]:
668
+ def otel_events(self, settings: InstrumentationSettings) -> list[Event]:
661
669
  """Return OpenTelemetry events for the response."""
662
670
  result: list[Event] = []
663
671
 
@@ -683,7 +691,8 @@ class ModelResponse:
683
691
  elif isinstance(part, TextPart):
684
692
  if body.get('content'):
685
693
  body = new_event_body()
686
- body['content'] = part.content
694
+ if settings.include_content:
695
+ body['content'] = part.content
687
696
 
688
697
  return result
689
698
 
@@ -20,9 +20,12 @@ from typing_extensions import Literal, TypeAliasType, TypedDict
20
20
 
21
21
  from pydantic_ai.profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec
22
22
 
23
+ from .. import _utils
24
+ from .._output import OutputObjectDefinition
23
25
  from .._parts_manager import ModelResponsePartsManager
24
26
  from ..exceptions import UserError
25
27
  from ..messages import FileUrl, ModelMessage, ModelRequest, ModelResponse, ModelResponseStreamEvent, VideoUrl
28
+ from ..output import OutputMode
26
29
  from ..profiles._json_schema import JsonSchemaTransformer
27
30
  from ..settings import ModelSettings
28
31
  from ..tools import ToolDefinition
@@ -300,13 +303,18 @@ KnownModelName = TypeAliasType(
300
303
  """
301
304
 
302
305
 
303
- @dataclass
306
+ @dataclass(repr=False)
304
307
  class ModelRequestParameters:
305
308
  """Configuration for an agent's request to a model, specifically related to tools and output handling."""
306
309
 
307
310
  function_tools: list[ToolDefinition] = field(default_factory=list)
308
- allow_text_output: bool = True
311
+
312
+ output_mode: OutputMode = 'text'
313
+ output_object: OutputObjectDefinition | None = None
309
314
  output_tools: list[ToolDefinition] = field(default_factory=list)
315
+ allow_text_output: bool = True
316
+
317
+ __repr__ = _utils.dataclasses_no_defaults_repr
310
318
 
311
319
 
312
320
  class Model(ABC):
@@ -351,6 +359,11 @@ class Model(ABC):
351
359
  function_tools=[_customize_tool_def(transformer, t) for t in model_request_parameters.function_tools],
352
360
  output_tools=[_customize_tool_def(transformer, t) for t in model_request_parameters.output_tools],
353
361
  )
362
+ if output_object := model_request_parameters.output_object:
363
+ model_request_parameters = replace(
364
+ model_request_parameters,
365
+ output_object=_customize_output_object(transformer, output_object),
366
+ )
354
367
 
355
368
  return model_request_parameters
356
369
 
@@ -718,3 +731,9 @@ def _customize_tool_def(transformer: type[JsonSchemaTransformer], t: ToolDefinit
718
731
  if t.strict is None:
719
732
  t = replace(t, strict=schema_transformer.is_strict_compatible)
720
733
  return replace(t, parameters_json_schema=parameters_json_schema)
734
+
735
+
736
+ def _customize_output_object(transformer: type[JsonSchemaTransformer], o: OutputObjectDefinition):
737
+ schema_transformer = transformer(o.json_schema, strict=True)
738
+ son_schema = schema_transformer.walk()
739
+ return replace(o, json_schema=son_schema)
@@ -342,15 +342,13 @@ class AnthropicModel(Model):
342
342
  if response_part.content: # Only add non-empty text
343
343
  assistant_content_params.append(BetaTextBlockParam(text=response_part.content, type='text'))
344
344
  elif isinstance(response_part, ThinkingPart):
345
- # NOTE: We don't send ThinkingPart to the providers yet. If you are unsatisfied with this,
346
- # please open an issue. The below code is the code to send thinking to the provider.
347
- # assert response_part.signature is not None, 'Thinking part must have a signature'
348
- # assistant_content_params.append(
349
- # BetaThinkingBlockParam(
350
- # thinking=response_part.content, signature=response_part.signature, type='thinking'
351
- # )
352
- # )
353
- pass
345
+ # NOTE: We only send thinking part back for Anthropic, otherwise they raise an error.
346
+ if response_part.signature is not None: # pragma: no branch
347
+ assistant_content_params.append(
348
+ BetaThinkingBlockParam(
349
+ thinking=response_part.content, signature=response_part.signature, type='thinking'
350
+ )
351
+ )
354
352
  else:
355
353
  tool_use_block_param = BetaToolUseBlockParam(
356
354
  id=_guard_tool_call_id(t=response_part),
@@ -11,6 +11,8 @@ from typing import Callable, Union
11
11
 
12
12
  from typing_extensions import TypeAlias, assert_never, overload
13
13
 
14
+ from pydantic_ai.profiles import ModelProfileSpec
15
+
14
16
  from .. import _utils, usage
15
17
  from .._utils import PeekableAsyncStream
16
18
  from ..messages import (
@@ -49,14 +51,27 @@ class FunctionModel(Model):
49
51
  _system: str = field(default='function', repr=False)
50
52
 
51
53
  @overload
52
- def __init__(self, function: FunctionDef, *, model_name: str | None = None) -> None: ...
54
+ def __init__(
55
+ self, function: FunctionDef, *, model_name: str | None = None, profile: ModelProfileSpec | None = None
56
+ ) -> None: ...
53
57
 
54
58
  @overload
55
- def __init__(self, *, stream_function: StreamFunctionDef, model_name: str | None = None) -> None: ...
59
+ def __init__(
60
+ self,
61
+ *,
62
+ stream_function: StreamFunctionDef,
63
+ model_name: str | None = None,
64
+ profile: ModelProfileSpec | None = None,
65
+ ) -> None: ...
56
66
 
57
67
  @overload
58
68
  def __init__(
59
- self, function: FunctionDef, *, stream_function: StreamFunctionDef, model_name: str | None = None
69
+ self,
70
+ function: FunctionDef,
71
+ *,
72
+ stream_function: StreamFunctionDef,
73
+ model_name: str | None = None,
74
+ profile: ModelProfileSpec | None = None,
60
75
  ) -> None: ...
61
76
 
62
77
  def __init__(
@@ -65,6 +80,7 @@ class FunctionModel(Model):
65
80
  *,
66
81
  stream_function: StreamFunctionDef | None = None,
67
82
  model_name: str | None = None,
83
+ profile: ModelProfileSpec | None = None,
68
84
  ):
69
85
  """Initialize a `FunctionModel`.
70
86
 
@@ -74,6 +90,7 @@ class FunctionModel(Model):
74
90
  function: The function to call for non-streamed requests.
75
91
  stream_function: The function to call for streamed requests.
76
92
  model_name: The name of the model. If not provided, a name is generated from the function names.
93
+ profile: The model profile to use.
77
94
  """
78
95
  if function is None and stream_function is None:
79
96
  raise TypeError('Either `function` or `stream_function` must be provided')
@@ -83,6 +100,7 @@ class FunctionModel(Model):
83
100
  function_name = self.function.__name__ if self.function is not None else ''
84
101
  stream_function_name = self.stream_function.__name__ if self.stream_function is not None else ''
85
102
  self._model_name = model_name or f'function:{function_name}:{stream_function_name}'
103
+ self._profile = profile
86
104
 
87
105
  async def request(
88
106
  self,