pydantic-ai-slim 0.7.4__py3-none-any.whl → 0.7.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. pydantic_ai/_otel_messages.py +67 -0
  2. pydantic_ai/agent/__init__.py +11 -4
  3. pydantic_ai/builtin_tools.py +1 -0
  4. pydantic_ai/durable_exec/temporal/_model.py +4 -0
  5. pydantic_ai/messages.py +109 -18
  6. pydantic_ai/models/__init__.py +27 -9
  7. pydantic_ai/models/anthropic.py +20 -8
  8. pydantic_ai/models/bedrock.py +16 -10
  9. pydantic_ai/models/cohere.py +3 -1
  10. pydantic_ai/models/function.py +5 -0
  11. pydantic_ai/models/gemini.py +8 -1
  12. pydantic_ai/models/google.py +21 -4
  13. pydantic_ai/models/groq.py +8 -0
  14. pydantic_ai/models/huggingface.py +8 -0
  15. pydantic_ai/models/instrumented.py +103 -42
  16. pydantic_ai/models/mistral.py +8 -0
  17. pydantic_ai/models/openai.py +80 -36
  18. pydantic_ai/models/test.py +7 -0
  19. pydantic_ai/profiles/__init__.py +1 -1
  20. pydantic_ai/profiles/harmony.py +13 -0
  21. pydantic_ai/profiles/openai.py +6 -1
  22. pydantic_ai/profiles/qwen.py +8 -0
  23. pydantic_ai/providers/__init__.py +5 -1
  24. pydantic_ai/providers/anthropic.py +11 -8
  25. pydantic_ai/providers/azure.py +1 -1
  26. pydantic_ai/providers/cerebras.py +96 -0
  27. pydantic_ai/providers/cohere.py +2 -2
  28. pydantic_ai/providers/deepseek.py +4 -4
  29. pydantic_ai/providers/fireworks.py +3 -3
  30. pydantic_ai/providers/github.py +4 -4
  31. pydantic_ai/providers/grok.py +3 -3
  32. pydantic_ai/providers/groq.py +3 -3
  33. pydantic_ai/providers/heroku.py +3 -3
  34. pydantic_ai/providers/mistral.py +3 -3
  35. pydantic_ai/providers/moonshotai.py +3 -6
  36. pydantic_ai/providers/ollama.py +1 -1
  37. pydantic_ai/providers/openrouter.py +4 -4
  38. pydantic_ai/providers/together.py +3 -3
  39. pydantic_ai/providers/vercel.py +4 -4
  40. pydantic_ai/retries.py +154 -42
  41. {pydantic_ai_slim-0.7.4.dist-info → pydantic_ai_slim-0.7.6.dist-info}/METADATA +4 -4
  42. {pydantic_ai_slim-0.7.4.dist-info → pydantic_ai_slim-0.7.6.dist-info}/RECORD +45 -42
  43. {pydantic_ai_slim-0.7.4.dist-info → pydantic_ai_slim-0.7.6.dist-info}/WHEEL +0 -0
  44. {pydantic_ai_slim-0.7.4.dist-info → pydantic_ai_slim-0.7.6.dist-info}/entry_points.txt +0 -0
  45. {pydantic_ai_slim-0.7.4.dist-info → pydantic_ai_slim-0.7.6.dist-info}/licenses/LICENSE +0 -0
@@ -82,8 +82,10 @@ except ImportError as _import_error:
82
82
 
83
83
  __all__ = (
84
84
  'OpenAIModel',
85
+ 'OpenAIChatModel',
85
86
  'OpenAIResponsesModel',
86
87
  'OpenAIModelSettings',
88
+ 'OpenAIChatModelSettings',
87
89
  'OpenAIResponsesModelSettings',
88
90
  'OpenAIModelName',
89
91
  )
@@ -101,7 +103,7 @@ allows this model to be used more easily with other model types (ie, Ollama, Dee
101
103
  """
102
104
 
103
105
 
104
- class OpenAIModelSettings(ModelSettings, total=False):
106
+ class OpenAIChatModelSettings(ModelSettings, total=False):
105
107
  """Settings used for an OpenAI model request."""
106
108
 
107
109
  # ALL FIELDS MUST BE `openai_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
@@ -139,7 +141,12 @@ class OpenAIModelSettings(ModelSettings, total=False):
139
141
  """
140
142
 
141
143
 
142
- class OpenAIResponsesModelSettings(OpenAIModelSettings, total=False):
144
+ @deprecated('Use `OpenAIChatModelSettings` instead.')
145
+ class OpenAIModelSettings(OpenAIChatModelSettings, total=False):
146
+ """Deprecated alias for `OpenAIChatModelSettings`."""
147
+
148
+
149
+ class OpenAIResponsesModelSettings(OpenAIChatModelSettings, total=False):
143
150
  """Settings used for an OpenAI Responses model request.
144
151
 
145
152
  ALL FIELDS MUST BE `openai_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
@@ -185,7 +192,7 @@ class OpenAIResponsesModelSettings(OpenAIModelSettings, total=False):
185
192
 
186
193
 
187
194
  @dataclass(init=False)
188
- class OpenAIModel(Model):
195
+ class OpenAIChatModel(Model):
189
196
  """A model that uses the OpenAI API.
190
197
 
191
198
  Internally, this uses the [OpenAI Python client](https://github.com/openai/openai-python) to interact with the API.
@@ -204,18 +211,20 @@ class OpenAIModel(Model):
204
211
  model_name: OpenAIModelName,
205
212
  *,
206
213
  provider: Literal[
207
- 'openai',
208
- 'deepseek',
209
214
  'azure',
210
- 'openrouter',
211
- 'moonshotai',
212
- 'vercel',
213
- 'grok',
215
+ 'deepseek',
216
+ 'cerebras',
214
217
  'fireworks',
215
- 'together',
216
- 'heroku',
217
218
  'github',
219
+ 'grok',
220
+ 'heroku',
221
+ 'moonshotai',
218
222
  'ollama',
223
+ 'openai',
224
+ 'openai-chat',
225
+ 'openrouter',
226
+ 'together',
227
+ 'vercel',
219
228
  ]
220
229
  | Provider[AsyncOpenAI] = 'openai',
221
230
  profile: ModelProfileSpec | None = None,
@@ -229,18 +238,20 @@ class OpenAIModel(Model):
229
238
  model_name: OpenAIModelName,
230
239
  *,
231
240
  provider: Literal[
232
- 'openai',
233
- 'deepseek',
234
241
  'azure',
235
- 'openrouter',
236
- 'moonshotai',
237
- 'vercel',
238
- 'grok',
242
+ 'deepseek',
243
+ 'cerebras',
239
244
  'fireworks',
240
- 'together',
241
- 'heroku',
242
245
  'github',
246
+ 'grok',
247
+ 'heroku',
248
+ 'moonshotai',
243
249
  'ollama',
250
+ 'openai',
251
+ 'openai-chat',
252
+ 'openrouter',
253
+ 'together',
254
+ 'vercel',
244
255
  ]
245
256
  | Provider[AsyncOpenAI] = 'openai',
246
257
  profile: ModelProfileSpec | None = None,
@@ -253,18 +264,20 @@ class OpenAIModel(Model):
253
264
  model_name: OpenAIModelName,
254
265
  *,
255
266
  provider: Literal[
256
- 'openai',
257
- 'deepseek',
258
267
  'azure',
259
- 'openrouter',
260
- 'moonshotai',
261
- 'vercel',
262
- 'grok',
268
+ 'deepseek',
269
+ 'cerebras',
263
270
  'fireworks',
264
- 'together',
265
- 'heroku',
266
271
  'github',
272
+ 'grok',
273
+ 'heroku',
274
+ 'moonshotai',
267
275
  'ollama',
276
+ 'openai',
277
+ 'openai-chat',
278
+ 'openrouter',
279
+ 'together',
280
+ 'vercel',
268
281
  ]
269
282
  | Provider[AsyncOpenAI] = 'openai',
270
283
  profile: ModelProfileSpec | None = None,
@@ -322,7 +335,7 @@ class OpenAIModel(Model):
322
335
  ) -> ModelResponse:
323
336
  check_allow_model_requests()
324
337
  response = await self._completions_create(
325
- messages, False, cast(OpenAIModelSettings, model_settings or {}), model_request_parameters
338
+ messages, False, cast(OpenAIChatModelSettings, model_settings or {}), model_request_parameters
326
339
  )
327
340
  model_response = self._process_response(response)
328
341
  return model_response
@@ -337,7 +350,7 @@ class OpenAIModel(Model):
337
350
  ) -> AsyncIterator[StreamedResponse]:
338
351
  check_allow_model_requests()
339
352
  response = await self._completions_create(
340
- messages, True, cast(OpenAIModelSettings, model_settings or {}), model_request_parameters
353
+ messages, True, cast(OpenAIChatModelSettings, model_settings or {}), model_request_parameters
341
354
  )
342
355
  async with response:
343
356
  yield await self._process_streamed_response(response, model_request_parameters)
@@ -347,7 +360,7 @@ class OpenAIModel(Model):
347
360
  self,
348
361
  messages: list[ModelMessage],
349
362
  stream: Literal[True],
350
- model_settings: OpenAIModelSettings,
363
+ model_settings: OpenAIChatModelSettings,
351
364
  model_request_parameters: ModelRequestParameters,
352
365
  ) -> AsyncStream[ChatCompletionChunk]: ...
353
366
 
@@ -356,7 +369,7 @@ class OpenAIModel(Model):
356
369
  self,
357
370
  messages: list[ModelMessage],
358
371
  stream: Literal[False],
359
- model_settings: OpenAIModelSettings,
372
+ model_settings: OpenAIChatModelSettings,
360
373
  model_request_parameters: ModelRequestParameters,
361
374
  ) -> chat.ChatCompletion: ...
362
375
 
@@ -364,7 +377,7 @@ class OpenAIModel(Model):
364
377
  self,
365
378
  messages: list[ModelMessage],
366
379
  stream: bool,
367
- model_settings: OpenAIModelSettings,
380
+ model_settings: OpenAIChatModelSettings,
368
381
  model_request_parameters: ModelRequestParameters,
369
382
  ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
370
383
  tools = self._get_tools(model_request_parameters)
@@ -392,10 +405,15 @@ class OpenAIModel(Model):
392
405
  ): # pragma: no branch
393
406
  response_format = {'type': 'json_object'}
394
407
 
408
+ unsupported_model_settings = OpenAIModelProfile.from_profile(self.profile).openai_unsupported_model_settings
409
+ for setting in unsupported_model_settings:
410
+ model_settings.pop(setting, None)
411
+
412
+ # TODO(Marcelo): Deprecate this in favor of `openai_unsupported_model_settings`.
395
413
  sampling_settings = (
396
414
  model_settings
397
415
  if OpenAIModelProfile.from_profile(self.profile).openai_supports_sampling_settings
398
- else OpenAIModelSettings()
416
+ else OpenAIChatModelSettings()
399
417
  )
400
418
 
401
419
  try:
@@ -500,6 +518,7 @@ class OpenAIModel(Model):
500
518
  timestamp=timestamp,
501
519
  provider_details=vendor_details,
502
520
  provider_request_id=response.id,
521
+ provider_name=self._provider.name,
503
522
  )
504
523
 
505
524
  async def _process_streamed_response(
@@ -519,6 +538,7 @@ class OpenAIModel(Model):
519
538
  _model_profile=self.profile,
520
539
  _response=peekable_response,
521
540
  _timestamp=number_to_datetime(first_chunk.created),
541
+ _provider_name=self._provider.name,
522
542
  )
523
543
 
524
544
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
@@ -571,6 +591,8 @@ class OpenAIModel(Model):
571
591
  # Note: model responses from this model should only have one text item, so the following
572
592
  # shouldn't merge multiple texts into one unless you switch models between runs:
573
593
  message_param['content'] = '\n\n'.join(texts)
594
+ else:
595
+ message_param['content'] = None
574
596
  if tool_calls:
575
597
  message_param['tool_calls'] = tool_calls
576
598
  openai_messages.append(message_param)
@@ -632,9 +654,7 @@ class OpenAIModel(Model):
632
654
  )
633
655
  elif isinstance(part, RetryPromptPart):
634
656
  if part.tool_name is None:
635
- yield chat.ChatCompletionUserMessageParam( # pragma: no cover
636
- role='user', content=part.model_response()
637
- )
657
+ yield chat.ChatCompletionUserMessageParam(role='user', content=part.model_response())
638
658
  else:
639
659
  yield chat.ChatCompletionToolMessageParam(
640
660
  role='tool',
@@ -702,6 +722,16 @@ class OpenAIModel(Model):
702
722
  return chat.ChatCompletionUserMessageParam(role='user', content=content)
703
723
 
704
724
 
725
+ @deprecated(
726
+ '`OpenAIModel` was renamed to `OpenAIChatModel` to clearly distinguish it from `OpenAIResponsesModel` which '
727
+ "uses OpenAI's newer Responses API. Use that unless you're using an OpenAI Chat Completions-compatible API, or "
728
+ "require a feature that the Responses API doesn't support yet like audio."
729
+ )
730
+ @dataclass(init=False)
731
+ class OpenAIModel(OpenAIChatModel):
732
+ """Deprecated alias for `OpenAIChatModel`."""
733
+
734
+
705
735
  @dataclass(init=False)
706
736
  class OpenAIResponsesModel(Model):
707
737
  """A model that uses the OpenAI Responses API.
@@ -803,6 +833,7 @@ class OpenAIResponsesModel(Model):
803
833
  model_name=response.model,
804
834
  provider_request_id=response.id,
805
835
  timestamp=timestamp,
836
+ provider_name=self._provider.name,
806
837
  )
807
838
 
808
839
  async def _process_streamed_response(
@@ -822,6 +853,7 @@ class OpenAIResponsesModel(Model):
822
853
  _model_name=self._model_name,
823
854
  _response=peekable_response,
824
855
  _timestamp=number_to_datetime(first_chunk.response.created_at),
856
+ _provider_name=self._provider.name,
825
857
  )
826
858
 
827
859
  @overload
@@ -1137,6 +1169,7 @@ class OpenAIStreamedResponse(StreamedResponse):
1137
1169
  _model_profile: ModelProfile
1138
1170
  _response: AsyncIterable[ChatCompletionChunk]
1139
1171
  _timestamp: datetime
1172
+ _provider_name: str
1140
1173
 
1141
1174
  async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
1142
1175
  async for chunk in self._response:
@@ -1180,6 +1213,11 @@ class OpenAIStreamedResponse(StreamedResponse):
1180
1213
  """Get the model name of the response."""
1181
1214
  return self._model_name
1182
1215
 
1216
+ @property
1217
+ def provider_name(self) -> str:
1218
+ """Get the provider name."""
1219
+ return self._provider_name
1220
+
1183
1221
  @property
1184
1222
  def timestamp(self) -> datetime:
1185
1223
  """Get the timestamp of the response."""
@@ -1193,6 +1231,7 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
1193
1231
  _model_name: OpenAIModelName
1194
1232
  _response: AsyncIterable[responses.ResponseStreamEvent]
1195
1233
  _timestamp: datetime
1234
+ _provider_name: str
1196
1235
 
1197
1236
  async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
1198
1237
  async for chunk in self._response:
@@ -1313,6 +1352,11 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
1313
1352
  """Get the model name of the response."""
1314
1353
  return self._model_name
1315
1354
 
1355
+ @property
1356
+ def provider_name(self) -> str:
1357
+ """Get the provider name."""
1358
+ return self._provider_name
1359
+
1316
1360
  @property
1317
1361
  def timestamp(self) -> datetime:
1318
1362
  """Get the timestamp of the response."""
@@ -131,6 +131,7 @@ class TestModel(Model):
131
131
  _model_name=self._model_name,
132
132
  _structured_response=model_response,
133
133
  _messages=messages,
134
+ _provider_name=self._system,
134
135
  )
135
136
 
136
137
  @property
@@ -263,6 +264,7 @@ class TestStreamedResponse(StreamedResponse):
263
264
  _model_name: str
264
265
  _structured_response: ModelResponse
265
266
  _messages: InitVar[Iterable[ModelMessage]]
267
+ _provider_name: str
266
268
  _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
267
269
 
268
270
  def __post_init__(self, _messages: Iterable[ModelMessage]):
@@ -305,6 +307,11 @@ class TestStreamedResponse(StreamedResponse):
305
307
  """Get the model name of the response."""
306
308
  return self._model_name
307
309
 
310
+ @property
311
+ def provider_name(self) -> str:
312
+ """Get the provider name."""
313
+ return self._provider_name
314
+
308
315
  @property
309
316
  def timestamp(self) -> datetime:
310
317
  """Get the timestamp of the response."""
@@ -52,7 +52,7 @@ class ModelProfile:
52
52
  This is a workaround for models that emit `<think>\n</think>\n\n` or an empty text part ahead of tool calls (e.g. Ollama + Qwen3),
53
53
  which we don't want to end up treating as a final result when using `run_stream` with `str` a valid `output_type`.
54
54
 
55
- This is currently only used by `OpenAIModel`, `HuggingFaceModel`, and `GroqModel`.
55
+ This is currently only used by `OpenAIChatModel`, `HuggingFaceModel`, and `GroqModel`.
56
56
  """
57
57
 
58
58
  @classmethod
@@ -0,0 +1,13 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ from . import ModelProfile
4
+ from .openai import OpenAIModelProfile, openai_model_profile
5
+
6
+
7
+ def harmony_model_profile(model_name: str) -> ModelProfile | None:
8
+ """The model profile for the OpenAI Harmony Response format.
9
+
10
+ See <https://cookbook.openai.com/articles/openai-harmony> for more details.
11
+ """
12
+ profile = openai_model_profile(model_name)
13
+ return OpenAIModelProfile(openai_supports_tool_choice_required=False).update(profile)
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import re
4
+ from collections.abc import Sequence
4
5
  from dataclasses import dataclass
5
6
  from typing import Any, Literal
6
7
 
@@ -12,7 +13,7 @@ OpenAISystemPromptRole = Literal['system', 'developer', 'user']
12
13
 
13
14
  @dataclass
14
15
  class OpenAIModelProfile(ModelProfile):
15
- """Profile for models used with OpenAIModel.
16
+ """Profile for models used with `OpenAIChatModel`.
16
17
 
17
18
  ALL FIELDS MUST BE `openai_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
18
19
  """
@@ -20,9 +21,13 @@ class OpenAIModelProfile(ModelProfile):
20
21
  openai_supports_strict_tool_definition: bool = True
21
22
  """This can be set by a provider or user if the OpenAI-"compatible" API doesn't support strict tool definitions."""
22
23
 
24
+ # TODO(Marcelo): Deprecate this in favor of `openai_unsupported_model_settings`.
23
25
  openai_supports_sampling_settings: bool = True
24
26
  """Turn off to don't send sampling settings like `temperature` and `top_p` to models that don't support them, like OpenAI's o-series reasoning models."""
25
27
 
28
+ openai_unsupported_model_settings: Sequence[str] = ()
29
+ """A list of model settings that are not supported by the model."""
30
+
26
31
  # Some OpenAI-compatible providers (e.g. MoonshotAI) currently do **not** accept
27
32
  # `tool_choice="required"`. This flag lets the calling model know whether it's
28
33
  # safe to pass that value along. Default is `True` to preserve existing
@@ -1,10 +1,18 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
+ from ..profiles.openai import OpenAIModelProfile
3
4
  from . import InlineDefsJsonSchemaTransformer, ModelProfile
4
5
 
5
6
 
6
7
  def qwen_model_profile(model_name: str) -> ModelProfile | None:
7
8
  """Get the model profile for a Qwen model."""
9
+ if model_name.startswith('qwen-3-coder'):
10
+ return OpenAIModelProfile(
11
+ json_schema_transformer=InlineDefsJsonSchemaTransformer,
12
+ openai_supports_tool_choice_required=False,
13
+ openai_supports_strict_tool_definition=False,
14
+ ignore_streamed_leading_whitespace=True,
15
+ )
8
16
  return ModelProfile(
9
17
  json_schema_transformer=InlineDefsJsonSchemaTransformer,
10
18
  ignore_streamed_leading_whitespace=True,
@@ -20,7 +20,7 @@ class Provider(ABC, Generic[InterfaceClient]):
20
20
 
21
21
  Each provider only supports a specific interface. A interface can be supported by multiple providers.
22
22
 
23
- For example, the OpenAIModel interface can be supported by the OpenAIProvider and the DeepSeekProvider.
23
+ For example, the `OpenAIChatModel` interface can be supported by the `OpenAIProvider` and the `DeepSeekProvider`.
24
24
  """
25
25
 
26
26
  _client: InterfaceClient
@@ -95,6 +95,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
95
95
  from .mistral import MistralProvider
96
96
 
97
97
  return MistralProvider
98
+ elif provider == 'cerebras':
99
+ from .cerebras import CerebrasProvider
100
+
101
+ return CerebrasProvider
98
102
  elif provider == 'cohere':
99
103
  from .cohere import CohereProvider
100
104
 
@@ -1,9 +1,10 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import os
4
- from typing import overload
4
+ from typing import Union, overload
5
5
 
6
6
  import httpx
7
+ from typing_extensions import TypeAlias
7
8
 
8
9
  from pydantic_ai.exceptions import UserError
9
10
  from pydantic_ai.models import cached_async_http_client
@@ -12,15 +13,18 @@ from pydantic_ai.profiles.anthropic import anthropic_model_profile
12
13
  from pydantic_ai.providers import Provider
13
14
 
14
15
  try:
15
- from anthropic import AsyncAnthropic
16
- except ImportError as _import_error: # pragma: no cover
16
+ from anthropic import AsyncAnthropic, AsyncAnthropicBedrock
17
+ except ImportError as _import_error:
17
18
  raise ImportError(
18
19
  'Please install the `anthropic` package to use the Anthropic provider, '
19
20
  'you can use the `anthropic` optional group — `pip install "pydantic-ai-slim[anthropic]"`'
20
21
  ) from _import_error
21
22
 
22
23
 
23
- class AnthropicProvider(Provider[AsyncAnthropic]):
24
+ AsyncAnthropicClient: TypeAlias = Union[AsyncAnthropic, AsyncAnthropicBedrock]
25
+
26
+
27
+ class AnthropicProvider(Provider[AsyncAnthropicClient]):
24
28
  """Provider for Anthropic API."""
25
29
 
26
30
  @property
@@ -32,14 +36,14 @@ class AnthropicProvider(Provider[AsyncAnthropic]):
32
36
  return str(self._client.base_url)
33
37
 
34
38
  @property
35
- def client(self) -> AsyncAnthropic:
39
+ def client(self) -> AsyncAnthropicClient:
36
40
  return self._client
37
41
 
38
42
  def model_profile(self, model_name: str) -> ModelProfile | None:
39
43
  return anthropic_model_profile(model_name)
40
44
 
41
45
  @overload
42
- def __init__(self, *, anthropic_client: AsyncAnthropic | None = None) -> None: ...
46
+ def __init__(self, *, anthropic_client: AsyncAnthropicClient | None = None) -> None: ...
43
47
 
44
48
  @overload
45
49
  def __init__(self, *, api_key: str | None = None, http_client: httpx.AsyncClient | None = None) -> None: ...
@@ -48,7 +52,7 @@ class AnthropicProvider(Provider[AsyncAnthropic]):
48
52
  self,
49
53
  *,
50
54
  api_key: str | None = None,
51
- anthropic_client: AsyncAnthropic | None = None,
55
+ anthropic_client: AsyncAnthropicClient | None = None,
52
56
  http_client: httpx.AsyncClient | None = None,
53
57
  ) -> None:
54
58
  """Create a new Anthropic provider.
@@ -71,7 +75,6 @@ class AnthropicProvider(Provider[AsyncAnthropic]):
71
75
  'Set the `ANTHROPIC_API_KEY` environment variable or pass it via `AnthropicProvider(api_key=...)`'
72
76
  'to use the Anthropic provider.'
73
77
  )
74
-
75
78
  if http_client is not None:
76
79
  self._client = AsyncAnthropic(api_key=api_key, http_client=http_client)
77
80
  else:
@@ -65,7 +65,7 @@ class AzureProvider(Provider[AsyncOpenAI]):
65
65
 
66
66
  profile = profile_func(model_name)
67
67
 
68
- # As AzureProvider is always used with OpenAIModel, which used to unconditionally use OpenAIJsonSchemaTransformer,
68
+ # As AzureProvider is always used with OpenAIChatModel, which used to unconditionally use OpenAIJsonSchemaTransformer,
69
69
  # we need to maintain that behavior unless json_schema_transformer is set explicitly
70
70
  return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile)
71
71
 
@@ -0,0 +1,96 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import os
4
+ from typing import overload
5
+
6
+ import httpx
7
+
8
+ from pydantic_ai.exceptions import UserError
9
+ from pydantic_ai.models import cached_async_http_client
10
+ from pydantic_ai.profiles import ModelProfile
11
+ from pydantic_ai.profiles.harmony import harmony_model_profile
12
+ from pydantic_ai.profiles.meta import meta_model_profile
13
+ from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile
14
+ from pydantic_ai.profiles.qwen import qwen_model_profile
15
+ from pydantic_ai.providers import Provider
16
+
17
+ try:
18
+ from openai import AsyncOpenAI
19
+ except ImportError as _import_error: # pragma: no cover
20
+ raise ImportError(
21
+ 'Please install the `openai` package to use the Cerebras provider, '
22
+ 'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
23
+ ) from _import_error
24
+
25
+
26
+ class CerebrasProvider(Provider[AsyncOpenAI]):
27
+ """Provider for Cerebras API."""
28
+
29
+ @property
30
+ def name(self) -> str:
31
+ return 'cerebras'
32
+
33
+ @property
34
+ def base_url(self) -> str:
35
+ return 'https://api.cerebras.ai/v1'
36
+
37
+ @property
38
+ def client(self) -> AsyncOpenAI:
39
+ return self._client
40
+
41
+ def model_profile(self, model_name: str) -> ModelProfile | None:
42
+ prefix_to_profile = {'llama': meta_model_profile, 'qwen': qwen_model_profile, 'gpt-oss': harmony_model_profile}
43
+
44
+ profile = None
45
+ for prefix, profile_func in prefix_to_profile.items():
46
+ model_name = model_name.lower()
47
+ if model_name.startswith(prefix):
48
+ profile = profile_func(model_name)
49
+
50
+ # According to https://inference-docs.cerebras.ai/resources/openai#currently-unsupported-openai-features,
51
+ # Cerebras doesn't support some model settings.
52
+ unsupported_model_settings = (
53
+ 'frequency_penalty',
54
+ 'logit_bias',
55
+ 'presence_penalty',
56
+ 'parallel_tool_calls',
57
+ 'service_tier',
58
+ )
59
+ return OpenAIModelProfile(
60
+ json_schema_transformer=OpenAIJsonSchemaTransformer,
61
+ openai_unsupported_model_settings=unsupported_model_settings,
62
+ ).update(profile)
63
+
64
+ @overload
65
+ def __init__(self) -> None: ...
66
+
67
+ @overload
68
+ def __init__(self, *, api_key: str) -> None: ...
69
+
70
+ @overload
71
+ def __init__(self, *, api_key: str, http_client: httpx.AsyncClient) -> None: ...
72
+
73
+ @overload
74
+ def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
75
+
76
+ def __init__(
77
+ self,
78
+ *,
79
+ api_key: str | None = None,
80
+ openai_client: AsyncOpenAI | None = None,
81
+ http_client: httpx.AsyncClient | None = None,
82
+ ) -> None:
83
+ api_key = api_key or os.getenv('CEREBRAS_API_KEY')
84
+ if not api_key and openai_client is None:
85
+ raise UserError(
86
+ 'Set the `CEREBRAS_API_KEY` environment variable or pass it via `CerebrasProvider(api_key=...)` '
87
+ 'to use the Cerebras provider.'
88
+ )
89
+
90
+ if openai_client is not None:
91
+ self._client = openai_client
92
+ elif http_client is not None:
93
+ self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
94
+ else:
95
+ http_client = cached_async_http_client(provider='cerebras')
96
+ self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import os
4
4
 
5
- from httpx import AsyncClient as AsyncHTTPClient
5
+ import httpx
6
6
 
7
7
  from pydantic_ai.exceptions import UserError
8
8
  from pydantic_ai.models import cached_async_http_client
@@ -43,7 +43,7 @@ class CohereProvider(Provider[AsyncClientV2]):
43
43
  *,
44
44
  api_key: str | None = None,
45
45
  cohere_client: AsyncClientV2 | None = None,
46
- http_client: AsyncHTTPClient | None = None,
46
+ http_client: httpx.AsyncClient | None = None,
47
47
  ) -> None:
48
48
  """Create a new Cohere provider.
49
49
 
@@ -3,7 +3,7 @@ from __future__ import annotations as _annotations
3
3
  import os
4
4
  from typing import overload
5
5
 
6
- from httpx import AsyncClient as AsyncHTTPClient
6
+ import httpx
7
7
  from openai import AsyncOpenAI
8
8
 
9
9
  from pydantic_ai.exceptions import UserError
@@ -40,7 +40,7 @@ class DeepSeekProvider(Provider[AsyncOpenAI]):
40
40
  def model_profile(self, model_name: str) -> ModelProfile | None:
41
41
  profile = deepseek_model_profile(model_name)
42
42
 
43
- # As DeepSeekProvider is always used with OpenAIModel, which used to unconditionally use OpenAIJsonSchemaTransformer,
43
+ # As DeepSeekProvider is always used with OpenAIChatModel, which used to unconditionally use OpenAIJsonSchemaTransformer,
44
44
  # we need to maintain that behavior unless json_schema_transformer is set explicitly.
45
45
  # This was not the case when using a DeepSeek model with another model class (e.g. BedrockConverseModel or GroqModel),
46
46
  # so we won't do this in `deepseek_model_profile` unless we learn it's always needed.
@@ -53,7 +53,7 @@ class DeepSeekProvider(Provider[AsyncOpenAI]):
53
53
  def __init__(self, *, api_key: str) -> None: ...
54
54
 
55
55
  @overload
56
- def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ...
56
+ def __init__(self, *, api_key: str, http_client: httpx.AsyncClient) -> None: ...
57
57
 
58
58
  @overload
59
59
  def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
@@ -63,7 +63,7 @@ class DeepSeekProvider(Provider[AsyncOpenAI]):
63
63
  *,
64
64
  api_key: str | None = None,
65
65
  openai_client: AsyncOpenAI | None = None,
66
- http_client: AsyncHTTPClient | None = None,
66
+ http_client: httpx.AsyncClient | None = None,
67
67
  ) -> None:
68
68
  api_key = api_key or os.getenv('DEEPSEEK_API_KEY')
69
69
  if not api_key and openai_client is None:
@@ -3,7 +3,7 @@ from __future__ import annotations as _annotations
3
3
  import os
4
4
  from typing import overload
5
5
 
6
- from httpx import AsyncClient as AsyncHTTPClient
6
+ import httpx
7
7
  from openai import AsyncOpenAI
8
8
 
9
9
  from pydantic_ai.exceptions import UserError
@@ -71,7 +71,7 @@ class FireworksProvider(Provider[AsyncOpenAI]):
71
71
  def __init__(self, *, api_key: str) -> None: ...
72
72
 
73
73
  @overload
74
- def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ...
74
+ def __init__(self, *, api_key: str, http_client: httpx.AsyncClient) -> None: ...
75
75
 
76
76
  @overload
77
77
  def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
@@ -81,7 +81,7 @@ class FireworksProvider(Provider[AsyncOpenAI]):
81
81
  *,
82
82
  api_key: str | None = None,
83
83
  openai_client: AsyncOpenAI | None = None,
84
- http_client: AsyncHTTPClient | None = None,
84
+ http_client: httpx.AsyncClient | None = None,
85
85
  ) -> None:
86
86
  api_key = api_key or os.getenv('FIREWORKS_API_KEY')
87
87
  if not api_key and openai_client is None: