pydantic-ai-slim 1.0.8__py3-none-any.whl → 1.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +20 -14
- pydantic_ai/_cli.py +1 -1
- pydantic_ai/_otel_messages.py +2 -0
- pydantic_ai/_parts_manager.py +82 -12
- pydantic_ai/_run_context.py +8 -1
- pydantic_ai/_tool_manager.py +1 -0
- pydantic_ai/ag_ui.py +86 -33
- pydantic_ai/builtin_tools.py +12 -0
- pydantic_ai/durable_exec/temporal/_model.py +14 -6
- pydantic_ai/durable_exec/temporal/_run_context.py +2 -1
- pydantic_ai/messages.py +69 -30
- pydantic_ai/models/anthropic.py +119 -45
- pydantic_ai/models/function.py +17 -8
- pydantic_ai/models/google.py +105 -16
- pydantic_ai/models/groq.py +68 -17
- pydantic_ai/models/openai.py +262 -41
- pydantic_ai/providers/__init__.py +1 -1
- pydantic_ai/result.py +21 -3
- pydantic_ai/toolsets/function.py +8 -2
- {pydantic_ai_slim-1.0.8.dist-info → pydantic_ai_slim-1.0.9.dist-info}/METADATA +5 -5
- {pydantic_ai_slim-1.0.8.dist-info → pydantic_ai_slim-1.0.9.dist-info}/RECORD +24 -24
- {pydantic_ai_slim-1.0.8.dist-info → pydantic_ai_slim-1.0.9.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-1.0.8.dist-info → pydantic_ai_slim-1.0.9.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-1.0.8.dist-info → pydantic_ai_slim-1.0.9.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/google.py
CHANGED
|
@@ -51,9 +51,12 @@ from . import (
|
|
|
51
51
|
try:
|
|
52
52
|
from google.genai import Client
|
|
53
53
|
from google.genai.types import (
|
|
54
|
+
CodeExecutionResult,
|
|
55
|
+
CodeExecutionResultDict,
|
|
54
56
|
ContentDict,
|
|
55
57
|
ContentUnionDict,
|
|
56
58
|
CountTokensConfigDict,
|
|
59
|
+
ExecutableCode,
|
|
57
60
|
ExecutableCodeDict,
|
|
58
61
|
FinishReason as GoogleFinishReason,
|
|
59
62
|
FunctionCallDict,
|
|
@@ -64,6 +67,7 @@ try:
|
|
|
64
67
|
GenerateContentResponse,
|
|
65
68
|
GenerationConfigDict,
|
|
66
69
|
GoogleSearchDict,
|
|
70
|
+
GroundingMetadata,
|
|
67
71
|
HttpOptionsDict,
|
|
68
72
|
MediaResolution,
|
|
69
73
|
Part,
|
|
@@ -434,6 +438,7 @@ class GoogleModel(Model):
|
|
|
434
438
|
usage = _metadata_as_usage(response)
|
|
435
439
|
return _process_response_from_parts(
|
|
436
440
|
parts,
|
|
441
|
+
candidate.grounding_metadata,
|
|
437
442
|
response.model_version or self._model_name,
|
|
438
443
|
self._provider.name,
|
|
439
444
|
usage,
|
|
@@ -569,6 +574,7 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
569
574
|
_provider_name: str
|
|
570
575
|
|
|
571
576
|
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
|
|
577
|
+
code_execution_tool_call_id: str | None = None
|
|
572
578
|
async for chunk in self._response:
|
|
573
579
|
self._usage = _metadata_as_usage(chunk)
|
|
574
580
|
|
|
@@ -582,6 +588,19 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
582
588
|
self.provider_details = {'finish_reason': raw_finish_reason.value}
|
|
583
589
|
self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
|
|
584
590
|
|
|
591
|
+
# Google streams the grounding metadata (including the web search queries and results)
|
|
592
|
+
# _after_ the text that was generated using it, so it would show up out of order in the stream,
|
|
593
|
+
# and cause issues with the logic that doesn't consider text ahead of built-in tool calls as output.
|
|
594
|
+
# If that gets fixed (or we have a workaround), we can uncomment this:
|
|
595
|
+
# web_search_call, web_search_return = _map_grounding_metadata(
|
|
596
|
+
# candidate.grounding_metadata, self.provider_name
|
|
597
|
+
# )
|
|
598
|
+
# if web_search_call and web_search_return:
|
|
599
|
+
# yield self._parts_manager.handle_builtin_tool_call_part(vendor_part_id=uuid4(), part=web_search_call)
|
|
600
|
+
# yield self._parts_manager.handle_builtin_tool_return_part(
|
|
601
|
+
# vendor_part_id=uuid4(), part=web_search_return
|
|
602
|
+
# )
|
|
603
|
+
|
|
585
604
|
if candidate.content is None or candidate.content.parts is None:
|
|
586
605
|
if candidate.finish_reason == 'STOP': # pragma: no cover
|
|
587
606
|
# Normal completion - skip this chunk
|
|
@@ -590,6 +609,7 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
590
609
|
raise UnexpectedModelBehavior('Safety settings triggered', str(chunk))
|
|
591
610
|
else: # pragma: no cover
|
|
592
611
|
raise UnexpectedModelBehavior('Content field missing from streaming Gemini response', str(chunk))
|
|
612
|
+
|
|
593
613
|
parts = candidate.content.parts or []
|
|
594
614
|
for part in parts:
|
|
595
615
|
if part.thought_signature:
|
|
@@ -617,9 +637,21 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
617
637
|
if maybe_event is not None: # pragma: no branch
|
|
618
638
|
yield maybe_event
|
|
619
639
|
elif part.executable_code is not None:
|
|
620
|
-
|
|
640
|
+
code_execution_tool_call_id = _utils.generate_tool_call_id()
|
|
641
|
+
yield self._parts_manager.handle_builtin_tool_call_part(
|
|
642
|
+
vendor_part_id=uuid4(),
|
|
643
|
+
part=_map_executable_code(
|
|
644
|
+
part.executable_code, self.provider_name, code_execution_tool_call_id
|
|
645
|
+
),
|
|
646
|
+
)
|
|
621
647
|
elif part.code_execution_result is not None:
|
|
622
|
-
|
|
648
|
+
assert code_execution_tool_call_id is not None
|
|
649
|
+
yield self._parts_manager.handle_builtin_tool_return_part(
|
|
650
|
+
vendor_part_id=uuid4(),
|
|
651
|
+
part=_map_code_execution_result(
|
|
652
|
+
part.code_execution_result, self.provider_name, code_execution_tool_call_id
|
|
653
|
+
),
|
|
654
|
+
)
|
|
623
655
|
else:
|
|
624
656
|
assert part.function_response is not None, f'Unexpected part: {part}' # pragma: no cover
|
|
625
657
|
|
|
@@ -639,7 +671,7 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
639
671
|
return self._timestamp
|
|
640
672
|
|
|
641
673
|
|
|
642
|
-
def _content_model_response(m: ModelResponse, provider_name: str) -> ContentDict:
|
|
674
|
+
def _content_model_response(m: ModelResponse, provider_name: str) -> ContentDict: # noqa: C901
|
|
643
675
|
parts: list[PartDict] = []
|
|
644
676
|
thought_signature: bytes | None = None
|
|
645
677
|
for item in m.parts:
|
|
@@ -663,12 +695,18 @@ def _content_model_response(m: ModelResponse, provider_name: str) -> ContentDict
|
|
|
663
695
|
part['thought'] = True
|
|
664
696
|
elif isinstance(item, BuiltinToolCallPart):
|
|
665
697
|
if item.provider_name == provider_name:
|
|
666
|
-
if item.tool_name ==
|
|
667
|
-
part['executable_code'] = cast(ExecutableCodeDict, item.
|
|
698
|
+
if item.tool_name == CodeExecutionTool.kind:
|
|
699
|
+
part['executable_code'] = cast(ExecutableCodeDict, item.args_as_dict())
|
|
700
|
+
elif item.tool_name == WebSearchTool.kind:
|
|
701
|
+
# Web search calls are not sent back
|
|
702
|
+
pass
|
|
668
703
|
elif isinstance(item, BuiltinToolReturnPart):
|
|
669
704
|
if item.provider_name == provider_name:
|
|
670
|
-
if item.tool_name ==
|
|
671
|
-
part['code_execution_result'] = item.content
|
|
705
|
+
if item.tool_name == CodeExecutionTool.kind and isinstance(item.content, dict):
|
|
706
|
+
part['code_execution_result'] = cast(CodeExecutionResultDict, item.content) # pyright: ignore[reportUnknownMemberType]
|
|
707
|
+
elif item.tool_name == WebSearchTool.kind:
|
|
708
|
+
# Web search results are not sent back
|
|
709
|
+
pass
|
|
672
710
|
else:
|
|
673
711
|
assert_never(item)
|
|
674
712
|
|
|
@@ -679,6 +717,7 @@ def _content_model_response(m: ModelResponse, provider_name: str) -> ContentDict
|
|
|
679
717
|
|
|
680
718
|
def _process_response_from_parts(
|
|
681
719
|
parts: list[Part],
|
|
720
|
+
grounding_metadata: GroundingMetadata | None,
|
|
682
721
|
model_name: GoogleModelName,
|
|
683
722
|
provider_name: str,
|
|
684
723
|
usage: usage.RequestUsage,
|
|
@@ -687,7 +726,17 @@ def _process_response_from_parts(
|
|
|
687
726
|
finish_reason: FinishReason | None = None,
|
|
688
727
|
) -> ModelResponse:
|
|
689
728
|
items: list[ModelResponsePart] = []
|
|
729
|
+
|
|
730
|
+
# We don't currently turn `candidate.url_context_metadata` into BuiltinToolCallPart and BuiltinToolReturnPart for UrlContextTool.
|
|
731
|
+
# Please file an issue if you need this.
|
|
732
|
+
|
|
733
|
+
web_search_call, web_search_return = _map_grounding_metadata(grounding_metadata, provider_name)
|
|
734
|
+
if web_search_call and web_search_return:
|
|
735
|
+
items.append(web_search_call)
|
|
736
|
+
items.append(web_search_return)
|
|
737
|
+
|
|
690
738
|
item: ModelResponsePart | None = None
|
|
739
|
+
code_execution_tool_call_id: str | None = None
|
|
691
740
|
for part in parts:
|
|
692
741
|
if part.thought_signature:
|
|
693
742
|
signature = base64.b64encode(part.thought_signature).decode('utf-8')
|
|
@@ -698,16 +747,11 @@ def _process_response_from_parts(
|
|
|
698
747
|
item.provider_name = provider_name
|
|
699
748
|
|
|
700
749
|
if part.executable_code is not None:
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
)
|
|
750
|
+
code_execution_tool_call_id = _utils.generate_tool_call_id()
|
|
751
|
+
item = _map_executable_code(part.executable_code, provider_name, code_execution_tool_call_id)
|
|
704
752
|
elif part.code_execution_result is not None:
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
tool_name='code_execution',
|
|
708
|
-
content=part.code_execution_result,
|
|
709
|
-
tool_call_id='not_provided',
|
|
710
|
-
)
|
|
753
|
+
assert code_execution_tool_call_id is not None
|
|
754
|
+
item = _map_code_execution_result(part.code_execution_result, provider_name, code_execution_tool_call_id)
|
|
711
755
|
elif part.text is not None:
|
|
712
756
|
if part.thought:
|
|
713
757
|
item = ThinkingPart(content=part.text)
|
|
@@ -799,3 +843,48 @@ def _metadata_as_usage(response: GenerateContentResponse) -> usage.RequestUsage:
|
|
|
799
843
|
cache_audio_read_tokens=cache_audio_read_tokens,
|
|
800
844
|
details=details,
|
|
801
845
|
)
|
|
846
|
+
|
|
847
|
+
|
|
848
|
+
def _map_executable_code(executable_code: ExecutableCode, provider_name: str, tool_call_id: str) -> BuiltinToolCallPart:
|
|
849
|
+
return BuiltinToolCallPart(
|
|
850
|
+
provider_name=provider_name,
|
|
851
|
+
tool_name=CodeExecutionTool.kind,
|
|
852
|
+
args=executable_code.model_dump(mode='json'),
|
|
853
|
+
tool_call_id=tool_call_id,
|
|
854
|
+
)
|
|
855
|
+
|
|
856
|
+
|
|
857
|
+
def _map_code_execution_result(
|
|
858
|
+
code_execution_result: CodeExecutionResult, provider_name: str, tool_call_id: str
|
|
859
|
+
) -> BuiltinToolReturnPart:
|
|
860
|
+
return BuiltinToolReturnPart(
|
|
861
|
+
provider_name=provider_name,
|
|
862
|
+
tool_name=CodeExecutionTool.kind,
|
|
863
|
+
content=code_execution_result.model_dump(mode='json'),
|
|
864
|
+
tool_call_id=tool_call_id,
|
|
865
|
+
)
|
|
866
|
+
|
|
867
|
+
|
|
868
|
+
def _map_grounding_metadata(
|
|
869
|
+
grounding_metadata: GroundingMetadata | None, provider_name: str
|
|
870
|
+
) -> tuple[BuiltinToolCallPart, BuiltinToolReturnPart] | tuple[None, None]:
|
|
871
|
+
if grounding_metadata and (web_search_queries := grounding_metadata.web_search_queries):
|
|
872
|
+
tool_call_id = _utils.generate_tool_call_id()
|
|
873
|
+
return (
|
|
874
|
+
BuiltinToolCallPart(
|
|
875
|
+
provider_name=provider_name,
|
|
876
|
+
tool_name=WebSearchTool.kind,
|
|
877
|
+
tool_call_id=tool_call_id,
|
|
878
|
+
args={'queries': web_search_queries},
|
|
879
|
+
),
|
|
880
|
+
BuiltinToolReturnPart(
|
|
881
|
+
provider_name=provider_name,
|
|
882
|
+
tool_name=WebSearchTool.kind,
|
|
883
|
+
tool_call_id=tool_call_id,
|
|
884
|
+
content=[chunk.web.model_dump(mode='json') for chunk in grounding_chunks if chunk.web]
|
|
885
|
+
if (grounding_chunks := grounding_metadata.grounding_chunks)
|
|
886
|
+
else None,
|
|
887
|
+
),
|
|
888
|
+
)
|
|
889
|
+
else:
|
|
890
|
+
return None, None
|
pydantic_ai/models/groq.py
CHANGED
|
@@ -8,11 +8,11 @@ from datetime import datetime
|
|
|
8
8
|
from typing import Any, Literal, cast, overload
|
|
9
9
|
|
|
10
10
|
from pydantic import BaseModel, Json, ValidationError
|
|
11
|
+
from pydantic_core import from_json
|
|
11
12
|
from typing_extensions import assert_never
|
|
12
13
|
|
|
13
|
-
from pydantic_ai._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition
|
|
14
|
-
|
|
15
14
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
15
|
+
from .._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition
|
|
16
16
|
from .._run_context import RunContext
|
|
17
17
|
from .._thinking_part import split_content_into_text_and_thinking
|
|
18
18
|
from .._utils import generate_tool_call_id, guard_tool_call_id as _guard_tool_call_id, number_to_datetime
|
|
@@ -55,6 +55,7 @@ try:
|
|
|
55
55
|
from groq import NOT_GIVEN, APIError, APIStatusError, AsyncGroq, AsyncStream
|
|
56
56
|
from groq.types import chat
|
|
57
57
|
from groq.types.chat.chat_completion_content_part_image_param import ImageURL
|
|
58
|
+
from groq.types.chat.chat_completion_message import ExecutedTool
|
|
58
59
|
except ImportError as _import_error:
|
|
59
60
|
raise ImportError(
|
|
60
61
|
'Please install `groq` to use the Groq model, '
|
|
@@ -308,22 +309,15 @@ class GroqModel(Model):
|
|
|
308
309
|
timestamp = number_to_datetime(response.created)
|
|
309
310
|
choice = response.choices[0]
|
|
310
311
|
items: list[ModelResponsePart] = []
|
|
311
|
-
if choice.message.executed_tools:
|
|
312
|
-
for tool in choice.message.executed_tools:
|
|
313
|
-
tool_call_id = generate_tool_call_id()
|
|
314
|
-
items.append(
|
|
315
|
-
BuiltinToolCallPart(
|
|
316
|
-
tool_name=tool.type, args=tool.arguments, provider_name=self.system, tool_call_id=tool_call_id
|
|
317
|
-
)
|
|
318
|
-
)
|
|
319
|
-
items.append(
|
|
320
|
-
BuiltinToolReturnPart(
|
|
321
|
-
provider_name=self.system, tool_name=tool.type, content=tool.output, tool_call_id=tool_call_id
|
|
322
|
-
)
|
|
323
|
-
)
|
|
324
312
|
if choice.message.reasoning is not None:
|
|
325
313
|
# NOTE: The `reasoning` field is only present if `groq_reasoning_format` is set to `parsed`.
|
|
326
314
|
items.append(ThinkingPart(content=choice.message.reasoning))
|
|
315
|
+
if choice.message.executed_tools:
|
|
316
|
+
for tool in choice.message.executed_tools:
|
|
317
|
+
call_part, return_part = _map_executed_tool(tool, self.system)
|
|
318
|
+
if call_part and return_part: # pragma: no branch
|
|
319
|
+
items.append(call_part)
|
|
320
|
+
items.append(return_part)
|
|
327
321
|
if choice.message.content is not None:
|
|
328
322
|
# NOTE: The `<think>` tag is only present if `groq_reasoning_format` is set to `raw`.
|
|
329
323
|
items.extend(split_content_into_text_and_thinking(choice.message.content, self.profile.thinking_tags))
|
|
@@ -400,7 +394,7 @@ class GroqModel(Model):
|
|
|
400
394
|
start_tag, end_tag = self.profile.thinking_tags
|
|
401
395
|
texts.append('\n'.join([start_tag, item.content, end_tag]))
|
|
402
396
|
elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
|
|
403
|
-
#
|
|
397
|
+
# These are not currently sent back
|
|
404
398
|
pass
|
|
405
399
|
else:
|
|
406
400
|
assert_never(item)
|
|
@@ -513,8 +507,9 @@ class GroqStreamedResponse(StreamedResponse):
|
|
|
513
507
|
_timestamp: datetime
|
|
514
508
|
_provider_name: str
|
|
515
509
|
|
|
516
|
-
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
510
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
|
|
517
511
|
try:
|
|
512
|
+
executed_tool_call_id: str | None = None
|
|
518
513
|
async for chunk in self._response:
|
|
519
514
|
self._usage += _map_usage(chunk)
|
|
520
515
|
|
|
@@ -530,6 +525,28 @@ class GroqStreamedResponse(StreamedResponse):
|
|
|
530
525
|
self.provider_details = {'finish_reason': raw_finish_reason}
|
|
531
526
|
self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
|
|
532
527
|
|
|
528
|
+
if choice.delta.reasoning is not None:
|
|
529
|
+
# NOTE: The `reasoning` field is only present if `groq_reasoning_format` is set to `parsed`.
|
|
530
|
+
yield self._parts_manager.handle_thinking_delta(
|
|
531
|
+
vendor_part_id='reasoning', content=choice.delta.reasoning
|
|
532
|
+
)
|
|
533
|
+
|
|
534
|
+
if choice.delta.executed_tools:
|
|
535
|
+
for tool in choice.delta.executed_tools:
|
|
536
|
+
call_part, return_part = _map_executed_tool(
|
|
537
|
+
tool, self.provider_name, streaming=True, tool_call_id=executed_tool_call_id
|
|
538
|
+
)
|
|
539
|
+
if call_part:
|
|
540
|
+
executed_tool_call_id = call_part.tool_call_id
|
|
541
|
+
yield self._parts_manager.handle_builtin_tool_call_part(
|
|
542
|
+
vendor_part_id=f'executed_tools-{tool.index}-call', part=call_part
|
|
543
|
+
)
|
|
544
|
+
if return_part:
|
|
545
|
+
executed_tool_call_id = None
|
|
546
|
+
yield self._parts_manager.handle_builtin_tool_return_part(
|
|
547
|
+
vendor_part_id=f'executed_tools-{tool.index}-return', part=return_part
|
|
548
|
+
)
|
|
549
|
+
|
|
533
550
|
# Handle the text part of the response
|
|
534
551
|
content = choice.delta.content
|
|
535
552
|
if content is not None:
|
|
@@ -626,3 +643,37 @@ class _GroqToolUseFailedError(BaseModel):
|
|
|
626
643
|
# }
|
|
627
644
|
|
|
628
645
|
error: _GroqToolUseFailedInnerError
|
|
646
|
+
|
|
647
|
+
|
|
648
|
+
def _map_executed_tool(
|
|
649
|
+
tool: ExecutedTool, provider_name: str, streaming: bool = False, tool_call_id: str | None = None
|
|
650
|
+
) -> tuple[BuiltinToolCallPart | None, BuiltinToolReturnPart | None]:
|
|
651
|
+
if tool.type == 'search':
|
|
652
|
+
if tool.search_results and (tool.search_results.images or tool.search_results.results):
|
|
653
|
+
results = tool.search_results.model_dump(mode='json')
|
|
654
|
+
else:
|
|
655
|
+
results = tool.output
|
|
656
|
+
|
|
657
|
+
tool_call_id = tool_call_id or generate_tool_call_id()
|
|
658
|
+
call_part = BuiltinToolCallPart(
|
|
659
|
+
tool_name=WebSearchTool.kind,
|
|
660
|
+
args=from_json(tool.arguments),
|
|
661
|
+
provider_name=provider_name,
|
|
662
|
+
tool_call_id=tool_call_id,
|
|
663
|
+
)
|
|
664
|
+
return_part = BuiltinToolReturnPart(
|
|
665
|
+
tool_name=WebSearchTool.kind,
|
|
666
|
+
content=results,
|
|
667
|
+
provider_name=provider_name,
|
|
668
|
+
tool_call_id=tool_call_id,
|
|
669
|
+
)
|
|
670
|
+
|
|
671
|
+
if streaming:
|
|
672
|
+
if results:
|
|
673
|
+
return None, return_part
|
|
674
|
+
else:
|
|
675
|
+
return call_part, None
|
|
676
|
+
else:
|
|
677
|
+
return call_part, return_part
|
|
678
|
+
else: # pragma: no cover
|
|
679
|
+
return None, None
|