pydantic-ai-slim 0.2.15__py3-none-any.whl → 0.2.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,10 +17,8 @@ from pydantic_ai.providers import Provider, infer_provider
17
17
 
18
18
  from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
19
19
  from ..messages import (
20
- AudioUrl,
21
20
  BinaryContent,
22
- DocumentUrl,
23
- ImageUrl,
21
+ FileUrl,
24
22
  ModelMessage,
25
23
  ModelRequest,
26
24
  ModelResponse,
@@ -41,8 +39,8 @@ from . import (
41
39
  Model,
42
40
  ModelRequestParameters,
43
41
  StreamedResponse,
44
- cached_async_http_client,
45
42
  check_allow_model_requests,
43
+ download_item,
46
44
  get_user_agent,
47
45
  )
48
46
 
@@ -228,7 +226,7 @@ class GeminiModel(Model):
228
226
 
229
227
  if gemini_labels := model_settings.get('gemini_labels'):
230
228
  if self._system == 'google-vertex':
231
- request_data['labels'] = gemini_labels
229
+ request_data['labels'] = gemini_labels # pragma: lax no cover
232
230
 
233
231
  headers = {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}
234
232
  url = f'/{self._model_name}:{"streamGenerateContent" if streamed else "generateContent"}'
@@ -348,15 +346,19 @@ class GeminiModel(Model):
348
346
  content.append(
349
347
  _GeminiInlineDataPart(inline_data={'data': base64_encoded, 'mime_type': item.media_type})
350
348
  )
351
- elif isinstance(item, (AudioUrl, ImageUrl, DocumentUrl, VideoUrl)):
352
- client = cached_async_http_client()
353
- response = await client.get(item.url, follow_redirects=True)
354
- response.raise_for_status()
355
- mime_type = response.headers['Content-Type'].split(';')[0]
356
- inline_data = _GeminiInlineDataPart(
357
- inline_data={'data': base64.b64encode(response.content).decode('utf-8'), 'mime_type': mime_type}
358
- )
359
- content.append(inline_data)
349
+ elif isinstance(item, VideoUrl) and item.is_youtube:
350
+ file_data = _GeminiFileDataPart(file_data={'file_uri': item.url, 'mime_type': item.media_type})
351
+ content.append(file_data)
352
+ elif isinstance(item, FileUrl):
353
+ if self.system == 'google-gla' or item.force_download:
354
+ downloaded_item = await download_item(item, data_format='base64')
355
+ inline_data = _GeminiInlineDataPart(
356
+ inline_data={'data': downloaded_item['data'], 'mime_type': downloaded_item['data_type']}
357
+ )
358
+ content.append(inline_data)
359
+ else:
360
+ file_data = _GeminiFileDataPart(file_data={'file_uri': item.url, 'mime_type': item.media_type})
361
+ content.append(file_data)
360
362
  else:
361
363
  assert_never(item)
362
364
  return content
@@ -366,6 +368,8 @@ def _settings_to_generation_config(model_settings: GeminiModelSettings) -> _Gemi
366
368
  config: _GeminiGenerationConfig = {}
367
369
  if (max_tokens := model_settings.get('max_tokens')) is not None:
368
370
  config['max_output_tokens'] = max_tokens
371
+ if (stop_sequences := model_settings.get('stop_sequences')) is not None:
372
+ config['stop_sequences'] = stop_sequences # pragma: no cover
369
373
  if (temperature := model_settings.get('temperature')) is not None:
370
374
  config['temperature'] = temperature
371
375
  if (top_p := model_settings.get('top_p')) is not None:
@@ -14,10 +14,8 @@ from pydantic_ai.providers import Provider
14
14
 
15
15
  from .. import UnexpectedModelBehavior, _utils, usage
16
16
  from ..messages import (
17
- AudioUrl,
18
17
  BinaryContent,
19
- DocumentUrl,
20
- ImageUrl,
18
+ FileUrl,
21
19
  ModelMessage,
22
20
  ModelRequest,
23
21
  ModelResponse,
@@ -38,8 +36,8 @@ from . import (
38
36
  Model,
39
37
  ModelRequestParameters,
40
38
  StreamedResponse,
41
- cached_async_http_client,
42
39
  check_allow_model_requests,
40
+ download_item,
43
41
  get_user_agent,
44
42
  )
45
43
 
@@ -260,6 +258,7 @@ class GoogleModel(Model):
260
258
  temperature=model_settings.get('temperature'),
261
259
  top_p=model_settings.get('top_p'),
262
260
  max_output_tokens=model_settings.get('max_tokens'),
261
+ stop_sequences=model_settings.get('stop_sequences'),
263
262
  presence_penalty=model_settings.get('presence_penalty'),
264
263
  frequency_penalty=model_settings.get('frequency_penalty'),
265
264
  safety_settings=model_settings.get('google_safety_settings'),
@@ -371,13 +370,15 @@ class GoogleModel(Model):
371
370
  # NOTE: The type from Google GenAI is incorrect, it should be `str`, not `bytes`.
372
371
  base64_encoded = base64.b64encode(item.data).decode('utf-8')
373
372
  content.append({'inline_data': {'data': base64_encoded, 'mime_type': item.media_type}}) # type: ignore
374
- elif isinstance(item, (AudioUrl, ImageUrl, DocumentUrl, VideoUrl)):
375
- client = cached_async_http_client()
376
- response = await client.get(item.url, follow_redirects=True)
377
- response.raise_for_status()
378
- # NOTE: The type from Google GenAI is incorrect, it should be `str`, not `bytes`.
379
- base64_encoded = base64.b64encode(response.content).decode('utf-8')
380
- content.append({'inline_data': {'data': base64_encoded, 'mime_type': item.media_type}}) # type: ignore
373
+ elif isinstance(item, VideoUrl) and item.is_youtube:
374
+ content.append({'file_data': {'file_uri': item.url, 'mime_type': item.media_type}})
375
+ elif isinstance(item, FileUrl):
376
+ if self.system == 'google-gla' or item.force_download:
377
+ downloaded_item = await download_item(item, data_format='base64')
378
+ inline_data = {'data': downloaded_item['data'], 'mime_type': downloaded_item['data_type']}
379
+ content.append({'inline_data': inline_data}) # type: ignore
380
+ else:
381
+ content.append({'file_data': {'file_uri': item.url, 'mime_type': item.media_type}})
381
382
  else:
382
383
  assert_never(item)
383
384
  return content
@@ -4,13 +4,13 @@ import base64
4
4
  from collections.abc import AsyncIterable, AsyncIterator, Iterable
5
5
  from contextlib import asynccontextmanager
6
6
  from dataclasses import dataclass, field
7
- from datetime import datetime, timezone
7
+ from datetime import datetime
8
8
  from typing import Literal, Union, cast, overload
9
9
 
10
10
  from typing_extensions import assert_never
11
11
 
12
12
  from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
13
- from .._utils import guard_tool_call_id as _guard_tool_call_id
13
+ from .._utils import guard_tool_call_id as _guard_tool_call_id, number_to_datetime
14
14
  from ..messages import (
15
15
  BinaryContent,
16
16
  DocumentUrl,
@@ -246,7 +246,7 @@ class GroqModel(Model):
246
246
 
247
247
  def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
248
248
  """Process a non-streamed response, and prepare a message to return."""
249
- timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
249
+ timestamp = number_to_datetime(response.created)
250
250
  choice = response.choices[0]
251
251
  items: list[ModelResponsePart] = []
252
252
  if choice.message.content is not None:
@@ -270,7 +270,7 @@ class GroqModel(Model):
270
270
  return GroqStreamedResponse(
271
271
  _response=peekable_response,
272
272
  _model_name=self._model_name,
273
- _timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
273
+ _timestamp=number_to_datetime(first_chunk.created),
274
274
  )
275
275
 
276
276
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
@@ -13,6 +13,7 @@ from opentelemetry._events import (
13
13
  EventLoggerProvider, # pyright: ignore[reportPrivateImportUsage]
14
14
  get_event_logger_provider, # pyright: ignore[reportPrivateImportUsage]
15
15
  )
16
+ from opentelemetry.metrics import MeterProvider, get_meter_provider
16
17
  from opentelemetry.trace import Span, Tracer, TracerProvider, get_tracer_provider
17
18
  from opentelemetry.util.types import AttributeValue
18
19
  from pydantic import TypeAdapter
@@ -49,6 +50,10 @@ MODEL_SETTING_ATTRIBUTES: tuple[
49
50
 
50
51
  ANY_ADAPTER = TypeAdapter[Any](Any)
51
52
 
53
+ # These are in the spec:
54
+ # https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-metrics/#metric-gen_aiclienttokenusage
55
+ TOKEN_HISTOGRAM_BOUNDARIES = (1, 4, 16, 64, 256, 1024, 4096, 16384, 65536, 262144, 1048576, 4194304, 16777216, 67108864)
56
+
52
57
 
53
58
  def instrument_model(model: Model, instrument: InstrumentationSettings | bool) -> Model:
54
59
  """Instrument a model with OpenTelemetry/logfire."""
@@ -84,6 +89,7 @@ class InstrumentationSettings:
84
89
  *,
85
90
  event_mode: Literal['attributes', 'logs'] = 'attributes',
86
91
  tracer_provider: TracerProvider | None = None,
92
+ meter_provider: MeterProvider | None = None,
87
93
  event_logger_provider: EventLoggerProvider | None = None,
88
94
  include_binary_content: bool = True,
89
95
  ):
@@ -95,6 +101,9 @@ class InstrumentationSettings:
95
101
  tracer_provider: The OpenTelemetry tracer provider to use.
96
102
  If not provided, the global tracer provider is used.
97
103
  Calling `logfire.configure()` sets the global tracer provider, so most users don't need this.
104
+ meter_provider: The OpenTelemetry meter provider to use.
105
+ If not provided, the global meter provider is used.
106
+ Calling `logfire.configure()` sets the global meter provider, so most users don't need this.
98
107
  event_logger_provider: The OpenTelemetry event logger provider to use.
99
108
  If not provided, the global event logger provider is used.
100
109
  Calling `logfire.configure()` sets the global event logger provider, so most users don't need this.
@@ -104,12 +113,33 @@ class InstrumentationSettings:
104
113
  from pydantic_ai import __version__
105
114
 
106
115
  tracer_provider = tracer_provider or get_tracer_provider()
116
+ meter_provider = meter_provider or get_meter_provider()
107
117
  event_logger_provider = event_logger_provider or get_event_logger_provider()
108
- self.tracer = tracer_provider.get_tracer('pydantic-ai', __version__)
109
- self.event_logger = event_logger_provider.get_event_logger('pydantic-ai', __version__)
118
+ scope_name = 'pydantic-ai'
119
+ self.tracer = tracer_provider.get_tracer(scope_name, __version__)
120
+ self.meter = meter_provider.get_meter(scope_name, __version__)
121
+ self.event_logger = event_logger_provider.get_event_logger(scope_name, __version__)
110
122
  self.event_mode = event_mode
111
123
  self.include_binary_content = include_binary_content
112
124
 
125
+ # As specified in the OpenTelemetry GenAI metrics spec:
126
+ # https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-metrics/#metric-gen_aiclienttokenusage
127
+ tokens_histogram_kwargs = dict(
128
+ name='gen_ai.client.token.usage',
129
+ unit='{token}',
130
+ description='Measures number of input and output tokens used',
131
+ )
132
+ try:
133
+ self.tokens_histogram = self.meter.create_histogram(
134
+ **tokens_histogram_kwargs,
135
+ explicit_bucket_boundaries_advisory=TOKEN_HISTOGRAM_BOUNDARIES,
136
+ )
137
+ except TypeError:
138
+ # Older OTel/logfire versions don't support explicit_bucket_boundaries_advisory
139
+ self.tokens_histogram = self.meter.create_histogram(
140
+ **tokens_histogram_kwargs, # pyright: ignore
141
+ )
142
+
113
143
  def messages_to_otel_events(self, messages: list[ModelMessage]) -> list[Event]:
114
144
  """Convert a list of model messages to OpenTelemetry events.
115
145
 
@@ -224,38 +254,74 @@ class InstrumentedModel(WrapperModel):
224
254
  if isinstance(value := model_settings.get(key), (float, int)):
225
255
  attributes[f'gen_ai.request.{key}'] = value
226
256
 
227
- with self.settings.tracer.start_as_current_span(span_name, attributes=attributes) as span:
228
-
229
- def finish(response: ModelResponse):
230
- if not span.is_recording():
231
- return
232
-
233
- events = self.settings.messages_to_otel_events(messages)
234
- for event in self.settings.messages_to_otel_events([response]):
235
- events.append(
236
- Event(
237
- 'gen_ai.choice',
238
- body={
239
- # TODO finish_reason
240
- 'index': 0,
241
- 'message': event.body,
242
- },
257
+ record_metrics: Callable[[], None] | None = None
258
+ try:
259
+ with self.settings.tracer.start_as_current_span(span_name, attributes=attributes) as span:
260
+
261
+ def finish(response: ModelResponse):
262
+ # FallbackModel updates these span attributes.
263
+ attributes.update(getattr(span, 'attributes', {}))
264
+ request_model = attributes[GEN_AI_REQUEST_MODEL_ATTRIBUTE]
265
+ system = attributes[GEN_AI_SYSTEM_ATTRIBUTE]
266
+
267
+ response_model = response.model_name or request_model
268
+
269
+ def _record_metrics():
270
+ metric_attributes = {
271
+ GEN_AI_SYSTEM_ATTRIBUTE: system,
272
+ 'gen_ai.operation.name': operation,
273
+ 'gen_ai.request.model': request_model,
274
+ 'gen_ai.response.model': response_model,
275
+ }
276
+ if response.usage.request_tokens: # pragma: no branch
277
+ self.settings.tokens_histogram.record(
278
+ response.usage.request_tokens,
279
+ {**metric_attributes, 'gen_ai.token.type': 'input'},
280
+ )
281
+ if response.usage.response_tokens: # pragma: no branch
282
+ self.settings.tokens_histogram.record(
283
+ response.usage.response_tokens,
284
+ {**metric_attributes, 'gen_ai.token.type': 'output'},
285
+ )
286
+
287
+ nonlocal record_metrics
288
+ record_metrics = _record_metrics
289
+
290
+ if not span.is_recording():
291
+ return
292
+
293
+ events = self.settings.messages_to_otel_events(messages)
294
+ for event in self.settings.messages_to_otel_events([response]):
295
+ events.append(
296
+ Event(
297
+ 'gen_ai.choice',
298
+ body={
299
+ # TODO finish_reason
300
+ 'index': 0,
301
+ 'message': event.body,
302
+ },
303
+ )
243
304
  )
305
+ span.set_attributes(
306
+ {
307
+ **response.usage.opentelemetry_attributes(),
308
+ 'gen_ai.response.model': response_model,
309
+ }
244
310
  )
245
- new_attributes: dict[str, AttributeValue] = response.usage.opentelemetry_attributes() # pyright: ignore[reportAssignmentType]
246
- attributes.update(getattr(span, 'attributes', {}))
247
- request_model = attributes[GEN_AI_REQUEST_MODEL_ATTRIBUTE]
248
- new_attributes['gen_ai.response.model'] = response.model_name or request_model
249
- span.set_attributes(new_attributes)
250
- span.update_name(f'{operation} {request_model}')
251
- for event in events:
252
- event.attributes = {
253
- GEN_AI_SYSTEM_ATTRIBUTE: attributes[GEN_AI_SYSTEM_ATTRIBUTE],
254
- **(event.attributes or {}),
255
- }
256
- self._emit_events(span, events)
257
-
258
- yield finish
311
+ span.update_name(f'{operation} {request_model}')
312
+ for event in events:
313
+ event.attributes = {
314
+ GEN_AI_SYSTEM_ATTRIBUTE: system,
315
+ **(event.attributes or {}),
316
+ }
317
+ self._emit_events(span, events)
318
+
319
+ yield finish
320
+ finally:
321
+ if record_metrics:
322
+ # We only want to record metrics after the span is finished,
323
+ # to prevent them from being redundantly recorded in the span itself by logfire.
324
+ record_metrics()
259
325
 
260
326
  def _emit_events(self, span: Span, events: list[Event]) -> None:
261
327
  if self.settings.event_mode == 'logs':
@@ -4,7 +4,7 @@ import base64
4
4
  from collections.abc import AsyncIterable, AsyncIterator, Iterable
5
5
  from contextlib import asynccontextmanager
6
6
  from dataclasses import dataclass, field
7
- from datetime import datetime, timezone
7
+ from datetime import datetime
8
8
  from typing import Any, Literal, Union, cast
9
9
 
10
10
  import pydantic_core
@@ -12,7 +12,7 @@ from httpx import Timeout
12
12
  from typing_extensions import assert_never
13
13
 
14
14
  from .. import ModelHTTPError, UnexpectedModelBehavior, _utils
15
- from .._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc
15
+ from .._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc, number_to_datetime
16
16
  from ..messages import (
17
17
  BinaryContent,
18
18
  DocumentUrl,
@@ -312,7 +312,7 @@ class MistralModel(Model):
312
312
  assert response.choices, 'Unexpected empty response choice.'
313
313
 
314
314
  if response.created:
315
- timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
315
+ timestamp = number_to_datetime(response.created)
316
316
  else:
317
317
  timestamp = _now_utc()
318
318
 
@@ -347,9 +347,9 @@ class MistralModel(Model):
347
347
  )
348
348
 
349
349
  if first_chunk.data.created:
350
- timestamp = datetime.fromtimestamp(first_chunk.data.created, tz=timezone.utc)
350
+ timestamp = number_to_datetime(first_chunk.data.created)
351
351
  else:
352
- timestamp = datetime.now(tz=timezone.utc)
352
+ timestamp = _now_utc()
353
353
 
354
354
  return MistralStreamedResponse(
355
355
  _response=peekable_response,
@@ -5,7 +5,7 @@ import warnings
5
5
  from collections.abc import AsyncIterable, AsyncIterator, Sequence
6
6
  from contextlib import asynccontextmanager
7
7
  from dataclasses import dataclass, field
8
- from datetime import datetime, timezone
8
+ from datetime import datetime
9
9
  from typing import Any, Literal, Union, cast, overload
10
10
 
11
11
  from typing_extensions import assert_never
@@ -14,7 +14,7 @@ from pydantic_ai.profiles.openai import OpenAIModelProfile
14
14
  from pydantic_ai.providers import Provider, infer_provider
15
15
 
16
16
  from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
17
- from .._utils import guard_tool_call_id as _guard_tool_call_id
17
+ from .._utils import guard_tool_call_id as _guard_tool_call_id, number_to_datetime
18
18
  from ..messages import (
19
19
  AudioUrl,
20
20
  BinaryContent,
@@ -40,8 +40,8 @@ from . import (
40
40
  Model,
41
41
  ModelRequestParameters,
42
42
  StreamedResponse,
43
- cached_async_http_client,
44
43
  check_allow_model_requests,
44
+ download_item,
45
45
  get_user_agent,
46
46
  )
47
47
 
@@ -116,6 +116,13 @@ class OpenAIModelSettings(ModelSettings, total=False):
116
116
  See [OpenAI's safety best practices](https://platform.openai.com/docs/guides/safety-best-practices#end-user-ids) for more details.
117
117
  """
118
118
 
119
+ openai_service_tier: Literal['auto', 'default', 'flex']
120
+ """The service tier to use for the model request.
121
+
122
+ Currently supported values are `auto`, `default`, and `flex`.
123
+ For more information, see [OpenAI's service tiers documentation](https://platform.openai.com/docs/api-reference/chat/object#chat/object-service_tier).
124
+ """
125
+
119
126
 
120
127
  class OpenAIResponsesModelSettings(OpenAIModelSettings, total=False):
121
128
  """Settings used for an OpenAI Responses model request.
@@ -170,7 +177,7 @@ class OpenAIModel(Model):
170
177
  self,
171
178
  model_name: OpenAIModelName,
172
179
  *,
173
- provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together']
180
+ provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together', 'heroku']
174
181
  | Provider[AsyncOpenAI] = 'openai',
175
182
  profile: ModelProfileSpec | None = None,
176
183
  system_prompt_role: OpenAISystemPromptRole | None = None,
@@ -274,6 +281,12 @@ class OpenAIModel(Model):
274
281
 
275
282
  openai_messages = await self._map_messages(messages)
276
283
 
284
+ sampling_settings = (
285
+ model_settings
286
+ if OpenAIModelProfile.from_profile(self.profile).openai_supports_sampling_settings
287
+ else OpenAIModelSettings()
288
+ )
289
+
277
290
  try:
278
291
  extra_headers = model_settings.get('extra_headers', {})
279
292
  extra_headers.setdefault('User-Agent', get_user_agent())
@@ -287,17 +300,18 @@ class OpenAIModel(Model):
287
300
  stream_options={'include_usage': True} if stream else NOT_GIVEN,
288
301
  stop=model_settings.get('stop_sequences', NOT_GIVEN),
289
302
  max_completion_tokens=model_settings.get('max_tokens', NOT_GIVEN),
290
- temperature=model_settings.get('temperature', NOT_GIVEN),
291
- top_p=model_settings.get('top_p', NOT_GIVEN),
292
303
  timeout=model_settings.get('timeout', NOT_GIVEN),
293
304
  seed=model_settings.get('seed', NOT_GIVEN),
294
- presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
295
- frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
296
- logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
297
305
  reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN),
298
- logprobs=model_settings.get('openai_logprobs', NOT_GIVEN),
299
- top_logprobs=model_settings.get('openai_top_logprobs', NOT_GIVEN),
300
306
  user=model_settings.get('openai_user', NOT_GIVEN),
307
+ service_tier=model_settings.get('openai_service_tier', NOT_GIVEN),
308
+ temperature=sampling_settings.get('temperature', NOT_GIVEN),
309
+ top_p=sampling_settings.get('top_p', NOT_GIVEN),
310
+ presence_penalty=sampling_settings.get('presence_penalty', NOT_GIVEN),
311
+ frequency_penalty=sampling_settings.get('frequency_penalty', NOT_GIVEN),
312
+ logit_bias=sampling_settings.get('logit_bias', NOT_GIVEN),
313
+ logprobs=sampling_settings.get('openai_logprobs', NOT_GIVEN),
314
+ top_logprobs=sampling_settings.get('openai_top_logprobs', NOT_GIVEN),
301
315
  extra_headers=extra_headers,
302
316
  extra_body=model_settings.get('extra_body'),
303
317
  )
@@ -308,7 +322,7 @@ class OpenAIModel(Model):
308
322
 
309
323
  def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
310
324
  """Process a non-streamed response, and prepare a message to return."""
311
- timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
325
+ timestamp = number_to_datetime(response.created)
312
326
  choice = response.choices[0]
313
327
  items: list[ModelResponsePart] = []
314
328
  vendor_details: dict[str, Any] | None = None
@@ -358,7 +372,7 @@ class OpenAIModel(Model):
358
372
  return OpenAIStreamedResponse(
359
373
  _model_name=self._model_name,
360
374
  _response=peekable_response,
361
- _timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
375
+ _timestamp=number_to_datetime(first_chunk.created),
362
376
  )
363
377
 
364
378
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
@@ -485,21 +499,21 @@ class OpenAIModel(Model):
485
499
  else: # pragma: no cover
486
500
  raise RuntimeError(f'Unsupported binary content type: {item.media_type}')
487
501
  elif isinstance(item, AudioUrl):
488
- client = cached_async_http_client()
489
- response = await client.get(item.url)
490
- response.raise_for_status()
491
- base64_encoded = base64.b64encode(response.content).decode('utf-8')
492
- audio_format: Any = response.headers['content-type'].removeprefix('audio/')
493
- audio = InputAudio(data=base64_encoded, format=audio_format)
502
+ downloaded_item = await download_item(item, data_format='base64', type_format='extension')
503
+ assert downloaded_item['data_type'] in (
504
+ 'wav',
505
+ 'mp3',
506
+ ), f'Unsupported audio format: {downloaded_item["data_type"]}'
507
+ audio = InputAudio(data=downloaded_item['data'], format=downloaded_item['data_type'])
494
508
  content.append(ChatCompletionContentPartInputAudioParam(input_audio=audio, type='input_audio'))
495
509
  elif isinstance(item, DocumentUrl):
496
- client = cached_async_http_client()
497
- response = await client.get(item.url)
498
- response.raise_for_status()
499
- base64_encoded = base64.b64encode(response.content).decode('utf-8')
500
- media_type = response.headers.get('content-type').split(';')[0]
501
- file_data = f'data:{media_type};base64,{base64_encoded}'
502
- file = File(file=FileFile(file_data=file_data, filename=f'filename.{item.format}'), type='file')
510
+ downloaded_item = await download_item(item, data_format='base64_uri', type_format='extension')
511
+ file = File(
512
+ file=FileFile(
513
+ file_data=downloaded_item['data'], filename=f'filename.{downloaded_item["data_type"]}'
514
+ ),
515
+ type='file',
516
+ )
503
517
  content.append(file)
504
518
  elif isinstance(item, VideoUrl): # pragma: no cover
505
519
  raise NotImplementedError('VideoUrl is not supported for OpenAI')
@@ -593,7 +607,7 @@ class OpenAIResponsesModel(Model):
593
607
 
594
608
  def _process_response(self, response: responses.Response) -> ModelResponse:
595
609
  """Process a non-streamed response, and prepare a message to return."""
596
- timestamp = datetime.fromtimestamp(response.created_at, tz=timezone.utc)
610
+ timestamp = number_to_datetime(response.created_at)
597
611
  items: list[ModelResponsePart] = []
598
612
  items.append(TextPart(response.output_text))
599
613
  for item in response.output:
@@ -614,7 +628,7 @@ class OpenAIResponsesModel(Model):
614
628
  return OpenAIResponsesStreamedResponse(
615
629
  _model_name=self._model_name,
616
630
  _response=peekable_response,
617
- _timestamp=datetime.fromtimestamp(first_chunk.response.created_at, tz=timezone.utc),
631
+ _timestamp=number_to_datetime(first_chunk.response.created_at),
618
632
  )
619
633
 
620
634
  @overload
@@ -656,6 +670,12 @@ class OpenAIResponsesModel(Model):
656
670
  instructions, openai_messages = await self._map_messages(messages)
657
671
  reasoning = self._get_reasoning(model_settings)
658
672
 
673
+ sampling_settings = (
674
+ model_settings
675
+ if OpenAIModelProfile.from_profile(self.profile).openai_supports_sampling_settings
676
+ else OpenAIResponsesModelSettings()
677
+ )
678
+
659
679
  try:
660
680
  extra_headers = model_settings.get('extra_headers', {})
661
681
  extra_headers.setdefault('User-Agent', get_user_agent())
@@ -668,8 +688,8 @@ class OpenAIResponsesModel(Model):
668
688
  tool_choice=tool_choice or NOT_GIVEN,
669
689
  max_output_tokens=model_settings.get('max_tokens', NOT_GIVEN),
670
690
  stream=stream,
671
- temperature=model_settings.get('temperature', NOT_GIVEN),
672
- top_p=model_settings.get('top_p', NOT_GIVEN),
691
+ temperature=sampling_settings.get('temperature', NOT_GIVEN),
692
+ top_p=sampling_settings.get('top_p', NOT_GIVEN),
673
693
  truncation=model_settings.get('openai_truncation', NOT_GIVEN),
674
694
  timeout=model_settings.get('timeout', NOT_GIVEN),
675
695
  reasoning=reasoning,
@@ -805,27 +825,21 @@ class OpenAIResponsesModel(Model):
805
825
  responses.ResponseInputImageParam(image_url=item.url, type='input_image', detail='auto')
806
826
  )
807
827
  elif isinstance(item, AudioUrl): # pragma: no cover
808
- client = cached_async_http_client()
809
- response = await client.get(item.url)
810
- response.raise_for_status()
811
- base64_encoded = base64.b64encode(response.content).decode('utf-8')
828
+ downloaded_item = await download_item(item, data_format='base64_uri', type_format='extension')
812
829
  content.append(
813
830
  responses.ResponseInputFileParam(
814
831
  type='input_file',
815
- file_data=f'data:{item.media_type};base64,{base64_encoded}',
832
+ file_data=downloaded_item['data'],
833
+ filename=f'filename.{downloaded_item["data_type"]}',
816
834
  )
817
835
  )
818
836
  elif isinstance(item, DocumentUrl):
819
- client = cached_async_http_client()
820
- response = await client.get(item.url)
821
- response.raise_for_status()
822
- base64_encoded = base64.b64encode(response.content).decode('utf-8')
823
- media_type = response.headers.get('content-type').split(';')[0]
837
+ downloaded_item = await download_item(item, data_format='base64_uri', type_format='extension')
824
838
  content.append(
825
839
  responses.ResponseInputFileParam(
826
840
  type='input_file',
827
- file_data=f'data:{media_type};base64,{base64_encoded}',
828
- filename=f'filename.{item.format}',
841
+ file_data=downloaded_item['data'],
842
+ filename=f'filename.{downloaded_item["data_type"]}',
829
843
  )
830
844
  )
831
845
  elif isinstance(item, VideoUrl): # pragma: no cover
@@ -15,13 +15,20 @@ class OpenAIModelProfile(ModelProfile):
15
15
  ALL FIELDS MUST BE `openai_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
16
16
  """
17
17
 
18
- # This can be set by a provider or user if the OpenAI-"compatible" API doesn't support strict tool definitions
19
18
  openai_supports_strict_tool_definition: bool = True
19
+ """This can be set by a provider or user if the OpenAI-"compatible" API doesn't support strict tool definitions."""
20
+
21
+ openai_supports_sampling_settings: bool = True
22
+ """Turn off to don't send sampling settings like `temperature` and `top_p` to models that don't support them, like OpenAI's o-series reasoning models."""
20
23
 
21
24
 
22
25
  def openai_model_profile(model_name: str) -> ModelProfile:
23
26
  """Get the model profile for an OpenAI model."""
24
- return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer)
27
+ is_reasoning_model = model_name.startswith('o')
28
+ return OpenAIModelProfile(
29
+ json_schema_transformer=OpenAIJsonSchemaTransformer,
30
+ openai_supports_sampling_settings=not is_reasoning_model,
31
+ )
25
32
 
26
33
 
27
34
  _STRICT_INCOMPATIBLE_KEYS = [
@@ -48,7 +48,7 @@ class Provider(ABC, Generic[InterfaceClient]):
48
48
  return None # pragma: no cover
49
49
 
50
50
 
51
- def infer_provider(provider: str) -> Provider[Any]:
51
+ def infer_provider(provider: str) -> Provider[Any]: # noqa: C901
52
52
  """Infer the provider from the provider name."""
53
53
  if provider == 'openai':
54
54
  from .openai import OpenAIProvider
@@ -107,5 +107,9 @@ def infer_provider(provider: str) -> Provider[Any]:
107
107
  from .together import TogetherProvider
108
108
 
109
109
  return TogetherProvider()
110
+ elif provider == 'heroku':
111
+ from .heroku import HerokuProvider
112
+
113
+ return HerokuProvider()
110
114
  else: # pragma: no cover
111
115
  raise ValueError(f'Unknown provider: {provider}')
@@ -50,7 +50,7 @@ class GoogleVertexProvider(Provider[httpx.AsyncClient]):
50
50
  return self._client
51
51
 
52
52
  def model_profile(self, model_name: str) -> ModelProfile | None:
53
- return google_model_profile(model_name)
53
+ return google_model_profile(model_name) # pragma: lax no cover
54
54
 
55
55
  @overload
56
56
  def __init__(