pydantic-ai-slim 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (35) hide show
  1. pydantic_ai/__init__.py +2 -1
  2. pydantic_ai/_a2a.py +3 -4
  3. pydantic_ai/_agent_graph.py +5 -2
  4. pydantic_ai/_output.py +130 -20
  5. pydantic_ai/_utils.py +6 -1
  6. pydantic_ai/agent.py +13 -10
  7. pydantic_ai/common_tools/duckduckgo.py +5 -2
  8. pydantic_ai/exceptions.py +2 -2
  9. pydantic_ai/messages.py +6 -4
  10. pydantic_ai/models/__init__.py +34 -1
  11. pydantic_ai/models/anthropic.py +5 -2
  12. pydantic_ai/models/bedrock.py +5 -2
  13. pydantic_ai/models/cohere.py +5 -2
  14. pydantic_ai/models/fallback.py +1 -0
  15. pydantic_ai/models/function.py +13 -2
  16. pydantic_ai/models/gemini.py +13 -10
  17. pydantic_ai/models/google.py +5 -2
  18. pydantic_ai/models/groq.py +5 -2
  19. pydantic_ai/models/huggingface.py +463 -0
  20. pydantic_ai/models/instrumented.py +12 -12
  21. pydantic_ai/models/mistral.py +6 -3
  22. pydantic_ai/models/openai.py +16 -4
  23. pydantic_ai/models/test.py +22 -1
  24. pydantic_ai/models/wrapper.py +6 -0
  25. pydantic_ai/output.py +65 -1
  26. pydantic_ai/providers/__init__.py +4 -0
  27. pydantic_ai/providers/google.py +2 -2
  28. pydantic_ai/providers/google_vertex.py +10 -5
  29. pydantic_ai/providers/huggingface.py +88 -0
  30. pydantic_ai/result.py +16 -5
  31. {pydantic_ai_slim-0.4.1.dist-info → pydantic_ai_slim-0.4.3.dist-info}/METADATA +7 -5
  32. {pydantic_ai_slim-0.4.1.dist-info → pydantic_ai_slim-0.4.3.dist-info}/RECORD +35 -33
  33. {pydantic_ai_slim-0.4.1.dist-info → pydantic_ai_slim-0.4.3.dist-info}/WHEEL +0 -0
  34. {pydantic_ai_slim-0.4.1.dist-info → pydantic_ai_slim-0.4.3.dist-info}/entry_points.txt +0 -0
  35. {pydantic_ai_slim-0.4.1.dist-info → pydantic_ai_slim-0.4.3.dist-info}/licenses/LICENSE +0 -0
@@ -127,6 +127,7 @@ class AnthropicModel(Model):
127
127
  *,
128
128
  provider: Literal['anthropic'] | Provider[AsyncAnthropic] = 'anthropic',
129
129
  profile: ModelProfileSpec | None = None,
130
+ settings: ModelSettings | None = None,
130
131
  ):
131
132
  """Initialize an Anthropic model.
132
133
 
@@ -136,13 +137,15 @@ class AnthropicModel(Model):
136
137
  provider: The provider to use for the Anthropic API. Can be either the string 'anthropic' or an
137
138
  instance of `Provider[AsyncAnthropic]`. If not provided, the other parameters will be used.
138
139
  profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
140
+ settings: Default model settings for this model instance.
139
141
  """
140
142
  self._model_name = model_name
141
143
 
142
144
  if isinstance(provider, str):
143
145
  provider = infer_provider(provider)
144
146
  self.client = provider.client
145
- self._profile = profile or provider.model_profile
147
+
148
+ super().__init__(settings=settings, profile=profile or provider.model_profile)
146
149
 
147
150
  @property
148
151
  def base_url(self) -> str:
@@ -253,7 +256,7 @@ class AnthropicModel(Model):
253
256
  except APIStatusError as e:
254
257
  if (status_code := e.status_code) >= 400:
255
258
  raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
256
- raise # pragma: lax no cover
259
+ raise # pragma: no cover
257
260
 
258
261
  def _process_response(self, response: BetaMessage) -> ModelResponse:
259
262
  """Process a non-streamed response, and prepare a message to return."""
@@ -202,6 +202,7 @@ class BedrockConverseModel(Model):
202
202
  *,
203
203
  provider: Literal['bedrock'] | Provider[BaseClient] = 'bedrock',
204
204
  profile: ModelProfileSpec | None = None,
205
+ settings: ModelSettings | None = None,
205
206
  ):
206
207
  """Initialize a Bedrock model.
207
208
 
@@ -213,13 +214,15 @@ class BedrockConverseModel(Model):
213
214
  'bedrock' or an instance of `Provider[BaseClient]`. If not provided, a new provider will be
214
215
  created using the other parameters.
215
216
  profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
217
+ settings: Model-specific settings that will be used as defaults for this model.
216
218
  """
217
219
  self._model_name = model_name
218
220
 
219
221
  if isinstance(provider, str):
220
222
  provider = infer_provider(provider)
221
223
  self.client = cast('BedrockRuntimeClient', provider.client)
222
- self._profile = profile or provider.model_profile
224
+
225
+ super().__init__(settings=settings, profile=profile or provider.model_profile)
223
226
 
224
227
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolTypeDef]:
225
228
  tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
@@ -660,4 +663,4 @@ class _AsyncIteratorWrapper(Generic[T]):
660
663
  if type(e.__cause__) is StopIteration:
661
664
  raise StopAsyncIteration
662
665
  else:
663
- raise e # pragma: lax no cover
666
+ raise e # pragma: no cover
@@ -111,6 +111,7 @@ class CohereModel(Model):
111
111
  *,
112
112
  provider: Literal['cohere'] | Provider[AsyncClientV2] = 'cohere',
113
113
  profile: ModelProfileSpec | None = None,
114
+ settings: ModelSettings | None = None,
114
115
  ):
115
116
  """Initialize an Cohere model.
116
117
 
@@ -121,13 +122,15 @@ class CohereModel(Model):
121
122
  'cohere' or an instance of `Provider[AsyncClientV2]`. If not provided, a new provider will be
122
123
  created using the other parameters.
123
124
  profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
125
+ settings: Model-specific settings that will be used as defaults for this model.
124
126
  """
125
127
  self._model_name = model_name
126
128
 
127
129
  if isinstance(provider, str):
128
130
  provider = infer_provider(provider)
129
131
  self.client = provider.client
130
- self._profile = profile or provider.model_profile
132
+
133
+ super().__init__(settings=settings, profile=profile or provider.model_profile)
131
134
 
132
135
  @property
133
136
  def base_url(self) -> str:
@@ -180,7 +183,7 @@ class CohereModel(Model):
180
183
  except ApiError as e:
181
184
  if (status_code := e.status_code) and status_code >= 400:
182
185
  raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
183
- raise # pragma: lax no cover
186
+ raise # pragma: no cover
184
187
 
185
188
  def _process_response(self, response: ChatResponse) -> ModelResponse:
186
189
  """Process a non-streamed response, and prepare a message to return."""
@@ -42,6 +42,7 @@ class FallbackModel(Model):
42
42
  fallback_models: The names or instances of the fallback models to use upon failure.
43
43
  fallback_on: A callable or tuple of exceptions that should trigger a fallback.
44
44
  """
45
+ super().__init__()
45
46
  self.models = [infer_model(default_model), *[infer_model(m) for m in fallback_models]]
46
47
 
47
48
  if isinstance(fallback_on, tuple):
@@ -52,7 +52,12 @@ class FunctionModel(Model):
52
52
 
53
53
  @overload
54
54
  def __init__(
55
- self, function: FunctionDef, *, model_name: str | None = None, profile: ModelProfileSpec | None = None
55
+ self,
56
+ function: FunctionDef,
57
+ *,
58
+ model_name: str | None = None,
59
+ profile: ModelProfileSpec | None = None,
60
+ settings: ModelSettings | None = None,
56
61
  ) -> None: ...
57
62
 
58
63
  @overload
@@ -62,6 +67,7 @@ class FunctionModel(Model):
62
67
  stream_function: StreamFunctionDef,
63
68
  model_name: str | None = None,
64
69
  profile: ModelProfileSpec | None = None,
70
+ settings: ModelSettings | None = None,
65
71
  ) -> None: ...
66
72
 
67
73
  @overload
@@ -72,6 +78,7 @@ class FunctionModel(Model):
72
78
  stream_function: StreamFunctionDef,
73
79
  model_name: str | None = None,
74
80
  profile: ModelProfileSpec | None = None,
81
+ settings: ModelSettings | None = None,
75
82
  ) -> None: ...
76
83
 
77
84
  def __init__(
@@ -81,6 +88,7 @@ class FunctionModel(Model):
81
88
  stream_function: StreamFunctionDef | None = None,
82
89
  model_name: str | None = None,
83
90
  profile: ModelProfileSpec | None = None,
91
+ settings: ModelSettings | None = None,
84
92
  ):
85
93
  """Initialize a `FunctionModel`.
86
94
 
@@ -91,16 +99,19 @@ class FunctionModel(Model):
91
99
  stream_function: The function to call for streamed requests.
92
100
  model_name: The name of the model. If not provided, a name is generated from the function names.
93
101
  profile: The model profile to use.
102
+ settings: Model-specific settings that will be used as defaults for this model.
94
103
  """
95
104
  if function is None and stream_function is None:
96
105
  raise TypeError('Either `function` or `stream_function` must be provided')
106
+
97
107
  self.function = function
98
108
  self.stream_function = stream_function
99
109
 
100
110
  function_name = self.function.__name__ if self.function is not None else ''
101
111
  stream_function_name = self.stream_function.__name__ if self.stream_function is not None else ''
102
112
  self._model_name = model_name or f'function:{function_name}:{stream_function_name}'
103
- self._profile = profile
113
+
114
+ super().__init__(settings=settings, profile=profile)
104
115
 
105
116
  async def request(
106
117
  self,
@@ -133,6 +133,7 @@ class GeminiModel(Model):
133
133
  *,
134
134
  provider: Literal['google-gla', 'google-vertex'] | Provider[httpx.AsyncClient] = 'google-gla',
135
135
  profile: ModelProfileSpec | None = None,
136
+ settings: ModelSettings | None = None,
136
137
  ):
137
138
  """Initialize a Gemini model.
138
139
 
@@ -142,6 +143,7 @@ class GeminiModel(Model):
142
143
  'google-gla' or 'google-vertex' or an instance of `Provider[httpx.AsyncClient]`.
143
144
  If not provided, a new provider will be created using the other parameters.
144
145
  profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
146
+ settings: Default model settings for this model instance.
145
147
  """
146
148
  self._model_name = model_name
147
149
  self._provider = provider
@@ -151,7 +153,8 @@ class GeminiModel(Model):
151
153
  self._system = provider.name
152
154
  self.client = provider.client
153
155
  self._url = str(self.client.base_url)
154
- self._profile = profile or provider.model_profile
156
+
157
+ super().__init__(settings=settings, profile=profile or provider.model_profile)
155
158
 
156
159
  @property
157
160
  def base_url(self) -> str:
@@ -250,7 +253,7 @@ class GeminiModel(Model):
250
253
 
251
254
  if gemini_labels := model_settings.get('gemini_labels'):
252
255
  if self._system == 'google-vertex':
253
- request_data['labels'] = gemini_labels # pragma: lax no cover
256
+ request_data['labels'] = gemini_labels
254
257
 
255
258
  headers = {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}
256
259
  url = f'/{self._model_name}:{"streamGenerateContent" if streamed else "generateContent"}'
@@ -412,7 +415,7 @@ def _settings_to_generation_config(model_settings: GeminiModelSettings) -> _Gemi
412
415
  if (frequency_penalty := model_settings.get('frequency_penalty')) is not None:
413
416
  config['frequency_penalty'] = frequency_penalty
414
417
  if (thinkingConfig := model_settings.get('gemini_thinking_config')) is not None:
415
- config['thinking_config'] = thinkingConfig # pragma: lax no cover
418
+ config['thinking_config'] = thinkingConfig
416
419
  return config
417
420
 
418
421
 
@@ -921,10 +924,10 @@ def _ensure_decodeable(content: bytearray) -> bytearray:
921
924
 
922
925
  This is a temporary workaround until https://github.com/pydantic/pydantic-core/issues/1633 is resolved
923
926
  """
924
- while True:
925
- try:
926
- content.decode()
927
- except UnicodeDecodeError:
928
- content = content[:-1] # this will definitely succeed before we run out of bytes
929
- else:
930
- return content
927
+ try:
928
+ content.decode()
929
+ except UnicodeDecodeError as e:
930
+ # e.start marks the start of the invalid decoded bytes, so cut up to before the first invalid byte
931
+ return content[: e.start]
932
+ else:
933
+ return content
@@ -151,6 +151,7 @@ class GoogleModel(Model):
151
151
  *,
152
152
  provider: Literal['google-gla', 'google-vertex'] | Provider[genai.Client] = 'google-gla',
153
153
  profile: ModelProfileSpec | None = None,
154
+ settings: ModelSettings | None = None,
154
155
  ):
155
156
  """Initialize a Gemini model.
156
157
 
@@ -160,16 +161,18 @@ class GoogleModel(Model):
160
161
  'google-gla' or 'google-vertex' or an instance of `Provider[httpx.AsyncClient]`.
161
162
  If not provided, a new provider will be created using the other parameters.
162
163
  profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
164
+ settings: The model settings to use. Defaults to None.
163
165
  """
164
166
  self._model_name = model_name
165
167
 
166
168
  if isinstance(provider, str):
167
- provider = GoogleProvider(vertexai=provider == 'google-vertex') # pragma: lax no cover
169
+ provider = GoogleProvider(vertexai=provider == 'google-vertex')
168
170
 
169
171
  self._provider = provider
170
172
  self._system = provider.name
171
173
  self.client = provider.client
172
- self._profile = profile or provider.model_profile
174
+
175
+ super().__init__(settings=settings, profile=profile or provider.model_profile)
173
176
 
174
177
  @property
175
178
  def base_url(self) -> str:
@@ -120,6 +120,7 @@ class GroqModel(Model):
120
120
  *,
121
121
  provider: Literal['groq'] | Provider[AsyncGroq] = 'groq',
122
122
  profile: ModelProfileSpec | None = None,
123
+ settings: ModelSettings | None = None,
123
124
  ):
124
125
  """Initialize a Groq model.
125
126
 
@@ -130,13 +131,15 @@ class GroqModel(Model):
130
131
  'groq' or an instance of `Provider[AsyncGroq]`. If not provided, a new provider will be
131
132
  created using the other parameters.
132
133
  profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
134
+ settings: Model-specific settings that will be used as defaults for this model.
133
135
  """
134
136
  self._model_name = model_name
135
137
 
136
138
  if isinstance(provider, str):
137
139
  provider = infer_provider(provider)
138
140
  self.client = provider.client
139
- self._profile = profile or provider.model_profile
141
+
142
+ super().__init__(settings=settings, profile=profile or provider.model_profile)
140
143
 
141
144
  @property
142
145
  def base_url(self) -> str:
@@ -245,7 +248,7 @@ class GroqModel(Model):
245
248
  except APIStatusError as e:
246
249
  if (status_code := e.status_code) >= 400:
247
250
  raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
248
- raise # pragma: lax no cover
251
+ raise # pragma: no cover
249
252
 
250
253
  def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
251
254
  """Process a non-streamed response, and prepare a message to return."""