pydantic-ai-slim 0.2.15__py3-none-any.whl → 0.2.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_ai/_agent_graph.py +0 -4
- pydantic_ai/_function_schema.py +4 -4
- pydantic_ai/_output.py +1 -1
- pydantic_ai/_utils.py +5 -1
- pydantic_ai/agent.py +5 -6
- pydantic_ai/ext/__init__.py +0 -0
- pydantic_ai/ext/langchain.py +61 -0
- pydantic_ai/mcp.py +57 -15
- pydantic_ai/messages.py +43 -13
- pydantic_ai/models/__init__.py +95 -3
- pydantic_ai/models/anthropic.py +3 -11
- pydantic_ai/models/bedrock.py +23 -15
- pydantic_ai/models/gemini.py +18 -14
- pydantic_ai/models/google.py +12 -11
- pydantic_ai/models/groq.py +4 -4
- pydantic_ai/models/instrumented.py +98 -32
- pydantic_ai/models/mistral.py +5 -5
- pydantic_ai/models/openai.py +56 -42
- pydantic_ai/profiles/openai.py +9 -2
- pydantic_ai/providers/__init__.py +5 -1
- pydantic_ai/providers/google_vertex.py +1 -1
- pydantic_ai/providers/heroku.py +82 -0
- pydantic_ai/settings.py +1 -0
- pydantic_ai/tools.py +53 -6
- {pydantic_ai_slim-0.2.15.dist-info → pydantic_ai_slim-0.2.17.dist-info}/METADATA +5 -5
- {pydantic_ai_slim-0.2.15.dist-info → pydantic_ai_slim-0.2.17.dist-info}/RECORD +29 -26
- {pydantic_ai_slim-0.2.15.dist-info → pydantic_ai_slim-0.2.17.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.2.15.dist-info → pydantic_ai_slim-0.2.17.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.2.15.dist-info → pydantic_ai_slim-0.2.17.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/gemini.py
CHANGED
|
@@ -17,10 +17,8 @@ from pydantic_ai.providers import Provider, infer_provider
|
|
|
17
17
|
|
|
18
18
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
19
19
|
from ..messages import (
|
|
20
|
-
AudioUrl,
|
|
21
20
|
BinaryContent,
|
|
22
|
-
|
|
23
|
-
ImageUrl,
|
|
21
|
+
FileUrl,
|
|
24
22
|
ModelMessage,
|
|
25
23
|
ModelRequest,
|
|
26
24
|
ModelResponse,
|
|
@@ -41,8 +39,8 @@ from . import (
|
|
|
41
39
|
Model,
|
|
42
40
|
ModelRequestParameters,
|
|
43
41
|
StreamedResponse,
|
|
44
|
-
cached_async_http_client,
|
|
45
42
|
check_allow_model_requests,
|
|
43
|
+
download_item,
|
|
46
44
|
get_user_agent,
|
|
47
45
|
)
|
|
48
46
|
|
|
@@ -228,7 +226,7 @@ class GeminiModel(Model):
|
|
|
228
226
|
|
|
229
227
|
if gemini_labels := model_settings.get('gemini_labels'):
|
|
230
228
|
if self._system == 'google-vertex':
|
|
231
|
-
request_data['labels'] = gemini_labels
|
|
229
|
+
request_data['labels'] = gemini_labels # pragma: lax no cover
|
|
232
230
|
|
|
233
231
|
headers = {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}
|
|
234
232
|
url = f'/{self._model_name}:{"streamGenerateContent" if streamed else "generateContent"}'
|
|
@@ -348,15 +346,19 @@ class GeminiModel(Model):
|
|
|
348
346
|
content.append(
|
|
349
347
|
_GeminiInlineDataPart(inline_data={'data': base64_encoded, 'mime_type': item.media_type})
|
|
350
348
|
)
|
|
351
|
-
elif isinstance(item,
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
inline_data=
|
|
358
|
-
|
|
359
|
-
|
|
349
|
+
elif isinstance(item, VideoUrl) and item.is_youtube:
|
|
350
|
+
file_data = _GeminiFileDataPart(file_data={'file_uri': item.url, 'mime_type': item.media_type})
|
|
351
|
+
content.append(file_data)
|
|
352
|
+
elif isinstance(item, FileUrl):
|
|
353
|
+
if self.system == 'google-gla' or item.force_download:
|
|
354
|
+
downloaded_item = await download_item(item, data_format='base64')
|
|
355
|
+
inline_data = _GeminiInlineDataPart(
|
|
356
|
+
inline_data={'data': downloaded_item['data'], 'mime_type': downloaded_item['data_type']}
|
|
357
|
+
)
|
|
358
|
+
content.append(inline_data)
|
|
359
|
+
else:
|
|
360
|
+
file_data = _GeminiFileDataPart(file_data={'file_uri': item.url, 'mime_type': item.media_type})
|
|
361
|
+
content.append(file_data)
|
|
360
362
|
else:
|
|
361
363
|
assert_never(item)
|
|
362
364
|
return content
|
|
@@ -366,6 +368,8 @@ def _settings_to_generation_config(model_settings: GeminiModelSettings) -> _Gemi
|
|
|
366
368
|
config: _GeminiGenerationConfig = {}
|
|
367
369
|
if (max_tokens := model_settings.get('max_tokens')) is not None:
|
|
368
370
|
config['max_output_tokens'] = max_tokens
|
|
371
|
+
if (stop_sequences := model_settings.get('stop_sequences')) is not None:
|
|
372
|
+
config['stop_sequences'] = stop_sequences # pragma: no cover
|
|
369
373
|
if (temperature := model_settings.get('temperature')) is not None:
|
|
370
374
|
config['temperature'] = temperature
|
|
371
375
|
if (top_p := model_settings.get('top_p')) is not None:
|
pydantic_ai/models/google.py
CHANGED
|
@@ -14,10 +14,8 @@ from pydantic_ai.providers import Provider
|
|
|
14
14
|
|
|
15
15
|
from .. import UnexpectedModelBehavior, _utils, usage
|
|
16
16
|
from ..messages import (
|
|
17
|
-
AudioUrl,
|
|
18
17
|
BinaryContent,
|
|
19
|
-
|
|
20
|
-
ImageUrl,
|
|
18
|
+
FileUrl,
|
|
21
19
|
ModelMessage,
|
|
22
20
|
ModelRequest,
|
|
23
21
|
ModelResponse,
|
|
@@ -38,8 +36,8 @@ from . import (
|
|
|
38
36
|
Model,
|
|
39
37
|
ModelRequestParameters,
|
|
40
38
|
StreamedResponse,
|
|
41
|
-
cached_async_http_client,
|
|
42
39
|
check_allow_model_requests,
|
|
40
|
+
download_item,
|
|
43
41
|
get_user_agent,
|
|
44
42
|
)
|
|
45
43
|
|
|
@@ -260,6 +258,7 @@ class GoogleModel(Model):
|
|
|
260
258
|
temperature=model_settings.get('temperature'),
|
|
261
259
|
top_p=model_settings.get('top_p'),
|
|
262
260
|
max_output_tokens=model_settings.get('max_tokens'),
|
|
261
|
+
stop_sequences=model_settings.get('stop_sequences'),
|
|
263
262
|
presence_penalty=model_settings.get('presence_penalty'),
|
|
264
263
|
frequency_penalty=model_settings.get('frequency_penalty'),
|
|
265
264
|
safety_settings=model_settings.get('google_safety_settings'),
|
|
@@ -371,13 +370,15 @@ class GoogleModel(Model):
|
|
|
371
370
|
# NOTE: The type from Google GenAI is incorrect, it should be `str`, not `bytes`.
|
|
372
371
|
base64_encoded = base64.b64encode(item.data).decode('utf-8')
|
|
373
372
|
content.append({'inline_data': {'data': base64_encoded, 'mime_type': item.media_type}}) # type: ignore
|
|
374
|
-
elif isinstance(item,
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
373
|
+
elif isinstance(item, VideoUrl) and item.is_youtube:
|
|
374
|
+
content.append({'file_data': {'file_uri': item.url, 'mime_type': item.media_type}})
|
|
375
|
+
elif isinstance(item, FileUrl):
|
|
376
|
+
if self.system == 'google-gla' or item.force_download:
|
|
377
|
+
downloaded_item = await download_item(item, data_format='base64')
|
|
378
|
+
inline_data = {'data': downloaded_item['data'], 'mime_type': downloaded_item['data_type']}
|
|
379
|
+
content.append({'inline_data': inline_data}) # type: ignore
|
|
380
|
+
else:
|
|
381
|
+
content.append({'file_data': {'file_uri': item.url, 'mime_type': item.media_type}})
|
|
381
382
|
else:
|
|
382
383
|
assert_never(item)
|
|
383
384
|
return content
|
pydantic_ai/models/groq.py
CHANGED
|
@@ -4,13 +4,13 @@ import base64
|
|
|
4
4
|
from collections.abc import AsyncIterable, AsyncIterator, Iterable
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
|
-
from datetime import datetime
|
|
7
|
+
from datetime import datetime
|
|
8
8
|
from typing import Literal, Union, cast, overload
|
|
9
9
|
|
|
10
10
|
from typing_extensions import assert_never
|
|
11
11
|
|
|
12
12
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
13
|
-
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
13
|
+
from .._utils import guard_tool_call_id as _guard_tool_call_id, number_to_datetime
|
|
14
14
|
from ..messages import (
|
|
15
15
|
BinaryContent,
|
|
16
16
|
DocumentUrl,
|
|
@@ -246,7 +246,7 @@ class GroqModel(Model):
|
|
|
246
246
|
|
|
247
247
|
def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
|
|
248
248
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
249
|
-
timestamp =
|
|
249
|
+
timestamp = number_to_datetime(response.created)
|
|
250
250
|
choice = response.choices[0]
|
|
251
251
|
items: list[ModelResponsePart] = []
|
|
252
252
|
if choice.message.content is not None:
|
|
@@ -270,7 +270,7 @@ class GroqModel(Model):
|
|
|
270
270
|
return GroqStreamedResponse(
|
|
271
271
|
_response=peekable_response,
|
|
272
272
|
_model_name=self._model_name,
|
|
273
|
-
_timestamp=
|
|
273
|
+
_timestamp=number_to_datetime(first_chunk.created),
|
|
274
274
|
)
|
|
275
275
|
|
|
276
276
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
|
|
@@ -13,6 +13,7 @@ from opentelemetry._events import (
|
|
|
13
13
|
EventLoggerProvider, # pyright: ignore[reportPrivateImportUsage]
|
|
14
14
|
get_event_logger_provider, # pyright: ignore[reportPrivateImportUsage]
|
|
15
15
|
)
|
|
16
|
+
from opentelemetry.metrics import MeterProvider, get_meter_provider
|
|
16
17
|
from opentelemetry.trace import Span, Tracer, TracerProvider, get_tracer_provider
|
|
17
18
|
from opentelemetry.util.types import AttributeValue
|
|
18
19
|
from pydantic import TypeAdapter
|
|
@@ -49,6 +50,10 @@ MODEL_SETTING_ATTRIBUTES: tuple[
|
|
|
49
50
|
|
|
50
51
|
ANY_ADAPTER = TypeAdapter[Any](Any)
|
|
51
52
|
|
|
53
|
+
# These are in the spec:
|
|
54
|
+
# https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-metrics/#metric-gen_aiclienttokenusage
|
|
55
|
+
TOKEN_HISTOGRAM_BOUNDARIES = (1, 4, 16, 64, 256, 1024, 4096, 16384, 65536, 262144, 1048576, 4194304, 16777216, 67108864)
|
|
56
|
+
|
|
52
57
|
|
|
53
58
|
def instrument_model(model: Model, instrument: InstrumentationSettings | bool) -> Model:
|
|
54
59
|
"""Instrument a model with OpenTelemetry/logfire."""
|
|
@@ -84,6 +89,7 @@ class InstrumentationSettings:
|
|
|
84
89
|
*,
|
|
85
90
|
event_mode: Literal['attributes', 'logs'] = 'attributes',
|
|
86
91
|
tracer_provider: TracerProvider | None = None,
|
|
92
|
+
meter_provider: MeterProvider | None = None,
|
|
87
93
|
event_logger_provider: EventLoggerProvider | None = None,
|
|
88
94
|
include_binary_content: bool = True,
|
|
89
95
|
):
|
|
@@ -95,6 +101,9 @@ class InstrumentationSettings:
|
|
|
95
101
|
tracer_provider: The OpenTelemetry tracer provider to use.
|
|
96
102
|
If not provided, the global tracer provider is used.
|
|
97
103
|
Calling `logfire.configure()` sets the global tracer provider, so most users don't need this.
|
|
104
|
+
meter_provider: The OpenTelemetry meter provider to use.
|
|
105
|
+
If not provided, the global meter provider is used.
|
|
106
|
+
Calling `logfire.configure()` sets the global meter provider, so most users don't need this.
|
|
98
107
|
event_logger_provider: The OpenTelemetry event logger provider to use.
|
|
99
108
|
If not provided, the global event logger provider is used.
|
|
100
109
|
Calling `logfire.configure()` sets the global event logger provider, so most users don't need this.
|
|
@@ -104,12 +113,33 @@ class InstrumentationSettings:
|
|
|
104
113
|
from pydantic_ai import __version__
|
|
105
114
|
|
|
106
115
|
tracer_provider = tracer_provider or get_tracer_provider()
|
|
116
|
+
meter_provider = meter_provider or get_meter_provider()
|
|
107
117
|
event_logger_provider = event_logger_provider or get_event_logger_provider()
|
|
108
|
-
|
|
109
|
-
self.
|
|
118
|
+
scope_name = 'pydantic-ai'
|
|
119
|
+
self.tracer = tracer_provider.get_tracer(scope_name, __version__)
|
|
120
|
+
self.meter = meter_provider.get_meter(scope_name, __version__)
|
|
121
|
+
self.event_logger = event_logger_provider.get_event_logger(scope_name, __version__)
|
|
110
122
|
self.event_mode = event_mode
|
|
111
123
|
self.include_binary_content = include_binary_content
|
|
112
124
|
|
|
125
|
+
# As specified in the OpenTelemetry GenAI metrics spec:
|
|
126
|
+
# https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-metrics/#metric-gen_aiclienttokenusage
|
|
127
|
+
tokens_histogram_kwargs = dict(
|
|
128
|
+
name='gen_ai.client.token.usage',
|
|
129
|
+
unit='{token}',
|
|
130
|
+
description='Measures number of input and output tokens used',
|
|
131
|
+
)
|
|
132
|
+
try:
|
|
133
|
+
self.tokens_histogram = self.meter.create_histogram(
|
|
134
|
+
**tokens_histogram_kwargs,
|
|
135
|
+
explicit_bucket_boundaries_advisory=TOKEN_HISTOGRAM_BOUNDARIES,
|
|
136
|
+
)
|
|
137
|
+
except TypeError:
|
|
138
|
+
# Older OTel/logfire versions don't support explicit_bucket_boundaries_advisory
|
|
139
|
+
self.tokens_histogram = self.meter.create_histogram(
|
|
140
|
+
**tokens_histogram_kwargs, # pyright: ignore
|
|
141
|
+
)
|
|
142
|
+
|
|
113
143
|
def messages_to_otel_events(self, messages: list[ModelMessage]) -> list[Event]:
|
|
114
144
|
"""Convert a list of model messages to OpenTelemetry events.
|
|
115
145
|
|
|
@@ -224,38 +254,74 @@ class InstrumentedModel(WrapperModel):
|
|
|
224
254
|
if isinstance(value := model_settings.get(key), (float, int)):
|
|
225
255
|
attributes[f'gen_ai.request.{key}'] = value
|
|
226
256
|
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
257
|
+
record_metrics: Callable[[], None] | None = None
|
|
258
|
+
try:
|
|
259
|
+
with self.settings.tracer.start_as_current_span(span_name, attributes=attributes) as span:
|
|
260
|
+
|
|
261
|
+
def finish(response: ModelResponse):
|
|
262
|
+
# FallbackModel updates these span attributes.
|
|
263
|
+
attributes.update(getattr(span, 'attributes', {}))
|
|
264
|
+
request_model = attributes[GEN_AI_REQUEST_MODEL_ATTRIBUTE]
|
|
265
|
+
system = attributes[GEN_AI_SYSTEM_ATTRIBUTE]
|
|
266
|
+
|
|
267
|
+
response_model = response.model_name or request_model
|
|
268
|
+
|
|
269
|
+
def _record_metrics():
|
|
270
|
+
metric_attributes = {
|
|
271
|
+
GEN_AI_SYSTEM_ATTRIBUTE: system,
|
|
272
|
+
'gen_ai.operation.name': operation,
|
|
273
|
+
'gen_ai.request.model': request_model,
|
|
274
|
+
'gen_ai.response.model': response_model,
|
|
275
|
+
}
|
|
276
|
+
if response.usage.request_tokens: # pragma: no branch
|
|
277
|
+
self.settings.tokens_histogram.record(
|
|
278
|
+
response.usage.request_tokens,
|
|
279
|
+
{**metric_attributes, 'gen_ai.token.type': 'input'},
|
|
280
|
+
)
|
|
281
|
+
if response.usage.response_tokens: # pragma: no branch
|
|
282
|
+
self.settings.tokens_histogram.record(
|
|
283
|
+
response.usage.response_tokens,
|
|
284
|
+
{**metric_attributes, 'gen_ai.token.type': 'output'},
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
nonlocal record_metrics
|
|
288
|
+
record_metrics = _record_metrics
|
|
289
|
+
|
|
290
|
+
if not span.is_recording():
|
|
291
|
+
return
|
|
292
|
+
|
|
293
|
+
events = self.settings.messages_to_otel_events(messages)
|
|
294
|
+
for event in self.settings.messages_to_otel_events([response]):
|
|
295
|
+
events.append(
|
|
296
|
+
Event(
|
|
297
|
+
'gen_ai.choice',
|
|
298
|
+
body={
|
|
299
|
+
# TODO finish_reason
|
|
300
|
+
'index': 0,
|
|
301
|
+
'message': event.body,
|
|
302
|
+
},
|
|
303
|
+
)
|
|
243
304
|
)
|
|
305
|
+
span.set_attributes(
|
|
306
|
+
{
|
|
307
|
+
**response.usage.opentelemetry_attributes(),
|
|
308
|
+
'gen_ai.response.model': response_model,
|
|
309
|
+
}
|
|
244
310
|
)
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
311
|
+
span.update_name(f'{operation} {request_model}')
|
|
312
|
+
for event in events:
|
|
313
|
+
event.attributes = {
|
|
314
|
+
GEN_AI_SYSTEM_ATTRIBUTE: system,
|
|
315
|
+
**(event.attributes or {}),
|
|
316
|
+
}
|
|
317
|
+
self._emit_events(span, events)
|
|
318
|
+
|
|
319
|
+
yield finish
|
|
320
|
+
finally:
|
|
321
|
+
if record_metrics:
|
|
322
|
+
# We only want to record metrics after the span is finished,
|
|
323
|
+
# to prevent them from being redundantly recorded in the span itself by logfire.
|
|
324
|
+
record_metrics()
|
|
259
325
|
|
|
260
326
|
def _emit_events(self, span: Span, events: list[Event]) -> None:
|
|
261
327
|
if self.settings.event_mode == 'logs':
|
pydantic_ai/models/mistral.py
CHANGED
|
@@ -4,7 +4,7 @@ import base64
|
|
|
4
4
|
from collections.abc import AsyncIterable, AsyncIterator, Iterable
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
|
-
from datetime import datetime
|
|
7
|
+
from datetime import datetime
|
|
8
8
|
from typing import Any, Literal, Union, cast
|
|
9
9
|
|
|
10
10
|
import pydantic_core
|
|
@@ -12,7 +12,7 @@ from httpx import Timeout
|
|
|
12
12
|
from typing_extensions import assert_never
|
|
13
13
|
|
|
14
14
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils
|
|
15
|
-
from .._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc
|
|
15
|
+
from .._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc, number_to_datetime
|
|
16
16
|
from ..messages import (
|
|
17
17
|
BinaryContent,
|
|
18
18
|
DocumentUrl,
|
|
@@ -312,7 +312,7 @@ class MistralModel(Model):
|
|
|
312
312
|
assert response.choices, 'Unexpected empty response choice.'
|
|
313
313
|
|
|
314
314
|
if response.created:
|
|
315
|
-
timestamp =
|
|
315
|
+
timestamp = number_to_datetime(response.created)
|
|
316
316
|
else:
|
|
317
317
|
timestamp = _now_utc()
|
|
318
318
|
|
|
@@ -347,9 +347,9 @@ class MistralModel(Model):
|
|
|
347
347
|
)
|
|
348
348
|
|
|
349
349
|
if first_chunk.data.created:
|
|
350
|
-
timestamp =
|
|
350
|
+
timestamp = number_to_datetime(first_chunk.data.created)
|
|
351
351
|
else:
|
|
352
|
-
timestamp =
|
|
352
|
+
timestamp = _now_utc()
|
|
353
353
|
|
|
354
354
|
return MistralStreamedResponse(
|
|
355
355
|
_response=peekable_response,
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -5,7 +5,7 @@ import warnings
|
|
|
5
5
|
from collections.abc import AsyncIterable, AsyncIterator, Sequence
|
|
6
6
|
from contextlib import asynccontextmanager
|
|
7
7
|
from dataclasses import dataclass, field
|
|
8
|
-
from datetime import datetime
|
|
8
|
+
from datetime import datetime
|
|
9
9
|
from typing import Any, Literal, Union, cast, overload
|
|
10
10
|
|
|
11
11
|
from typing_extensions import assert_never
|
|
@@ -14,7 +14,7 @@ from pydantic_ai.profiles.openai import OpenAIModelProfile
|
|
|
14
14
|
from pydantic_ai.providers import Provider, infer_provider
|
|
15
15
|
|
|
16
16
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
17
|
-
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
17
|
+
from .._utils import guard_tool_call_id as _guard_tool_call_id, number_to_datetime
|
|
18
18
|
from ..messages import (
|
|
19
19
|
AudioUrl,
|
|
20
20
|
BinaryContent,
|
|
@@ -40,8 +40,8 @@ from . import (
|
|
|
40
40
|
Model,
|
|
41
41
|
ModelRequestParameters,
|
|
42
42
|
StreamedResponse,
|
|
43
|
-
cached_async_http_client,
|
|
44
43
|
check_allow_model_requests,
|
|
44
|
+
download_item,
|
|
45
45
|
get_user_agent,
|
|
46
46
|
)
|
|
47
47
|
|
|
@@ -116,6 +116,13 @@ class OpenAIModelSettings(ModelSettings, total=False):
|
|
|
116
116
|
See [OpenAI's safety best practices](https://platform.openai.com/docs/guides/safety-best-practices#end-user-ids) for more details.
|
|
117
117
|
"""
|
|
118
118
|
|
|
119
|
+
openai_service_tier: Literal['auto', 'default', 'flex']
|
|
120
|
+
"""The service tier to use for the model request.
|
|
121
|
+
|
|
122
|
+
Currently supported values are `auto`, `default`, and `flex`.
|
|
123
|
+
For more information, see [OpenAI's service tiers documentation](https://platform.openai.com/docs/api-reference/chat/object#chat/object-service_tier).
|
|
124
|
+
"""
|
|
125
|
+
|
|
119
126
|
|
|
120
127
|
class OpenAIResponsesModelSettings(OpenAIModelSettings, total=False):
|
|
121
128
|
"""Settings used for an OpenAI Responses model request.
|
|
@@ -170,7 +177,7 @@ class OpenAIModel(Model):
|
|
|
170
177
|
self,
|
|
171
178
|
model_name: OpenAIModelName,
|
|
172
179
|
*,
|
|
173
|
-
provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together']
|
|
180
|
+
provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together', 'heroku']
|
|
174
181
|
| Provider[AsyncOpenAI] = 'openai',
|
|
175
182
|
profile: ModelProfileSpec | None = None,
|
|
176
183
|
system_prompt_role: OpenAISystemPromptRole | None = None,
|
|
@@ -274,6 +281,12 @@ class OpenAIModel(Model):
|
|
|
274
281
|
|
|
275
282
|
openai_messages = await self._map_messages(messages)
|
|
276
283
|
|
|
284
|
+
sampling_settings = (
|
|
285
|
+
model_settings
|
|
286
|
+
if OpenAIModelProfile.from_profile(self.profile).openai_supports_sampling_settings
|
|
287
|
+
else OpenAIModelSettings()
|
|
288
|
+
)
|
|
289
|
+
|
|
277
290
|
try:
|
|
278
291
|
extra_headers = model_settings.get('extra_headers', {})
|
|
279
292
|
extra_headers.setdefault('User-Agent', get_user_agent())
|
|
@@ -287,17 +300,18 @@ class OpenAIModel(Model):
|
|
|
287
300
|
stream_options={'include_usage': True} if stream else NOT_GIVEN,
|
|
288
301
|
stop=model_settings.get('stop_sequences', NOT_GIVEN),
|
|
289
302
|
max_completion_tokens=model_settings.get('max_tokens', NOT_GIVEN),
|
|
290
|
-
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
291
|
-
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
292
303
|
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
293
304
|
seed=model_settings.get('seed', NOT_GIVEN),
|
|
294
|
-
presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
|
|
295
|
-
frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
|
|
296
|
-
logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
|
|
297
305
|
reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN),
|
|
298
|
-
logprobs=model_settings.get('openai_logprobs', NOT_GIVEN),
|
|
299
|
-
top_logprobs=model_settings.get('openai_top_logprobs', NOT_GIVEN),
|
|
300
306
|
user=model_settings.get('openai_user', NOT_GIVEN),
|
|
307
|
+
service_tier=model_settings.get('openai_service_tier', NOT_GIVEN),
|
|
308
|
+
temperature=sampling_settings.get('temperature', NOT_GIVEN),
|
|
309
|
+
top_p=sampling_settings.get('top_p', NOT_GIVEN),
|
|
310
|
+
presence_penalty=sampling_settings.get('presence_penalty', NOT_GIVEN),
|
|
311
|
+
frequency_penalty=sampling_settings.get('frequency_penalty', NOT_GIVEN),
|
|
312
|
+
logit_bias=sampling_settings.get('logit_bias', NOT_GIVEN),
|
|
313
|
+
logprobs=sampling_settings.get('openai_logprobs', NOT_GIVEN),
|
|
314
|
+
top_logprobs=sampling_settings.get('openai_top_logprobs', NOT_GIVEN),
|
|
301
315
|
extra_headers=extra_headers,
|
|
302
316
|
extra_body=model_settings.get('extra_body'),
|
|
303
317
|
)
|
|
@@ -308,7 +322,7 @@ class OpenAIModel(Model):
|
|
|
308
322
|
|
|
309
323
|
def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
|
|
310
324
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
311
|
-
timestamp =
|
|
325
|
+
timestamp = number_to_datetime(response.created)
|
|
312
326
|
choice = response.choices[0]
|
|
313
327
|
items: list[ModelResponsePart] = []
|
|
314
328
|
vendor_details: dict[str, Any] | None = None
|
|
@@ -358,7 +372,7 @@ class OpenAIModel(Model):
|
|
|
358
372
|
return OpenAIStreamedResponse(
|
|
359
373
|
_model_name=self._model_name,
|
|
360
374
|
_response=peekable_response,
|
|
361
|
-
_timestamp=
|
|
375
|
+
_timestamp=number_to_datetime(first_chunk.created),
|
|
362
376
|
)
|
|
363
377
|
|
|
364
378
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
|
|
@@ -485,21 +499,21 @@ class OpenAIModel(Model):
|
|
|
485
499
|
else: # pragma: no cover
|
|
486
500
|
raise RuntimeError(f'Unsupported binary content type: {item.media_type}')
|
|
487
501
|
elif isinstance(item, AudioUrl):
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
audio = InputAudio(data=
|
|
502
|
+
downloaded_item = await download_item(item, data_format='base64', type_format='extension')
|
|
503
|
+
assert downloaded_item['data_type'] in (
|
|
504
|
+
'wav',
|
|
505
|
+
'mp3',
|
|
506
|
+
), f'Unsupported audio format: {downloaded_item["data_type"]}'
|
|
507
|
+
audio = InputAudio(data=downloaded_item['data'], format=downloaded_item['data_type'])
|
|
494
508
|
content.append(ChatCompletionContentPartInputAudioParam(input_audio=audio, type='input_audio'))
|
|
495
509
|
elif isinstance(item, DocumentUrl):
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
510
|
+
downloaded_item = await download_item(item, data_format='base64_uri', type_format='extension')
|
|
511
|
+
file = File(
|
|
512
|
+
file=FileFile(
|
|
513
|
+
file_data=downloaded_item['data'], filename=f'filename.{downloaded_item["data_type"]}'
|
|
514
|
+
),
|
|
515
|
+
type='file',
|
|
516
|
+
)
|
|
503
517
|
content.append(file)
|
|
504
518
|
elif isinstance(item, VideoUrl): # pragma: no cover
|
|
505
519
|
raise NotImplementedError('VideoUrl is not supported for OpenAI')
|
|
@@ -593,7 +607,7 @@ class OpenAIResponsesModel(Model):
|
|
|
593
607
|
|
|
594
608
|
def _process_response(self, response: responses.Response) -> ModelResponse:
|
|
595
609
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
596
|
-
timestamp =
|
|
610
|
+
timestamp = number_to_datetime(response.created_at)
|
|
597
611
|
items: list[ModelResponsePart] = []
|
|
598
612
|
items.append(TextPart(response.output_text))
|
|
599
613
|
for item in response.output:
|
|
@@ -614,7 +628,7 @@ class OpenAIResponsesModel(Model):
|
|
|
614
628
|
return OpenAIResponsesStreamedResponse(
|
|
615
629
|
_model_name=self._model_name,
|
|
616
630
|
_response=peekable_response,
|
|
617
|
-
_timestamp=
|
|
631
|
+
_timestamp=number_to_datetime(first_chunk.response.created_at),
|
|
618
632
|
)
|
|
619
633
|
|
|
620
634
|
@overload
|
|
@@ -656,6 +670,12 @@ class OpenAIResponsesModel(Model):
|
|
|
656
670
|
instructions, openai_messages = await self._map_messages(messages)
|
|
657
671
|
reasoning = self._get_reasoning(model_settings)
|
|
658
672
|
|
|
673
|
+
sampling_settings = (
|
|
674
|
+
model_settings
|
|
675
|
+
if OpenAIModelProfile.from_profile(self.profile).openai_supports_sampling_settings
|
|
676
|
+
else OpenAIResponsesModelSettings()
|
|
677
|
+
)
|
|
678
|
+
|
|
659
679
|
try:
|
|
660
680
|
extra_headers = model_settings.get('extra_headers', {})
|
|
661
681
|
extra_headers.setdefault('User-Agent', get_user_agent())
|
|
@@ -668,8 +688,8 @@ class OpenAIResponsesModel(Model):
|
|
|
668
688
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
669
689
|
max_output_tokens=model_settings.get('max_tokens', NOT_GIVEN),
|
|
670
690
|
stream=stream,
|
|
671
|
-
temperature=
|
|
672
|
-
top_p=
|
|
691
|
+
temperature=sampling_settings.get('temperature', NOT_GIVEN),
|
|
692
|
+
top_p=sampling_settings.get('top_p', NOT_GIVEN),
|
|
673
693
|
truncation=model_settings.get('openai_truncation', NOT_GIVEN),
|
|
674
694
|
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
675
695
|
reasoning=reasoning,
|
|
@@ -805,27 +825,21 @@ class OpenAIResponsesModel(Model):
|
|
|
805
825
|
responses.ResponseInputImageParam(image_url=item.url, type='input_image', detail='auto')
|
|
806
826
|
)
|
|
807
827
|
elif isinstance(item, AudioUrl): # pragma: no cover
|
|
808
|
-
|
|
809
|
-
response = await client.get(item.url)
|
|
810
|
-
response.raise_for_status()
|
|
811
|
-
base64_encoded = base64.b64encode(response.content).decode('utf-8')
|
|
828
|
+
downloaded_item = await download_item(item, data_format='base64_uri', type_format='extension')
|
|
812
829
|
content.append(
|
|
813
830
|
responses.ResponseInputFileParam(
|
|
814
831
|
type='input_file',
|
|
815
|
-
file_data=
|
|
832
|
+
file_data=downloaded_item['data'],
|
|
833
|
+
filename=f'filename.{downloaded_item["data_type"]}',
|
|
816
834
|
)
|
|
817
835
|
)
|
|
818
836
|
elif isinstance(item, DocumentUrl):
|
|
819
|
-
|
|
820
|
-
response = await client.get(item.url)
|
|
821
|
-
response.raise_for_status()
|
|
822
|
-
base64_encoded = base64.b64encode(response.content).decode('utf-8')
|
|
823
|
-
media_type = response.headers.get('content-type').split(';')[0]
|
|
837
|
+
downloaded_item = await download_item(item, data_format='base64_uri', type_format='extension')
|
|
824
838
|
content.append(
|
|
825
839
|
responses.ResponseInputFileParam(
|
|
826
840
|
type='input_file',
|
|
827
|
-
file_data=
|
|
828
|
-
filename=f'filename.{
|
|
841
|
+
file_data=downloaded_item['data'],
|
|
842
|
+
filename=f'filename.{downloaded_item["data_type"]}',
|
|
829
843
|
)
|
|
830
844
|
)
|
|
831
845
|
elif isinstance(item, VideoUrl): # pragma: no cover
|
pydantic_ai/profiles/openai.py
CHANGED
|
@@ -15,13 +15,20 @@ class OpenAIModelProfile(ModelProfile):
|
|
|
15
15
|
ALL FIELDS MUST BE `openai_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
|
|
16
16
|
"""
|
|
17
17
|
|
|
18
|
-
# This can be set by a provider or user if the OpenAI-"compatible" API doesn't support strict tool definitions
|
|
19
18
|
openai_supports_strict_tool_definition: bool = True
|
|
19
|
+
"""This can be set by a provider or user if the OpenAI-"compatible" API doesn't support strict tool definitions."""
|
|
20
|
+
|
|
21
|
+
openai_supports_sampling_settings: bool = True
|
|
22
|
+
"""Turn off to don't send sampling settings like `temperature` and `top_p` to models that don't support them, like OpenAI's o-series reasoning models."""
|
|
20
23
|
|
|
21
24
|
|
|
22
25
|
def openai_model_profile(model_name: str) -> ModelProfile:
|
|
23
26
|
"""Get the model profile for an OpenAI model."""
|
|
24
|
-
|
|
27
|
+
is_reasoning_model = model_name.startswith('o')
|
|
28
|
+
return OpenAIModelProfile(
|
|
29
|
+
json_schema_transformer=OpenAIJsonSchemaTransformer,
|
|
30
|
+
openai_supports_sampling_settings=not is_reasoning_model,
|
|
31
|
+
)
|
|
25
32
|
|
|
26
33
|
|
|
27
34
|
_STRICT_INCOMPATIBLE_KEYS = [
|
|
@@ -48,7 +48,7 @@ class Provider(ABC, Generic[InterfaceClient]):
|
|
|
48
48
|
return None # pragma: no cover
|
|
49
49
|
|
|
50
50
|
|
|
51
|
-
def infer_provider(provider: str) -> Provider[Any]:
|
|
51
|
+
def infer_provider(provider: str) -> Provider[Any]: # noqa: C901
|
|
52
52
|
"""Infer the provider from the provider name."""
|
|
53
53
|
if provider == 'openai':
|
|
54
54
|
from .openai import OpenAIProvider
|
|
@@ -107,5 +107,9 @@ def infer_provider(provider: str) -> Provider[Any]:
|
|
|
107
107
|
from .together import TogetherProvider
|
|
108
108
|
|
|
109
109
|
return TogetherProvider()
|
|
110
|
+
elif provider == 'heroku':
|
|
111
|
+
from .heroku import HerokuProvider
|
|
112
|
+
|
|
113
|
+
return HerokuProvider()
|
|
110
114
|
else: # pragma: no cover
|
|
111
115
|
raise ValueError(f'Unknown provider: {provider}')
|
|
@@ -50,7 +50,7 @@ class GoogleVertexProvider(Provider[httpx.AsyncClient]):
|
|
|
50
50
|
return self._client
|
|
51
51
|
|
|
52
52
|
def model_profile(self, model_name: str) -> ModelProfile | None:
|
|
53
|
-
return google_model_profile(model_name)
|
|
53
|
+
return google_model_profile(model_name) # pragma: lax no cover
|
|
54
54
|
|
|
55
55
|
@overload
|
|
56
56
|
def __init__(
|