pydantic-ai-slim 1.0.0b1__py3-none-any.whl → 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_a2a.py +1 -1
- pydantic_ai/_agent_graph.py +65 -49
- pydantic_ai/_parts_manager.py +3 -1
- pydantic_ai/_tool_manager.py +33 -6
- pydantic_ai/ag_ui.py +75 -43
- pydantic_ai/agent/__init__.py +10 -7
- pydantic_ai/durable_exec/dbos/__init__.py +6 -0
- pydantic_ai/durable_exec/dbos/_agent.py +718 -0
- pydantic_ai/durable_exec/dbos/_mcp_server.py +89 -0
- pydantic_ai/durable_exec/dbos/_model.py +137 -0
- pydantic_ai/durable_exec/dbos/_utils.py +10 -0
- pydantic_ai/durable_exec/temporal/_agent.py +71 -10
- pydantic_ai/exceptions.py +2 -2
- pydantic_ai/mcp.py +14 -26
- pydantic_ai/messages.py +90 -19
- pydantic_ai/models/__init__.py +9 -0
- pydantic_ai/models/anthropic.py +28 -11
- pydantic_ai/models/bedrock.py +6 -14
- pydantic_ai/models/gemini.py +3 -1
- pydantic_ai/models/google.py +58 -5
- pydantic_ai/models/groq.py +122 -34
- pydantic_ai/models/instrumented.py +29 -11
- pydantic_ai/models/openai.py +84 -29
- pydantic_ai/providers/__init__.py +4 -0
- pydantic_ai/providers/bedrock.py +11 -3
- pydantic_ai/providers/google_vertex.py +2 -1
- pydantic_ai/providers/groq.py +21 -2
- pydantic_ai/providers/litellm.py +134 -0
- pydantic_ai/retries.py +42 -2
- pydantic_ai/tools.py +18 -7
- pydantic_ai/toolsets/combined.py +2 -2
- pydantic_ai/toolsets/function.py +54 -19
- pydantic_ai/usage.py +37 -3
- {pydantic_ai_slim-1.0.0b1.dist-info → pydantic_ai_slim-1.0.2.dist-info}/METADATA +9 -8
- {pydantic_ai_slim-1.0.0b1.dist-info → pydantic_ai_slim-1.0.2.dist-info}/RECORD +38 -32
- {pydantic_ai_slim-1.0.0b1.dist-info → pydantic_ai_slim-1.0.2.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-1.0.0b1.dist-info → pydantic_ai_slim-1.0.2.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-1.0.0b1.dist-info → pydantic_ai_slim-1.0.2.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/messages.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import base64
|
|
4
|
+
import hashlib
|
|
4
5
|
from abc import ABC, abstractmethod
|
|
5
6
|
from collections.abc import Sequence
|
|
6
7
|
from dataclasses import KW_ONLY, dataclass, field, replace
|
|
@@ -51,6 +52,15 @@ ImageFormat: TypeAlias = Literal['jpeg', 'png', 'gif', 'webp']
|
|
|
51
52
|
DocumentFormat: TypeAlias = Literal['csv', 'doc', 'docx', 'html', 'md', 'pdf', 'txt', 'xls', 'xlsx']
|
|
52
53
|
VideoFormat: TypeAlias = Literal['mkv', 'mov', 'mp4', 'webm', 'flv', 'mpeg', 'mpg', 'wmv', 'three_gp']
|
|
53
54
|
|
|
55
|
+
FinishReason: TypeAlias = Literal[
|
|
56
|
+
'stop',
|
|
57
|
+
'length',
|
|
58
|
+
'content_filter',
|
|
59
|
+
'tool_call',
|
|
60
|
+
'error',
|
|
61
|
+
]
|
|
62
|
+
"""Reason the model finished generating the response, normalized to OpenTelemetry values."""
|
|
63
|
+
|
|
54
64
|
|
|
55
65
|
@dataclass(repr=False)
|
|
56
66
|
class SystemPromptPart:
|
|
@@ -88,6 +98,13 @@ class SystemPromptPart:
|
|
|
88
98
|
__repr__ = _utils.dataclasses_no_defaults_repr
|
|
89
99
|
|
|
90
100
|
|
|
101
|
+
def _multi_modal_content_identifier(identifier: str | bytes) -> str:
|
|
102
|
+
"""Generate stable identifier for multi-modal content to help LLM in finding a specific file in tool call responses."""
|
|
103
|
+
if isinstance(identifier, str):
|
|
104
|
+
identifier = identifier.encode('utf-8')
|
|
105
|
+
return hashlib.sha1(identifier).hexdigest()[:6]
|
|
106
|
+
|
|
107
|
+
|
|
91
108
|
@dataclass(init=False, repr=False)
|
|
92
109
|
class FileUrl(ABC):
|
|
93
110
|
"""Abstract base class for any URL-based file."""
|
|
@@ -115,17 +132,31 @@ class FileUrl(ABC):
|
|
|
115
132
|
compare=False, default=None
|
|
116
133
|
)
|
|
117
134
|
|
|
135
|
+
identifier: str | None = None
|
|
136
|
+
"""The identifier of the file, such as a unique ID. generating one from the url if not explicitly set
|
|
137
|
+
|
|
138
|
+
This identifier can be provided to the model in a message to allow it to refer to this file in a tool call argument,
|
|
139
|
+
and the tool can look up the file in question by iterating over the message history and finding the matching `FileUrl`.
|
|
140
|
+
|
|
141
|
+
This identifier is only automatically passed to the model when the `FileUrl` is returned by a tool.
|
|
142
|
+
If you're passing the `FileUrl` as a user message, it's up to you to include a separate text part with the identifier,
|
|
143
|
+
e.g. "This is file <identifier>:" preceding the `FileUrl`.
|
|
144
|
+
"""
|
|
145
|
+
|
|
118
146
|
def __init__(
|
|
119
147
|
self,
|
|
120
148
|
url: str,
|
|
149
|
+
*,
|
|
121
150
|
force_download: bool = False,
|
|
122
151
|
vendor_metadata: dict[str, Any] | None = None,
|
|
123
152
|
media_type: str | None = None,
|
|
153
|
+
identifier: str | None = None,
|
|
124
154
|
) -> None:
|
|
125
155
|
self.url = url
|
|
126
|
-
self.vendor_metadata = vendor_metadata
|
|
127
156
|
self.force_download = force_download
|
|
157
|
+
self.vendor_metadata = vendor_metadata
|
|
128
158
|
self._media_type = media_type
|
|
159
|
+
self.identifier = identifier or _multi_modal_content_identifier(url)
|
|
129
160
|
|
|
130
161
|
@pydantic.computed_field
|
|
131
162
|
@property
|
|
@@ -162,11 +193,12 @@ class VideoUrl(FileUrl):
|
|
|
162
193
|
def __init__(
|
|
163
194
|
self,
|
|
164
195
|
url: str,
|
|
196
|
+
*,
|
|
165
197
|
force_download: bool = False,
|
|
166
198
|
vendor_metadata: dict[str, Any] | None = None,
|
|
167
199
|
media_type: str | None = None,
|
|
168
200
|
kind: Literal['video-url'] = 'video-url',
|
|
169
|
-
|
|
201
|
+
identifier: str | None = None,
|
|
170
202
|
# Required for inline-snapshot which expects all dataclass `__init__` methods to take all field names as kwargs.
|
|
171
203
|
_media_type: str | None = None,
|
|
172
204
|
) -> None:
|
|
@@ -175,6 +207,7 @@ class VideoUrl(FileUrl):
|
|
|
175
207
|
force_download=force_download,
|
|
176
208
|
vendor_metadata=vendor_metadata,
|
|
177
209
|
media_type=media_type or _media_type,
|
|
210
|
+
identifier=identifier,
|
|
178
211
|
)
|
|
179
212
|
self.kind = kind
|
|
180
213
|
|
|
@@ -235,11 +268,12 @@ class AudioUrl(FileUrl):
|
|
|
235
268
|
def __init__(
|
|
236
269
|
self,
|
|
237
270
|
url: str,
|
|
271
|
+
*,
|
|
238
272
|
force_download: bool = False,
|
|
239
273
|
vendor_metadata: dict[str, Any] | None = None,
|
|
240
274
|
media_type: str | None = None,
|
|
241
275
|
kind: Literal['audio-url'] = 'audio-url',
|
|
242
|
-
|
|
276
|
+
identifier: str | None = None,
|
|
243
277
|
# Required for inline-snapshot which expects all dataclass `__init__` methods to take all field names as kwargs.
|
|
244
278
|
_media_type: str | None = None,
|
|
245
279
|
) -> None:
|
|
@@ -248,6 +282,7 @@ class AudioUrl(FileUrl):
|
|
|
248
282
|
force_download=force_download,
|
|
249
283
|
vendor_metadata=vendor_metadata,
|
|
250
284
|
media_type=media_type or _media_type,
|
|
285
|
+
identifier=identifier,
|
|
251
286
|
)
|
|
252
287
|
self.kind = kind
|
|
253
288
|
|
|
@@ -295,11 +330,12 @@ class ImageUrl(FileUrl):
|
|
|
295
330
|
def __init__(
|
|
296
331
|
self,
|
|
297
332
|
url: str,
|
|
333
|
+
*,
|
|
298
334
|
force_download: bool = False,
|
|
299
335
|
vendor_metadata: dict[str, Any] | None = None,
|
|
300
336
|
media_type: str | None = None,
|
|
301
337
|
kind: Literal['image-url'] = 'image-url',
|
|
302
|
-
|
|
338
|
+
identifier: str | None = None,
|
|
303
339
|
# Required for inline-snapshot which expects all dataclass `__init__` methods to take all field names as kwargs.
|
|
304
340
|
_media_type: str | None = None,
|
|
305
341
|
) -> None:
|
|
@@ -308,6 +344,7 @@ class ImageUrl(FileUrl):
|
|
|
308
344
|
force_download=force_download,
|
|
309
345
|
vendor_metadata=vendor_metadata,
|
|
310
346
|
media_type=media_type or _media_type,
|
|
347
|
+
identifier=identifier,
|
|
311
348
|
)
|
|
312
349
|
self.kind = kind
|
|
313
350
|
|
|
@@ -350,11 +387,12 @@ class DocumentUrl(FileUrl):
|
|
|
350
387
|
def __init__(
|
|
351
388
|
self,
|
|
352
389
|
url: str,
|
|
390
|
+
*,
|
|
353
391
|
force_download: bool = False,
|
|
354
392
|
vendor_metadata: dict[str, Any] | None = None,
|
|
355
393
|
media_type: str | None = None,
|
|
356
394
|
kind: Literal['document-url'] = 'document-url',
|
|
357
|
-
|
|
395
|
+
identifier: str | None = None,
|
|
358
396
|
# Required for inline-snapshot which expects all dataclass `__init__` methods to take all field names as kwargs.
|
|
359
397
|
_media_type: str | None = None,
|
|
360
398
|
) -> None:
|
|
@@ -363,6 +401,7 @@ class DocumentUrl(FileUrl):
|
|
|
363
401
|
force_download=force_download,
|
|
364
402
|
vendor_metadata=vendor_metadata,
|
|
365
403
|
media_type=media_type or _media_type,
|
|
404
|
+
identifier=identifier,
|
|
366
405
|
)
|
|
367
406
|
self.kind = kind
|
|
368
407
|
|
|
@@ -405,24 +444,26 @@ class DocumentUrl(FileUrl):
|
|
|
405
444
|
raise ValueError(f'Unknown document media type: {media_type}') from e
|
|
406
445
|
|
|
407
446
|
|
|
408
|
-
@dataclass(repr=False)
|
|
447
|
+
@dataclass(init=False, repr=False)
|
|
409
448
|
class BinaryContent:
|
|
410
449
|
"""Binary content, e.g. an audio or image file."""
|
|
411
450
|
|
|
412
451
|
data: bytes
|
|
413
452
|
"""The binary data."""
|
|
414
453
|
|
|
415
|
-
media_type: AudioMediaType | ImageMediaType | DocumentMediaType | str
|
|
416
|
-
"""The media type of the binary data."""
|
|
417
|
-
|
|
418
454
|
_: KW_ONLY
|
|
419
455
|
|
|
420
|
-
|
|
421
|
-
"""
|
|
456
|
+
media_type: AudioMediaType | ImageMediaType | DocumentMediaType | str
|
|
457
|
+
"""The media type of the binary data."""
|
|
422
458
|
|
|
423
|
-
|
|
459
|
+
identifier: str
|
|
460
|
+
"""Identifier for the binary content, such as a unique ID. generating one from the data if not explicitly set
|
|
461
|
+
This identifier can be provided to the model in a message to allow it to refer to this file in a tool call argument,
|
|
462
|
+
and the tool can look up the file in question by iterating over the message history and finding the matching `BinaryContent`.
|
|
424
463
|
|
|
425
|
-
This identifier is only automatically passed to the model when the `BinaryContent` is returned by a tool.
|
|
464
|
+
This identifier is only automatically passed to the model when the `BinaryContent` is returned by a tool.
|
|
465
|
+
If you're passing the `BinaryContent` as a user message, it's up to you to include a separate text part with the identifier,
|
|
466
|
+
e.g. "This is file <identifier>:" preceding the `BinaryContent`.
|
|
426
467
|
"""
|
|
427
468
|
|
|
428
469
|
vendor_metadata: dict[str, Any] | None = None
|
|
@@ -435,6 +476,21 @@ class BinaryContent:
|
|
|
435
476
|
kind: Literal['binary'] = 'binary'
|
|
436
477
|
"""Type identifier, this is available on all parts as a discriminator."""
|
|
437
478
|
|
|
479
|
+
def __init__(
|
|
480
|
+
self,
|
|
481
|
+
data: bytes,
|
|
482
|
+
*,
|
|
483
|
+
media_type: AudioMediaType | ImageMediaType | DocumentMediaType | str,
|
|
484
|
+
identifier: str | None = None,
|
|
485
|
+
vendor_metadata: dict[str, Any] | None = None,
|
|
486
|
+
kind: Literal['binary'] = 'binary',
|
|
487
|
+
) -> None:
|
|
488
|
+
self.data = data
|
|
489
|
+
self.media_type = media_type
|
|
490
|
+
self.identifier = identifier or _multi_modal_content_identifier(data)
|
|
491
|
+
self.vendor_metadata = vendor_metadata
|
|
492
|
+
self.kind = kind
|
|
493
|
+
|
|
438
494
|
@property
|
|
439
495
|
def is_audio(self) -> bool:
|
|
440
496
|
"""Return `True` if the media type is an audio type."""
|
|
@@ -786,7 +842,7 @@ ModelRequestPart = Annotated[
|
|
|
786
842
|
class ModelRequest:
|
|
787
843
|
"""A request generated by Pydantic AI and sent to a model, e.g. a message from the Pydantic AI app to the model."""
|
|
788
844
|
|
|
789
|
-
parts:
|
|
845
|
+
parts: Sequence[ModelRequestPart]
|
|
790
846
|
"""The parts of the user message."""
|
|
791
847
|
|
|
792
848
|
_: KW_ONLY
|
|
@@ -941,7 +997,7 @@ ModelResponsePart = Annotated[
|
|
|
941
997
|
class ModelResponse:
|
|
942
998
|
"""A response from a model, e.g. a message from the model to the Pydantic AI app."""
|
|
943
999
|
|
|
944
|
-
parts:
|
|
1000
|
+
parts: Sequence[ModelResponsePart]
|
|
945
1001
|
"""The parts of the model message."""
|
|
946
1002
|
|
|
947
1003
|
_: KW_ONLY
|
|
@@ -967,18 +1023,33 @@ class ModelResponse:
|
|
|
967
1023
|
provider_name: str | None = None
|
|
968
1024
|
"""The name of the LLM provider that generated the response."""
|
|
969
1025
|
|
|
970
|
-
provider_details:
|
|
1026
|
+
provider_details: Annotated[
|
|
1027
|
+
dict[str, Any] | None,
|
|
1028
|
+
# `vendor_details` is deprecated, but we still want to support deserializing model responses stored in a DB before the name was changed
|
|
1029
|
+
pydantic.Field(validation_alias=pydantic.AliasChoices('provider_details', 'vendor_details')),
|
|
1030
|
+
] = None
|
|
971
1031
|
"""Additional provider-specific details in a serializable format.
|
|
972
1032
|
|
|
973
1033
|
This allows storing selected vendor-specific data that isn't mapped to standard ModelResponse fields.
|
|
974
1034
|
For OpenAI models, this may include 'logprobs', 'finish_reason', etc.
|
|
975
1035
|
"""
|
|
976
1036
|
|
|
977
|
-
provider_response_id:
|
|
1037
|
+
provider_response_id: Annotated[
|
|
1038
|
+
str | None,
|
|
1039
|
+
# `vendor_id` is deprecated, but we still want to support deserializing model responses stored in a DB before the name was changed
|
|
1040
|
+
pydantic.Field(validation_alias=pydantic.AliasChoices('provider_response_id', 'vendor_id')),
|
|
1041
|
+
] = None
|
|
978
1042
|
"""request ID as specified by the model provider. This can be used to track the specific request to the model."""
|
|
979
1043
|
|
|
980
|
-
|
|
981
|
-
|
|
1044
|
+
finish_reason: FinishReason | None = None
|
|
1045
|
+
"""Reason the model finished generating the response, normalized to OpenTelemetry values."""
|
|
1046
|
+
|
|
1047
|
+
@deprecated('`price` is deprecated, use `cost` instead')
|
|
1048
|
+
def price(self) -> genai_types.PriceCalculation: # pragma: no cover
|
|
1049
|
+
return self.cost()
|
|
1050
|
+
|
|
1051
|
+
def cost(self) -> genai_types.PriceCalculation:
|
|
1052
|
+
"""Calculate the cost of the usage.
|
|
982
1053
|
|
|
983
1054
|
Uses [`genai-prices`](https://github.com/pydantic/genai-prices).
|
|
984
1055
|
"""
|
pydantic_ai/models/__init__.py
CHANGED
|
@@ -28,6 +28,7 @@ from ..exceptions import UserError
|
|
|
28
28
|
from ..messages import (
|
|
29
29
|
FileUrl,
|
|
30
30
|
FinalResultEvent,
|
|
31
|
+
FinishReason,
|
|
31
32
|
ModelMessage,
|
|
32
33
|
ModelRequest,
|
|
33
34
|
ModelResponse,
|
|
@@ -555,6 +556,10 @@ class StreamedResponse(ABC):
|
|
|
555
556
|
|
|
556
557
|
final_result_event: FinalResultEvent | None = field(default=None, init=False)
|
|
557
558
|
|
|
559
|
+
provider_response_id: str | None = field(default=None, init=False)
|
|
560
|
+
provider_details: dict[str, Any] | None = field(default=None, init=False)
|
|
561
|
+
finish_reason: FinishReason | None = field(default=None, init=False)
|
|
562
|
+
|
|
558
563
|
_parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False)
|
|
559
564
|
_event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False)
|
|
560
565
|
_usage: RequestUsage = field(default_factory=RequestUsage, init=False)
|
|
@@ -609,6 +614,9 @@ class StreamedResponse(ABC):
|
|
|
609
614
|
timestamp=self.timestamp,
|
|
610
615
|
usage=self.usage(),
|
|
611
616
|
provider_name=self.provider_name,
|
|
617
|
+
provider_response_id=self.provider_response_id,
|
|
618
|
+
provider_details=self.provider_details,
|
|
619
|
+
finish_reason=self.finish_reason,
|
|
612
620
|
)
|
|
613
621
|
|
|
614
622
|
def usage(self) -> RequestUsage:
|
|
@@ -728,6 +736,7 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
|
|
|
728
736
|
'openrouter',
|
|
729
737
|
'together',
|
|
730
738
|
'vercel',
|
|
739
|
+
'litellm',
|
|
731
740
|
):
|
|
732
741
|
from .openai import OpenAIChatModel
|
|
733
742
|
|
pydantic_ai/models/anthropic.py
CHANGED
|
@@ -21,6 +21,7 @@ from ..messages import (
|
|
|
21
21
|
BuiltinToolCallPart,
|
|
22
22
|
BuiltinToolReturnPart,
|
|
23
23
|
DocumentUrl,
|
|
24
|
+
FinishReason,
|
|
24
25
|
ImageUrl,
|
|
25
26
|
ModelMessage,
|
|
26
27
|
ModelRequest,
|
|
@@ -42,6 +43,16 @@ from ..settings import ModelSettings
|
|
|
42
43
|
from ..tools import ToolDefinition
|
|
43
44
|
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent
|
|
44
45
|
|
|
46
|
+
_FINISH_REASON_MAP: dict[BetaStopReason, FinishReason] = {
|
|
47
|
+
'end_turn': 'stop',
|
|
48
|
+
'max_tokens': 'length',
|
|
49
|
+
'stop_sequence': 'stop',
|
|
50
|
+
'tool_use': 'tool_call',
|
|
51
|
+
'pause_turn': 'stop',
|
|
52
|
+
'refusal': 'content_filter',
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
|
|
45
56
|
try:
|
|
46
57
|
from anthropic import NOT_GIVEN, APIStatusError, AsyncStream
|
|
47
58
|
from anthropic.types.beta import (
|
|
@@ -70,6 +81,7 @@ try:
|
|
|
70
81
|
BetaServerToolUseBlock,
|
|
71
82
|
BetaServerToolUseBlockParam,
|
|
72
83
|
BetaSignatureDelta,
|
|
84
|
+
BetaStopReason,
|
|
73
85
|
BetaTextBlock,
|
|
74
86
|
BetaTextBlockParam,
|
|
75
87
|
BetaTextDelta,
|
|
@@ -326,12 +338,20 @@ class AnthropicModel(Model):
|
|
|
326
338
|
)
|
|
327
339
|
)
|
|
328
340
|
|
|
341
|
+
finish_reason: FinishReason | None = None
|
|
342
|
+
provider_details: dict[str, Any] | None = None
|
|
343
|
+
if raw_finish_reason := response.stop_reason: # pragma: no branch
|
|
344
|
+
provider_details = {'finish_reason': raw_finish_reason}
|
|
345
|
+
finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
|
|
346
|
+
|
|
329
347
|
return ModelResponse(
|
|
330
348
|
parts=items,
|
|
331
349
|
usage=_map_usage(response),
|
|
332
350
|
model_name=response.model,
|
|
333
351
|
provider_response_id=response.id,
|
|
334
352
|
provider_name=self._provider.name,
|
|
353
|
+
finish_reason=finish_reason,
|
|
354
|
+
provider_details=provider_details,
|
|
335
355
|
)
|
|
336
356
|
|
|
337
357
|
async def _process_streamed_response(
|
|
@@ -536,7 +556,7 @@ class AnthropicModel(Model):
|
|
|
536
556
|
}
|
|
537
557
|
|
|
538
558
|
|
|
539
|
-
def _map_usage(message: BetaMessage |
|
|
559
|
+
def _map_usage(message: BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent) -> usage.RequestUsage:
|
|
540
560
|
if isinstance(message, BetaMessage):
|
|
541
561
|
response_usage = message.usage
|
|
542
562
|
elif isinstance(message, BetaRawMessageStartEvent):
|
|
@@ -544,12 +564,7 @@ def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.Reques
|
|
|
544
564
|
elif isinstance(message, BetaRawMessageDeltaEvent):
|
|
545
565
|
response_usage = message.usage
|
|
546
566
|
else:
|
|
547
|
-
|
|
548
|
-
# - RawMessageStopEvent
|
|
549
|
-
# - RawContentBlockStartEvent
|
|
550
|
-
# - RawContentBlockDeltaEvent
|
|
551
|
-
# - RawContentBlockStopEvent
|
|
552
|
-
return usage.RequestUsage()
|
|
567
|
+
assert_never(message)
|
|
553
568
|
|
|
554
569
|
# Store all integer-typed usage values in the details, except 'output_tokens' which is represented exactly by
|
|
555
570
|
# `response_tokens`
|
|
@@ -586,10 +601,9 @@ class AnthropicStreamedResponse(StreamedResponse):
|
|
|
586
601
|
current_block: BetaContentBlock | None = None
|
|
587
602
|
|
|
588
603
|
async for event in self._response:
|
|
589
|
-
self._usage += _map_usage(event)
|
|
590
|
-
|
|
591
604
|
if isinstance(event, BetaRawMessageStartEvent):
|
|
592
|
-
|
|
605
|
+
self._usage = _map_usage(event)
|
|
606
|
+
self.provider_response_id = event.message.id
|
|
593
607
|
|
|
594
608
|
elif isinstance(event, BetaRawContentBlockStartEvent):
|
|
595
609
|
current_block = event.content_block
|
|
@@ -652,7 +666,10 @@ class AnthropicStreamedResponse(StreamedResponse):
|
|
|
652
666
|
pass
|
|
653
667
|
|
|
654
668
|
elif isinstance(event, BetaRawMessageDeltaEvent):
|
|
655
|
-
|
|
669
|
+
self._usage = _map_usage(event)
|
|
670
|
+
if raw_finish_reason := event.delta.stop_reason: # pragma: no branch
|
|
671
|
+
self.provider_details = {'finish_reason': raw_finish_reason}
|
|
672
|
+
self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
|
|
656
673
|
|
|
657
674
|
elif isinstance(event, BetaRawContentBlockStopEvent | BetaRawMessageStopEvent): # pragma: no branch
|
|
658
675
|
current_block = None
|
pydantic_ai/models/bedrock.py
CHANGED
|
@@ -2,7 +2,6 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import functools
|
|
4
4
|
import typing
|
|
5
|
-
import warnings
|
|
6
5
|
from collections.abc import AsyncIterator, Iterable, Iterator, Mapping
|
|
7
6
|
from contextlib import asynccontextmanager
|
|
8
7
|
from dataclasses import dataclass, field
|
|
@@ -601,7 +600,7 @@ class BedrockStreamedResponse(StreamedResponse):
|
|
|
601
600
|
_provider_name: str
|
|
602
601
|
_timestamp: datetime = field(default_factory=_utils.now_utc)
|
|
603
602
|
|
|
604
|
-
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
603
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
605
604
|
"""Return an async iterator of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.
|
|
606
605
|
|
|
607
606
|
This method should be implemented by subclasses to translate the vendor-specific stream of events into
|
|
@@ -638,18 +637,11 @@ class BedrockStreamedResponse(StreamedResponse):
|
|
|
638
637
|
index = content_block_delta['contentBlockIndex']
|
|
639
638
|
delta = content_block_delta['delta']
|
|
640
639
|
if 'reasoningContent' in delta:
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
)
|
|
647
|
-
else: # pragma: no cover
|
|
648
|
-
warnings.warn(
|
|
649
|
-
f'Only text reasoning content is supported yet, but you got {delta["reasoningContent"]}. '
|
|
650
|
-
'Please report this to the maintainers.',
|
|
651
|
-
UserWarning,
|
|
652
|
-
)
|
|
640
|
+
yield self._parts_manager.handle_thinking_delta(
|
|
641
|
+
vendor_part_id=index,
|
|
642
|
+
content=delta['reasoningContent'].get('text'),
|
|
643
|
+
signature=delta['reasoningContent'].get('signature'),
|
|
644
|
+
)
|
|
653
645
|
if 'text' in delta:
|
|
654
646
|
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=delta['text'])
|
|
655
647
|
if maybe_event is not None: # pragma: no branch
|
pydantic_ai/models/gemini.py
CHANGED
|
@@ -211,7 +211,9 @@ class GeminiModel(Model):
|
|
|
211
211
|
generation_config = _settings_to_generation_config(model_settings)
|
|
212
212
|
if model_request_parameters.output_mode == 'native':
|
|
213
213
|
if tools:
|
|
214
|
-
raise UserError(
|
|
214
|
+
raise UserError(
|
|
215
|
+
'Gemini does not support `NativeOutput` and tools at the same time. Use `output_type=ToolOutput(...)` instead.'
|
|
216
|
+
)
|
|
215
217
|
|
|
216
218
|
generation_config['response_mime_type'] = 'application/json'
|
|
217
219
|
|
pydantic_ai/models/google.py
CHANGED
|
@@ -20,6 +20,7 @@ from ..messages import (
|
|
|
20
20
|
BuiltinToolCallPart,
|
|
21
21
|
BuiltinToolReturnPart,
|
|
22
22
|
FileUrl,
|
|
23
|
+
FinishReason,
|
|
23
24
|
ModelMessage,
|
|
24
25
|
ModelRequest,
|
|
25
26
|
ModelResponse,
|
|
@@ -54,6 +55,7 @@ try:
|
|
|
54
55
|
ContentUnionDict,
|
|
55
56
|
CountTokensConfigDict,
|
|
56
57
|
ExecutableCodeDict,
|
|
58
|
+
FinishReason as GoogleFinishReason,
|
|
57
59
|
FunctionCallDict,
|
|
58
60
|
FunctionCallingConfigDict,
|
|
59
61
|
FunctionCallingConfigMode,
|
|
@@ -99,6 +101,22 @@ allow any name in the type hints.
|
|
|
99
101
|
See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#model-variations) for a full list.
|
|
100
102
|
"""
|
|
101
103
|
|
|
104
|
+
_FINISH_REASON_MAP: dict[GoogleFinishReason, FinishReason | None] = {
|
|
105
|
+
GoogleFinishReason.FINISH_REASON_UNSPECIFIED: None,
|
|
106
|
+
GoogleFinishReason.STOP: 'stop',
|
|
107
|
+
GoogleFinishReason.MAX_TOKENS: 'length',
|
|
108
|
+
GoogleFinishReason.SAFETY: 'content_filter',
|
|
109
|
+
GoogleFinishReason.RECITATION: 'content_filter',
|
|
110
|
+
GoogleFinishReason.LANGUAGE: 'error',
|
|
111
|
+
GoogleFinishReason.OTHER: None,
|
|
112
|
+
GoogleFinishReason.BLOCKLIST: 'content_filter',
|
|
113
|
+
GoogleFinishReason.PROHIBITED_CONTENT: 'content_filter',
|
|
114
|
+
GoogleFinishReason.SPII: 'content_filter',
|
|
115
|
+
GoogleFinishReason.MALFORMED_FUNCTION_CALL: 'error',
|
|
116
|
+
GoogleFinishReason.IMAGE_SAFETY: 'content_filter',
|
|
117
|
+
GoogleFinishReason.UNEXPECTED_TOOL_CALL: 'error',
|
|
118
|
+
}
|
|
119
|
+
|
|
102
120
|
|
|
103
121
|
class GoogleModelSettings(ModelSettings, total=False):
|
|
104
122
|
"""Settings used for a Gemini model request."""
|
|
@@ -129,6 +147,12 @@ class GoogleModelSettings(ModelSettings, total=False):
|
|
|
129
147
|
See <https://ai.google.dev/api/generate-content#MediaResolution> for more information.
|
|
130
148
|
"""
|
|
131
149
|
|
|
150
|
+
google_cached_content: str
|
|
151
|
+
"""The name of the cached content to use for the model.
|
|
152
|
+
|
|
153
|
+
See <https://ai.google.dev/gemini-api/docs/caching> for more information.
|
|
154
|
+
"""
|
|
155
|
+
|
|
132
156
|
|
|
133
157
|
@dataclass(init=False)
|
|
134
158
|
class GoogleModel(Model):
|
|
@@ -264,6 +288,14 @@ class GoogleModel(Model):
|
|
|
264
288
|
yield await self._process_streamed_response(response, model_request_parameters) # type: ignore
|
|
265
289
|
|
|
266
290
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolDict] | None:
|
|
291
|
+
if model_request_parameters.builtin_tools:
|
|
292
|
+
if model_request_parameters.output_tools:
|
|
293
|
+
raise UserError(
|
|
294
|
+
'Gemini does not support output tools and built-in tools at the same time. Use `output_type=PromptedOutput(...)` instead.'
|
|
295
|
+
)
|
|
296
|
+
if model_request_parameters.function_tools:
|
|
297
|
+
raise UserError('Gemini does not support user tools and built-in tools at the same time.')
|
|
298
|
+
|
|
267
299
|
tools: list[ToolDict] = [
|
|
268
300
|
ToolDict(function_declarations=[_function_declaration_from_tool(t)])
|
|
269
301
|
for t in model_request_parameters.tool_defs.values()
|
|
@@ -334,7 +366,9 @@ class GoogleModel(Model):
|
|
|
334
366
|
response_schema = None
|
|
335
367
|
if model_request_parameters.output_mode == 'native':
|
|
336
368
|
if tools:
|
|
337
|
-
raise UserError(
|
|
369
|
+
raise UserError(
|
|
370
|
+
'Gemini does not support `NativeOutput` and tools at the same time. Use `output_type=ToolOutput(...)` instead.'
|
|
371
|
+
)
|
|
338
372
|
response_mime_type = 'application/json'
|
|
339
373
|
output_object = model_request_parameters.output_object
|
|
340
374
|
assert output_object is not None
|
|
@@ -367,6 +401,7 @@ class GoogleModel(Model):
|
|
|
367
401
|
thinking_config=model_settings.get('google_thinking_config'),
|
|
368
402
|
labels=model_settings.get('google_labels'),
|
|
369
403
|
media_resolution=model_settings.get('google_video_resolution'),
|
|
404
|
+
cached_content=model_settings.get('google_cached_content'),
|
|
370
405
|
tools=cast(ToolListUnionDict, tools),
|
|
371
406
|
tool_config=tool_config,
|
|
372
407
|
response_mime_type=response_mime_type,
|
|
@@ -386,11 +421,14 @@ class GoogleModel(Model):
|
|
|
386
421
|
'Content field missing from Gemini response', str(response)
|
|
387
422
|
) # pragma: no cover
|
|
388
423
|
parts = candidate.content.parts or []
|
|
389
|
-
|
|
424
|
+
|
|
425
|
+
vendor_id = response.response_id
|
|
390
426
|
vendor_details: dict[str, Any] | None = None
|
|
391
|
-
finish_reason =
|
|
392
|
-
if finish_reason: # pragma: no branch
|
|
393
|
-
vendor_details = {'finish_reason':
|
|
427
|
+
finish_reason: FinishReason | None = None
|
|
428
|
+
if raw_finish_reason := candidate.finish_reason: # pragma: no branch
|
|
429
|
+
vendor_details = {'finish_reason': raw_finish_reason.value}
|
|
430
|
+
finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
|
|
431
|
+
|
|
394
432
|
usage = _metadata_as_usage(response)
|
|
395
433
|
return _process_response_from_parts(
|
|
396
434
|
parts,
|
|
@@ -399,6 +437,7 @@ class GoogleModel(Model):
|
|
|
399
437
|
usage,
|
|
400
438
|
vendor_id=vendor_id,
|
|
401
439
|
vendor_details=vendor_details,
|
|
440
|
+
finish_reason=finish_reason,
|
|
402
441
|
)
|
|
403
442
|
|
|
404
443
|
async def _process_streamed_response(
|
|
@@ -533,6 +572,14 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
533
572
|
|
|
534
573
|
assert chunk.candidates is not None
|
|
535
574
|
candidate = chunk.candidates[0]
|
|
575
|
+
|
|
576
|
+
if chunk.response_id: # pragma: no branch
|
|
577
|
+
self.provider_response_id = chunk.response_id
|
|
578
|
+
|
|
579
|
+
if raw_finish_reason := candidate.finish_reason:
|
|
580
|
+
self.provider_details = {'finish_reason': raw_finish_reason.value}
|
|
581
|
+
self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
|
|
582
|
+
|
|
536
583
|
if candidate.content is None or candidate.content.parts is None:
|
|
537
584
|
if candidate.finish_reason == 'STOP': # pragma: no cover
|
|
538
585
|
# Normal completion - skip this chunk
|
|
@@ -559,6 +606,10 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
559
606
|
)
|
|
560
607
|
if maybe_event is not None: # pragma: no branch
|
|
561
608
|
yield maybe_event
|
|
609
|
+
elif part.executable_code is not None:
|
|
610
|
+
pass
|
|
611
|
+
elif part.code_execution_result is not None:
|
|
612
|
+
pass
|
|
562
613
|
else:
|
|
563
614
|
assert part.function_response is not None, f'Unexpected part: {part}' # pragma: no cover
|
|
564
615
|
|
|
@@ -611,6 +662,7 @@ def _process_response_from_parts(
|
|
|
611
662
|
usage: usage.RequestUsage,
|
|
612
663
|
vendor_id: str | None,
|
|
613
664
|
vendor_details: dict[str, Any] | None = None,
|
|
665
|
+
finish_reason: FinishReason | None = None,
|
|
614
666
|
) -> ModelResponse:
|
|
615
667
|
items: list[ModelResponsePart] = []
|
|
616
668
|
for part in parts:
|
|
@@ -651,6 +703,7 @@ def _process_response_from_parts(
|
|
|
651
703
|
provider_response_id=vendor_id,
|
|
652
704
|
provider_details=vendor_details,
|
|
653
705
|
provider_name=provider_name,
|
|
706
|
+
finish_reason=finish_reason,
|
|
654
707
|
)
|
|
655
708
|
|
|
656
709
|
|