pydantic-ai-slim 0.4.6__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_ai/_parts_manager.py +31 -5
- pydantic_ai/ag_ui.py +67 -77
- pydantic_ai/agent.py +5 -3
- pydantic_ai/mcp.py +97 -37
- pydantic_ai/messages.py +84 -21
- pydantic_ai/models/__init__.py +11 -0
- pydantic_ai/models/anthropic.py +11 -3
- pydantic_ai/models/bedrock.py +4 -2
- pydantic_ai/models/cohere.py +6 -6
- pydantic_ai/models/function.py +4 -2
- pydantic_ai/models/gemini.py +5 -1
- pydantic_ai/models/google.py +9 -2
- pydantic_ai/models/groq.py +6 -2
- pydantic_ai/models/huggingface.py +6 -2
- pydantic_ai/models/mistral.py +3 -1
- pydantic_ai/models/openai.py +34 -7
- pydantic_ai/models/test.py +6 -2
- pydantic_ai/profiles/openai.py +8 -0
- pydantic_ai/providers/__init__.py +8 -0
- pydantic_ai/providers/moonshotai.py +97 -0
- pydantic_ai/providers/vercel.py +107 -0
- pydantic_ai/retries.py +249 -0
- pydantic_ai/toolsets/combined.py +4 -3
- {pydantic_ai_slim-0.4.6.dist-info → pydantic_ai_slim-0.4.8.dist-info}/METADATA +9 -6
- {pydantic_ai_slim-0.4.6.dist-info → pydantic_ai_slim-0.4.8.dist-info}/RECORD +28 -25
- {pydantic_ai_slim-0.4.6.dist-info → pydantic_ai_slim-0.4.8.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.4.6.dist-info → pydantic_ai_slim-0.4.8.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.4.6.dist-info → pydantic_ai_slim-0.4.8.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/messages.py
CHANGED
|
@@ -85,7 +85,7 @@ class SystemPromptPart:
|
|
|
85
85
|
__repr__ = _utils.dataclasses_no_defaults_repr
|
|
86
86
|
|
|
87
87
|
|
|
88
|
-
@dataclass(repr=False)
|
|
88
|
+
@dataclass(init=False, repr=False)
|
|
89
89
|
class FileUrl(ABC):
|
|
90
90
|
"""Abstract base class for any URL-based file."""
|
|
91
91
|
|
|
@@ -106,11 +106,29 @@ class FileUrl(ABC):
|
|
|
106
106
|
- `GoogleModel`: `VideoUrl.vendor_metadata` is used as `video_metadata`: https://ai.google.dev/gemini-api/docs/video-understanding#customize-video-processing
|
|
107
107
|
"""
|
|
108
108
|
|
|
109
|
-
|
|
109
|
+
_media_type: str | None = field(init=False, repr=False)
|
|
110
|
+
|
|
111
|
+
def __init__(
|
|
112
|
+
self,
|
|
113
|
+
url: str,
|
|
114
|
+
force_download: bool = False,
|
|
115
|
+
vendor_metadata: dict[str, Any] | None = None,
|
|
116
|
+
media_type: str | None = None,
|
|
117
|
+
) -> None:
|
|
118
|
+
self.url = url
|
|
119
|
+
self.vendor_metadata = vendor_metadata
|
|
120
|
+
self.force_download = force_download
|
|
121
|
+
self._media_type = media_type
|
|
122
|
+
|
|
110
123
|
@abstractmethod
|
|
111
|
-
def
|
|
124
|
+
def _infer_media_type(self) -> str:
|
|
112
125
|
"""Return the media type of the file, based on the url."""
|
|
113
126
|
|
|
127
|
+
@property
|
|
128
|
+
def media_type(self) -> str:
|
|
129
|
+
"""Return the media type of the file, based on the url or the provided `_media_type`."""
|
|
130
|
+
return self._media_type or self._infer_media_type()
|
|
131
|
+
|
|
114
132
|
@property
|
|
115
133
|
@abstractmethod
|
|
116
134
|
def format(self) -> str:
|
|
@@ -119,7 +137,7 @@ class FileUrl(ABC):
|
|
|
119
137
|
__repr__ = _utils.dataclasses_no_defaults_repr
|
|
120
138
|
|
|
121
139
|
|
|
122
|
-
@dataclass(repr=False)
|
|
140
|
+
@dataclass(init=False, repr=False)
|
|
123
141
|
class VideoUrl(FileUrl):
|
|
124
142
|
"""A URL to a video."""
|
|
125
143
|
|
|
@@ -129,8 +147,18 @@ class VideoUrl(FileUrl):
|
|
|
129
147
|
kind: Literal['video-url'] = 'video-url'
|
|
130
148
|
"""Type identifier, this is available on all parts as a discriminator."""
|
|
131
149
|
|
|
132
|
-
|
|
133
|
-
|
|
150
|
+
def __init__(
|
|
151
|
+
self,
|
|
152
|
+
url: str,
|
|
153
|
+
force_download: bool = False,
|
|
154
|
+
vendor_metadata: dict[str, Any] | None = None,
|
|
155
|
+
media_type: str | None = None,
|
|
156
|
+
kind: Literal['video-url'] = 'video-url',
|
|
157
|
+
) -> None:
|
|
158
|
+
super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type)
|
|
159
|
+
self.kind = kind
|
|
160
|
+
|
|
161
|
+
def _infer_media_type(self) -> VideoMediaType:
|
|
134
162
|
"""Return the media type of the video, based on the url."""
|
|
135
163
|
if self.url.endswith('.mkv'):
|
|
136
164
|
return 'video/x-matroska'
|
|
@@ -170,7 +198,7 @@ class VideoUrl(FileUrl):
|
|
|
170
198
|
return _video_format_lookup[self.media_type]
|
|
171
199
|
|
|
172
200
|
|
|
173
|
-
@dataclass(repr=False)
|
|
201
|
+
@dataclass(init=False, repr=False)
|
|
174
202
|
class AudioUrl(FileUrl):
|
|
175
203
|
"""A URL to an audio file."""
|
|
176
204
|
|
|
@@ -180,8 +208,18 @@ class AudioUrl(FileUrl):
|
|
|
180
208
|
kind: Literal['audio-url'] = 'audio-url'
|
|
181
209
|
"""Type identifier, this is available on all parts as a discriminator."""
|
|
182
210
|
|
|
183
|
-
|
|
184
|
-
|
|
211
|
+
def __init__(
|
|
212
|
+
self,
|
|
213
|
+
url: str,
|
|
214
|
+
force_download: bool = False,
|
|
215
|
+
vendor_metadata: dict[str, Any] | None = None,
|
|
216
|
+
media_type: str | None = None,
|
|
217
|
+
kind: Literal['audio-url'] = 'audio-url',
|
|
218
|
+
) -> None:
|
|
219
|
+
super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type)
|
|
220
|
+
self.kind = kind
|
|
221
|
+
|
|
222
|
+
def _infer_media_type(self) -> AudioMediaType:
|
|
185
223
|
"""Return the media type of the audio file, based on the url.
|
|
186
224
|
|
|
187
225
|
References:
|
|
@@ -208,7 +246,7 @@ class AudioUrl(FileUrl):
|
|
|
208
246
|
return _audio_format_lookup[self.media_type]
|
|
209
247
|
|
|
210
248
|
|
|
211
|
-
@dataclass(repr=False)
|
|
249
|
+
@dataclass(init=False, repr=False)
|
|
212
250
|
class ImageUrl(FileUrl):
|
|
213
251
|
"""A URL to an image."""
|
|
214
252
|
|
|
@@ -218,8 +256,18 @@ class ImageUrl(FileUrl):
|
|
|
218
256
|
kind: Literal['image-url'] = 'image-url'
|
|
219
257
|
"""Type identifier, this is available on all parts as a discriminator."""
|
|
220
258
|
|
|
221
|
-
|
|
222
|
-
|
|
259
|
+
def __init__(
|
|
260
|
+
self,
|
|
261
|
+
url: str,
|
|
262
|
+
force_download: bool = False,
|
|
263
|
+
vendor_metadata: dict[str, Any] | None = None,
|
|
264
|
+
media_type: str | None = None,
|
|
265
|
+
kind: Literal['image-url'] = 'image-url',
|
|
266
|
+
) -> None:
|
|
267
|
+
super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type)
|
|
268
|
+
self.kind = kind
|
|
269
|
+
|
|
270
|
+
def _infer_media_type(self) -> ImageMediaType:
|
|
223
271
|
"""Return the media type of the image, based on the url."""
|
|
224
272
|
if self.url.endswith(('.jpg', '.jpeg')):
|
|
225
273
|
return 'image/jpeg'
|
|
@@ -241,7 +289,7 @@ class ImageUrl(FileUrl):
|
|
|
241
289
|
return _image_format_lookup[self.media_type]
|
|
242
290
|
|
|
243
291
|
|
|
244
|
-
@dataclass(repr=False)
|
|
292
|
+
@dataclass(init=False, repr=False)
|
|
245
293
|
class DocumentUrl(FileUrl):
|
|
246
294
|
"""The URL of the document."""
|
|
247
295
|
|
|
@@ -251,8 +299,18 @@ class DocumentUrl(FileUrl):
|
|
|
251
299
|
kind: Literal['document-url'] = 'document-url'
|
|
252
300
|
"""Type identifier, this is available on all parts as a discriminator."""
|
|
253
301
|
|
|
254
|
-
|
|
255
|
-
|
|
302
|
+
def __init__(
|
|
303
|
+
self,
|
|
304
|
+
url: str,
|
|
305
|
+
force_download: bool = False,
|
|
306
|
+
vendor_metadata: dict[str, Any] | None = None,
|
|
307
|
+
media_type: str | None = None,
|
|
308
|
+
kind: Literal['document-url'] = 'document-url',
|
|
309
|
+
) -> None:
|
|
310
|
+
super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type)
|
|
311
|
+
self.kind = kind
|
|
312
|
+
|
|
313
|
+
def _infer_media_type(self) -> str:
|
|
256
314
|
"""Return the media type of the document, based on the url."""
|
|
257
315
|
type_, _ = guess_type(self.url)
|
|
258
316
|
if type_ is None:
|
|
@@ -632,7 +690,7 @@ class ThinkingPart:
|
|
|
632
690
|
|
|
633
691
|
def has_content(self) -> bool:
|
|
634
692
|
"""Return `True` if the thinking content is non-empty."""
|
|
635
|
-
return bool(self.content)
|
|
693
|
+
return bool(self.content)
|
|
636
694
|
|
|
637
695
|
__repr__ = _utils.dataclasses_no_defaults_repr
|
|
638
696
|
|
|
@@ -757,11 +815,16 @@ class ModelResponse:
|
|
|
757
815
|
},
|
|
758
816
|
}
|
|
759
817
|
)
|
|
760
|
-
elif isinstance(part, TextPart):
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
818
|
+
elif isinstance(part, (TextPart, ThinkingPart)):
|
|
819
|
+
kind = part.part_kind
|
|
820
|
+
body.setdefault('content', []).append(
|
|
821
|
+
{'kind': kind, **({'text': part.content} if settings.include_content else {})}
|
|
822
|
+
)
|
|
823
|
+
|
|
824
|
+
if content := body.get('content'):
|
|
825
|
+
text_content = content[0].get('text')
|
|
826
|
+
if content == [{'kind': 'text', 'text': text_content}]:
|
|
827
|
+
body['content'] = text_content
|
|
765
828
|
|
|
766
829
|
return result
|
|
767
830
|
|
pydantic_ai/models/__init__.py
CHANGED
|
@@ -233,6 +233,15 @@ KnownModelName = TypeAliasType(
|
|
|
233
233
|
'mistral:mistral-large-latest',
|
|
234
234
|
'mistral:mistral-moderation-latest',
|
|
235
235
|
'mistral:mistral-small-latest',
|
|
236
|
+
'moonshotai:moonshot-v1-8k',
|
|
237
|
+
'moonshotai:moonshot-v1-32k',
|
|
238
|
+
'moonshotai:moonshot-v1-128k',
|
|
239
|
+
'moonshotai:moonshot-v1-8k-vision-preview',
|
|
240
|
+
'moonshotai:moonshot-v1-32k-vision-preview',
|
|
241
|
+
'moonshotai:moonshot-v1-128k-vision-preview',
|
|
242
|
+
'moonshotai:kimi-latest',
|
|
243
|
+
'moonshotai:kimi-thinking-preview',
|
|
244
|
+
'moonshotai:kimi-k2-0711-preview',
|
|
236
245
|
'o1',
|
|
237
246
|
'o1-2024-12-17',
|
|
238
247
|
'o1-mini',
|
|
@@ -615,7 +624,9 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
|
|
|
615
624
|
'deepseek',
|
|
616
625
|
'azure',
|
|
617
626
|
'openrouter',
|
|
627
|
+
'vercel',
|
|
618
628
|
'grok',
|
|
629
|
+
'moonshotai',
|
|
619
630
|
'fireworks',
|
|
620
631
|
'together',
|
|
621
632
|
'heroku',
|
pydantic_ai/models/anthropic.py
CHANGED
|
@@ -470,7 +470,7 @@ class AnthropicStreamedResponse(StreamedResponse):
|
|
|
470
470
|
_response: AsyncIterable[BetaRawMessageStreamEvent]
|
|
471
471
|
_timestamp: datetime
|
|
472
472
|
|
|
473
|
-
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
473
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
|
|
474
474
|
current_block: BetaContentBlock | None = None
|
|
475
475
|
|
|
476
476
|
async for event in self._response:
|
|
@@ -479,7 +479,11 @@ class AnthropicStreamedResponse(StreamedResponse):
|
|
|
479
479
|
if isinstance(event, BetaRawContentBlockStartEvent):
|
|
480
480
|
current_block = event.content_block
|
|
481
481
|
if isinstance(current_block, BetaTextBlock) and current_block.text:
|
|
482
|
-
|
|
482
|
+
maybe_event = self._parts_manager.handle_text_delta(
|
|
483
|
+
vendor_part_id='content', content=current_block.text
|
|
484
|
+
)
|
|
485
|
+
if maybe_event is not None: # pragma: no branch
|
|
486
|
+
yield maybe_event
|
|
483
487
|
elif isinstance(current_block, BetaThinkingBlock):
|
|
484
488
|
yield self._parts_manager.handle_thinking_delta(
|
|
485
489
|
vendor_part_id='thinking',
|
|
@@ -498,7 +502,11 @@ class AnthropicStreamedResponse(StreamedResponse):
|
|
|
498
502
|
|
|
499
503
|
elif isinstance(event, BetaRawContentBlockDeltaEvent):
|
|
500
504
|
if isinstance(event.delta, BetaTextDelta):
|
|
501
|
-
|
|
505
|
+
maybe_event = self._parts_manager.handle_text_delta(
|
|
506
|
+
vendor_part_id='content', content=event.delta.text
|
|
507
|
+
)
|
|
508
|
+
if maybe_event is not None: # pragma: no branch
|
|
509
|
+
yield maybe_event
|
|
502
510
|
elif isinstance(event.delta, BetaThinkingDelta):
|
|
503
511
|
yield self._parts_manager.handle_thinking_delta(
|
|
504
512
|
vendor_part_id='thinking', content=event.delta.thinking
|
pydantic_ai/models/bedrock.py
CHANGED
|
@@ -572,7 +572,7 @@ class BedrockStreamedResponse(StreamedResponse):
|
|
|
572
572
|
_event_stream: EventStream[ConverseStreamOutputTypeDef]
|
|
573
573
|
_timestamp: datetime = field(default_factory=_utils.now_utc)
|
|
574
574
|
|
|
575
|
-
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
575
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
|
|
576
576
|
"""Return an async iterator of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.
|
|
577
577
|
|
|
578
578
|
This method should be implemented by subclasses to translate the vendor-specific stream of events into
|
|
@@ -618,7 +618,9 @@ class BedrockStreamedResponse(StreamedResponse):
|
|
|
618
618
|
UserWarning,
|
|
619
619
|
)
|
|
620
620
|
if 'text' in delta:
|
|
621
|
-
|
|
621
|
+
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=delta['text'])
|
|
622
|
+
if maybe_event is not None:
|
|
623
|
+
yield maybe_event
|
|
622
624
|
if 'toolUse' in delta:
|
|
623
625
|
tool_use = delta['toolUse']
|
|
624
626
|
maybe_event = self._parts_manager.handle_tool_call_delta(
|
pydantic_ai/models/cohere.py
CHANGED
|
@@ -38,15 +38,15 @@ try:
|
|
|
38
38
|
AssistantChatMessageV2,
|
|
39
39
|
AsyncClientV2,
|
|
40
40
|
ChatMessageV2,
|
|
41
|
-
ChatResponse,
|
|
42
41
|
SystemChatMessageV2,
|
|
43
|
-
|
|
42
|
+
TextAssistantMessageV2ContentItem,
|
|
44
43
|
ToolCallV2,
|
|
45
44
|
ToolCallV2Function,
|
|
46
45
|
ToolChatMessageV2,
|
|
47
46
|
ToolV2,
|
|
48
47
|
ToolV2Function,
|
|
49
48
|
UserChatMessageV2,
|
|
49
|
+
V2ChatResponse,
|
|
50
50
|
)
|
|
51
51
|
from cohere.core.api_error import ApiError
|
|
52
52
|
from cohere.v2.client import OMIT
|
|
@@ -164,7 +164,7 @@ class CohereModel(Model):
|
|
|
164
164
|
messages: list[ModelMessage],
|
|
165
165
|
model_settings: CohereModelSettings,
|
|
166
166
|
model_request_parameters: ModelRequestParameters,
|
|
167
|
-
) ->
|
|
167
|
+
) -> V2ChatResponse:
|
|
168
168
|
tools = self._get_tools(model_request_parameters)
|
|
169
169
|
cohere_messages = self._map_messages(messages)
|
|
170
170
|
try:
|
|
@@ -185,7 +185,7 @@ class CohereModel(Model):
|
|
|
185
185
|
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
|
|
186
186
|
raise # pragma: no cover
|
|
187
187
|
|
|
188
|
-
def _process_response(self, response:
|
|
188
|
+
def _process_response(self, response: V2ChatResponse) -> ModelResponse:
|
|
189
189
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
190
190
|
parts: list[ModelResponsePart] = []
|
|
191
191
|
if response.message.content is not None and len(response.message.content) > 0:
|
|
@@ -227,7 +227,7 @@ class CohereModel(Model):
|
|
|
227
227
|
assert_never(item)
|
|
228
228
|
message_param = AssistantChatMessageV2(role='assistant')
|
|
229
229
|
if texts:
|
|
230
|
-
message_param.content = [
|
|
230
|
+
message_param.content = [TextAssistantMessageV2ContentItem(text='\n\n'.join(texts))]
|
|
231
231
|
if tool_calls:
|
|
232
232
|
message_param.tool_calls = tool_calls
|
|
233
233
|
cohere_messages.append(message_param)
|
|
@@ -294,7 +294,7 @@ class CohereModel(Model):
|
|
|
294
294
|
assert_never(part)
|
|
295
295
|
|
|
296
296
|
|
|
297
|
-
def _map_usage(response:
|
|
297
|
+
def _map_usage(response: V2ChatResponse) -> usage.Usage:
|
|
298
298
|
u = response.usage
|
|
299
299
|
if u is None:
|
|
300
300
|
return usage.Usage()
|
pydantic_ai/models/function.py
CHANGED
|
@@ -264,7 +264,9 @@ class FunctionStreamedResponse(StreamedResponse):
|
|
|
264
264
|
if isinstance(item, str):
|
|
265
265
|
response_tokens = _estimate_string_tokens(item)
|
|
266
266
|
self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens)
|
|
267
|
-
|
|
267
|
+
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=item)
|
|
268
|
+
if maybe_event is not None: # pragma: no branch
|
|
269
|
+
yield maybe_event
|
|
268
270
|
elif isinstance(item, dict) and item:
|
|
269
271
|
for dtc_index, delta in item.items():
|
|
270
272
|
if isinstance(delta, DeltaThinkingPart):
|
|
@@ -286,7 +288,7 @@ class FunctionStreamedResponse(StreamedResponse):
|
|
|
286
288
|
args=delta.json_args,
|
|
287
289
|
tool_call_id=delta.tool_call_id,
|
|
288
290
|
)
|
|
289
|
-
if maybe_event is not None:
|
|
291
|
+
if maybe_event is not None: # pragma: no branch
|
|
290
292
|
yield maybe_event
|
|
291
293
|
else:
|
|
292
294
|
assert_never(delta)
|
pydantic_ai/models/gemini.py
CHANGED
|
@@ -438,7 +438,11 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
438
438
|
if 'text' in gemini_part:
|
|
439
439
|
# Using vendor_part_id=None means we can produce multiple text parts if their deltas are sprinkled
|
|
440
440
|
# amongst the tool call deltas
|
|
441
|
-
|
|
441
|
+
maybe_event = self._parts_manager.handle_text_delta(
|
|
442
|
+
vendor_part_id=None, content=gemini_part['text']
|
|
443
|
+
)
|
|
444
|
+
if maybe_event is not None: # pragma: no branch
|
|
445
|
+
yield maybe_event
|
|
442
446
|
|
|
443
447
|
elif 'function_call' in gemini_part:
|
|
444
448
|
# Here, we assume all function_call parts are complete and don't have deltas.
|
pydantic_ai/models/google.py
CHANGED
|
@@ -411,7 +411,12 @@ class GoogleModel(Model):
|
|
|
411
411
|
file_data_dict['video_metadata'] = item.vendor_metadata
|
|
412
412
|
content.append(file_data_dict) # type: ignore
|
|
413
413
|
elif isinstance(item, FileUrl):
|
|
414
|
-
if
|
|
414
|
+
if item.force_download or (
|
|
415
|
+
# google-gla does not support passing file urls directly, except for youtube videos
|
|
416
|
+
# (see above) and files uploaded to the file API (which cannot be downloaded anyway)
|
|
417
|
+
self.system == 'google-gla'
|
|
418
|
+
and not item.url.startswith(r'https://generativelanguage.googleapis.com/v1beta/files')
|
|
419
|
+
):
|
|
415
420
|
downloaded_item = await download_item(item, data_format='base64')
|
|
416
421
|
inline_data = {'data': downloaded_item['data'], 'mime_type': downloaded_item['data_type']}
|
|
417
422
|
content.append({'inline_data': inline_data}) # type: ignore
|
|
@@ -453,7 +458,9 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
453
458
|
if part.thought:
|
|
454
459
|
yield self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=part.text)
|
|
455
460
|
else:
|
|
456
|
-
|
|
461
|
+
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=part.text)
|
|
462
|
+
if maybe_event is not None: # pragma: no branch
|
|
463
|
+
yield maybe_event
|
|
457
464
|
elif part.function_call:
|
|
458
465
|
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
459
466
|
vendor_part_id=uuid4(),
|
pydantic_ai/models/groq.py
CHANGED
|
@@ -415,7 +415,11 @@ class GroqStreamedResponse(StreamedResponse):
|
|
|
415
415
|
# Handle the text part of the response
|
|
416
416
|
content = choice.delta.content
|
|
417
417
|
if content is not None:
|
|
418
|
-
|
|
418
|
+
maybe_event = self._parts_manager.handle_text_delta(
|
|
419
|
+
vendor_part_id='content', content=content, extract_think_tags=True
|
|
420
|
+
)
|
|
421
|
+
if maybe_event is not None: # pragma: no branch
|
|
422
|
+
yield maybe_event
|
|
419
423
|
|
|
420
424
|
# Handle the tool calls
|
|
421
425
|
for dtc in choice.delta.tool_calls or []:
|
|
@@ -444,7 +448,7 @@ def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> us
|
|
|
444
448
|
if isinstance(completion, chat.ChatCompletion):
|
|
445
449
|
response_usage = completion.usage
|
|
446
450
|
elif completion.x_groq is not None:
|
|
447
|
-
response_usage = completion.x_groq.usage
|
|
451
|
+
response_usage = completion.x_groq.usage
|
|
448
452
|
|
|
449
453
|
if response_usage is None:
|
|
450
454
|
return usage.Usage()
|
|
@@ -426,8 +426,12 @@ class HuggingFaceStreamedResponse(StreamedResponse):
|
|
|
426
426
|
|
|
427
427
|
# Handle the text part of the response
|
|
428
428
|
content = choice.delta.content
|
|
429
|
-
if content
|
|
430
|
-
|
|
429
|
+
if content:
|
|
430
|
+
maybe_event = self._parts_manager.handle_text_delta(
|
|
431
|
+
vendor_part_id='content', content=content, extract_think_tags=True
|
|
432
|
+
)
|
|
433
|
+
if maybe_event is not None: # pragma: no branch
|
|
434
|
+
yield maybe_event
|
|
431
435
|
|
|
432
436
|
for dtc in choice.delta.tool_calls or []:
|
|
433
437
|
maybe_event = self._parts_manager.handle_tool_call_delta(
|
pydantic_ai/models/mistral.py
CHANGED
|
@@ -601,7 +601,9 @@ class MistralStreamedResponse(StreamedResponse):
|
|
|
601
601
|
tool_call_id=maybe_tool_call_part.tool_call_id,
|
|
602
602
|
)
|
|
603
603
|
else:
|
|
604
|
-
|
|
604
|
+
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=text)
|
|
605
|
+
if maybe_event is not None: # pragma: no branch
|
|
606
|
+
yield maybe_event
|
|
605
607
|
|
|
606
608
|
# Handle the explicit tool calls
|
|
607
609
|
for index, dtc in enumerate(choice.delta.tool_calls or []):
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -17,7 +17,7 @@ from pydantic_ai.providers import Provider, infer_provider
|
|
|
17
17
|
|
|
18
18
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
19
19
|
from .._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition
|
|
20
|
-
from .._utils import guard_tool_call_id as _guard_tool_call_id, number_to_datetime
|
|
20
|
+
from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc, number_to_datetime
|
|
21
21
|
from ..messages import (
|
|
22
22
|
AudioUrl,
|
|
23
23
|
BinaryContent,
|
|
@@ -191,7 +191,17 @@ class OpenAIModel(Model):
|
|
|
191
191
|
model_name: OpenAIModelName,
|
|
192
192
|
*,
|
|
193
193
|
provider: Literal[
|
|
194
|
-
'openai',
|
|
194
|
+
'openai',
|
|
195
|
+
'deepseek',
|
|
196
|
+
'azure',
|
|
197
|
+
'openrouter',
|
|
198
|
+
'moonshotai',
|
|
199
|
+
'vercel',
|
|
200
|
+
'grok',
|
|
201
|
+
'fireworks',
|
|
202
|
+
'together',
|
|
203
|
+
'heroku',
|
|
204
|
+
'github',
|
|
195
205
|
]
|
|
196
206
|
| Provider[AsyncOpenAI] = 'openai',
|
|
197
207
|
profile: ModelProfileSpec | None = None,
|
|
@@ -290,7 +300,10 @@ class OpenAIModel(Model):
|
|
|
290
300
|
tools = self._get_tools(model_request_parameters)
|
|
291
301
|
if not tools:
|
|
292
302
|
tool_choice: Literal['none', 'required', 'auto'] | None = None
|
|
293
|
-
elif
|
|
303
|
+
elif (
|
|
304
|
+
not model_request_parameters.allow_text_output
|
|
305
|
+
and OpenAIModelProfile.from_profile(self.profile).openai_supports_tool_choice_required
|
|
306
|
+
):
|
|
294
307
|
tool_choice = 'required'
|
|
295
308
|
else:
|
|
296
309
|
tool_choice = 'auto'
|
|
@@ -357,11 +370,17 @@ class OpenAIModel(Model):
|
|
|
357
370
|
if not isinstance(response, chat.ChatCompletion):
|
|
358
371
|
raise UnexpectedModelBehavior('Invalid response from OpenAI chat completions endpoint, expected JSON data')
|
|
359
372
|
|
|
373
|
+
if response.created:
|
|
374
|
+
timestamp = number_to_datetime(response.created)
|
|
375
|
+
else:
|
|
376
|
+
timestamp = _now_utc()
|
|
377
|
+
response.created = int(timestamp.timestamp())
|
|
378
|
+
|
|
360
379
|
try:
|
|
361
380
|
response = chat.ChatCompletion.model_validate(response.model_dump())
|
|
362
381
|
except ValidationError as e:
|
|
363
382
|
raise UnexpectedModelBehavior(f'Invalid response from OpenAI chat completions endpoint: {e}') from e
|
|
364
|
-
|
|
383
|
+
|
|
365
384
|
choice = response.choices[0]
|
|
366
385
|
items: list[ModelResponsePart] = []
|
|
367
386
|
# The `reasoning_content` is only present in DeepSeek models.
|
|
@@ -1003,8 +1022,12 @@ class OpenAIStreamedResponse(StreamedResponse):
|
|
|
1003
1022
|
|
|
1004
1023
|
# Handle the text part of the response
|
|
1005
1024
|
content = choice.delta.content
|
|
1006
|
-
if content
|
|
1007
|
-
|
|
1025
|
+
if content:
|
|
1026
|
+
maybe_event = self._parts_manager.handle_text_delta(
|
|
1027
|
+
vendor_part_id='content', content=content, extract_think_tags=True
|
|
1028
|
+
)
|
|
1029
|
+
if maybe_event is not None: # pragma: no branch
|
|
1030
|
+
yield maybe_event
|
|
1008
1031
|
|
|
1009
1032
|
# Handle reasoning part of the response, present in DeepSeek models
|
|
1010
1033
|
if reasoning_content := getattr(choice.delta, 'reasoning_content', None):
|
|
@@ -1121,7 +1144,11 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
|
|
|
1121
1144
|
)
|
|
1122
1145
|
|
|
1123
1146
|
elif isinstance(chunk, responses.ResponseTextDeltaEvent):
|
|
1124
|
-
|
|
1147
|
+
maybe_event = self._parts_manager.handle_text_delta(
|
|
1148
|
+
vendor_part_id=chunk.content_index, content=chunk.delta
|
|
1149
|
+
)
|
|
1150
|
+
if maybe_event is not None: # pragma: no branch
|
|
1151
|
+
yield maybe_event
|
|
1125
1152
|
|
|
1126
1153
|
elif isinstance(chunk, responses.ResponseTextDoneEvent):
|
|
1127
1154
|
pass # there's nothing we need to do here
|
pydantic_ai/models/test.py
CHANGED
|
@@ -269,10 +269,14 @@ class TestStreamedResponse(StreamedResponse):
|
|
|
269
269
|
mid = len(text) // 2
|
|
270
270
|
words = [text[:mid], text[mid:]]
|
|
271
271
|
self._usage += _get_string_usage('')
|
|
272
|
-
|
|
272
|
+
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=i, content='')
|
|
273
|
+
if maybe_event is not None: # pragma: no branch
|
|
274
|
+
yield maybe_event
|
|
273
275
|
for word in words:
|
|
274
276
|
self._usage += _get_string_usage(word)
|
|
275
|
-
|
|
277
|
+
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=i, content=word)
|
|
278
|
+
if maybe_event is not None: # pragma: no branch
|
|
279
|
+
yield maybe_event
|
|
276
280
|
elif isinstance(part, ToolCallPart):
|
|
277
281
|
yield self._parts_manager.handle_tool_call_part(
|
|
278
282
|
vendor_part_id=i, tool_name=part.tool_name, args=part.args, tool_call_id=part.tool_call_id
|
pydantic_ai/profiles/openai.py
CHANGED
|
@@ -21,6 +21,14 @@ class OpenAIModelProfile(ModelProfile):
|
|
|
21
21
|
openai_supports_sampling_settings: bool = True
|
|
22
22
|
"""Turn off to don't send sampling settings like `temperature` and `top_p` to models that don't support them, like OpenAI's o-series reasoning models."""
|
|
23
23
|
|
|
24
|
+
# Some OpenAI-compatible providers (e.g. MoonshotAI) currently do **not** accept
|
|
25
|
+
# `tool_choice="required"`. This flag lets the calling model know whether it's
|
|
26
|
+
# safe to pass that value along. Default is `True` to preserve existing
|
|
27
|
+
# behaviour for OpenAI itself and most providers.
|
|
28
|
+
openai_supports_tool_choice_required: bool = True
|
|
29
|
+
"""Whether the provider accepts the value ``tool_choice='required'`` in the
|
|
30
|
+
request payload."""
|
|
31
|
+
|
|
24
32
|
|
|
25
33
|
def openai_model_profile(model_name: str) -> ModelProfile:
|
|
26
34
|
"""Get the model profile for an OpenAI model."""
|
|
@@ -62,6 +62,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
|
|
|
62
62
|
from .openrouter import OpenRouterProvider
|
|
63
63
|
|
|
64
64
|
return OpenRouterProvider
|
|
65
|
+
elif provider == 'vercel':
|
|
66
|
+
from .vercel import VercelProvider
|
|
67
|
+
|
|
68
|
+
return VercelProvider
|
|
65
69
|
elif provider == 'azure':
|
|
66
70
|
from .azure import AzureProvider
|
|
67
71
|
|
|
@@ -99,6 +103,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
|
|
|
99
103
|
from .grok import GrokProvider
|
|
100
104
|
|
|
101
105
|
return GrokProvider
|
|
106
|
+
elif provider == 'moonshotai':
|
|
107
|
+
from .moonshotai import MoonshotAIProvider
|
|
108
|
+
|
|
109
|
+
return MoonshotAIProvider
|
|
102
110
|
elif provider == 'fireworks':
|
|
103
111
|
from .fireworks import FireworksProvider
|
|
104
112
|
|