pydantic-ai-slim 0.8.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (75) hide show
  1. pydantic_ai/__init__.py +28 -2
  2. pydantic_ai/_a2a.py +1 -1
  3. pydantic_ai/_agent_graph.py +323 -156
  4. pydantic_ai/_function_schema.py +5 -5
  5. pydantic_ai/_griffe.py +2 -1
  6. pydantic_ai/_otel_messages.py +2 -2
  7. pydantic_ai/_output.py +31 -35
  8. pydantic_ai/_parts_manager.py +7 -5
  9. pydantic_ai/_run_context.py +3 -1
  10. pydantic_ai/_system_prompt.py +2 -2
  11. pydantic_ai/_tool_manager.py +32 -28
  12. pydantic_ai/_utils.py +14 -26
  13. pydantic_ai/ag_ui.py +82 -51
  14. pydantic_ai/agent/__init__.py +84 -17
  15. pydantic_ai/agent/abstract.py +35 -4
  16. pydantic_ai/agent/wrapper.py +6 -0
  17. pydantic_ai/builtin_tools.py +2 -2
  18. pydantic_ai/common_tools/duckduckgo.py +4 -2
  19. pydantic_ai/durable_exec/temporal/__init__.py +70 -17
  20. pydantic_ai/durable_exec/temporal/_agent.py +93 -11
  21. pydantic_ai/durable_exec/temporal/_function_toolset.py +53 -6
  22. pydantic_ai/durable_exec/temporal/_logfire.py +6 -3
  23. pydantic_ai/durable_exec/temporal/_mcp_server.py +2 -1
  24. pydantic_ai/durable_exec/temporal/_model.py +2 -2
  25. pydantic_ai/durable_exec/temporal/_run_context.py +2 -1
  26. pydantic_ai/durable_exec/temporal/_toolset.py +2 -1
  27. pydantic_ai/exceptions.py +45 -2
  28. pydantic_ai/format_prompt.py +2 -2
  29. pydantic_ai/mcp.py +15 -27
  30. pydantic_ai/messages.py +156 -44
  31. pydantic_ai/models/__init__.py +20 -7
  32. pydantic_ai/models/anthropic.py +10 -17
  33. pydantic_ai/models/bedrock.py +55 -57
  34. pydantic_ai/models/cohere.py +3 -3
  35. pydantic_ai/models/fallback.py +2 -2
  36. pydantic_ai/models/function.py +25 -23
  37. pydantic_ai/models/gemini.py +13 -14
  38. pydantic_ai/models/google.py +19 -5
  39. pydantic_ai/models/groq.py +127 -39
  40. pydantic_ai/models/huggingface.py +5 -5
  41. pydantic_ai/models/instrumented.py +49 -21
  42. pydantic_ai/models/mcp_sampling.py +3 -1
  43. pydantic_ai/models/mistral.py +8 -8
  44. pydantic_ai/models/openai.py +37 -42
  45. pydantic_ai/models/test.py +24 -4
  46. pydantic_ai/output.py +27 -32
  47. pydantic_ai/profiles/__init__.py +3 -3
  48. pydantic_ai/profiles/groq.py +1 -1
  49. pydantic_ai/profiles/openai.py +25 -4
  50. pydantic_ai/providers/__init__.py +4 -0
  51. pydantic_ai/providers/anthropic.py +2 -3
  52. pydantic_ai/providers/bedrock.py +3 -2
  53. pydantic_ai/providers/google_vertex.py +2 -1
  54. pydantic_ai/providers/groq.py +21 -2
  55. pydantic_ai/providers/litellm.py +134 -0
  56. pydantic_ai/result.py +173 -52
  57. pydantic_ai/retries.py +52 -31
  58. pydantic_ai/run.py +12 -5
  59. pydantic_ai/tools.py +127 -23
  60. pydantic_ai/toolsets/__init__.py +4 -1
  61. pydantic_ai/toolsets/_dynamic.py +4 -4
  62. pydantic_ai/toolsets/abstract.py +18 -2
  63. pydantic_ai/toolsets/approval_required.py +32 -0
  64. pydantic_ai/toolsets/combined.py +7 -12
  65. pydantic_ai/toolsets/{deferred.py → external.py} +11 -5
  66. pydantic_ai/toolsets/filtered.py +1 -1
  67. pydantic_ai/toolsets/function.py +58 -21
  68. pydantic_ai/toolsets/wrapper.py +2 -1
  69. pydantic_ai/usage.py +44 -8
  70. {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/METADATA +8 -9
  71. pydantic_ai_slim-1.0.0.dist-info/RECORD +121 -0
  72. pydantic_ai_slim-0.8.0.dist-info/RECORD +0 -119
  73. {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/WHEEL +0 -0
  74. {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/entry_points.txt +0 -0
  75. {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -7,16 +7,17 @@ specific LLM being used.
7
7
  from __future__ import annotations as _annotations
8
8
 
9
9
  import base64
10
+ import warnings
10
11
  from abc import ABC, abstractmethod
11
12
  from collections.abc import AsyncIterator, Iterator
12
13
  from contextlib import asynccontextmanager, contextmanager
13
14
  from dataclasses import dataclass, field, replace
14
15
  from datetime import datetime
15
16
  from functools import cache, cached_property
16
- from typing import Any, Generic, TypeVar, overload
17
+ from typing import Any, Generic, Literal, TypeVar, overload
17
18
 
18
19
  import httpx
19
- from typing_extensions import Literal, TypeAliasType, TypedDict
20
+ from typing_extensions import TypeAliasType, TypedDict
20
21
 
21
22
  from .. import _utils
22
23
  from .._output import OutputObjectDefinition
@@ -366,7 +367,7 @@ KnownModelName = TypeAliasType(
366
367
  """
367
368
 
368
369
 
369
- @dataclass(repr=False)
370
+ @dataclass(repr=False, kw_only=True)
370
371
  class ModelRequestParameters:
371
372
  """Configuration for an agent's request to a model, specifically related to tools and output handling."""
372
373
 
@@ -551,6 +552,7 @@ class StreamedResponse(ABC):
551
552
  """Streamed response from an LLM when calling a tool."""
552
553
 
553
554
  model_request_parameters: ModelRequestParameters
555
+
554
556
  final_result_event: FinalResultEvent | None = field(default=None, init=False)
555
557
 
556
558
  _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False)
@@ -684,19 +686,29 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
684
686
  try:
685
687
  provider, model_name = model.split(':', maxsplit=1)
686
688
  except ValueError:
689
+ provider = None
687
690
  model_name = model
688
- # TODO(Marcelo): We should deprecate this way.
689
691
  if model_name.startswith(('gpt', 'o1', 'o3')):
690
692
  provider = 'openai'
691
693
  elif model_name.startswith('claude'):
692
694
  provider = 'anthropic'
693
695
  elif model_name.startswith('gemini'):
694
696
  provider = 'google-gla'
697
+
698
+ if provider is not None:
699
+ warnings.warn(
700
+ f"Specifying a model name without a provider prefix is deprecated. Instead of {model_name!r}, use '{provider}:{model_name}'.",
701
+ DeprecationWarning,
702
+ )
695
703
  else:
696
704
  raise UserError(f'Unknown model: {model}')
697
705
 
698
- if provider == 'vertexai':
699
- provider = 'google-vertex' # pragma: no cover
706
+ if provider == 'vertexai': # pragma: no cover
707
+ warnings.warn(
708
+ "The 'vertexai' provider name is deprecated. Use 'google-vertex' instead.",
709
+ DeprecationWarning,
710
+ )
711
+ provider = 'google-vertex'
700
712
 
701
713
  if provider == 'cohere':
702
714
  from .cohere import CohereModel
@@ -716,6 +728,7 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
716
728
  'openrouter',
717
729
  'together',
718
730
  'vercel',
731
+ 'litellm',
719
732
  ):
720
733
  from .openai import OpenAIChatModel
721
734
 
@@ -909,5 +922,5 @@ def _get_final_result_event(e: ModelResponseStreamEvent, params: ModelRequestPar
909
922
  elif isinstance(new_part, ToolCallPart) and (tool_def := params.tool_defs.get(new_part.tool_name)):
910
923
  if tool_def.kind == 'output':
911
924
  return FinalResultEvent(tool_name=new_part.tool_name, tool_call_id=new_part.tool_call_id)
912
- elif tool_def.kind == 'deferred':
925
+ elif tool_def.defer:
913
926
  return FinalResultEvent(tool_name=None, tool_call_id=None)
@@ -6,7 +6,7 @@ from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator
6
6
  from contextlib import asynccontextmanager
7
7
  from dataclasses import dataclass, field
8
8
  from datetime import datetime, timezone
9
- from typing import Any, Literal, Union, cast, overload
9
+ from typing import Any, Literal, cast, overload
10
10
 
11
11
  from typing_extensions import assert_never
12
12
 
@@ -99,7 +99,7 @@ except ImportError as _import_error:
99
99
  LatestAnthropicModelNames = ModelParam
100
100
  """Latest Anthropic models."""
101
101
 
102
- AnthropicModelName = Union[str, LatestAnthropicModelNames]
102
+ AnthropicModelName = str | LatestAnthropicModelNames
103
103
  """Possible Anthropic model names.
104
104
 
105
105
  Since Anthropic supports a variety of date-stamped models, we explicitly list the latest models but
@@ -290,7 +290,7 @@ class AnthropicModel(Model):
290
290
  for item in response.content:
291
291
  if isinstance(item, BetaTextBlock):
292
292
  items.append(TextPart(content=item.text))
293
- elif isinstance(item, (BetaWebSearchToolResultBlock, BetaCodeExecutionToolResultBlock)):
293
+ elif isinstance(item, BetaWebSearchToolResultBlock | BetaCodeExecutionToolResultBlock):
294
294
  items.append(
295
295
  BuiltinToolReturnPart(
296
296
  provider_name='anthropic',
@@ -327,10 +327,10 @@ class AnthropicModel(Model):
327
327
  )
328
328
 
329
329
  return ModelResponse(
330
- items,
330
+ parts=items,
331
331
  usage=_map_usage(response),
332
332
  model_name=response.model,
333
- provider_request_id=response.id,
333
+ provider_response_id=response.id,
334
334
  provider_name=self._provider.name,
335
335
  )
336
336
 
@@ -536,7 +536,7 @@ class AnthropicModel(Model):
536
536
  }
537
537
 
538
538
 
539
- def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.RequestUsage:
539
+ def _map_usage(message: BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent) -> usage.RequestUsage:
540
540
  if isinstance(message, BetaMessage):
541
541
  response_usage = message.usage
542
542
  elif isinstance(message, BetaRawMessageStartEvent):
@@ -544,12 +544,7 @@ def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.Reques
544
544
  elif isinstance(message, BetaRawMessageDeltaEvent):
545
545
  response_usage = message.usage
546
546
  else:
547
- # No usage information provided in:
548
- # - RawMessageStopEvent
549
- # - RawContentBlockStartEvent
550
- # - RawContentBlockDeltaEvent
551
- # - RawContentBlockStopEvent
552
- return usage.RequestUsage()
547
+ assert_never(message)
553
548
 
554
549
  # Store all integer-typed usage values in the details, except 'output_tokens' which is represented exactly by
555
550
  # `response_tokens`
@@ -586,10 +581,8 @@ class AnthropicStreamedResponse(StreamedResponse):
586
581
  current_block: BetaContentBlock | None = None
587
582
 
588
583
  async for event in self._response:
589
- self._usage += _map_usage(event)
590
-
591
584
  if isinstance(event, BetaRawMessageStartEvent):
592
- pass
585
+ self._usage = _map_usage(event)
593
586
 
594
587
  elif isinstance(event, BetaRawContentBlockStartEvent):
595
588
  current_block = event.content_block
@@ -652,9 +645,9 @@ class AnthropicStreamedResponse(StreamedResponse):
652
645
  pass
653
646
 
654
647
  elif isinstance(event, BetaRawMessageDeltaEvent):
655
- pass
648
+ self._usage = _map_usage(event)
656
649
 
657
- elif isinstance(event, (BetaRawContentBlockStopEvent, BetaRawMessageStopEvent)): # pragma: no branch
650
+ elif isinstance(event, BetaRawContentBlockStopEvent | BetaRawMessageStopEvent): # pragma: no branch
658
651
  current_block = None
659
652
 
660
653
  @property
@@ -2,13 +2,12 @@ from __future__ import annotations
2
2
 
3
3
  import functools
4
4
  import typing
5
- import warnings
6
5
  from collections.abc import AsyncIterator, Iterable, Iterator, Mapping
7
6
  from contextlib import asynccontextmanager
8
7
  from dataclasses import dataclass, field
9
8
  from datetime import datetime
10
9
  from itertools import count
11
- from typing import TYPE_CHECKING, Any, Generic, Literal, Union, cast, overload
10
+ from typing import TYPE_CHECKING, Any, Generic, Literal, cast, overload
12
11
 
13
12
  import anyio
14
13
  import anyio.to_thread
@@ -125,7 +124,7 @@ LatestBedrockModelNames = Literal[
125
124
  ]
126
125
  """Latest Bedrock models."""
127
126
 
128
- BedrockModelName = Union[str, LatestBedrockModelNames]
127
+ BedrockModelName = str | LatestBedrockModelNames
129
128
  """Possible Bedrock model names.
130
129
 
131
130
  Since Bedrock supports a variety of date-stamped models, we explicitly list the latest models but allow any name in the type hints.
@@ -301,9 +300,13 @@ class BedrockConverseModel(Model):
301
300
  input_tokens=response['usage']['inputTokens'],
302
301
  output_tokens=response['usage']['outputTokens'],
303
302
  )
304
- vendor_id = response.get('ResponseMetadata', {}).get('RequestId', None)
303
+ response_id = response.get('ResponseMetadata', {}).get('RequestId', None)
305
304
  return ModelResponse(
306
- items, usage=u, model_name=self.model_name, provider_request_id=vendor_id, provider_name=self._provider.name
305
+ parts=items,
306
+ usage=u,
307
+ model_name=self.model_name,
308
+ provider_response_id=response_id,
309
+ provider_name=self._provider.name,
307
310
  )
308
311
 
309
312
  @overload
@@ -486,7 +489,7 @@ class BedrockConverseModel(Model):
486
489
  else:
487
490
  # NOTE: We don't pass the thinking part to Bedrock for models other than Claude since it raises an error.
488
491
  pass
489
- elif isinstance(item, (BuiltinToolCallPart, BuiltinToolReturnPart)):
492
+ elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart):
490
493
  pass
491
494
  else:
492
495
  assert isinstance(item, ToolCallPart)
@@ -542,7 +545,7 @@ class BedrockConverseModel(Model):
542
545
  content.append({'video': {'format': format, 'source': {'bytes': item.data}}})
543
546
  else:
544
547
  raise NotImplementedError('Binary content is not supported yet.')
545
- elif isinstance(item, (ImageUrl, DocumentUrl, VideoUrl)):
548
+ elif isinstance(item, ImageUrl | DocumentUrl | VideoUrl):
546
549
  downloaded_item = await download_item(item, data_format='bytes', type_format='extension')
547
550
  format = downloaded_item['data_type']
548
551
  if item.kind == 'image-url':
@@ -597,7 +600,7 @@ class BedrockStreamedResponse(StreamedResponse):
597
600
  _provider_name: str
598
601
  _timestamp: datetime = field(default_factory=_utils.now_utc)
599
602
 
600
- async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
603
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
601
604
  """Return an async iterator of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.
602
605
 
603
606
  This method should be implemented by subclasses to translate the vendor-specific stream of events into
@@ -606,60 +609,55 @@ class BedrockStreamedResponse(StreamedResponse):
606
609
  chunk: ConverseStreamOutputTypeDef
607
610
  tool_id: str | None = None
608
611
  async for chunk in _AsyncIteratorWrapper(self._event_stream):
609
- # TODO(Marcelo): Switch this to `match` when we drop Python 3.9 support.
610
- if 'messageStart' in chunk:
611
- continue
612
- if 'messageStop' in chunk:
613
- continue
614
- if 'metadata' in chunk:
615
- if 'usage' in chunk['metadata']: # pragma: no branch
616
- self._usage += self._map_usage(chunk['metadata'])
617
- continue
618
- if 'contentBlockStart' in chunk:
619
- index = chunk['contentBlockStart']['contentBlockIndex']
620
- start = chunk['contentBlockStart']['start']
621
- if 'toolUse' in start: # pragma: no branch
622
- tool_use_start = start['toolUse']
623
- tool_id = tool_use_start['toolUseId']
624
- tool_name = tool_use_start['name']
625
- maybe_event = self._parts_manager.handle_tool_call_delta(
626
- vendor_part_id=index,
627
- tool_name=tool_name,
628
- args=None,
629
- tool_call_id=tool_id,
630
- )
631
- if maybe_event: # pragma: no branch
632
- yield maybe_event
633
- if 'contentBlockDelta' in chunk:
634
- index = chunk['contentBlockDelta']['contentBlockIndex']
635
- delta = chunk['contentBlockDelta']['delta']
636
- if 'reasoningContent' in delta:
637
- if text := delta['reasoningContent'].get('text'):
612
+ match chunk:
613
+ case {'messageStart': _}:
614
+ continue
615
+ case {'messageStop': _}:
616
+ continue
617
+ case {'metadata': metadata}:
618
+ if 'usage' in metadata: # pragma: no branch
619
+ self._usage += self._map_usage(metadata)
620
+ continue
621
+ case {'contentBlockStart': content_block_start}:
622
+ index = content_block_start['contentBlockIndex']
623
+ start = content_block_start['start']
624
+ if 'toolUse' in start: # pragma: no branch
625
+ tool_use_start = start['toolUse']
626
+ tool_id = tool_use_start['toolUseId']
627
+ tool_name = tool_use_start['name']
628
+ maybe_event = self._parts_manager.handle_tool_call_delta(
629
+ vendor_part_id=index,
630
+ tool_name=tool_name,
631
+ args=None,
632
+ tool_call_id=tool_id,
633
+ )
634
+ if maybe_event: # pragma: no branch
635
+ yield maybe_event
636
+ case {'contentBlockDelta': content_block_delta}:
637
+ index = content_block_delta['contentBlockIndex']
638
+ delta = content_block_delta['delta']
639
+ if 'reasoningContent' in delta:
638
640
  yield self._parts_manager.handle_thinking_delta(
639
641
  vendor_part_id=index,
640
- content=text,
642
+ content=delta['reasoningContent'].get('text'),
641
643
  signature=delta['reasoningContent'].get('signature'),
642
644
  )
643
- else: # pragma: no cover
644
- warnings.warn(
645
- f'Only text reasoning content is supported yet, but you got {delta["reasoningContent"]}. '
646
- 'Please report this to the maintainers.',
647
- UserWarning,
645
+ if 'text' in delta:
646
+ maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=delta['text'])
647
+ if maybe_event is not None: # pragma: no branch
648
+ yield maybe_event
649
+ if 'toolUse' in delta:
650
+ tool_use = delta['toolUse']
651
+ maybe_event = self._parts_manager.handle_tool_call_delta(
652
+ vendor_part_id=index,
653
+ tool_name=tool_use.get('name'),
654
+ args=tool_use.get('input'),
655
+ tool_call_id=tool_id,
648
656
  )
649
- if 'text' in delta:
650
- maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=delta['text'])
651
- if maybe_event is not None: # pragma: no branch
652
- yield maybe_event
653
- if 'toolUse' in delta:
654
- tool_use = delta['toolUse']
655
- maybe_event = self._parts_manager.handle_tool_call_delta(
656
- vendor_part_id=index,
657
- tool_name=tool_use.get('name'),
658
- args=tool_use.get('input'),
659
- tool_call_id=tool_id,
660
- )
661
- if maybe_event: # pragma: no branch
662
- yield maybe_event
657
+ if maybe_event: # pragma: no branch
658
+ yield maybe_event
659
+ case _:
660
+ pass # pyright wants match statements to be exhaustive
663
661
 
664
662
  @property
665
663
  def model_name(self) -> str:
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  from collections.abc import Iterable
4
4
  from dataclasses import dataclass, field
5
- from typing import Literal, Union, cast
5
+ from typing import Literal, cast
6
6
 
7
7
  from typing_extensions import assert_never
8
8
 
@@ -72,7 +72,7 @@ LatestCohereModelNames = Literal[
72
72
  ]
73
73
  """Latest Cohere models."""
74
74
 
75
- CohereModelName = Union[str, LatestCohereModelNames]
75
+ CohereModelName = str | LatestCohereModelNames
76
76
  """Possible Cohere model names.
77
77
 
78
78
  Since Cohere supports a variety of date-stamped models, we explicitly list the latest models but
@@ -228,7 +228,7 @@ class CohereModel(Model):
228
228
  pass
229
229
  elif isinstance(item, ToolCallPart):
230
230
  tool_calls.append(self._map_tool_call(item))
231
- elif isinstance(item, (BuiltinToolCallPart, BuiltinToolReturnPart)): # pragma: no cover
231
+ elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
232
232
  # This is currently never returned from cohere
233
233
  pass
234
234
  else:
@@ -1,9 +1,9 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
- from collections.abc import AsyncIterator
3
+ from collections.abc import AsyncIterator, Callable
4
4
  from contextlib import AsyncExitStack, asynccontextmanager, suppress
5
5
  from dataclasses import dataclass, field
6
- from typing import TYPE_CHECKING, Any, Callable
6
+ from typing import TYPE_CHECKING, Any
7
7
 
8
8
  from opentelemetry.trace import get_current_span
9
9
 
@@ -2,14 +2,14 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import inspect
4
4
  import re
5
- from collections.abc import AsyncIterator, Awaitable, Iterable, Sequence
5
+ from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence
6
6
  from contextlib import asynccontextmanager
7
- from dataclasses import dataclass, field
7
+ from dataclasses import KW_ONLY, dataclass, field
8
8
  from datetime import datetime
9
9
  from itertools import chain
10
- from typing import Any, Callable, Union
10
+ from typing import Any, TypeAlias
11
11
 
12
- from typing_extensions import TypeAlias, assert_never, overload
12
+ from typing_extensions import assert_never, overload
13
13
 
14
14
  from .. import _utils, usage
15
15
  from .._run_context import RunContext
@@ -44,8 +44,8 @@ class FunctionModel(Model):
44
44
  Apart from `__init__`, all methods are private or match those of the base class.
45
45
  """
46
46
 
47
- function: FunctionDef | None = None
48
- stream_function: StreamFunctionDef | None = None
47
+ function: FunctionDef | None
48
+ stream_function: StreamFunctionDef | None
49
49
 
50
50
  _model_name: str = field(repr=False)
51
51
  _system: str = field(default='function', repr=False)
@@ -120,10 +120,10 @@ class FunctionModel(Model):
120
120
  model_request_parameters: ModelRequestParameters,
121
121
  ) -> ModelResponse:
122
122
  agent_info = AgentInfo(
123
- model_request_parameters.function_tools,
124
- model_request_parameters.allow_text_output,
125
- model_request_parameters.output_tools,
126
- model_settings,
123
+ function_tools=model_request_parameters.function_tools,
124
+ allow_text_output=model_request_parameters.allow_text_output,
125
+ output_tools=model_request_parameters.output_tools,
126
+ model_settings=model_settings,
127
127
  )
128
128
 
129
129
  assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
@@ -149,10 +149,10 @@ class FunctionModel(Model):
149
149
  run_context: RunContext[Any] | None = None,
150
150
  ) -> AsyncIterator[StreamedResponse]:
151
151
  agent_info = AgentInfo(
152
- model_request_parameters.function_tools,
153
- model_request_parameters.allow_text_output,
154
- model_request_parameters.output_tools,
155
- model_settings,
152
+ function_tools=model_request_parameters.function_tools,
153
+ allow_text_output=model_request_parameters.allow_text_output,
154
+ output_tools=model_request_parameters.output_tools,
155
+ model_settings=model_settings,
156
156
  )
157
157
 
158
158
  assert self.stream_function is not None, (
@@ -182,7 +182,7 @@ class FunctionModel(Model):
182
182
  return self._system
183
183
 
184
184
 
185
- @dataclass(frozen=True)
185
+ @dataclass(frozen=True, kw_only=True)
186
186
  class AgentInfo:
187
187
  """Information about an agent.
188
188
 
@@ -212,13 +212,17 @@ class DeltaToolCall:
212
212
 
213
213
  name: str | None = None
214
214
  """Incremental change to the name of the tool."""
215
+
215
216
  json_args: str | None = None
216
217
  """Incremental change to the arguments as JSON"""
218
+
219
+ _: KW_ONLY
220
+
217
221
  tool_call_id: str | None = None
218
222
  """Incremental change to the tool call ID."""
219
223
 
220
224
 
221
- @dataclass
225
+ @dataclass(kw_only=True)
222
226
  class DeltaThinkingPart:
223
227
  """Incremental change to a thinking part.
224
228
 
@@ -237,18 +241,16 @@ DeltaToolCalls: TypeAlias = dict[int, DeltaToolCall]
237
241
  DeltaThinkingCalls: TypeAlias = dict[int, DeltaThinkingPart]
238
242
  """A mapping of thinking call IDs to incremental changes."""
239
243
 
240
- # TODO: Change the signature to Callable[[list[ModelMessage], ModelSettings, ModelRequestParameters], ...]
241
- FunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], Union[ModelResponse, Awaitable[ModelResponse]]]
244
+ FunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], ModelResponse | Awaitable[ModelResponse]]
242
245
  """A function used to generate a non-streamed response."""
243
246
 
244
- # TODO: Change signature as indicated above
245
247
  StreamFunctionDef: TypeAlias = Callable[
246
- [list[ModelMessage], AgentInfo], AsyncIterator[Union[str, DeltaToolCalls, DeltaThinkingCalls]]
248
+ [list[ModelMessage], AgentInfo], AsyncIterator[str | DeltaToolCalls | DeltaThinkingCalls]
247
249
  ]
248
250
  """A function used to generate a streamed response.
249
251
 
250
- While this is defined as having return type of `AsyncIterator[Union[str, DeltaToolCalls, DeltaThinkingCalls]]`, it should
251
- really be considered as `Union[AsyncIterator[str], AsyncIterator[DeltaToolCalls], AsyncIterator[DeltaThinkingCalls]]`,
252
+ While this is defined as having return type of `AsyncIterator[str | DeltaToolCalls | DeltaThinkingCalls]`, it should
253
+ really be considered as `AsyncIterator[str] | AsyncIterator[DeltaToolCalls] | AsyncIterator[DeltaThinkingCalls]`,
252
254
 
253
255
  E.g. you need to yield all text, all `DeltaToolCalls`, or all `DeltaThinkingCalls`, not mix them.
254
256
  """
@@ -326,7 +328,7 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.RequestUsage:
326
328
  for message in messages:
327
329
  if isinstance(message, ModelRequest):
328
330
  for part in message.parts:
329
- if isinstance(part, (SystemPromptPart, UserPromptPart)):
331
+ if isinstance(part, SystemPromptPart | UserPromptPart):
330
332
  request_tokens += _estimate_string_tokens(part.content)
331
333
  elif isinstance(part, ToolReturnPart):
332
334
  request_tokens += _estimate_string_tokens(part.model_response_str())
@@ -5,7 +5,7 @@ from collections.abc import AsyncIterator, Sequence
5
5
  from contextlib import asynccontextmanager
6
6
  from dataclasses import dataclass, field
7
7
  from datetime import datetime
8
- from typing import Annotated, Any, Literal, Protocol, Union, cast
8
+ from typing import Annotated, Any, Literal, Protocol, cast
9
9
  from uuid import uuid4
10
10
 
11
11
  import httpx
@@ -51,7 +51,7 @@ LatestGeminiModelNames = Literal[
51
51
  ]
52
52
  """Latest Gemini models."""
53
53
 
54
- GeminiModelName = Union[str, LatestGeminiModelNames]
54
+ GeminiModelName = str | LatestGeminiModelNames
55
55
  """Possible Gemini model names.
56
56
 
57
57
  Since Gemini supports a variety of date-stamped models, we explicitly list the latest models but
@@ -211,7 +211,9 @@ class GeminiModel(Model):
211
211
  generation_config = _settings_to_generation_config(model_settings)
212
212
  if model_request_parameters.output_mode == 'native':
213
213
  if tools:
214
- raise UserError('Gemini does not support structured output and tools at the same time.')
214
+ raise UserError(
215
+ 'Gemini does not support `NativeOutput` and tools at the same time. Use `output_type=ToolOutput(...)` instead.'
216
+ )
215
217
 
216
218
  generation_config['response_mime_type'] = 'application/json'
217
219
 
@@ -615,7 +617,7 @@ def _content_model_response(m: ModelResponse) -> _GeminiContent:
615
617
  elif isinstance(item, TextPart):
616
618
  if item.content:
617
619
  parts.append(_GeminiTextPart(text=item.content))
618
- elif isinstance(item, (BuiltinToolCallPart, BuiltinToolReturnPart)): # pragma: no cover
620
+ elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
619
621
  # This is currently never returned from gemini
620
622
  pass
621
623
  else:
@@ -690,7 +692,7 @@ def _process_response_from_parts(
690
692
  f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
691
693
  )
692
694
  return ModelResponse(
693
- parts=items, usage=usage, model_name=model_name, provider_request_id=vendor_id, provider_details=vendor_details
695
+ parts=items, usage=usage, model_name=model_name, provider_response_id=vendor_id, provider_details=vendor_details
694
696
  )
695
697
 
696
698
 
@@ -735,16 +737,13 @@ def _part_discriminator(v: Any) -> str:
735
737
 
736
738
  # See <https://ai.google.dev/api/caching#Part>
737
739
  # we don't currently support other part types
738
- # TODO discriminator
739
740
  _GeminiPartUnion = Annotated[
740
- Union[
741
- Annotated[_GeminiTextPart, pydantic.Tag('text')],
742
- Annotated[_GeminiFunctionCallPart, pydantic.Tag('function_call')],
743
- Annotated[_GeminiFunctionResponsePart, pydantic.Tag('function_response')],
744
- Annotated[_GeminiInlineDataPart, pydantic.Tag('inline_data')],
745
- Annotated[_GeminiFileDataPart, pydantic.Tag('file_data')],
746
- Annotated[_GeminiThoughtPart, pydantic.Tag('thought')],
747
- ],
741
+ Annotated[_GeminiTextPart, pydantic.Tag('text')]
742
+ | Annotated[_GeminiFunctionCallPart, pydantic.Tag('function_call')]
743
+ | Annotated[_GeminiFunctionResponsePart, pydantic.Tag('function_response')]
744
+ | Annotated[_GeminiInlineDataPart, pydantic.Tag('inline_data')]
745
+ | Annotated[_GeminiFileDataPart, pydantic.Tag('file_data')]
746
+ | Annotated[_GeminiThoughtPart, pydantic.Tag('thought')],
748
747
  pydantic.Discriminator(_part_discriminator),
749
748
  ]
750
749
 
@@ -5,7 +5,7 @@ from collections.abc import AsyncIterator, Awaitable
5
5
  from contextlib import asynccontextmanager
6
6
  from dataclasses import dataclass, field
7
7
  from datetime import datetime
8
- from typing import Any, Literal, Union, cast, overload
8
+ from typing import Any, Literal, cast, overload
9
9
  from uuid import uuid4
10
10
 
11
11
  from typing_extensions import assert_never
@@ -91,7 +91,7 @@ LatestGoogleModelNames = Literal[
91
91
  ]
92
92
  """Latest Gemini models."""
93
93
 
94
- GoogleModelName = Union[str, LatestGoogleModelNames]
94
+ GoogleModelName = str | LatestGoogleModelNames
95
95
  """Possible Gemini model names.
96
96
 
97
97
  Since Gemini supports a variety of date-stamped models, we explicitly list the latest models but
@@ -264,6 +264,14 @@ class GoogleModel(Model):
264
264
  yield await self._process_streamed_response(response, model_request_parameters) # type: ignore
265
265
 
266
266
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolDict] | None:
267
+ if model_request_parameters.builtin_tools:
268
+ if model_request_parameters.output_tools:
269
+ raise UserError(
270
+ 'Gemini does not support output tools and built-in tools at the same time. Use `output_type=PromptedOutput(...)` instead.'
271
+ )
272
+ if model_request_parameters.function_tools:
273
+ raise UserError('Gemini does not support user tools and built-in tools at the same time.')
274
+
267
275
  tools: list[ToolDict] = [
268
276
  ToolDict(function_declarations=[_function_declaration_from_tool(t)])
269
277
  for t in model_request_parameters.tool_defs.values()
@@ -334,7 +342,9 @@ class GoogleModel(Model):
334
342
  response_schema = None
335
343
  if model_request_parameters.output_mode == 'native':
336
344
  if tools:
337
- raise UserError('Gemini does not support structured output and tools at the same time.')
345
+ raise UserError(
346
+ 'Gemini does not support `NativeOutput` and tools at the same time. Use `output_type=ToolOutput(...)` instead.'
347
+ )
338
348
  response_mime_type = 'application/json'
339
349
  output_object = model_request_parameters.output_object
340
350
  assert output_object is not None
@@ -349,7 +359,7 @@ class GoogleModel(Model):
349
359
  'headers': {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}
350
360
  }
351
361
  if timeout := model_settings.get('timeout'):
352
- if isinstance(timeout, (int, float)):
362
+ if isinstance(timeout, int | float):
353
363
  http_options['timeout'] = int(1000 * timeout)
354
364
  else:
355
365
  raise UserError('Google does not support setting ModelSettings.timeout to a httpx.Timeout')
@@ -559,6 +569,10 @@ class GeminiStreamedResponse(StreamedResponse):
559
569
  )
560
570
  if maybe_event is not None: # pragma: no branch
561
571
  yield maybe_event
572
+ elif part.executable_code is not None:
573
+ pass
574
+ elif part.code_execution_result is not None:
575
+ pass
562
576
  else:
563
577
  assert part.function_response is not None, f'Unexpected part: {part}' # pragma: no cover
564
578
 
@@ -648,7 +662,7 @@ def _process_response_from_parts(
648
662
  parts=items,
649
663
  model_name=model_name,
650
664
  usage=usage,
651
- provider_request_id=vendor_id,
665
+ provider_response_id=vendor_id,
652
666
  provider_details=vendor_details,
653
667
  provider_name=provider_name,
654
668
  )