pydantic-ai-slim 0.7.1__py3-none-any.whl → 0.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. pydantic_ai/_agent_graph.py +60 -57
  2. pydantic_ai/_cli.py +18 -3
  3. pydantic_ai/_parts_manager.py +5 -4
  4. pydantic_ai/_run_context.py +2 -2
  5. pydantic_ai/_tool_manager.py +50 -29
  6. pydantic_ai/ag_ui.py +4 -4
  7. pydantic_ai/agent/__init__.py +69 -84
  8. pydantic_ai/agent/abstract.py +16 -18
  9. pydantic_ai/agent/wrapper.py +4 -6
  10. pydantic_ai/direct.py +4 -4
  11. pydantic_ai/durable_exec/temporal/_agent.py +13 -15
  12. pydantic_ai/durable_exec/temporal/_model.py +2 -2
  13. pydantic_ai/messages.py +16 -6
  14. pydantic_ai/models/__init__.py +5 -5
  15. pydantic_ai/models/anthropic.py +47 -46
  16. pydantic_ai/models/bedrock.py +25 -27
  17. pydantic_ai/models/cohere.py +20 -25
  18. pydantic_ai/models/fallback.py +15 -15
  19. pydantic_ai/models/function.py +7 -9
  20. pydantic_ai/models/gemini.py +43 -39
  21. pydantic_ai/models/google.py +59 -40
  22. pydantic_ai/models/groq.py +23 -19
  23. pydantic_ai/models/huggingface.py +27 -23
  24. pydantic_ai/models/instrumented.py +4 -4
  25. pydantic_ai/models/mcp_sampling.py +1 -2
  26. pydantic_ai/models/mistral.py +24 -22
  27. pydantic_ai/models/openai.py +101 -45
  28. pydantic_ai/models/test.py +4 -5
  29. pydantic_ai/profiles/__init__.py +10 -1
  30. pydantic_ai/profiles/deepseek.py +1 -1
  31. pydantic_ai/profiles/moonshotai.py +1 -1
  32. pydantic_ai/profiles/openai.py +13 -3
  33. pydantic_ai/profiles/qwen.py +4 -1
  34. pydantic_ai/providers/__init__.py +4 -0
  35. pydantic_ai/providers/huggingface.py +27 -0
  36. pydantic_ai/providers/ollama.py +105 -0
  37. pydantic_ai/providers/openai.py +1 -1
  38. pydantic_ai/providers/openrouter.py +2 -0
  39. pydantic_ai/result.py +6 -6
  40. pydantic_ai/run.py +4 -11
  41. pydantic_ai/tools.py +9 -9
  42. pydantic_ai/usage.py +229 -67
  43. {pydantic_ai_slim-0.7.1.dist-info → pydantic_ai_slim-0.7.3.dist-info}/METADATA +10 -4
  44. {pydantic_ai_slim-0.7.1.dist-info → pydantic_ai_slim-0.7.3.dist-info}/RECORD +47 -46
  45. {pydantic_ai_slim-0.7.1.dist-info → pydantic_ai_slim-0.7.3.dist-info}/WHEEL +0 -0
  46. {pydantic_ai_slim-0.7.1.dist-info → pydantic_ai_slim-0.7.3.dist-info}/entry_points.txt +0 -0
  47. {pydantic_ai_slim-0.7.1.dist-info → pydantic_ai_slim-0.7.3.dist-info}/licenses/LICENSE +0 -0
@@ -40,7 +40,7 @@ from ..profiles import ModelProfileSpec
40
40
  from ..providers import Provider, infer_provider
41
41
  from ..settings import ModelSettings
42
42
  from ..tools import ToolDefinition
43
- from ..usage import Usage
43
+ from ..usage import RequestUsage
44
44
  from . import (
45
45
  Model,
46
46
  ModelRequestParameters,
@@ -120,7 +120,7 @@ class MistralModel(Model):
120
120
  json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n"""
121
121
 
122
122
  _model_name: MistralModelName = field(repr=False)
123
- _system: str = field(default='mistral_ai', repr=False)
123
+ _provider: Provider[Mistral] = field(repr=False)
124
124
 
125
125
  def __init__(
126
126
  self,
@@ -147,13 +147,24 @@ class MistralModel(Model):
147
147
 
148
148
  if isinstance(provider, str):
149
149
  provider = infer_provider(provider)
150
+ self._provider = provider
150
151
  self.client = provider.client
151
152
 
152
153
  super().__init__(settings=settings, profile=profile or provider.model_profile)
153
154
 
154
155
  @property
155
156
  def base_url(self) -> str:
156
- return self.client.sdk_configuration.get_server_details()[0]
157
+ return self._provider.base_url
158
+
159
+ @property
160
+ def model_name(self) -> MistralModelName:
161
+ """The model name."""
162
+ return self._model_name
163
+
164
+ @property
165
+ def system(self) -> str:
166
+ """The model provider."""
167
+ return self._provider.name
157
168
 
158
169
  async def request(
159
170
  self,
@@ -167,7 +178,6 @@ class MistralModel(Model):
167
178
  messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters
168
179
  )
169
180
  model_response = self._process_response(response)
170
- model_response.usage.requests = 1
171
181
  return model_response
172
182
 
173
183
  @asynccontextmanager
@@ -186,16 +196,6 @@ class MistralModel(Model):
186
196
  async with response:
187
197
  yield await self._process_streamed_response(response, model_request_parameters)
188
198
 
189
- @property
190
- def model_name(self) -> MistralModelName:
191
- """The model name."""
192
- return self._model_name
193
-
194
- @property
195
- def system(self) -> str:
196
- """The system / model provider."""
197
- return self._system
198
-
199
199
  async def _completions_create(
200
200
  self,
201
201
  messages: list[ModelMessage],
@@ -348,7 +348,11 @@ class MistralModel(Model):
348
348
  parts.append(tool)
349
349
 
350
350
  return ModelResponse(
351
- parts, usage=_map_usage(response), model_name=response.model, timestamp=timestamp, vendor_id=response.id
351
+ parts,
352
+ usage=_map_usage(response),
353
+ model_name=response.model,
354
+ timestamp=timestamp,
355
+ provider_request_id=response.id,
352
356
  )
353
357
 
354
358
  async def _process_streamed_response(
@@ -699,17 +703,15 @@ SIMPLE_JSON_TYPE_MAPPING = {
699
703
  }
700
704
 
701
705
 
702
- def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk) -> Usage:
706
+ def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk) -> RequestUsage:
703
707
  """Maps a Mistral Completion Chunk or Chat Completion Response to a Usage."""
704
708
  if response.usage:
705
- return Usage(
706
- request_tokens=response.usage.prompt_tokens,
707
- response_tokens=response.usage.completion_tokens,
708
- total_tokens=response.usage.total_tokens,
709
- details=None,
709
+ return RequestUsage(
710
+ input_tokens=response.usage.prompt_tokens,
711
+ output_tokens=response.usage.completion_tokens,
710
712
  )
711
713
  else:
712
- return Usage() # pragma: no cover
714
+ return RequestUsage() # pragma: no cover
713
715
 
714
716
 
715
717
  def _map_content(content: MistralOptionalNullable[MistralContent]) -> str | None:
@@ -9,7 +9,7 @@ from datetime import datetime
9
9
  from typing import Any, Literal, Union, cast, overload
10
10
 
11
11
  from pydantic import ValidationError
12
- from typing_extensions import assert_never
12
+ from typing_extensions import assert_never, deprecated
13
13
 
14
14
  from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
15
15
  from .._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition
@@ -40,7 +40,7 @@ from ..messages import (
40
40
  VideoUrl,
41
41
  )
42
42
  from ..profiles import ModelProfile, ModelProfileSpec
43
- from ..profiles.openai import OpenAIModelProfile
43
+ from ..profiles.openai import OpenAIModelProfile, OpenAISystemPromptRole
44
44
  from ..providers import Provider, infer_provider
45
45
  from ..settings import ModelSettings
46
46
  from ..tools import ToolDefinition
@@ -100,8 +100,6 @@ Using this more broad type for the model name instead of the ChatModel definitio
100
100
  allows this model to be used more easily with other model types (ie, Ollama, Deepseek).
101
101
  """
102
102
 
103
- OpenAISystemPromptRole = Literal['system', 'developer', 'user']
104
-
105
103
 
106
104
  class OpenAIModelSettings(ModelSettings, total=False):
107
105
  """Settings used for an OpenAI model request."""
@@ -196,10 +194,59 @@ class OpenAIModel(Model):
196
194
  """
197
195
 
198
196
  client: AsyncOpenAI = field(repr=False)
199
- system_prompt_role: OpenAISystemPromptRole | None = field(default=None, repr=False)
200
197
 
201
198
  _model_name: OpenAIModelName = field(repr=False)
202
- _system: str = field(default='openai', repr=False)
199
+ _provider: Provider[AsyncOpenAI] = field(repr=False)
200
+
201
+ @overload
202
+ def __init__(
203
+ self,
204
+ model_name: OpenAIModelName,
205
+ *,
206
+ provider: Literal[
207
+ 'openai',
208
+ 'deepseek',
209
+ 'azure',
210
+ 'openrouter',
211
+ 'moonshotai',
212
+ 'vercel',
213
+ 'grok',
214
+ 'fireworks',
215
+ 'together',
216
+ 'heroku',
217
+ 'github',
218
+ 'ollama',
219
+ ]
220
+ | Provider[AsyncOpenAI] = 'openai',
221
+ profile: ModelProfileSpec | None = None,
222
+ settings: ModelSettings | None = None,
223
+ ) -> None: ...
224
+
225
+ @deprecated('Set the `system_prompt_role` in the `OpenAIModelProfile` instead.')
226
+ @overload
227
+ def __init__(
228
+ self,
229
+ model_name: OpenAIModelName,
230
+ *,
231
+ provider: Literal[
232
+ 'openai',
233
+ 'deepseek',
234
+ 'azure',
235
+ 'openrouter',
236
+ 'moonshotai',
237
+ 'vercel',
238
+ 'grok',
239
+ 'fireworks',
240
+ 'together',
241
+ 'heroku',
242
+ 'github',
243
+ 'ollama',
244
+ ]
245
+ | Provider[AsyncOpenAI] = 'openai',
246
+ profile: ModelProfileSpec | None = None,
247
+ system_prompt_role: OpenAISystemPromptRole | None = None,
248
+ settings: ModelSettings | None = None,
249
+ ) -> None: ...
203
250
 
204
251
  def __init__(
205
252
  self,
@@ -217,6 +264,7 @@ class OpenAIModel(Model):
217
264
  'together',
218
265
  'heroku',
219
266
  'github',
267
+ 'ollama',
220
268
  ]
221
269
  | Provider[AsyncOpenAI] = 'openai',
222
270
  profile: ModelProfileSpec | None = None,
@@ -239,16 +287,33 @@ class OpenAIModel(Model):
239
287
 
240
288
  if isinstance(provider, str):
241
289
  provider = infer_provider(provider)
290
+ self._provider = provider
242
291
  self.client = provider.client
243
292
 
244
- self.system_prompt_role = system_prompt_role
245
-
246
293
  super().__init__(settings=settings, profile=profile or provider.model_profile)
247
294
 
295
+ if system_prompt_role is not None:
296
+ self.profile = OpenAIModelProfile(openai_system_prompt_role=system_prompt_role).update(self.profile)
297
+
248
298
  @property
249
299
  def base_url(self) -> str:
250
300
  return str(self.client.base_url)
251
301
 
302
+ @property
303
+ def model_name(self) -> OpenAIModelName:
304
+ """The model name."""
305
+ return self._model_name
306
+
307
+ @property
308
+ def system(self) -> str:
309
+ """The model provider."""
310
+ return self._provider.name
311
+
312
+ @property
313
+ @deprecated('Set the `system_prompt_role` in the `OpenAIModelProfile` instead.')
314
+ def system_prompt_role(self) -> OpenAISystemPromptRole | None:
315
+ return OpenAIModelProfile.from_profile(self.profile).openai_system_prompt_role
316
+
252
317
  async def request(
253
318
  self,
254
319
  messages: list[ModelMessage],
@@ -260,7 +325,6 @@ class OpenAIModel(Model):
260
325
  messages, False, cast(OpenAIModelSettings, model_settings or {}), model_request_parameters
261
326
  )
262
327
  model_response = self._process_response(response)
263
- model_response.usage.requests = 1
264
328
  return model_response
265
329
 
266
330
  @asynccontextmanager
@@ -278,16 +342,6 @@ class OpenAIModel(Model):
278
342
  async with response:
279
343
  yield await self._process_streamed_response(response, model_request_parameters)
280
344
 
281
- @property
282
- def model_name(self) -> OpenAIModelName:
283
- """The model name."""
284
- return self._model_name
285
-
286
- @property
287
- def system(self) -> str:
288
- """The system / model provider."""
289
- return self._system
290
-
291
345
  @overload
292
346
  async def _completions_create(
293
347
  self,
@@ -444,8 +498,8 @@ class OpenAIModel(Model):
444
498
  usage=_map_usage(response),
445
499
  model_name=response.model,
446
500
  timestamp=timestamp,
447
- vendor_details=vendor_details,
448
- vendor_id=response.id,
501
+ provider_details=vendor_details,
502
+ provider_request_id=response.id,
449
503
  )
450
504
 
451
505
  async def _process_streamed_response(
@@ -561,9 +615,10 @@ class OpenAIModel(Model):
561
615
  async def _map_user_message(self, message: ModelRequest) -> AsyncIterable[chat.ChatCompletionMessageParam]:
562
616
  for part in message.parts:
563
617
  if isinstance(part, SystemPromptPart):
564
- if self.system_prompt_role == 'developer':
618
+ system_prompt_role = OpenAIModelProfile.from_profile(self.profile).openai_system_prompt_role
619
+ if system_prompt_role == 'developer':
565
620
  yield chat.ChatCompletionDeveloperMessageParam(role='developer', content=part.content)
566
- elif self.system_prompt_role == 'user':
621
+ elif system_prompt_role == 'user':
567
622
  yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
568
623
  else:
569
624
  yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
@@ -659,10 +714,9 @@ class OpenAIResponsesModel(Model):
659
714
  """
660
715
 
661
716
  client: AsyncOpenAI = field(repr=False)
662
- system_prompt_role: OpenAISystemPromptRole | None = field(default=None)
663
717
 
664
718
  _model_name: OpenAIModelName = field(repr=False)
665
- _system: str = field(default='openai', repr=False)
719
+ _provider: Provider[AsyncOpenAI] = field(repr=False)
666
720
 
667
721
  def __init__(
668
722
  self,
@@ -685,6 +739,7 @@ class OpenAIResponsesModel(Model):
685
739
 
686
740
  if isinstance(provider, str):
687
741
  provider = infer_provider(provider)
742
+ self._provider = provider
688
743
  self.client = provider.client
689
744
 
690
745
  super().__init__(settings=settings, profile=profile or provider.model_profile)
@@ -696,8 +751,8 @@ class OpenAIResponsesModel(Model):
696
751
 
697
752
  @property
698
753
  def system(self) -> str:
699
- """The system / model provider."""
700
- return self._system
754
+ """The model provider."""
755
+ return self._provider.name
701
756
 
702
757
  async def request(
703
758
  self,
@@ -746,7 +801,7 @@ class OpenAIResponsesModel(Model):
746
801
  items,
747
802
  usage=_map_usage(response),
748
803
  model_name=response.model,
749
- vendor_id=response.id,
804
+ provider_request_id=response.id,
750
805
  timestamp=timestamp,
751
806
  )
752
807
 
@@ -1094,11 +1149,12 @@ class OpenAIStreamedResponse(StreamedResponse):
1094
1149
 
1095
1150
  # Handle the text part of the response
1096
1151
  content = choice.delta.content
1097
- if content:
1152
+ if content is not None:
1098
1153
  maybe_event = self._parts_manager.handle_text_delta(
1099
1154
  vendor_part_id='content',
1100
1155
  content=content,
1101
1156
  thinking_tags=self._model_profile.thinking_tags,
1157
+ ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
1102
1158
  )
1103
1159
  if maybe_event is not None: # pragma: no branch
1104
1160
  yield maybe_event
@@ -1263,10 +1319,10 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
1263
1319
  return self._timestamp
1264
1320
 
1265
1321
 
1266
- def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.Response) -> usage.Usage:
1322
+ def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.Response) -> usage.RequestUsage:
1267
1323
  response_usage = response.usage
1268
1324
  if response_usage is None:
1269
- return usage.Usage()
1325
+ return usage.RequestUsage()
1270
1326
  elif isinstance(response_usage, responses.ResponseUsage):
1271
1327
  details: dict[str, int] = {
1272
1328
  key: value
@@ -1276,29 +1332,29 @@ def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.R
1276
1332
  if isinstance(value, int)
1277
1333
  }
1278
1334
  details['reasoning_tokens'] = response_usage.output_tokens_details.reasoning_tokens
1279
- details['cached_tokens'] = response_usage.input_tokens_details.cached_tokens
1280
- return usage.Usage(
1281
- request_tokens=response_usage.input_tokens,
1282
- response_tokens=response_usage.output_tokens,
1283
- total_tokens=response_usage.total_tokens,
1335
+ return usage.RequestUsage(
1336
+ input_tokens=response_usage.input_tokens,
1337
+ output_tokens=response_usage.output_tokens,
1338
+ cache_read_tokens=response_usage.input_tokens_details.cached_tokens,
1284
1339
  details=details,
1285
1340
  )
1286
1341
  else:
1287
1342
  details = {
1288
1343
  key: value
1289
1344
  for key, value in response_usage.model_dump(
1290
- exclude={'prompt_tokens', 'completion_tokens', 'total_tokens'}
1345
+ exclude_none=True, exclude={'prompt_tokens', 'completion_tokens', 'total_tokens'}
1291
1346
  ).items()
1292
1347
  if isinstance(value, int)
1293
1348
  }
1349
+ u = usage.RequestUsage(
1350
+ input_tokens=response_usage.prompt_tokens,
1351
+ output_tokens=response_usage.completion_tokens,
1352
+ details=details,
1353
+ )
1294
1354
  if response_usage.completion_tokens_details is not None:
1295
1355
  details.update(response_usage.completion_tokens_details.model_dump(exclude_none=True))
1356
+ u.output_audio_tokens = response_usage.completion_tokens_details.audio_tokens or 0
1296
1357
  if response_usage.prompt_tokens_details is not None:
1297
- details.update(response_usage.prompt_tokens_details.model_dump(exclude_none=True))
1298
- return usage.Usage(
1299
- requests=1,
1300
- request_tokens=response_usage.prompt_tokens,
1301
- response_tokens=response_usage.completion_tokens,
1302
- total_tokens=response_usage.total_tokens,
1303
- details=details,
1304
- )
1358
+ u.input_audio_tokens = response_usage.prompt_tokens_details.audio_tokens or 0
1359
+ u.cache_read_tokens = response_usage.prompt_tokens_details.cached_tokens or 0
1360
+ return u
@@ -31,7 +31,7 @@ from ..messages import (
31
31
  from ..profiles import ModelProfileSpec
32
32
  from ..settings import ModelSettings
33
33
  from ..tools import ToolDefinition
34
- from ..usage import Usage
34
+ from ..usage import RequestUsage
35
35
  from . import Model, ModelRequestParameters, StreamedResponse
36
36
  from .function import _estimate_string_tokens, _estimate_usage # pyright: ignore[reportPrivateUsage]
37
37
 
@@ -113,7 +113,6 @@ class TestModel(Model):
113
113
  self.last_model_request_parameters = model_request_parameters
114
114
  model_response = self._request(messages, model_settings, model_request_parameters)
115
115
  model_response.usage = _estimate_usage([*messages, model_response])
116
- model_response.usage.requests = 1
117
116
  return model_response
118
117
 
119
118
  @asynccontextmanager
@@ -141,7 +140,7 @@ class TestModel(Model):
141
140
 
142
141
  @property
143
142
  def system(self) -> str:
144
- """The system / model provider."""
143
+ """The model provider."""
145
144
  return self._system
146
145
 
147
146
  def gen_tool_args(self, tool_def: ToolDefinition) -> Any:
@@ -468,6 +467,6 @@ class _JsonSchemaTestData:
468
467
  return s
469
468
 
470
469
 
471
- def _get_string_usage(text: str) -> Usage:
470
+ def _get_string_usage(text: str) -> RequestUsage:
472
471
  response_tokens = _estimate_string_tokens(text)
473
- return Usage(response_tokens=response_tokens, total_tokens=response_tokens)
472
+ return RequestUsage(output_tokens=response_tokens)
@@ -20,7 +20,7 @@ __all__ = [
20
20
 
21
21
  @dataclass
22
22
  class ModelProfile:
23
- """Describes how requests to a specific model or family of models need to be constructed to get the best results, independent of the model and provider classes used."""
23
+ """Describes how requests to and responses from specific models or families of models need to be constructed and processed to get the best results, independent of the model and provider classes used."""
24
24
 
25
25
  supports_tools: bool = True
26
26
  """Whether the model supports tools."""
@@ -46,6 +46,15 @@ class ModelProfile:
46
46
  thinking_tags: tuple[str, str] = ('<think>', '</think>')
47
47
  """The tags used to indicate thinking parts in the model's output. Defaults to ('<think>', '</think>')."""
48
48
 
49
+ ignore_streamed_leading_whitespace: bool = False
50
+ """Whether to ignore leading whitespace when streaming a response.
51
+
52
+ This is a workaround for models that emit `<think>\n</think>\n\n` or an empty text part ahead of tool calls (e.g. Ollama + Qwen3),
53
+ which we don't want to end up treating as a final result when using `run_stream` with `str` a valid `output_type`.
54
+
55
+ This is currently only used by `OpenAIModel`, `HuggingFaceModel`, and `GroqModel`.
56
+ """
57
+
49
58
  @classmethod
50
59
  def from_profile(cls, profile: ModelProfile | None) -> Self:
51
60
  """Build a ModelProfile subclass instance from a ModelProfile instance."""
@@ -5,4 +5,4 @@ from . import ModelProfile
5
5
 
6
6
  def deepseek_model_profile(model_name: str) -> ModelProfile | None:
7
7
  """Get the model profile for a DeepSeek model."""
8
- return None
8
+ return ModelProfile(ignore_streamed_leading_whitespace='r1' in model_name)
@@ -5,4 +5,4 @@ from . import ModelProfile
5
5
 
6
6
  def moonshotai_model_profile(model_name: str) -> ModelProfile | None:
7
7
  """Get the model profile for a MoonshotAI model."""
8
- return None
8
+ return ModelProfile(ignore_streamed_leading_whitespace=True)
@@ -2,11 +2,13 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import re
4
4
  from dataclasses import dataclass
5
- from typing import Any
5
+ from typing import Any, Literal
6
6
 
7
7
  from . import ModelProfile
8
8
  from ._json_schema import JsonSchema, JsonSchemaTransformer
9
9
 
10
+ OpenAISystemPromptRole = Literal['system', 'developer', 'user']
11
+
10
12
 
11
13
  @dataclass
12
14
  class OpenAIModelProfile(ModelProfile):
@@ -26,8 +28,10 @@ class OpenAIModelProfile(ModelProfile):
26
28
  # safe to pass that value along. Default is `True` to preserve existing
27
29
  # behaviour for OpenAI itself and most providers.
28
30
  openai_supports_tool_choice_required: bool = True
29
- """Whether the provider accepts the value ``tool_choice='required'`` in the
30
- request payload."""
31
+ """Whether the provider accepts the value ``tool_choice='required'`` in the request payload."""
32
+
33
+ openai_system_prompt_role: OpenAISystemPromptRole | None = None
34
+ """The role to use for the system prompt message. If not provided, defaults to `'system'`."""
31
35
 
32
36
 
33
37
  def openai_model_profile(model_name: str) -> ModelProfile:
@@ -36,11 +40,17 @@ def openai_model_profile(model_name: str) -> ModelProfile:
36
40
  # Structured Outputs (output mode 'native') is only supported with the gpt-4o-mini, gpt-4o-mini-2024-07-18, and gpt-4o-2024-08-06 model snapshots and later.
37
41
  # We leave it in here for all models because the `default_structured_output_mode` is `'tool'`, so `native` is only used
38
42
  # when the user specifically uses the `NativeOutput` marker, so an error from the API is acceptable.
43
+
44
+ # The o1-mini model doesn't support the `system` role, so we default to `user`.
45
+ # See https://github.com/pydantic/pydantic-ai/issues/974 for more details.
46
+ openai_system_prompt_role = 'user' if model_name.startswith('o1-mini') else None
47
+
39
48
  return OpenAIModelProfile(
40
49
  json_schema_transformer=OpenAIJsonSchemaTransformer,
41
50
  supports_json_schema_output=True,
42
51
  supports_json_object_output=True,
43
52
  openai_supports_sampling_settings=not is_reasoning_model,
53
+ openai_system_prompt_role=openai_system_prompt_role,
44
54
  )
45
55
 
46
56
 
@@ -5,4 +5,7 @@ from . import InlineDefsJsonSchemaTransformer, ModelProfile
5
5
 
6
6
  def qwen_model_profile(model_name: str) -> ModelProfile | None:
7
7
  """Get the model profile for a Qwen model."""
8
- return ModelProfile(json_schema_transformer=InlineDefsJsonSchemaTransformer)
8
+ return ModelProfile(
9
+ json_schema_transformer=InlineDefsJsonSchemaTransformer,
10
+ ignore_streamed_leading_whitespace=True,
11
+ )
@@ -123,6 +123,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
123
123
  from .huggingface import HuggingFaceProvider
124
124
 
125
125
  return HuggingFaceProvider
126
+ elif provider == 'ollama':
127
+ from .ollama import OllamaProvider
128
+
129
+ return OllamaProvider
126
130
  elif provider == 'github':
127
131
  from .github import GitHubProvider
128
132
 
@@ -6,6 +6,13 @@ from typing import overload
6
6
  from httpx import AsyncClient
7
7
 
8
8
  from pydantic_ai.exceptions import UserError
9
+ from pydantic_ai.profiles import ModelProfile
10
+ from pydantic_ai.profiles.deepseek import deepseek_model_profile
11
+ from pydantic_ai.profiles.google import google_model_profile
12
+ from pydantic_ai.profiles.meta import meta_model_profile
13
+ from pydantic_ai.profiles.mistral import mistral_model_profile
14
+ from pydantic_ai.profiles.moonshotai import moonshotai_model_profile
15
+ from pydantic_ai.profiles.qwen import qwen_model_profile
9
16
 
10
17
  try:
11
18
  from huggingface_hub import AsyncInferenceClient
@@ -33,6 +40,26 @@ class HuggingFaceProvider(Provider[AsyncInferenceClient]):
33
40
  def client(self) -> AsyncInferenceClient:
34
41
  return self._client
35
42
 
43
+ def model_profile(self, model_name: str) -> ModelProfile | None:
44
+ provider_to_profile = {
45
+ 'deepseek-ai': deepseek_model_profile,
46
+ 'google': google_model_profile,
47
+ 'qwen': qwen_model_profile,
48
+ 'meta-llama': meta_model_profile,
49
+ 'mistralai': mistral_model_profile,
50
+ 'moonshotai': moonshotai_model_profile,
51
+ }
52
+
53
+ if '/' not in model_name:
54
+ return None
55
+
56
+ model_name = model_name.lower()
57
+ provider, model_name = model_name.split('/', 1)
58
+ if provider in provider_to_profile:
59
+ return provider_to_profile[provider](model_name)
60
+
61
+ return None
62
+
36
63
  @overload
37
64
  def __init__(self, *, base_url: str, api_key: str | None = None) -> None: ...
38
65
  @overload
@@ -0,0 +1,105 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import os
4
+
5
+ import httpx
6
+ from openai import AsyncOpenAI
7
+
8
+ from pydantic_ai.exceptions import UserError
9
+ from pydantic_ai.models import cached_async_http_client
10
+ from pydantic_ai.profiles import ModelProfile
11
+ from pydantic_ai.profiles.cohere import cohere_model_profile
12
+ from pydantic_ai.profiles.deepseek import deepseek_model_profile
13
+ from pydantic_ai.profiles.google import google_model_profile
14
+ from pydantic_ai.profiles.meta import meta_model_profile
15
+ from pydantic_ai.profiles.mistral import mistral_model_profile
16
+ from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile
17
+ from pydantic_ai.profiles.qwen import qwen_model_profile
18
+ from pydantic_ai.providers import Provider
19
+
20
+ try:
21
+ from openai import AsyncOpenAI
22
+ except ImportError as _import_error: # pragma: no cover
23
+ raise ImportError(
24
+ 'Please install the `openai` package to use the Ollama provider, '
25
+ 'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
26
+ ) from _import_error
27
+
28
+
29
+ class OllamaProvider(Provider[AsyncOpenAI]):
30
+ """Provider for local or remote Ollama API."""
31
+
32
+ @property
33
+ def name(self) -> str:
34
+ return 'ollama'
35
+
36
+ @property
37
+ def base_url(self) -> str:
38
+ return str(self.client.base_url)
39
+
40
+ @property
41
+ def client(self) -> AsyncOpenAI:
42
+ return self._client
43
+
44
+ def model_profile(self, model_name: str) -> ModelProfile | None:
45
+ prefix_to_profile = {
46
+ 'llama': meta_model_profile,
47
+ 'gemma': google_model_profile,
48
+ 'qwen': qwen_model_profile,
49
+ 'qwq': qwen_model_profile,
50
+ 'deepseek': deepseek_model_profile,
51
+ 'mistral': mistral_model_profile,
52
+ 'command': cohere_model_profile,
53
+ }
54
+
55
+ profile = None
56
+ for prefix, profile_func in prefix_to_profile.items():
57
+ model_name = model_name.lower()
58
+ if model_name.startswith(prefix):
59
+ profile = profile_func(model_name)
60
+
61
+ # As OllamaProvider is always used with OpenAIModel, which used to unconditionally use OpenAIJsonSchemaTransformer,
62
+ # we need to maintain that behavior unless json_schema_transformer is set explicitly
63
+ return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile)
64
+
65
+ def __init__(
66
+ self,
67
+ base_url: str | None = None,
68
+ api_key: str | None = None,
69
+ openai_client: AsyncOpenAI | None = None,
70
+ http_client: httpx.AsyncClient | None = None,
71
+ ) -> None:
72
+ """Create a new Ollama provider.
73
+
74
+ Args:
75
+ base_url: The base url for the Ollama requests. If not provided, the `OLLAMA_BASE_URL` environment variable
76
+ will be used if available.
77
+ api_key: The API key to use for authentication, if not provided, the `OLLAMA_API_KEY` environment variable
78
+ will be used if available.
79
+ openai_client: An existing
80
+ [`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
81
+ client to use. If provided, `base_url`, `api_key`, and `http_client` must be `None`.
82
+ http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
83
+ """
84
+ if openai_client is not None:
85
+ assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
86
+ assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
87
+ assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
88
+ self._client = openai_client
89
+ else:
90
+ base_url = base_url or os.getenv('OLLAMA_BASE_URL')
91
+ if not base_url:
92
+ raise UserError(
93
+ 'Set the `OLLAMA_BASE_URL` environment variable or pass it via `OllamaProvider(base_url=...)`'
94
+ 'to use the Ollama provider.'
95
+ )
96
+
97
+ # This is a workaround for the OpenAI client requiring an API key, whilst locally served,
98
+ # openai compatible models do not always need an API key, but a placeholder (non-empty) key is required.
99
+ api_key = api_key or os.getenv('OLLAMA_API_KEY') or 'api-key-not-set'
100
+
101
+ if http_client is not None:
102
+ self._client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client)
103
+ else:
104
+ http_client = cached_async_http_client(provider='ollama')
105
+ self._client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client)