pydantic-ai-slim 0.7.0__py3-none-any.whl → 0.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -8,14 +8,6 @@ from dataclasses import dataclass, field
8
8
  from datetime import datetime, timezone
9
9
  from typing import Any, Literal, Union, cast, overload
10
10
 
11
- from anthropic.types.beta import (
12
- BetaCitationsDelta,
13
- BetaCodeExecutionToolResultBlock,
14
- BetaCodeExecutionToolResultBlockParam,
15
- BetaInputJSONDelta,
16
- BetaServerToolUseBlockParam,
17
- BetaWebSearchToolResultBlockParam,
18
- )
19
11
  from typing_extensions import assert_never
20
12
 
21
13
  from pydantic_ai.builtin_tools import CodeExecutionTool, WebSearchTool
@@ -47,24 +39,21 @@ from ..profiles import ModelProfileSpec
47
39
  from ..providers import Provider, infer_provider
48
40
  from ..settings import ModelSettings
49
41
  from ..tools import ToolDefinition
50
- from . import (
51
- Model,
52
- ModelRequestParameters,
53
- StreamedResponse,
54
- check_allow_model_requests,
55
- download_item,
56
- get_user_agent,
57
- )
42
+ from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent
58
43
 
59
44
  try:
60
45
  from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic, AsyncStream
61
46
  from anthropic.types.beta import (
62
47
  BetaBase64PDFBlockParam,
63
48
  BetaBase64PDFSourceParam,
49
+ BetaCitationsDelta,
64
50
  BetaCodeExecutionTool20250522Param,
51
+ BetaCodeExecutionToolResultBlock,
52
+ BetaCodeExecutionToolResultBlockParam,
65
53
  BetaContentBlock,
66
54
  BetaContentBlockParam,
67
55
  BetaImageBlockParam,
56
+ BetaInputJSONDelta,
68
57
  BetaMessage,
69
58
  BetaMessageParam,
70
59
  BetaMetadataParam,
@@ -78,6 +67,7 @@ try:
78
67
  BetaRawMessageStreamEvent,
79
68
  BetaRedactedThinkingBlock,
80
69
  BetaServerToolUseBlock,
70
+ BetaServerToolUseBlockParam,
81
71
  BetaSignatureDelta,
82
72
  BetaTextBlock,
83
73
  BetaTextBlockParam,
@@ -94,6 +84,7 @@ try:
94
84
  BetaToolUseBlockParam,
95
85
  BetaWebSearchTool20250305Param,
96
86
  BetaWebSearchToolResultBlock,
87
+ BetaWebSearchToolResultBlockParam,
97
88
  )
98
89
  from anthropic.types.beta.beta_web_search_tool_20250305_param import UserLocation
99
90
  from anthropic.types.model_param import ModelParam
@@ -246,7 +237,9 @@ class AnthropicModel(Model):
246
237
  ) -> BetaMessage | AsyncStream[BetaRawMessageStreamEvent]:
247
238
  # standalone function to make it easier to override
248
239
  tools = self._get_tools(model_request_parameters)
249
- tools += self._get_builtin_tools(model_request_parameters)
240
+ builtin_tools, tool_headers = self._get_builtin_tools(model_request_parameters)
241
+ tools += builtin_tools
242
+
250
243
  tool_choice: BetaToolChoiceParam | None
251
244
 
252
245
  if not tools:
@@ -264,8 +257,10 @@ class AnthropicModel(Model):
264
257
 
265
258
  try:
266
259
  extra_headers = model_settings.get('extra_headers', {})
260
+ for k, v in tool_headers.items():
261
+ extra_headers.setdefault(k, v)
267
262
  extra_headers.setdefault('User-Agent', get_user_agent())
268
- extra_headers.setdefault('anthropic-beta', 'code-execution-2025-05-22')
263
+
269
264
  return await self.client.beta.messages.create(
270
265
  max_tokens=model_settings.get('max_tokens', 4096),
271
266
  system=system_prompt or NOT_GIVEN,
@@ -352,8 +347,11 @@ class AnthropicModel(Model):
352
347
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[BetaToolParam]:
353
348
  return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
354
349
 
355
- def _get_builtin_tools(self, model_request_parameters: ModelRequestParameters) -> list[BetaToolUnionParam]:
350
+ def _get_builtin_tools(
351
+ self, model_request_parameters: ModelRequestParameters
352
+ ) -> tuple[list[BetaToolUnionParam], dict[str, str]]:
356
353
  tools: list[BetaToolUnionParam] = []
354
+ extra_headers: dict[str, str] = {}
357
355
  for tool in model_request_parameters.builtin_tools:
358
356
  if isinstance(tool, WebSearchTool):
359
357
  user_location = UserLocation(type='approximate', **tool.user_location) if tool.user_location else None
@@ -361,18 +359,20 @@ class AnthropicModel(Model):
361
359
  BetaWebSearchTool20250305Param(
362
360
  name='web_search',
363
361
  type='web_search_20250305',
362
+ max_uses=tool.max_uses,
364
363
  allowed_domains=tool.allowed_domains,
365
364
  blocked_domains=tool.blocked_domains,
366
365
  user_location=user_location,
367
366
  )
368
367
  )
369
368
  elif isinstance(tool, CodeExecutionTool): # pragma: no branch
369
+ extra_headers['anthropic-beta'] = 'code-execution-2025-05-22'
370
370
  tools.append(BetaCodeExecutionTool20250522Param(name='code_execution', type='code_execution_20250522'))
371
371
  else: # pragma: no cover
372
372
  raise UserError(
373
373
  f'`{tool.__class__.__name__}` is not supported by `AnthropicModel`. If it should be, please file an issue.'
374
374
  )
375
- return tools
375
+ return tools, extra_headers
376
376
 
377
377
  async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[BetaMessageParam]]: # noqa: C901
378
378
  """Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
@@ -648,7 +648,7 @@ class BedrockStreamedResponse(StreamedResponse):
648
648
  )
649
649
  if 'text' in delta:
650
650
  maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=delta['text'])
651
- if maybe_event is not None:
651
+ if maybe_event is not None: # pragma: no branch
652
652
  yield maybe_event
653
653
  if 'toolUse' in delta:
654
654
  tool_use = delta['toolUse']
@@ -11,6 +11,7 @@ from pydantic_ai._run_context import RunContext
11
11
  from pydantic_ai.models.instrumented import InstrumentedModel
12
12
 
13
13
  from ..exceptions import FallbackExceptionGroup, ModelHTTPError
14
+ from ..settings import merge_model_settings
14
15
  from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model
15
16
 
16
17
  if TYPE_CHECKING:
@@ -65,8 +66,9 @@ class FallbackModel(Model):
65
66
 
66
67
  for model in self.models:
67
68
  customized_model_request_parameters = model.customize_request_parameters(model_request_parameters)
69
+ merged_settings = merge_model_settings(model.settings, model_settings)
68
70
  try:
69
- response = await model.request(messages, model_settings, customized_model_request_parameters)
71
+ response = await model.request(messages, merged_settings, customized_model_request_parameters)
70
72
  except Exception as exc:
71
73
  if self._fallback_on(exc):
72
74
  exceptions.append(exc)
@@ -91,10 +93,13 @@ class FallbackModel(Model):
91
93
 
92
94
  for model in self.models:
93
95
  customized_model_request_parameters = model.customize_request_parameters(model_request_parameters)
96
+ merged_settings = merge_model_settings(model.settings, model_settings)
94
97
  async with AsyncExitStack() as stack:
95
98
  try:
96
99
  response = await stack.enter_async_context(
97
- model.request_stream(messages, model_settings, customized_model_request_parameters, run_context)
100
+ model.request_stream(
101
+ messages, merged_settings, customized_model_request_parameters, run_context
102
+ )
98
103
  )
99
104
  except Exception as exc:
100
105
  if self._fallback_on(exc):
@@ -52,6 +52,7 @@ try:
52
52
  from google.genai.types import (
53
53
  ContentDict,
54
54
  ContentUnionDict,
55
+ CountTokensConfigDict,
55
56
  ExecutableCodeDict,
56
57
  FunctionCallDict,
57
58
  FunctionCallingConfigDict,
@@ -59,6 +60,7 @@ try:
59
60
  FunctionDeclarationDict,
60
61
  GenerateContentConfigDict,
61
62
  GenerateContentResponse,
63
+ GenerationConfigDict,
62
64
  GoogleSearchDict,
63
65
  HttpOptionsDict,
64
66
  MediaResolution,
@@ -188,6 +190,59 @@ class GoogleModel(Model):
188
190
  response = await self._generate_content(messages, False, model_settings, model_request_parameters)
189
191
  return self._process_response(response)
190
192
 
193
+ async def count_tokens(
194
+ self,
195
+ messages: list[ModelMessage],
196
+ model_settings: ModelSettings | None,
197
+ model_request_parameters: ModelRequestParameters,
198
+ ) -> usage.Usage:
199
+ check_allow_model_requests()
200
+ model_settings = cast(GoogleModelSettings, model_settings or {})
201
+ contents, generation_config = await self._build_content_and_config(
202
+ messages, model_settings, model_request_parameters
203
+ )
204
+
205
+ # Annoyingly, the type of `GenerateContentConfigDict.get` is "partially `Unknown`" because `response_schema` includes `typing._UnionGenericAlias`,
206
+ # so without this we'd need `pyright: ignore[reportUnknownMemberType]` on every line and wouldn't get type checking anyway.
207
+ generation_config = cast(dict[str, Any], generation_config)
208
+
209
+ config = CountTokensConfigDict(
210
+ http_options=generation_config.get('http_options'),
211
+ )
212
+ if self.system != 'google-gla':
213
+ # The fields are not supported by the Gemini API per https://github.com/googleapis/python-genai/blob/7e4ec284dc6e521949626f3ed54028163ef9121d/google/genai/models.py#L1195-L1214
214
+ config.update(
215
+ system_instruction=generation_config.get('system_instruction'),
216
+ tools=cast(list[ToolDict], generation_config.get('tools')),
217
+ # Annoyingly, GenerationConfigDict has fewer fields than GenerateContentConfigDict, and no extra fields are allowed.
218
+ generation_config=GenerationConfigDict(
219
+ temperature=generation_config.get('temperature'),
220
+ top_p=generation_config.get('top_p'),
221
+ max_output_tokens=generation_config.get('max_output_tokens'),
222
+ stop_sequences=generation_config.get('stop_sequences'),
223
+ presence_penalty=generation_config.get('presence_penalty'),
224
+ frequency_penalty=generation_config.get('frequency_penalty'),
225
+ thinking_config=generation_config.get('thinking_config'),
226
+ media_resolution=generation_config.get('media_resolution'),
227
+ response_mime_type=generation_config.get('response_mime_type'),
228
+ response_schema=generation_config.get('response_schema'),
229
+ ),
230
+ )
231
+
232
+ response = await self.client.aio.models.count_tokens(
233
+ model=self._model_name,
234
+ contents=contents,
235
+ config=config,
236
+ )
237
+ if response.total_tokens is None:
238
+ raise UnexpectedModelBehavior( # pragma: no cover
239
+ 'Total tokens missing from Gemini response', str(response)
240
+ )
241
+ return usage.Usage(
242
+ request_tokens=response.total_tokens,
243
+ total_tokens=response.total_tokens,
244
+ )
245
+
191
246
  @asynccontextmanager
192
247
  async def request_stream(
193
248
  self,
@@ -265,16 +320,23 @@ class GoogleModel(Model):
265
320
  model_settings: GoogleModelSettings,
266
321
  model_request_parameters: ModelRequestParameters,
267
322
  ) -> GenerateContentResponse | Awaitable[AsyncIterator[GenerateContentResponse]]:
268
- tools = self._get_tools(model_request_parameters)
323
+ contents, config = await self._build_content_and_config(messages, model_settings, model_request_parameters)
324
+ func = self.client.aio.models.generate_content_stream if stream else self.client.aio.models.generate_content
325
+ return await func(model=self._model_name, contents=contents, config=config) # type: ignore
269
326
 
327
+ async def _build_content_and_config(
328
+ self,
329
+ messages: list[ModelMessage],
330
+ model_settings: GoogleModelSettings,
331
+ model_request_parameters: ModelRequestParameters,
332
+ ) -> tuple[list[ContentUnionDict], GenerateContentConfigDict]:
333
+ tools = self._get_tools(model_request_parameters)
270
334
  response_mime_type = None
271
335
  response_schema = None
272
336
  if model_request_parameters.output_mode == 'native':
273
337
  if tools:
274
338
  raise UserError('Gemini does not support structured output and tools at the same time.')
275
-
276
339
  response_mime_type = 'application/json'
277
-
278
340
  output_object = model_request_parameters.output_object
279
341
  assert output_object is not None
280
342
  response_schema = self._map_response_schema(output_object)
@@ -311,9 +373,7 @@ class GoogleModel(Model):
311
373
  response_mime_type=response_mime_type,
312
374
  response_schema=response_schema,
313
375
  )
314
-
315
- func = self.client.aio.models.generate_content_stream if stream else self.client.aio.models.generate_content
316
- return await func(model=self._model_name, contents=contents, config=config) # type: ignore
376
+ return contents, config
317
377
 
318
378
  def _process_response(self, response: GenerateContentResponse) -> ModelResponse:
319
379
  if not response.candidates or len(response.candidates) != 1:
@@ -457,6 +457,7 @@ class GroqStreamedResponse(StreamedResponse):
457
457
  vendor_part_id='content',
458
458
  content=content,
459
459
  thinking_tags=self._model_profile.thinking_tags,
460
+ ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
460
461
  )
461
462
  if maybe_event is not None: # pragma: no branch
462
463
  yield maybe_event
@@ -35,7 +35,7 @@ from ..messages import (
35
35
  UserPromptPart,
36
36
  VideoUrl,
37
37
  )
38
- from ..profiles import ModelProfile
38
+ from ..profiles import ModelProfile, ModelProfileSpec
39
39
  from ..providers import Provider, infer_provider
40
40
  from ..settings import ModelSettings
41
41
  from ..tools import ToolDefinition
@@ -121,6 +121,8 @@ class HuggingFaceModel(Model):
121
121
  model_name: str,
122
122
  *,
123
123
  provider: Literal['huggingface'] | Provider[AsyncInferenceClient] = 'huggingface',
124
+ profile: ModelProfileSpec | None = None,
125
+ settings: ModelSettings | None = None,
124
126
  ):
125
127
  """Initialize a Hugging Face model.
126
128
 
@@ -128,6 +130,8 @@ class HuggingFaceModel(Model):
128
130
  model_name: The name of the Model to use. You can browse available models [here](https://huggingface.co/models?pipeline_tag=text-generation&inference_provider=all&sort=trending).
129
131
  provider: The provider to use for Hugging Face Inference Providers. Can be either the string 'huggingface' or an
130
132
  instance of `Provider[AsyncInferenceClient]`. If not provided, the other parameters will be used.
133
+ profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
134
+ settings: Model-specific settings that will be used as defaults for this model.
131
135
  """
132
136
  self._model_name = model_name
133
137
  self._provider = provider
@@ -135,6 +139,8 @@ class HuggingFaceModel(Model):
135
139
  provider = infer_provider(provider)
136
140
  self.client = provider.client
137
141
 
142
+ super().__init__(settings=settings, profile=profile or provider.model_profile)
143
+
138
144
  async def request(
139
145
  self,
140
146
  messages: list[ModelMessage],
@@ -444,11 +450,12 @@ class HuggingFaceStreamedResponse(StreamedResponse):
444
450
 
445
451
  # Handle the text part of the response
446
452
  content = choice.delta.content
447
- if content:
453
+ if content is not None:
448
454
  maybe_event = self._parts_manager.handle_text_delta(
449
455
  vendor_part_id='content',
450
456
  content=content,
451
457
  thinking_tags=self._model_profile.thinking_tags,
458
+ ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
452
459
  )
453
460
  if maybe_event is not None: # pragma: no branch
454
461
  yield maybe_event
@@ -59,6 +59,11 @@ try:
59
59
  from openai.types.chat.chat_completion_content_part_image_param import ImageURL
60
60
  from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio
61
61
  from openai.types.chat.chat_completion_content_part_param import File, FileFile
62
+ from openai.types.chat.chat_completion_message_custom_tool_call import ChatCompletionMessageCustomToolCall
63
+ from openai.types.chat.chat_completion_message_function_tool_call import ChatCompletionMessageFunctionToolCall
64
+ from openai.types.chat.chat_completion_message_function_tool_call_param import (
65
+ ChatCompletionMessageFunctionToolCallParam,
66
+ )
62
67
  from openai.types.chat.chat_completion_prediction_content_param import ChatCompletionPredictionContentParam
63
68
  from openai.types.chat.completion_create_params import (
64
69
  WebSearchOptions,
@@ -172,6 +177,14 @@ class OpenAIResponsesModelSettings(OpenAIModelSettings, total=False):
172
177
  middle of the conversation.
173
178
  """
174
179
 
180
+ openai_text_verbosity: Literal['low', 'medium', 'high']
181
+ """Constrains the verbosity of the model's text response.
182
+
183
+ Lower values will result in more concise responses, while higher values will
184
+ result in more verbose responses. Currently supported values are `low`,
185
+ `medium`, and `high`.
186
+ """
187
+
175
188
 
176
189
  @dataclass(init=False)
177
190
  class OpenAIModel(Model):
@@ -204,6 +217,7 @@ class OpenAIModel(Model):
204
217
  'together',
205
218
  'heroku',
206
219
  'github',
220
+ 'ollama',
207
221
  ]
208
222
  | Provider[AsyncOpenAI] = 'openai',
209
223
  profile: ModelProfileSpec | None = None,
@@ -416,7 +430,14 @@ class OpenAIModel(Model):
416
430
  items.extend(split_content_into_text_and_thinking(choice.message.content, self.profile.thinking_tags))
417
431
  if choice.message.tool_calls is not None:
418
432
  for c in choice.message.tool_calls:
419
- part = ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id)
433
+ if isinstance(c, ChatCompletionMessageFunctionToolCall):
434
+ part = ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id)
435
+ elif isinstance(c, ChatCompletionMessageCustomToolCall): # pragma: no cover
436
+ # NOTE: Custom tool calls are not supported.
437
+ # See <https://github.com/pydantic/pydantic-ai/issues/2513> for more details.
438
+ raise RuntimeError('Custom tool calls are not supported')
439
+ else:
440
+ assert_never(c)
420
441
  part.tool_call_id = _guard_tool_call_id(part)
421
442
  items.append(part)
422
443
  return ModelResponse(
@@ -476,7 +497,7 @@ class OpenAIModel(Model):
476
497
  openai_messages.append(item)
477
498
  elif isinstance(message, ModelResponse):
478
499
  texts: list[str] = []
479
- tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
500
+ tool_calls: list[ChatCompletionMessageFunctionToolCallParam] = []
480
501
  for item in message.parts:
481
502
  if isinstance(item, TextPart):
482
503
  texts.append(item.content)
@@ -507,8 +528,8 @@ class OpenAIModel(Model):
507
528
  return openai_messages
508
529
 
509
530
  @staticmethod
510
- def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
511
- return chat.ChatCompletionMessageToolCallParam(
531
+ def _map_tool_call(t: ToolCallPart) -> ChatCompletionMessageFunctionToolCallParam:
532
+ return ChatCompletionMessageFunctionToolCallParam(
512
533
  id=_guard_tool_call_id(t=t),
513
534
  type='function',
514
535
  function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
@@ -807,6 +828,10 @@ class OpenAIResponsesModel(Model):
807
828
  openai_messages.insert(0, responses.EasyInputMessageParam(role='system', content=instructions))
808
829
  instructions = NOT_GIVEN
809
830
 
831
+ if verbosity := model_settings.get('openai_text_verbosity'):
832
+ text = text or {}
833
+ text['verbosity'] = verbosity
834
+
810
835
  sampling_settings = (
811
836
  model_settings
812
837
  if OpenAIModelProfile.from_profile(self.profile).openai_supports_sampling_settings
@@ -1070,11 +1095,12 @@ class OpenAIStreamedResponse(StreamedResponse):
1070
1095
 
1071
1096
  # Handle the text part of the response
1072
1097
  content = choice.delta.content
1073
- if content:
1098
+ if content is not None:
1074
1099
  maybe_event = self._parts_manager.handle_text_delta(
1075
1100
  vendor_part_id='content',
1076
1101
  content=content,
1077
1102
  thinking_tags=self._model_profile.thinking_tags,
1103
+ ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
1078
1104
  )
1079
1105
  if maybe_event is not None: # pragma: no branch
1080
1106
  yield maybe_event
@@ -20,7 +20,7 @@ __all__ = [
20
20
 
21
21
  @dataclass
22
22
  class ModelProfile:
23
- """Describes how requests to a specific model or family of models need to be constructed to get the best results, independent of the model and provider classes used."""
23
+ """Describes how requests to and responses from specific models or families of models need to be constructed and processed to get the best results, independent of the model and provider classes used."""
24
24
 
25
25
  supports_tools: bool = True
26
26
  """Whether the model supports tools."""
@@ -46,6 +46,15 @@ class ModelProfile:
46
46
  thinking_tags: tuple[str, str] = ('<think>', '</think>')
47
47
  """The tags used to indicate thinking parts in the model's output. Defaults to ('<think>', '</think>')."""
48
48
 
49
+ ignore_streamed_leading_whitespace: bool = False
50
+ """Whether to ignore leading whitespace when streaming a response.
51
+
52
+ This is a workaround for models that emit `<think>\n</think>\n\n` or an empty text part ahead of tool calls (e.g. Ollama + Qwen3),
53
+ which we don't want to end up treating as a final result when using `run_stream` with `str` a valid `output_type`.
54
+
55
+ This is currently only used by `OpenAIModel`, `HuggingFaceModel`, and `GroqModel`.
56
+ """
57
+
49
58
  @classmethod
50
59
  def from_profile(cls, profile: ModelProfile | None) -> Self:
51
60
  """Build a ModelProfile subclass instance from a ModelProfile instance."""
@@ -5,4 +5,4 @@ from . import ModelProfile
5
5
 
6
6
  def deepseek_model_profile(model_name: str) -> ModelProfile | None:
7
7
  """Get the model profile for a DeepSeek model."""
8
- return None
8
+ return ModelProfile(ignore_streamed_leading_whitespace='r1' in model_name)
@@ -5,4 +5,4 @@ from . import ModelProfile
5
5
 
6
6
  def moonshotai_model_profile(model_name: str) -> ModelProfile | None:
7
7
  """Get the model profile for a MoonshotAI model."""
8
- return None
8
+ return ModelProfile(ignore_streamed_leading_whitespace=True)
@@ -5,4 +5,7 @@ from . import InlineDefsJsonSchemaTransformer, ModelProfile
5
5
 
6
6
  def qwen_model_profile(model_name: str) -> ModelProfile | None:
7
7
  """Get the model profile for a Qwen model."""
8
- return ModelProfile(json_schema_transformer=InlineDefsJsonSchemaTransformer)
8
+ return ModelProfile(
9
+ json_schema_transformer=InlineDefsJsonSchemaTransformer,
10
+ ignore_streamed_leading_whitespace=True,
11
+ )
@@ -123,6 +123,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
123
123
  from .huggingface import HuggingFaceProvider
124
124
 
125
125
  return HuggingFaceProvider
126
+ elif provider == 'ollama':
127
+ from .ollama import OllamaProvider
128
+
129
+ return OllamaProvider
126
130
  elif provider == 'github':
127
131
  from .github import GitHubProvider
128
132
 
@@ -6,6 +6,13 @@ from typing import overload
6
6
  from httpx import AsyncClient
7
7
 
8
8
  from pydantic_ai.exceptions import UserError
9
+ from pydantic_ai.profiles import ModelProfile
10
+ from pydantic_ai.profiles.deepseek import deepseek_model_profile
11
+ from pydantic_ai.profiles.google import google_model_profile
12
+ from pydantic_ai.profiles.meta import meta_model_profile
13
+ from pydantic_ai.profiles.mistral import mistral_model_profile
14
+ from pydantic_ai.profiles.moonshotai import moonshotai_model_profile
15
+ from pydantic_ai.profiles.qwen import qwen_model_profile
9
16
 
10
17
  try:
11
18
  from huggingface_hub import AsyncInferenceClient
@@ -33,6 +40,26 @@ class HuggingFaceProvider(Provider[AsyncInferenceClient]):
33
40
  def client(self) -> AsyncInferenceClient:
34
41
  return self._client
35
42
 
43
+ def model_profile(self, model_name: str) -> ModelProfile | None:
44
+ provider_to_profile = {
45
+ 'deepseek-ai': deepseek_model_profile,
46
+ 'google': google_model_profile,
47
+ 'qwen': qwen_model_profile,
48
+ 'meta-llama': meta_model_profile,
49
+ 'mistralai': mistral_model_profile,
50
+ 'moonshotai': moonshotai_model_profile,
51
+ }
52
+
53
+ if '/' not in model_name:
54
+ return None
55
+
56
+ model_name = model_name.lower()
57
+ provider, model_name = model_name.split('/', 1)
58
+ if provider in provider_to_profile:
59
+ return provider_to_profile[provider](model_name)
60
+
61
+ return None
62
+
36
63
  @overload
37
64
  def __init__(self, *, base_url: str, api_key: str | None = None) -> None: ...
38
65
  @overload
@@ -0,0 +1,105 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import os
4
+
5
+ import httpx
6
+ from openai import AsyncOpenAI
7
+
8
+ from pydantic_ai.exceptions import UserError
9
+ from pydantic_ai.models import cached_async_http_client
10
+ from pydantic_ai.profiles import ModelProfile
11
+ from pydantic_ai.profiles.cohere import cohere_model_profile
12
+ from pydantic_ai.profiles.deepseek import deepseek_model_profile
13
+ from pydantic_ai.profiles.google import google_model_profile
14
+ from pydantic_ai.profiles.meta import meta_model_profile
15
+ from pydantic_ai.profiles.mistral import mistral_model_profile
16
+ from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile
17
+ from pydantic_ai.profiles.qwen import qwen_model_profile
18
+ from pydantic_ai.providers import Provider
19
+
20
+ try:
21
+ from openai import AsyncOpenAI
22
+ except ImportError as _import_error: # pragma: no cover
23
+ raise ImportError(
24
+ 'Please install the `openai` package to use the Ollama provider, '
25
+ 'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
26
+ ) from _import_error
27
+
28
+
29
+ class OllamaProvider(Provider[AsyncOpenAI]):
30
+ """Provider for local or remote Ollama API."""
31
+
32
+ @property
33
+ def name(self) -> str:
34
+ return 'ollama'
35
+
36
+ @property
37
+ def base_url(self) -> str:
38
+ return str(self.client.base_url)
39
+
40
+ @property
41
+ def client(self) -> AsyncOpenAI:
42
+ return self._client
43
+
44
+ def model_profile(self, model_name: str) -> ModelProfile | None:
45
+ prefix_to_profile = {
46
+ 'llama': meta_model_profile,
47
+ 'gemma': google_model_profile,
48
+ 'qwen': qwen_model_profile,
49
+ 'qwq': qwen_model_profile,
50
+ 'deepseek': deepseek_model_profile,
51
+ 'mistral': mistral_model_profile,
52
+ 'command': cohere_model_profile,
53
+ }
54
+
55
+ profile = None
56
+ for prefix, profile_func in prefix_to_profile.items():
57
+ model_name = model_name.lower()
58
+ if model_name.startswith(prefix):
59
+ profile = profile_func(model_name)
60
+
61
+ # As OllamaProvider is always used with OpenAIModel, which used to unconditionally use OpenAIJsonSchemaTransformer,
62
+ # we need to maintain that behavior unless json_schema_transformer is set explicitly
63
+ return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile)
64
+
65
+ def __init__(
66
+ self,
67
+ base_url: str | None = None,
68
+ api_key: str | None = None,
69
+ openai_client: AsyncOpenAI | None = None,
70
+ http_client: httpx.AsyncClient | None = None,
71
+ ) -> None:
72
+ """Create a new Ollama provider.
73
+
74
+ Args:
75
+ base_url: The base url for the Ollama requests. If not provided, the `OLLAMA_BASE_URL` environment variable
76
+ will be used if available.
77
+ api_key: The API key to use for authentication, if not provided, the `OLLAMA_API_KEY` environment variable
78
+ will be used if available.
79
+ openai_client: An existing
80
+ [`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
81
+ client to use. If provided, `base_url`, `api_key`, and `http_client` must be `None`.
82
+ http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
83
+ """
84
+ if openai_client is not None:
85
+ assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
86
+ assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
87
+ assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
88
+ self._client = openai_client
89
+ else:
90
+ base_url = base_url or os.getenv('OLLAMA_BASE_URL')
91
+ if not base_url:
92
+ raise UserError(
93
+ 'Set the `OLLAMA_BASE_URL` environment variable or pass it via `OllamaProvider(base_url=...)`'
94
+ 'to use the Ollama provider.'
95
+ )
96
+
97
+ # This is a workaround for the OpenAI client requiring an API key, whilst locally served,
98
+ # openai compatible models do not always need an API key, but a placeholder (non-empty) key is required.
99
+ api_key = api_key or os.getenv('OLLAMA_API_KEY') or 'api-key-not-set'
100
+
101
+ if http_client is not None:
102
+ self._client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client)
103
+ else:
104
+ http_client = cached_async_http_client(provider='ollama')
105
+ self._client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client)
@@ -17,6 +17,7 @@ from pydantic_ai.profiles.google import google_model_profile
17
17
  from pydantic_ai.profiles.grok import grok_model_profile
18
18
  from pydantic_ai.profiles.meta import meta_model_profile
19
19
  from pydantic_ai.profiles.mistral import mistral_model_profile
20
+ from pydantic_ai.profiles.moonshotai import moonshotai_model_profile
20
21
  from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile, openai_model_profile
21
22
  from pydantic_ai.profiles.qwen import qwen_model_profile
22
23
  from pydantic_ai.providers import Provider
@@ -57,6 +58,7 @@ class OpenRouterProvider(Provider[AsyncOpenAI]):
57
58
  'amazon': amazon_model_profile,
58
59
  'deepseek': deepseek_model_profile,
59
60
  'meta-llama': meta_model_profile,
61
+ 'moonshotai': moonshotai_model_profile,
60
62
  }
61
63
 
62
64
  profile = None
pydantic_ai/result.py CHANGED
@@ -196,7 +196,7 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
196
196
  and isinstance(event.part, _messages.TextPart)
197
197
  and event.part.content
198
198
  ):
199
- yield event.part.content, event.index
199
+ yield event.part.content, event.index # pragma: no cover
200
200
  elif ( # pragma: no branch
201
201
  isinstance(event, _messages.PartDeltaEvent)
202
202
  and isinstance(event.delta, _messages.TextPartDelta)