pydantic-ai-slim 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -0,0 +1,56 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import dataclasses
4
+ from collections.abc import Sequence
5
+ from dataclasses import field
6
+ from typing import TYPE_CHECKING, Generic
7
+
8
+ from typing_extensions import TypeVar
9
+
10
+ from . import _utils, messages as _messages
11
+
12
+ if TYPE_CHECKING:
13
+ from .models import Model
14
+ from .result import Usage
15
+
16
+ AgentDepsT = TypeVar('AgentDepsT', default=None, contravariant=True)
17
+ """Type variable for agent dependencies."""
18
+
19
+
20
+ @dataclasses.dataclass(repr=False)
21
+ class RunContext(Generic[AgentDepsT]):
22
+ """Information about the current call."""
23
+
24
+ deps: AgentDepsT
25
+ """Dependencies for the agent."""
26
+ model: Model
27
+ """The model used in this run."""
28
+ usage: Usage
29
+ """LLM usage associated with the run."""
30
+ prompt: str | Sequence[_messages.UserContent] | None
31
+ """The original user prompt passed to the run."""
32
+ messages: list[_messages.ModelMessage] = field(default_factory=list)
33
+ """Messages exchanged in the conversation so far."""
34
+ tool_call_id: str | None = None
35
+ """The ID of the tool call."""
36
+ tool_name: str | None = None
37
+ """Name of the tool being called."""
38
+ retry: int = 0
39
+ """Number of retries so far."""
40
+ run_step: int = 0
41
+ """The current step in the run."""
42
+
43
+ def replace_with(
44
+ self,
45
+ retry: int | None = None,
46
+ tool_name: str | None | _utils.Unset = _utils.UNSET,
47
+ ) -> RunContext[AgentDepsT]:
48
+ # Create a new `RunContext` a new `retry` value and `tool_name`.
49
+ kwargs = {}
50
+ if retry is not None:
51
+ kwargs['retry'] = retry
52
+ if tool_name is not _utils.UNSET: # pragma: no branch
53
+ kwargs['tool_name'] = tool_name
54
+ return dataclasses.replace(self, **kwargs)
55
+
56
+ __repr__ = _utils.dataclasses_no_defaults_repr
@@ -6,7 +6,8 @@ from dataclasses import dataclass, field
6
6
  from typing import Any, Callable, Generic, cast
7
7
 
8
8
  from . import _utils
9
- from .tools import AgentDepsT, RunContext, SystemPromptFunc
9
+ from ._run_context import AgentDepsT, RunContext
10
+ from .tools import SystemPromptFunc
10
11
 
11
12
 
12
13
  @dataclass
pydantic_ai/_utils.py CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations as _annotations
3
3
  import asyncio
4
4
  import functools
5
5
  import inspect
6
+ import re
6
7
  import time
7
8
  import uuid
8
9
  from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterator
@@ -16,7 +17,17 @@ from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, overlo
16
17
  from anyio.to_thread import run_sync
17
18
  from pydantic import BaseModel, TypeAdapter
18
19
  from pydantic.json_schema import JsonSchemaValue
19
- from typing_extensions import ParamSpec, TypeAlias, TypeGuard, TypeIs, is_typeddict
20
+ from typing_extensions import (
21
+ ParamSpec,
22
+ TypeAlias,
23
+ TypeGuard,
24
+ TypeIs,
25
+ get_args,
26
+ get_origin,
27
+ is_typeddict,
28
+ )
29
+ from typing_inspection import typing_objects
30
+ from typing_inspection.introspection import is_union_origin
20
31
 
21
32
  from pydantic_graph._utils import AbstractSpan
22
33
 
@@ -327,3 +338,102 @@ def is_async_callable(obj: Any) -> Any:
327
338
  obj = obj.func
328
339
 
329
340
  return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__)) # type: ignore
341
+
342
+
343
+ def _update_mapped_json_schema_refs(s: dict[str, Any], name_mapping: dict[str, str]) -> None:
344
+ """Update $refs in a schema to use the new names from name_mapping."""
345
+ if '$ref' in s:
346
+ ref = s['$ref']
347
+ if ref.startswith('#/$defs/'): # pragma: no branch
348
+ original_name = ref[8:] # Remove '#/$defs/'
349
+ new_name = name_mapping.get(original_name, original_name)
350
+ s['$ref'] = f'#/$defs/{new_name}'
351
+
352
+ # Recursively update refs in properties
353
+ if 'properties' in s:
354
+ props: dict[str, dict[str, Any]] = s['properties']
355
+ for prop in props.values():
356
+ _update_mapped_json_schema_refs(prop, name_mapping)
357
+
358
+ # Handle arrays
359
+ if 'items' in s and isinstance(s['items'], dict):
360
+ items: dict[str, Any] = s['items']
361
+ _update_mapped_json_schema_refs(items, name_mapping)
362
+ if 'prefixItems' in s:
363
+ prefix_items: list[dict[str, Any]] = s['prefixItems']
364
+ for item in prefix_items:
365
+ _update_mapped_json_schema_refs(item, name_mapping)
366
+
367
+ # Handle unions
368
+ for union_type in ['anyOf', 'oneOf']:
369
+ if union_type in s:
370
+ union_items: list[dict[str, Any]] = s[union_type]
371
+ for item in union_items:
372
+ _update_mapped_json_schema_refs(item, name_mapping)
373
+
374
+
375
+ def merge_json_schema_defs(schemas: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], dict[str, dict[str, Any]]]:
376
+ """Merges the `$defs` from different JSON schemas into a single deduplicated `$defs`, handling name collisions of `$defs` that are not the same, and rewrites `$ref`s to point to the new `$defs`.
377
+
378
+ Returns a tuple of the rewritten schemas and a dictionary of the new `$defs`.
379
+ """
380
+ all_defs: dict[str, dict[str, Any]] = {}
381
+ rewritten_schemas: list[dict[str, Any]] = []
382
+
383
+ for schema in schemas:
384
+ if '$defs' not in schema:
385
+ rewritten_schemas.append(schema)
386
+ continue
387
+
388
+ schema = schema.copy()
389
+ defs = schema.pop('$defs', None)
390
+ schema_name_mapping: dict[str, str] = {}
391
+
392
+ # Process definitions and build mapping
393
+ for name, def_schema in defs.items():
394
+ if name not in all_defs:
395
+ all_defs[name] = def_schema
396
+ schema_name_mapping[name] = name
397
+ elif def_schema != all_defs[name]:
398
+ new_name = name
399
+ if title := schema.get('title'):
400
+ new_name = f'{title}_{name}'
401
+
402
+ i = 1
403
+ original_new_name = new_name
404
+ new_name = f'{new_name}_{i}'
405
+ while new_name in all_defs:
406
+ i += 1
407
+ new_name = f'{original_new_name}_{i}'
408
+
409
+ all_defs[new_name] = def_schema
410
+ schema_name_mapping[name] = new_name
411
+
412
+ _update_mapped_json_schema_refs(schema, schema_name_mapping)
413
+ rewritten_schemas.append(schema)
414
+
415
+ return rewritten_schemas, all_defs
416
+
417
+
418
+ def strip_markdown_fences(text: str) -> str:
419
+ if text.startswith('{'):
420
+ return text
421
+
422
+ regex = r'```(?:\w+)?\n(\{.*\})\n```'
423
+ match = re.search(regex, text, re.DOTALL)
424
+ if match:
425
+ return match.group(1)
426
+
427
+ return text
428
+
429
+
430
+ def get_union_args(tp: Any) -> tuple[Any, ...]:
431
+ """Extract the arguments of a Union type if `tp` is a union, otherwise return an empty tuple."""
432
+ if typing_objects.is_typealiastype(tp):
433
+ tp = tp.__value__
434
+
435
+ origin = get_origin(tp)
436
+ if is_union_origin(origin):
437
+ return get_args(tp)
438
+ else:
439
+ return ()
pydantic_ai/agent.py CHANGED
@@ -14,6 +14,7 @@ from opentelemetry.trace import NoOpTracer, use_span
14
14
  from pydantic.json_schema import GenerateJsonSchema
15
15
  from typing_extensions import Literal, Never, Self, TypeIs, TypeVar, deprecated
16
16
 
17
+ from pydantic_ai.profiles import ModelProfile
17
18
  from pydantic_graph import End, Graph, GraphRun, GraphRunContext
18
19
  from pydantic_graph._utils import get_event_loop
19
20
 
@@ -30,7 +31,8 @@ from . import (
30
31
  )
31
32
  from ._agent_graph import HistoryProcessor
32
33
  from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model
33
- from .result import FinalResult, OutputDataT, StreamedRunResult
34
+ from .output import OutputDataT, OutputSpec
35
+ from .result import FinalResult, StreamedRunResult
34
36
  from .settings import ModelSettings, merge_model_settings
35
37
  from .tools import (
36
38
  AgentDepsT,
@@ -91,7 +93,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
91
93
  """Class for defining "agents" - a way to have a specific type of "conversation" with an LLM.
92
94
 
93
95
  Agents are generic in the dependency type they take [`AgentDepsT`][pydantic_ai.tools.AgentDepsT]
94
- and the result data type they return, [`OutputDataT`][pydantic_ai.result.OutputDataT].
96
+ and the output type they return, [`OutputDataT`][pydantic_ai.output.OutputDataT].
95
97
 
96
98
  By default, if neither generic parameter is customised, agents have type `Agent[None, str]`.
97
99
 
@@ -128,7 +130,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
128
130
  be merged with this value, with the runtime argument taking priority.
129
131
  """
130
132
 
131
- output_type: _output.OutputType[OutputDataT]
133
+ output_type: OutputSpec[OutputDataT]
132
134
  """
133
135
  The type of data output by agent runs, used to validate the data returned by the model, defaults to `str`.
134
136
  """
@@ -141,7 +143,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
141
143
  _deps_type: type[AgentDepsT] = dataclasses.field(repr=False)
142
144
  _deprecated_result_tool_name: str | None = dataclasses.field(repr=False)
143
145
  _deprecated_result_tool_description: str | None = dataclasses.field(repr=False)
144
- _output_schema: _output.OutputSchema[OutputDataT] | None = dataclasses.field(repr=False)
146
+ _output_schema: _output.BaseOutputSchema[OutputDataT] = dataclasses.field(repr=False)
145
147
  _output_validators: list[_output.OutputValidator[AgentDepsT, OutputDataT]] = dataclasses.field(repr=False)
146
148
  _instructions: str | None = dataclasses.field(repr=False)
147
149
  _instructions_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False)
@@ -163,7 +165,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
163
165
  self,
164
166
  model: models.Model | models.KnownModelName | str | None = None,
165
167
  *,
166
- output_type: _output.OutputType[OutputDataT] = str,
168
+ output_type: OutputSpec[OutputDataT] = str,
167
169
  instructions: str
168
170
  | _system_prompt.SystemPromptFunc[AgentDepsT]
169
171
  | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]]
@@ -324,8 +326,14 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
324
326
  warnings.warn('`result_retries` is deprecated, use `max_result_retries` instead', DeprecationWarning)
325
327
  output_retries = result_retries
326
328
 
329
+ default_output_mode = (
330
+ self.model.profile.default_structured_output_mode if isinstance(self.model, models.Model) else None
331
+ )
327
332
  self._output_schema = _output.OutputSchema[OutputDataT].build(
328
- output_type, self._deprecated_result_tool_name, self._deprecated_result_tool_description
333
+ output_type,
334
+ default_mode=default_output_mode,
335
+ name=self._deprecated_result_tool_name,
336
+ description=self._deprecated_result_tool_description,
329
337
  )
330
338
  self._output_validators = []
331
339
 
@@ -382,7 +390,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
382
390
  self,
383
391
  user_prompt: str | Sequence[_messages.UserContent] | None = None,
384
392
  *,
385
- output_type: _output.OutputType[RunOutputDataT],
393
+ output_type: OutputSpec[RunOutputDataT],
386
394
  message_history: list[_messages.ModelMessage] | None = None,
387
395
  model: models.Model | models.KnownModelName | str | None = None,
388
396
  deps: AgentDepsT = None,
@@ -412,7 +420,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
412
420
  self,
413
421
  user_prompt: str | Sequence[_messages.UserContent] | None = None,
414
422
  *,
415
- output_type: _output.OutputType[RunOutputDataT] | None = None,
423
+ output_type: OutputSpec[RunOutputDataT] | None = None,
416
424
  message_history: list[_messages.ModelMessage] | None = None,
417
425
  model: models.Model | models.KnownModelName | str | None = None,
418
426
  deps: AgentDepsT = None,
@@ -500,7 +508,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
500
508
  self,
501
509
  user_prompt: str | Sequence[_messages.UserContent] | None,
502
510
  *,
503
- output_type: _output.OutputType[RunOutputDataT],
511
+ output_type: OutputSpec[RunOutputDataT],
504
512
  message_history: list[_messages.ModelMessage] | None = None,
505
513
  model: models.Model | models.KnownModelName | str | None = None,
506
514
  deps: AgentDepsT = None,
@@ -532,7 +540,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
532
540
  self,
533
541
  user_prompt: str | Sequence[_messages.UserContent] | None = None,
534
542
  *,
535
- output_type: _output.OutputType[RunOutputDataT] | None = None,
543
+ output_type: OutputSpec[RunOutputDataT] | None = None,
536
544
  message_history: list[_messages.ModelMessage] | None = None,
537
545
  model: models.Model | models.KnownModelName | str | None = None,
538
546
  deps: AgentDepsT = None,
@@ -631,7 +639,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
631
639
 
632
640
  deps = self._get_deps(deps)
633
641
  new_message_index = len(message_history) if message_history else 0
634
- output_schema = self._prepare_output_schema(output_type)
642
+ output_schema = self._prepare_output_schema(output_type, model_used.profile)
635
643
 
636
644
  output_type_ = output_type or self.output_type
637
645
 
@@ -674,14 +682,20 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
674
682
  )
675
683
 
676
684
  async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
677
- if self._instructions is None and not self._instructions_functions:
678
- return None
685
+ parts = [
686
+ self._instructions,
687
+ *[await func.run(run_context) for func in self._instructions_functions],
688
+ ]
679
689
 
680
- instructions = [self._instructions] if self._instructions else []
681
- for instructions_runner in self._instructions_functions:
682
- instructions.append(await instructions_runner.run(run_context))
683
- concatenated_instructions = '\n'.join(instruction for instruction in instructions if instruction)
684
- return concatenated_instructions.strip() if concatenated_instructions else None
690
+ model_profile = model_used.profile
691
+ if isinstance(output_schema, _output.PromptedOutputSchema):
692
+ instructions = output_schema.instructions(model_profile.prompted_output_template)
693
+ parts.append(instructions)
694
+
695
+ parts = [p for p in parts if p]
696
+ if not parts:
697
+ return None
698
+ return '\n\n'.join(parts).strip()
685
699
 
686
700
  # Copy the function tools so that retry state is agent-run-specific
687
701
  # Note that the retry count is reset to 0 when this happens due to the `default=0` and `init=False`.
@@ -779,7 +793,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
779
793
  self,
780
794
  user_prompt: str | Sequence[_messages.UserContent] | None = None,
781
795
  *,
782
- output_type: _output.OutputType[RunOutputDataT] | None = None,
796
+ output_type: OutputSpec[RunOutputDataT] | None = None,
783
797
  message_history: list[_messages.ModelMessage] | None = None,
784
798
  model: models.Model | models.KnownModelName | str | None = None,
785
799
  deps: AgentDepsT = None,
@@ -809,7 +823,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
809
823
  self,
810
824
  user_prompt: str | Sequence[_messages.UserContent] | None = None,
811
825
  *,
812
- output_type: _output.OutputType[RunOutputDataT] | None = None,
826
+ output_type: OutputSpec[RunOutputDataT] | None = None,
813
827
  message_history: list[_messages.ModelMessage] | None = None,
814
828
  model: models.Model | models.KnownModelName | str | None = None,
815
829
  deps: AgentDepsT = None,
@@ -892,7 +906,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
892
906
  self,
893
907
  user_prompt: str | Sequence[_messages.UserContent],
894
908
  *,
895
- output_type: _output.OutputType[RunOutputDataT],
909
+ output_type: OutputSpec[RunOutputDataT],
896
910
  message_history: list[_messages.ModelMessage] | None = None,
897
911
  model: models.Model | models.KnownModelName | str | None = None,
898
912
  deps: AgentDepsT = None,
@@ -923,7 +937,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
923
937
  self,
924
938
  user_prompt: str | Sequence[_messages.UserContent] | None = None,
925
939
  *,
926
- output_type: _output.OutputType[RunOutputDataT] | None = None,
940
+ output_type: OutputSpec[RunOutputDataT] | None = None,
927
941
  message_history: list[_messages.ModelMessage] | None = None,
928
942
  model: models.Model | models.KnownModelName | str | None = None,
929
943
  deps: AgentDepsT = None,
@@ -1002,10 +1016,13 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1002
1016
  async for maybe_part_event in streamed_response:
1003
1017
  if isinstance(maybe_part_event, _messages.PartStartEvent):
1004
1018
  new_part = maybe_part_event.part
1005
- if isinstance(new_part, _messages.TextPart):
1006
- if _output.allow_text_output(output_schema):
1007
- return FinalResult(s, None, None)
1008
- elif isinstance(new_part, _messages.ToolCallPart) and output_schema:
1019
+ if isinstance(new_part, _messages.TextPart) and isinstance(
1020
+ output_schema, _output.TextOutputSchema
1021
+ ):
1022
+ return FinalResult(s, None, None)
1023
+ elif isinstance(new_part, _messages.ToolCallPart) and isinstance(
1024
+ output_schema, _output.ToolOutputSchema
1025
+ ): # pragma: no branch
1009
1026
  for call, _ in output_schema.find_tool([new_part]):
1010
1027
  return FinalResult(s, call.tool_name, call.tool_call_id)
1011
1028
  return None
@@ -1561,8 +1578,8 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1561
1578
  if tool.name in self._function_tools:
1562
1579
  raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}')
1563
1580
 
1564
- if self._output_schema and tool.name in self._output_schema.tools:
1565
- raise exceptions.UserError(f'Tool name conflicts with result schema name: {tool.name!r}')
1581
+ if tool.name in self._output_schema.tools:
1582
+ raise exceptions.UserError(f'Tool name conflicts with output tool name: {tool.name!r}')
1566
1583
 
1567
1584
  self._function_tools[tool.name] = tool
1568
1585
 
@@ -1637,18 +1654,23 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1637
1654
  raise AttributeError('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.')
1638
1655
 
1639
1656
  def _prepare_output_schema(
1640
- self, output_type: _output.OutputType[RunOutputDataT] | None
1641
- ) -> _output.OutputSchema[RunOutputDataT] | None:
1657
+ self, output_type: OutputSpec[RunOutputDataT] | None, model_profile: ModelProfile
1658
+ ) -> _output.OutputSchema[RunOutputDataT]:
1642
1659
  if output_type is not None:
1643
1660
  if self._output_validators:
1644
1661
  raise exceptions.UserError('Cannot set a custom run `output_type` when the agent has output validators')
1645
- return _output.OutputSchema[RunOutputDataT].build(
1662
+ schema = _output.OutputSchema[RunOutputDataT].build(
1646
1663
  output_type,
1647
- self._deprecated_result_tool_name,
1648
- self._deprecated_result_tool_description,
1664
+ name=self._deprecated_result_tool_name,
1665
+ description=self._deprecated_result_tool_description,
1666
+ default_mode=model_profile.default_structured_output_mode,
1649
1667
  )
1650
1668
  else:
1651
- return self._output_schema # pyright: ignore[reportReturnType]
1669
+ schema = self._output_schema.with_default_mode(model_profile.default_structured_output_mode)
1670
+
1671
+ schema.raise_if_unsupported(model_profile)
1672
+
1673
+ return schema # pyright: ignore[reportReturnType]
1652
1674
 
1653
1675
  @staticmethod
1654
1676
  def is_model_request_node(
@@ -20,9 +20,12 @@ from typing_extensions import Literal, TypeAliasType, TypedDict
20
20
 
21
21
  from pydantic_ai.profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec
22
22
 
23
+ from .. import _utils
24
+ from .._output import OutputObjectDefinition
23
25
  from .._parts_manager import ModelResponsePartsManager
24
26
  from ..exceptions import UserError
25
27
  from ..messages import FileUrl, ModelMessage, ModelRequest, ModelResponse, ModelResponseStreamEvent, VideoUrl
28
+ from ..output import OutputMode
26
29
  from ..profiles._json_schema import JsonSchemaTransformer
27
30
  from ..settings import ModelSettings
28
31
  from ..tools import ToolDefinition
@@ -300,13 +303,18 @@ KnownModelName = TypeAliasType(
300
303
  """
301
304
 
302
305
 
303
- @dataclass
306
+ @dataclass(repr=False)
304
307
  class ModelRequestParameters:
305
308
  """Configuration for an agent's request to a model, specifically related to tools and output handling."""
306
309
 
307
310
  function_tools: list[ToolDefinition] = field(default_factory=list)
308
- allow_text_output: bool = True
311
+
312
+ output_mode: OutputMode = 'text'
313
+ output_object: OutputObjectDefinition | None = None
309
314
  output_tools: list[ToolDefinition] = field(default_factory=list)
315
+ allow_text_output: bool = True
316
+
317
+ __repr__ = _utils.dataclasses_no_defaults_repr
310
318
 
311
319
 
312
320
  class Model(ABC):
@@ -351,6 +359,11 @@ class Model(ABC):
351
359
  function_tools=[_customize_tool_def(transformer, t) for t in model_request_parameters.function_tools],
352
360
  output_tools=[_customize_tool_def(transformer, t) for t in model_request_parameters.output_tools],
353
361
  )
362
+ if output_object := model_request_parameters.output_object:
363
+ model_request_parameters = replace(
364
+ model_request_parameters,
365
+ output_object=_customize_output_object(transformer, output_object),
366
+ )
354
367
 
355
368
  return model_request_parameters
356
369
 
@@ -718,3 +731,9 @@ def _customize_tool_def(transformer: type[JsonSchemaTransformer], t: ToolDefinit
718
731
  if t.strict is None:
719
732
  t = replace(t, strict=schema_transformer.is_strict_compatible)
720
733
  return replace(t, parameters_json_schema=parameters_json_schema)
734
+
735
+
736
+ def _customize_output_object(transformer: type[JsonSchemaTransformer], o: OutputObjectDefinition):
737
+ schema_transformer = transformer(o.json_schema, strict=True)
738
+ son_schema = schema_transformer.walk()
739
+ return replace(o, json_schema=son_schema)
@@ -11,6 +11,8 @@ from typing import Callable, Union
11
11
 
12
12
  from typing_extensions import TypeAlias, assert_never, overload
13
13
 
14
+ from pydantic_ai.profiles import ModelProfileSpec
15
+
14
16
  from .. import _utils, usage
15
17
  from .._utils import PeekableAsyncStream
16
18
  from ..messages import (
@@ -49,14 +51,27 @@ class FunctionModel(Model):
49
51
  _system: str = field(default='function', repr=False)
50
52
 
51
53
  @overload
52
- def __init__(self, function: FunctionDef, *, model_name: str | None = None) -> None: ...
54
+ def __init__(
55
+ self, function: FunctionDef, *, model_name: str | None = None, profile: ModelProfileSpec | None = None
56
+ ) -> None: ...
53
57
 
54
58
  @overload
55
- def __init__(self, *, stream_function: StreamFunctionDef, model_name: str | None = None) -> None: ...
59
+ def __init__(
60
+ self,
61
+ *,
62
+ stream_function: StreamFunctionDef,
63
+ model_name: str | None = None,
64
+ profile: ModelProfileSpec | None = None,
65
+ ) -> None: ...
56
66
 
57
67
  @overload
58
68
  def __init__(
59
- self, function: FunctionDef, *, stream_function: StreamFunctionDef, model_name: str | None = None
69
+ self,
70
+ function: FunctionDef,
71
+ *,
72
+ stream_function: StreamFunctionDef,
73
+ model_name: str | None = None,
74
+ profile: ModelProfileSpec | None = None,
60
75
  ) -> None: ...
61
76
 
62
77
  def __init__(
@@ -65,6 +80,7 @@ class FunctionModel(Model):
65
80
  *,
66
81
  stream_function: StreamFunctionDef | None = None,
67
82
  model_name: str | None = None,
83
+ profile: ModelProfileSpec | None = None,
68
84
  ):
69
85
  """Initialize a `FunctionModel`.
70
86
 
@@ -74,6 +90,7 @@ class FunctionModel(Model):
74
90
  function: The function to call for non-streamed requests.
75
91
  stream_function: The function to call for streamed requests.
76
92
  model_name: The name of the model. If not provided, a name is generated from the function names.
93
+ profile: The model profile to use.
77
94
  """
78
95
  if function is None and stream_function is None:
79
96
  raise TypeError('Either `function` or `stream_function` must be provided')
@@ -83,6 +100,7 @@ class FunctionModel(Model):
83
100
  function_name = self.function.__name__ if self.function is not None else ''
84
101
  stream_function_name = self.stream_function.__name__ if self.stream_function is not None else ''
85
102
  self._model_name = model_name or f'function:{function_name}:{stream_function_name}'
103
+ self._profile = profile
86
104
 
87
105
  async def request(
88
106
  self,
@@ -16,6 +16,8 @@ from typing_extensions import NotRequired, TypedDict, assert_never
16
16
  from pydantic_ai.providers import Provider, infer_provider
17
17
 
18
18
  from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
19
+ from .._output import OutputObjectDefinition
20
+ from ..exceptions import UserError
19
21
  from ..messages import (
20
22
  BinaryContent,
21
23
  FileUrl,
@@ -203,12 +205,10 @@ class GeminiModel(Model):
203
205
  def _get_tool_config(
204
206
  self, model_request_parameters: ModelRequestParameters, tools: _GeminiTools | None
205
207
  ) -> _GeminiToolConfig | None:
206
- if model_request_parameters.allow_text_output:
207
- return None
208
- elif tools:
208
+ if not model_request_parameters.allow_text_output and tools:
209
209
  return _tool_config([t['name'] for t in tools['function_declarations']])
210
210
  else:
211
- return _tool_config([]) # pragma: no cover
211
+ return None
212
212
 
213
213
  @asynccontextmanager
214
214
  async def _make_request(
@@ -231,6 +231,18 @@ class GeminiModel(Model):
231
231
  request_data['toolConfig'] = tool_config
232
232
 
233
233
  generation_config = _settings_to_generation_config(model_settings)
234
+ if model_request_parameters.output_mode == 'native':
235
+ if tools:
236
+ raise UserError('Gemini does not support structured output and tools at the same time.')
237
+
238
+ generation_config['response_mime_type'] = 'application/json'
239
+
240
+ output_object = model_request_parameters.output_object
241
+ assert output_object is not None
242
+ generation_config['response_schema'] = self._map_response_schema(output_object)
243
+ elif model_request_parameters.output_mode == 'prompted' and not tools:
244
+ generation_config['response_mime_type'] = 'application/json'
245
+
234
246
  if generation_config:
235
247
  request_data['generationConfig'] = generation_config
236
248
 
@@ -376,6 +388,15 @@ class GeminiModel(Model):
376
388
  assert_never(item)
377
389
  return content
378
390
 
391
+ def _map_response_schema(self, o: OutputObjectDefinition) -> dict[str, Any]:
392
+ response_schema = o.json_schema.copy()
393
+ if o.name:
394
+ response_schema['title'] = o.name
395
+ if o.description:
396
+ response_schema['description'] = o.description
397
+
398
+ return response_schema
399
+
379
400
 
380
401
  def _settings_to_generation_config(model_settings: GeminiModelSettings) -> _GeminiGenerationConfig:
381
402
  config: _GeminiGenerationConfig = {}
@@ -577,6 +598,8 @@ class _GeminiGenerationConfig(TypedDict, total=False):
577
598
  frequency_penalty: float
578
599
  stop_sequences: list[str]
579
600
  thinking_config: ThinkingConfig
601
+ response_mime_type: str
602
+ response_schema: dict[str, Any]
580
603
 
581
604
 
582
605
  class _GeminiContent(TypedDict):