pydantic-ai-slim 1.0.0b1__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (38) hide show
  1. pydantic_ai/_a2a.py +1 -1
  2. pydantic_ai/_agent_graph.py +65 -49
  3. pydantic_ai/_parts_manager.py +3 -1
  4. pydantic_ai/_tool_manager.py +33 -6
  5. pydantic_ai/ag_ui.py +75 -43
  6. pydantic_ai/agent/__init__.py +10 -7
  7. pydantic_ai/durable_exec/dbos/__init__.py +6 -0
  8. pydantic_ai/durable_exec/dbos/_agent.py +718 -0
  9. pydantic_ai/durable_exec/dbos/_mcp_server.py +89 -0
  10. pydantic_ai/durable_exec/dbos/_model.py +137 -0
  11. pydantic_ai/durable_exec/dbos/_utils.py +10 -0
  12. pydantic_ai/durable_exec/temporal/_agent.py +71 -10
  13. pydantic_ai/exceptions.py +2 -2
  14. pydantic_ai/mcp.py +14 -26
  15. pydantic_ai/messages.py +90 -19
  16. pydantic_ai/models/__init__.py +9 -0
  17. pydantic_ai/models/anthropic.py +28 -11
  18. pydantic_ai/models/bedrock.py +6 -14
  19. pydantic_ai/models/gemini.py +3 -1
  20. pydantic_ai/models/google.py +58 -5
  21. pydantic_ai/models/groq.py +122 -34
  22. pydantic_ai/models/instrumented.py +29 -11
  23. pydantic_ai/models/openai.py +84 -29
  24. pydantic_ai/providers/__init__.py +4 -0
  25. pydantic_ai/providers/bedrock.py +11 -3
  26. pydantic_ai/providers/google_vertex.py +2 -1
  27. pydantic_ai/providers/groq.py +21 -2
  28. pydantic_ai/providers/litellm.py +134 -0
  29. pydantic_ai/retries.py +42 -2
  30. pydantic_ai/tools.py +18 -7
  31. pydantic_ai/toolsets/combined.py +2 -2
  32. pydantic_ai/toolsets/function.py +54 -19
  33. pydantic_ai/usage.py +37 -3
  34. {pydantic_ai_slim-1.0.0b1.dist-info → pydantic_ai_slim-1.0.2.dist-info}/METADATA +9 -8
  35. {pydantic_ai_slim-1.0.0b1.dist-info → pydantic_ai_slim-1.0.2.dist-info}/RECORD +38 -32
  36. {pydantic_ai_slim-1.0.0b1.dist-info → pydantic_ai_slim-1.0.2.dist-info}/WHEEL +0 -0
  37. {pydantic_ai_slim-1.0.0b1.dist-info → pydantic_ai_slim-1.0.2.dist-info}/entry_points.txt +0 -0
  38. {pydantic_ai_slim-1.0.0b1.dist-info → pydantic_ai_slim-1.0.2.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/messages.py CHANGED
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import base64
4
+ import hashlib
4
5
  from abc import ABC, abstractmethod
5
6
  from collections.abc import Sequence
6
7
  from dataclasses import KW_ONLY, dataclass, field, replace
@@ -51,6 +52,15 @@ ImageFormat: TypeAlias = Literal['jpeg', 'png', 'gif', 'webp']
51
52
  DocumentFormat: TypeAlias = Literal['csv', 'doc', 'docx', 'html', 'md', 'pdf', 'txt', 'xls', 'xlsx']
52
53
  VideoFormat: TypeAlias = Literal['mkv', 'mov', 'mp4', 'webm', 'flv', 'mpeg', 'mpg', 'wmv', 'three_gp']
53
54
 
55
+ FinishReason: TypeAlias = Literal[
56
+ 'stop',
57
+ 'length',
58
+ 'content_filter',
59
+ 'tool_call',
60
+ 'error',
61
+ ]
62
+ """Reason the model finished generating the response, normalized to OpenTelemetry values."""
63
+
54
64
 
55
65
  @dataclass(repr=False)
56
66
  class SystemPromptPart:
@@ -88,6 +98,13 @@ class SystemPromptPart:
88
98
  __repr__ = _utils.dataclasses_no_defaults_repr
89
99
 
90
100
 
101
+ def _multi_modal_content_identifier(identifier: str | bytes) -> str:
102
+ """Generate stable identifier for multi-modal content to help LLM in finding a specific file in tool call responses."""
103
+ if isinstance(identifier, str):
104
+ identifier = identifier.encode('utf-8')
105
+ return hashlib.sha1(identifier).hexdigest()[:6]
106
+
107
+
91
108
  @dataclass(init=False, repr=False)
92
109
  class FileUrl(ABC):
93
110
  """Abstract base class for any URL-based file."""
@@ -115,17 +132,31 @@ class FileUrl(ABC):
115
132
  compare=False, default=None
116
133
  )
117
134
 
135
+ identifier: str | None = None
136
+ """The identifier of the file, such as a unique ID. generating one from the url if not explicitly set
137
+
138
+ This identifier can be provided to the model in a message to allow it to refer to this file in a tool call argument,
139
+ and the tool can look up the file in question by iterating over the message history and finding the matching `FileUrl`.
140
+
141
+ This identifier is only automatically passed to the model when the `FileUrl` is returned by a tool.
142
+ If you're passing the `FileUrl` as a user message, it's up to you to include a separate text part with the identifier,
143
+ e.g. "This is file <identifier>:" preceding the `FileUrl`.
144
+ """
145
+
118
146
  def __init__(
119
147
  self,
120
148
  url: str,
149
+ *,
121
150
  force_download: bool = False,
122
151
  vendor_metadata: dict[str, Any] | None = None,
123
152
  media_type: str | None = None,
153
+ identifier: str | None = None,
124
154
  ) -> None:
125
155
  self.url = url
126
- self.vendor_metadata = vendor_metadata
127
156
  self.force_download = force_download
157
+ self.vendor_metadata = vendor_metadata
128
158
  self._media_type = media_type
159
+ self.identifier = identifier or _multi_modal_content_identifier(url)
129
160
 
130
161
  @pydantic.computed_field
131
162
  @property
@@ -162,11 +193,12 @@ class VideoUrl(FileUrl):
162
193
  def __init__(
163
194
  self,
164
195
  url: str,
196
+ *,
165
197
  force_download: bool = False,
166
198
  vendor_metadata: dict[str, Any] | None = None,
167
199
  media_type: str | None = None,
168
200
  kind: Literal['video-url'] = 'video-url',
169
- *,
201
+ identifier: str | None = None,
170
202
  # Required for inline-snapshot which expects all dataclass `__init__` methods to take all field names as kwargs.
171
203
  _media_type: str | None = None,
172
204
  ) -> None:
@@ -175,6 +207,7 @@ class VideoUrl(FileUrl):
175
207
  force_download=force_download,
176
208
  vendor_metadata=vendor_metadata,
177
209
  media_type=media_type or _media_type,
210
+ identifier=identifier,
178
211
  )
179
212
  self.kind = kind
180
213
 
@@ -235,11 +268,12 @@ class AudioUrl(FileUrl):
235
268
  def __init__(
236
269
  self,
237
270
  url: str,
271
+ *,
238
272
  force_download: bool = False,
239
273
  vendor_metadata: dict[str, Any] | None = None,
240
274
  media_type: str | None = None,
241
275
  kind: Literal['audio-url'] = 'audio-url',
242
- *,
276
+ identifier: str | None = None,
243
277
  # Required for inline-snapshot which expects all dataclass `__init__` methods to take all field names as kwargs.
244
278
  _media_type: str | None = None,
245
279
  ) -> None:
@@ -248,6 +282,7 @@ class AudioUrl(FileUrl):
248
282
  force_download=force_download,
249
283
  vendor_metadata=vendor_metadata,
250
284
  media_type=media_type or _media_type,
285
+ identifier=identifier,
251
286
  )
252
287
  self.kind = kind
253
288
 
@@ -295,11 +330,12 @@ class ImageUrl(FileUrl):
295
330
  def __init__(
296
331
  self,
297
332
  url: str,
333
+ *,
298
334
  force_download: bool = False,
299
335
  vendor_metadata: dict[str, Any] | None = None,
300
336
  media_type: str | None = None,
301
337
  kind: Literal['image-url'] = 'image-url',
302
- *,
338
+ identifier: str | None = None,
303
339
  # Required for inline-snapshot which expects all dataclass `__init__` methods to take all field names as kwargs.
304
340
  _media_type: str | None = None,
305
341
  ) -> None:
@@ -308,6 +344,7 @@ class ImageUrl(FileUrl):
308
344
  force_download=force_download,
309
345
  vendor_metadata=vendor_metadata,
310
346
  media_type=media_type or _media_type,
347
+ identifier=identifier,
311
348
  )
312
349
  self.kind = kind
313
350
 
@@ -350,11 +387,12 @@ class DocumentUrl(FileUrl):
350
387
  def __init__(
351
388
  self,
352
389
  url: str,
390
+ *,
353
391
  force_download: bool = False,
354
392
  vendor_metadata: dict[str, Any] | None = None,
355
393
  media_type: str | None = None,
356
394
  kind: Literal['document-url'] = 'document-url',
357
- *,
395
+ identifier: str | None = None,
358
396
  # Required for inline-snapshot which expects all dataclass `__init__` methods to take all field names as kwargs.
359
397
  _media_type: str | None = None,
360
398
  ) -> None:
@@ -363,6 +401,7 @@ class DocumentUrl(FileUrl):
363
401
  force_download=force_download,
364
402
  vendor_metadata=vendor_metadata,
365
403
  media_type=media_type or _media_type,
404
+ identifier=identifier,
366
405
  )
367
406
  self.kind = kind
368
407
 
@@ -405,24 +444,26 @@ class DocumentUrl(FileUrl):
405
444
  raise ValueError(f'Unknown document media type: {media_type}') from e
406
445
 
407
446
 
408
- @dataclass(repr=False)
447
+ @dataclass(init=False, repr=False)
409
448
  class BinaryContent:
410
449
  """Binary content, e.g. an audio or image file."""
411
450
 
412
451
  data: bytes
413
452
  """The binary data."""
414
453
 
415
- media_type: AudioMediaType | ImageMediaType | DocumentMediaType | str
416
- """The media type of the binary data."""
417
-
418
454
  _: KW_ONLY
419
455
 
420
- identifier: str | None = None
421
- """Identifier for the binary content, such as a URL or unique ID.
456
+ media_type: AudioMediaType | ImageMediaType | DocumentMediaType | str
457
+ """The media type of the binary data."""
422
458
 
423
- This identifier can be provided to the model in a message to allow it to refer to this file in a tool call argument, and the tool can look up the file in question by iterating over the message history and finding the matching `BinaryContent`.
459
+ identifier: str
460
+ """Identifier for the binary content, such as a unique ID. generating one from the data if not explicitly set
461
+ This identifier can be provided to the model in a message to allow it to refer to this file in a tool call argument,
462
+ and the tool can look up the file in question by iterating over the message history and finding the matching `BinaryContent`.
424
463
 
425
- This identifier is only automatically passed to the model when the `BinaryContent` is returned by a tool. If you're passing the `BinaryContent` as a user message, it's up to you to include a separate text part with the identifier, e.g. "This is file <identifier>:" preceding the `BinaryContent`.
464
+ This identifier is only automatically passed to the model when the `BinaryContent` is returned by a tool.
465
+ If you're passing the `BinaryContent` as a user message, it's up to you to include a separate text part with the identifier,
466
+ e.g. "This is file <identifier>:" preceding the `BinaryContent`.
426
467
  """
427
468
 
428
469
  vendor_metadata: dict[str, Any] | None = None
@@ -435,6 +476,21 @@ class BinaryContent:
435
476
  kind: Literal['binary'] = 'binary'
436
477
  """Type identifier, this is available on all parts as a discriminator."""
437
478
 
479
+ def __init__(
480
+ self,
481
+ data: bytes,
482
+ *,
483
+ media_type: AudioMediaType | ImageMediaType | DocumentMediaType | str,
484
+ identifier: str | None = None,
485
+ vendor_metadata: dict[str, Any] | None = None,
486
+ kind: Literal['binary'] = 'binary',
487
+ ) -> None:
488
+ self.data = data
489
+ self.media_type = media_type
490
+ self.identifier = identifier or _multi_modal_content_identifier(data)
491
+ self.vendor_metadata = vendor_metadata
492
+ self.kind = kind
493
+
438
494
  @property
439
495
  def is_audio(self) -> bool:
440
496
  """Return `True` if the media type is an audio type."""
@@ -786,7 +842,7 @@ ModelRequestPart = Annotated[
786
842
  class ModelRequest:
787
843
  """A request generated by Pydantic AI and sent to a model, e.g. a message from the Pydantic AI app to the model."""
788
844
 
789
- parts: list[ModelRequestPart]
845
+ parts: Sequence[ModelRequestPart]
790
846
  """The parts of the user message."""
791
847
 
792
848
  _: KW_ONLY
@@ -941,7 +997,7 @@ ModelResponsePart = Annotated[
941
997
  class ModelResponse:
942
998
  """A response from a model, e.g. a message from the model to the Pydantic AI app."""
943
999
 
944
- parts: list[ModelResponsePart]
1000
+ parts: Sequence[ModelResponsePart]
945
1001
  """The parts of the model message."""
946
1002
 
947
1003
  _: KW_ONLY
@@ -967,18 +1023,33 @@ class ModelResponse:
967
1023
  provider_name: str | None = None
968
1024
  """The name of the LLM provider that generated the response."""
969
1025
 
970
- provider_details: dict[str, Any] | None = field(default=None)
1026
+ provider_details: Annotated[
1027
+ dict[str, Any] | None,
1028
+ # `vendor_details` is deprecated, but we still want to support deserializing model responses stored in a DB before the name was changed
1029
+ pydantic.Field(validation_alias=pydantic.AliasChoices('provider_details', 'vendor_details')),
1030
+ ] = None
971
1031
  """Additional provider-specific details in a serializable format.
972
1032
 
973
1033
  This allows storing selected vendor-specific data that isn't mapped to standard ModelResponse fields.
974
1034
  For OpenAI models, this may include 'logprobs', 'finish_reason', etc.
975
1035
  """
976
1036
 
977
- provider_response_id: str | None = None
1037
+ provider_response_id: Annotated[
1038
+ str | None,
1039
+ # `vendor_id` is deprecated, but we still want to support deserializing model responses stored in a DB before the name was changed
1040
+ pydantic.Field(validation_alias=pydantic.AliasChoices('provider_response_id', 'vendor_id')),
1041
+ ] = None
978
1042
  """request ID as specified by the model provider. This can be used to track the specific request to the model."""
979
1043
 
980
- def price(self) -> genai_types.PriceCalculation:
981
- """Calculate the price of the usage.
1044
+ finish_reason: FinishReason | None = None
1045
+ """Reason the model finished generating the response, normalized to OpenTelemetry values."""
1046
+
1047
+ @deprecated('`price` is deprecated, use `cost` instead')
1048
+ def price(self) -> genai_types.PriceCalculation: # pragma: no cover
1049
+ return self.cost()
1050
+
1051
+ def cost(self) -> genai_types.PriceCalculation:
1052
+ """Calculate the cost of the usage.
982
1053
 
983
1054
  Uses [`genai-prices`](https://github.com/pydantic/genai-prices).
984
1055
  """
@@ -28,6 +28,7 @@ from ..exceptions import UserError
28
28
  from ..messages import (
29
29
  FileUrl,
30
30
  FinalResultEvent,
31
+ FinishReason,
31
32
  ModelMessage,
32
33
  ModelRequest,
33
34
  ModelResponse,
@@ -555,6 +556,10 @@ class StreamedResponse(ABC):
555
556
 
556
557
  final_result_event: FinalResultEvent | None = field(default=None, init=False)
557
558
 
559
+ provider_response_id: str | None = field(default=None, init=False)
560
+ provider_details: dict[str, Any] | None = field(default=None, init=False)
561
+ finish_reason: FinishReason | None = field(default=None, init=False)
562
+
558
563
  _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False)
559
564
  _event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False)
560
565
  _usage: RequestUsage = field(default_factory=RequestUsage, init=False)
@@ -609,6 +614,9 @@ class StreamedResponse(ABC):
609
614
  timestamp=self.timestamp,
610
615
  usage=self.usage(),
611
616
  provider_name=self.provider_name,
617
+ provider_response_id=self.provider_response_id,
618
+ provider_details=self.provider_details,
619
+ finish_reason=self.finish_reason,
612
620
  )
613
621
 
614
622
  def usage(self) -> RequestUsage:
@@ -728,6 +736,7 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
728
736
  'openrouter',
729
737
  'together',
730
738
  'vercel',
739
+ 'litellm',
731
740
  ):
732
741
  from .openai import OpenAIChatModel
733
742
 
@@ -21,6 +21,7 @@ from ..messages import (
21
21
  BuiltinToolCallPart,
22
22
  BuiltinToolReturnPart,
23
23
  DocumentUrl,
24
+ FinishReason,
24
25
  ImageUrl,
25
26
  ModelMessage,
26
27
  ModelRequest,
@@ -42,6 +43,16 @@ from ..settings import ModelSettings
42
43
  from ..tools import ToolDefinition
43
44
  from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent
44
45
 
46
+ _FINISH_REASON_MAP: dict[BetaStopReason, FinishReason] = {
47
+ 'end_turn': 'stop',
48
+ 'max_tokens': 'length',
49
+ 'stop_sequence': 'stop',
50
+ 'tool_use': 'tool_call',
51
+ 'pause_turn': 'stop',
52
+ 'refusal': 'content_filter',
53
+ }
54
+
55
+
45
56
  try:
46
57
  from anthropic import NOT_GIVEN, APIStatusError, AsyncStream
47
58
  from anthropic.types.beta import (
@@ -70,6 +81,7 @@ try:
70
81
  BetaServerToolUseBlock,
71
82
  BetaServerToolUseBlockParam,
72
83
  BetaSignatureDelta,
84
+ BetaStopReason,
73
85
  BetaTextBlock,
74
86
  BetaTextBlockParam,
75
87
  BetaTextDelta,
@@ -326,12 +338,20 @@ class AnthropicModel(Model):
326
338
  )
327
339
  )
328
340
 
341
+ finish_reason: FinishReason | None = None
342
+ provider_details: dict[str, Any] | None = None
343
+ if raw_finish_reason := response.stop_reason: # pragma: no branch
344
+ provider_details = {'finish_reason': raw_finish_reason}
345
+ finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
346
+
329
347
  return ModelResponse(
330
348
  parts=items,
331
349
  usage=_map_usage(response),
332
350
  model_name=response.model,
333
351
  provider_response_id=response.id,
334
352
  provider_name=self._provider.name,
353
+ finish_reason=finish_reason,
354
+ provider_details=provider_details,
335
355
  )
336
356
 
337
357
  async def _process_streamed_response(
@@ -536,7 +556,7 @@ class AnthropicModel(Model):
536
556
  }
537
557
 
538
558
 
539
- def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.RequestUsage:
559
+ def _map_usage(message: BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent) -> usage.RequestUsage:
540
560
  if isinstance(message, BetaMessage):
541
561
  response_usage = message.usage
542
562
  elif isinstance(message, BetaRawMessageStartEvent):
@@ -544,12 +564,7 @@ def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.Reques
544
564
  elif isinstance(message, BetaRawMessageDeltaEvent):
545
565
  response_usage = message.usage
546
566
  else:
547
- # No usage information provided in:
548
- # - RawMessageStopEvent
549
- # - RawContentBlockStartEvent
550
- # - RawContentBlockDeltaEvent
551
- # - RawContentBlockStopEvent
552
- return usage.RequestUsage()
567
+ assert_never(message)
553
568
 
554
569
  # Store all integer-typed usage values in the details, except 'output_tokens' which is represented exactly by
555
570
  # `response_tokens`
@@ -586,10 +601,9 @@ class AnthropicStreamedResponse(StreamedResponse):
586
601
  current_block: BetaContentBlock | None = None
587
602
 
588
603
  async for event in self._response:
589
- self._usage += _map_usage(event)
590
-
591
604
  if isinstance(event, BetaRawMessageStartEvent):
592
- pass
605
+ self._usage = _map_usage(event)
606
+ self.provider_response_id = event.message.id
593
607
 
594
608
  elif isinstance(event, BetaRawContentBlockStartEvent):
595
609
  current_block = event.content_block
@@ -652,7 +666,10 @@ class AnthropicStreamedResponse(StreamedResponse):
652
666
  pass
653
667
 
654
668
  elif isinstance(event, BetaRawMessageDeltaEvent):
655
- pass
669
+ self._usage = _map_usage(event)
670
+ if raw_finish_reason := event.delta.stop_reason: # pragma: no branch
671
+ self.provider_details = {'finish_reason': raw_finish_reason}
672
+ self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
656
673
 
657
674
  elif isinstance(event, BetaRawContentBlockStopEvent | BetaRawMessageStopEvent): # pragma: no branch
658
675
  current_block = None
@@ -2,7 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  import functools
4
4
  import typing
5
- import warnings
6
5
  from collections.abc import AsyncIterator, Iterable, Iterator, Mapping
7
6
  from contextlib import asynccontextmanager
8
7
  from dataclasses import dataclass, field
@@ -601,7 +600,7 @@ class BedrockStreamedResponse(StreamedResponse):
601
600
  _provider_name: str
602
601
  _timestamp: datetime = field(default_factory=_utils.now_utc)
603
602
 
604
- async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
603
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
605
604
  """Return an async iterator of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.
606
605
 
607
606
  This method should be implemented by subclasses to translate the vendor-specific stream of events into
@@ -638,18 +637,11 @@ class BedrockStreamedResponse(StreamedResponse):
638
637
  index = content_block_delta['contentBlockIndex']
639
638
  delta = content_block_delta['delta']
640
639
  if 'reasoningContent' in delta:
641
- if text := delta['reasoningContent'].get('text'):
642
- yield self._parts_manager.handle_thinking_delta(
643
- vendor_part_id=index,
644
- content=text,
645
- signature=delta['reasoningContent'].get('signature'),
646
- )
647
- else: # pragma: no cover
648
- warnings.warn(
649
- f'Only text reasoning content is supported yet, but you got {delta["reasoningContent"]}. '
650
- 'Please report this to the maintainers.',
651
- UserWarning,
652
- )
640
+ yield self._parts_manager.handle_thinking_delta(
641
+ vendor_part_id=index,
642
+ content=delta['reasoningContent'].get('text'),
643
+ signature=delta['reasoningContent'].get('signature'),
644
+ )
653
645
  if 'text' in delta:
654
646
  maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=delta['text'])
655
647
  if maybe_event is not None: # pragma: no branch
@@ -211,7 +211,9 @@ class GeminiModel(Model):
211
211
  generation_config = _settings_to_generation_config(model_settings)
212
212
  if model_request_parameters.output_mode == 'native':
213
213
  if tools:
214
- raise UserError('Gemini does not support structured output and tools at the same time.')
214
+ raise UserError(
215
+ 'Gemini does not support `NativeOutput` and tools at the same time. Use `output_type=ToolOutput(...)` instead.'
216
+ )
215
217
 
216
218
  generation_config['response_mime_type'] = 'application/json'
217
219
 
@@ -20,6 +20,7 @@ from ..messages import (
20
20
  BuiltinToolCallPart,
21
21
  BuiltinToolReturnPart,
22
22
  FileUrl,
23
+ FinishReason,
23
24
  ModelMessage,
24
25
  ModelRequest,
25
26
  ModelResponse,
@@ -54,6 +55,7 @@ try:
54
55
  ContentUnionDict,
55
56
  CountTokensConfigDict,
56
57
  ExecutableCodeDict,
58
+ FinishReason as GoogleFinishReason,
57
59
  FunctionCallDict,
58
60
  FunctionCallingConfigDict,
59
61
  FunctionCallingConfigMode,
@@ -99,6 +101,22 @@ allow any name in the type hints.
99
101
  See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#model-variations) for a full list.
100
102
  """
101
103
 
104
+ _FINISH_REASON_MAP: dict[GoogleFinishReason, FinishReason | None] = {
105
+ GoogleFinishReason.FINISH_REASON_UNSPECIFIED: None,
106
+ GoogleFinishReason.STOP: 'stop',
107
+ GoogleFinishReason.MAX_TOKENS: 'length',
108
+ GoogleFinishReason.SAFETY: 'content_filter',
109
+ GoogleFinishReason.RECITATION: 'content_filter',
110
+ GoogleFinishReason.LANGUAGE: 'error',
111
+ GoogleFinishReason.OTHER: None,
112
+ GoogleFinishReason.BLOCKLIST: 'content_filter',
113
+ GoogleFinishReason.PROHIBITED_CONTENT: 'content_filter',
114
+ GoogleFinishReason.SPII: 'content_filter',
115
+ GoogleFinishReason.MALFORMED_FUNCTION_CALL: 'error',
116
+ GoogleFinishReason.IMAGE_SAFETY: 'content_filter',
117
+ GoogleFinishReason.UNEXPECTED_TOOL_CALL: 'error',
118
+ }
119
+
102
120
 
103
121
  class GoogleModelSettings(ModelSettings, total=False):
104
122
  """Settings used for a Gemini model request."""
@@ -129,6 +147,12 @@ class GoogleModelSettings(ModelSettings, total=False):
129
147
  See <https://ai.google.dev/api/generate-content#MediaResolution> for more information.
130
148
  """
131
149
 
150
+ google_cached_content: str
151
+ """The name of the cached content to use for the model.
152
+
153
+ See <https://ai.google.dev/gemini-api/docs/caching> for more information.
154
+ """
155
+
132
156
 
133
157
  @dataclass(init=False)
134
158
  class GoogleModel(Model):
@@ -264,6 +288,14 @@ class GoogleModel(Model):
264
288
  yield await self._process_streamed_response(response, model_request_parameters) # type: ignore
265
289
 
266
290
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolDict] | None:
291
+ if model_request_parameters.builtin_tools:
292
+ if model_request_parameters.output_tools:
293
+ raise UserError(
294
+ 'Gemini does not support output tools and built-in tools at the same time. Use `output_type=PromptedOutput(...)` instead.'
295
+ )
296
+ if model_request_parameters.function_tools:
297
+ raise UserError('Gemini does not support user tools and built-in tools at the same time.')
298
+
267
299
  tools: list[ToolDict] = [
268
300
  ToolDict(function_declarations=[_function_declaration_from_tool(t)])
269
301
  for t in model_request_parameters.tool_defs.values()
@@ -334,7 +366,9 @@ class GoogleModel(Model):
334
366
  response_schema = None
335
367
  if model_request_parameters.output_mode == 'native':
336
368
  if tools:
337
- raise UserError('Gemini does not support structured output and tools at the same time.')
369
+ raise UserError(
370
+ 'Gemini does not support `NativeOutput` and tools at the same time. Use `output_type=ToolOutput(...)` instead.'
371
+ )
338
372
  response_mime_type = 'application/json'
339
373
  output_object = model_request_parameters.output_object
340
374
  assert output_object is not None
@@ -367,6 +401,7 @@ class GoogleModel(Model):
367
401
  thinking_config=model_settings.get('google_thinking_config'),
368
402
  labels=model_settings.get('google_labels'),
369
403
  media_resolution=model_settings.get('google_video_resolution'),
404
+ cached_content=model_settings.get('google_cached_content'),
370
405
  tools=cast(ToolListUnionDict, tools),
371
406
  tool_config=tool_config,
372
407
  response_mime_type=response_mime_type,
@@ -386,11 +421,14 @@ class GoogleModel(Model):
386
421
  'Content field missing from Gemini response', str(response)
387
422
  ) # pragma: no cover
388
423
  parts = candidate.content.parts or []
389
- vendor_id = response.response_id or None
424
+
425
+ vendor_id = response.response_id
390
426
  vendor_details: dict[str, Any] | None = None
391
- finish_reason = candidate.finish_reason
392
- if finish_reason: # pragma: no branch
393
- vendor_details = {'finish_reason': finish_reason.value}
427
+ finish_reason: FinishReason | None = None
428
+ if raw_finish_reason := candidate.finish_reason: # pragma: no branch
429
+ vendor_details = {'finish_reason': raw_finish_reason.value}
430
+ finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
431
+
394
432
  usage = _metadata_as_usage(response)
395
433
  return _process_response_from_parts(
396
434
  parts,
@@ -399,6 +437,7 @@ class GoogleModel(Model):
399
437
  usage,
400
438
  vendor_id=vendor_id,
401
439
  vendor_details=vendor_details,
440
+ finish_reason=finish_reason,
402
441
  )
403
442
 
404
443
  async def _process_streamed_response(
@@ -533,6 +572,14 @@ class GeminiStreamedResponse(StreamedResponse):
533
572
 
534
573
  assert chunk.candidates is not None
535
574
  candidate = chunk.candidates[0]
575
+
576
+ if chunk.response_id: # pragma: no branch
577
+ self.provider_response_id = chunk.response_id
578
+
579
+ if raw_finish_reason := candidate.finish_reason:
580
+ self.provider_details = {'finish_reason': raw_finish_reason.value}
581
+ self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
582
+
536
583
  if candidate.content is None or candidate.content.parts is None:
537
584
  if candidate.finish_reason == 'STOP': # pragma: no cover
538
585
  # Normal completion - skip this chunk
@@ -559,6 +606,10 @@ class GeminiStreamedResponse(StreamedResponse):
559
606
  )
560
607
  if maybe_event is not None: # pragma: no branch
561
608
  yield maybe_event
609
+ elif part.executable_code is not None:
610
+ pass
611
+ elif part.code_execution_result is not None:
612
+ pass
562
613
  else:
563
614
  assert part.function_response is not None, f'Unexpected part: {part}' # pragma: no cover
564
615
 
@@ -611,6 +662,7 @@ def _process_response_from_parts(
611
662
  usage: usage.RequestUsage,
612
663
  vendor_id: str | None,
613
664
  vendor_details: dict[str, Any] | None = None,
665
+ finish_reason: FinishReason | None = None,
614
666
  ) -> ModelResponse:
615
667
  items: list[ModelResponsePart] = []
616
668
  for part in parts:
@@ -651,6 +703,7 @@ def _process_response_from_parts(
651
703
  provider_response_id=vendor_id,
652
704
  provider_details=vendor_details,
653
705
  provider_name=provider_name,
706
+ finish_reason=finish_reason,
654
707
  )
655
708
 
656
709