pydantic-ai-slim 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (35) hide show
  1. pydantic_ai/__init__.py +2 -1
  2. pydantic_ai/_a2a.py +3 -4
  3. pydantic_ai/_agent_graph.py +5 -2
  4. pydantic_ai/_output.py +130 -20
  5. pydantic_ai/_utils.py +6 -1
  6. pydantic_ai/agent.py +13 -10
  7. pydantic_ai/common_tools/duckduckgo.py +5 -2
  8. pydantic_ai/exceptions.py +2 -2
  9. pydantic_ai/messages.py +6 -4
  10. pydantic_ai/models/__init__.py +34 -1
  11. pydantic_ai/models/anthropic.py +5 -2
  12. pydantic_ai/models/bedrock.py +5 -2
  13. pydantic_ai/models/cohere.py +5 -2
  14. pydantic_ai/models/fallback.py +1 -0
  15. pydantic_ai/models/function.py +13 -2
  16. pydantic_ai/models/gemini.py +13 -10
  17. pydantic_ai/models/google.py +5 -2
  18. pydantic_ai/models/groq.py +5 -2
  19. pydantic_ai/models/huggingface.py +463 -0
  20. pydantic_ai/models/instrumented.py +12 -12
  21. pydantic_ai/models/mistral.py +6 -3
  22. pydantic_ai/models/openai.py +16 -4
  23. pydantic_ai/models/test.py +22 -1
  24. pydantic_ai/models/wrapper.py +6 -0
  25. pydantic_ai/output.py +65 -1
  26. pydantic_ai/providers/__init__.py +4 -0
  27. pydantic_ai/providers/google.py +2 -2
  28. pydantic_ai/providers/google_vertex.py +10 -5
  29. pydantic_ai/providers/huggingface.py +88 -0
  30. pydantic_ai/result.py +16 -5
  31. {pydantic_ai_slim-0.4.1.dist-info → pydantic_ai_slim-0.4.3.dist-info}/METADATA +7 -5
  32. {pydantic_ai_slim-0.4.1.dist-info → pydantic_ai_slim-0.4.3.dist-info}/RECORD +35 -33
  33. {pydantic_ai_slim-0.4.1.dist-info → pydantic_ai_slim-0.4.3.dist-info}/WHEEL +0 -0
  34. {pydantic_ai_slim-0.4.1.dist-info → pydantic_ai_slim-0.4.3.dist-info}/entry_points.txt +0 -0
  35. {pydantic_ai_slim-0.4.1.dist-info → pydantic_ai_slim-0.4.3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,463 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import base64
4
+ from collections.abc import AsyncIterable, AsyncIterator
5
+ from contextlib import asynccontextmanager
6
+ from dataclasses import dataclass, field
7
+ from datetime import datetime, timezone
8
+ from typing import Literal, Union, cast, overload
9
+
10
+ from typing_extensions import assert_never
11
+
12
+ from pydantic_ai._thinking_part import split_content_into_text_and_thinking
13
+ from pydantic_ai.providers import Provider, infer_provider
14
+
15
+ from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
16
+ from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc
17
+ from ..messages import (
18
+ AudioUrl,
19
+ BinaryContent,
20
+ DocumentUrl,
21
+ ImageUrl,
22
+ ModelMessage,
23
+ ModelRequest,
24
+ ModelResponse,
25
+ ModelResponsePart,
26
+ ModelResponseStreamEvent,
27
+ RetryPromptPart,
28
+ SystemPromptPart,
29
+ TextPart,
30
+ ThinkingPart,
31
+ ToolCallPart,
32
+ ToolReturnPart,
33
+ UserPromptPart,
34
+ VideoUrl,
35
+ )
36
+ from ..settings import ModelSettings
37
+ from ..tools import ToolDefinition
38
+ from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests
39
+
40
+ try:
41
+ import aiohttp
42
+ from huggingface_hub import (
43
+ AsyncInferenceClient,
44
+ ChatCompletionInputMessage,
45
+ ChatCompletionInputMessageChunk,
46
+ ChatCompletionInputTool,
47
+ ChatCompletionInputToolCall,
48
+ ChatCompletionInputURL,
49
+ ChatCompletionOutput,
50
+ ChatCompletionOutputMessage,
51
+ ChatCompletionStreamOutput,
52
+ )
53
+ from huggingface_hub.errors import HfHubHTTPError
54
+
55
+ except ImportError as _import_error:
56
+ raise ImportError(
57
+ 'Please install `huggingface_hub` to use Hugging Face Inference Providers, '
58
+ 'you can use the `huggingface` optional group — `pip install "pydantic-ai-slim[huggingface]"`'
59
+ ) from _import_error
60
+
61
+ __all__ = (
62
+ 'HuggingFaceModel',
63
+ 'HuggingFaceModelSettings',
64
+ )
65
+
66
+
67
+ HFSystemPromptRole = Literal['system', 'user']
68
+
69
+ LatestHuggingFaceModelNames = Literal[
70
+ 'deepseek-ai/DeepSeek-R1',
71
+ 'meta-llama/Llama-3.3-70B-Instruct',
72
+ 'meta-llama/Llama-4-Maverick-17B-128E-Instruct',
73
+ 'meta-llama/Llama-4-Scout-17B-16E-Instruct',
74
+ 'Qwen/QwQ-32B',
75
+ 'Qwen/Qwen2.5-72B-Instruct',
76
+ 'Qwen/Qwen3-235B-A22B',
77
+ 'Qwen/Qwen3-32B',
78
+ ]
79
+ """Latest Hugging Face models."""
80
+
81
+
82
+ HuggingFaceModelName = Union[str, LatestHuggingFaceModelNames]
83
+ """Possible Hugging Face model names.
84
+
85
+ You can browse available models [here](https://huggingface.co/models?pipeline_tag=text-generation&inference_provider=all&sort=trending).
86
+ """
87
+
88
+
89
+ class HuggingFaceModelSettings(ModelSettings, total=False):
90
+ """Settings used for a Hugging Face model request."""
91
+
92
+ # ALL FIELDS MUST BE `huggingface_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
93
+ # This class is a placeholder for any future huggingface-specific settings
94
+
95
+
96
+ @dataclass(init=False)
97
+ class HuggingFaceModel(Model):
98
+ """A model that uses Hugging Face Inference Providers.
99
+
100
+ Internally, this uses the [HF Python client](https://github.com/huggingface/huggingface_hub) to interact with the API.
101
+
102
+ Apart from `__init__`, all methods are private or match those of the base class.
103
+ """
104
+
105
+ client: AsyncInferenceClient = field(repr=False)
106
+
107
+ _model_name: str = field(repr=False)
108
+ _system: str = field(default='huggingface', repr=False)
109
+
110
+ def __init__(
111
+ self,
112
+ model_name: str,
113
+ *,
114
+ provider: Literal['huggingface'] | Provider[AsyncInferenceClient] = 'huggingface',
115
+ ):
116
+ """Initialize a Hugging Face model.
117
+
118
+ Args:
119
+ model_name: The name of the Model to use. You can browse available models [here](https://huggingface.co/models?pipeline_tag=text-generation&inference_provider=all&sort=trending).
120
+ provider: The provider to use for Hugging Face Inference Providers. Can be either the string 'huggingface' or an
121
+ instance of `Provider[AsyncInferenceClient]`. If not provided, the other parameters will be used.
122
+ """
123
+ self._model_name = model_name
124
+ self._provider = provider
125
+ if isinstance(provider, str):
126
+ provider = infer_provider(provider)
127
+ self.client = provider.client
128
+
129
+ async def request(
130
+ self,
131
+ messages: list[ModelMessage],
132
+ model_settings: ModelSettings | None,
133
+ model_request_parameters: ModelRequestParameters,
134
+ ) -> ModelResponse:
135
+ check_allow_model_requests()
136
+ response = await self._completions_create(
137
+ messages, False, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters
138
+ )
139
+ model_response = self._process_response(response)
140
+ model_response.usage.requests = 1
141
+ return model_response
142
+
143
+ @asynccontextmanager
144
+ async def request_stream(
145
+ self,
146
+ messages: list[ModelMessage],
147
+ model_settings: ModelSettings | None,
148
+ model_request_parameters: ModelRequestParameters,
149
+ ) -> AsyncIterator[StreamedResponse]:
150
+ check_allow_model_requests()
151
+ response = await self._completions_create(
152
+ messages, True, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters
153
+ )
154
+ yield await self._process_streamed_response(response)
155
+
156
+ @property
157
+ def model_name(self) -> HuggingFaceModelName:
158
+ """The model name."""
159
+ return self._model_name
160
+
161
+ @property
162
+ def system(self) -> str:
163
+ """The system / model provider."""
164
+ return self._system
165
+
166
+ @overload
167
+ async def _completions_create(
168
+ self,
169
+ messages: list[ModelMessage],
170
+ stream: Literal[True],
171
+ model_settings: HuggingFaceModelSettings,
172
+ model_request_parameters: ModelRequestParameters,
173
+ ) -> AsyncIterable[ChatCompletionStreamOutput]: ...
174
+
175
+ @overload
176
+ async def _completions_create(
177
+ self,
178
+ messages: list[ModelMessage],
179
+ stream: Literal[False],
180
+ model_settings: HuggingFaceModelSettings,
181
+ model_request_parameters: ModelRequestParameters,
182
+ ) -> ChatCompletionOutput: ...
183
+
184
+ async def _completions_create(
185
+ self,
186
+ messages: list[ModelMessage],
187
+ stream: bool,
188
+ model_settings: HuggingFaceModelSettings,
189
+ model_request_parameters: ModelRequestParameters,
190
+ ) -> ChatCompletionOutput | AsyncIterable[ChatCompletionStreamOutput]:
191
+ tools = self._get_tools(model_request_parameters)
192
+
193
+ if not tools:
194
+ tool_choice: Literal['none', 'required', 'auto'] | None = None
195
+ elif not model_request_parameters.allow_text_output:
196
+ tool_choice = 'required'
197
+ else:
198
+ tool_choice = 'auto'
199
+
200
+ hf_messages = await self._map_messages(messages)
201
+
202
+ try:
203
+ return await self.client.chat.completions.create( # type: ignore
204
+ model=self._model_name,
205
+ messages=hf_messages, # type: ignore
206
+ tools=tools,
207
+ tool_choice=tool_choice or None,
208
+ stream=stream,
209
+ stop=model_settings.get('stop_sequences', None),
210
+ temperature=model_settings.get('temperature', None),
211
+ top_p=model_settings.get('top_p', None),
212
+ seed=model_settings.get('seed', None),
213
+ presence_penalty=model_settings.get('presence_penalty', None),
214
+ frequency_penalty=model_settings.get('frequency_penalty', None),
215
+ logit_bias=model_settings.get('logit_bias', None), # type: ignore
216
+ logprobs=model_settings.get('logprobs', None),
217
+ top_logprobs=model_settings.get('top_logprobs', None),
218
+ extra_body=model_settings.get('extra_body'), # type: ignore
219
+ )
220
+ except aiohttp.ClientResponseError as e:
221
+ raise ModelHTTPError(
222
+ status_code=e.status,
223
+ model_name=self.model_name,
224
+ body=e.response_error_payload, # type: ignore
225
+ ) from e
226
+ except HfHubHTTPError as e:
227
+ raise ModelHTTPError(
228
+ status_code=e.response.status_code,
229
+ model_name=self.model_name,
230
+ body=e.response.content,
231
+ ) from e
232
+
233
+ def _process_response(self, response: ChatCompletionOutput) -> ModelResponse:
234
+ """Process a non-streamed response, and prepare a message to return."""
235
+ if response.created:
236
+ timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
237
+ else:
238
+ timestamp = _now_utc()
239
+
240
+ choice = response.choices[0]
241
+ content = choice.message.content
242
+ tool_calls = choice.message.tool_calls
243
+
244
+ items: list[ModelResponsePart] = []
245
+
246
+ if content is not None:
247
+ items.extend(split_content_into_text_and_thinking(content))
248
+ if tool_calls is not None:
249
+ for c in tool_calls:
250
+ items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id))
251
+ return ModelResponse(
252
+ items,
253
+ usage=_map_usage(response),
254
+ model_name=response.model,
255
+ timestamp=timestamp,
256
+ vendor_id=response.id,
257
+ )
258
+
259
+ async def _process_streamed_response(self, response: AsyncIterable[ChatCompletionStreamOutput]) -> StreamedResponse:
260
+ """Process a streamed response, and prepare a streaming response to return."""
261
+ peekable_response = _utils.PeekableAsyncStream(response)
262
+ first_chunk = await peekable_response.peek()
263
+ if isinstance(first_chunk, _utils.Unset):
264
+ raise UnexpectedModelBehavior( # pragma: no cover
265
+ 'Streamed response ended without content or tool calls'
266
+ )
267
+
268
+ return HuggingFaceStreamedResponse(
269
+ _model_name=self._model_name,
270
+ _response=peekable_response,
271
+ _timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
272
+ )
273
+
274
+ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ChatCompletionInputTool]:
275
+ tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
276
+ if model_request_parameters.output_tools:
277
+ tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
278
+ return tools
279
+
280
+ async def _map_messages(
281
+ self, messages: list[ModelMessage]
282
+ ) -> list[ChatCompletionInputMessage | ChatCompletionOutputMessage]:
283
+ """Just maps a `pydantic_ai.Message` to a `huggingface_hub.ChatCompletionInputMessage`."""
284
+ hf_messages: list[ChatCompletionInputMessage | ChatCompletionOutputMessage] = []
285
+ for message in messages:
286
+ if isinstance(message, ModelRequest):
287
+ async for item in self._map_user_message(message):
288
+ hf_messages.append(item)
289
+ elif isinstance(message, ModelResponse):
290
+ texts: list[str] = []
291
+ tool_calls: list[ChatCompletionInputToolCall] = []
292
+ for item in message.parts:
293
+ if isinstance(item, TextPart):
294
+ texts.append(item.content)
295
+ elif isinstance(item, ToolCallPart):
296
+ tool_calls.append(self._map_tool_call(item))
297
+ elif isinstance(item, ThinkingPart):
298
+ # NOTE: We don't send ThinkingPart to the providers yet. If you are unsatisfied with this,
299
+ # please open an issue. The below code is the code to send thinking to the provider.
300
+ # texts.append(f'<think>\n{item.content}\n</think>')
301
+ pass
302
+ else:
303
+ assert_never(item)
304
+ message_param = ChatCompletionInputMessage(role='assistant') # type: ignore
305
+ if texts:
306
+ # Note: model responses from this model should only have one text item, so the following
307
+ # shouldn't merge multiple texts into one unless you switch models between runs:
308
+ message_param['content'] = '\n\n'.join(texts)
309
+ if tool_calls:
310
+ message_param['tool_calls'] = tool_calls
311
+ hf_messages.append(message_param)
312
+ else:
313
+ assert_never(message)
314
+ if instructions := self._get_instructions(messages):
315
+ hf_messages.insert(0, ChatCompletionInputMessage(content=instructions, role='system')) # type: ignore
316
+ return hf_messages
317
+
318
+ @staticmethod
319
+ def _map_tool_call(t: ToolCallPart) -> ChatCompletionInputToolCall:
320
+ return ChatCompletionInputToolCall.parse_obj_as_instance( # type: ignore
321
+ {
322
+ 'id': _guard_tool_call_id(t=t),
323
+ 'type': 'function',
324
+ 'function': {
325
+ 'name': t.tool_name,
326
+ 'arguments': t.args_as_json_str(),
327
+ },
328
+ }
329
+ )
330
+
331
+ @staticmethod
332
+ def _map_tool_definition(f: ToolDefinition) -> ChatCompletionInputTool:
333
+ tool_param: ChatCompletionInputTool = ChatCompletionInputTool.parse_obj_as_instance( # type: ignore
334
+ {
335
+ 'type': 'function',
336
+ 'function': {
337
+ 'name': f.name,
338
+ 'description': f.description,
339
+ 'parameters': f.parameters_json_schema,
340
+ },
341
+ }
342
+ )
343
+ if f.strict is not None:
344
+ tool_param['function']['strict'] = f.strict
345
+ return tool_param
346
+
347
+ async def _map_user_message(
348
+ self, message: ModelRequest
349
+ ) -> AsyncIterable[ChatCompletionInputMessage | ChatCompletionOutputMessage]:
350
+ for part in message.parts:
351
+ if isinstance(part, SystemPromptPart):
352
+ yield ChatCompletionInputMessage.parse_obj_as_instance({'role': 'system', 'content': part.content}) # type: ignore
353
+ elif isinstance(part, UserPromptPart):
354
+ yield await self._map_user_prompt(part)
355
+ elif isinstance(part, ToolReturnPart):
356
+ yield ChatCompletionOutputMessage.parse_obj_as_instance( # type: ignore
357
+ {
358
+ 'role': 'tool',
359
+ 'tool_call_id': _guard_tool_call_id(t=part),
360
+ 'content': part.model_response_str(),
361
+ }
362
+ )
363
+ elif isinstance(part, RetryPromptPart):
364
+ if part.tool_name is None:
365
+ yield ChatCompletionInputMessage.parse_obj_as_instance( # type: ignore
366
+ {'role': 'user', 'content': part.model_response()}
367
+ )
368
+ else:
369
+ yield ChatCompletionInputMessage.parse_obj_as_instance( # type: ignore
370
+ {
371
+ 'role': 'tool',
372
+ 'tool_call_id': _guard_tool_call_id(t=part),
373
+ 'content': part.model_response(),
374
+ }
375
+ )
376
+ else:
377
+ assert_never(part)
378
+
379
+ @staticmethod
380
+ async def _map_user_prompt(part: UserPromptPart) -> ChatCompletionInputMessage:
381
+ content: str | list[ChatCompletionInputMessage]
382
+ if isinstance(part.content, str):
383
+ content = part.content
384
+ else:
385
+ content = []
386
+ for item in part.content:
387
+ if isinstance(item, str):
388
+ content.append(ChatCompletionInputMessageChunk(type='text', text=item)) # type: ignore
389
+ elif isinstance(item, ImageUrl):
390
+ url = ChatCompletionInputURL(url=item.url) # type: ignore
391
+ content.append(ChatCompletionInputMessageChunk(type='image_url', image_url=url)) # type: ignore
392
+ elif isinstance(item, BinaryContent):
393
+ base64_encoded = base64.b64encode(item.data).decode('utf-8')
394
+ if item.is_image:
395
+ url = ChatCompletionInputURL(url=f'data:{item.media_type};base64,{base64_encoded}') # type: ignore
396
+ content.append(ChatCompletionInputMessageChunk(type='image_url', image_url=url)) # type: ignore
397
+ else: # pragma: no cover
398
+ raise RuntimeError(f'Unsupported binary content type: {item.media_type}')
399
+ elif isinstance(item, AudioUrl):
400
+ raise NotImplementedError('AudioUrl is not supported for Hugging Face')
401
+ elif isinstance(item, DocumentUrl):
402
+ raise NotImplementedError('DocumentUrl is not supported for Hugging Face')
403
+ elif isinstance(item, VideoUrl):
404
+ raise NotImplementedError('VideoUrl is not supported for Hugging Face')
405
+ else:
406
+ assert_never(item)
407
+ return ChatCompletionInputMessage(role='user', content=content) # type: ignore
408
+
409
+
410
+ @dataclass
411
+ class HuggingFaceStreamedResponse(StreamedResponse):
412
+ """Implementation of `StreamedResponse` for Hugging Face models."""
413
+
414
+ _model_name: str
415
+ _response: AsyncIterable[ChatCompletionStreamOutput]
416
+ _timestamp: datetime
417
+
418
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
419
+ async for chunk in self._response:
420
+ self._usage += _map_usage(chunk)
421
+
422
+ try:
423
+ choice = chunk.choices[0]
424
+ except IndexError:
425
+ continue
426
+
427
+ # Handle the text part of the response
428
+ content = choice.delta.content
429
+ if content is not None:
430
+ yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content)
431
+
432
+ for dtc in choice.delta.tool_calls or []:
433
+ maybe_event = self._parts_manager.handle_tool_call_delta(
434
+ vendor_part_id=dtc.index,
435
+ tool_name=dtc.function and dtc.function.name, # type: ignore
436
+ args=dtc.function and dtc.function.arguments,
437
+ tool_call_id=dtc.id,
438
+ )
439
+ if maybe_event is not None:
440
+ yield maybe_event
441
+
442
+ @property
443
+ def model_name(self) -> str:
444
+ """Get the model name of the response."""
445
+ return self._model_name
446
+
447
+ @property
448
+ def timestamp(self) -> datetime:
449
+ """Get the timestamp of the response."""
450
+ return self._timestamp
451
+
452
+
453
+ def _map_usage(response: ChatCompletionOutput | ChatCompletionStreamOutput) -> usage.Usage:
454
+ response_usage = response.usage
455
+ if response_usage is None:
456
+ return usage.Usage()
457
+
458
+ return usage.Usage(
459
+ request_tokens=response_usage.prompt_tokens,
460
+ response_tokens=response_usage.completion_tokens,
461
+ total_tokens=response_usage.total_tokens,
462
+ details=None,
463
+ )
@@ -138,7 +138,7 @@ class InstrumentationSettings:
138
138
  **tokens_histogram_kwargs,
139
139
  explicit_bucket_boundaries_advisory=TOKEN_HISTOGRAM_BOUNDARIES,
140
140
  )
141
- except TypeError: # pragma: lax no cover
141
+ except TypeError:
142
142
  # Older OTel/logfire versions don't support explicit_bucket_boundaries_advisory
143
143
  self.tokens_histogram = self.meter.create_histogram(
144
144
  **tokens_histogram_kwargs, # pyright: ignore
@@ -182,15 +182,15 @@ GEN_AI_SYSTEM_ATTRIBUTE = 'gen_ai.system'
182
182
  GEN_AI_REQUEST_MODEL_ATTRIBUTE = 'gen_ai.request.model'
183
183
 
184
184
 
185
- @dataclass
185
+ @dataclass(init=False)
186
186
  class InstrumentedModel(WrapperModel):
187
187
  """Model which wraps another model so that requests are instrumented with OpenTelemetry.
188
188
 
189
189
  See the [Debugging and Monitoring guide](https://ai.pydantic.dev/logfire/) for more info.
190
190
  """
191
191
 
192
- settings: InstrumentationSettings
193
- """Configuration for instrumenting requests."""
192
+ instrumentation_settings: InstrumentationSettings
193
+ """Instrumentation settings for this model."""
194
194
 
195
195
  def __init__(
196
196
  self,
@@ -198,7 +198,7 @@ class InstrumentedModel(WrapperModel):
198
198
  options: InstrumentationSettings | None = None,
199
199
  ) -> None:
200
200
  super().__init__(wrapped)
201
- self.settings = options or InstrumentationSettings()
201
+ self.instrumentation_settings = options or InstrumentationSettings()
202
202
 
203
203
  async def request(
204
204
  self,
@@ -260,7 +260,7 @@ class InstrumentedModel(WrapperModel):
260
260
 
261
261
  record_metrics: Callable[[], None] | None = None
262
262
  try:
263
- with self.settings.tracer.start_as_current_span(span_name, attributes=attributes) as span:
263
+ with self.instrumentation_settings.tracer.start_as_current_span(span_name, attributes=attributes) as span:
264
264
 
265
265
  def finish(response: ModelResponse):
266
266
  # FallbackModel updates these span attributes.
@@ -278,12 +278,12 @@ class InstrumentedModel(WrapperModel):
278
278
  'gen_ai.response.model': response_model,
279
279
  }
280
280
  if response.usage.request_tokens: # pragma: no branch
281
- self.settings.tokens_histogram.record(
281
+ self.instrumentation_settings.tokens_histogram.record(
282
282
  response.usage.request_tokens,
283
283
  {**metric_attributes, 'gen_ai.token.type': 'input'},
284
284
  )
285
285
  if response.usage.response_tokens: # pragma: no branch
286
- self.settings.tokens_histogram.record(
286
+ self.instrumentation_settings.tokens_histogram.record(
287
287
  response.usage.response_tokens,
288
288
  {**metric_attributes, 'gen_ai.token.type': 'output'},
289
289
  )
@@ -294,8 +294,8 @@ class InstrumentedModel(WrapperModel):
294
294
  if not span.is_recording():
295
295
  return
296
296
 
297
- events = self.settings.messages_to_otel_events(messages)
298
- for event in self.settings.messages_to_otel_events([response]):
297
+ events = self.instrumentation_settings.messages_to_otel_events(messages)
298
+ for event in self.instrumentation_settings.messages_to_otel_events([response]):
299
299
  events.append(
300
300
  Event(
301
301
  'gen_ai.choice',
@@ -328,9 +328,9 @@ class InstrumentedModel(WrapperModel):
328
328
  record_metrics()
329
329
 
330
330
  def _emit_events(self, span: Span, events: list[Event]) -> None:
331
- if self.settings.event_mode == 'logs':
331
+ if self.instrumentation_settings.event_mode == 'logs':
332
332
  for event in events:
333
- self.settings.event_logger.emit(event)
333
+ self.instrumentation_settings.event_logger.emit(event)
334
334
  else:
335
335
  attr_name = 'events'
336
336
  span.set_attributes(
@@ -75,7 +75,7 @@ try:
75
75
  from mistralai.models.usermessage import UserMessage as MistralUserMessage
76
76
  from mistralai.types.basemodel import Unset as MistralUnset
77
77
  from mistralai.utils.eventstreaming import EventStreamAsync as MistralEventStreamAsync
78
- except ImportError as e: # pragma: lax no cover
78
+ except ImportError as e: # pragma: no cover
79
79
  raise ImportError(
80
80
  'Please install `mistral` to use the Mistral model, '
81
81
  'you can use the `mistral` optional group — `pip install "pydantic-ai-slim[mistral]"`'
@@ -125,6 +125,7 @@ class MistralModel(Model):
125
125
  provider: Literal['mistral'] | Provider[Mistral] = 'mistral',
126
126
  profile: ModelProfileSpec | None = None,
127
127
  json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n""",
128
+ settings: ModelSettings | None = None,
128
129
  ):
129
130
  """Initialize a Mistral model.
130
131
 
@@ -135,6 +136,7 @@ class MistralModel(Model):
135
136
  created using the other parameters.
136
137
  profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
137
138
  json_mode_schema_prompt: The prompt to show when the model expects a JSON object as input.
139
+ settings: Model-specific settings that will be used as defaults for this model.
138
140
  """
139
141
  self._model_name = model_name
140
142
  self.json_mode_schema_prompt = json_mode_schema_prompt
@@ -142,7 +144,8 @@ class MistralModel(Model):
142
144
  if isinstance(provider, str):
143
145
  provider = infer_provider(provider)
144
146
  self.client = provider.client
145
- self._profile = profile or provider.model_profile
147
+
148
+ super().__init__(settings=settings, profile=profile or provider.model_profile)
146
149
 
147
150
  @property
148
151
  def base_url(self) -> str:
@@ -214,7 +217,7 @@ class MistralModel(Model):
214
217
  except SDKError as e:
215
218
  if (status_code := e.status_code) >= 400:
216
219
  raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
217
- raise # pragma: lax no cover
220
+ raise # pragma: no cover
218
221
 
219
222
  assert response, 'A unexpected empty response from Mistral.'
220
223
  return response
@@ -195,6 +195,7 @@ class OpenAIModel(Model):
195
195
  | Provider[AsyncOpenAI] = 'openai',
196
196
  profile: ModelProfileSpec | None = None,
197
197
  system_prompt_role: OpenAISystemPromptRole | None = None,
198
+ settings: ModelSettings | None = None,
198
199
  ):
199
200
  """Initialize an OpenAI model.
200
201
 
@@ -206,16 +207,18 @@ class OpenAIModel(Model):
206
207
  profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
207
208
  system_prompt_role: The role to use for the system prompt message. If not provided, defaults to `'system'`.
208
209
  In the future, this may be inferred from the model name.
210
+ settings: Default model settings for this model instance.
209
211
  """
210
212
  self._model_name = model_name
211
213
 
212
214
  if isinstance(provider, str):
213
215
  provider = infer_provider(provider)
214
216
  self.client = provider.client
215
- self._profile = profile or provider.model_profile
216
217
 
217
218
  self.system_prompt_role = system_prompt_role
218
219
 
220
+ super().__init__(settings=settings, profile=profile or provider.model_profile)
221
+
219
222
  @property
220
223
  def base_url(self) -> str:
221
224
  return str(self.client.base_url)
@@ -342,7 +345,7 @@ class OpenAIModel(Model):
342
345
  except APIStatusError as e:
343
346
  if (status_code := e.status_code) >= 400:
344
347
  raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
345
- raise # pragma: lax no cover
348
+ raise # pragma: no cover
346
349
 
347
350
  def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
348
351
  """Process a non-streamed response, and prepare a message to return."""
@@ -598,6 +601,7 @@ class OpenAIResponsesModel(Model):
598
601
  provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together']
599
602
  | Provider[AsyncOpenAI] = 'openai',
600
603
  profile: ModelProfileSpec | None = None,
604
+ settings: ModelSettings | None = None,
601
605
  ):
602
606
  """Initialize an OpenAI Responses model.
603
607
 
@@ -605,13 +609,15 @@ class OpenAIResponsesModel(Model):
605
609
  model_name: The name of the OpenAI model to use.
606
610
  provider: The provider to use. Defaults to `'openai'`.
607
611
  profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
612
+ settings: Default model settings for this model instance.
608
613
  """
609
614
  self._model_name = model_name
610
615
 
611
616
  if isinstance(provider, str):
612
617
  provider = infer_provider(provider)
613
618
  self.client = provider.client
614
- self._profile = profile or provider.model_profile
619
+
620
+ super().__init__(settings=settings, profile=profile or provider.model_profile)
615
621
 
616
622
  @property
617
623
  def model_name(self) -> OpenAIModelName:
@@ -775,7 +781,7 @@ class OpenAIResponsesModel(Model):
775
781
  except APIStatusError as e:
776
782
  if (status_code := e.status_code) >= 400:
777
783
  raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
778
- raise # pragma: lax no cover
784
+ raise # pragma: no cover
779
785
 
780
786
  def _get_reasoning(self, model_settings: OpenAIResponsesModelSettings) -> Reasoning | NotGiven:
781
787
  reasoning_effort = model_settings.get('openai_reasoning_effort', None)
@@ -988,6 +994,12 @@ class OpenAIStreamedResponse(StreamedResponse):
988
994
  if content is not None:
989
995
  yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content)
990
996
 
997
+ # Handle reasoning part of the response, present in DeepSeek models
998
+ if reasoning_content := getattr(choice.delta, 'reasoning_content', None):
999
+ yield self._parts_manager.handle_thinking_delta(
1000
+ vendor_part_id='reasoning_content', content=reasoning_content
1001
+ )
1002
+
991
1003
  for dtc in choice.delta.tool_calls or []:
992
1004
  maybe_event = self._parts_manager.handle_tool_call_delta(
993
1005
  vendor_part_id=dtc.index,