pydantic-ai-slim 0.7.4__py3-none-any.whl → 0.7.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_ai/_otel_messages.py +67 -0
- pydantic_ai/agent/__init__.py +11 -4
- pydantic_ai/builtin_tools.py +1 -0
- pydantic_ai/durable_exec/temporal/_model.py +4 -0
- pydantic_ai/messages.py +109 -18
- pydantic_ai/models/__init__.py +27 -9
- pydantic_ai/models/anthropic.py +20 -8
- pydantic_ai/models/bedrock.py +16 -10
- pydantic_ai/models/cohere.py +3 -1
- pydantic_ai/models/function.py +5 -0
- pydantic_ai/models/gemini.py +8 -1
- pydantic_ai/models/google.py +21 -4
- pydantic_ai/models/groq.py +8 -0
- pydantic_ai/models/huggingface.py +8 -0
- pydantic_ai/models/instrumented.py +103 -42
- pydantic_ai/models/mistral.py +8 -0
- pydantic_ai/models/openai.py +80 -36
- pydantic_ai/models/test.py +7 -0
- pydantic_ai/profiles/__init__.py +1 -1
- pydantic_ai/profiles/harmony.py +13 -0
- pydantic_ai/profiles/openai.py +6 -1
- pydantic_ai/profiles/qwen.py +8 -0
- pydantic_ai/providers/__init__.py +5 -1
- pydantic_ai/providers/anthropic.py +11 -8
- pydantic_ai/providers/azure.py +1 -1
- pydantic_ai/providers/cerebras.py +96 -0
- pydantic_ai/providers/cohere.py +2 -2
- pydantic_ai/providers/deepseek.py +4 -4
- pydantic_ai/providers/fireworks.py +3 -3
- pydantic_ai/providers/github.py +4 -4
- pydantic_ai/providers/grok.py +3 -3
- pydantic_ai/providers/groq.py +3 -3
- pydantic_ai/providers/heroku.py +3 -3
- pydantic_ai/providers/mistral.py +3 -3
- pydantic_ai/providers/moonshotai.py +3 -6
- pydantic_ai/providers/ollama.py +1 -1
- pydantic_ai/providers/openrouter.py +4 -4
- pydantic_ai/providers/together.py +3 -3
- pydantic_ai/providers/vercel.py +4 -4
- pydantic_ai/retries.py +154 -42
- {pydantic_ai_slim-0.7.4.dist-info → pydantic_ai_slim-0.7.6.dist-info}/METADATA +4 -4
- {pydantic_ai_slim-0.7.4.dist-info → pydantic_ai_slim-0.7.6.dist-info}/RECORD +45 -42
- {pydantic_ai_slim-0.7.4.dist-info → pydantic_ai_slim-0.7.6.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.7.4.dist-info → pydantic_ai_slim-0.7.6.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.7.4.dist-info → pydantic_ai_slim-0.7.6.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/openai.py
CHANGED
|
@@ -82,8 +82,10 @@ except ImportError as _import_error:
|
|
|
82
82
|
|
|
83
83
|
__all__ = (
|
|
84
84
|
'OpenAIModel',
|
|
85
|
+
'OpenAIChatModel',
|
|
85
86
|
'OpenAIResponsesModel',
|
|
86
87
|
'OpenAIModelSettings',
|
|
88
|
+
'OpenAIChatModelSettings',
|
|
87
89
|
'OpenAIResponsesModelSettings',
|
|
88
90
|
'OpenAIModelName',
|
|
89
91
|
)
|
|
@@ -101,7 +103,7 @@ allows this model to be used more easily with other model types (ie, Ollama, Dee
|
|
|
101
103
|
"""
|
|
102
104
|
|
|
103
105
|
|
|
104
|
-
class
|
|
106
|
+
class OpenAIChatModelSettings(ModelSettings, total=False):
|
|
105
107
|
"""Settings used for an OpenAI model request."""
|
|
106
108
|
|
|
107
109
|
# ALL FIELDS MUST BE `openai_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
|
|
@@ -139,7 +141,12 @@ class OpenAIModelSettings(ModelSettings, total=False):
|
|
|
139
141
|
"""
|
|
140
142
|
|
|
141
143
|
|
|
142
|
-
|
|
144
|
+
@deprecated('Use `OpenAIChatModelSettings` instead.')
|
|
145
|
+
class OpenAIModelSettings(OpenAIChatModelSettings, total=False):
|
|
146
|
+
"""Deprecated alias for `OpenAIChatModelSettings`."""
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class OpenAIResponsesModelSettings(OpenAIChatModelSettings, total=False):
|
|
143
150
|
"""Settings used for an OpenAI Responses model request.
|
|
144
151
|
|
|
145
152
|
ALL FIELDS MUST BE `openai_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
|
|
@@ -185,7 +192,7 @@ class OpenAIResponsesModelSettings(OpenAIModelSettings, total=False):
|
|
|
185
192
|
|
|
186
193
|
|
|
187
194
|
@dataclass(init=False)
|
|
188
|
-
class
|
|
195
|
+
class OpenAIChatModel(Model):
|
|
189
196
|
"""A model that uses the OpenAI API.
|
|
190
197
|
|
|
191
198
|
Internally, this uses the [OpenAI Python client](https://github.com/openai/openai-python) to interact with the API.
|
|
@@ -204,18 +211,20 @@ class OpenAIModel(Model):
|
|
|
204
211
|
model_name: OpenAIModelName,
|
|
205
212
|
*,
|
|
206
213
|
provider: Literal[
|
|
207
|
-
'openai',
|
|
208
|
-
'deepseek',
|
|
209
214
|
'azure',
|
|
210
|
-
'
|
|
211
|
-
'
|
|
212
|
-
'vercel',
|
|
213
|
-
'grok',
|
|
215
|
+
'deepseek',
|
|
216
|
+
'cerebras',
|
|
214
217
|
'fireworks',
|
|
215
|
-
'together',
|
|
216
|
-
'heroku',
|
|
217
218
|
'github',
|
|
219
|
+
'grok',
|
|
220
|
+
'heroku',
|
|
221
|
+
'moonshotai',
|
|
218
222
|
'ollama',
|
|
223
|
+
'openai',
|
|
224
|
+
'openai-chat',
|
|
225
|
+
'openrouter',
|
|
226
|
+
'together',
|
|
227
|
+
'vercel',
|
|
219
228
|
]
|
|
220
229
|
| Provider[AsyncOpenAI] = 'openai',
|
|
221
230
|
profile: ModelProfileSpec | None = None,
|
|
@@ -229,18 +238,20 @@ class OpenAIModel(Model):
|
|
|
229
238
|
model_name: OpenAIModelName,
|
|
230
239
|
*,
|
|
231
240
|
provider: Literal[
|
|
232
|
-
'openai',
|
|
233
|
-
'deepseek',
|
|
234
241
|
'azure',
|
|
235
|
-
'
|
|
236
|
-
'
|
|
237
|
-
'vercel',
|
|
238
|
-
'grok',
|
|
242
|
+
'deepseek',
|
|
243
|
+
'cerebras',
|
|
239
244
|
'fireworks',
|
|
240
|
-
'together',
|
|
241
|
-
'heroku',
|
|
242
245
|
'github',
|
|
246
|
+
'grok',
|
|
247
|
+
'heroku',
|
|
248
|
+
'moonshotai',
|
|
243
249
|
'ollama',
|
|
250
|
+
'openai',
|
|
251
|
+
'openai-chat',
|
|
252
|
+
'openrouter',
|
|
253
|
+
'together',
|
|
254
|
+
'vercel',
|
|
244
255
|
]
|
|
245
256
|
| Provider[AsyncOpenAI] = 'openai',
|
|
246
257
|
profile: ModelProfileSpec | None = None,
|
|
@@ -253,18 +264,20 @@ class OpenAIModel(Model):
|
|
|
253
264
|
model_name: OpenAIModelName,
|
|
254
265
|
*,
|
|
255
266
|
provider: Literal[
|
|
256
|
-
'openai',
|
|
257
|
-
'deepseek',
|
|
258
267
|
'azure',
|
|
259
|
-
'
|
|
260
|
-
'
|
|
261
|
-
'vercel',
|
|
262
|
-
'grok',
|
|
268
|
+
'deepseek',
|
|
269
|
+
'cerebras',
|
|
263
270
|
'fireworks',
|
|
264
|
-
'together',
|
|
265
|
-
'heroku',
|
|
266
271
|
'github',
|
|
272
|
+
'grok',
|
|
273
|
+
'heroku',
|
|
274
|
+
'moonshotai',
|
|
267
275
|
'ollama',
|
|
276
|
+
'openai',
|
|
277
|
+
'openai-chat',
|
|
278
|
+
'openrouter',
|
|
279
|
+
'together',
|
|
280
|
+
'vercel',
|
|
268
281
|
]
|
|
269
282
|
| Provider[AsyncOpenAI] = 'openai',
|
|
270
283
|
profile: ModelProfileSpec | None = None,
|
|
@@ -322,7 +335,7 @@ class OpenAIModel(Model):
|
|
|
322
335
|
) -> ModelResponse:
|
|
323
336
|
check_allow_model_requests()
|
|
324
337
|
response = await self._completions_create(
|
|
325
|
-
messages, False, cast(
|
|
338
|
+
messages, False, cast(OpenAIChatModelSettings, model_settings or {}), model_request_parameters
|
|
326
339
|
)
|
|
327
340
|
model_response = self._process_response(response)
|
|
328
341
|
return model_response
|
|
@@ -337,7 +350,7 @@ class OpenAIModel(Model):
|
|
|
337
350
|
) -> AsyncIterator[StreamedResponse]:
|
|
338
351
|
check_allow_model_requests()
|
|
339
352
|
response = await self._completions_create(
|
|
340
|
-
messages, True, cast(
|
|
353
|
+
messages, True, cast(OpenAIChatModelSettings, model_settings or {}), model_request_parameters
|
|
341
354
|
)
|
|
342
355
|
async with response:
|
|
343
356
|
yield await self._process_streamed_response(response, model_request_parameters)
|
|
@@ -347,7 +360,7 @@ class OpenAIModel(Model):
|
|
|
347
360
|
self,
|
|
348
361
|
messages: list[ModelMessage],
|
|
349
362
|
stream: Literal[True],
|
|
350
|
-
model_settings:
|
|
363
|
+
model_settings: OpenAIChatModelSettings,
|
|
351
364
|
model_request_parameters: ModelRequestParameters,
|
|
352
365
|
) -> AsyncStream[ChatCompletionChunk]: ...
|
|
353
366
|
|
|
@@ -356,7 +369,7 @@ class OpenAIModel(Model):
|
|
|
356
369
|
self,
|
|
357
370
|
messages: list[ModelMessage],
|
|
358
371
|
stream: Literal[False],
|
|
359
|
-
model_settings:
|
|
372
|
+
model_settings: OpenAIChatModelSettings,
|
|
360
373
|
model_request_parameters: ModelRequestParameters,
|
|
361
374
|
) -> chat.ChatCompletion: ...
|
|
362
375
|
|
|
@@ -364,7 +377,7 @@ class OpenAIModel(Model):
|
|
|
364
377
|
self,
|
|
365
378
|
messages: list[ModelMessage],
|
|
366
379
|
stream: bool,
|
|
367
|
-
model_settings:
|
|
380
|
+
model_settings: OpenAIChatModelSettings,
|
|
368
381
|
model_request_parameters: ModelRequestParameters,
|
|
369
382
|
) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
|
|
370
383
|
tools = self._get_tools(model_request_parameters)
|
|
@@ -392,10 +405,15 @@ class OpenAIModel(Model):
|
|
|
392
405
|
): # pragma: no branch
|
|
393
406
|
response_format = {'type': 'json_object'}
|
|
394
407
|
|
|
408
|
+
unsupported_model_settings = OpenAIModelProfile.from_profile(self.profile).openai_unsupported_model_settings
|
|
409
|
+
for setting in unsupported_model_settings:
|
|
410
|
+
model_settings.pop(setting, None)
|
|
411
|
+
|
|
412
|
+
# TODO(Marcelo): Deprecate this in favor of `openai_unsupported_model_settings`.
|
|
395
413
|
sampling_settings = (
|
|
396
414
|
model_settings
|
|
397
415
|
if OpenAIModelProfile.from_profile(self.profile).openai_supports_sampling_settings
|
|
398
|
-
else
|
|
416
|
+
else OpenAIChatModelSettings()
|
|
399
417
|
)
|
|
400
418
|
|
|
401
419
|
try:
|
|
@@ -500,6 +518,7 @@ class OpenAIModel(Model):
|
|
|
500
518
|
timestamp=timestamp,
|
|
501
519
|
provider_details=vendor_details,
|
|
502
520
|
provider_request_id=response.id,
|
|
521
|
+
provider_name=self._provider.name,
|
|
503
522
|
)
|
|
504
523
|
|
|
505
524
|
async def _process_streamed_response(
|
|
@@ -519,6 +538,7 @@ class OpenAIModel(Model):
|
|
|
519
538
|
_model_profile=self.profile,
|
|
520
539
|
_response=peekable_response,
|
|
521
540
|
_timestamp=number_to_datetime(first_chunk.created),
|
|
541
|
+
_provider_name=self._provider.name,
|
|
522
542
|
)
|
|
523
543
|
|
|
524
544
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
|
|
@@ -571,6 +591,8 @@ class OpenAIModel(Model):
|
|
|
571
591
|
# Note: model responses from this model should only have one text item, so the following
|
|
572
592
|
# shouldn't merge multiple texts into one unless you switch models between runs:
|
|
573
593
|
message_param['content'] = '\n\n'.join(texts)
|
|
594
|
+
else:
|
|
595
|
+
message_param['content'] = None
|
|
574
596
|
if tool_calls:
|
|
575
597
|
message_param['tool_calls'] = tool_calls
|
|
576
598
|
openai_messages.append(message_param)
|
|
@@ -632,9 +654,7 @@ class OpenAIModel(Model):
|
|
|
632
654
|
)
|
|
633
655
|
elif isinstance(part, RetryPromptPart):
|
|
634
656
|
if part.tool_name is None:
|
|
635
|
-
yield chat.ChatCompletionUserMessageParam(
|
|
636
|
-
role='user', content=part.model_response()
|
|
637
|
-
)
|
|
657
|
+
yield chat.ChatCompletionUserMessageParam(role='user', content=part.model_response())
|
|
638
658
|
else:
|
|
639
659
|
yield chat.ChatCompletionToolMessageParam(
|
|
640
660
|
role='tool',
|
|
@@ -702,6 +722,16 @@ class OpenAIModel(Model):
|
|
|
702
722
|
return chat.ChatCompletionUserMessageParam(role='user', content=content)
|
|
703
723
|
|
|
704
724
|
|
|
725
|
+
@deprecated(
|
|
726
|
+
'`OpenAIModel` was renamed to `OpenAIChatModel` to clearly distinguish it from `OpenAIResponsesModel` which '
|
|
727
|
+
"uses OpenAI's newer Responses API. Use that unless you're using an OpenAI Chat Completions-compatible API, or "
|
|
728
|
+
"require a feature that the Responses API doesn't support yet like audio."
|
|
729
|
+
)
|
|
730
|
+
@dataclass(init=False)
|
|
731
|
+
class OpenAIModel(OpenAIChatModel):
|
|
732
|
+
"""Deprecated alias for `OpenAIChatModel`."""
|
|
733
|
+
|
|
734
|
+
|
|
705
735
|
@dataclass(init=False)
|
|
706
736
|
class OpenAIResponsesModel(Model):
|
|
707
737
|
"""A model that uses the OpenAI Responses API.
|
|
@@ -803,6 +833,7 @@ class OpenAIResponsesModel(Model):
|
|
|
803
833
|
model_name=response.model,
|
|
804
834
|
provider_request_id=response.id,
|
|
805
835
|
timestamp=timestamp,
|
|
836
|
+
provider_name=self._provider.name,
|
|
806
837
|
)
|
|
807
838
|
|
|
808
839
|
async def _process_streamed_response(
|
|
@@ -822,6 +853,7 @@ class OpenAIResponsesModel(Model):
|
|
|
822
853
|
_model_name=self._model_name,
|
|
823
854
|
_response=peekable_response,
|
|
824
855
|
_timestamp=number_to_datetime(first_chunk.response.created_at),
|
|
856
|
+
_provider_name=self._provider.name,
|
|
825
857
|
)
|
|
826
858
|
|
|
827
859
|
@overload
|
|
@@ -1137,6 +1169,7 @@ class OpenAIStreamedResponse(StreamedResponse):
|
|
|
1137
1169
|
_model_profile: ModelProfile
|
|
1138
1170
|
_response: AsyncIterable[ChatCompletionChunk]
|
|
1139
1171
|
_timestamp: datetime
|
|
1172
|
+
_provider_name: str
|
|
1140
1173
|
|
|
1141
1174
|
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
1142
1175
|
async for chunk in self._response:
|
|
@@ -1180,6 +1213,11 @@ class OpenAIStreamedResponse(StreamedResponse):
|
|
|
1180
1213
|
"""Get the model name of the response."""
|
|
1181
1214
|
return self._model_name
|
|
1182
1215
|
|
|
1216
|
+
@property
|
|
1217
|
+
def provider_name(self) -> str:
|
|
1218
|
+
"""Get the provider name."""
|
|
1219
|
+
return self._provider_name
|
|
1220
|
+
|
|
1183
1221
|
@property
|
|
1184
1222
|
def timestamp(self) -> datetime:
|
|
1185
1223
|
"""Get the timestamp of the response."""
|
|
@@ -1193,6 +1231,7 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
|
|
|
1193
1231
|
_model_name: OpenAIModelName
|
|
1194
1232
|
_response: AsyncIterable[responses.ResponseStreamEvent]
|
|
1195
1233
|
_timestamp: datetime
|
|
1234
|
+
_provider_name: str
|
|
1196
1235
|
|
|
1197
1236
|
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
|
|
1198
1237
|
async for chunk in self._response:
|
|
@@ -1313,6 +1352,11 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
|
|
|
1313
1352
|
"""Get the model name of the response."""
|
|
1314
1353
|
return self._model_name
|
|
1315
1354
|
|
|
1355
|
+
@property
|
|
1356
|
+
def provider_name(self) -> str:
|
|
1357
|
+
"""Get the provider name."""
|
|
1358
|
+
return self._provider_name
|
|
1359
|
+
|
|
1316
1360
|
@property
|
|
1317
1361
|
def timestamp(self) -> datetime:
|
|
1318
1362
|
"""Get the timestamp of the response."""
|
pydantic_ai/models/test.py
CHANGED
|
@@ -131,6 +131,7 @@ class TestModel(Model):
|
|
|
131
131
|
_model_name=self._model_name,
|
|
132
132
|
_structured_response=model_response,
|
|
133
133
|
_messages=messages,
|
|
134
|
+
_provider_name=self._system,
|
|
134
135
|
)
|
|
135
136
|
|
|
136
137
|
@property
|
|
@@ -263,6 +264,7 @@ class TestStreamedResponse(StreamedResponse):
|
|
|
263
264
|
_model_name: str
|
|
264
265
|
_structured_response: ModelResponse
|
|
265
266
|
_messages: InitVar[Iterable[ModelMessage]]
|
|
267
|
+
_provider_name: str
|
|
266
268
|
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
|
|
267
269
|
|
|
268
270
|
def __post_init__(self, _messages: Iterable[ModelMessage]):
|
|
@@ -305,6 +307,11 @@ class TestStreamedResponse(StreamedResponse):
|
|
|
305
307
|
"""Get the model name of the response."""
|
|
306
308
|
return self._model_name
|
|
307
309
|
|
|
310
|
+
@property
|
|
311
|
+
def provider_name(self) -> str:
|
|
312
|
+
"""Get the provider name."""
|
|
313
|
+
return self._provider_name
|
|
314
|
+
|
|
308
315
|
@property
|
|
309
316
|
def timestamp(self) -> datetime:
|
|
310
317
|
"""Get the timestamp of the response."""
|
pydantic_ai/profiles/__init__.py
CHANGED
|
@@ -52,7 +52,7 @@ class ModelProfile:
|
|
|
52
52
|
This is a workaround for models that emit `<think>\n</think>\n\n` or an empty text part ahead of tool calls (e.g. Ollama + Qwen3),
|
|
53
53
|
which we don't want to end up treating as a final result when using `run_stream` with `str` a valid `output_type`.
|
|
54
54
|
|
|
55
|
-
This is currently only used by `
|
|
55
|
+
This is currently only used by `OpenAIChatModel`, `HuggingFaceModel`, and `GroqModel`.
|
|
56
56
|
"""
|
|
57
57
|
|
|
58
58
|
@classmethod
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
from . import ModelProfile
|
|
4
|
+
from .openai import OpenAIModelProfile, openai_model_profile
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def harmony_model_profile(model_name: str) -> ModelProfile | None:
|
|
8
|
+
"""The model profile for the OpenAI Harmony Response format.
|
|
9
|
+
|
|
10
|
+
See <https://cookbook.openai.com/articles/openai-harmony> for more details.
|
|
11
|
+
"""
|
|
12
|
+
profile = openai_model_profile(model_name)
|
|
13
|
+
return OpenAIModelProfile(openai_supports_tool_choice_required=False).update(profile)
|
pydantic_ai/profiles/openai.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import re
|
|
4
|
+
from collections.abc import Sequence
|
|
4
5
|
from dataclasses import dataclass
|
|
5
6
|
from typing import Any, Literal
|
|
6
7
|
|
|
@@ -12,7 +13,7 @@ OpenAISystemPromptRole = Literal['system', 'developer', 'user']
|
|
|
12
13
|
|
|
13
14
|
@dataclass
|
|
14
15
|
class OpenAIModelProfile(ModelProfile):
|
|
15
|
-
"""Profile for models used with
|
|
16
|
+
"""Profile for models used with `OpenAIChatModel`.
|
|
16
17
|
|
|
17
18
|
ALL FIELDS MUST BE `openai_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
|
|
18
19
|
"""
|
|
@@ -20,9 +21,13 @@ class OpenAIModelProfile(ModelProfile):
|
|
|
20
21
|
openai_supports_strict_tool_definition: bool = True
|
|
21
22
|
"""This can be set by a provider or user if the OpenAI-"compatible" API doesn't support strict tool definitions."""
|
|
22
23
|
|
|
24
|
+
# TODO(Marcelo): Deprecate this in favor of `openai_unsupported_model_settings`.
|
|
23
25
|
openai_supports_sampling_settings: bool = True
|
|
24
26
|
"""Turn off to don't send sampling settings like `temperature` and `top_p` to models that don't support them, like OpenAI's o-series reasoning models."""
|
|
25
27
|
|
|
28
|
+
openai_unsupported_model_settings: Sequence[str] = ()
|
|
29
|
+
"""A list of model settings that are not supported by the model."""
|
|
30
|
+
|
|
26
31
|
# Some OpenAI-compatible providers (e.g. MoonshotAI) currently do **not** accept
|
|
27
32
|
# `tool_choice="required"`. This flag lets the calling model know whether it's
|
|
28
33
|
# safe to pass that value along. Default is `True` to preserve existing
|
pydantic_ai/profiles/qwen.py
CHANGED
|
@@ -1,10 +1,18 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
+
from ..profiles.openai import OpenAIModelProfile
|
|
3
4
|
from . import InlineDefsJsonSchemaTransformer, ModelProfile
|
|
4
5
|
|
|
5
6
|
|
|
6
7
|
def qwen_model_profile(model_name: str) -> ModelProfile | None:
|
|
7
8
|
"""Get the model profile for a Qwen model."""
|
|
9
|
+
if model_name.startswith('qwen-3-coder'):
|
|
10
|
+
return OpenAIModelProfile(
|
|
11
|
+
json_schema_transformer=InlineDefsJsonSchemaTransformer,
|
|
12
|
+
openai_supports_tool_choice_required=False,
|
|
13
|
+
openai_supports_strict_tool_definition=False,
|
|
14
|
+
ignore_streamed_leading_whitespace=True,
|
|
15
|
+
)
|
|
8
16
|
return ModelProfile(
|
|
9
17
|
json_schema_transformer=InlineDefsJsonSchemaTransformer,
|
|
10
18
|
ignore_streamed_leading_whitespace=True,
|
|
@@ -20,7 +20,7 @@ class Provider(ABC, Generic[InterfaceClient]):
|
|
|
20
20
|
|
|
21
21
|
Each provider only supports a specific interface. A interface can be supported by multiple providers.
|
|
22
22
|
|
|
23
|
-
For example, the
|
|
23
|
+
For example, the `OpenAIChatModel` interface can be supported by the `OpenAIProvider` and the `DeepSeekProvider`.
|
|
24
24
|
"""
|
|
25
25
|
|
|
26
26
|
_client: InterfaceClient
|
|
@@ -95,6 +95,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
|
|
|
95
95
|
from .mistral import MistralProvider
|
|
96
96
|
|
|
97
97
|
return MistralProvider
|
|
98
|
+
elif provider == 'cerebras':
|
|
99
|
+
from .cerebras import CerebrasProvider
|
|
100
|
+
|
|
101
|
+
return CerebrasProvider
|
|
98
102
|
elif provider == 'cohere':
|
|
99
103
|
from .cohere import CohereProvider
|
|
100
104
|
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
|
-
from typing import overload
|
|
4
|
+
from typing import Union, overload
|
|
5
5
|
|
|
6
6
|
import httpx
|
|
7
|
+
from typing_extensions import TypeAlias
|
|
7
8
|
|
|
8
9
|
from pydantic_ai.exceptions import UserError
|
|
9
10
|
from pydantic_ai.models import cached_async_http_client
|
|
@@ -12,15 +13,18 @@ from pydantic_ai.profiles.anthropic import anthropic_model_profile
|
|
|
12
13
|
from pydantic_ai.providers import Provider
|
|
13
14
|
|
|
14
15
|
try:
|
|
15
|
-
from anthropic import AsyncAnthropic
|
|
16
|
-
except ImportError as _import_error:
|
|
16
|
+
from anthropic import AsyncAnthropic, AsyncAnthropicBedrock
|
|
17
|
+
except ImportError as _import_error:
|
|
17
18
|
raise ImportError(
|
|
18
19
|
'Please install the `anthropic` package to use the Anthropic provider, '
|
|
19
20
|
'you can use the `anthropic` optional group — `pip install "pydantic-ai-slim[anthropic]"`'
|
|
20
21
|
) from _import_error
|
|
21
22
|
|
|
22
23
|
|
|
23
|
-
|
|
24
|
+
AsyncAnthropicClient: TypeAlias = Union[AsyncAnthropic, AsyncAnthropicBedrock]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class AnthropicProvider(Provider[AsyncAnthropicClient]):
|
|
24
28
|
"""Provider for Anthropic API."""
|
|
25
29
|
|
|
26
30
|
@property
|
|
@@ -32,14 +36,14 @@ class AnthropicProvider(Provider[AsyncAnthropic]):
|
|
|
32
36
|
return str(self._client.base_url)
|
|
33
37
|
|
|
34
38
|
@property
|
|
35
|
-
def client(self) ->
|
|
39
|
+
def client(self) -> AsyncAnthropicClient:
|
|
36
40
|
return self._client
|
|
37
41
|
|
|
38
42
|
def model_profile(self, model_name: str) -> ModelProfile | None:
|
|
39
43
|
return anthropic_model_profile(model_name)
|
|
40
44
|
|
|
41
45
|
@overload
|
|
42
|
-
def __init__(self, *, anthropic_client:
|
|
46
|
+
def __init__(self, *, anthropic_client: AsyncAnthropicClient | None = None) -> None: ...
|
|
43
47
|
|
|
44
48
|
@overload
|
|
45
49
|
def __init__(self, *, api_key: str | None = None, http_client: httpx.AsyncClient | None = None) -> None: ...
|
|
@@ -48,7 +52,7 @@ class AnthropicProvider(Provider[AsyncAnthropic]):
|
|
|
48
52
|
self,
|
|
49
53
|
*,
|
|
50
54
|
api_key: str | None = None,
|
|
51
|
-
anthropic_client:
|
|
55
|
+
anthropic_client: AsyncAnthropicClient | None = None,
|
|
52
56
|
http_client: httpx.AsyncClient | None = None,
|
|
53
57
|
) -> None:
|
|
54
58
|
"""Create a new Anthropic provider.
|
|
@@ -71,7 +75,6 @@ class AnthropicProvider(Provider[AsyncAnthropic]):
|
|
|
71
75
|
'Set the `ANTHROPIC_API_KEY` environment variable or pass it via `AnthropicProvider(api_key=...)`'
|
|
72
76
|
'to use the Anthropic provider.'
|
|
73
77
|
)
|
|
74
|
-
|
|
75
78
|
if http_client is not None:
|
|
76
79
|
self._client = AsyncAnthropic(api_key=api_key, http_client=http_client)
|
|
77
80
|
else:
|
pydantic_ai/providers/azure.py
CHANGED
|
@@ -65,7 +65,7 @@ class AzureProvider(Provider[AsyncOpenAI]):
|
|
|
65
65
|
|
|
66
66
|
profile = profile_func(model_name)
|
|
67
67
|
|
|
68
|
-
# As AzureProvider is always used with
|
|
68
|
+
# As AzureProvider is always used with OpenAIChatModel, which used to unconditionally use OpenAIJsonSchemaTransformer,
|
|
69
69
|
# we need to maintain that behavior unless json_schema_transformer is set explicitly
|
|
70
70
|
return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile)
|
|
71
71
|
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import overload
|
|
5
|
+
|
|
6
|
+
import httpx
|
|
7
|
+
|
|
8
|
+
from pydantic_ai.exceptions import UserError
|
|
9
|
+
from pydantic_ai.models import cached_async_http_client
|
|
10
|
+
from pydantic_ai.profiles import ModelProfile
|
|
11
|
+
from pydantic_ai.profiles.harmony import harmony_model_profile
|
|
12
|
+
from pydantic_ai.profiles.meta import meta_model_profile
|
|
13
|
+
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile
|
|
14
|
+
from pydantic_ai.profiles.qwen import qwen_model_profile
|
|
15
|
+
from pydantic_ai.providers import Provider
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
from openai import AsyncOpenAI
|
|
19
|
+
except ImportError as _import_error: # pragma: no cover
|
|
20
|
+
raise ImportError(
|
|
21
|
+
'Please install the `openai` package to use the Cerebras provider, '
|
|
22
|
+
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
|
|
23
|
+
) from _import_error
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class CerebrasProvider(Provider[AsyncOpenAI]):
|
|
27
|
+
"""Provider for Cerebras API."""
|
|
28
|
+
|
|
29
|
+
@property
|
|
30
|
+
def name(self) -> str:
|
|
31
|
+
return 'cerebras'
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def base_url(self) -> str:
|
|
35
|
+
return 'https://api.cerebras.ai/v1'
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def client(self) -> AsyncOpenAI:
|
|
39
|
+
return self._client
|
|
40
|
+
|
|
41
|
+
def model_profile(self, model_name: str) -> ModelProfile | None:
|
|
42
|
+
prefix_to_profile = {'llama': meta_model_profile, 'qwen': qwen_model_profile, 'gpt-oss': harmony_model_profile}
|
|
43
|
+
|
|
44
|
+
profile = None
|
|
45
|
+
for prefix, profile_func in prefix_to_profile.items():
|
|
46
|
+
model_name = model_name.lower()
|
|
47
|
+
if model_name.startswith(prefix):
|
|
48
|
+
profile = profile_func(model_name)
|
|
49
|
+
|
|
50
|
+
# According to https://inference-docs.cerebras.ai/resources/openai#currently-unsupported-openai-features,
|
|
51
|
+
# Cerebras doesn't support some model settings.
|
|
52
|
+
unsupported_model_settings = (
|
|
53
|
+
'frequency_penalty',
|
|
54
|
+
'logit_bias',
|
|
55
|
+
'presence_penalty',
|
|
56
|
+
'parallel_tool_calls',
|
|
57
|
+
'service_tier',
|
|
58
|
+
)
|
|
59
|
+
return OpenAIModelProfile(
|
|
60
|
+
json_schema_transformer=OpenAIJsonSchemaTransformer,
|
|
61
|
+
openai_unsupported_model_settings=unsupported_model_settings,
|
|
62
|
+
).update(profile)
|
|
63
|
+
|
|
64
|
+
@overload
|
|
65
|
+
def __init__(self) -> None: ...
|
|
66
|
+
|
|
67
|
+
@overload
|
|
68
|
+
def __init__(self, *, api_key: str) -> None: ...
|
|
69
|
+
|
|
70
|
+
@overload
|
|
71
|
+
def __init__(self, *, api_key: str, http_client: httpx.AsyncClient) -> None: ...
|
|
72
|
+
|
|
73
|
+
@overload
|
|
74
|
+
def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
|
|
75
|
+
|
|
76
|
+
def __init__(
|
|
77
|
+
self,
|
|
78
|
+
*,
|
|
79
|
+
api_key: str | None = None,
|
|
80
|
+
openai_client: AsyncOpenAI | None = None,
|
|
81
|
+
http_client: httpx.AsyncClient | None = None,
|
|
82
|
+
) -> None:
|
|
83
|
+
api_key = api_key or os.getenv('CEREBRAS_API_KEY')
|
|
84
|
+
if not api_key and openai_client is None:
|
|
85
|
+
raise UserError(
|
|
86
|
+
'Set the `CEREBRAS_API_KEY` environment variable or pass it via `CerebrasProvider(api_key=...)` '
|
|
87
|
+
'to use the Cerebras provider.'
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
if openai_client is not None:
|
|
91
|
+
self._client = openai_client
|
|
92
|
+
elif http_client is not None:
|
|
93
|
+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
|
|
94
|
+
else:
|
|
95
|
+
http_client = cached_async_http_client(provider='cerebras')
|
|
96
|
+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
|
pydantic_ai/providers/cohere.py
CHANGED
|
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
|
|
5
|
-
|
|
5
|
+
import httpx
|
|
6
6
|
|
|
7
7
|
from pydantic_ai.exceptions import UserError
|
|
8
8
|
from pydantic_ai.models import cached_async_http_client
|
|
@@ -43,7 +43,7 @@ class CohereProvider(Provider[AsyncClientV2]):
|
|
|
43
43
|
*,
|
|
44
44
|
api_key: str | None = None,
|
|
45
45
|
cohere_client: AsyncClientV2 | None = None,
|
|
46
|
-
http_client:
|
|
46
|
+
http_client: httpx.AsyncClient | None = None,
|
|
47
47
|
) -> None:
|
|
48
48
|
"""Create a new Cohere provider.
|
|
49
49
|
|
|
@@ -3,7 +3,7 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
import os
|
|
4
4
|
from typing import overload
|
|
5
5
|
|
|
6
|
-
|
|
6
|
+
import httpx
|
|
7
7
|
from openai import AsyncOpenAI
|
|
8
8
|
|
|
9
9
|
from pydantic_ai.exceptions import UserError
|
|
@@ -40,7 +40,7 @@ class DeepSeekProvider(Provider[AsyncOpenAI]):
|
|
|
40
40
|
def model_profile(self, model_name: str) -> ModelProfile | None:
|
|
41
41
|
profile = deepseek_model_profile(model_name)
|
|
42
42
|
|
|
43
|
-
# As DeepSeekProvider is always used with
|
|
43
|
+
# As DeepSeekProvider is always used with OpenAIChatModel, which used to unconditionally use OpenAIJsonSchemaTransformer,
|
|
44
44
|
# we need to maintain that behavior unless json_schema_transformer is set explicitly.
|
|
45
45
|
# This was not the case when using a DeepSeek model with another model class (e.g. BedrockConverseModel or GroqModel),
|
|
46
46
|
# so we won't do this in `deepseek_model_profile` unless we learn it's always needed.
|
|
@@ -53,7 +53,7 @@ class DeepSeekProvider(Provider[AsyncOpenAI]):
|
|
|
53
53
|
def __init__(self, *, api_key: str) -> None: ...
|
|
54
54
|
|
|
55
55
|
@overload
|
|
56
|
-
def __init__(self, *, api_key: str, http_client:
|
|
56
|
+
def __init__(self, *, api_key: str, http_client: httpx.AsyncClient) -> None: ...
|
|
57
57
|
|
|
58
58
|
@overload
|
|
59
59
|
def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
|
|
@@ -63,7 +63,7 @@ class DeepSeekProvider(Provider[AsyncOpenAI]):
|
|
|
63
63
|
*,
|
|
64
64
|
api_key: str | None = None,
|
|
65
65
|
openai_client: AsyncOpenAI | None = None,
|
|
66
|
-
http_client:
|
|
66
|
+
http_client: httpx.AsyncClient | None = None,
|
|
67
67
|
) -> None:
|
|
68
68
|
api_key = api_key or os.getenv('DEEPSEEK_API_KEY')
|
|
69
69
|
if not api_key and openai_client is None:
|
|
@@ -3,7 +3,7 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
import os
|
|
4
4
|
from typing import overload
|
|
5
5
|
|
|
6
|
-
|
|
6
|
+
import httpx
|
|
7
7
|
from openai import AsyncOpenAI
|
|
8
8
|
|
|
9
9
|
from pydantic_ai.exceptions import UserError
|
|
@@ -71,7 +71,7 @@ class FireworksProvider(Provider[AsyncOpenAI]):
|
|
|
71
71
|
def __init__(self, *, api_key: str) -> None: ...
|
|
72
72
|
|
|
73
73
|
@overload
|
|
74
|
-
def __init__(self, *, api_key: str, http_client:
|
|
74
|
+
def __init__(self, *, api_key: str, http_client: httpx.AsyncClient) -> None: ...
|
|
75
75
|
|
|
76
76
|
@overload
|
|
77
77
|
def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
|
|
@@ -81,7 +81,7 @@ class FireworksProvider(Provider[AsyncOpenAI]):
|
|
|
81
81
|
*,
|
|
82
82
|
api_key: str | None = None,
|
|
83
83
|
openai_client: AsyncOpenAI | None = None,
|
|
84
|
-
http_client:
|
|
84
|
+
http_client: httpx.AsyncClient | None = None,
|
|
85
85
|
) -> None:
|
|
86
86
|
api_key = api_key or os.getenv('FIREWORKS_API_KEY')
|
|
87
87
|
if not api_key and openai_client is None:
|