pydantic-ai-slim 1.0.8__py3-none-any.whl → 1.0.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +67 -55
- pydantic_ai/_cli.py +1 -1
- pydantic_ai/_otel_messages.py +2 -0
- pydantic_ai/_parts_manager.py +82 -12
- pydantic_ai/_run_context.py +8 -1
- pydantic_ai/_tool_manager.py +1 -0
- pydantic_ai/ag_ui.py +86 -33
- pydantic_ai/agent/__init__.py +2 -1
- pydantic_ai/builtin_tools.py +12 -0
- pydantic_ai/durable_exec/temporal/_model.py +14 -6
- pydantic_ai/durable_exec/temporal/_run_context.py +2 -1
- pydantic_ai/format_prompt.py +109 -17
- pydantic_ai/messages.py +65 -30
- pydantic_ai/models/anthropic.py +119 -45
- pydantic_ai/models/function.py +17 -8
- pydantic_ai/models/google.py +132 -33
- pydantic_ai/models/groq.py +68 -17
- pydantic_ai/models/openai.py +262 -41
- pydantic_ai/providers/__init__.py +1 -1
- pydantic_ai/result.py +21 -3
- pydantic_ai/toolsets/function.py +8 -2
- {pydantic_ai_slim-1.0.8.dist-info → pydantic_ai_slim-1.0.10.dist-info}/METADATA +5 -5
- {pydantic_ai_slim-1.0.8.dist-info → pydantic_ai_slim-1.0.10.dist-info}/RECORD +26 -26
- {pydantic_ai_slim-1.0.8.dist-info → pydantic_ai_slim-1.0.10.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-1.0.8.dist-info → pydantic_ai_slim-1.0.10.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-1.0.8.dist-info → pydantic_ai_slim-1.0.10.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/anthropic.py
CHANGED
|
@@ -7,6 +7,7 @@ from dataclasses import dataclass, field
|
|
|
7
7
|
from datetime import datetime
|
|
8
8
|
from typing import Any, Literal, cast, overload
|
|
9
9
|
|
|
10
|
+
from pydantic import TypeAdapter
|
|
10
11
|
from typing_extensions import assert_never
|
|
11
12
|
|
|
12
13
|
from pydantic_ai.builtin_tools import CodeExecutionTool, WebSearchTool
|
|
@@ -60,7 +61,9 @@ try:
|
|
|
60
61
|
BetaCitationsDelta,
|
|
61
62
|
BetaCodeExecutionTool20250522Param,
|
|
62
63
|
BetaCodeExecutionToolResultBlock,
|
|
64
|
+
BetaCodeExecutionToolResultBlockContent,
|
|
63
65
|
BetaCodeExecutionToolResultBlockParam,
|
|
66
|
+
BetaCodeExecutionToolResultBlockParamContentParam,
|
|
64
67
|
BetaContentBlock,
|
|
65
68
|
BetaContentBlockParam,
|
|
66
69
|
BetaImageBlockParam,
|
|
@@ -97,7 +100,9 @@ try:
|
|
|
97
100
|
BetaToolUseBlockParam,
|
|
98
101
|
BetaWebSearchTool20250305Param,
|
|
99
102
|
BetaWebSearchToolResultBlock,
|
|
103
|
+
BetaWebSearchToolResultBlockContent,
|
|
100
104
|
BetaWebSearchToolResultBlockParam,
|
|
105
|
+
BetaWebSearchToolResultBlockParamContentParam,
|
|
101
106
|
)
|
|
102
107
|
from anthropic.types.beta.beta_web_search_tool_20250305_param import UserLocation
|
|
103
108
|
from anthropic.types.model_param import ModelParam
|
|
@@ -302,24 +307,12 @@ class AnthropicModel(Model):
|
|
|
302
307
|
for item in response.content:
|
|
303
308
|
if isinstance(item, BetaTextBlock):
|
|
304
309
|
items.append(TextPart(content=item.text))
|
|
305
|
-
elif isinstance(item, BetaWebSearchToolResultBlock | BetaCodeExecutionToolResultBlock):
|
|
306
|
-
items.append(
|
|
307
|
-
BuiltinToolReturnPart(
|
|
308
|
-
provider_name=self.system,
|
|
309
|
-
tool_name=item.type,
|
|
310
|
-
content=item.content,
|
|
311
|
-
tool_call_id=item.tool_use_id,
|
|
312
|
-
)
|
|
313
|
-
)
|
|
314
310
|
elif isinstance(item, BetaServerToolUseBlock):
|
|
315
|
-
items.append(
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
tool_call_id=item.id,
|
|
321
|
-
)
|
|
322
|
-
)
|
|
311
|
+
items.append(_map_server_tool_use_block(item, self.system))
|
|
312
|
+
elif isinstance(item, BetaWebSearchToolResultBlock):
|
|
313
|
+
items.append(_map_web_search_tool_result_block(item, self.system))
|
|
314
|
+
elif isinstance(item, BetaCodeExecutionToolResultBlock):
|
|
315
|
+
items.append(_map_code_execution_tool_result_block(item, self.system))
|
|
323
316
|
elif isinstance(item, BetaRedactedThinkingBlock):
|
|
324
317
|
items.append(
|
|
325
318
|
ThinkingPart(id='redacted_thinking', content='', signature=item.data, provider_name=self.system)
|
|
@@ -485,27 +478,54 @@ class AnthropicModel(Model):
|
|
|
485
478
|
)
|
|
486
479
|
elif isinstance(response_part, BuiltinToolCallPart):
|
|
487
480
|
if response_part.provider_name == self.system:
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
481
|
+
tool_use_id = _guard_tool_call_id(t=response_part)
|
|
482
|
+
if response_part.tool_name == WebSearchTool.kind:
|
|
483
|
+
server_tool_use_block_param = BetaServerToolUseBlockParam(
|
|
484
|
+
id=tool_use_id,
|
|
485
|
+
type='server_tool_use',
|
|
486
|
+
name='web_search',
|
|
487
|
+
input=response_part.args_as_dict(),
|
|
488
|
+
)
|
|
489
|
+
assistant_content_params.append(server_tool_use_block_param)
|
|
490
|
+
elif response_part.tool_name == CodeExecutionTool.kind: # pragma: no branch
|
|
491
|
+
server_tool_use_block_param = BetaServerToolUseBlockParam(
|
|
492
|
+
id=tool_use_id,
|
|
493
|
+
type='server_tool_use',
|
|
494
|
+
name='code_execution',
|
|
495
|
+
input=response_part.args_as_dict(),
|
|
496
|
+
)
|
|
497
|
+
assistant_content_params.append(server_tool_use_block_param)
|
|
495
498
|
elif isinstance(response_part, BuiltinToolReturnPart):
|
|
496
499
|
if response_part.provider_name == self.system:
|
|
497
500
|
tool_use_id = _guard_tool_call_id(t=response_part)
|
|
498
|
-
if response_part.tool_name
|
|
499
|
-
|
|
500
|
-
|
|
501
|
+
if response_part.tool_name in (
|
|
502
|
+
WebSearchTool.kind,
|
|
503
|
+
'web_search_tool_result', # Backward compatibility
|
|
504
|
+
) and isinstance(response_part.content, dict | list):
|
|
505
|
+
assistant_content_params.append(
|
|
506
|
+
BetaWebSearchToolResultBlockParam(
|
|
507
|
+
tool_use_id=tool_use_id,
|
|
508
|
+
type='web_search_tool_result',
|
|
509
|
+
content=cast(
|
|
510
|
+
BetaWebSearchToolResultBlockParamContentParam,
|
|
511
|
+
response_part.content, # pyright: ignore[reportUnknownMemberType]
|
|
512
|
+
),
|
|
513
|
+
)
|
|
501
514
|
)
|
|
502
|
-
elif response_part.tool_name
|
|
503
|
-
|
|
504
|
-
|
|
515
|
+
elif response_part.tool_name in ( # pragma: no branch
|
|
516
|
+
CodeExecutionTool.kind,
|
|
517
|
+
'code_execution_tool_result', # Backward compatibility
|
|
518
|
+
) and isinstance(response_part.content, dict):
|
|
519
|
+
assistant_content_params.append(
|
|
520
|
+
BetaCodeExecutionToolResultBlockParam(
|
|
521
|
+
tool_use_id=tool_use_id,
|
|
522
|
+
type='code_execution_tool_result',
|
|
523
|
+
content=cast(
|
|
524
|
+
BetaCodeExecutionToolResultBlockParamContentParam,
|
|
525
|
+
response_part.content, # pyright: ignore[reportUnknownMemberType]
|
|
526
|
+
),
|
|
527
|
+
)
|
|
505
528
|
)
|
|
506
|
-
else:
|
|
507
|
-
raise ValueError(f'Unsupported tool name: {response_part.tool_name}')
|
|
508
|
-
assistant_content_params.append(server_tool_result_block_param)
|
|
509
529
|
else:
|
|
510
530
|
assert_never(response_part)
|
|
511
531
|
if len(assistant_content_params) > 0:
|
|
@@ -646,7 +666,7 @@ class AnthropicStreamedResponse(StreamedResponse):
|
|
|
646
666
|
)
|
|
647
667
|
elif isinstance(current_block, BetaToolUseBlock):
|
|
648
668
|
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
649
|
-
vendor_part_id=
|
|
669
|
+
vendor_part_id=event.index,
|
|
650
670
|
tool_name=current_block.name,
|
|
651
671
|
args=cast(dict[str, Any], current_block.input) or None,
|
|
652
672
|
tool_call_id=current_block.id,
|
|
@@ -654,7 +674,20 @@ class AnthropicStreamedResponse(StreamedResponse):
|
|
|
654
674
|
if maybe_event is not None: # pragma: no branch
|
|
655
675
|
yield maybe_event
|
|
656
676
|
elif isinstance(current_block, BetaServerToolUseBlock):
|
|
657
|
-
|
|
677
|
+
yield self._parts_manager.handle_builtin_tool_call_part(
|
|
678
|
+
vendor_part_id=event.index,
|
|
679
|
+
part=_map_server_tool_use_block(current_block, self.provider_name),
|
|
680
|
+
)
|
|
681
|
+
elif isinstance(current_block, BetaWebSearchToolResultBlock):
|
|
682
|
+
yield self._parts_manager.handle_builtin_tool_return_part(
|
|
683
|
+
vendor_part_id=event.index,
|
|
684
|
+
part=_map_web_search_tool_result_block(current_block, self.provider_name),
|
|
685
|
+
)
|
|
686
|
+
elif isinstance(current_block, BetaCodeExecutionToolResultBlock):
|
|
687
|
+
yield self._parts_manager.handle_builtin_tool_return_part(
|
|
688
|
+
vendor_part_id=event.index,
|
|
689
|
+
part=_map_code_execution_tool_result_block(current_block, self.provider_name),
|
|
690
|
+
)
|
|
658
691
|
|
|
659
692
|
elif isinstance(event, BetaRawContentBlockDeltaEvent):
|
|
660
693
|
if isinstance(event.delta, BetaTextDelta):
|
|
@@ -675,21 +708,13 @@ class AnthropicStreamedResponse(StreamedResponse):
|
|
|
675
708
|
signature=event.delta.signature,
|
|
676
709
|
provider_name=self.provider_name,
|
|
677
710
|
)
|
|
678
|
-
elif (
|
|
679
|
-
current_block
|
|
680
|
-
and event.delta.type == 'input_json_delta'
|
|
681
|
-
and isinstance(current_block, BetaToolUseBlock)
|
|
682
|
-
): # pragma: no branch
|
|
711
|
+
elif isinstance(event.delta, BetaInputJSONDelta):
|
|
683
712
|
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
684
|
-
vendor_part_id=
|
|
685
|
-
tool_name='',
|
|
713
|
+
vendor_part_id=event.index,
|
|
686
714
|
args=event.delta.partial_json,
|
|
687
|
-
tool_call_id=current_block.id,
|
|
688
715
|
)
|
|
689
716
|
if maybe_event is not None: # pragma: no branch
|
|
690
717
|
yield maybe_event
|
|
691
|
-
elif isinstance(event.delta, BetaInputJSONDelta):
|
|
692
|
-
pass
|
|
693
718
|
# TODO(Marcelo): We need to handle citations.
|
|
694
719
|
elif isinstance(event.delta, BetaCitationsDelta):
|
|
695
720
|
pass
|
|
@@ -717,3 +742,52 @@ class AnthropicStreamedResponse(StreamedResponse):
|
|
|
717
742
|
def timestamp(self) -> datetime:
|
|
718
743
|
"""Get the timestamp of the response."""
|
|
719
744
|
return self._timestamp
|
|
745
|
+
|
|
746
|
+
|
|
747
|
+
def _map_server_tool_use_block(item: BetaServerToolUseBlock, provider_name: str) -> BuiltinToolCallPart:
|
|
748
|
+
if item.name == 'web_search':
|
|
749
|
+
return BuiltinToolCallPart(
|
|
750
|
+
provider_name=provider_name,
|
|
751
|
+
tool_name=WebSearchTool.kind,
|
|
752
|
+
args=cast(dict[str, Any], item.input) or None,
|
|
753
|
+
tool_call_id=item.id,
|
|
754
|
+
)
|
|
755
|
+
elif item.name == 'code_execution':
|
|
756
|
+
return BuiltinToolCallPart(
|
|
757
|
+
provider_name=provider_name,
|
|
758
|
+
tool_name=CodeExecutionTool.kind,
|
|
759
|
+
args=cast(dict[str, Any], item.input) or None,
|
|
760
|
+
tool_call_id=item.id,
|
|
761
|
+
)
|
|
762
|
+
else:
|
|
763
|
+
assert_never(item.name)
|
|
764
|
+
|
|
765
|
+
|
|
766
|
+
web_search_tool_result_content_ta: TypeAdapter[BetaWebSearchToolResultBlockContent] = TypeAdapter(
|
|
767
|
+
BetaWebSearchToolResultBlockContent
|
|
768
|
+
)
|
|
769
|
+
|
|
770
|
+
|
|
771
|
+
def _map_web_search_tool_result_block(item: BetaWebSearchToolResultBlock, provider_name: str) -> BuiltinToolReturnPart:
|
|
772
|
+
return BuiltinToolReturnPart(
|
|
773
|
+
provider_name=provider_name,
|
|
774
|
+
tool_name=WebSearchTool.kind,
|
|
775
|
+
content=web_search_tool_result_content_ta.dump_python(item.content, mode='json'),
|
|
776
|
+
tool_call_id=item.tool_use_id,
|
|
777
|
+
)
|
|
778
|
+
|
|
779
|
+
|
|
780
|
+
code_execution_tool_result_content_ta: TypeAdapter[BetaCodeExecutionToolResultBlockContent] = TypeAdapter(
|
|
781
|
+
BetaCodeExecutionToolResultBlockContent
|
|
782
|
+
)
|
|
783
|
+
|
|
784
|
+
|
|
785
|
+
def _map_code_execution_tool_result_block(
|
|
786
|
+
item: BetaCodeExecutionToolResultBlock, provider_name: str
|
|
787
|
+
) -> BuiltinToolReturnPart:
|
|
788
|
+
return BuiltinToolReturnPart(
|
|
789
|
+
provider_name=provider_name,
|
|
790
|
+
tool_name=CodeExecutionTool.kind,
|
|
791
|
+
content=code_execution_tool_result_content_ta.dump_python(item.content, mode='json'),
|
|
792
|
+
tool_call_id=item.tool_use_id,
|
|
793
|
+
)
|
pydantic_ai/models/function.py
CHANGED
|
@@ -247,18 +247,20 @@ DeltaToolCalls: TypeAlias = dict[int, DeltaToolCall]
|
|
|
247
247
|
DeltaThinkingCalls: TypeAlias = dict[int, DeltaThinkingPart]
|
|
248
248
|
"""A mapping of thinking call IDs to incremental changes."""
|
|
249
249
|
|
|
250
|
+
BuiltinToolCallsReturns: TypeAlias = dict[int, BuiltinToolCallPart | BuiltinToolReturnPart]
|
|
251
|
+
|
|
250
252
|
FunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], ModelResponse | Awaitable[ModelResponse]]
|
|
251
253
|
"""A function used to generate a non-streamed response."""
|
|
252
254
|
|
|
253
255
|
StreamFunctionDef: TypeAlias = Callable[
|
|
254
|
-
[list[ModelMessage], AgentInfo], AsyncIterator[str | DeltaToolCalls | DeltaThinkingCalls]
|
|
256
|
+
[list[ModelMessage], AgentInfo], AsyncIterator[str | DeltaToolCalls | DeltaThinkingCalls | BuiltinToolCallsReturns]
|
|
255
257
|
]
|
|
256
258
|
"""A function used to generate a streamed response.
|
|
257
259
|
|
|
258
|
-
While this is defined as having return type of `AsyncIterator[str | DeltaToolCalls | DeltaThinkingCalls]`, it should
|
|
260
|
+
While this is defined as having return type of `AsyncIterator[str | DeltaToolCalls | DeltaThinkingCalls | BuiltinTools]`, it should
|
|
259
261
|
really be considered as `AsyncIterator[str] | AsyncIterator[DeltaToolCalls] | AsyncIterator[DeltaThinkingCalls]`,
|
|
260
262
|
|
|
261
|
-
E.g. you need to yield all text, all `DeltaToolCalls`, or all `
|
|
263
|
+
E.g. you need to yield all text, all `DeltaToolCalls`, all `DeltaThinkingCalls`, or all `BuiltinToolCallsReturns`, not mix them.
|
|
262
264
|
"""
|
|
263
265
|
|
|
264
266
|
|
|
@@ -267,7 +269,7 @@ class FunctionStreamedResponse(StreamedResponse):
|
|
|
267
269
|
"""Implementation of `StreamedResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
|
|
268
270
|
|
|
269
271
|
_model_name: str
|
|
270
|
-
_iter: AsyncIterator[str | DeltaToolCalls | DeltaThinkingCalls]
|
|
272
|
+
_iter: AsyncIterator[str | DeltaToolCalls | DeltaThinkingCalls | BuiltinToolCallsReturns]
|
|
271
273
|
_timestamp: datetime = field(default_factory=_utils.now_utc)
|
|
272
274
|
|
|
273
275
|
def __post_init__(self):
|
|
@@ -305,6 +307,16 @@ class FunctionStreamedResponse(StreamedResponse):
|
|
|
305
307
|
)
|
|
306
308
|
if maybe_event is not None: # pragma: no branch
|
|
307
309
|
yield maybe_event
|
|
310
|
+
elif isinstance(delta, BuiltinToolCallPart):
|
|
311
|
+
if content := delta.args_as_json_str(): # pragma: no branch
|
|
312
|
+
response_tokens = _estimate_string_tokens(content)
|
|
313
|
+
self._usage += usage.RequestUsage(output_tokens=response_tokens)
|
|
314
|
+
yield self._parts_manager.handle_builtin_tool_call_part(vendor_part_id=dtc_index, part=delta)
|
|
315
|
+
elif isinstance(delta, BuiltinToolReturnPart):
|
|
316
|
+
if content := delta.model_response_str(): # pragma: no branch
|
|
317
|
+
response_tokens = _estimate_string_tokens(content)
|
|
318
|
+
self._usage += usage.RequestUsage(output_tokens=response_tokens)
|
|
319
|
+
yield self._parts_manager.handle_builtin_tool_return_part(vendor_part_id=dtc_index, part=delta)
|
|
308
320
|
else:
|
|
309
321
|
assert_never(delta)
|
|
310
322
|
|
|
@@ -351,11 +363,8 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.RequestUsage:
|
|
|
351
363
|
response_tokens += _estimate_string_tokens(part.content)
|
|
352
364
|
elif isinstance(part, ToolCallPart):
|
|
353
365
|
response_tokens += 1 + _estimate_string_tokens(part.args_as_json_str())
|
|
354
|
-
# TODO(Marcelo): We need to add coverage here.
|
|
355
366
|
elif isinstance(part, BuiltinToolCallPart): # pragma: no cover
|
|
356
|
-
|
|
357
|
-
response_tokens += 1 + _estimate_string_tokens(call.args_as_json_str())
|
|
358
|
-
# TODO(Marcelo): We need to add coverage here.
|
|
367
|
+
response_tokens += 1 + _estimate_string_tokens(part.args_as_json_str())
|
|
359
368
|
elif isinstance(part, BuiltinToolReturnPart): # pragma: no cover
|
|
360
369
|
response_tokens += _estimate_string_tokens(part.model_response_str())
|
|
361
370
|
else:
|
pydantic_ai/models/google.py
CHANGED
|
@@ -51,10 +51,15 @@ from . import (
|
|
|
51
51
|
try:
|
|
52
52
|
from google.genai import Client
|
|
53
53
|
from google.genai.types import (
|
|
54
|
+
BlobDict,
|
|
55
|
+
CodeExecutionResult,
|
|
56
|
+
CodeExecutionResultDict,
|
|
54
57
|
ContentDict,
|
|
55
58
|
ContentUnionDict,
|
|
56
59
|
CountTokensConfigDict,
|
|
60
|
+
ExecutableCode,
|
|
57
61
|
ExecutableCodeDict,
|
|
62
|
+
FileDataDict,
|
|
58
63
|
FinishReason as GoogleFinishReason,
|
|
59
64
|
FunctionCallDict,
|
|
60
65
|
FunctionCallingConfigDict,
|
|
@@ -64,6 +69,7 @@ try:
|
|
|
64
69
|
GenerateContentResponse,
|
|
65
70
|
GenerationConfigDict,
|
|
66
71
|
GoogleSearchDict,
|
|
72
|
+
GroundingMetadata,
|
|
67
73
|
HttpOptionsDict,
|
|
68
74
|
MediaResolution,
|
|
69
75
|
Part,
|
|
@@ -75,6 +81,7 @@ try:
|
|
|
75
81
|
ToolDict,
|
|
76
82
|
ToolListUnionDict,
|
|
77
83
|
UrlContextDict,
|
|
84
|
+
VideoMetadataDict,
|
|
78
85
|
)
|
|
79
86
|
|
|
80
87
|
from ..providers.google import GoogleProvider
|
|
@@ -434,6 +441,7 @@ class GoogleModel(Model):
|
|
|
434
441
|
usage = _metadata_as_usage(response)
|
|
435
442
|
return _process_response_from_parts(
|
|
436
443
|
parts,
|
|
444
|
+
candidate.grounding_metadata,
|
|
437
445
|
response.model_version or self._model_name,
|
|
438
446
|
self._provider.name,
|
|
439
447
|
usage,
|
|
@@ -520,17 +528,17 @@ class GoogleModel(Model):
|
|
|
520
528
|
if isinstance(item, str):
|
|
521
529
|
content.append({'text': item})
|
|
522
530
|
elif isinstance(item, BinaryContent):
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
inline_data_dict = {'inline_data': {'data': base64_encoded, 'mime_type': item.media_type}}
|
|
531
|
+
inline_data_dict: BlobDict = {'data': item.data, 'mime_type': item.media_type}
|
|
532
|
+
part_dict: PartDict = {'inline_data': inline_data_dict}
|
|
526
533
|
if item.vendor_metadata:
|
|
527
|
-
|
|
528
|
-
content.append(
|
|
534
|
+
part_dict['video_metadata'] = cast(VideoMetadataDict, item.vendor_metadata)
|
|
535
|
+
content.append(part_dict)
|
|
529
536
|
elif isinstance(item, VideoUrl) and item.is_youtube:
|
|
530
|
-
file_data_dict = {'
|
|
537
|
+
file_data_dict: FileDataDict = {'file_uri': item.url, 'mime_type': item.media_type}
|
|
538
|
+
part_dict: PartDict = {'file_data': file_data_dict}
|
|
531
539
|
if item.vendor_metadata: # pragma: no branch
|
|
532
|
-
|
|
533
|
-
content.append(
|
|
540
|
+
part_dict['video_metadata'] = cast(VideoMetadataDict, item.vendor_metadata)
|
|
541
|
+
content.append(part_dict)
|
|
534
542
|
elif isinstance(item, FileUrl):
|
|
535
543
|
if item.force_download or (
|
|
536
544
|
# google-gla does not support passing file urls directly, except for youtube videos
|
|
@@ -538,13 +546,15 @@ class GoogleModel(Model):
|
|
|
538
546
|
self.system == 'google-gla'
|
|
539
547
|
and not item.url.startswith(r'https://generativelanguage.googleapis.com/v1beta/files')
|
|
540
548
|
):
|
|
541
|
-
downloaded_item = await download_item(item, data_format='
|
|
542
|
-
inline_data = {
|
|
543
|
-
|
|
549
|
+
downloaded_item = await download_item(item, data_format='bytes')
|
|
550
|
+
inline_data: BlobDict = {
|
|
551
|
+
'data': downloaded_item['data'],
|
|
552
|
+
'mime_type': downloaded_item['data_type'],
|
|
553
|
+
}
|
|
554
|
+
content.append({'inline_data': inline_data})
|
|
544
555
|
else:
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
) # pragma: lax no cover
|
|
556
|
+
file_data_dict: FileDataDict = {'file_uri': item.url, 'mime_type': item.media_type}
|
|
557
|
+
content.append({'file_data': file_data_dict}) # pragma: lax no cover
|
|
548
558
|
else:
|
|
549
559
|
assert_never(item)
|
|
550
560
|
return content
|
|
@@ -569,10 +579,13 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
569
579
|
_provider_name: str
|
|
570
580
|
|
|
571
581
|
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
|
|
582
|
+
code_execution_tool_call_id: str | None = None
|
|
572
583
|
async for chunk in self._response:
|
|
573
584
|
self._usage = _metadata_as_usage(chunk)
|
|
574
585
|
|
|
575
|
-
|
|
586
|
+
if not chunk.candidates:
|
|
587
|
+
continue # pragma: no cover
|
|
588
|
+
|
|
576
589
|
candidate = chunk.candidates[0]
|
|
577
590
|
|
|
578
591
|
if chunk.response_id: # pragma: no branch
|
|
@@ -582,6 +595,19 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
582
595
|
self.provider_details = {'finish_reason': raw_finish_reason.value}
|
|
583
596
|
self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
|
|
584
597
|
|
|
598
|
+
# Google streams the grounding metadata (including the web search queries and results)
|
|
599
|
+
# _after_ the text that was generated using it, so it would show up out of order in the stream,
|
|
600
|
+
# and cause issues with the logic that doesn't consider text ahead of built-in tool calls as output.
|
|
601
|
+
# If that gets fixed (or we have a workaround), we can uncomment this:
|
|
602
|
+
# web_search_call, web_search_return = _map_grounding_metadata(
|
|
603
|
+
# candidate.grounding_metadata, self.provider_name
|
|
604
|
+
# )
|
|
605
|
+
# if web_search_call and web_search_return:
|
|
606
|
+
# yield self._parts_manager.handle_builtin_tool_call_part(vendor_part_id=uuid4(), part=web_search_call)
|
|
607
|
+
# yield self._parts_manager.handle_builtin_tool_return_part(
|
|
608
|
+
# vendor_part_id=uuid4(), part=web_search_return
|
|
609
|
+
# )
|
|
610
|
+
|
|
585
611
|
if candidate.content is None or candidate.content.parts is None:
|
|
586
612
|
if candidate.finish_reason == 'STOP': # pragma: no cover
|
|
587
613
|
# Normal completion - skip this chunk
|
|
@@ -590,7 +616,11 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
590
616
|
raise UnexpectedModelBehavior('Safety settings triggered', str(chunk))
|
|
591
617
|
else: # pragma: no cover
|
|
592
618
|
raise UnexpectedModelBehavior('Content field missing from streaming Gemini response', str(chunk))
|
|
593
|
-
|
|
619
|
+
|
|
620
|
+
parts = candidate.content.parts
|
|
621
|
+
if not parts:
|
|
622
|
+
continue # pragma: no cover
|
|
623
|
+
|
|
594
624
|
for part in parts:
|
|
595
625
|
if part.thought_signature:
|
|
596
626
|
signature = base64.b64encode(part.thought_signature).decode('utf-8')
|
|
@@ -617,9 +647,21 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
617
647
|
if maybe_event is not None: # pragma: no branch
|
|
618
648
|
yield maybe_event
|
|
619
649
|
elif part.executable_code is not None:
|
|
620
|
-
|
|
650
|
+
code_execution_tool_call_id = _utils.generate_tool_call_id()
|
|
651
|
+
yield self._parts_manager.handle_builtin_tool_call_part(
|
|
652
|
+
vendor_part_id=uuid4(),
|
|
653
|
+
part=_map_executable_code(
|
|
654
|
+
part.executable_code, self.provider_name, code_execution_tool_call_id
|
|
655
|
+
),
|
|
656
|
+
)
|
|
621
657
|
elif part.code_execution_result is not None:
|
|
622
|
-
|
|
658
|
+
assert code_execution_tool_call_id is not None
|
|
659
|
+
yield self._parts_manager.handle_builtin_tool_return_part(
|
|
660
|
+
vendor_part_id=uuid4(),
|
|
661
|
+
part=_map_code_execution_result(
|
|
662
|
+
part.code_execution_result, self.provider_name, code_execution_tool_call_id
|
|
663
|
+
),
|
|
664
|
+
)
|
|
623
665
|
else:
|
|
624
666
|
assert part.function_response is not None, f'Unexpected part: {part}' # pragma: no cover
|
|
625
667
|
|
|
@@ -639,7 +681,7 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
639
681
|
return self._timestamp
|
|
640
682
|
|
|
641
683
|
|
|
642
|
-
def _content_model_response(m: ModelResponse, provider_name: str) -> ContentDict:
|
|
684
|
+
def _content_model_response(m: ModelResponse, provider_name: str) -> ContentDict: # noqa: C901
|
|
643
685
|
parts: list[PartDict] = []
|
|
644
686
|
thought_signature: bytes | None = None
|
|
645
687
|
for item in m.parts:
|
|
@@ -663,12 +705,18 @@ def _content_model_response(m: ModelResponse, provider_name: str) -> ContentDict
|
|
|
663
705
|
part['thought'] = True
|
|
664
706
|
elif isinstance(item, BuiltinToolCallPart):
|
|
665
707
|
if item.provider_name == provider_name:
|
|
666
|
-
if item.tool_name ==
|
|
667
|
-
part['executable_code'] = cast(ExecutableCodeDict, item.
|
|
708
|
+
if item.tool_name == CodeExecutionTool.kind:
|
|
709
|
+
part['executable_code'] = cast(ExecutableCodeDict, item.args_as_dict())
|
|
710
|
+
elif item.tool_name == WebSearchTool.kind:
|
|
711
|
+
# Web search calls are not sent back
|
|
712
|
+
pass
|
|
668
713
|
elif isinstance(item, BuiltinToolReturnPart):
|
|
669
714
|
if item.provider_name == provider_name:
|
|
670
|
-
if item.tool_name ==
|
|
671
|
-
part['code_execution_result'] = item.content
|
|
715
|
+
if item.tool_name == CodeExecutionTool.kind and isinstance(item.content, dict):
|
|
716
|
+
part['code_execution_result'] = cast(CodeExecutionResultDict, item.content) # pyright: ignore[reportUnknownMemberType]
|
|
717
|
+
elif item.tool_name == WebSearchTool.kind:
|
|
718
|
+
# Web search results are not sent back
|
|
719
|
+
pass
|
|
672
720
|
else:
|
|
673
721
|
assert_never(item)
|
|
674
722
|
|
|
@@ -679,6 +727,7 @@ def _content_model_response(m: ModelResponse, provider_name: str) -> ContentDict
|
|
|
679
727
|
|
|
680
728
|
def _process_response_from_parts(
|
|
681
729
|
parts: list[Part],
|
|
730
|
+
grounding_metadata: GroundingMetadata | None,
|
|
682
731
|
model_name: GoogleModelName,
|
|
683
732
|
provider_name: str,
|
|
684
733
|
usage: usage.RequestUsage,
|
|
@@ -687,7 +736,17 @@ def _process_response_from_parts(
|
|
|
687
736
|
finish_reason: FinishReason | None = None,
|
|
688
737
|
) -> ModelResponse:
|
|
689
738
|
items: list[ModelResponsePart] = []
|
|
739
|
+
|
|
740
|
+
# We don't currently turn `candidate.url_context_metadata` into BuiltinToolCallPart and BuiltinToolReturnPart for UrlContextTool.
|
|
741
|
+
# Please file an issue if you need this.
|
|
742
|
+
|
|
743
|
+
web_search_call, web_search_return = _map_grounding_metadata(grounding_metadata, provider_name)
|
|
744
|
+
if web_search_call and web_search_return:
|
|
745
|
+
items.append(web_search_call)
|
|
746
|
+
items.append(web_search_return)
|
|
747
|
+
|
|
690
748
|
item: ModelResponsePart | None = None
|
|
749
|
+
code_execution_tool_call_id: str | None = None
|
|
691
750
|
for part in parts:
|
|
692
751
|
if part.thought_signature:
|
|
693
752
|
signature = base64.b64encode(part.thought_signature).decode('utf-8')
|
|
@@ -698,16 +757,11 @@ def _process_response_from_parts(
|
|
|
698
757
|
item.provider_name = provider_name
|
|
699
758
|
|
|
700
759
|
if part.executable_code is not None:
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
)
|
|
760
|
+
code_execution_tool_call_id = _utils.generate_tool_call_id()
|
|
761
|
+
item = _map_executable_code(part.executable_code, provider_name, code_execution_tool_call_id)
|
|
704
762
|
elif part.code_execution_result is not None:
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
tool_name='code_execution',
|
|
708
|
-
content=part.code_execution_result,
|
|
709
|
-
tool_call_id='not_provided',
|
|
710
|
-
)
|
|
763
|
+
assert code_execution_tool_call_id is not None
|
|
764
|
+
item = _map_code_execution_result(part.code_execution_result, provider_name, code_execution_tool_call_id)
|
|
711
765
|
elif part.text is not None:
|
|
712
766
|
if part.thought:
|
|
713
767
|
item = ThinkingPart(content=part.text)
|
|
@@ -778,7 +832,7 @@ def _metadata_as_usage(response: GenerateContentResponse) -> usage.RequestUsage:
|
|
|
778
832
|
if not metadata_details:
|
|
779
833
|
continue
|
|
780
834
|
for detail in metadata_details:
|
|
781
|
-
if not detail.modality or not detail.token_count:
|
|
835
|
+
if not detail.modality or not detail.token_count:
|
|
782
836
|
continue
|
|
783
837
|
details[f'{detail.modality.lower()}_{prefix}_tokens'] = detail.token_count
|
|
784
838
|
if detail.modality != 'AUDIO':
|
|
@@ -799,3 +853,48 @@ def _metadata_as_usage(response: GenerateContentResponse) -> usage.RequestUsage:
|
|
|
799
853
|
cache_audio_read_tokens=cache_audio_read_tokens,
|
|
800
854
|
details=details,
|
|
801
855
|
)
|
|
856
|
+
|
|
857
|
+
|
|
858
|
+
def _map_executable_code(executable_code: ExecutableCode, provider_name: str, tool_call_id: str) -> BuiltinToolCallPart:
|
|
859
|
+
return BuiltinToolCallPart(
|
|
860
|
+
provider_name=provider_name,
|
|
861
|
+
tool_name=CodeExecutionTool.kind,
|
|
862
|
+
args=executable_code.model_dump(mode='json'),
|
|
863
|
+
tool_call_id=tool_call_id,
|
|
864
|
+
)
|
|
865
|
+
|
|
866
|
+
|
|
867
|
+
def _map_code_execution_result(
|
|
868
|
+
code_execution_result: CodeExecutionResult, provider_name: str, tool_call_id: str
|
|
869
|
+
) -> BuiltinToolReturnPart:
|
|
870
|
+
return BuiltinToolReturnPart(
|
|
871
|
+
provider_name=provider_name,
|
|
872
|
+
tool_name=CodeExecutionTool.kind,
|
|
873
|
+
content=code_execution_result.model_dump(mode='json'),
|
|
874
|
+
tool_call_id=tool_call_id,
|
|
875
|
+
)
|
|
876
|
+
|
|
877
|
+
|
|
878
|
+
def _map_grounding_metadata(
|
|
879
|
+
grounding_metadata: GroundingMetadata | None, provider_name: str
|
|
880
|
+
) -> tuple[BuiltinToolCallPart, BuiltinToolReturnPart] | tuple[None, None]:
|
|
881
|
+
if grounding_metadata and (web_search_queries := grounding_metadata.web_search_queries):
|
|
882
|
+
tool_call_id = _utils.generate_tool_call_id()
|
|
883
|
+
return (
|
|
884
|
+
BuiltinToolCallPart(
|
|
885
|
+
provider_name=provider_name,
|
|
886
|
+
tool_name=WebSearchTool.kind,
|
|
887
|
+
tool_call_id=tool_call_id,
|
|
888
|
+
args={'queries': web_search_queries},
|
|
889
|
+
),
|
|
890
|
+
BuiltinToolReturnPart(
|
|
891
|
+
provider_name=provider_name,
|
|
892
|
+
tool_name=WebSearchTool.kind,
|
|
893
|
+
tool_call_id=tool_call_id,
|
|
894
|
+
content=[chunk.web.model_dump(mode='json') for chunk in grounding_chunks if chunk.web]
|
|
895
|
+
if (grounding_chunks := grounding_metadata.grounding_chunks)
|
|
896
|
+
else None,
|
|
897
|
+
),
|
|
898
|
+
)
|
|
899
|
+
else:
|
|
900
|
+
return None, None
|