pydantic-ai-slim 1.0.2__py3-none-any.whl → 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_output.py +19 -7
- pydantic_ai/_parts_manager.py +10 -12
- pydantic_ai/_tool_manager.py +18 -1
- pydantic_ai/ag_ui.py +32 -17
- pydantic_ai/agent/abstract.py +8 -0
- pydantic_ai/durable_exec/dbos/_agent.py +5 -2
- pydantic_ai/durable_exec/temporal/_agent.py +1 -1
- pydantic_ai/messages.py +30 -6
- pydantic_ai/models/__init__.py +5 -1
- pydantic_ai/models/anthropic.py +54 -25
- pydantic_ai/models/bedrock.py +81 -31
- pydantic_ai/models/cohere.py +39 -13
- pydantic_ai/models/function.py +8 -1
- pydantic_ai/models/google.py +61 -33
- pydantic_ai/models/groq.py +35 -7
- pydantic_ai/models/huggingface.py +27 -5
- pydantic_ai/models/mistral.py +55 -21
- pydantic_ai/models/openai.py +135 -63
- pydantic_ai/profiles/openai.py +11 -0
- pydantic_ai/providers/__init__.py +3 -0
- pydantic_ai/providers/anthropic.py +8 -4
- pydantic_ai/providers/bedrock.py +9 -1
- pydantic_ai/providers/cohere.py +2 -2
- pydantic_ai/providers/gateway.py +187 -0
- pydantic_ai/providers/google.py +2 -2
- pydantic_ai/providers/google_gla.py +1 -1
- pydantic_ai/providers/groq.py +12 -5
- pydantic_ai/providers/heroku.py +2 -2
- pydantic_ai/providers/huggingface.py +1 -1
- pydantic_ai/providers/mistral.py +1 -1
- pydantic_ai/providers/openai.py +13 -0
- pydantic_ai/settings.py +1 -0
- {pydantic_ai_slim-1.0.2.dist-info → pydantic_ai_slim-1.0.4.dist-info}/METADATA +5 -5
- {pydantic_ai_slim-1.0.2.dist-info → pydantic_ai_slim-1.0.4.dist-info}/RECORD +37 -36
- {pydantic_ai_slim-1.0.2.dist-info → pydantic_ai_slim-1.0.4.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-1.0.2.dist-info → pydantic_ai_slim-1.0.4.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-1.0.2.dist-info → pydantic_ai_slim-1.0.4.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/bedrock.py
CHANGED
|
@@ -22,6 +22,7 @@ from pydantic_ai.messages import (
|
|
|
22
22
|
BuiltinToolCallPart,
|
|
23
23
|
BuiltinToolReturnPart,
|
|
24
24
|
DocumentUrl,
|
|
25
|
+
FinishReason,
|
|
25
26
|
ImageUrl,
|
|
26
27
|
ModelMessage,
|
|
27
28
|
ModelRequest,
|
|
@@ -48,6 +49,7 @@ if TYPE_CHECKING:
|
|
|
48
49
|
from botocore.client import BaseClient
|
|
49
50
|
from botocore.eventstream import EventStream
|
|
50
51
|
from mypy_boto3_bedrock_runtime import BedrockRuntimeClient
|
|
52
|
+
from mypy_boto3_bedrock_runtime.literals import StopReasonType
|
|
51
53
|
from mypy_boto3_bedrock_runtime.type_defs import (
|
|
52
54
|
ContentBlockOutputTypeDef,
|
|
53
55
|
ContentBlockUnionTypeDef,
|
|
@@ -55,6 +57,7 @@ if TYPE_CHECKING:
|
|
|
55
57
|
ConverseResponseTypeDef,
|
|
56
58
|
ConverseStreamMetadataEventTypeDef,
|
|
57
59
|
ConverseStreamOutputTypeDef,
|
|
60
|
+
ConverseStreamResponseTypeDef,
|
|
58
61
|
DocumentBlockTypeDef,
|
|
59
62
|
GuardrailConfigurationTypeDef,
|
|
60
63
|
ImageBlockTypeDef,
|
|
@@ -63,7 +66,6 @@ if TYPE_CHECKING:
|
|
|
63
66
|
PerformanceConfigurationTypeDef,
|
|
64
67
|
PromptVariableValuesTypeDef,
|
|
65
68
|
ReasoningContentBlockOutputTypeDef,
|
|
66
|
-
ReasoningTextBlockTypeDef,
|
|
67
69
|
SystemContentBlockTypeDef,
|
|
68
70
|
ToolChoiceTypeDef,
|
|
69
71
|
ToolConfigurationTypeDef,
|
|
@@ -135,6 +137,15 @@ See [the Bedrock docs](https://docs.aws.amazon.com/bedrock/latest/userguide/mode
|
|
|
135
137
|
P = ParamSpec('P')
|
|
136
138
|
T = typing.TypeVar('T')
|
|
137
139
|
|
|
140
|
+
_FINISH_REASON_MAP: dict[StopReasonType, FinishReason] = {
|
|
141
|
+
'content_filtered': 'content_filter',
|
|
142
|
+
'end_turn': 'stop',
|
|
143
|
+
'guardrail_intervened': 'content_filter',
|
|
144
|
+
'max_tokens': 'length',
|
|
145
|
+
'stop_sequence': 'stop',
|
|
146
|
+
'tool_use': 'tool_call',
|
|
147
|
+
}
|
|
148
|
+
|
|
138
149
|
|
|
139
150
|
class BedrockModelSettings(ModelSettings, total=False):
|
|
140
151
|
"""Settings for Bedrock models.
|
|
@@ -270,8 +281,9 @@ class BedrockConverseModel(Model):
|
|
|
270
281
|
yield BedrockStreamedResponse(
|
|
271
282
|
model_request_parameters=model_request_parameters,
|
|
272
283
|
_model_name=self.model_name,
|
|
273
|
-
_event_stream=response,
|
|
284
|
+
_event_stream=response['stream'],
|
|
274
285
|
_provider_name=self._provider.name,
|
|
286
|
+
_provider_response_id=response.get('ResponseMetadata', {}).get('RequestId', None),
|
|
275
287
|
)
|
|
276
288
|
|
|
277
289
|
async def _process_response(self, response: ConverseResponseTypeDef) -> ModelResponse:
|
|
@@ -279,13 +291,24 @@ class BedrockConverseModel(Model):
|
|
|
279
291
|
if message := response['output'].get('message'): # pragma: no branch
|
|
280
292
|
for item in message['content']:
|
|
281
293
|
if reasoning_content := item.get('reasoningContent'):
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
294
|
+
if redacted_content := reasoning_content.get('redactedContent'):
|
|
295
|
+
items.append(
|
|
296
|
+
ThinkingPart(
|
|
297
|
+
id='redacted_content',
|
|
298
|
+
content='',
|
|
299
|
+
signature=redacted_content.decode('utf-8'),
|
|
300
|
+
provider_name=self.system,
|
|
301
|
+
)
|
|
302
|
+
)
|
|
303
|
+
elif reasoning_text := reasoning_content.get('reasoningText'): # pragma: no branch
|
|
304
|
+
signature = reasoning_text.get('signature')
|
|
305
|
+
items.append(
|
|
306
|
+
ThinkingPart(
|
|
307
|
+
content=reasoning_text['text'],
|
|
308
|
+
signature=signature,
|
|
309
|
+
provider_name=self.system if signature else None,
|
|
310
|
+
)
|
|
287
311
|
)
|
|
288
|
-
items.append(thinking_part)
|
|
289
312
|
if text := item.get('text'):
|
|
290
313
|
items.append(TextPart(content=text))
|
|
291
314
|
elif tool_use := item.get('toolUse'):
|
|
@@ -301,12 +324,18 @@ class BedrockConverseModel(Model):
|
|
|
301
324
|
output_tokens=response['usage']['outputTokens'],
|
|
302
325
|
)
|
|
303
326
|
response_id = response.get('ResponseMetadata', {}).get('RequestId', None)
|
|
327
|
+
raw_finish_reason = response['stopReason']
|
|
328
|
+
provider_details = {'finish_reason': raw_finish_reason}
|
|
329
|
+
finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
|
|
330
|
+
|
|
304
331
|
return ModelResponse(
|
|
305
332
|
parts=items,
|
|
306
333
|
usage=u,
|
|
307
334
|
model_name=self.model_name,
|
|
308
335
|
provider_response_id=response_id,
|
|
309
336
|
provider_name=self._provider.name,
|
|
337
|
+
finish_reason=finish_reason,
|
|
338
|
+
provider_details=provider_details,
|
|
310
339
|
)
|
|
311
340
|
|
|
312
341
|
@overload
|
|
@@ -316,7 +345,7 @@ class BedrockConverseModel(Model):
|
|
|
316
345
|
stream: Literal[True],
|
|
317
346
|
model_settings: BedrockModelSettings | None,
|
|
318
347
|
model_request_parameters: ModelRequestParameters,
|
|
319
|
-
) ->
|
|
348
|
+
) -> ConverseStreamResponseTypeDef:
|
|
320
349
|
pass
|
|
321
350
|
|
|
322
351
|
@overload
|
|
@@ -335,7 +364,7 @@ class BedrockConverseModel(Model):
|
|
|
335
364
|
stream: bool,
|
|
336
365
|
model_settings: BedrockModelSettings | None,
|
|
337
366
|
model_request_parameters: ModelRequestParameters,
|
|
338
|
-
) -> ConverseResponseTypeDef |
|
|
367
|
+
) -> ConverseResponseTypeDef | ConverseStreamResponseTypeDef:
|
|
339
368
|
system_prompt, bedrock_messages = await self._map_messages(messages)
|
|
340
369
|
inference_config = self._map_inference_config(model_settings)
|
|
341
370
|
|
|
@@ -372,7 +401,6 @@ class BedrockConverseModel(Model):
|
|
|
372
401
|
|
|
373
402
|
if stream:
|
|
374
403
|
model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse_stream, **params))
|
|
375
|
-
model_response = model_response['stream']
|
|
376
404
|
else:
|
|
377
405
|
model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse, **params))
|
|
378
406
|
return model_response
|
|
@@ -476,19 +504,26 @@ class BedrockConverseModel(Model):
|
|
|
476
504
|
if isinstance(item, TextPart):
|
|
477
505
|
content.append({'text': item.content})
|
|
478
506
|
elif isinstance(item, ThinkingPart):
|
|
479
|
-
if
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
507
|
+
if (
|
|
508
|
+
item.provider_name == self.system
|
|
509
|
+
and item.signature
|
|
510
|
+
and BedrockModelProfile.from_profile(self.profile).bedrock_send_back_thinking_parts
|
|
511
|
+
):
|
|
512
|
+
if item.id == 'redacted_content':
|
|
513
|
+
reasoning_content: ReasoningContentBlockOutputTypeDef = {
|
|
514
|
+
'redactedContent': item.signature.encode('utf-8'),
|
|
515
|
+
}
|
|
516
|
+
else:
|
|
517
|
+
reasoning_content: ReasoningContentBlockOutputTypeDef = {
|
|
518
|
+
'reasoningText': {
|
|
519
|
+
'text': item.content,
|
|
520
|
+
'signature': item.signature,
|
|
521
|
+
}
|
|
522
|
+
}
|
|
488
523
|
content.append({'reasoningContent': reasoning_content})
|
|
489
524
|
else:
|
|
490
|
-
|
|
491
|
-
|
|
525
|
+
start_tag, end_tag = self.profile.thinking_tags
|
|
526
|
+
content.append({'text': '\n'.join([start_tag, item.content, end_tag])})
|
|
492
527
|
elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart):
|
|
493
528
|
pass
|
|
494
529
|
else:
|
|
@@ -599,25 +634,30 @@ class BedrockStreamedResponse(StreamedResponse):
|
|
|
599
634
|
_event_stream: EventStream[ConverseStreamOutputTypeDef]
|
|
600
635
|
_provider_name: str
|
|
601
636
|
_timestamp: datetime = field(default_factory=_utils.now_utc)
|
|
637
|
+
_provider_response_id: str | None = None
|
|
602
638
|
|
|
603
|
-
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
639
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
|
|
604
640
|
"""Return an async iterator of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.
|
|
605
641
|
|
|
606
642
|
This method should be implemented by subclasses to translate the vendor-specific stream of events into
|
|
607
643
|
pydantic_ai-format events.
|
|
608
644
|
"""
|
|
645
|
+
if self._provider_response_id is not None: # pragma: no cover
|
|
646
|
+
self.provider_response_id = self._provider_response_id
|
|
647
|
+
|
|
609
648
|
chunk: ConverseStreamOutputTypeDef
|
|
610
649
|
tool_id: str | None = None
|
|
611
650
|
async for chunk in _AsyncIteratorWrapper(self._event_stream):
|
|
612
651
|
match chunk:
|
|
613
652
|
case {'messageStart': _}:
|
|
614
653
|
continue
|
|
615
|
-
case {'messageStop':
|
|
616
|
-
|
|
654
|
+
case {'messageStop': message_stop}:
|
|
655
|
+
raw_finish_reason = message_stop['stopReason']
|
|
656
|
+
self.provider_details = {'finish_reason': raw_finish_reason}
|
|
657
|
+
self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
|
|
617
658
|
case {'metadata': metadata}:
|
|
618
659
|
if 'usage' in metadata: # pragma: no branch
|
|
619
660
|
self._usage += self._map_usage(metadata)
|
|
620
|
-
continue
|
|
621
661
|
case {'contentBlockStart': content_block_start}:
|
|
622
662
|
index = content_block_start['contentBlockIndex']
|
|
623
663
|
start = content_block_start['start']
|
|
@@ -637,11 +677,21 @@ class BedrockStreamedResponse(StreamedResponse):
|
|
|
637
677
|
index = content_block_delta['contentBlockIndex']
|
|
638
678
|
delta = content_block_delta['delta']
|
|
639
679
|
if 'reasoningContent' in delta:
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
680
|
+
if redacted_content := delta['reasoningContent'].get('redactedContent'):
|
|
681
|
+
yield self._parts_manager.handle_thinking_delta(
|
|
682
|
+
vendor_part_id=index,
|
|
683
|
+
id='redacted_content',
|
|
684
|
+
signature=redacted_content.decode('utf-8'),
|
|
685
|
+
provider_name=self.provider_name,
|
|
686
|
+
)
|
|
687
|
+
else:
|
|
688
|
+
signature = delta['reasoningContent'].get('signature')
|
|
689
|
+
yield self._parts_manager.handle_thinking_delta(
|
|
690
|
+
vendor_part_id=index,
|
|
691
|
+
content=delta['reasoningContent'].get('text'),
|
|
692
|
+
signature=signature,
|
|
693
|
+
provider_name=self.provider_name if signature else None,
|
|
694
|
+
)
|
|
645
695
|
if 'text' in delta:
|
|
646
696
|
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=delta['text'])
|
|
647
697
|
if maybe_event is not None: # pragma: no branch
|
pydantic_ai/models/cohere.py
CHANGED
|
@@ -6,7 +6,6 @@ from typing import Literal, cast
|
|
|
6
6
|
|
|
7
7
|
from typing_extensions import assert_never
|
|
8
8
|
|
|
9
|
-
from pydantic_ai._thinking_part import split_content_into_text_and_thinking
|
|
10
9
|
from pydantic_ai.exceptions import UserError
|
|
11
10
|
|
|
12
11
|
from .. import ModelHTTPError, usage
|
|
@@ -14,6 +13,7 @@ from .._utils import generate_tool_call_id as _generate_tool_call_id, guard_tool
|
|
|
14
13
|
from ..messages import (
|
|
15
14
|
BuiltinToolCallPart,
|
|
16
15
|
BuiltinToolReturnPart,
|
|
16
|
+
FinishReason,
|
|
17
17
|
ModelMessage,
|
|
18
18
|
ModelRequest,
|
|
19
19
|
ModelResponse,
|
|
@@ -35,10 +35,13 @@ from . import Model, ModelRequestParameters, check_allow_model_requests
|
|
|
35
35
|
try:
|
|
36
36
|
from cohere import (
|
|
37
37
|
AssistantChatMessageV2,
|
|
38
|
+
AssistantMessageV2ContentItem,
|
|
38
39
|
AsyncClientV2,
|
|
40
|
+
ChatFinishReason,
|
|
39
41
|
ChatMessageV2,
|
|
40
42
|
SystemChatMessageV2,
|
|
41
43
|
TextAssistantMessageV2ContentItem,
|
|
44
|
+
ThinkingAssistantMessageV2ContentItem,
|
|
42
45
|
ToolCallV2,
|
|
43
46
|
ToolCallV2Function,
|
|
44
47
|
ToolChatMessageV2,
|
|
@@ -80,6 +83,14 @@ allow any name in the type hints.
|
|
|
80
83
|
See [Cohere's docs](https://docs.cohere.com/v2/docs/models) for a list of all available models.
|
|
81
84
|
"""
|
|
82
85
|
|
|
86
|
+
_FINISH_REASON_MAP: dict[ChatFinishReason, FinishReason] = {
|
|
87
|
+
'COMPLETE': 'stop',
|
|
88
|
+
'STOP_SEQUENCE': 'stop',
|
|
89
|
+
'MAX_TOKENS': 'length',
|
|
90
|
+
'TOOL_CALL': 'tool_call',
|
|
91
|
+
'ERROR': 'error',
|
|
92
|
+
}
|
|
93
|
+
|
|
83
94
|
|
|
84
95
|
class CohereModelSettings(ModelSettings, total=False):
|
|
85
96
|
"""Settings used for a Cohere model request."""
|
|
@@ -191,11 +202,12 @@ class CohereModel(Model):
|
|
|
191
202
|
def _process_response(self, response: V2ChatResponse) -> ModelResponse:
|
|
192
203
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
193
204
|
parts: list[ModelResponsePart] = []
|
|
194
|
-
if response.message.content is not None
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
205
|
+
if response.message.content is not None:
|
|
206
|
+
for content in response.message.content:
|
|
207
|
+
if content.type == 'text':
|
|
208
|
+
parts.append(TextPart(content=content.text))
|
|
209
|
+
elif content.type == 'thinking': # pragma: no branch
|
|
210
|
+
parts.append(ThinkingPart(content=cast(str, content.thinking))) # pyright: ignore[reportUnknownMemberType,reportAttributeAccessIssue] - https://github.com/cohere-ai/cohere-python/issues/692
|
|
199
211
|
for c in response.message.tool_calls or []:
|
|
200
212
|
if c.function and c.function.name and c.function.arguments: # pragma: no branch
|
|
201
213
|
parts.append(
|
|
@@ -205,8 +217,18 @@ class CohereModel(Model):
|
|
|
205
217
|
tool_call_id=c.id or _generate_tool_call_id(),
|
|
206
218
|
)
|
|
207
219
|
)
|
|
220
|
+
|
|
221
|
+
raw_finish_reason = response.finish_reason
|
|
222
|
+
provider_details = {'finish_reason': raw_finish_reason}
|
|
223
|
+
finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
|
|
224
|
+
|
|
208
225
|
return ModelResponse(
|
|
209
|
-
parts=parts,
|
|
226
|
+
parts=parts,
|
|
227
|
+
usage=_map_usage(response),
|
|
228
|
+
model_name=self._model_name,
|
|
229
|
+
provider_name=self._provider.name,
|
|
230
|
+
finish_reason=finish_reason,
|
|
231
|
+
provider_details=provider_details,
|
|
210
232
|
)
|
|
211
233
|
|
|
212
234
|
def _map_messages(self, messages: list[ModelMessage]) -> list[ChatMessageV2]:
|
|
@@ -217,15 +239,13 @@ class CohereModel(Model):
|
|
|
217
239
|
cohere_messages.extend(self._map_user_message(message))
|
|
218
240
|
elif isinstance(message, ModelResponse):
|
|
219
241
|
texts: list[str] = []
|
|
242
|
+
thinking: list[str] = []
|
|
220
243
|
tool_calls: list[ToolCallV2] = []
|
|
221
244
|
for item in message.parts:
|
|
222
245
|
if isinstance(item, TextPart):
|
|
223
246
|
texts.append(item.content)
|
|
224
247
|
elif isinstance(item, ThinkingPart):
|
|
225
|
-
|
|
226
|
-
# please open an issue. The below code is the code to send thinking to the provider.
|
|
227
|
-
# texts.append(f'<think>\n{item.content}\n</think>')
|
|
228
|
-
pass
|
|
248
|
+
thinking.append(item.content)
|
|
229
249
|
elif isinstance(item, ToolCallPart):
|
|
230
250
|
tool_calls.append(self._map_tool_call(item))
|
|
231
251
|
elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
|
|
@@ -233,9 +253,15 @@ class CohereModel(Model):
|
|
|
233
253
|
pass
|
|
234
254
|
else:
|
|
235
255
|
assert_never(item)
|
|
256
|
+
|
|
236
257
|
message_param = AssistantChatMessageV2(role='assistant')
|
|
237
|
-
if texts:
|
|
238
|
-
|
|
258
|
+
if texts or thinking:
|
|
259
|
+
contents: list[AssistantMessageV2ContentItem] = []
|
|
260
|
+
if thinking:
|
|
261
|
+
contents.append(ThinkingAssistantMessageV2ContentItem(thinking='\n\n'.join(thinking))) # pyright: ignore[reportCallIssue] - https://github.com/cohere-ai/cohere-python/issues/692
|
|
262
|
+
if texts: # pragma: no branch
|
|
263
|
+
contents.append(TextAssistantMessageV2ContentItem(text='\n\n'.join(texts)))
|
|
264
|
+
message_param.content = contents
|
|
239
265
|
if tool_calls:
|
|
240
266
|
message_param.tool_calls = tool_calls
|
|
241
267
|
cohere_messages.append(message_param)
|
pydantic_ai/models/function.py
CHANGED
|
@@ -31,7 +31,7 @@ from ..messages import (
|
|
|
31
31
|
UserContent,
|
|
32
32
|
UserPromptPart,
|
|
33
33
|
)
|
|
34
|
-
from ..profiles import ModelProfileSpec
|
|
34
|
+
from ..profiles import ModelProfile, ModelProfileSpec
|
|
35
35
|
from ..settings import ModelSettings
|
|
36
36
|
from ..tools import ToolDefinition
|
|
37
37
|
from . import Model, ModelRequestParameters, StreamedResponse
|
|
@@ -111,6 +111,12 @@ class FunctionModel(Model):
|
|
|
111
111
|
stream_function_name = self.stream_function.__name__ if self.stream_function is not None else ''
|
|
112
112
|
self._model_name = model_name or f'function:{function_name}:{stream_function_name}'
|
|
113
113
|
|
|
114
|
+
# Use a default profile that supports JSON schema and object output if none provided
|
|
115
|
+
if profile is None:
|
|
116
|
+
profile = ModelProfile(
|
|
117
|
+
supports_json_schema_output=True,
|
|
118
|
+
supports_json_object_output=True,
|
|
119
|
+
)
|
|
114
120
|
super().__init__(settings=settings, profile=profile)
|
|
115
121
|
|
|
116
122
|
async def request(
|
|
@@ -285,6 +291,7 @@ class FunctionStreamedResponse(StreamedResponse):
|
|
|
285
291
|
vendor_part_id=dtc_index,
|
|
286
292
|
content=delta.content,
|
|
287
293
|
signature=delta.signature,
|
|
294
|
+
provider_name='function' if delta.signature else None,
|
|
288
295
|
)
|
|
289
296
|
elif isinstance(delta, DeltaToolCall):
|
|
290
297
|
if delta.json_args:
|
pydantic_ai/models/google.py
CHANGED
|
@@ -254,6 +254,7 @@ class GoogleModel(Model):
|
|
|
254
254
|
stop_sequences=generation_config.get('stop_sequences'),
|
|
255
255
|
presence_penalty=generation_config.get('presence_penalty'),
|
|
256
256
|
frequency_penalty=generation_config.get('frequency_penalty'),
|
|
257
|
+
seed=generation_config.get('seed'),
|
|
257
258
|
thinking_config=generation_config.get('thinking_config'),
|
|
258
259
|
media_resolution=generation_config.get('media_resolution'),
|
|
259
260
|
response_mime_type=generation_config.get('response_mime_type'),
|
|
@@ -397,6 +398,7 @@ class GoogleModel(Model):
|
|
|
397
398
|
stop_sequences=model_settings.get('stop_sequences'),
|
|
398
399
|
presence_penalty=model_settings.get('presence_penalty'),
|
|
399
400
|
frequency_penalty=model_settings.get('frequency_penalty'),
|
|
401
|
+
seed=model_settings.get('seed'),
|
|
400
402
|
safety_settings=model_settings.get('google_safety_settings'),
|
|
401
403
|
thinking_config=model_settings.get('google_thinking_config'),
|
|
402
404
|
labels=model_settings.get('google_labels'),
|
|
@@ -451,7 +453,7 @@ class GoogleModel(Model):
|
|
|
451
453
|
|
|
452
454
|
return GeminiStreamedResponse(
|
|
453
455
|
model_request_parameters=model_request_parameters,
|
|
454
|
-
_model_name=self._model_name,
|
|
456
|
+
_model_name=first_chunk.model_version or self._model_name,
|
|
455
457
|
_response=peekable_response,
|
|
456
458
|
_timestamp=first_chunk.create_time or _utils.now_utc(),
|
|
457
459
|
_provider_name=self._provider.name,
|
|
@@ -501,7 +503,7 @@ class GoogleModel(Model):
|
|
|
501
503
|
message_parts = [{'text': ''}]
|
|
502
504
|
contents.append({'role': 'user', 'parts': message_parts})
|
|
503
505
|
elif isinstance(m, ModelResponse):
|
|
504
|
-
contents.append(_content_model_response(m))
|
|
506
|
+
contents.append(_content_model_response(m, self.system))
|
|
505
507
|
else:
|
|
506
508
|
assert_never(m)
|
|
507
509
|
if instructions := self._get_instructions(messages):
|
|
@@ -566,7 +568,7 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
566
568
|
_timestamp: datetime
|
|
567
569
|
_provider_name: str
|
|
568
570
|
|
|
569
|
-
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
571
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
|
|
570
572
|
async for chunk in self._response:
|
|
571
573
|
self._usage = _metadata_as_usage(chunk)
|
|
572
574
|
|
|
@@ -590,6 +592,14 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
590
592
|
raise UnexpectedModelBehavior('Content field missing from streaming Gemini response', str(chunk))
|
|
591
593
|
parts = candidate.content.parts or []
|
|
592
594
|
for part in parts:
|
|
595
|
+
if part.thought_signature:
|
|
596
|
+
signature = base64.b64encode(part.thought_signature).decode('utf-8')
|
|
597
|
+
yield self._parts_manager.handle_thinking_delta(
|
|
598
|
+
vendor_part_id='thinking',
|
|
599
|
+
signature=signature,
|
|
600
|
+
provider_name=self.provider_name,
|
|
601
|
+
)
|
|
602
|
+
|
|
593
603
|
if part.text is not None:
|
|
594
604
|
if part.thought:
|
|
595
605
|
yield self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=part.text)
|
|
@@ -629,29 +639,41 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
629
639
|
return self._timestamp
|
|
630
640
|
|
|
631
641
|
|
|
632
|
-
def _content_model_response(m: ModelResponse) -> ContentDict:
|
|
642
|
+
def _content_model_response(m: ModelResponse, provider_name: str) -> ContentDict:
|
|
633
643
|
parts: list[PartDict] = []
|
|
644
|
+
thought_signature: bytes | None = None
|
|
634
645
|
for item in m.parts:
|
|
646
|
+
part: PartDict = {}
|
|
647
|
+
if thought_signature:
|
|
648
|
+
part['thought_signature'] = thought_signature
|
|
649
|
+
thought_signature = None
|
|
650
|
+
|
|
635
651
|
if isinstance(item, ToolCallPart):
|
|
636
652
|
function_call = FunctionCallDict(name=item.tool_name, args=item.args_as_dict(), id=item.tool_call_id)
|
|
637
|
-
|
|
653
|
+
part['function_call'] = function_call
|
|
638
654
|
elif isinstance(item, TextPart):
|
|
639
|
-
|
|
640
|
-
elif isinstance(item, ThinkingPart):
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
655
|
+
part['text'] = item.content
|
|
656
|
+
elif isinstance(item, ThinkingPart):
|
|
657
|
+
if item.provider_name == provider_name and item.signature:
|
|
658
|
+
# The thought signature is to be included on the _next_ part, not the thought part itself
|
|
659
|
+
thought_signature = base64.b64decode(item.signature)
|
|
660
|
+
|
|
661
|
+
if item.content:
|
|
662
|
+
part['text'] = item.content
|
|
663
|
+
part['thought'] = True
|
|
645
664
|
elif isinstance(item, BuiltinToolCallPart):
|
|
646
|
-
if item.provider_name ==
|
|
665
|
+
if item.provider_name == provider_name:
|
|
647
666
|
if item.tool_name == 'code_execution': # pragma: no branch
|
|
648
|
-
|
|
667
|
+
part['executable_code'] = cast(ExecutableCodeDict, item.args)
|
|
649
668
|
elif isinstance(item, BuiltinToolReturnPart):
|
|
650
|
-
if item.provider_name ==
|
|
669
|
+
if item.provider_name == provider_name:
|
|
651
670
|
if item.tool_name == 'code_execution': # pragma: no branch
|
|
652
|
-
|
|
671
|
+
part['code_execution_result'] = item.content
|
|
653
672
|
else:
|
|
654
673
|
assert_never(item)
|
|
674
|
+
|
|
675
|
+
if part:
|
|
676
|
+
parts.append(part)
|
|
655
677
|
return ContentDict(role='model', parts=parts)
|
|
656
678
|
|
|
657
679
|
|
|
@@ -665,37 +687,43 @@ def _process_response_from_parts(
|
|
|
665
687
|
finish_reason: FinishReason | None = None,
|
|
666
688
|
) -> ModelResponse:
|
|
667
689
|
items: list[ModelResponsePart] = []
|
|
690
|
+
item: ModelResponsePart | None = None
|
|
668
691
|
for part in parts:
|
|
692
|
+
if part.thought_signature:
|
|
693
|
+
signature = base64.b64encode(part.thought_signature).decode('utf-8')
|
|
694
|
+
if not isinstance(item, ThinkingPart):
|
|
695
|
+
item = ThinkingPart(content='')
|
|
696
|
+
items.append(item)
|
|
697
|
+
item.signature = signature
|
|
698
|
+
item.provider_name = provider_name
|
|
699
|
+
|
|
669
700
|
if part.executable_code is not None:
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
provider_name='google', args=part.executable_code.model_dump(), tool_name='code_execution'
|
|
673
|
-
)
|
|
701
|
+
item = BuiltinToolCallPart(
|
|
702
|
+
provider_name=provider_name, args=part.executable_code.model_dump(), tool_name='code_execution'
|
|
674
703
|
)
|
|
675
704
|
elif part.code_execution_result is not None:
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
tool_call_id='not_provided',
|
|
682
|
-
)
|
|
705
|
+
item = BuiltinToolReturnPart(
|
|
706
|
+
provider_name=provider_name,
|
|
707
|
+
tool_name='code_execution',
|
|
708
|
+
content=part.code_execution_result,
|
|
709
|
+
tool_call_id='not_provided',
|
|
683
710
|
)
|
|
684
711
|
elif part.text is not None:
|
|
685
712
|
if part.thought:
|
|
686
|
-
|
|
713
|
+
item = ThinkingPart(content=part.text)
|
|
687
714
|
else:
|
|
688
|
-
|
|
715
|
+
item = TextPart(content=part.text)
|
|
689
716
|
elif part.function_call:
|
|
690
717
|
assert part.function_call.name is not None
|
|
691
|
-
|
|
718
|
+
item = ToolCallPart(tool_name=part.function_call.name, args=part.function_call.args)
|
|
692
719
|
if part.function_call.id is not None:
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
elif part.function_response: # pragma: no cover
|
|
720
|
+
item.tool_call_id = part.function_call.id # pragma: no cover
|
|
721
|
+
else: # pragma: no cover
|
|
696
722
|
raise UnexpectedModelBehavior(
|
|
697
|
-
f'Unsupported response from Gemini, expected all parts to be function calls or
|
|
723
|
+
f'Unsupported response from Gemini, expected all parts to be function calls, text, or thoughts, got: {part!r}'
|
|
698
724
|
)
|
|
725
|
+
|
|
726
|
+
items.append(item)
|
|
699
727
|
return ModelResponse(
|
|
700
728
|
parts=items,
|
|
701
729
|
model_name=model_name,
|