pydantic-ai-slim 1.0.2__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (37) hide show
  1. pydantic_ai/_output.py +19 -7
  2. pydantic_ai/_parts_manager.py +10 -12
  3. pydantic_ai/_tool_manager.py +18 -1
  4. pydantic_ai/ag_ui.py +32 -17
  5. pydantic_ai/agent/abstract.py +8 -0
  6. pydantic_ai/durable_exec/dbos/_agent.py +5 -2
  7. pydantic_ai/durable_exec/temporal/_agent.py +1 -1
  8. pydantic_ai/messages.py +30 -6
  9. pydantic_ai/models/__init__.py +5 -1
  10. pydantic_ai/models/anthropic.py +54 -25
  11. pydantic_ai/models/bedrock.py +81 -31
  12. pydantic_ai/models/cohere.py +39 -13
  13. pydantic_ai/models/function.py +8 -1
  14. pydantic_ai/models/google.py +61 -33
  15. pydantic_ai/models/groq.py +35 -7
  16. pydantic_ai/models/huggingface.py +27 -5
  17. pydantic_ai/models/mistral.py +55 -21
  18. pydantic_ai/models/openai.py +135 -63
  19. pydantic_ai/profiles/openai.py +11 -0
  20. pydantic_ai/providers/__init__.py +3 -0
  21. pydantic_ai/providers/anthropic.py +8 -4
  22. pydantic_ai/providers/bedrock.py +9 -1
  23. pydantic_ai/providers/cohere.py +2 -2
  24. pydantic_ai/providers/gateway.py +187 -0
  25. pydantic_ai/providers/google.py +2 -2
  26. pydantic_ai/providers/google_gla.py +1 -1
  27. pydantic_ai/providers/groq.py +12 -5
  28. pydantic_ai/providers/heroku.py +2 -2
  29. pydantic_ai/providers/huggingface.py +1 -1
  30. pydantic_ai/providers/mistral.py +1 -1
  31. pydantic_ai/providers/openai.py +13 -0
  32. pydantic_ai/settings.py +1 -0
  33. {pydantic_ai_slim-1.0.2.dist-info → pydantic_ai_slim-1.0.4.dist-info}/METADATA +5 -5
  34. {pydantic_ai_slim-1.0.2.dist-info → pydantic_ai_slim-1.0.4.dist-info}/RECORD +37 -36
  35. {pydantic_ai_slim-1.0.2.dist-info → pydantic_ai_slim-1.0.4.dist-info}/WHEEL +0 -0
  36. {pydantic_ai_slim-1.0.2.dist-info → pydantic_ai_slim-1.0.4.dist-info}/entry_points.txt +0 -0
  37. {pydantic_ai_slim-1.0.2.dist-info → pydantic_ai_slim-1.0.4.dist-info}/licenses/LICENSE +0 -0
@@ -22,6 +22,7 @@ from pydantic_ai.messages import (
22
22
  BuiltinToolCallPart,
23
23
  BuiltinToolReturnPart,
24
24
  DocumentUrl,
25
+ FinishReason,
25
26
  ImageUrl,
26
27
  ModelMessage,
27
28
  ModelRequest,
@@ -48,6 +49,7 @@ if TYPE_CHECKING:
48
49
  from botocore.client import BaseClient
49
50
  from botocore.eventstream import EventStream
50
51
  from mypy_boto3_bedrock_runtime import BedrockRuntimeClient
52
+ from mypy_boto3_bedrock_runtime.literals import StopReasonType
51
53
  from mypy_boto3_bedrock_runtime.type_defs import (
52
54
  ContentBlockOutputTypeDef,
53
55
  ContentBlockUnionTypeDef,
@@ -55,6 +57,7 @@ if TYPE_CHECKING:
55
57
  ConverseResponseTypeDef,
56
58
  ConverseStreamMetadataEventTypeDef,
57
59
  ConverseStreamOutputTypeDef,
60
+ ConverseStreamResponseTypeDef,
58
61
  DocumentBlockTypeDef,
59
62
  GuardrailConfigurationTypeDef,
60
63
  ImageBlockTypeDef,
@@ -63,7 +66,6 @@ if TYPE_CHECKING:
63
66
  PerformanceConfigurationTypeDef,
64
67
  PromptVariableValuesTypeDef,
65
68
  ReasoningContentBlockOutputTypeDef,
66
- ReasoningTextBlockTypeDef,
67
69
  SystemContentBlockTypeDef,
68
70
  ToolChoiceTypeDef,
69
71
  ToolConfigurationTypeDef,
@@ -135,6 +137,15 @@ See [the Bedrock docs](https://docs.aws.amazon.com/bedrock/latest/userguide/mode
135
137
  P = ParamSpec('P')
136
138
  T = typing.TypeVar('T')
137
139
 
140
+ _FINISH_REASON_MAP: dict[StopReasonType, FinishReason] = {
141
+ 'content_filtered': 'content_filter',
142
+ 'end_turn': 'stop',
143
+ 'guardrail_intervened': 'content_filter',
144
+ 'max_tokens': 'length',
145
+ 'stop_sequence': 'stop',
146
+ 'tool_use': 'tool_call',
147
+ }
148
+
138
149
 
139
150
  class BedrockModelSettings(ModelSettings, total=False):
140
151
  """Settings for Bedrock models.
@@ -270,8 +281,9 @@ class BedrockConverseModel(Model):
270
281
  yield BedrockStreamedResponse(
271
282
  model_request_parameters=model_request_parameters,
272
283
  _model_name=self.model_name,
273
- _event_stream=response,
284
+ _event_stream=response['stream'],
274
285
  _provider_name=self._provider.name,
286
+ _provider_response_id=response.get('ResponseMetadata', {}).get('RequestId', None),
275
287
  )
276
288
 
277
289
  async def _process_response(self, response: ConverseResponseTypeDef) -> ModelResponse:
@@ -279,13 +291,24 @@ class BedrockConverseModel(Model):
279
291
  if message := response['output'].get('message'): # pragma: no branch
280
292
  for item in message['content']:
281
293
  if reasoning_content := item.get('reasoningContent'):
282
- reasoning_text = reasoning_content.get('reasoningText')
283
- if reasoning_text: # pragma: no branch
284
- thinking_part = ThinkingPart(
285
- content=reasoning_text['text'],
286
- signature=reasoning_text.get('signature'),
294
+ if redacted_content := reasoning_content.get('redactedContent'):
295
+ items.append(
296
+ ThinkingPart(
297
+ id='redacted_content',
298
+ content='',
299
+ signature=redacted_content.decode('utf-8'),
300
+ provider_name=self.system,
301
+ )
302
+ )
303
+ elif reasoning_text := reasoning_content.get('reasoningText'): # pragma: no branch
304
+ signature = reasoning_text.get('signature')
305
+ items.append(
306
+ ThinkingPart(
307
+ content=reasoning_text['text'],
308
+ signature=signature,
309
+ provider_name=self.system if signature else None,
310
+ )
287
311
  )
288
- items.append(thinking_part)
289
312
  if text := item.get('text'):
290
313
  items.append(TextPart(content=text))
291
314
  elif tool_use := item.get('toolUse'):
@@ -301,12 +324,18 @@ class BedrockConverseModel(Model):
301
324
  output_tokens=response['usage']['outputTokens'],
302
325
  )
303
326
  response_id = response.get('ResponseMetadata', {}).get('RequestId', None)
327
+ raw_finish_reason = response['stopReason']
328
+ provider_details = {'finish_reason': raw_finish_reason}
329
+ finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
330
+
304
331
  return ModelResponse(
305
332
  parts=items,
306
333
  usage=u,
307
334
  model_name=self.model_name,
308
335
  provider_response_id=response_id,
309
336
  provider_name=self._provider.name,
337
+ finish_reason=finish_reason,
338
+ provider_details=provider_details,
310
339
  )
311
340
 
312
341
  @overload
@@ -316,7 +345,7 @@ class BedrockConverseModel(Model):
316
345
  stream: Literal[True],
317
346
  model_settings: BedrockModelSettings | None,
318
347
  model_request_parameters: ModelRequestParameters,
319
- ) -> EventStream[ConverseStreamOutputTypeDef]:
348
+ ) -> ConverseStreamResponseTypeDef:
320
349
  pass
321
350
 
322
351
  @overload
@@ -335,7 +364,7 @@ class BedrockConverseModel(Model):
335
364
  stream: bool,
336
365
  model_settings: BedrockModelSettings | None,
337
366
  model_request_parameters: ModelRequestParameters,
338
- ) -> ConverseResponseTypeDef | EventStream[ConverseStreamOutputTypeDef]:
367
+ ) -> ConverseResponseTypeDef | ConverseStreamResponseTypeDef:
339
368
  system_prompt, bedrock_messages = await self._map_messages(messages)
340
369
  inference_config = self._map_inference_config(model_settings)
341
370
 
@@ -372,7 +401,6 @@ class BedrockConverseModel(Model):
372
401
 
373
402
  if stream:
374
403
  model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse_stream, **params))
375
- model_response = model_response['stream']
376
404
  else:
377
405
  model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse, **params))
378
406
  return model_response
@@ -476,19 +504,26 @@ class BedrockConverseModel(Model):
476
504
  if isinstance(item, TextPart):
477
505
  content.append({'text': item.content})
478
506
  elif isinstance(item, ThinkingPart):
479
- if BedrockModelProfile.from_profile(self.profile).bedrock_send_back_thinking_parts:
480
- reasoning_text: ReasoningTextBlockTypeDef = {
481
- 'text': item.content,
482
- }
483
- if item.signature:
484
- reasoning_text['signature'] = item.signature
485
- reasoning_content: ReasoningContentBlockOutputTypeDef = {
486
- 'reasoningText': reasoning_text,
487
- }
507
+ if (
508
+ item.provider_name == self.system
509
+ and item.signature
510
+ and BedrockModelProfile.from_profile(self.profile).bedrock_send_back_thinking_parts
511
+ ):
512
+ if item.id == 'redacted_content':
513
+ reasoning_content: ReasoningContentBlockOutputTypeDef = {
514
+ 'redactedContent': item.signature.encode('utf-8'),
515
+ }
516
+ else:
517
+ reasoning_content: ReasoningContentBlockOutputTypeDef = {
518
+ 'reasoningText': {
519
+ 'text': item.content,
520
+ 'signature': item.signature,
521
+ }
522
+ }
488
523
  content.append({'reasoningContent': reasoning_content})
489
524
  else:
490
- # NOTE: We don't pass the thinking part to Bedrock for models other than Claude since it raises an error.
491
- pass
525
+ start_tag, end_tag = self.profile.thinking_tags
526
+ content.append({'text': '\n'.join([start_tag, item.content, end_tag])})
492
527
  elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart):
493
528
  pass
494
529
  else:
@@ -599,25 +634,30 @@ class BedrockStreamedResponse(StreamedResponse):
599
634
  _event_stream: EventStream[ConverseStreamOutputTypeDef]
600
635
  _provider_name: str
601
636
  _timestamp: datetime = field(default_factory=_utils.now_utc)
637
+ _provider_response_id: str | None = None
602
638
 
603
- async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
639
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
604
640
  """Return an async iterator of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.
605
641
 
606
642
  This method should be implemented by subclasses to translate the vendor-specific stream of events into
607
643
  pydantic_ai-format events.
608
644
  """
645
+ if self._provider_response_id is not None: # pragma: no cover
646
+ self.provider_response_id = self._provider_response_id
647
+
609
648
  chunk: ConverseStreamOutputTypeDef
610
649
  tool_id: str | None = None
611
650
  async for chunk in _AsyncIteratorWrapper(self._event_stream):
612
651
  match chunk:
613
652
  case {'messageStart': _}:
614
653
  continue
615
- case {'messageStop': _}:
616
- continue
654
+ case {'messageStop': message_stop}:
655
+ raw_finish_reason = message_stop['stopReason']
656
+ self.provider_details = {'finish_reason': raw_finish_reason}
657
+ self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
617
658
  case {'metadata': metadata}:
618
659
  if 'usage' in metadata: # pragma: no branch
619
660
  self._usage += self._map_usage(metadata)
620
- continue
621
661
  case {'contentBlockStart': content_block_start}:
622
662
  index = content_block_start['contentBlockIndex']
623
663
  start = content_block_start['start']
@@ -637,11 +677,21 @@ class BedrockStreamedResponse(StreamedResponse):
637
677
  index = content_block_delta['contentBlockIndex']
638
678
  delta = content_block_delta['delta']
639
679
  if 'reasoningContent' in delta:
640
- yield self._parts_manager.handle_thinking_delta(
641
- vendor_part_id=index,
642
- content=delta['reasoningContent'].get('text'),
643
- signature=delta['reasoningContent'].get('signature'),
644
- )
680
+ if redacted_content := delta['reasoningContent'].get('redactedContent'):
681
+ yield self._parts_manager.handle_thinking_delta(
682
+ vendor_part_id=index,
683
+ id='redacted_content',
684
+ signature=redacted_content.decode('utf-8'),
685
+ provider_name=self.provider_name,
686
+ )
687
+ else:
688
+ signature = delta['reasoningContent'].get('signature')
689
+ yield self._parts_manager.handle_thinking_delta(
690
+ vendor_part_id=index,
691
+ content=delta['reasoningContent'].get('text'),
692
+ signature=signature,
693
+ provider_name=self.provider_name if signature else None,
694
+ )
645
695
  if 'text' in delta:
646
696
  maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=delta['text'])
647
697
  if maybe_event is not None: # pragma: no branch
@@ -6,7 +6,6 @@ from typing import Literal, cast
6
6
 
7
7
  from typing_extensions import assert_never
8
8
 
9
- from pydantic_ai._thinking_part import split_content_into_text_and_thinking
10
9
  from pydantic_ai.exceptions import UserError
11
10
 
12
11
  from .. import ModelHTTPError, usage
@@ -14,6 +13,7 @@ from .._utils import generate_tool_call_id as _generate_tool_call_id, guard_tool
14
13
  from ..messages import (
15
14
  BuiltinToolCallPart,
16
15
  BuiltinToolReturnPart,
16
+ FinishReason,
17
17
  ModelMessage,
18
18
  ModelRequest,
19
19
  ModelResponse,
@@ -35,10 +35,13 @@ from . import Model, ModelRequestParameters, check_allow_model_requests
35
35
  try:
36
36
  from cohere import (
37
37
  AssistantChatMessageV2,
38
+ AssistantMessageV2ContentItem,
38
39
  AsyncClientV2,
40
+ ChatFinishReason,
39
41
  ChatMessageV2,
40
42
  SystemChatMessageV2,
41
43
  TextAssistantMessageV2ContentItem,
44
+ ThinkingAssistantMessageV2ContentItem,
42
45
  ToolCallV2,
43
46
  ToolCallV2Function,
44
47
  ToolChatMessageV2,
@@ -80,6 +83,14 @@ allow any name in the type hints.
80
83
  See [Cohere's docs](https://docs.cohere.com/v2/docs/models) for a list of all available models.
81
84
  """
82
85
 
86
+ _FINISH_REASON_MAP: dict[ChatFinishReason, FinishReason] = {
87
+ 'COMPLETE': 'stop',
88
+ 'STOP_SEQUENCE': 'stop',
89
+ 'MAX_TOKENS': 'length',
90
+ 'TOOL_CALL': 'tool_call',
91
+ 'ERROR': 'error',
92
+ }
93
+
83
94
 
84
95
  class CohereModelSettings(ModelSettings, total=False):
85
96
  """Settings used for a Cohere model request."""
@@ -191,11 +202,12 @@ class CohereModel(Model):
191
202
  def _process_response(self, response: V2ChatResponse) -> ModelResponse:
192
203
  """Process a non-streamed response, and prepare a message to return."""
193
204
  parts: list[ModelResponsePart] = []
194
- if response.message.content is not None and len(response.message.content) > 0:
195
- # While Cohere's API returns a list, it only does that for future proofing
196
- # and currently only one item is being returned.
197
- choice = response.message.content[0]
198
- parts.extend(split_content_into_text_and_thinking(choice.text, self.profile.thinking_tags))
205
+ if response.message.content is not None:
206
+ for content in response.message.content:
207
+ if content.type == 'text':
208
+ parts.append(TextPart(content=content.text))
209
+ elif content.type == 'thinking': # pragma: no branch
210
+ parts.append(ThinkingPart(content=cast(str, content.thinking))) # pyright: ignore[reportUnknownMemberType,reportAttributeAccessIssue] - https://github.com/cohere-ai/cohere-python/issues/692
199
211
  for c in response.message.tool_calls or []:
200
212
  if c.function and c.function.name and c.function.arguments: # pragma: no branch
201
213
  parts.append(
@@ -205,8 +217,18 @@ class CohereModel(Model):
205
217
  tool_call_id=c.id or _generate_tool_call_id(),
206
218
  )
207
219
  )
220
+
221
+ raw_finish_reason = response.finish_reason
222
+ provider_details = {'finish_reason': raw_finish_reason}
223
+ finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
224
+
208
225
  return ModelResponse(
209
- parts=parts, usage=_map_usage(response), model_name=self._model_name, provider_name=self._provider.name
226
+ parts=parts,
227
+ usage=_map_usage(response),
228
+ model_name=self._model_name,
229
+ provider_name=self._provider.name,
230
+ finish_reason=finish_reason,
231
+ provider_details=provider_details,
210
232
  )
211
233
 
212
234
  def _map_messages(self, messages: list[ModelMessage]) -> list[ChatMessageV2]:
@@ -217,15 +239,13 @@ class CohereModel(Model):
217
239
  cohere_messages.extend(self._map_user_message(message))
218
240
  elif isinstance(message, ModelResponse):
219
241
  texts: list[str] = []
242
+ thinking: list[str] = []
220
243
  tool_calls: list[ToolCallV2] = []
221
244
  for item in message.parts:
222
245
  if isinstance(item, TextPart):
223
246
  texts.append(item.content)
224
247
  elif isinstance(item, ThinkingPart):
225
- # NOTE: We don't send ThinkingPart to the providers yet. If you are unsatisfied with this,
226
- # please open an issue. The below code is the code to send thinking to the provider.
227
- # texts.append(f'<think>\n{item.content}\n</think>')
228
- pass
248
+ thinking.append(item.content)
229
249
  elif isinstance(item, ToolCallPart):
230
250
  tool_calls.append(self._map_tool_call(item))
231
251
  elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
@@ -233,9 +253,15 @@ class CohereModel(Model):
233
253
  pass
234
254
  else:
235
255
  assert_never(item)
256
+
236
257
  message_param = AssistantChatMessageV2(role='assistant')
237
- if texts:
238
- message_param.content = [TextAssistantMessageV2ContentItem(text='\n\n'.join(texts))]
258
+ if texts or thinking:
259
+ contents: list[AssistantMessageV2ContentItem] = []
260
+ if thinking:
261
+ contents.append(ThinkingAssistantMessageV2ContentItem(thinking='\n\n'.join(thinking))) # pyright: ignore[reportCallIssue] - https://github.com/cohere-ai/cohere-python/issues/692
262
+ if texts: # pragma: no branch
263
+ contents.append(TextAssistantMessageV2ContentItem(text='\n\n'.join(texts)))
264
+ message_param.content = contents
239
265
  if tool_calls:
240
266
  message_param.tool_calls = tool_calls
241
267
  cohere_messages.append(message_param)
@@ -31,7 +31,7 @@ from ..messages import (
31
31
  UserContent,
32
32
  UserPromptPart,
33
33
  )
34
- from ..profiles import ModelProfileSpec
34
+ from ..profiles import ModelProfile, ModelProfileSpec
35
35
  from ..settings import ModelSettings
36
36
  from ..tools import ToolDefinition
37
37
  from . import Model, ModelRequestParameters, StreamedResponse
@@ -111,6 +111,12 @@ class FunctionModel(Model):
111
111
  stream_function_name = self.stream_function.__name__ if self.stream_function is not None else ''
112
112
  self._model_name = model_name or f'function:{function_name}:{stream_function_name}'
113
113
 
114
+ # Use a default profile that supports JSON schema and object output if none provided
115
+ if profile is None:
116
+ profile = ModelProfile(
117
+ supports_json_schema_output=True,
118
+ supports_json_object_output=True,
119
+ )
114
120
  super().__init__(settings=settings, profile=profile)
115
121
 
116
122
  async def request(
@@ -285,6 +291,7 @@ class FunctionStreamedResponse(StreamedResponse):
285
291
  vendor_part_id=dtc_index,
286
292
  content=delta.content,
287
293
  signature=delta.signature,
294
+ provider_name='function' if delta.signature else None,
288
295
  )
289
296
  elif isinstance(delta, DeltaToolCall):
290
297
  if delta.json_args:
@@ -254,6 +254,7 @@ class GoogleModel(Model):
254
254
  stop_sequences=generation_config.get('stop_sequences'),
255
255
  presence_penalty=generation_config.get('presence_penalty'),
256
256
  frequency_penalty=generation_config.get('frequency_penalty'),
257
+ seed=generation_config.get('seed'),
257
258
  thinking_config=generation_config.get('thinking_config'),
258
259
  media_resolution=generation_config.get('media_resolution'),
259
260
  response_mime_type=generation_config.get('response_mime_type'),
@@ -397,6 +398,7 @@ class GoogleModel(Model):
397
398
  stop_sequences=model_settings.get('stop_sequences'),
398
399
  presence_penalty=model_settings.get('presence_penalty'),
399
400
  frequency_penalty=model_settings.get('frequency_penalty'),
401
+ seed=model_settings.get('seed'),
400
402
  safety_settings=model_settings.get('google_safety_settings'),
401
403
  thinking_config=model_settings.get('google_thinking_config'),
402
404
  labels=model_settings.get('google_labels'),
@@ -451,7 +453,7 @@ class GoogleModel(Model):
451
453
 
452
454
  return GeminiStreamedResponse(
453
455
  model_request_parameters=model_request_parameters,
454
- _model_name=self._model_name,
456
+ _model_name=first_chunk.model_version or self._model_name,
455
457
  _response=peekable_response,
456
458
  _timestamp=first_chunk.create_time or _utils.now_utc(),
457
459
  _provider_name=self._provider.name,
@@ -501,7 +503,7 @@ class GoogleModel(Model):
501
503
  message_parts = [{'text': ''}]
502
504
  contents.append({'role': 'user', 'parts': message_parts})
503
505
  elif isinstance(m, ModelResponse):
504
- contents.append(_content_model_response(m))
506
+ contents.append(_content_model_response(m, self.system))
505
507
  else:
506
508
  assert_never(m)
507
509
  if instructions := self._get_instructions(messages):
@@ -566,7 +568,7 @@ class GeminiStreamedResponse(StreamedResponse):
566
568
  _timestamp: datetime
567
569
  _provider_name: str
568
570
 
569
- async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
571
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
570
572
  async for chunk in self._response:
571
573
  self._usage = _metadata_as_usage(chunk)
572
574
 
@@ -590,6 +592,14 @@ class GeminiStreamedResponse(StreamedResponse):
590
592
  raise UnexpectedModelBehavior('Content field missing from streaming Gemini response', str(chunk))
591
593
  parts = candidate.content.parts or []
592
594
  for part in parts:
595
+ if part.thought_signature:
596
+ signature = base64.b64encode(part.thought_signature).decode('utf-8')
597
+ yield self._parts_manager.handle_thinking_delta(
598
+ vendor_part_id='thinking',
599
+ signature=signature,
600
+ provider_name=self.provider_name,
601
+ )
602
+
593
603
  if part.text is not None:
594
604
  if part.thought:
595
605
  yield self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=part.text)
@@ -629,29 +639,41 @@ class GeminiStreamedResponse(StreamedResponse):
629
639
  return self._timestamp
630
640
 
631
641
 
632
- def _content_model_response(m: ModelResponse) -> ContentDict:
642
+ def _content_model_response(m: ModelResponse, provider_name: str) -> ContentDict:
633
643
  parts: list[PartDict] = []
644
+ thought_signature: bytes | None = None
634
645
  for item in m.parts:
646
+ part: PartDict = {}
647
+ if thought_signature:
648
+ part['thought_signature'] = thought_signature
649
+ thought_signature = None
650
+
635
651
  if isinstance(item, ToolCallPart):
636
652
  function_call = FunctionCallDict(name=item.tool_name, args=item.args_as_dict(), id=item.tool_call_id)
637
- parts.append({'function_call': function_call})
653
+ part['function_call'] = function_call
638
654
  elif isinstance(item, TextPart):
639
- parts.append({'text': item.content})
640
- elif isinstance(item, ThinkingPart): # pragma: no cover
641
- # NOTE: We don't send ThinkingPart to the providers yet. If you are unsatisfied with this,
642
- # please open an issue. The below code is the code to send thinking to the provider.
643
- # parts.append({'text': item.content, 'thought': True})
644
- pass
655
+ part['text'] = item.content
656
+ elif isinstance(item, ThinkingPart):
657
+ if item.provider_name == provider_name and item.signature:
658
+ # The thought signature is to be included on the _next_ part, not the thought part itself
659
+ thought_signature = base64.b64decode(item.signature)
660
+
661
+ if item.content:
662
+ part['text'] = item.content
663
+ part['thought'] = True
645
664
  elif isinstance(item, BuiltinToolCallPart):
646
- if item.provider_name == 'google':
665
+ if item.provider_name == provider_name:
647
666
  if item.tool_name == 'code_execution': # pragma: no branch
648
- parts.append({'executable_code': cast(ExecutableCodeDict, item.args)})
667
+ part['executable_code'] = cast(ExecutableCodeDict, item.args)
649
668
  elif isinstance(item, BuiltinToolReturnPart):
650
- if item.provider_name == 'google':
669
+ if item.provider_name == provider_name:
651
670
  if item.tool_name == 'code_execution': # pragma: no branch
652
- parts.append({'code_execution_result': item.content})
671
+ part['code_execution_result'] = item.content
653
672
  else:
654
673
  assert_never(item)
674
+
675
+ if part:
676
+ parts.append(part)
655
677
  return ContentDict(role='model', parts=parts)
656
678
 
657
679
 
@@ -665,37 +687,43 @@ def _process_response_from_parts(
665
687
  finish_reason: FinishReason | None = None,
666
688
  ) -> ModelResponse:
667
689
  items: list[ModelResponsePart] = []
690
+ item: ModelResponsePart | None = None
668
691
  for part in parts:
692
+ if part.thought_signature:
693
+ signature = base64.b64encode(part.thought_signature).decode('utf-8')
694
+ if not isinstance(item, ThinkingPart):
695
+ item = ThinkingPart(content='')
696
+ items.append(item)
697
+ item.signature = signature
698
+ item.provider_name = provider_name
699
+
669
700
  if part.executable_code is not None:
670
- items.append(
671
- BuiltinToolCallPart(
672
- provider_name='google', args=part.executable_code.model_dump(), tool_name='code_execution'
673
- )
701
+ item = BuiltinToolCallPart(
702
+ provider_name=provider_name, args=part.executable_code.model_dump(), tool_name='code_execution'
674
703
  )
675
704
  elif part.code_execution_result is not None:
676
- items.append(
677
- BuiltinToolReturnPart(
678
- provider_name='google',
679
- tool_name='code_execution',
680
- content=part.code_execution_result,
681
- tool_call_id='not_provided',
682
- )
705
+ item = BuiltinToolReturnPart(
706
+ provider_name=provider_name,
707
+ tool_name='code_execution',
708
+ content=part.code_execution_result,
709
+ tool_call_id='not_provided',
683
710
  )
684
711
  elif part.text is not None:
685
712
  if part.thought:
686
- items.append(ThinkingPart(content=part.text))
713
+ item = ThinkingPart(content=part.text)
687
714
  else:
688
- items.append(TextPart(content=part.text))
715
+ item = TextPart(content=part.text)
689
716
  elif part.function_call:
690
717
  assert part.function_call.name is not None
691
- tool_call_part = ToolCallPart(tool_name=part.function_call.name, args=part.function_call.args)
718
+ item = ToolCallPart(tool_name=part.function_call.name, args=part.function_call.args)
692
719
  if part.function_call.id is not None:
693
- tool_call_part.tool_call_id = part.function_call.id # pragma: no cover
694
- items.append(tool_call_part)
695
- elif part.function_response: # pragma: no cover
720
+ item.tool_call_id = part.function_call.id # pragma: no cover
721
+ else: # pragma: no cover
696
722
  raise UnexpectedModelBehavior(
697
- f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
723
+ f'Unsupported response from Gemini, expected all parts to be function calls, text, or thoughts, got: {part!r}'
698
724
  )
725
+
726
+ items.append(item)
699
727
  return ModelResponse(
700
728
  parts=items,
701
729
  model_name=model_name,