pydantic-ai-slim 0.4.5__py3-none-any.whl → 0.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -52,6 +52,7 @@ try:
52
52
  CompletionChunk as MistralCompletionChunk,
53
53
  Content as MistralContent,
54
54
  ContentChunk as MistralContentChunk,
55
+ DocumentURLChunk as MistralDocumentURLChunk,
55
56
  FunctionCall as MistralFunctionCall,
56
57
  ImageURL as MistralImageURL,
57
58
  ImageURLChunk as MistralImageURLChunk,
@@ -539,10 +540,19 @@ class MistralModel(Model):
539
540
  if item.is_image:
540
541
  image_url = MistralImageURL(url=f'data:{item.media_type};base64,{base64_encoded}')
541
542
  content.append(MistralImageURLChunk(image_url=image_url, type='image_url'))
543
+ elif item.media_type == 'application/pdf':
544
+ content.append(
545
+ MistralDocumentURLChunk(
546
+ document_url=f'data:application/pdf;base64,{base64_encoded}', type='document_url'
547
+ )
548
+ )
542
549
  else:
543
- raise RuntimeError('Only image binary content is supported for Mistral.')
550
+ raise RuntimeError('BinaryContent other than image or PDF is not supported in Mistral.')
544
551
  elif isinstance(item, DocumentUrl):
545
- raise RuntimeError('DocumentUrl is not supported in Mistral.') # pragma: no cover
552
+ if item.media_type == 'application/pdf':
553
+ content.append(MistralDocumentURLChunk(document_url=item.url, type='document_url'))
554
+ else:
555
+ raise RuntimeError('DocumentUrl other than PDF is not supported in Mistral.')
546
556
  elif isinstance(item, VideoUrl):
547
557
  raise RuntimeError('VideoUrl is not supported in Mistral.')
548
558
  else: # pragma: no cover
@@ -591,7 +601,9 @@ class MistralStreamedResponse(StreamedResponse):
591
601
  tool_call_id=maybe_tool_call_part.tool_call_id,
592
602
  )
593
603
  else:
594
- yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=text)
604
+ maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=text)
605
+ if maybe_event is not None: # pragma: no branch
606
+ yield maybe_event
595
607
 
596
608
  # Handle the explicit tool calls
597
609
  for index, dtc in enumerate(choice.delta.tool_calls or []):
@@ -17,7 +17,7 @@ from pydantic_ai.providers import Provider, infer_provider
17
17
 
18
18
  from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
19
19
  from .._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition
20
- from .._utils import guard_tool_call_id as _guard_tool_call_id, number_to_datetime
20
+ from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc, number_to_datetime
21
21
  from ..messages import (
22
22
  AudioUrl,
23
23
  BinaryContent,
@@ -191,7 +191,17 @@ class OpenAIModel(Model):
191
191
  model_name: OpenAIModelName,
192
192
  *,
193
193
  provider: Literal[
194
- 'openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together', 'heroku', 'github'
194
+ 'openai',
195
+ 'deepseek',
196
+ 'azure',
197
+ 'openrouter',
198
+ 'moonshotai',
199
+ 'vercel',
200
+ 'grok',
201
+ 'fireworks',
202
+ 'together',
203
+ 'heroku',
204
+ 'github',
195
205
  ]
196
206
  | Provider[AsyncOpenAI] = 'openai',
197
207
  profile: ModelProfileSpec | None = None,
@@ -290,7 +300,10 @@ class OpenAIModel(Model):
290
300
  tools = self._get_tools(model_request_parameters)
291
301
  if not tools:
292
302
  tool_choice: Literal['none', 'required', 'auto'] | None = None
293
- elif not model_request_parameters.allow_text_output:
303
+ elif (
304
+ not model_request_parameters.allow_text_output
305
+ and OpenAIModelProfile.from_profile(self.profile).openai_supports_tool_choice_required
306
+ ):
294
307
  tool_choice = 'required'
295
308
  else:
296
309
  tool_choice = 'auto'
@@ -357,11 +370,17 @@ class OpenAIModel(Model):
357
370
  if not isinstance(response, chat.ChatCompletion):
358
371
  raise UnexpectedModelBehavior('Invalid response from OpenAI chat completions endpoint, expected JSON data')
359
372
 
373
+ if response.created:
374
+ timestamp = number_to_datetime(response.created)
375
+ else:
376
+ timestamp = _now_utc()
377
+ response.created = int(timestamp.timestamp())
378
+
360
379
  try:
361
380
  response = chat.ChatCompletion.model_validate(response.model_dump())
362
381
  except ValidationError as e:
363
382
  raise UnexpectedModelBehavior(f'Invalid response from OpenAI chat completions endpoint: {e}') from e
364
- timestamp = number_to_datetime(response.created)
383
+
365
384
  choice = response.choices[0]
366
385
  items: list[ModelResponsePart] = []
367
386
  # The `reasoning_content` is only present in DeepSeek models.
@@ -1003,8 +1022,12 @@ class OpenAIStreamedResponse(StreamedResponse):
1003
1022
 
1004
1023
  # Handle the text part of the response
1005
1024
  content = choice.delta.content
1006
- if content is not None:
1007
- yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content)
1025
+ if content:
1026
+ maybe_event = self._parts_manager.handle_text_delta(
1027
+ vendor_part_id='content', content=content, extract_think_tags=True
1028
+ )
1029
+ if maybe_event is not None: # pragma: no branch
1030
+ yield maybe_event
1008
1031
 
1009
1032
  # Handle reasoning part of the response, present in DeepSeek models
1010
1033
  if reasoning_content := getattr(choice.delta, 'reasoning_content', None):
@@ -1121,7 +1144,11 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
1121
1144
  )
1122
1145
 
1123
1146
  elif isinstance(chunk, responses.ResponseTextDeltaEvent):
1124
- yield self._parts_manager.handle_text_delta(vendor_part_id=chunk.content_index, content=chunk.delta)
1147
+ maybe_event = self._parts_manager.handle_text_delta(
1148
+ vendor_part_id=chunk.content_index, content=chunk.delta
1149
+ )
1150
+ if maybe_event is not None: # pragma: no branch
1151
+ yield maybe_event
1125
1152
 
1126
1153
  elif isinstance(chunk, responses.ResponseTextDoneEvent):
1127
1154
  pass # there's nothing we need to do here
@@ -269,10 +269,14 @@ class TestStreamedResponse(StreamedResponse):
269
269
  mid = len(text) // 2
270
270
  words = [text[:mid], text[mid:]]
271
271
  self._usage += _get_string_usage('')
272
- yield self._parts_manager.handle_text_delta(vendor_part_id=i, content='')
272
+ maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=i, content='')
273
+ if maybe_event is not None: # pragma: no branch
274
+ yield maybe_event
273
275
  for word in words:
274
276
  self._usage += _get_string_usage(word)
275
- yield self._parts_manager.handle_text_delta(vendor_part_id=i, content=word)
277
+ maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=i, content=word)
278
+ if maybe_event is not None: # pragma: no branch
279
+ yield maybe_event
276
280
  elif isinstance(part, ToolCallPart):
277
281
  yield self._parts_manager.handle_tool_call_part(
278
282
  vendor_part_id=i, tool_name=part.tool_name, args=part.args, tool_call_id=part.tool_call_id
@@ -21,6 +21,14 @@ class OpenAIModelProfile(ModelProfile):
21
21
  openai_supports_sampling_settings: bool = True
22
22
  """Turn off to don't send sampling settings like `temperature` and `top_p` to models that don't support them, like OpenAI's o-series reasoning models."""
23
23
 
24
+ # Some OpenAI-compatible providers (e.g. MoonshotAI) currently do **not** accept
25
+ # `tool_choice="required"`. This flag lets the calling model know whether it's
26
+ # safe to pass that value along. Default is `True` to preserve existing
27
+ # behaviour for OpenAI itself and most providers.
28
+ openai_supports_tool_choice_required: bool = True
29
+ """Whether the provider accepts the value ``tool_choice='required'`` in the
30
+ request payload."""
31
+
24
32
 
25
33
  def openai_model_profile(model_name: str) -> ModelProfile:
26
34
  """Get the model profile for an OpenAI model."""
@@ -62,6 +62,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
62
62
  from .openrouter import OpenRouterProvider
63
63
 
64
64
  return OpenRouterProvider
65
+ elif provider == 'vercel':
66
+ from .vercel import VercelProvider
67
+
68
+ return VercelProvider
65
69
  elif provider == 'azure':
66
70
  from .azure import AzureProvider
67
71
 
@@ -99,6 +103,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
99
103
  from .grok import GrokProvider
100
104
 
101
105
  return GrokProvider
106
+ elif provider == 'moonshotai':
107
+ from .moonshotai import MoonshotAIProvider
108
+
109
+ return MoonshotAIProvider
102
110
  elif provider == 'fireworks':
103
111
  from .fireworks import FireworksProvider
104
112
 
@@ -0,0 +1,97 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import os
4
+ from typing import Literal, overload
5
+
6
+ from httpx import AsyncClient as AsyncHTTPClient
7
+ from openai import AsyncOpenAI
8
+
9
+ from pydantic_ai.exceptions import UserError
10
+ from pydantic_ai.models import cached_async_http_client
11
+ from pydantic_ai.profiles import ModelProfile
12
+ from pydantic_ai.profiles.moonshotai import moonshotai_model_profile
13
+ from pydantic_ai.profiles.openai import (
14
+ OpenAIJsonSchemaTransformer,
15
+ OpenAIModelProfile,
16
+ )
17
+ from pydantic_ai.providers import Provider
18
+
19
+ MoonshotAIModelName = Literal[
20
+ 'moonshot-v1-8k',
21
+ 'moonshot-v1-32k',
22
+ 'moonshot-v1-128k',
23
+ 'moonshot-v1-8k-vision-preview',
24
+ 'moonshot-v1-32k-vision-preview',
25
+ 'moonshot-v1-128k-vision-preview',
26
+ 'kimi-latest',
27
+ 'kimi-thinking-preview',
28
+ 'kimi-k2-0711-preview',
29
+ ]
30
+
31
+
32
+ class MoonshotAIProvider(Provider[AsyncOpenAI]):
33
+ """Provider for MoonshotAI platform (Kimi models)."""
34
+
35
+ @property
36
+ def name(self) -> str:
37
+ return 'moonshotai'
38
+
39
+ @property
40
+ def base_url(self) -> str:
41
+ # OpenAI-compatible endpoint, see MoonshotAI docs
42
+ return 'https://api.moonshot.ai/v1'
43
+
44
+ @property
45
+ def client(self) -> AsyncOpenAI:
46
+ return self._client
47
+
48
+ def model_profile(self, model_name: str) -> ModelProfile | None:
49
+ profile = moonshotai_model_profile(model_name)
50
+
51
+ # As the MoonshotAI API is OpenAI-compatible, let's assume we also need OpenAIJsonSchemaTransformer,
52
+ # unless json_schema_transformer is set explicitly.
53
+ # Also, MoonshotAI does not support strict tool definitions
54
+ # https://platform.moonshot.ai/docs/guide/migrating-from-openai-to-kimi#about-tool_choice
55
+ # "Please note that the current version of Kimi API does not support the tool_choice=required parameter."
56
+ return OpenAIModelProfile(
57
+ json_schema_transformer=OpenAIJsonSchemaTransformer,
58
+ openai_supports_tool_choice_required=False,
59
+ supports_json_object_output=True,
60
+ ).update(profile)
61
+
62
+ # ---------------------------------------------------------------------
63
+ # Construction helpers
64
+ # ---------------------------------------------------------------------
65
+ @overload
66
+ def __init__(self) -> None: ...
67
+
68
+ @overload
69
+ def __init__(self, *, api_key: str) -> None: ...
70
+
71
+ @overload
72
+ def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ...
73
+
74
+ @overload
75
+ def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
76
+
77
+ def __init__(
78
+ self,
79
+ *,
80
+ api_key: str | None = None,
81
+ openai_client: AsyncOpenAI | None = None,
82
+ http_client: AsyncHTTPClient | None = None,
83
+ ) -> None:
84
+ api_key = api_key or os.getenv('MOONSHOTAI_API_KEY')
85
+ if not api_key and openai_client is None:
86
+ raise UserError(
87
+ 'Set the `MOONSHOTAI_API_KEY` environment variable or pass it via '
88
+ '`MoonshotAIProvider(api_key=...)` to use the MoonshotAI provider.'
89
+ )
90
+
91
+ if openai_client is not None:
92
+ self._client = openai_client
93
+ elif http_client is not None:
94
+ self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
95
+ else:
96
+ http_client = cached_async_http_client(provider='moonshotai')
97
+ self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
@@ -0,0 +1,107 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import os
4
+ from typing import overload
5
+
6
+ from httpx import AsyncClient as AsyncHTTPClient
7
+
8
+ from pydantic_ai.exceptions import UserError
9
+ from pydantic_ai.models import cached_async_http_client
10
+ from pydantic_ai.profiles import ModelProfile
11
+ from pydantic_ai.profiles.amazon import amazon_model_profile
12
+ from pydantic_ai.profiles.anthropic import anthropic_model_profile
13
+ from pydantic_ai.profiles.cohere import cohere_model_profile
14
+ from pydantic_ai.profiles.deepseek import deepseek_model_profile
15
+ from pydantic_ai.profiles.google import google_model_profile
16
+ from pydantic_ai.profiles.grok import grok_model_profile
17
+ from pydantic_ai.profiles.mistral import mistral_model_profile
18
+ from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile, openai_model_profile
19
+ from pydantic_ai.providers import Provider
20
+
21
+ try:
22
+ from openai import AsyncOpenAI
23
+ except ImportError as _import_error: # pragma: no cover
24
+ raise ImportError(
25
+ 'Please install the `openai` package to use the Vercel provider, '
26
+ 'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
27
+ ) from _import_error
28
+
29
+
30
+ class VercelProvider(Provider[AsyncOpenAI]):
31
+ """Provider for Vercel AI Gateway API."""
32
+
33
+ @property
34
+ def name(self) -> str:
35
+ return 'vercel'
36
+
37
+ @property
38
+ def base_url(self) -> str:
39
+ return 'https://ai-gateway.vercel.sh/v1'
40
+
41
+ @property
42
+ def client(self) -> AsyncOpenAI:
43
+ return self._client
44
+
45
+ def model_profile(self, model_name: str) -> ModelProfile | None:
46
+ provider_to_profile = {
47
+ 'anthropic': anthropic_model_profile,
48
+ 'bedrock': amazon_model_profile,
49
+ 'cohere': cohere_model_profile,
50
+ 'deepseek': deepseek_model_profile,
51
+ 'mistral': mistral_model_profile,
52
+ 'openai': openai_model_profile,
53
+ 'vertex': google_model_profile,
54
+ 'xai': grok_model_profile,
55
+ }
56
+
57
+ profile = None
58
+
59
+ try:
60
+ provider, model_name = model_name.split('/', 1)
61
+ except ValueError:
62
+ raise UserError(f"Model name must be in 'provider/model' format, got: {model_name!r}")
63
+
64
+ if provider in provider_to_profile:
65
+ profile = provider_to_profile[provider](model_name)
66
+
67
+ # As VercelProvider is always used with OpenAIModel, which used to unconditionally use OpenAIJsonSchemaTransformer,
68
+ # we need to maintain that behavior unless json_schema_transformer is set explicitly
69
+ return OpenAIModelProfile(
70
+ json_schema_transformer=OpenAIJsonSchemaTransformer,
71
+ ).update(profile)
72
+
73
+ @overload
74
+ def __init__(self) -> None: ...
75
+
76
+ @overload
77
+ def __init__(self, *, api_key: str) -> None: ...
78
+
79
+ @overload
80
+ def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ...
81
+
82
+ @overload
83
+ def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
84
+
85
+ def __init__(
86
+ self,
87
+ *,
88
+ api_key: str | None = None,
89
+ openai_client: AsyncOpenAI | None = None,
90
+ http_client: AsyncHTTPClient | None = None,
91
+ ) -> None:
92
+ # Support Vercel AI Gateway's standard environment variables
93
+ api_key = api_key or os.getenv('VERCEL_AI_GATEWAY_API_KEY') or os.getenv('VERCEL_OIDC_TOKEN')
94
+
95
+ if not api_key and openai_client is None:
96
+ raise UserError(
97
+ 'Set the `VERCEL_AI_GATEWAY_API_KEY` or `VERCEL_OIDC_TOKEN` environment variable '
98
+ 'or pass the API key via `VercelProvider(api_key=...)` to use the Vercel provider.'
99
+ )
100
+
101
+ if openai_client is not None:
102
+ self._client = openai_client
103
+ elif http_client is not None:
104
+ self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
105
+ else:
106
+ http_client = cached_async_http_client(provider='vercel')
107
+ self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)