pydantic-ai-slim 0.8.1__py3-none-any.whl → 1.0.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (70) hide show
  1. pydantic_ai/__init__.py +28 -2
  2. pydantic_ai/_agent_graph.py +310 -140
  3. pydantic_ai/_function_schema.py +5 -5
  4. pydantic_ai/_griffe.py +2 -1
  5. pydantic_ai/_otel_messages.py +2 -2
  6. pydantic_ai/_output.py +31 -35
  7. pydantic_ai/_parts_manager.py +4 -4
  8. pydantic_ai/_run_context.py +3 -1
  9. pydantic_ai/_system_prompt.py +2 -2
  10. pydantic_ai/_tool_manager.py +3 -22
  11. pydantic_ai/_utils.py +14 -26
  12. pydantic_ai/ag_ui.py +7 -8
  13. pydantic_ai/agent/__init__.py +70 -9
  14. pydantic_ai/agent/abstract.py +35 -4
  15. pydantic_ai/agent/wrapper.py +6 -0
  16. pydantic_ai/builtin_tools.py +2 -2
  17. pydantic_ai/common_tools/duckduckgo.py +4 -2
  18. pydantic_ai/durable_exec/temporal/__init__.py +4 -2
  19. pydantic_ai/durable_exec/temporal/_agent.py +23 -2
  20. pydantic_ai/durable_exec/temporal/_function_toolset.py +53 -6
  21. pydantic_ai/durable_exec/temporal/_logfire.py +1 -1
  22. pydantic_ai/durable_exec/temporal/_mcp_server.py +2 -1
  23. pydantic_ai/durable_exec/temporal/_model.py +2 -2
  24. pydantic_ai/durable_exec/temporal/_run_context.py +2 -1
  25. pydantic_ai/durable_exec/temporal/_toolset.py +2 -1
  26. pydantic_ai/exceptions.py +45 -2
  27. pydantic_ai/format_prompt.py +2 -2
  28. pydantic_ai/mcp.py +2 -2
  29. pydantic_ai/messages.py +73 -25
  30. pydantic_ai/models/__init__.py +5 -4
  31. pydantic_ai/models/anthropic.py +5 -5
  32. pydantic_ai/models/bedrock.py +58 -56
  33. pydantic_ai/models/cohere.py +3 -3
  34. pydantic_ai/models/fallback.py +2 -2
  35. pydantic_ai/models/function.py +25 -23
  36. pydantic_ai/models/gemini.py +9 -12
  37. pydantic_ai/models/google.py +3 -3
  38. pydantic_ai/models/groq.py +4 -4
  39. pydantic_ai/models/huggingface.py +4 -4
  40. pydantic_ai/models/instrumented.py +30 -16
  41. pydantic_ai/models/mcp_sampling.py +3 -1
  42. pydantic_ai/models/mistral.py +6 -6
  43. pydantic_ai/models/openai.py +18 -27
  44. pydantic_ai/models/test.py +24 -4
  45. pydantic_ai/output.py +27 -32
  46. pydantic_ai/profiles/__init__.py +3 -3
  47. pydantic_ai/profiles/groq.py +1 -1
  48. pydantic_ai/profiles/openai.py +25 -4
  49. pydantic_ai/providers/anthropic.py +2 -3
  50. pydantic_ai/providers/bedrock.py +3 -2
  51. pydantic_ai/result.py +144 -41
  52. pydantic_ai/retries.py +10 -29
  53. pydantic_ai/run.py +12 -5
  54. pydantic_ai/tools.py +126 -22
  55. pydantic_ai/toolsets/__init__.py +4 -1
  56. pydantic_ai/toolsets/_dynamic.py +4 -4
  57. pydantic_ai/toolsets/abstract.py +18 -2
  58. pydantic_ai/toolsets/approval_required.py +32 -0
  59. pydantic_ai/toolsets/combined.py +7 -12
  60. pydantic_ai/toolsets/{deferred.py → external.py} +11 -5
  61. pydantic_ai/toolsets/filtered.py +1 -1
  62. pydantic_ai/toolsets/function.py +13 -4
  63. pydantic_ai/toolsets/wrapper.py +2 -1
  64. pydantic_ai/usage.py +7 -5
  65. {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/METADATA +5 -6
  66. pydantic_ai_slim-1.0.0b1.dist-info/RECORD +120 -0
  67. pydantic_ai_slim-0.8.1.dist-info/RECORD +0 -119
  68. {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/WHEEL +0 -0
  69. {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/entry_points.txt +0 -0
  70. {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/licenses/LICENSE +0 -0
@@ -2,14 +2,14 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import inspect
4
4
  import re
5
- from collections.abc import AsyncIterator, Awaitable, Iterable, Sequence
5
+ from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence
6
6
  from contextlib import asynccontextmanager
7
- from dataclasses import dataclass, field
7
+ from dataclasses import KW_ONLY, dataclass, field
8
8
  from datetime import datetime
9
9
  from itertools import chain
10
- from typing import Any, Callable, Union
10
+ from typing import Any, TypeAlias
11
11
 
12
- from typing_extensions import TypeAlias, assert_never, overload
12
+ from typing_extensions import assert_never, overload
13
13
 
14
14
  from .. import _utils, usage
15
15
  from .._run_context import RunContext
@@ -44,8 +44,8 @@ class FunctionModel(Model):
44
44
  Apart from `__init__`, all methods are private or match those of the base class.
45
45
  """
46
46
 
47
- function: FunctionDef | None = None
48
- stream_function: StreamFunctionDef | None = None
47
+ function: FunctionDef | None
48
+ stream_function: StreamFunctionDef | None
49
49
 
50
50
  _model_name: str = field(repr=False)
51
51
  _system: str = field(default='function', repr=False)
@@ -120,10 +120,10 @@ class FunctionModel(Model):
120
120
  model_request_parameters: ModelRequestParameters,
121
121
  ) -> ModelResponse:
122
122
  agent_info = AgentInfo(
123
- model_request_parameters.function_tools,
124
- model_request_parameters.allow_text_output,
125
- model_request_parameters.output_tools,
126
- model_settings,
123
+ function_tools=model_request_parameters.function_tools,
124
+ allow_text_output=model_request_parameters.allow_text_output,
125
+ output_tools=model_request_parameters.output_tools,
126
+ model_settings=model_settings,
127
127
  )
128
128
 
129
129
  assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
@@ -149,10 +149,10 @@ class FunctionModel(Model):
149
149
  run_context: RunContext[Any] | None = None,
150
150
  ) -> AsyncIterator[StreamedResponse]:
151
151
  agent_info = AgentInfo(
152
- model_request_parameters.function_tools,
153
- model_request_parameters.allow_text_output,
154
- model_request_parameters.output_tools,
155
- model_settings,
152
+ function_tools=model_request_parameters.function_tools,
153
+ allow_text_output=model_request_parameters.allow_text_output,
154
+ output_tools=model_request_parameters.output_tools,
155
+ model_settings=model_settings,
156
156
  )
157
157
 
158
158
  assert self.stream_function is not None, (
@@ -182,7 +182,7 @@ class FunctionModel(Model):
182
182
  return self._system
183
183
 
184
184
 
185
- @dataclass(frozen=True)
185
+ @dataclass(frozen=True, kw_only=True)
186
186
  class AgentInfo:
187
187
  """Information about an agent.
188
188
 
@@ -212,13 +212,17 @@ class DeltaToolCall:
212
212
 
213
213
  name: str | None = None
214
214
  """Incremental change to the name of the tool."""
215
+
215
216
  json_args: str | None = None
216
217
  """Incremental change to the arguments as JSON"""
218
+
219
+ _: KW_ONLY
220
+
217
221
  tool_call_id: str | None = None
218
222
  """Incremental change to the tool call ID."""
219
223
 
220
224
 
221
- @dataclass
225
+ @dataclass(kw_only=True)
222
226
  class DeltaThinkingPart:
223
227
  """Incremental change to a thinking part.
224
228
 
@@ -237,18 +241,16 @@ DeltaToolCalls: TypeAlias = dict[int, DeltaToolCall]
237
241
  DeltaThinkingCalls: TypeAlias = dict[int, DeltaThinkingPart]
238
242
  """A mapping of thinking call IDs to incremental changes."""
239
243
 
240
- # TODO: Change the signature to Callable[[list[ModelMessage], ModelSettings, ModelRequestParameters], ...]
241
- FunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], Union[ModelResponse, Awaitable[ModelResponse]]]
244
+ FunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], ModelResponse | Awaitable[ModelResponse]]
242
245
  """A function used to generate a non-streamed response."""
243
246
 
244
- # TODO: Change signature as indicated above
245
247
  StreamFunctionDef: TypeAlias = Callable[
246
- [list[ModelMessage], AgentInfo], AsyncIterator[Union[str, DeltaToolCalls, DeltaThinkingCalls]]
248
+ [list[ModelMessage], AgentInfo], AsyncIterator[str | DeltaToolCalls | DeltaThinkingCalls]
247
249
  ]
248
250
  """A function used to generate a streamed response.
249
251
 
250
- While this is defined as having return type of `AsyncIterator[Union[str, DeltaToolCalls, DeltaThinkingCalls]]`, it should
251
- really be considered as `Union[AsyncIterator[str], AsyncIterator[DeltaToolCalls], AsyncIterator[DeltaThinkingCalls]]`,
252
+ While this is defined as having return type of `AsyncIterator[str | DeltaToolCalls | DeltaThinkingCalls]`, it should
253
+ really be considered as `AsyncIterator[str] | AsyncIterator[DeltaToolCalls] | AsyncIterator[DeltaThinkingCalls]`,
252
254
 
253
255
  E.g. you need to yield all text, all `DeltaToolCalls`, or all `DeltaThinkingCalls`, not mix them.
254
256
  """
@@ -326,7 +328,7 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.RequestUsage:
326
328
  for message in messages:
327
329
  if isinstance(message, ModelRequest):
328
330
  for part in message.parts:
329
- if isinstance(part, (SystemPromptPart, UserPromptPart)):
331
+ if isinstance(part, SystemPromptPart | UserPromptPart):
330
332
  request_tokens += _estimate_string_tokens(part.content)
331
333
  elif isinstance(part, ToolReturnPart):
332
334
  request_tokens += _estimate_string_tokens(part.model_response_str())
@@ -5,7 +5,7 @@ from collections.abc import AsyncIterator, Sequence
5
5
  from contextlib import asynccontextmanager
6
6
  from dataclasses import dataclass, field
7
7
  from datetime import datetime
8
- from typing import Annotated, Any, Literal, Protocol, Union, cast
8
+ from typing import Annotated, Any, Literal, Protocol, cast
9
9
  from uuid import uuid4
10
10
 
11
11
  import httpx
@@ -51,7 +51,7 @@ LatestGeminiModelNames = Literal[
51
51
  ]
52
52
  """Latest Gemini models."""
53
53
 
54
- GeminiModelName = Union[str, LatestGeminiModelNames]
54
+ GeminiModelName = str | LatestGeminiModelNames
55
55
  """Possible Gemini model names.
56
56
 
57
57
  Since Gemini supports a variety of date-stamped models, we explicitly list the latest models but
@@ -615,7 +615,7 @@ def _content_model_response(m: ModelResponse) -> _GeminiContent:
615
615
  elif isinstance(item, TextPart):
616
616
  if item.content:
617
617
  parts.append(_GeminiTextPart(text=item.content))
618
- elif isinstance(item, (BuiltinToolCallPart, BuiltinToolReturnPart)): # pragma: no cover
618
+ elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
619
619
  # This is currently never returned from gemini
620
620
  pass
621
621
  else:
@@ -735,16 +735,13 @@ def _part_discriminator(v: Any) -> str:
735
735
 
736
736
  # See <https://ai.google.dev/api/caching#Part>
737
737
  # we don't currently support other part types
738
- # TODO discriminator
739
738
  _GeminiPartUnion = Annotated[
740
- Union[
741
- Annotated[_GeminiTextPart, pydantic.Tag('text')],
742
- Annotated[_GeminiFunctionCallPart, pydantic.Tag('function_call')],
743
- Annotated[_GeminiFunctionResponsePart, pydantic.Tag('function_response')],
744
- Annotated[_GeminiInlineDataPart, pydantic.Tag('inline_data')],
745
- Annotated[_GeminiFileDataPart, pydantic.Tag('file_data')],
746
- Annotated[_GeminiThoughtPart, pydantic.Tag('thought')],
747
- ],
739
+ Annotated[_GeminiTextPart, pydantic.Tag('text')]
740
+ | Annotated[_GeminiFunctionCallPart, pydantic.Tag('function_call')]
741
+ | Annotated[_GeminiFunctionResponsePart, pydantic.Tag('function_response')]
742
+ | Annotated[_GeminiInlineDataPart, pydantic.Tag('inline_data')]
743
+ | Annotated[_GeminiFileDataPart, pydantic.Tag('file_data')]
744
+ | Annotated[_GeminiThoughtPart, pydantic.Tag('thought')],
748
745
  pydantic.Discriminator(_part_discriminator),
749
746
  ]
750
747
 
@@ -5,7 +5,7 @@ from collections.abc import AsyncIterator, Awaitable
5
5
  from contextlib import asynccontextmanager
6
6
  from dataclasses import dataclass, field
7
7
  from datetime import datetime
8
- from typing import Any, Literal, Union, cast, overload
8
+ from typing import Any, Literal, cast, overload
9
9
  from uuid import uuid4
10
10
 
11
11
  from typing_extensions import assert_never
@@ -91,7 +91,7 @@ LatestGoogleModelNames = Literal[
91
91
  ]
92
92
  """Latest Gemini models."""
93
93
 
94
- GoogleModelName = Union[str, LatestGoogleModelNames]
94
+ GoogleModelName = str | LatestGoogleModelNames
95
95
  """Possible Gemini model names.
96
96
 
97
97
  Since Gemini supports a variety of date-stamped models, we explicitly list the latest models but
@@ -349,7 +349,7 @@ class GoogleModel(Model):
349
349
  'headers': {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}
350
350
  }
351
351
  if timeout := model_settings.get('timeout'):
352
- if isinstance(timeout, (int, float)):
352
+ if isinstance(timeout, int | float):
353
353
  http_options['timeout'] = int(1000 * timeout)
354
354
  else:
355
355
  raise UserError('Google does not support setting ModelSettings.timeout to a httpx.Timeout')
@@ -5,7 +5,7 @@ from collections.abc import AsyncIterable, AsyncIterator, Iterable
5
5
  from contextlib import asynccontextmanager
6
6
  from dataclasses import dataclass, field
7
7
  from datetime import datetime
8
- from typing import Any, Literal, Union, cast, overload
8
+ from typing import Any, Literal, cast, overload
9
9
 
10
10
  from typing_extensions import assert_never
11
11
 
@@ -88,7 +88,7 @@ PreviewGroqModelNames = Literal[
88
88
  ]
89
89
  """Preview Groq models from <https://console.groq.com/docs/models#preview-models>."""
90
90
 
91
- GroqModelName = Union[str, ProductionGroqModelNames, PreviewGroqModelNames]
91
+ GroqModelName = str | ProductionGroqModelNames | PreviewGroqModelNames
92
92
  """Possible Groq model names.
93
93
 
94
94
  Since Groq supports a variety of models and the list changes frequencly, we explicitly list the named models as of 2025-03-31
@@ -285,7 +285,7 @@ class GroqModel(Model):
285
285
  for c in choice.message.tool_calls:
286
286
  items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
287
287
  return ModelResponse(
288
- items,
288
+ parts=items,
289
289
  usage=_map_usage(response),
290
290
  model_name=response.model,
291
291
  timestamp=timestamp,
@@ -347,7 +347,7 @@ class GroqModel(Model):
347
347
  elif isinstance(item, ThinkingPart):
348
348
  # Skip thinking parts when mapping to Groq messages
349
349
  continue
350
- elif isinstance(item, (BuiltinToolCallPart, BuiltinToolReturnPart)): # pragma: no cover
350
+ elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
351
351
  # This is currently never returned from groq
352
352
  pass
353
353
  else:
@@ -5,7 +5,7 @@ from collections.abc import AsyncIterable, AsyncIterator
5
5
  from contextlib import asynccontextmanager
6
6
  from dataclasses import dataclass, field
7
7
  from datetime import datetime, timezone
8
- from typing import Any, Literal, Union, cast, overload
8
+ from typing import Any, Literal, cast, overload
9
9
 
10
10
  from typing_extensions import assert_never
11
11
 
@@ -88,7 +88,7 @@ LatestHuggingFaceModelNames = Literal[
88
88
  """Latest Hugging Face models."""
89
89
 
90
90
 
91
- HuggingFaceModelName = Union[str, LatestHuggingFaceModelNames]
91
+ HuggingFaceModelName = str | LatestHuggingFaceModelNames
92
92
  """Possible Hugging Face model names.
93
93
 
94
94
  You can browse available models [here](https://huggingface.co/models?pipeline_tag=text-generation&inference_provider=all&sort=trending).
@@ -267,7 +267,7 @@ class HuggingFaceModel(Model):
267
267
  for c in tool_calls:
268
268
  items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id))
269
269
  return ModelResponse(
270
- items,
270
+ parts=items,
271
271
  usage=_map_usage(response),
272
272
  model_name=response.model,
273
273
  timestamp=timestamp,
@@ -320,7 +320,7 @@ class HuggingFaceModel(Model):
320
320
  # please open an issue. The below code is the code to send thinking to the provider.
321
321
  # texts.append(f'<think>\n{item.content}\n</think>')
322
322
  pass
323
- elif isinstance(item, (BuiltinToolCallPart, BuiltinToolReturnPart)): # pragma: no cover
323
+ elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
324
324
  # This is currently never returned from huggingface
325
325
  pass
326
326
  else:
@@ -2,10 +2,11 @@ from __future__ import annotations
2
2
 
3
3
  import itertools
4
4
  import json
5
- from collections.abc import AsyncIterator, Iterator, Mapping
5
+ import warnings
6
+ from collections.abc import AsyncIterator, Callable, Iterator, Mapping
6
7
  from contextlib import asynccontextmanager, contextmanager
7
8
  from dataclasses import dataclass, field
8
- from typing import Any, Callable, Literal, cast
9
+ from typing import Any, Literal, cast
9
10
  from urllib.parse import urlparse
10
11
 
11
12
  from opentelemetry._events import (
@@ -93,36 +94,41 @@ class InstrumentationSettings:
93
94
  def __init__(
94
95
  self,
95
96
  *,
96
- event_mode: Literal['attributes', 'logs'] = 'attributes',
97
97
  tracer_provider: TracerProvider | None = None,
98
98
  meter_provider: MeterProvider | None = None,
99
- event_logger_provider: EventLoggerProvider | None = None,
100
99
  include_binary_content: bool = True,
101
100
  include_content: bool = True,
102
- version: Literal[1, 2] = 1,
101
+ version: Literal[1, 2] = 2,
102
+ event_mode: Literal['attributes', 'logs'] = 'attributes',
103
+ event_logger_provider: EventLoggerProvider | None = None,
103
104
  ):
104
105
  """Create instrumentation options.
105
106
 
106
107
  Args:
107
- event_mode: The mode for emitting events. If `'attributes'`, events are attached to the span as attributes.
108
- If `'logs'`, events are emitted as OpenTelemetry log-based events.
109
108
  tracer_provider: The OpenTelemetry tracer provider to use.
110
109
  If not provided, the global tracer provider is used.
111
110
  Calling `logfire.configure()` sets the global tracer provider, so most users don't need this.
112
111
  meter_provider: The OpenTelemetry meter provider to use.
113
112
  If not provided, the global meter provider is used.
114
113
  Calling `logfire.configure()` sets the global meter provider, so most users don't need this.
115
- event_logger_provider: The OpenTelemetry event logger provider to use.
116
- If not provided, the global event logger provider is used.
117
- Calling `logfire.configure()` sets the global event logger provider, so most users don't need this.
118
- This is only used if `event_mode='logs'`.
119
114
  include_binary_content: Whether to include binary content in the instrumentation events.
120
115
  include_content: Whether to include prompts, completions, and tool call arguments and responses
121
116
  in the instrumentation events.
122
- version: Version of the data format.
123
- Version 1 is based on the legacy event-based OpenTelemetry GenAI spec.
124
- Version 2 stores messages in the attributes `gen_ai.input.messages` and `gen_ai.output.messages`.
125
- Version 2 is still WIP and experimental, but will become the default in Pydantic AI v1.
117
+ version: Version of the data format. This is unrelated to the Pydantic AI package version.
118
+ Version 1 is based on the legacy event-based OpenTelemetry GenAI spec
119
+ and will be removed in a future release.
120
+ The parameters `event_mode` and `event_logger_provider` are only relevant for version 1.
121
+ Version 2 uses the newer OpenTelemetry GenAI spec and stores messages in the following attributes:
122
+ - `gen_ai.system_instructions` for instructions passed to the agent.
123
+ - `gen_ai.input.messages` and `gen_ai.output.messages` on model request spans.
124
+ - `pydantic_ai.all_messages` on agent run spans.
125
+ event_mode: The mode for emitting events in version 1.
126
+ If `'attributes'`, events are attached to the span as attributes.
127
+ If `'logs'`, events are emitted as OpenTelemetry log-based events.
128
+ event_logger_provider: The OpenTelemetry event logger provider to use.
129
+ If not provided, the global event logger provider is used.
130
+ Calling `logfire.configure()` sets the global event logger provider, so most users don't need this.
131
+ This is only used if `event_mode='logs'` and `version=1`.
126
132
  """
127
133
  from pydantic_ai import __version__
128
134
 
@@ -136,6 +142,14 @@ class InstrumentationSettings:
136
142
  self.event_mode = event_mode
137
143
  self.include_binary_content = include_binary_content
138
144
  self.include_content = include_content
145
+
146
+ if event_mode == 'logs' and version != 1:
147
+ warnings.warn(
148
+ 'event_mode is only relevant for version=1 which is deprecated and will be removed in a future release.',
149
+ stacklevel=2,
150
+ )
151
+ version = 1
152
+
139
153
  self.version = version
140
154
 
141
155
  # As specified in the OpenTelemetry GenAI metrics spec:
@@ -366,7 +380,7 @@ class InstrumentedModel(WrapperModel):
366
380
 
367
381
  if model_settings:
368
382
  for key in MODEL_SETTING_ATTRIBUTES:
369
- if isinstance(value := model_settings.get(key), (float, int)):
383
+ if isinstance(value := model_settings.get(key), float | int):
370
384
  attributes[f'gen_ai.request.{key}'] = value
371
385
 
372
386
  record_metrics: Callable[[], None] | None = None
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  from collections.abc import AsyncIterator
4
4
  from contextlib import asynccontextmanager
5
- from dataclasses import dataclass
5
+ from dataclasses import KW_ONLY, dataclass
6
6
  from typing import TYPE_CHECKING, Any, cast
7
7
 
8
8
  from .. import _mcp, exceptions
@@ -36,6 +36,8 @@ class MCPSamplingModel(Model):
36
36
  session: ServerSession
37
37
  """The MCP server session to use for sampling."""
38
38
 
39
+ _: KW_ONLY
40
+
39
41
  default_max_tokens: int = 16_384
40
42
  """Default max tokens to use if not set in [`ModelSettings`][pydantic_ai.settings.ModelSettings.max_tokens].
41
43
 
@@ -5,7 +5,7 @@ from collections.abc import AsyncIterable, AsyncIterator, Iterable
5
5
  from contextlib import asynccontextmanager
6
6
  from dataclasses import dataclass, field
7
7
  from datetime import datetime
8
- from typing import Any, Literal, Union, cast
8
+ from typing import Any, Literal, cast
9
9
 
10
10
  import pydantic_core
11
11
  from httpx import Timeout
@@ -90,7 +90,7 @@ LatestMistralModelNames = Literal[
90
90
  ]
91
91
  """Latest Mistral models."""
92
92
 
93
- MistralModelName = Union[str, LatestMistralModelNames]
93
+ MistralModelName = str | LatestMistralModelNames
94
94
  """Possible Mistral model names.
95
95
 
96
96
  Since Mistral supports a variety of date-stamped models, we explicitly list the most popular models but
@@ -117,7 +117,7 @@ class MistralModel(Model):
117
117
  """
118
118
 
119
119
  client: Mistral = field(repr=False)
120
- json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n"""
120
+ json_mode_schema_prompt: str
121
121
 
122
122
  _model_name: MistralModelName = field(repr=False)
123
123
  _provider: Provider[Mistral] = field(repr=False)
@@ -348,7 +348,7 @@ class MistralModel(Model):
348
348
  parts.append(tool)
349
349
 
350
350
  return ModelResponse(
351
- parts,
351
+ parts=parts,
352
352
  usage=_map_usage(response),
353
353
  model_name=response.model,
354
354
  timestamp=timestamp,
@@ -515,7 +515,7 @@ class MistralModel(Model):
515
515
  pass
516
516
  elif isinstance(part, ToolCallPart):
517
517
  tool_calls.append(self._map_tool_call(part))
518
- elif isinstance(part, (BuiltinToolCallPart, BuiltinToolReturnPart)): # pragma: no cover
518
+ elif isinstance(part, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
519
519
  # This is currently never returned from mistral
520
520
  pass
521
521
  else:
@@ -576,7 +576,7 @@ class MistralModel(Model):
576
576
  return MistralUserMessage(content=content)
577
577
 
578
578
 
579
- MistralToolCallId = Union[str, None]
579
+ MistralToolCallId = str | None
580
580
 
581
581
 
582
582
  @dataclass
@@ -6,7 +6,7 @@ from collections.abc import AsyncIterable, AsyncIterator, Sequence
6
6
  from contextlib import asynccontextmanager
7
7
  from dataclasses import dataclass, field
8
8
  from datetime import datetime
9
- from typing import Any, Literal, Union, cast, overload
9
+ from typing import Any, Literal, cast, overload
10
10
 
11
11
  from pydantic import ValidationError
12
12
  from typing_extensions import assert_never, deprecated
@@ -90,7 +90,7 @@ __all__ = (
90
90
  'OpenAIModelName',
91
91
  )
92
92
 
93
- OpenAIModelName = Union[str, AllModels]
93
+ OpenAIModelName = str | AllModels
94
94
  """
95
95
  Possible OpenAI model names.
96
96
 
@@ -409,13 +409,6 @@ class OpenAIChatModel(Model):
409
409
  for setting in unsupported_model_settings:
410
410
  model_settings.pop(setting, None)
411
411
 
412
- # TODO(Marcelo): Deprecate this in favor of `openai_unsupported_model_settings`.
413
- sampling_settings = (
414
- model_settings
415
- if OpenAIModelProfile.from_profile(self.profile).openai_supports_sampling_settings
416
- else OpenAIChatModelSettings()
417
- )
418
-
419
412
  try:
420
413
  extra_headers = model_settings.get('extra_headers', {})
421
414
  extra_headers.setdefault('User-Agent', get_user_agent())
@@ -437,13 +430,13 @@ class OpenAIChatModel(Model):
437
430
  web_search_options=web_search_options or NOT_GIVEN,
438
431
  service_tier=model_settings.get('openai_service_tier', NOT_GIVEN),
439
432
  prediction=model_settings.get('openai_prediction', NOT_GIVEN),
440
- temperature=sampling_settings.get('temperature', NOT_GIVEN),
441
- top_p=sampling_settings.get('top_p', NOT_GIVEN),
442
- presence_penalty=sampling_settings.get('presence_penalty', NOT_GIVEN),
443
- frequency_penalty=sampling_settings.get('frequency_penalty', NOT_GIVEN),
444
- logit_bias=sampling_settings.get('logit_bias', NOT_GIVEN),
445
- logprobs=sampling_settings.get('openai_logprobs', NOT_GIVEN),
446
- top_logprobs=sampling_settings.get('openai_top_logprobs', NOT_GIVEN),
433
+ temperature=model_settings.get('temperature', NOT_GIVEN),
434
+ top_p=model_settings.get('top_p', NOT_GIVEN),
435
+ presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
436
+ frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
437
+ logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
438
+ logprobs=model_settings.get('openai_logprobs', NOT_GIVEN),
439
+ top_logprobs=model_settings.get('openai_top_logprobs', NOT_GIVEN),
447
440
  extra_headers=extra_headers,
448
441
  extra_body=model_settings.get('extra_body'),
449
442
  )
@@ -512,7 +505,7 @@ class OpenAIChatModel(Model):
512
505
  part.tool_call_id = _guard_tool_call_id(part)
513
506
  items.append(part)
514
507
  return ModelResponse(
515
- items,
508
+ parts=items,
516
509
  usage=_map_usage(response),
517
510
  model_name=response.model,
518
511
  timestamp=timestamp,
@@ -582,7 +575,7 @@ class OpenAIChatModel(Model):
582
575
  elif isinstance(item, ToolCallPart):
583
576
  tool_calls.append(self._map_tool_call(item))
584
577
  # OpenAI doesn't return built-in tool calls
585
- elif isinstance(item, (BuiltinToolCallPart, BuiltinToolReturnPart)): # pragma: no cover
578
+ elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
586
579
  pass
587
580
  else:
588
581
  assert_never(item)
@@ -828,7 +821,7 @@ class OpenAIResponsesModel(Model):
828
821
  elif item.type == 'function_call':
829
822
  items.append(ToolCallPart(item.name, item.arguments, tool_call_id=item.call_id))
830
823
  return ModelResponse(
831
- items,
824
+ parts=items,
832
825
  usage=_map_usage(response),
833
826
  model_name=response.model,
834
827
  provider_response_id=response.id,
@@ -918,11 +911,9 @@ class OpenAIResponsesModel(Model):
918
911
  text = text or {}
919
912
  text['verbosity'] = verbosity
920
913
 
921
- sampling_settings = (
922
- model_settings
923
- if OpenAIModelProfile.from_profile(self.profile).openai_supports_sampling_settings
924
- else OpenAIResponsesModelSettings()
925
- )
914
+ unsupported_model_settings = OpenAIModelProfile.from_profile(self.profile).openai_unsupported_model_settings
915
+ for setting in unsupported_model_settings:
916
+ model_settings.pop(setting, None)
926
917
 
927
918
  try:
928
919
  extra_headers = model_settings.get('extra_headers', {})
@@ -936,8 +927,8 @@ class OpenAIResponsesModel(Model):
936
927
  tool_choice=tool_choice or NOT_GIVEN,
937
928
  max_output_tokens=model_settings.get('max_tokens', NOT_GIVEN),
938
929
  stream=stream,
939
- temperature=sampling_settings.get('temperature', NOT_GIVEN),
940
- top_p=sampling_settings.get('top_p', NOT_GIVEN),
930
+ temperature=model_settings.get('temperature', NOT_GIVEN),
931
+ top_p=model_settings.get('top_p', NOT_GIVEN),
941
932
  truncation=model_settings.get('openai_truncation', NOT_GIVEN),
942
933
  timeout=model_settings.get('timeout', NOT_GIVEN),
943
934
  service_tier=model_settings.get('openai_service_tier', NOT_GIVEN),
@@ -1049,7 +1040,7 @@ class OpenAIResponsesModel(Model):
1049
1040
  elif isinstance(item, ToolCallPart):
1050
1041
  openai_messages.append(self._map_tool_call(item))
1051
1042
  # OpenAI doesn't return built-in tool calls
1052
- elif isinstance(item, (BuiltinToolCallPart, BuiltinToolReturnPart)):
1043
+ elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart):
1053
1044
  pass
1054
1045
  elif isinstance(item, ThinkingPart):
1055
1046
  # NOTE: We don't send ThinkingPart to the providers yet. If you are unsatisfied with this,
@@ -195,7 +195,10 @@ class TestModel(Model):
195
195
  # if there are tools, the first thing we want to do is call all of them
196
196
  if tool_calls and not any(isinstance(m, ModelResponse) for m in messages):
197
197
  return ModelResponse(
198
- parts=[ToolCallPart(name, self.gen_tool_args(args)) for name, args in tool_calls],
198
+ parts=[
199
+ ToolCallPart(name, self.gen_tool_args(args), tool_call_id=f'pyd_ai_tool_call_id__{name}')
200
+ for name, args in tool_calls
201
+ ],
199
202
  model_name=self._model_name,
200
203
  )
201
204
 
@@ -220,6 +223,7 @@ class TestModel(Model):
220
223
  output_wrapper.value
221
224
  if isinstance(output_wrapper, _WrappedToolOutput) and output_wrapper.value is not None
222
225
  else self.gen_tool_args(tool),
226
+ tool_call_id=f'pyd_ai_tool_call_id__{tool.name}',
223
227
  )
224
228
  for tool in output_tools
225
229
  if tool.name in new_retry_names
@@ -250,11 +254,27 @@ class TestModel(Model):
250
254
  output_tool = output_tools[self.seed % len(output_tools)]
251
255
  if custom_output_args is not None:
252
256
  return ModelResponse(
253
- parts=[ToolCallPart(output_tool.name, custom_output_args)], model_name=self._model_name
257
+ parts=[
258
+ ToolCallPart(
259
+ output_tool.name,
260
+ custom_output_args,
261
+ tool_call_id=f'pyd_ai_tool_call_id__{output_tool.name}',
262
+ )
263
+ ],
264
+ model_name=self._model_name,
254
265
  )
255
266
  else:
256
267
  response_args = self.gen_tool_args(output_tool)
257
- return ModelResponse(parts=[ToolCallPart(output_tool.name, response_args)], model_name=self._model_name)
268
+ return ModelResponse(
269
+ parts=[
270
+ ToolCallPart(
271
+ output_tool.name,
272
+ response_args,
273
+ tool_call_id=f'pyd_ai_tool_call_id__{output_tool.name}',
274
+ )
275
+ ],
276
+ model_name=self._model_name,
277
+ )
258
278
 
259
279
 
260
280
  @dataclass
@@ -293,7 +313,7 @@ class TestStreamedResponse(StreamedResponse):
293
313
  yield self._parts_manager.handle_tool_call_part(
294
314
  vendor_part_id=i, tool_name=part.tool_name, args=part.args, tool_call_id=part.tool_call_id
295
315
  )
296
- elif isinstance(part, (BuiltinToolCallPart, BuiltinToolReturnPart)): # pragma: no cover
316
+ elif isinstance(part, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
297
317
  # NOTE: These parts are not generated by TestModel, but we need to handle them for type checking
298
318
  assert False, f'Unexpected part type in TestModel: {type(part).__name__}'
299
319
  elif isinstance(part, ThinkingPart): # pragma: no cover