pydantic-ai-slim 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. pydantic_ai/_agent_graph.py +50 -31
  2. pydantic_ai/_output.py +19 -7
  3. pydantic_ai/_parts_manager.py +8 -10
  4. pydantic_ai/_tool_manager.py +21 -0
  5. pydantic_ai/ag_ui.py +32 -17
  6. pydantic_ai/agent/__init__.py +3 -0
  7. pydantic_ai/agent/abstract.py +8 -0
  8. pydantic_ai/durable_exec/dbos/__init__.py +6 -0
  9. pydantic_ai/durable_exec/dbos/_agent.py +721 -0
  10. pydantic_ai/durable_exec/dbos/_mcp_server.py +89 -0
  11. pydantic_ai/durable_exec/dbos/_model.py +137 -0
  12. pydantic_ai/durable_exec/dbos/_utils.py +10 -0
  13. pydantic_ai/durable_exec/temporal/_agent.py +1 -1
  14. pydantic_ai/mcp.py +1 -1
  15. pydantic_ai/messages.py +42 -6
  16. pydantic_ai/models/__init__.py +8 -0
  17. pydantic_ai/models/anthropic.py +79 -25
  18. pydantic_ai/models/bedrock.py +82 -31
  19. pydantic_ai/models/cohere.py +39 -13
  20. pydantic_ai/models/function.py +8 -1
  21. pydantic_ai/models/google.py +105 -37
  22. pydantic_ai/models/groq.py +35 -7
  23. pydantic_ai/models/huggingface.py +27 -5
  24. pydantic_ai/models/instrumented.py +27 -14
  25. pydantic_ai/models/mistral.py +54 -20
  26. pydantic_ai/models/openai.py +151 -57
  27. pydantic_ai/profiles/openai.py +7 -0
  28. pydantic_ai/providers/bedrock.py +20 -4
  29. pydantic_ai/settings.py +1 -0
  30. pydantic_ai/tools.py +11 -0
  31. pydantic_ai/toolsets/function.py +7 -0
  32. {pydantic_ai_slim-1.0.1.dist-info → pydantic_ai_slim-1.0.3.dist-info}/METADATA +8 -6
  33. {pydantic_ai_slim-1.0.1.dist-info → pydantic_ai_slim-1.0.3.dist-info}/RECORD +36 -31
  34. {pydantic_ai_slim-1.0.1.dist-info → pydantic_ai_slim-1.0.3.dist-info}/WHEEL +0 -0
  35. {pydantic_ai_slim-1.0.1.dist-info → pydantic_ai_slim-1.0.3.dist-info}/entry_points.txt +0 -0
  36. {pydantic_ai_slim-1.0.1.dist-info → pydantic_ai_slim-1.0.3.dist-info}/licenses/LICENSE +0 -0
@@ -22,6 +22,7 @@ from pydantic_ai.messages import (
22
22
  BuiltinToolCallPart,
23
23
  BuiltinToolReturnPart,
24
24
  DocumentUrl,
25
+ FinishReason,
25
26
  ImageUrl,
26
27
  ModelMessage,
27
28
  ModelRequest,
@@ -48,6 +49,7 @@ if TYPE_CHECKING:
48
49
  from botocore.client import BaseClient
49
50
  from botocore.eventstream import EventStream
50
51
  from mypy_boto3_bedrock_runtime import BedrockRuntimeClient
52
+ from mypy_boto3_bedrock_runtime.literals import StopReasonType
51
53
  from mypy_boto3_bedrock_runtime.type_defs import (
52
54
  ContentBlockOutputTypeDef,
53
55
  ContentBlockUnionTypeDef,
@@ -55,6 +57,7 @@ if TYPE_CHECKING:
55
57
  ConverseResponseTypeDef,
56
58
  ConverseStreamMetadataEventTypeDef,
57
59
  ConverseStreamOutputTypeDef,
60
+ ConverseStreamResponseTypeDef,
58
61
  DocumentBlockTypeDef,
59
62
  GuardrailConfigurationTypeDef,
60
63
  ImageBlockTypeDef,
@@ -63,7 +66,6 @@ if TYPE_CHECKING:
63
66
  PerformanceConfigurationTypeDef,
64
67
  PromptVariableValuesTypeDef,
65
68
  ReasoningContentBlockOutputTypeDef,
66
- ReasoningTextBlockTypeDef,
67
69
  SystemContentBlockTypeDef,
68
70
  ToolChoiceTypeDef,
69
71
  ToolConfigurationTypeDef,
@@ -135,6 +137,15 @@ See [the Bedrock docs](https://docs.aws.amazon.com/bedrock/latest/userguide/mode
135
137
  P = ParamSpec('P')
136
138
  T = typing.TypeVar('T')
137
139
 
140
+ _FINISH_REASON_MAP: dict[StopReasonType, FinishReason] = {
141
+ 'content_filtered': 'content_filter',
142
+ 'end_turn': 'stop',
143
+ 'guardrail_intervened': 'content_filter',
144
+ 'max_tokens': 'length',
145
+ 'stop_sequence': 'stop',
146
+ 'tool_use': 'tool_call',
147
+ }
148
+
138
149
 
139
150
  class BedrockModelSettings(ModelSettings, total=False):
140
151
  """Settings for Bedrock models.
@@ -270,8 +281,9 @@ class BedrockConverseModel(Model):
270
281
  yield BedrockStreamedResponse(
271
282
  model_request_parameters=model_request_parameters,
272
283
  _model_name=self.model_name,
273
- _event_stream=response,
284
+ _event_stream=response['stream'],
274
285
  _provider_name=self._provider.name,
286
+ _provider_response_id=response.get('ResponseMetadata', {}).get('RequestId', None),
275
287
  )
276
288
 
277
289
  async def _process_response(self, response: ConverseResponseTypeDef) -> ModelResponse:
@@ -279,13 +291,24 @@ class BedrockConverseModel(Model):
279
291
  if message := response['output'].get('message'): # pragma: no branch
280
292
  for item in message['content']:
281
293
  if reasoning_content := item.get('reasoningContent'):
282
- reasoning_text = reasoning_content.get('reasoningText')
283
- if reasoning_text: # pragma: no branch
284
- thinking_part = ThinkingPart(
285
- content=reasoning_text['text'],
286
- signature=reasoning_text.get('signature'),
294
+ if redacted_content := reasoning_content.get('redactedContent'):
295
+ items.append(
296
+ ThinkingPart(
297
+ id='redacted_content',
298
+ content='',
299
+ signature=redacted_content.decode('utf-8'),
300
+ provider_name=self.system,
301
+ )
302
+ )
303
+ elif reasoning_text := reasoning_content.get('reasoningText'): # pragma: no branch
304
+ signature = reasoning_text.get('signature')
305
+ items.append(
306
+ ThinkingPart(
307
+ content=reasoning_text['text'],
308
+ signature=signature,
309
+ provider_name=self.system if signature else None,
310
+ )
287
311
  )
288
- items.append(thinking_part)
289
312
  if text := item.get('text'):
290
313
  items.append(TextPart(content=text))
291
314
  elif tool_use := item.get('toolUse'):
@@ -301,12 +324,18 @@ class BedrockConverseModel(Model):
301
324
  output_tokens=response['usage']['outputTokens'],
302
325
  )
303
326
  response_id = response.get('ResponseMetadata', {}).get('RequestId', None)
327
+ raw_finish_reason = response['stopReason']
328
+ provider_details = {'finish_reason': raw_finish_reason}
329
+ finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
330
+
304
331
  return ModelResponse(
305
332
  parts=items,
306
333
  usage=u,
307
334
  model_name=self.model_name,
308
335
  provider_response_id=response_id,
309
336
  provider_name=self._provider.name,
337
+ finish_reason=finish_reason,
338
+ provider_details=provider_details,
310
339
  )
311
340
 
312
341
  @overload
@@ -316,7 +345,7 @@ class BedrockConverseModel(Model):
316
345
  stream: Literal[True],
317
346
  model_settings: BedrockModelSettings | None,
318
347
  model_request_parameters: ModelRequestParameters,
319
- ) -> EventStream[ConverseStreamOutputTypeDef]:
348
+ ) -> ConverseStreamResponseTypeDef:
320
349
  pass
321
350
 
322
351
  @overload
@@ -335,7 +364,7 @@ class BedrockConverseModel(Model):
335
364
  stream: bool,
336
365
  model_settings: BedrockModelSettings | None,
337
366
  model_request_parameters: ModelRequestParameters,
338
- ) -> ConverseResponseTypeDef | EventStream[ConverseStreamOutputTypeDef]:
367
+ ) -> ConverseResponseTypeDef | ConverseStreamResponseTypeDef:
339
368
  system_prompt, bedrock_messages = await self._map_messages(messages)
340
369
  inference_config = self._map_inference_config(model_settings)
341
370
 
@@ -372,7 +401,6 @@ class BedrockConverseModel(Model):
372
401
 
373
402
  if stream:
374
403
  model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse_stream, **params))
375
- model_response = model_response['stream']
376
404
  else:
377
405
  model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse, **params))
378
406
  return model_response
@@ -476,19 +504,26 @@ class BedrockConverseModel(Model):
476
504
  if isinstance(item, TextPart):
477
505
  content.append({'text': item.content})
478
506
  elif isinstance(item, ThinkingPart):
479
- if BedrockModelProfile.from_profile(self.profile).bedrock_send_back_thinking_parts:
480
- reasoning_text: ReasoningTextBlockTypeDef = {
481
- 'text': item.content,
482
- }
483
- if item.signature:
484
- reasoning_text['signature'] = item.signature
485
- reasoning_content: ReasoningContentBlockOutputTypeDef = {
486
- 'reasoningText': reasoning_text,
487
- }
507
+ if (
508
+ item.provider_name == self.system
509
+ and item.signature
510
+ and BedrockModelProfile.from_profile(self.profile).bedrock_send_back_thinking_parts
511
+ ):
512
+ if item.id == 'redacted_content':
513
+ reasoning_content: ReasoningContentBlockOutputTypeDef = {
514
+ 'redactedContent': item.signature.encode('utf-8'),
515
+ }
516
+ else:
517
+ reasoning_content: ReasoningContentBlockOutputTypeDef = {
518
+ 'reasoningText': {
519
+ 'text': item.content,
520
+ 'signature': item.signature,
521
+ }
522
+ }
488
523
  content.append({'reasoningContent': reasoning_content})
489
524
  else:
490
- # NOTE: We don't pass the thinking part to Bedrock for models other than Claude since it raises an error.
491
- pass
525
+ start_tag, end_tag = self.profile.thinking_tags
526
+ content.append({'text': '\n'.join([start_tag, item.content, end_tag])})
492
527
  elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart):
493
528
  pass
494
529
  else:
@@ -599,25 +634,30 @@ class BedrockStreamedResponse(StreamedResponse):
599
634
  _event_stream: EventStream[ConverseStreamOutputTypeDef]
600
635
  _provider_name: str
601
636
  _timestamp: datetime = field(default_factory=_utils.now_utc)
637
+ _provider_response_id: str | None = None
602
638
 
603
- async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
639
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
604
640
  """Return an async iterator of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.
605
641
 
606
642
  This method should be implemented by subclasses to translate the vendor-specific stream of events into
607
643
  pydantic_ai-format events.
608
644
  """
645
+ if self._provider_response_id is not None: # pragma: no cover
646
+ self.provider_response_id = self._provider_response_id
647
+
609
648
  chunk: ConverseStreamOutputTypeDef
610
649
  tool_id: str | None = None
611
650
  async for chunk in _AsyncIteratorWrapper(self._event_stream):
612
651
  match chunk:
613
652
  case {'messageStart': _}:
614
653
  continue
615
- case {'messageStop': _}:
616
- continue
654
+ case {'messageStop': message_stop}:
655
+ raw_finish_reason = message_stop['stopReason']
656
+ self.provider_details = {'finish_reason': raw_finish_reason}
657
+ self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
617
658
  case {'metadata': metadata}:
618
659
  if 'usage' in metadata: # pragma: no branch
619
660
  self._usage += self._map_usage(metadata)
620
- continue
621
661
  case {'contentBlockStart': content_block_start}:
622
662
  index = content_block_start['contentBlockIndex']
623
663
  start = content_block_start['start']
@@ -637,11 +677,22 @@ class BedrockStreamedResponse(StreamedResponse):
637
677
  index = content_block_delta['contentBlockIndex']
638
678
  delta = content_block_delta['delta']
639
679
  if 'reasoningContent' in delta:
640
- yield self._parts_manager.handle_thinking_delta(
641
- vendor_part_id=index,
642
- content=delta['reasoningContent'].get('text'),
643
- signature=delta['reasoningContent'].get('signature'),
644
- )
680
+ if redacted_content := delta['reasoningContent'].get('redactedContent'):
681
+ yield self._parts_manager.handle_thinking_delta(
682
+ vendor_part_id=index,
683
+ id='redacted_content',
684
+ content='',
685
+ signature=redacted_content.decode('utf-8'),
686
+ provider_name=self.provider_name,
687
+ )
688
+ else:
689
+ signature = delta['reasoningContent'].get('signature')
690
+ yield self._parts_manager.handle_thinking_delta(
691
+ vendor_part_id=index,
692
+ content=delta['reasoningContent'].get('text'),
693
+ signature=signature,
694
+ provider_name=self.provider_name if signature else None,
695
+ )
645
696
  if 'text' in delta:
646
697
  maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=delta['text'])
647
698
  if maybe_event is not None: # pragma: no branch
@@ -6,7 +6,6 @@ from typing import Literal, cast
6
6
 
7
7
  from typing_extensions import assert_never
8
8
 
9
- from pydantic_ai._thinking_part import split_content_into_text_and_thinking
10
9
  from pydantic_ai.exceptions import UserError
11
10
 
12
11
  from .. import ModelHTTPError, usage
@@ -14,6 +13,7 @@ from .._utils import generate_tool_call_id as _generate_tool_call_id, guard_tool
14
13
  from ..messages import (
15
14
  BuiltinToolCallPart,
16
15
  BuiltinToolReturnPart,
16
+ FinishReason,
17
17
  ModelMessage,
18
18
  ModelRequest,
19
19
  ModelResponse,
@@ -35,10 +35,13 @@ from . import Model, ModelRequestParameters, check_allow_model_requests
35
35
  try:
36
36
  from cohere import (
37
37
  AssistantChatMessageV2,
38
+ AssistantMessageV2ContentItem,
38
39
  AsyncClientV2,
40
+ ChatFinishReason,
39
41
  ChatMessageV2,
40
42
  SystemChatMessageV2,
41
43
  TextAssistantMessageV2ContentItem,
44
+ ThinkingAssistantMessageV2ContentItem,
42
45
  ToolCallV2,
43
46
  ToolCallV2Function,
44
47
  ToolChatMessageV2,
@@ -80,6 +83,14 @@ allow any name in the type hints.
80
83
  See [Cohere's docs](https://docs.cohere.com/v2/docs/models) for a list of all available models.
81
84
  """
82
85
 
86
+ _FINISH_REASON_MAP: dict[ChatFinishReason, FinishReason] = {
87
+ 'COMPLETE': 'stop',
88
+ 'STOP_SEQUENCE': 'stop',
89
+ 'MAX_TOKENS': 'length',
90
+ 'TOOL_CALL': 'tool_call',
91
+ 'ERROR': 'error',
92
+ }
93
+
83
94
 
84
95
  class CohereModelSettings(ModelSettings, total=False):
85
96
  """Settings used for a Cohere model request."""
@@ -191,11 +202,12 @@ class CohereModel(Model):
191
202
  def _process_response(self, response: V2ChatResponse) -> ModelResponse:
192
203
  """Process a non-streamed response, and prepare a message to return."""
193
204
  parts: list[ModelResponsePart] = []
194
- if response.message.content is not None and len(response.message.content) > 0:
195
- # While Cohere's API returns a list, it only does that for future proofing
196
- # and currently only one item is being returned.
197
- choice = response.message.content[0]
198
- parts.extend(split_content_into_text_and_thinking(choice.text, self.profile.thinking_tags))
205
+ if response.message.content is not None:
206
+ for content in response.message.content:
207
+ if content.type == 'text':
208
+ parts.append(TextPart(content=content.text))
209
+ elif content.type == 'thinking': # pragma: no branch
210
+ parts.append(ThinkingPart(content=cast(str, content.thinking))) # pyright: ignore[reportUnknownMemberType,reportAttributeAccessIssue] - https://github.com/cohere-ai/cohere-python/issues/692
199
211
  for c in response.message.tool_calls or []:
200
212
  if c.function and c.function.name and c.function.arguments: # pragma: no branch
201
213
  parts.append(
@@ -205,8 +217,18 @@ class CohereModel(Model):
205
217
  tool_call_id=c.id or _generate_tool_call_id(),
206
218
  )
207
219
  )
220
+
221
+ raw_finish_reason = response.finish_reason
222
+ provider_details = {'finish_reason': raw_finish_reason}
223
+ finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
224
+
208
225
  return ModelResponse(
209
- parts=parts, usage=_map_usage(response), model_name=self._model_name, provider_name=self._provider.name
226
+ parts=parts,
227
+ usage=_map_usage(response),
228
+ model_name=self._model_name,
229
+ provider_name=self._provider.name,
230
+ finish_reason=finish_reason,
231
+ provider_details=provider_details,
210
232
  )
211
233
 
212
234
  def _map_messages(self, messages: list[ModelMessage]) -> list[ChatMessageV2]:
@@ -217,15 +239,13 @@ class CohereModel(Model):
217
239
  cohere_messages.extend(self._map_user_message(message))
218
240
  elif isinstance(message, ModelResponse):
219
241
  texts: list[str] = []
242
+ thinking: list[str] = []
220
243
  tool_calls: list[ToolCallV2] = []
221
244
  for item in message.parts:
222
245
  if isinstance(item, TextPart):
223
246
  texts.append(item.content)
224
247
  elif isinstance(item, ThinkingPart):
225
- # NOTE: We don't send ThinkingPart to the providers yet. If you are unsatisfied with this,
226
- # please open an issue. The below code is the code to send thinking to the provider.
227
- # texts.append(f'<think>\n{item.content}\n</think>')
228
- pass
248
+ thinking.append(item.content)
229
249
  elif isinstance(item, ToolCallPart):
230
250
  tool_calls.append(self._map_tool_call(item))
231
251
  elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
@@ -233,9 +253,15 @@ class CohereModel(Model):
233
253
  pass
234
254
  else:
235
255
  assert_never(item)
256
+
236
257
  message_param = AssistantChatMessageV2(role='assistant')
237
- if texts:
238
- message_param.content = [TextAssistantMessageV2ContentItem(text='\n\n'.join(texts))]
258
+ if texts or thinking:
259
+ contents: list[AssistantMessageV2ContentItem] = []
260
+ if thinking:
261
+ contents.append(ThinkingAssistantMessageV2ContentItem(thinking='\n\n'.join(thinking))) # pyright: ignore[reportCallIssue] - https://github.com/cohere-ai/cohere-python/issues/692
262
+ if texts: # pragma: no branch
263
+ contents.append(TextAssistantMessageV2ContentItem(text='\n\n'.join(texts)))
264
+ message_param.content = contents
239
265
  if tool_calls:
240
266
  message_param.tool_calls = tool_calls
241
267
  cohere_messages.append(message_param)
@@ -31,7 +31,7 @@ from ..messages import (
31
31
  UserContent,
32
32
  UserPromptPart,
33
33
  )
34
- from ..profiles import ModelProfileSpec
34
+ from ..profiles import ModelProfile, ModelProfileSpec
35
35
  from ..settings import ModelSettings
36
36
  from ..tools import ToolDefinition
37
37
  from . import Model, ModelRequestParameters, StreamedResponse
@@ -111,6 +111,12 @@ class FunctionModel(Model):
111
111
  stream_function_name = self.stream_function.__name__ if self.stream_function is not None else ''
112
112
  self._model_name = model_name or f'function:{function_name}:{stream_function_name}'
113
113
 
114
+ # Use a default profile that supports JSON schema and object output if none provided
115
+ if profile is None:
116
+ profile = ModelProfile(
117
+ supports_json_schema_output=True,
118
+ supports_json_object_output=True,
119
+ )
114
120
  super().__init__(settings=settings, profile=profile)
115
121
 
116
122
  async def request(
@@ -285,6 +291,7 @@ class FunctionStreamedResponse(StreamedResponse):
285
291
  vendor_part_id=dtc_index,
286
292
  content=delta.content,
287
293
  signature=delta.signature,
294
+ provider_name='function' if delta.signature else None,
288
295
  )
289
296
  elif isinstance(delta, DeltaToolCall):
290
297
  if delta.json_args:
@@ -20,6 +20,7 @@ from ..messages import (
20
20
  BuiltinToolCallPart,
21
21
  BuiltinToolReturnPart,
22
22
  FileUrl,
23
+ FinishReason,
23
24
  ModelMessage,
24
25
  ModelRequest,
25
26
  ModelResponse,
@@ -54,6 +55,7 @@ try:
54
55
  ContentUnionDict,
55
56
  CountTokensConfigDict,
56
57
  ExecutableCodeDict,
58
+ FinishReason as GoogleFinishReason,
57
59
  FunctionCallDict,
58
60
  FunctionCallingConfigDict,
59
61
  FunctionCallingConfigMode,
@@ -99,6 +101,22 @@ allow any name in the type hints.
99
101
  See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#model-variations) for a full list.
100
102
  """
101
103
 
104
+ _FINISH_REASON_MAP: dict[GoogleFinishReason, FinishReason | None] = {
105
+ GoogleFinishReason.FINISH_REASON_UNSPECIFIED: None,
106
+ GoogleFinishReason.STOP: 'stop',
107
+ GoogleFinishReason.MAX_TOKENS: 'length',
108
+ GoogleFinishReason.SAFETY: 'content_filter',
109
+ GoogleFinishReason.RECITATION: 'content_filter',
110
+ GoogleFinishReason.LANGUAGE: 'error',
111
+ GoogleFinishReason.OTHER: None,
112
+ GoogleFinishReason.BLOCKLIST: 'content_filter',
113
+ GoogleFinishReason.PROHIBITED_CONTENT: 'content_filter',
114
+ GoogleFinishReason.SPII: 'content_filter',
115
+ GoogleFinishReason.MALFORMED_FUNCTION_CALL: 'error',
116
+ GoogleFinishReason.IMAGE_SAFETY: 'content_filter',
117
+ GoogleFinishReason.UNEXPECTED_TOOL_CALL: 'error',
118
+ }
119
+
102
120
 
103
121
  class GoogleModelSettings(ModelSettings, total=False):
104
122
  """Settings used for a Gemini model request."""
@@ -129,6 +147,12 @@ class GoogleModelSettings(ModelSettings, total=False):
129
147
  See <https://ai.google.dev/api/generate-content#MediaResolution> for more information.
130
148
  """
131
149
 
150
+ google_cached_content: str
151
+ """The name of the cached content to use for the model.
152
+
153
+ See <https://ai.google.dev/gemini-api/docs/caching> for more information.
154
+ """
155
+
132
156
 
133
157
  @dataclass(init=False)
134
158
  class GoogleModel(Model):
@@ -230,6 +254,7 @@ class GoogleModel(Model):
230
254
  stop_sequences=generation_config.get('stop_sequences'),
231
255
  presence_penalty=generation_config.get('presence_penalty'),
232
256
  frequency_penalty=generation_config.get('frequency_penalty'),
257
+ seed=generation_config.get('seed'),
233
258
  thinking_config=generation_config.get('thinking_config'),
234
259
  media_resolution=generation_config.get('media_resolution'),
235
260
  response_mime_type=generation_config.get('response_mime_type'),
@@ -373,10 +398,12 @@ class GoogleModel(Model):
373
398
  stop_sequences=model_settings.get('stop_sequences'),
374
399
  presence_penalty=model_settings.get('presence_penalty'),
375
400
  frequency_penalty=model_settings.get('frequency_penalty'),
401
+ seed=model_settings.get('seed'),
376
402
  safety_settings=model_settings.get('google_safety_settings'),
377
403
  thinking_config=model_settings.get('google_thinking_config'),
378
404
  labels=model_settings.get('google_labels'),
379
405
  media_resolution=model_settings.get('google_video_resolution'),
406
+ cached_content=model_settings.get('google_cached_content'),
380
407
  tools=cast(ToolListUnionDict, tools),
381
408
  tool_config=tool_config,
382
409
  response_mime_type=response_mime_type,
@@ -396,11 +423,14 @@ class GoogleModel(Model):
396
423
  'Content field missing from Gemini response', str(response)
397
424
  ) # pragma: no cover
398
425
  parts = candidate.content.parts or []
399
- vendor_id = response.response_id or None
426
+
427
+ vendor_id = response.response_id
400
428
  vendor_details: dict[str, Any] | None = None
401
- finish_reason = candidate.finish_reason
402
- if finish_reason: # pragma: no branch
403
- vendor_details = {'finish_reason': finish_reason.value}
429
+ finish_reason: FinishReason | None = None
430
+ if raw_finish_reason := candidate.finish_reason: # pragma: no branch
431
+ vendor_details = {'finish_reason': raw_finish_reason.value}
432
+ finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
433
+
404
434
  usage = _metadata_as_usage(response)
405
435
  return _process_response_from_parts(
406
436
  parts,
@@ -409,6 +439,7 @@ class GoogleModel(Model):
409
439
  usage,
410
440
  vendor_id=vendor_id,
411
441
  vendor_details=vendor_details,
442
+ finish_reason=finish_reason,
412
443
  )
413
444
 
414
445
  async def _process_streamed_response(
@@ -422,7 +453,7 @@ class GoogleModel(Model):
422
453
 
423
454
  return GeminiStreamedResponse(
424
455
  model_request_parameters=model_request_parameters,
425
- _model_name=self._model_name,
456
+ _model_name=first_chunk.model_version or self._model_name,
426
457
  _response=peekable_response,
427
458
  _timestamp=first_chunk.create_time or _utils.now_utc(),
428
459
  _provider_name=self._provider.name,
@@ -472,7 +503,7 @@ class GoogleModel(Model):
472
503
  message_parts = [{'text': ''}]
473
504
  contents.append({'role': 'user', 'parts': message_parts})
474
505
  elif isinstance(m, ModelResponse):
475
- contents.append(_content_model_response(m))
506
+ contents.append(_content_model_response(m, self.system))
476
507
  else:
477
508
  assert_never(m)
478
509
  if instructions := self._get_instructions(messages):
@@ -537,12 +568,20 @@ class GeminiStreamedResponse(StreamedResponse):
537
568
  _timestamp: datetime
538
569
  _provider_name: str
539
570
 
540
- async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
571
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
541
572
  async for chunk in self._response:
542
573
  self._usage = _metadata_as_usage(chunk)
543
574
 
544
575
  assert chunk.candidates is not None
545
576
  candidate = chunk.candidates[0]
577
+
578
+ if chunk.response_id: # pragma: no branch
579
+ self.provider_response_id = chunk.response_id
580
+
581
+ if raw_finish_reason := candidate.finish_reason:
582
+ self.provider_details = {'finish_reason': raw_finish_reason.value}
583
+ self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
584
+
546
585
  if candidate.content is None or candidate.content.parts is None:
547
586
  if candidate.finish_reason == 'STOP': # pragma: no cover
548
587
  # Normal completion - skip this chunk
@@ -553,6 +592,15 @@ class GeminiStreamedResponse(StreamedResponse):
553
592
  raise UnexpectedModelBehavior('Content field missing from streaming Gemini response', str(chunk))
554
593
  parts = candidate.content.parts or []
555
594
  for part in parts:
595
+ if part.thought_signature:
596
+ signature = base64.b64encode(part.thought_signature).decode('utf-8')
597
+ yield self._parts_manager.handle_thinking_delta(
598
+ vendor_part_id='thinking',
599
+ content='', # A thought signature may occur without a preceding thinking part, so we add an empty delta so that a new part can be created
600
+ signature=signature,
601
+ provider_name=self.provider_name,
602
+ )
603
+
556
604
  if part.text is not None:
557
605
  if part.thought:
558
606
  yield self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=part.text)
@@ -592,29 +640,41 @@ class GeminiStreamedResponse(StreamedResponse):
592
640
  return self._timestamp
593
641
 
594
642
 
595
- def _content_model_response(m: ModelResponse) -> ContentDict:
643
+ def _content_model_response(m: ModelResponse, provider_name: str) -> ContentDict:
596
644
  parts: list[PartDict] = []
645
+ thought_signature: bytes | None = None
597
646
  for item in m.parts:
647
+ part: PartDict = {}
648
+ if thought_signature:
649
+ part['thought_signature'] = thought_signature
650
+ thought_signature = None
651
+
598
652
  if isinstance(item, ToolCallPart):
599
653
  function_call = FunctionCallDict(name=item.tool_name, args=item.args_as_dict(), id=item.tool_call_id)
600
- parts.append({'function_call': function_call})
654
+ part['function_call'] = function_call
601
655
  elif isinstance(item, TextPart):
602
- parts.append({'text': item.content})
603
- elif isinstance(item, ThinkingPart): # pragma: no cover
604
- # NOTE: We don't send ThinkingPart to the providers yet. If you are unsatisfied with this,
605
- # please open an issue. The below code is the code to send thinking to the provider.
606
- # parts.append({'text': item.content, 'thought': True})
607
- pass
656
+ part['text'] = item.content
657
+ elif isinstance(item, ThinkingPart):
658
+ if item.provider_name == provider_name and item.signature:
659
+ # The thought signature is to be included on the _next_ part, not the thought part itself
660
+ thought_signature = base64.b64decode(item.signature)
661
+
662
+ if item.content:
663
+ part['text'] = item.content
664
+ part['thought'] = True
608
665
  elif isinstance(item, BuiltinToolCallPart):
609
- if item.provider_name == 'google':
666
+ if item.provider_name == provider_name:
610
667
  if item.tool_name == 'code_execution': # pragma: no branch
611
- parts.append({'executable_code': cast(ExecutableCodeDict, item.args)})
668
+ part['executable_code'] = cast(ExecutableCodeDict, item.args)
612
669
  elif isinstance(item, BuiltinToolReturnPart):
613
- if item.provider_name == 'google':
670
+ if item.provider_name == provider_name:
614
671
  if item.tool_name == 'code_execution': # pragma: no branch
615
- parts.append({'code_execution_result': item.content})
672
+ part['code_execution_result'] = item.content
616
673
  else:
617
674
  assert_never(item)
675
+
676
+ if part:
677
+ parts.append(part)
618
678
  return ContentDict(role='model', parts=parts)
619
679
 
620
680
 
@@ -625,39 +685,46 @@ def _process_response_from_parts(
625
685
  usage: usage.RequestUsage,
626
686
  vendor_id: str | None,
627
687
  vendor_details: dict[str, Any] | None = None,
688
+ finish_reason: FinishReason | None = None,
628
689
  ) -> ModelResponse:
629
690
  items: list[ModelResponsePart] = []
691
+ item: ModelResponsePart | None = None
630
692
  for part in parts:
693
+ if part.thought_signature:
694
+ signature = base64.b64encode(part.thought_signature).decode('utf-8')
695
+ if not isinstance(item, ThinkingPart):
696
+ item = ThinkingPart(content='')
697
+ items.append(item)
698
+ item.signature = signature
699
+ item.provider_name = provider_name
700
+
631
701
  if part.executable_code is not None:
632
- items.append(
633
- BuiltinToolCallPart(
634
- provider_name='google', args=part.executable_code.model_dump(), tool_name='code_execution'
635
- )
702
+ item = BuiltinToolCallPart(
703
+ provider_name=provider_name, args=part.executable_code.model_dump(), tool_name='code_execution'
636
704
  )
637
705
  elif part.code_execution_result is not None:
638
- items.append(
639
- BuiltinToolReturnPart(
640
- provider_name='google',
641
- tool_name='code_execution',
642
- content=part.code_execution_result,
643
- tool_call_id='not_provided',
644
- )
706
+ item = BuiltinToolReturnPart(
707
+ provider_name=provider_name,
708
+ tool_name='code_execution',
709
+ content=part.code_execution_result,
710
+ tool_call_id='not_provided',
645
711
  )
646
712
  elif part.text is not None:
647
713
  if part.thought:
648
- items.append(ThinkingPart(content=part.text))
714
+ item = ThinkingPart(content=part.text)
649
715
  else:
650
- items.append(TextPart(content=part.text))
716
+ item = TextPart(content=part.text)
651
717
  elif part.function_call:
652
718
  assert part.function_call.name is not None
653
- tool_call_part = ToolCallPart(tool_name=part.function_call.name, args=part.function_call.args)
719
+ item = ToolCallPart(tool_name=part.function_call.name, args=part.function_call.args)
654
720
  if part.function_call.id is not None:
655
- tool_call_part.tool_call_id = part.function_call.id # pragma: no cover
656
- items.append(tool_call_part)
657
- elif part.function_response: # pragma: no cover
721
+ item.tool_call_id = part.function_call.id # pragma: no cover
722
+ else: # pragma: no cover
658
723
  raise UnexpectedModelBehavior(
659
- f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
724
+ f'Unsupported response from Gemini, expected all parts to be function calls, text, or thoughts, got: {part!r}'
660
725
  )
726
+
727
+ items.append(item)
661
728
  return ModelResponse(
662
729
  parts=items,
663
730
  model_name=model_name,
@@ -665,6 +732,7 @@ def _process_response_from_parts(
665
732
  provider_response_id=vendor_id,
666
733
  provider_details=vendor_details,
667
734
  provider_name=provider_name,
735
+ finish_reason=finish_reason,
668
736
  )
669
737
 
670
738