pydantic-ai-slim 1.0.8__py3-none-any.whl → 1.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -51,9 +51,12 @@ from . import (
51
51
  try:
52
52
  from google.genai import Client
53
53
  from google.genai.types import (
54
+ CodeExecutionResult,
55
+ CodeExecutionResultDict,
54
56
  ContentDict,
55
57
  ContentUnionDict,
56
58
  CountTokensConfigDict,
59
+ ExecutableCode,
57
60
  ExecutableCodeDict,
58
61
  FinishReason as GoogleFinishReason,
59
62
  FunctionCallDict,
@@ -64,6 +67,7 @@ try:
64
67
  GenerateContentResponse,
65
68
  GenerationConfigDict,
66
69
  GoogleSearchDict,
70
+ GroundingMetadata,
67
71
  HttpOptionsDict,
68
72
  MediaResolution,
69
73
  Part,
@@ -434,6 +438,7 @@ class GoogleModel(Model):
434
438
  usage = _metadata_as_usage(response)
435
439
  return _process_response_from_parts(
436
440
  parts,
441
+ candidate.grounding_metadata,
437
442
  response.model_version or self._model_name,
438
443
  self._provider.name,
439
444
  usage,
@@ -569,6 +574,7 @@ class GeminiStreamedResponse(StreamedResponse):
569
574
  _provider_name: str
570
575
 
571
576
  async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
577
+ code_execution_tool_call_id: str | None = None
572
578
  async for chunk in self._response:
573
579
  self._usage = _metadata_as_usage(chunk)
574
580
 
@@ -582,6 +588,19 @@ class GeminiStreamedResponse(StreamedResponse):
582
588
  self.provider_details = {'finish_reason': raw_finish_reason.value}
583
589
  self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
584
590
 
591
+ # Google streams the grounding metadata (including the web search queries and results)
592
+ # _after_ the text that was generated using it, so it would show up out of order in the stream,
593
+ # and cause issues with the logic that doesn't consider text ahead of built-in tool calls as output.
594
+ # If that gets fixed (or we have a workaround), we can uncomment this:
595
+ # web_search_call, web_search_return = _map_grounding_metadata(
596
+ # candidate.grounding_metadata, self.provider_name
597
+ # )
598
+ # if web_search_call and web_search_return:
599
+ # yield self._parts_manager.handle_builtin_tool_call_part(vendor_part_id=uuid4(), part=web_search_call)
600
+ # yield self._parts_manager.handle_builtin_tool_return_part(
601
+ # vendor_part_id=uuid4(), part=web_search_return
602
+ # )
603
+
585
604
  if candidate.content is None or candidate.content.parts is None:
586
605
  if candidate.finish_reason == 'STOP': # pragma: no cover
587
606
  # Normal completion - skip this chunk
@@ -590,6 +609,7 @@ class GeminiStreamedResponse(StreamedResponse):
590
609
  raise UnexpectedModelBehavior('Safety settings triggered', str(chunk))
591
610
  else: # pragma: no cover
592
611
  raise UnexpectedModelBehavior('Content field missing from streaming Gemini response', str(chunk))
612
+
593
613
  parts = candidate.content.parts or []
594
614
  for part in parts:
595
615
  if part.thought_signature:
@@ -617,9 +637,21 @@ class GeminiStreamedResponse(StreamedResponse):
617
637
  if maybe_event is not None: # pragma: no branch
618
638
  yield maybe_event
619
639
  elif part.executable_code is not None:
620
- pass
640
+ code_execution_tool_call_id = _utils.generate_tool_call_id()
641
+ yield self._parts_manager.handle_builtin_tool_call_part(
642
+ vendor_part_id=uuid4(),
643
+ part=_map_executable_code(
644
+ part.executable_code, self.provider_name, code_execution_tool_call_id
645
+ ),
646
+ )
621
647
  elif part.code_execution_result is not None:
622
- pass
648
+ assert code_execution_tool_call_id is not None
649
+ yield self._parts_manager.handle_builtin_tool_return_part(
650
+ vendor_part_id=uuid4(),
651
+ part=_map_code_execution_result(
652
+ part.code_execution_result, self.provider_name, code_execution_tool_call_id
653
+ ),
654
+ )
623
655
  else:
624
656
  assert part.function_response is not None, f'Unexpected part: {part}' # pragma: no cover
625
657
 
@@ -639,7 +671,7 @@ class GeminiStreamedResponse(StreamedResponse):
639
671
  return self._timestamp
640
672
 
641
673
 
642
- def _content_model_response(m: ModelResponse, provider_name: str) -> ContentDict:
674
+ def _content_model_response(m: ModelResponse, provider_name: str) -> ContentDict: # noqa: C901
643
675
  parts: list[PartDict] = []
644
676
  thought_signature: bytes | None = None
645
677
  for item in m.parts:
@@ -663,12 +695,18 @@ def _content_model_response(m: ModelResponse, provider_name: str) -> ContentDict
663
695
  part['thought'] = True
664
696
  elif isinstance(item, BuiltinToolCallPart):
665
697
  if item.provider_name == provider_name:
666
- if item.tool_name == 'code_execution': # pragma: no branch
667
- part['executable_code'] = cast(ExecutableCodeDict, item.args)
698
+ if item.tool_name == CodeExecutionTool.kind:
699
+ part['executable_code'] = cast(ExecutableCodeDict, item.args_as_dict())
700
+ elif item.tool_name == WebSearchTool.kind:
701
+ # Web search calls are not sent back
702
+ pass
668
703
  elif isinstance(item, BuiltinToolReturnPart):
669
704
  if item.provider_name == provider_name:
670
- if item.tool_name == 'code_execution': # pragma: no branch
671
- part['code_execution_result'] = item.content
705
+ if item.tool_name == CodeExecutionTool.kind and isinstance(item.content, dict):
706
+ part['code_execution_result'] = cast(CodeExecutionResultDict, item.content) # pyright: ignore[reportUnknownMemberType]
707
+ elif item.tool_name == WebSearchTool.kind:
708
+ # Web search results are not sent back
709
+ pass
672
710
  else:
673
711
  assert_never(item)
674
712
 
@@ -679,6 +717,7 @@ def _content_model_response(m: ModelResponse, provider_name: str) -> ContentDict
679
717
 
680
718
  def _process_response_from_parts(
681
719
  parts: list[Part],
720
+ grounding_metadata: GroundingMetadata | None,
682
721
  model_name: GoogleModelName,
683
722
  provider_name: str,
684
723
  usage: usage.RequestUsage,
@@ -687,7 +726,17 @@ def _process_response_from_parts(
687
726
  finish_reason: FinishReason | None = None,
688
727
  ) -> ModelResponse:
689
728
  items: list[ModelResponsePart] = []
729
+
730
+ # We don't currently turn `candidate.url_context_metadata` into BuiltinToolCallPart and BuiltinToolReturnPart for UrlContextTool.
731
+ # Please file an issue if you need this.
732
+
733
+ web_search_call, web_search_return = _map_grounding_metadata(grounding_metadata, provider_name)
734
+ if web_search_call and web_search_return:
735
+ items.append(web_search_call)
736
+ items.append(web_search_return)
737
+
690
738
  item: ModelResponsePart | None = None
739
+ code_execution_tool_call_id: str | None = None
691
740
  for part in parts:
692
741
  if part.thought_signature:
693
742
  signature = base64.b64encode(part.thought_signature).decode('utf-8')
@@ -698,16 +747,11 @@ def _process_response_from_parts(
698
747
  item.provider_name = provider_name
699
748
 
700
749
  if part.executable_code is not None:
701
- item = BuiltinToolCallPart(
702
- provider_name=provider_name, args=part.executable_code.model_dump(), tool_name='code_execution'
703
- )
750
+ code_execution_tool_call_id = _utils.generate_tool_call_id()
751
+ item = _map_executable_code(part.executable_code, provider_name, code_execution_tool_call_id)
704
752
  elif part.code_execution_result is not None:
705
- item = BuiltinToolReturnPart(
706
- provider_name=provider_name,
707
- tool_name='code_execution',
708
- content=part.code_execution_result,
709
- tool_call_id='not_provided',
710
- )
753
+ assert code_execution_tool_call_id is not None
754
+ item = _map_code_execution_result(part.code_execution_result, provider_name, code_execution_tool_call_id)
711
755
  elif part.text is not None:
712
756
  if part.thought:
713
757
  item = ThinkingPart(content=part.text)
@@ -799,3 +843,48 @@ def _metadata_as_usage(response: GenerateContentResponse) -> usage.RequestUsage:
799
843
  cache_audio_read_tokens=cache_audio_read_tokens,
800
844
  details=details,
801
845
  )
846
+
847
+
848
+ def _map_executable_code(executable_code: ExecutableCode, provider_name: str, tool_call_id: str) -> BuiltinToolCallPart:
849
+ return BuiltinToolCallPart(
850
+ provider_name=provider_name,
851
+ tool_name=CodeExecutionTool.kind,
852
+ args=executable_code.model_dump(mode='json'),
853
+ tool_call_id=tool_call_id,
854
+ )
855
+
856
+
857
+ def _map_code_execution_result(
858
+ code_execution_result: CodeExecutionResult, provider_name: str, tool_call_id: str
859
+ ) -> BuiltinToolReturnPart:
860
+ return BuiltinToolReturnPart(
861
+ provider_name=provider_name,
862
+ tool_name=CodeExecutionTool.kind,
863
+ content=code_execution_result.model_dump(mode='json'),
864
+ tool_call_id=tool_call_id,
865
+ )
866
+
867
+
868
+ def _map_grounding_metadata(
869
+ grounding_metadata: GroundingMetadata | None, provider_name: str
870
+ ) -> tuple[BuiltinToolCallPart, BuiltinToolReturnPart] | tuple[None, None]:
871
+ if grounding_metadata and (web_search_queries := grounding_metadata.web_search_queries):
872
+ tool_call_id = _utils.generate_tool_call_id()
873
+ return (
874
+ BuiltinToolCallPart(
875
+ provider_name=provider_name,
876
+ tool_name=WebSearchTool.kind,
877
+ tool_call_id=tool_call_id,
878
+ args={'queries': web_search_queries},
879
+ ),
880
+ BuiltinToolReturnPart(
881
+ provider_name=provider_name,
882
+ tool_name=WebSearchTool.kind,
883
+ tool_call_id=tool_call_id,
884
+ content=[chunk.web.model_dump(mode='json') for chunk in grounding_chunks if chunk.web]
885
+ if (grounding_chunks := grounding_metadata.grounding_chunks)
886
+ else None,
887
+ ),
888
+ )
889
+ else:
890
+ return None, None
@@ -8,11 +8,11 @@ from datetime import datetime
8
8
  from typing import Any, Literal, cast, overload
9
9
 
10
10
  from pydantic import BaseModel, Json, ValidationError
11
+ from pydantic_core import from_json
11
12
  from typing_extensions import assert_never
12
13
 
13
- from pydantic_ai._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition
14
-
15
14
  from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
15
+ from .._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition
16
16
  from .._run_context import RunContext
17
17
  from .._thinking_part import split_content_into_text_and_thinking
18
18
  from .._utils import generate_tool_call_id, guard_tool_call_id as _guard_tool_call_id, number_to_datetime
@@ -55,6 +55,7 @@ try:
55
55
  from groq import NOT_GIVEN, APIError, APIStatusError, AsyncGroq, AsyncStream
56
56
  from groq.types import chat
57
57
  from groq.types.chat.chat_completion_content_part_image_param import ImageURL
58
+ from groq.types.chat.chat_completion_message import ExecutedTool
58
59
  except ImportError as _import_error:
59
60
  raise ImportError(
60
61
  'Please install `groq` to use the Groq model, '
@@ -308,22 +309,15 @@ class GroqModel(Model):
308
309
  timestamp = number_to_datetime(response.created)
309
310
  choice = response.choices[0]
310
311
  items: list[ModelResponsePart] = []
311
- if choice.message.executed_tools:
312
- for tool in choice.message.executed_tools:
313
- tool_call_id = generate_tool_call_id()
314
- items.append(
315
- BuiltinToolCallPart(
316
- tool_name=tool.type, args=tool.arguments, provider_name=self.system, tool_call_id=tool_call_id
317
- )
318
- )
319
- items.append(
320
- BuiltinToolReturnPart(
321
- provider_name=self.system, tool_name=tool.type, content=tool.output, tool_call_id=tool_call_id
322
- )
323
- )
324
312
  if choice.message.reasoning is not None:
325
313
  # NOTE: The `reasoning` field is only present if `groq_reasoning_format` is set to `parsed`.
326
314
  items.append(ThinkingPart(content=choice.message.reasoning))
315
+ if choice.message.executed_tools:
316
+ for tool in choice.message.executed_tools:
317
+ call_part, return_part = _map_executed_tool(tool, self.system)
318
+ if call_part and return_part: # pragma: no branch
319
+ items.append(call_part)
320
+ items.append(return_part)
327
321
  if choice.message.content is not None:
328
322
  # NOTE: The `<think>` tag is only present if `groq_reasoning_format` is set to `raw`.
329
323
  items.extend(split_content_into_text_and_thinking(choice.message.content, self.profile.thinking_tags))
@@ -400,7 +394,7 @@ class GroqModel(Model):
400
394
  start_tag, end_tag = self.profile.thinking_tags
401
395
  texts.append('\n'.join([start_tag, item.content, end_tag]))
402
396
  elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
403
- # This is currently never returned from groq
397
+ # These are not currently sent back
404
398
  pass
405
399
  else:
406
400
  assert_never(item)
@@ -513,8 +507,9 @@ class GroqStreamedResponse(StreamedResponse):
513
507
  _timestamp: datetime
514
508
  _provider_name: str
515
509
 
516
- async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
510
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
517
511
  try:
512
+ executed_tool_call_id: str | None = None
518
513
  async for chunk in self._response:
519
514
  self._usage += _map_usage(chunk)
520
515
 
@@ -530,6 +525,28 @@ class GroqStreamedResponse(StreamedResponse):
530
525
  self.provider_details = {'finish_reason': raw_finish_reason}
531
526
  self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
532
527
 
528
+ if choice.delta.reasoning is not None:
529
+ # NOTE: The `reasoning` field is only present if `groq_reasoning_format` is set to `parsed`.
530
+ yield self._parts_manager.handle_thinking_delta(
531
+ vendor_part_id='reasoning', content=choice.delta.reasoning
532
+ )
533
+
534
+ if choice.delta.executed_tools:
535
+ for tool in choice.delta.executed_tools:
536
+ call_part, return_part = _map_executed_tool(
537
+ tool, self.provider_name, streaming=True, tool_call_id=executed_tool_call_id
538
+ )
539
+ if call_part:
540
+ executed_tool_call_id = call_part.tool_call_id
541
+ yield self._parts_manager.handle_builtin_tool_call_part(
542
+ vendor_part_id=f'executed_tools-{tool.index}-call', part=call_part
543
+ )
544
+ if return_part:
545
+ executed_tool_call_id = None
546
+ yield self._parts_manager.handle_builtin_tool_return_part(
547
+ vendor_part_id=f'executed_tools-{tool.index}-return', part=return_part
548
+ )
549
+
533
550
  # Handle the text part of the response
534
551
  content = choice.delta.content
535
552
  if content is not None:
@@ -626,3 +643,37 @@ class _GroqToolUseFailedError(BaseModel):
626
643
  # }
627
644
 
628
645
  error: _GroqToolUseFailedInnerError
646
+
647
+
648
+ def _map_executed_tool(
649
+ tool: ExecutedTool, provider_name: str, streaming: bool = False, tool_call_id: str | None = None
650
+ ) -> tuple[BuiltinToolCallPart | None, BuiltinToolReturnPart | None]:
651
+ if tool.type == 'search':
652
+ if tool.search_results and (tool.search_results.images or tool.search_results.results):
653
+ results = tool.search_results.model_dump(mode='json')
654
+ else:
655
+ results = tool.output
656
+
657
+ tool_call_id = tool_call_id or generate_tool_call_id()
658
+ call_part = BuiltinToolCallPart(
659
+ tool_name=WebSearchTool.kind,
660
+ args=from_json(tool.arguments),
661
+ provider_name=provider_name,
662
+ tool_call_id=tool_call_id,
663
+ )
664
+ return_part = BuiltinToolReturnPart(
665
+ tool_name=WebSearchTool.kind,
666
+ content=results,
667
+ provider_name=provider_name,
668
+ tool_call_id=tool_call_id,
669
+ )
670
+
671
+ if streaming:
672
+ if results:
673
+ return None, return_part
674
+ else:
675
+ return call_part, None
676
+ else:
677
+ return call_part, return_part
678
+ else: # pragma: no cover
679
+ return None, None