pydantic-ai-slim 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_ai/_agent_graph.py +50 -31
- pydantic_ai/_output.py +19 -7
- pydantic_ai/_parts_manager.py +8 -10
- pydantic_ai/_tool_manager.py +21 -0
- pydantic_ai/ag_ui.py +32 -17
- pydantic_ai/agent/__init__.py +3 -0
- pydantic_ai/agent/abstract.py +8 -0
- pydantic_ai/durable_exec/dbos/__init__.py +6 -0
- pydantic_ai/durable_exec/dbos/_agent.py +721 -0
- pydantic_ai/durable_exec/dbos/_mcp_server.py +89 -0
- pydantic_ai/durable_exec/dbos/_model.py +137 -0
- pydantic_ai/durable_exec/dbos/_utils.py +10 -0
- pydantic_ai/durable_exec/temporal/_agent.py +1 -1
- pydantic_ai/mcp.py +1 -1
- pydantic_ai/messages.py +42 -6
- pydantic_ai/models/__init__.py +8 -0
- pydantic_ai/models/anthropic.py +79 -25
- pydantic_ai/models/bedrock.py +82 -31
- pydantic_ai/models/cohere.py +39 -13
- pydantic_ai/models/function.py +8 -1
- pydantic_ai/models/google.py +105 -37
- pydantic_ai/models/groq.py +35 -7
- pydantic_ai/models/huggingface.py +27 -5
- pydantic_ai/models/instrumented.py +27 -14
- pydantic_ai/models/mistral.py +54 -20
- pydantic_ai/models/openai.py +151 -57
- pydantic_ai/profiles/openai.py +7 -0
- pydantic_ai/providers/bedrock.py +20 -4
- pydantic_ai/settings.py +1 -0
- pydantic_ai/tools.py +11 -0
- pydantic_ai/toolsets/function.py +7 -0
- {pydantic_ai_slim-1.0.1.dist-info → pydantic_ai_slim-1.0.3.dist-info}/METADATA +8 -6
- {pydantic_ai_slim-1.0.1.dist-info → pydantic_ai_slim-1.0.3.dist-info}/RECORD +36 -31
- {pydantic_ai_slim-1.0.1.dist-info → pydantic_ai_slim-1.0.3.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-1.0.1.dist-info → pydantic_ai_slim-1.0.3.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-1.0.1.dist-info → pydantic_ai_slim-1.0.3.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/bedrock.py
CHANGED
|
@@ -22,6 +22,7 @@ from pydantic_ai.messages import (
|
|
|
22
22
|
BuiltinToolCallPart,
|
|
23
23
|
BuiltinToolReturnPart,
|
|
24
24
|
DocumentUrl,
|
|
25
|
+
FinishReason,
|
|
25
26
|
ImageUrl,
|
|
26
27
|
ModelMessage,
|
|
27
28
|
ModelRequest,
|
|
@@ -48,6 +49,7 @@ if TYPE_CHECKING:
|
|
|
48
49
|
from botocore.client import BaseClient
|
|
49
50
|
from botocore.eventstream import EventStream
|
|
50
51
|
from mypy_boto3_bedrock_runtime import BedrockRuntimeClient
|
|
52
|
+
from mypy_boto3_bedrock_runtime.literals import StopReasonType
|
|
51
53
|
from mypy_boto3_bedrock_runtime.type_defs import (
|
|
52
54
|
ContentBlockOutputTypeDef,
|
|
53
55
|
ContentBlockUnionTypeDef,
|
|
@@ -55,6 +57,7 @@ if TYPE_CHECKING:
|
|
|
55
57
|
ConverseResponseTypeDef,
|
|
56
58
|
ConverseStreamMetadataEventTypeDef,
|
|
57
59
|
ConverseStreamOutputTypeDef,
|
|
60
|
+
ConverseStreamResponseTypeDef,
|
|
58
61
|
DocumentBlockTypeDef,
|
|
59
62
|
GuardrailConfigurationTypeDef,
|
|
60
63
|
ImageBlockTypeDef,
|
|
@@ -63,7 +66,6 @@ if TYPE_CHECKING:
|
|
|
63
66
|
PerformanceConfigurationTypeDef,
|
|
64
67
|
PromptVariableValuesTypeDef,
|
|
65
68
|
ReasoningContentBlockOutputTypeDef,
|
|
66
|
-
ReasoningTextBlockTypeDef,
|
|
67
69
|
SystemContentBlockTypeDef,
|
|
68
70
|
ToolChoiceTypeDef,
|
|
69
71
|
ToolConfigurationTypeDef,
|
|
@@ -135,6 +137,15 @@ See [the Bedrock docs](https://docs.aws.amazon.com/bedrock/latest/userguide/mode
|
|
|
135
137
|
P = ParamSpec('P')
|
|
136
138
|
T = typing.TypeVar('T')
|
|
137
139
|
|
|
140
|
+
_FINISH_REASON_MAP: dict[StopReasonType, FinishReason] = {
|
|
141
|
+
'content_filtered': 'content_filter',
|
|
142
|
+
'end_turn': 'stop',
|
|
143
|
+
'guardrail_intervened': 'content_filter',
|
|
144
|
+
'max_tokens': 'length',
|
|
145
|
+
'stop_sequence': 'stop',
|
|
146
|
+
'tool_use': 'tool_call',
|
|
147
|
+
}
|
|
148
|
+
|
|
138
149
|
|
|
139
150
|
class BedrockModelSettings(ModelSettings, total=False):
|
|
140
151
|
"""Settings for Bedrock models.
|
|
@@ -270,8 +281,9 @@ class BedrockConverseModel(Model):
|
|
|
270
281
|
yield BedrockStreamedResponse(
|
|
271
282
|
model_request_parameters=model_request_parameters,
|
|
272
283
|
_model_name=self.model_name,
|
|
273
|
-
_event_stream=response,
|
|
284
|
+
_event_stream=response['stream'],
|
|
274
285
|
_provider_name=self._provider.name,
|
|
286
|
+
_provider_response_id=response.get('ResponseMetadata', {}).get('RequestId', None),
|
|
275
287
|
)
|
|
276
288
|
|
|
277
289
|
async def _process_response(self, response: ConverseResponseTypeDef) -> ModelResponse:
|
|
@@ -279,13 +291,24 @@ class BedrockConverseModel(Model):
|
|
|
279
291
|
if message := response['output'].get('message'): # pragma: no branch
|
|
280
292
|
for item in message['content']:
|
|
281
293
|
if reasoning_content := item.get('reasoningContent'):
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
294
|
+
if redacted_content := reasoning_content.get('redactedContent'):
|
|
295
|
+
items.append(
|
|
296
|
+
ThinkingPart(
|
|
297
|
+
id='redacted_content',
|
|
298
|
+
content='',
|
|
299
|
+
signature=redacted_content.decode('utf-8'),
|
|
300
|
+
provider_name=self.system,
|
|
301
|
+
)
|
|
302
|
+
)
|
|
303
|
+
elif reasoning_text := reasoning_content.get('reasoningText'): # pragma: no branch
|
|
304
|
+
signature = reasoning_text.get('signature')
|
|
305
|
+
items.append(
|
|
306
|
+
ThinkingPart(
|
|
307
|
+
content=reasoning_text['text'],
|
|
308
|
+
signature=signature,
|
|
309
|
+
provider_name=self.system if signature else None,
|
|
310
|
+
)
|
|
287
311
|
)
|
|
288
|
-
items.append(thinking_part)
|
|
289
312
|
if text := item.get('text'):
|
|
290
313
|
items.append(TextPart(content=text))
|
|
291
314
|
elif tool_use := item.get('toolUse'):
|
|
@@ -301,12 +324,18 @@ class BedrockConverseModel(Model):
|
|
|
301
324
|
output_tokens=response['usage']['outputTokens'],
|
|
302
325
|
)
|
|
303
326
|
response_id = response.get('ResponseMetadata', {}).get('RequestId', None)
|
|
327
|
+
raw_finish_reason = response['stopReason']
|
|
328
|
+
provider_details = {'finish_reason': raw_finish_reason}
|
|
329
|
+
finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
|
|
330
|
+
|
|
304
331
|
return ModelResponse(
|
|
305
332
|
parts=items,
|
|
306
333
|
usage=u,
|
|
307
334
|
model_name=self.model_name,
|
|
308
335
|
provider_response_id=response_id,
|
|
309
336
|
provider_name=self._provider.name,
|
|
337
|
+
finish_reason=finish_reason,
|
|
338
|
+
provider_details=provider_details,
|
|
310
339
|
)
|
|
311
340
|
|
|
312
341
|
@overload
|
|
@@ -316,7 +345,7 @@ class BedrockConverseModel(Model):
|
|
|
316
345
|
stream: Literal[True],
|
|
317
346
|
model_settings: BedrockModelSettings | None,
|
|
318
347
|
model_request_parameters: ModelRequestParameters,
|
|
319
|
-
) ->
|
|
348
|
+
) -> ConverseStreamResponseTypeDef:
|
|
320
349
|
pass
|
|
321
350
|
|
|
322
351
|
@overload
|
|
@@ -335,7 +364,7 @@ class BedrockConverseModel(Model):
|
|
|
335
364
|
stream: bool,
|
|
336
365
|
model_settings: BedrockModelSettings | None,
|
|
337
366
|
model_request_parameters: ModelRequestParameters,
|
|
338
|
-
) -> ConverseResponseTypeDef |
|
|
367
|
+
) -> ConverseResponseTypeDef | ConverseStreamResponseTypeDef:
|
|
339
368
|
system_prompt, bedrock_messages = await self._map_messages(messages)
|
|
340
369
|
inference_config = self._map_inference_config(model_settings)
|
|
341
370
|
|
|
@@ -372,7 +401,6 @@ class BedrockConverseModel(Model):
|
|
|
372
401
|
|
|
373
402
|
if stream:
|
|
374
403
|
model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse_stream, **params))
|
|
375
|
-
model_response = model_response['stream']
|
|
376
404
|
else:
|
|
377
405
|
model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse, **params))
|
|
378
406
|
return model_response
|
|
@@ -476,19 +504,26 @@ class BedrockConverseModel(Model):
|
|
|
476
504
|
if isinstance(item, TextPart):
|
|
477
505
|
content.append({'text': item.content})
|
|
478
506
|
elif isinstance(item, ThinkingPart):
|
|
479
|
-
if
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
507
|
+
if (
|
|
508
|
+
item.provider_name == self.system
|
|
509
|
+
and item.signature
|
|
510
|
+
and BedrockModelProfile.from_profile(self.profile).bedrock_send_back_thinking_parts
|
|
511
|
+
):
|
|
512
|
+
if item.id == 'redacted_content':
|
|
513
|
+
reasoning_content: ReasoningContentBlockOutputTypeDef = {
|
|
514
|
+
'redactedContent': item.signature.encode('utf-8'),
|
|
515
|
+
}
|
|
516
|
+
else:
|
|
517
|
+
reasoning_content: ReasoningContentBlockOutputTypeDef = {
|
|
518
|
+
'reasoningText': {
|
|
519
|
+
'text': item.content,
|
|
520
|
+
'signature': item.signature,
|
|
521
|
+
}
|
|
522
|
+
}
|
|
488
523
|
content.append({'reasoningContent': reasoning_content})
|
|
489
524
|
else:
|
|
490
|
-
|
|
491
|
-
|
|
525
|
+
start_tag, end_tag = self.profile.thinking_tags
|
|
526
|
+
content.append({'text': '\n'.join([start_tag, item.content, end_tag])})
|
|
492
527
|
elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart):
|
|
493
528
|
pass
|
|
494
529
|
else:
|
|
@@ -599,25 +634,30 @@ class BedrockStreamedResponse(StreamedResponse):
|
|
|
599
634
|
_event_stream: EventStream[ConverseStreamOutputTypeDef]
|
|
600
635
|
_provider_name: str
|
|
601
636
|
_timestamp: datetime = field(default_factory=_utils.now_utc)
|
|
637
|
+
_provider_response_id: str | None = None
|
|
602
638
|
|
|
603
|
-
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
639
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
|
|
604
640
|
"""Return an async iterator of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.
|
|
605
641
|
|
|
606
642
|
This method should be implemented by subclasses to translate the vendor-specific stream of events into
|
|
607
643
|
pydantic_ai-format events.
|
|
608
644
|
"""
|
|
645
|
+
if self._provider_response_id is not None: # pragma: no cover
|
|
646
|
+
self.provider_response_id = self._provider_response_id
|
|
647
|
+
|
|
609
648
|
chunk: ConverseStreamOutputTypeDef
|
|
610
649
|
tool_id: str | None = None
|
|
611
650
|
async for chunk in _AsyncIteratorWrapper(self._event_stream):
|
|
612
651
|
match chunk:
|
|
613
652
|
case {'messageStart': _}:
|
|
614
653
|
continue
|
|
615
|
-
case {'messageStop':
|
|
616
|
-
|
|
654
|
+
case {'messageStop': message_stop}:
|
|
655
|
+
raw_finish_reason = message_stop['stopReason']
|
|
656
|
+
self.provider_details = {'finish_reason': raw_finish_reason}
|
|
657
|
+
self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
|
|
617
658
|
case {'metadata': metadata}:
|
|
618
659
|
if 'usage' in metadata: # pragma: no branch
|
|
619
660
|
self._usage += self._map_usage(metadata)
|
|
620
|
-
continue
|
|
621
661
|
case {'contentBlockStart': content_block_start}:
|
|
622
662
|
index = content_block_start['contentBlockIndex']
|
|
623
663
|
start = content_block_start['start']
|
|
@@ -637,11 +677,22 @@ class BedrockStreamedResponse(StreamedResponse):
|
|
|
637
677
|
index = content_block_delta['contentBlockIndex']
|
|
638
678
|
delta = content_block_delta['delta']
|
|
639
679
|
if 'reasoningContent' in delta:
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
680
|
+
if redacted_content := delta['reasoningContent'].get('redactedContent'):
|
|
681
|
+
yield self._parts_manager.handle_thinking_delta(
|
|
682
|
+
vendor_part_id=index,
|
|
683
|
+
id='redacted_content',
|
|
684
|
+
content='',
|
|
685
|
+
signature=redacted_content.decode('utf-8'),
|
|
686
|
+
provider_name=self.provider_name,
|
|
687
|
+
)
|
|
688
|
+
else:
|
|
689
|
+
signature = delta['reasoningContent'].get('signature')
|
|
690
|
+
yield self._parts_manager.handle_thinking_delta(
|
|
691
|
+
vendor_part_id=index,
|
|
692
|
+
content=delta['reasoningContent'].get('text'),
|
|
693
|
+
signature=signature,
|
|
694
|
+
provider_name=self.provider_name if signature else None,
|
|
695
|
+
)
|
|
645
696
|
if 'text' in delta:
|
|
646
697
|
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=delta['text'])
|
|
647
698
|
if maybe_event is not None: # pragma: no branch
|
pydantic_ai/models/cohere.py
CHANGED
|
@@ -6,7 +6,6 @@ from typing import Literal, cast
|
|
|
6
6
|
|
|
7
7
|
from typing_extensions import assert_never
|
|
8
8
|
|
|
9
|
-
from pydantic_ai._thinking_part import split_content_into_text_and_thinking
|
|
10
9
|
from pydantic_ai.exceptions import UserError
|
|
11
10
|
|
|
12
11
|
from .. import ModelHTTPError, usage
|
|
@@ -14,6 +13,7 @@ from .._utils import generate_tool_call_id as _generate_tool_call_id, guard_tool
|
|
|
14
13
|
from ..messages import (
|
|
15
14
|
BuiltinToolCallPart,
|
|
16
15
|
BuiltinToolReturnPart,
|
|
16
|
+
FinishReason,
|
|
17
17
|
ModelMessage,
|
|
18
18
|
ModelRequest,
|
|
19
19
|
ModelResponse,
|
|
@@ -35,10 +35,13 @@ from . import Model, ModelRequestParameters, check_allow_model_requests
|
|
|
35
35
|
try:
|
|
36
36
|
from cohere import (
|
|
37
37
|
AssistantChatMessageV2,
|
|
38
|
+
AssistantMessageV2ContentItem,
|
|
38
39
|
AsyncClientV2,
|
|
40
|
+
ChatFinishReason,
|
|
39
41
|
ChatMessageV2,
|
|
40
42
|
SystemChatMessageV2,
|
|
41
43
|
TextAssistantMessageV2ContentItem,
|
|
44
|
+
ThinkingAssistantMessageV2ContentItem,
|
|
42
45
|
ToolCallV2,
|
|
43
46
|
ToolCallV2Function,
|
|
44
47
|
ToolChatMessageV2,
|
|
@@ -80,6 +83,14 @@ allow any name in the type hints.
|
|
|
80
83
|
See [Cohere's docs](https://docs.cohere.com/v2/docs/models) for a list of all available models.
|
|
81
84
|
"""
|
|
82
85
|
|
|
86
|
+
_FINISH_REASON_MAP: dict[ChatFinishReason, FinishReason] = {
|
|
87
|
+
'COMPLETE': 'stop',
|
|
88
|
+
'STOP_SEQUENCE': 'stop',
|
|
89
|
+
'MAX_TOKENS': 'length',
|
|
90
|
+
'TOOL_CALL': 'tool_call',
|
|
91
|
+
'ERROR': 'error',
|
|
92
|
+
}
|
|
93
|
+
|
|
83
94
|
|
|
84
95
|
class CohereModelSettings(ModelSettings, total=False):
|
|
85
96
|
"""Settings used for a Cohere model request."""
|
|
@@ -191,11 +202,12 @@ class CohereModel(Model):
|
|
|
191
202
|
def _process_response(self, response: V2ChatResponse) -> ModelResponse:
|
|
192
203
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
193
204
|
parts: list[ModelResponsePart] = []
|
|
194
|
-
if response.message.content is not None
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
205
|
+
if response.message.content is not None:
|
|
206
|
+
for content in response.message.content:
|
|
207
|
+
if content.type == 'text':
|
|
208
|
+
parts.append(TextPart(content=content.text))
|
|
209
|
+
elif content.type == 'thinking': # pragma: no branch
|
|
210
|
+
parts.append(ThinkingPart(content=cast(str, content.thinking))) # pyright: ignore[reportUnknownMemberType,reportAttributeAccessIssue] - https://github.com/cohere-ai/cohere-python/issues/692
|
|
199
211
|
for c in response.message.tool_calls or []:
|
|
200
212
|
if c.function and c.function.name and c.function.arguments: # pragma: no branch
|
|
201
213
|
parts.append(
|
|
@@ -205,8 +217,18 @@ class CohereModel(Model):
|
|
|
205
217
|
tool_call_id=c.id or _generate_tool_call_id(),
|
|
206
218
|
)
|
|
207
219
|
)
|
|
220
|
+
|
|
221
|
+
raw_finish_reason = response.finish_reason
|
|
222
|
+
provider_details = {'finish_reason': raw_finish_reason}
|
|
223
|
+
finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
|
|
224
|
+
|
|
208
225
|
return ModelResponse(
|
|
209
|
-
parts=parts,
|
|
226
|
+
parts=parts,
|
|
227
|
+
usage=_map_usage(response),
|
|
228
|
+
model_name=self._model_name,
|
|
229
|
+
provider_name=self._provider.name,
|
|
230
|
+
finish_reason=finish_reason,
|
|
231
|
+
provider_details=provider_details,
|
|
210
232
|
)
|
|
211
233
|
|
|
212
234
|
def _map_messages(self, messages: list[ModelMessage]) -> list[ChatMessageV2]:
|
|
@@ -217,15 +239,13 @@ class CohereModel(Model):
|
|
|
217
239
|
cohere_messages.extend(self._map_user_message(message))
|
|
218
240
|
elif isinstance(message, ModelResponse):
|
|
219
241
|
texts: list[str] = []
|
|
242
|
+
thinking: list[str] = []
|
|
220
243
|
tool_calls: list[ToolCallV2] = []
|
|
221
244
|
for item in message.parts:
|
|
222
245
|
if isinstance(item, TextPart):
|
|
223
246
|
texts.append(item.content)
|
|
224
247
|
elif isinstance(item, ThinkingPart):
|
|
225
|
-
|
|
226
|
-
# please open an issue. The below code is the code to send thinking to the provider.
|
|
227
|
-
# texts.append(f'<think>\n{item.content}\n</think>')
|
|
228
|
-
pass
|
|
248
|
+
thinking.append(item.content)
|
|
229
249
|
elif isinstance(item, ToolCallPart):
|
|
230
250
|
tool_calls.append(self._map_tool_call(item))
|
|
231
251
|
elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
|
|
@@ -233,9 +253,15 @@ class CohereModel(Model):
|
|
|
233
253
|
pass
|
|
234
254
|
else:
|
|
235
255
|
assert_never(item)
|
|
256
|
+
|
|
236
257
|
message_param = AssistantChatMessageV2(role='assistant')
|
|
237
|
-
if texts:
|
|
238
|
-
|
|
258
|
+
if texts or thinking:
|
|
259
|
+
contents: list[AssistantMessageV2ContentItem] = []
|
|
260
|
+
if thinking:
|
|
261
|
+
contents.append(ThinkingAssistantMessageV2ContentItem(thinking='\n\n'.join(thinking))) # pyright: ignore[reportCallIssue] - https://github.com/cohere-ai/cohere-python/issues/692
|
|
262
|
+
if texts: # pragma: no branch
|
|
263
|
+
contents.append(TextAssistantMessageV2ContentItem(text='\n\n'.join(texts)))
|
|
264
|
+
message_param.content = contents
|
|
239
265
|
if tool_calls:
|
|
240
266
|
message_param.tool_calls = tool_calls
|
|
241
267
|
cohere_messages.append(message_param)
|
pydantic_ai/models/function.py
CHANGED
|
@@ -31,7 +31,7 @@ from ..messages import (
|
|
|
31
31
|
UserContent,
|
|
32
32
|
UserPromptPart,
|
|
33
33
|
)
|
|
34
|
-
from ..profiles import ModelProfileSpec
|
|
34
|
+
from ..profiles import ModelProfile, ModelProfileSpec
|
|
35
35
|
from ..settings import ModelSettings
|
|
36
36
|
from ..tools import ToolDefinition
|
|
37
37
|
from . import Model, ModelRequestParameters, StreamedResponse
|
|
@@ -111,6 +111,12 @@ class FunctionModel(Model):
|
|
|
111
111
|
stream_function_name = self.stream_function.__name__ if self.stream_function is not None else ''
|
|
112
112
|
self._model_name = model_name or f'function:{function_name}:{stream_function_name}'
|
|
113
113
|
|
|
114
|
+
# Use a default profile that supports JSON schema and object output if none provided
|
|
115
|
+
if profile is None:
|
|
116
|
+
profile = ModelProfile(
|
|
117
|
+
supports_json_schema_output=True,
|
|
118
|
+
supports_json_object_output=True,
|
|
119
|
+
)
|
|
114
120
|
super().__init__(settings=settings, profile=profile)
|
|
115
121
|
|
|
116
122
|
async def request(
|
|
@@ -285,6 +291,7 @@ class FunctionStreamedResponse(StreamedResponse):
|
|
|
285
291
|
vendor_part_id=dtc_index,
|
|
286
292
|
content=delta.content,
|
|
287
293
|
signature=delta.signature,
|
|
294
|
+
provider_name='function' if delta.signature else None,
|
|
288
295
|
)
|
|
289
296
|
elif isinstance(delta, DeltaToolCall):
|
|
290
297
|
if delta.json_args:
|
pydantic_ai/models/google.py
CHANGED
|
@@ -20,6 +20,7 @@ from ..messages import (
|
|
|
20
20
|
BuiltinToolCallPart,
|
|
21
21
|
BuiltinToolReturnPart,
|
|
22
22
|
FileUrl,
|
|
23
|
+
FinishReason,
|
|
23
24
|
ModelMessage,
|
|
24
25
|
ModelRequest,
|
|
25
26
|
ModelResponse,
|
|
@@ -54,6 +55,7 @@ try:
|
|
|
54
55
|
ContentUnionDict,
|
|
55
56
|
CountTokensConfigDict,
|
|
56
57
|
ExecutableCodeDict,
|
|
58
|
+
FinishReason as GoogleFinishReason,
|
|
57
59
|
FunctionCallDict,
|
|
58
60
|
FunctionCallingConfigDict,
|
|
59
61
|
FunctionCallingConfigMode,
|
|
@@ -99,6 +101,22 @@ allow any name in the type hints.
|
|
|
99
101
|
See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#model-variations) for a full list.
|
|
100
102
|
"""
|
|
101
103
|
|
|
104
|
+
_FINISH_REASON_MAP: dict[GoogleFinishReason, FinishReason | None] = {
|
|
105
|
+
GoogleFinishReason.FINISH_REASON_UNSPECIFIED: None,
|
|
106
|
+
GoogleFinishReason.STOP: 'stop',
|
|
107
|
+
GoogleFinishReason.MAX_TOKENS: 'length',
|
|
108
|
+
GoogleFinishReason.SAFETY: 'content_filter',
|
|
109
|
+
GoogleFinishReason.RECITATION: 'content_filter',
|
|
110
|
+
GoogleFinishReason.LANGUAGE: 'error',
|
|
111
|
+
GoogleFinishReason.OTHER: None,
|
|
112
|
+
GoogleFinishReason.BLOCKLIST: 'content_filter',
|
|
113
|
+
GoogleFinishReason.PROHIBITED_CONTENT: 'content_filter',
|
|
114
|
+
GoogleFinishReason.SPII: 'content_filter',
|
|
115
|
+
GoogleFinishReason.MALFORMED_FUNCTION_CALL: 'error',
|
|
116
|
+
GoogleFinishReason.IMAGE_SAFETY: 'content_filter',
|
|
117
|
+
GoogleFinishReason.UNEXPECTED_TOOL_CALL: 'error',
|
|
118
|
+
}
|
|
119
|
+
|
|
102
120
|
|
|
103
121
|
class GoogleModelSettings(ModelSettings, total=False):
|
|
104
122
|
"""Settings used for a Gemini model request."""
|
|
@@ -129,6 +147,12 @@ class GoogleModelSettings(ModelSettings, total=False):
|
|
|
129
147
|
See <https://ai.google.dev/api/generate-content#MediaResolution> for more information.
|
|
130
148
|
"""
|
|
131
149
|
|
|
150
|
+
google_cached_content: str
|
|
151
|
+
"""The name of the cached content to use for the model.
|
|
152
|
+
|
|
153
|
+
See <https://ai.google.dev/gemini-api/docs/caching> for more information.
|
|
154
|
+
"""
|
|
155
|
+
|
|
132
156
|
|
|
133
157
|
@dataclass(init=False)
|
|
134
158
|
class GoogleModel(Model):
|
|
@@ -230,6 +254,7 @@ class GoogleModel(Model):
|
|
|
230
254
|
stop_sequences=generation_config.get('stop_sequences'),
|
|
231
255
|
presence_penalty=generation_config.get('presence_penalty'),
|
|
232
256
|
frequency_penalty=generation_config.get('frequency_penalty'),
|
|
257
|
+
seed=generation_config.get('seed'),
|
|
233
258
|
thinking_config=generation_config.get('thinking_config'),
|
|
234
259
|
media_resolution=generation_config.get('media_resolution'),
|
|
235
260
|
response_mime_type=generation_config.get('response_mime_type'),
|
|
@@ -373,10 +398,12 @@ class GoogleModel(Model):
|
|
|
373
398
|
stop_sequences=model_settings.get('stop_sequences'),
|
|
374
399
|
presence_penalty=model_settings.get('presence_penalty'),
|
|
375
400
|
frequency_penalty=model_settings.get('frequency_penalty'),
|
|
401
|
+
seed=model_settings.get('seed'),
|
|
376
402
|
safety_settings=model_settings.get('google_safety_settings'),
|
|
377
403
|
thinking_config=model_settings.get('google_thinking_config'),
|
|
378
404
|
labels=model_settings.get('google_labels'),
|
|
379
405
|
media_resolution=model_settings.get('google_video_resolution'),
|
|
406
|
+
cached_content=model_settings.get('google_cached_content'),
|
|
380
407
|
tools=cast(ToolListUnionDict, tools),
|
|
381
408
|
tool_config=tool_config,
|
|
382
409
|
response_mime_type=response_mime_type,
|
|
@@ -396,11 +423,14 @@ class GoogleModel(Model):
|
|
|
396
423
|
'Content field missing from Gemini response', str(response)
|
|
397
424
|
) # pragma: no cover
|
|
398
425
|
parts = candidate.content.parts or []
|
|
399
|
-
|
|
426
|
+
|
|
427
|
+
vendor_id = response.response_id
|
|
400
428
|
vendor_details: dict[str, Any] | None = None
|
|
401
|
-
finish_reason =
|
|
402
|
-
if finish_reason: # pragma: no branch
|
|
403
|
-
vendor_details = {'finish_reason':
|
|
429
|
+
finish_reason: FinishReason | None = None
|
|
430
|
+
if raw_finish_reason := candidate.finish_reason: # pragma: no branch
|
|
431
|
+
vendor_details = {'finish_reason': raw_finish_reason.value}
|
|
432
|
+
finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
|
|
433
|
+
|
|
404
434
|
usage = _metadata_as_usage(response)
|
|
405
435
|
return _process_response_from_parts(
|
|
406
436
|
parts,
|
|
@@ -409,6 +439,7 @@ class GoogleModel(Model):
|
|
|
409
439
|
usage,
|
|
410
440
|
vendor_id=vendor_id,
|
|
411
441
|
vendor_details=vendor_details,
|
|
442
|
+
finish_reason=finish_reason,
|
|
412
443
|
)
|
|
413
444
|
|
|
414
445
|
async def _process_streamed_response(
|
|
@@ -422,7 +453,7 @@ class GoogleModel(Model):
|
|
|
422
453
|
|
|
423
454
|
return GeminiStreamedResponse(
|
|
424
455
|
model_request_parameters=model_request_parameters,
|
|
425
|
-
_model_name=self._model_name,
|
|
456
|
+
_model_name=first_chunk.model_version or self._model_name,
|
|
426
457
|
_response=peekable_response,
|
|
427
458
|
_timestamp=first_chunk.create_time or _utils.now_utc(),
|
|
428
459
|
_provider_name=self._provider.name,
|
|
@@ -472,7 +503,7 @@ class GoogleModel(Model):
|
|
|
472
503
|
message_parts = [{'text': ''}]
|
|
473
504
|
contents.append({'role': 'user', 'parts': message_parts})
|
|
474
505
|
elif isinstance(m, ModelResponse):
|
|
475
|
-
contents.append(_content_model_response(m))
|
|
506
|
+
contents.append(_content_model_response(m, self.system))
|
|
476
507
|
else:
|
|
477
508
|
assert_never(m)
|
|
478
509
|
if instructions := self._get_instructions(messages):
|
|
@@ -537,12 +568,20 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
537
568
|
_timestamp: datetime
|
|
538
569
|
_provider_name: str
|
|
539
570
|
|
|
540
|
-
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
571
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
|
|
541
572
|
async for chunk in self._response:
|
|
542
573
|
self._usage = _metadata_as_usage(chunk)
|
|
543
574
|
|
|
544
575
|
assert chunk.candidates is not None
|
|
545
576
|
candidate = chunk.candidates[0]
|
|
577
|
+
|
|
578
|
+
if chunk.response_id: # pragma: no branch
|
|
579
|
+
self.provider_response_id = chunk.response_id
|
|
580
|
+
|
|
581
|
+
if raw_finish_reason := candidate.finish_reason:
|
|
582
|
+
self.provider_details = {'finish_reason': raw_finish_reason.value}
|
|
583
|
+
self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
|
|
584
|
+
|
|
546
585
|
if candidate.content is None or candidate.content.parts is None:
|
|
547
586
|
if candidate.finish_reason == 'STOP': # pragma: no cover
|
|
548
587
|
# Normal completion - skip this chunk
|
|
@@ -553,6 +592,15 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
553
592
|
raise UnexpectedModelBehavior('Content field missing from streaming Gemini response', str(chunk))
|
|
554
593
|
parts = candidate.content.parts or []
|
|
555
594
|
for part in parts:
|
|
595
|
+
if part.thought_signature:
|
|
596
|
+
signature = base64.b64encode(part.thought_signature).decode('utf-8')
|
|
597
|
+
yield self._parts_manager.handle_thinking_delta(
|
|
598
|
+
vendor_part_id='thinking',
|
|
599
|
+
content='', # A thought signature may occur without a preceding thinking part, so we add an empty delta so that a new part can be created
|
|
600
|
+
signature=signature,
|
|
601
|
+
provider_name=self.provider_name,
|
|
602
|
+
)
|
|
603
|
+
|
|
556
604
|
if part.text is not None:
|
|
557
605
|
if part.thought:
|
|
558
606
|
yield self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=part.text)
|
|
@@ -592,29 +640,41 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
592
640
|
return self._timestamp
|
|
593
641
|
|
|
594
642
|
|
|
595
|
-
def _content_model_response(m: ModelResponse) -> ContentDict:
|
|
643
|
+
def _content_model_response(m: ModelResponse, provider_name: str) -> ContentDict:
|
|
596
644
|
parts: list[PartDict] = []
|
|
645
|
+
thought_signature: bytes | None = None
|
|
597
646
|
for item in m.parts:
|
|
647
|
+
part: PartDict = {}
|
|
648
|
+
if thought_signature:
|
|
649
|
+
part['thought_signature'] = thought_signature
|
|
650
|
+
thought_signature = None
|
|
651
|
+
|
|
598
652
|
if isinstance(item, ToolCallPart):
|
|
599
653
|
function_call = FunctionCallDict(name=item.tool_name, args=item.args_as_dict(), id=item.tool_call_id)
|
|
600
|
-
|
|
654
|
+
part['function_call'] = function_call
|
|
601
655
|
elif isinstance(item, TextPart):
|
|
602
|
-
|
|
603
|
-
elif isinstance(item, ThinkingPart):
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
656
|
+
part['text'] = item.content
|
|
657
|
+
elif isinstance(item, ThinkingPart):
|
|
658
|
+
if item.provider_name == provider_name and item.signature:
|
|
659
|
+
# The thought signature is to be included on the _next_ part, not the thought part itself
|
|
660
|
+
thought_signature = base64.b64decode(item.signature)
|
|
661
|
+
|
|
662
|
+
if item.content:
|
|
663
|
+
part['text'] = item.content
|
|
664
|
+
part['thought'] = True
|
|
608
665
|
elif isinstance(item, BuiltinToolCallPart):
|
|
609
|
-
if item.provider_name ==
|
|
666
|
+
if item.provider_name == provider_name:
|
|
610
667
|
if item.tool_name == 'code_execution': # pragma: no branch
|
|
611
|
-
|
|
668
|
+
part['executable_code'] = cast(ExecutableCodeDict, item.args)
|
|
612
669
|
elif isinstance(item, BuiltinToolReturnPart):
|
|
613
|
-
if item.provider_name ==
|
|
670
|
+
if item.provider_name == provider_name:
|
|
614
671
|
if item.tool_name == 'code_execution': # pragma: no branch
|
|
615
|
-
|
|
672
|
+
part['code_execution_result'] = item.content
|
|
616
673
|
else:
|
|
617
674
|
assert_never(item)
|
|
675
|
+
|
|
676
|
+
if part:
|
|
677
|
+
parts.append(part)
|
|
618
678
|
return ContentDict(role='model', parts=parts)
|
|
619
679
|
|
|
620
680
|
|
|
@@ -625,39 +685,46 @@ def _process_response_from_parts(
|
|
|
625
685
|
usage: usage.RequestUsage,
|
|
626
686
|
vendor_id: str | None,
|
|
627
687
|
vendor_details: dict[str, Any] | None = None,
|
|
688
|
+
finish_reason: FinishReason | None = None,
|
|
628
689
|
) -> ModelResponse:
|
|
629
690
|
items: list[ModelResponsePart] = []
|
|
691
|
+
item: ModelResponsePart | None = None
|
|
630
692
|
for part in parts:
|
|
693
|
+
if part.thought_signature:
|
|
694
|
+
signature = base64.b64encode(part.thought_signature).decode('utf-8')
|
|
695
|
+
if not isinstance(item, ThinkingPart):
|
|
696
|
+
item = ThinkingPart(content='')
|
|
697
|
+
items.append(item)
|
|
698
|
+
item.signature = signature
|
|
699
|
+
item.provider_name = provider_name
|
|
700
|
+
|
|
631
701
|
if part.executable_code is not None:
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
provider_name='google', args=part.executable_code.model_dump(), tool_name='code_execution'
|
|
635
|
-
)
|
|
702
|
+
item = BuiltinToolCallPart(
|
|
703
|
+
provider_name=provider_name, args=part.executable_code.model_dump(), tool_name='code_execution'
|
|
636
704
|
)
|
|
637
705
|
elif part.code_execution_result is not None:
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
tool_call_id='not_provided',
|
|
644
|
-
)
|
|
706
|
+
item = BuiltinToolReturnPart(
|
|
707
|
+
provider_name=provider_name,
|
|
708
|
+
tool_name='code_execution',
|
|
709
|
+
content=part.code_execution_result,
|
|
710
|
+
tool_call_id='not_provided',
|
|
645
711
|
)
|
|
646
712
|
elif part.text is not None:
|
|
647
713
|
if part.thought:
|
|
648
|
-
|
|
714
|
+
item = ThinkingPart(content=part.text)
|
|
649
715
|
else:
|
|
650
|
-
|
|
716
|
+
item = TextPart(content=part.text)
|
|
651
717
|
elif part.function_call:
|
|
652
718
|
assert part.function_call.name is not None
|
|
653
|
-
|
|
719
|
+
item = ToolCallPart(tool_name=part.function_call.name, args=part.function_call.args)
|
|
654
720
|
if part.function_call.id is not None:
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
elif part.function_response: # pragma: no cover
|
|
721
|
+
item.tool_call_id = part.function_call.id # pragma: no cover
|
|
722
|
+
else: # pragma: no cover
|
|
658
723
|
raise UnexpectedModelBehavior(
|
|
659
|
-
f'Unsupported response from Gemini, expected all parts to be function calls or
|
|
724
|
+
f'Unsupported response from Gemini, expected all parts to be function calls, text, or thoughts, got: {part!r}'
|
|
660
725
|
)
|
|
726
|
+
|
|
727
|
+
items.append(item)
|
|
661
728
|
return ModelResponse(
|
|
662
729
|
parts=items,
|
|
663
730
|
model_name=model_name,
|
|
@@ -665,6 +732,7 @@ def _process_response_from_parts(
|
|
|
665
732
|
provider_response_id=vendor_id,
|
|
666
733
|
provider_details=vendor_details,
|
|
667
734
|
provider_name=provider_name,
|
|
735
|
+
finish_reason=finish_reason,
|
|
668
736
|
)
|
|
669
737
|
|
|
670
738
|
|