pydantic-ai-slim 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (75) hide show
  1. pydantic_ai/__init__.py +28 -2
  2. pydantic_ai/_a2a.py +1 -1
  3. pydantic_ai/_agent_graph.py +323 -156
  4. pydantic_ai/_function_schema.py +5 -5
  5. pydantic_ai/_griffe.py +2 -1
  6. pydantic_ai/_otel_messages.py +2 -2
  7. pydantic_ai/_output.py +31 -35
  8. pydantic_ai/_parts_manager.py +7 -5
  9. pydantic_ai/_run_context.py +3 -1
  10. pydantic_ai/_system_prompt.py +2 -2
  11. pydantic_ai/_tool_manager.py +32 -28
  12. pydantic_ai/_utils.py +14 -26
  13. pydantic_ai/ag_ui.py +82 -51
  14. pydantic_ai/agent/__init__.py +70 -9
  15. pydantic_ai/agent/abstract.py +35 -4
  16. pydantic_ai/agent/wrapper.py +6 -0
  17. pydantic_ai/builtin_tools.py +2 -2
  18. pydantic_ai/common_tools/duckduckgo.py +4 -2
  19. pydantic_ai/durable_exec/temporal/__init__.py +4 -2
  20. pydantic_ai/durable_exec/temporal/_agent.py +93 -11
  21. pydantic_ai/durable_exec/temporal/_function_toolset.py +53 -6
  22. pydantic_ai/durable_exec/temporal/_logfire.py +1 -1
  23. pydantic_ai/durable_exec/temporal/_mcp_server.py +2 -1
  24. pydantic_ai/durable_exec/temporal/_model.py +2 -2
  25. pydantic_ai/durable_exec/temporal/_run_context.py +2 -1
  26. pydantic_ai/durable_exec/temporal/_toolset.py +2 -1
  27. pydantic_ai/exceptions.py +45 -2
  28. pydantic_ai/format_prompt.py +2 -2
  29. pydantic_ai/mcp.py +15 -27
  30. pydantic_ai/messages.py +149 -42
  31. pydantic_ai/models/__init__.py +6 -4
  32. pydantic_ai/models/anthropic.py +9 -16
  33. pydantic_ai/models/bedrock.py +50 -56
  34. pydantic_ai/models/cohere.py +3 -3
  35. pydantic_ai/models/fallback.py +2 -2
  36. pydantic_ai/models/function.py +25 -23
  37. pydantic_ai/models/gemini.py +12 -13
  38. pydantic_ai/models/google.py +18 -4
  39. pydantic_ai/models/groq.py +126 -38
  40. pydantic_ai/models/huggingface.py +4 -4
  41. pydantic_ai/models/instrumented.py +35 -16
  42. pydantic_ai/models/mcp_sampling.py +3 -1
  43. pydantic_ai/models/mistral.py +6 -6
  44. pydantic_ai/models/openai.py +35 -40
  45. pydantic_ai/models/test.py +24 -4
  46. pydantic_ai/output.py +27 -32
  47. pydantic_ai/profiles/__init__.py +3 -3
  48. pydantic_ai/profiles/groq.py +1 -1
  49. pydantic_ai/profiles/openai.py +25 -4
  50. pydantic_ai/providers/__init__.py +4 -0
  51. pydantic_ai/providers/anthropic.py +2 -3
  52. pydantic_ai/providers/bedrock.py +3 -2
  53. pydantic_ai/providers/google_vertex.py +2 -1
  54. pydantic_ai/providers/groq.py +21 -2
  55. pydantic_ai/providers/litellm.py +134 -0
  56. pydantic_ai/result.py +144 -41
  57. pydantic_ai/retries.py +52 -31
  58. pydantic_ai/run.py +12 -5
  59. pydantic_ai/tools.py +127 -23
  60. pydantic_ai/toolsets/__init__.py +4 -1
  61. pydantic_ai/toolsets/_dynamic.py +4 -4
  62. pydantic_ai/toolsets/abstract.py +18 -2
  63. pydantic_ai/toolsets/approval_required.py +32 -0
  64. pydantic_ai/toolsets/combined.py +7 -12
  65. pydantic_ai/toolsets/{deferred.py → external.py} +11 -5
  66. pydantic_ai/toolsets/filtered.py +1 -1
  67. pydantic_ai/toolsets/function.py +58 -21
  68. pydantic_ai/toolsets/wrapper.py +2 -1
  69. pydantic_ai/usage.py +44 -8
  70. {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0.dist-info}/METADATA +8 -9
  71. pydantic_ai_slim-1.0.0.dist-info/RECORD +121 -0
  72. pydantic_ai_slim-0.8.1.dist-info/RECORD +0 -119
  73. {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0.dist-info}/WHEEL +0 -0
  74. {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0.dist-info}/entry_points.txt +0 -0
  75. {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -6,7 +6,7 @@ from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator
6
6
  from contextlib import asynccontextmanager
7
7
  from dataclasses import dataclass, field
8
8
  from datetime import datetime, timezone
9
- from typing import Any, Literal, Union, cast, overload
9
+ from typing import Any, Literal, cast, overload
10
10
 
11
11
  from typing_extensions import assert_never
12
12
 
@@ -99,7 +99,7 @@ except ImportError as _import_error:
99
99
  LatestAnthropicModelNames = ModelParam
100
100
  """Latest Anthropic models."""
101
101
 
102
- AnthropicModelName = Union[str, LatestAnthropicModelNames]
102
+ AnthropicModelName = str | LatestAnthropicModelNames
103
103
  """Possible Anthropic model names.
104
104
 
105
105
  Since Anthropic supports a variety of date-stamped models, we explicitly list the latest models but
@@ -290,7 +290,7 @@ class AnthropicModel(Model):
290
290
  for item in response.content:
291
291
  if isinstance(item, BetaTextBlock):
292
292
  items.append(TextPart(content=item.text))
293
- elif isinstance(item, (BetaWebSearchToolResultBlock, BetaCodeExecutionToolResultBlock)):
293
+ elif isinstance(item, BetaWebSearchToolResultBlock | BetaCodeExecutionToolResultBlock):
294
294
  items.append(
295
295
  BuiltinToolReturnPart(
296
296
  provider_name='anthropic',
@@ -327,7 +327,7 @@ class AnthropicModel(Model):
327
327
  )
328
328
 
329
329
  return ModelResponse(
330
- items,
330
+ parts=items,
331
331
  usage=_map_usage(response),
332
332
  model_name=response.model,
333
333
  provider_response_id=response.id,
@@ -536,7 +536,7 @@ class AnthropicModel(Model):
536
536
  }
537
537
 
538
538
 
539
- def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.RequestUsage:
539
+ def _map_usage(message: BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent) -> usage.RequestUsage:
540
540
  if isinstance(message, BetaMessage):
541
541
  response_usage = message.usage
542
542
  elif isinstance(message, BetaRawMessageStartEvent):
@@ -544,12 +544,7 @@ def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.Reques
544
544
  elif isinstance(message, BetaRawMessageDeltaEvent):
545
545
  response_usage = message.usage
546
546
  else:
547
- # No usage information provided in:
548
- # - RawMessageStopEvent
549
- # - RawContentBlockStartEvent
550
- # - RawContentBlockDeltaEvent
551
- # - RawContentBlockStopEvent
552
- return usage.RequestUsage()
547
+ assert_never(message)
553
548
 
554
549
  # Store all integer-typed usage values in the details, except 'output_tokens' which is represented exactly by
555
550
  # `response_tokens`
@@ -586,10 +581,8 @@ class AnthropicStreamedResponse(StreamedResponse):
586
581
  current_block: BetaContentBlock | None = None
587
582
 
588
583
  async for event in self._response:
589
- self._usage += _map_usage(event)
590
-
591
584
  if isinstance(event, BetaRawMessageStartEvent):
592
- pass
585
+ self._usage = _map_usage(event)
593
586
 
594
587
  elif isinstance(event, BetaRawContentBlockStartEvent):
595
588
  current_block = event.content_block
@@ -652,9 +645,9 @@ class AnthropicStreamedResponse(StreamedResponse):
652
645
  pass
653
646
 
654
647
  elif isinstance(event, BetaRawMessageDeltaEvent):
655
- pass
648
+ self._usage = _map_usage(event)
656
649
 
657
- elif isinstance(event, (BetaRawContentBlockStopEvent, BetaRawMessageStopEvent)): # pragma: no branch
650
+ elif isinstance(event, BetaRawContentBlockStopEvent | BetaRawMessageStopEvent): # pragma: no branch
658
651
  current_block = None
659
652
 
660
653
  @property
@@ -2,13 +2,12 @@ from __future__ import annotations
2
2
 
3
3
  import functools
4
4
  import typing
5
- import warnings
6
5
  from collections.abc import AsyncIterator, Iterable, Iterator, Mapping
7
6
  from contextlib import asynccontextmanager
8
7
  from dataclasses import dataclass, field
9
8
  from datetime import datetime
10
9
  from itertools import count
11
- from typing import TYPE_CHECKING, Any, Generic, Literal, Union, cast, overload
10
+ from typing import TYPE_CHECKING, Any, Generic, Literal, cast, overload
12
11
 
13
12
  import anyio
14
13
  import anyio.to_thread
@@ -125,7 +124,7 @@ LatestBedrockModelNames = Literal[
125
124
  ]
126
125
  """Latest Bedrock models."""
127
126
 
128
- BedrockModelName = Union[str, LatestBedrockModelNames]
127
+ BedrockModelName = str | LatestBedrockModelNames
129
128
  """Possible Bedrock model names.
130
129
 
131
130
  Since Bedrock supports a variety of date-stamped models, we explicitly list the latest models but allow any name in the type hints.
@@ -303,7 +302,7 @@ class BedrockConverseModel(Model):
303
302
  )
304
303
  response_id = response.get('ResponseMetadata', {}).get('RequestId', None)
305
304
  return ModelResponse(
306
- items,
305
+ parts=items,
307
306
  usage=u,
308
307
  model_name=self.model_name,
309
308
  provider_response_id=response_id,
@@ -490,7 +489,7 @@ class BedrockConverseModel(Model):
490
489
  else:
491
490
  # NOTE: We don't pass the thinking part to Bedrock for models other than Claude since it raises an error.
492
491
  pass
493
- elif isinstance(item, (BuiltinToolCallPart, BuiltinToolReturnPart)):
492
+ elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart):
494
493
  pass
495
494
  else:
496
495
  assert isinstance(item, ToolCallPart)
@@ -546,7 +545,7 @@ class BedrockConverseModel(Model):
546
545
  content.append({'video': {'format': format, 'source': {'bytes': item.data}}})
547
546
  else:
548
547
  raise NotImplementedError('Binary content is not supported yet.')
549
- elif isinstance(item, (ImageUrl, DocumentUrl, VideoUrl)):
548
+ elif isinstance(item, ImageUrl | DocumentUrl | VideoUrl):
550
549
  downloaded_item = await download_item(item, data_format='bytes', type_format='extension')
551
550
  format = downloaded_item['data_type']
552
551
  if item.kind == 'image-url':
@@ -601,7 +600,7 @@ class BedrockStreamedResponse(StreamedResponse):
601
600
  _provider_name: str
602
601
  _timestamp: datetime = field(default_factory=_utils.now_utc)
603
602
 
604
- async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
603
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
605
604
  """Return an async iterator of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.
606
605
 
607
606
  This method should be implemented by subclasses to translate the vendor-specific stream of events into
@@ -610,60 +609,55 @@ class BedrockStreamedResponse(StreamedResponse):
610
609
  chunk: ConverseStreamOutputTypeDef
611
610
  tool_id: str | None = None
612
611
  async for chunk in _AsyncIteratorWrapper(self._event_stream):
613
- # TODO(Marcelo): Switch this to `match` when we drop Python 3.9 support.
614
- if 'messageStart' in chunk:
615
- continue
616
- if 'messageStop' in chunk:
617
- continue
618
- if 'metadata' in chunk:
619
- if 'usage' in chunk['metadata']: # pragma: no branch
620
- self._usage += self._map_usage(chunk['metadata'])
621
- continue
622
- if 'contentBlockStart' in chunk:
623
- index = chunk['contentBlockStart']['contentBlockIndex']
624
- start = chunk['contentBlockStart']['start']
625
- if 'toolUse' in start: # pragma: no branch
626
- tool_use_start = start['toolUse']
627
- tool_id = tool_use_start['toolUseId']
628
- tool_name = tool_use_start['name']
629
- maybe_event = self._parts_manager.handle_tool_call_delta(
630
- vendor_part_id=index,
631
- tool_name=tool_name,
632
- args=None,
633
- tool_call_id=tool_id,
634
- )
635
- if maybe_event: # pragma: no branch
636
- yield maybe_event
637
- if 'contentBlockDelta' in chunk:
638
- index = chunk['contentBlockDelta']['contentBlockIndex']
639
- delta = chunk['contentBlockDelta']['delta']
640
- if 'reasoningContent' in delta:
641
- if text := delta['reasoningContent'].get('text'):
612
+ match chunk:
613
+ case {'messageStart': _}:
614
+ continue
615
+ case {'messageStop': _}:
616
+ continue
617
+ case {'metadata': metadata}:
618
+ if 'usage' in metadata: # pragma: no branch
619
+ self._usage += self._map_usage(metadata)
620
+ continue
621
+ case {'contentBlockStart': content_block_start}:
622
+ index = content_block_start['contentBlockIndex']
623
+ start = content_block_start['start']
624
+ if 'toolUse' in start: # pragma: no branch
625
+ tool_use_start = start['toolUse']
626
+ tool_id = tool_use_start['toolUseId']
627
+ tool_name = tool_use_start['name']
628
+ maybe_event = self._parts_manager.handle_tool_call_delta(
629
+ vendor_part_id=index,
630
+ tool_name=tool_name,
631
+ args=None,
632
+ tool_call_id=tool_id,
633
+ )
634
+ if maybe_event: # pragma: no branch
635
+ yield maybe_event
636
+ case {'contentBlockDelta': content_block_delta}:
637
+ index = content_block_delta['contentBlockIndex']
638
+ delta = content_block_delta['delta']
639
+ if 'reasoningContent' in delta:
642
640
  yield self._parts_manager.handle_thinking_delta(
643
641
  vendor_part_id=index,
644
- content=text,
642
+ content=delta['reasoningContent'].get('text'),
645
643
  signature=delta['reasoningContent'].get('signature'),
646
644
  )
647
- else: # pragma: no cover
648
- warnings.warn(
649
- f'Only text reasoning content is supported yet, but you got {delta["reasoningContent"]}. '
650
- 'Please report this to the maintainers.',
651
- UserWarning,
645
+ if 'text' in delta:
646
+ maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=delta['text'])
647
+ if maybe_event is not None: # pragma: no branch
648
+ yield maybe_event
649
+ if 'toolUse' in delta:
650
+ tool_use = delta['toolUse']
651
+ maybe_event = self._parts_manager.handle_tool_call_delta(
652
+ vendor_part_id=index,
653
+ tool_name=tool_use.get('name'),
654
+ args=tool_use.get('input'),
655
+ tool_call_id=tool_id,
652
656
  )
653
- if 'text' in delta:
654
- maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=delta['text'])
655
- if maybe_event is not None: # pragma: no branch
656
- yield maybe_event
657
- if 'toolUse' in delta:
658
- tool_use = delta['toolUse']
659
- maybe_event = self._parts_manager.handle_tool_call_delta(
660
- vendor_part_id=index,
661
- tool_name=tool_use.get('name'),
662
- args=tool_use.get('input'),
663
- tool_call_id=tool_id,
664
- )
665
- if maybe_event: # pragma: no branch
666
- yield maybe_event
657
+ if maybe_event: # pragma: no branch
658
+ yield maybe_event
659
+ case _:
660
+ pass # pyright wants match statements to be exhaustive
667
661
 
668
662
  @property
669
663
  def model_name(self) -> str:
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  from collections.abc import Iterable
4
4
  from dataclasses import dataclass, field
5
- from typing import Literal, Union, cast
5
+ from typing import Literal, cast
6
6
 
7
7
  from typing_extensions import assert_never
8
8
 
@@ -72,7 +72,7 @@ LatestCohereModelNames = Literal[
72
72
  ]
73
73
  """Latest Cohere models."""
74
74
 
75
- CohereModelName = Union[str, LatestCohereModelNames]
75
+ CohereModelName = str | LatestCohereModelNames
76
76
  """Possible Cohere model names.
77
77
 
78
78
  Since Cohere supports a variety of date-stamped models, we explicitly list the latest models but
@@ -228,7 +228,7 @@ class CohereModel(Model):
228
228
  pass
229
229
  elif isinstance(item, ToolCallPart):
230
230
  tool_calls.append(self._map_tool_call(item))
231
- elif isinstance(item, (BuiltinToolCallPart, BuiltinToolReturnPart)): # pragma: no cover
231
+ elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
232
232
  # This is currently never returned from cohere
233
233
  pass
234
234
  else:
@@ -1,9 +1,9 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
- from collections.abc import AsyncIterator
3
+ from collections.abc import AsyncIterator, Callable
4
4
  from contextlib import AsyncExitStack, asynccontextmanager, suppress
5
5
  from dataclasses import dataclass, field
6
- from typing import TYPE_CHECKING, Any, Callable
6
+ from typing import TYPE_CHECKING, Any
7
7
 
8
8
  from opentelemetry.trace import get_current_span
9
9
 
@@ -2,14 +2,14 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import inspect
4
4
  import re
5
- from collections.abc import AsyncIterator, Awaitable, Iterable, Sequence
5
+ from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence
6
6
  from contextlib import asynccontextmanager
7
- from dataclasses import dataclass, field
7
+ from dataclasses import KW_ONLY, dataclass, field
8
8
  from datetime import datetime
9
9
  from itertools import chain
10
- from typing import Any, Callable, Union
10
+ from typing import Any, TypeAlias
11
11
 
12
- from typing_extensions import TypeAlias, assert_never, overload
12
+ from typing_extensions import assert_never, overload
13
13
 
14
14
  from .. import _utils, usage
15
15
  from .._run_context import RunContext
@@ -44,8 +44,8 @@ class FunctionModel(Model):
44
44
  Apart from `__init__`, all methods are private or match those of the base class.
45
45
  """
46
46
 
47
- function: FunctionDef | None = None
48
- stream_function: StreamFunctionDef | None = None
47
+ function: FunctionDef | None
48
+ stream_function: StreamFunctionDef | None
49
49
 
50
50
  _model_name: str = field(repr=False)
51
51
  _system: str = field(default='function', repr=False)
@@ -120,10 +120,10 @@ class FunctionModel(Model):
120
120
  model_request_parameters: ModelRequestParameters,
121
121
  ) -> ModelResponse:
122
122
  agent_info = AgentInfo(
123
- model_request_parameters.function_tools,
124
- model_request_parameters.allow_text_output,
125
- model_request_parameters.output_tools,
126
- model_settings,
123
+ function_tools=model_request_parameters.function_tools,
124
+ allow_text_output=model_request_parameters.allow_text_output,
125
+ output_tools=model_request_parameters.output_tools,
126
+ model_settings=model_settings,
127
127
  )
128
128
 
129
129
  assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
@@ -149,10 +149,10 @@ class FunctionModel(Model):
149
149
  run_context: RunContext[Any] | None = None,
150
150
  ) -> AsyncIterator[StreamedResponse]:
151
151
  agent_info = AgentInfo(
152
- model_request_parameters.function_tools,
153
- model_request_parameters.allow_text_output,
154
- model_request_parameters.output_tools,
155
- model_settings,
152
+ function_tools=model_request_parameters.function_tools,
153
+ allow_text_output=model_request_parameters.allow_text_output,
154
+ output_tools=model_request_parameters.output_tools,
155
+ model_settings=model_settings,
156
156
  )
157
157
 
158
158
  assert self.stream_function is not None, (
@@ -182,7 +182,7 @@ class FunctionModel(Model):
182
182
  return self._system
183
183
 
184
184
 
185
- @dataclass(frozen=True)
185
+ @dataclass(frozen=True, kw_only=True)
186
186
  class AgentInfo:
187
187
  """Information about an agent.
188
188
 
@@ -212,13 +212,17 @@ class DeltaToolCall:
212
212
 
213
213
  name: str | None = None
214
214
  """Incremental change to the name of the tool."""
215
+
215
216
  json_args: str | None = None
216
217
  """Incremental change to the arguments as JSON"""
218
+
219
+ _: KW_ONLY
220
+
217
221
  tool_call_id: str | None = None
218
222
  """Incremental change to the tool call ID."""
219
223
 
220
224
 
221
- @dataclass
225
+ @dataclass(kw_only=True)
222
226
  class DeltaThinkingPart:
223
227
  """Incremental change to a thinking part.
224
228
 
@@ -237,18 +241,16 @@ DeltaToolCalls: TypeAlias = dict[int, DeltaToolCall]
237
241
  DeltaThinkingCalls: TypeAlias = dict[int, DeltaThinkingPart]
238
242
  """A mapping of thinking call IDs to incremental changes."""
239
243
 
240
- # TODO: Change the signature to Callable[[list[ModelMessage], ModelSettings, ModelRequestParameters], ...]
241
- FunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], Union[ModelResponse, Awaitable[ModelResponse]]]
244
+ FunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], ModelResponse | Awaitable[ModelResponse]]
242
245
  """A function used to generate a non-streamed response."""
243
246
 
244
- # TODO: Change signature as indicated above
245
247
  StreamFunctionDef: TypeAlias = Callable[
246
- [list[ModelMessage], AgentInfo], AsyncIterator[Union[str, DeltaToolCalls, DeltaThinkingCalls]]
248
+ [list[ModelMessage], AgentInfo], AsyncIterator[str | DeltaToolCalls | DeltaThinkingCalls]
247
249
  ]
248
250
  """A function used to generate a streamed response.
249
251
 
250
- While this is defined as having return type of `AsyncIterator[Union[str, DeltaToolCalls, DeltaThinkingCalls]]`, it should
251
- really be considered as `Union[AsyncIterator[str], AsyncIterator[DeltaToolCalls], AsyncIterator[DeltaThinkingCalls]]`,
252
+ While this is defined as having return type of `AsyncIterator[str | DeltaToolCalls | DeltaThinkingCalls]`, it should
253
+ really be considered as `AsyncIterator[str] | AsyncIterator[DeltaToolCalls] | AsyncIterator[DeltaThinkingCalls]`,
252
254
 
253
255
  E.g. you need to yield all text, all `DeltaToolCalls`, or all `DeltaThinkingCalls`, not mix them.
254
256
  """
@@ -326,7 +328,7 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.RequestUsage:
326
328
  for message in messages:
327
329
  if isinstance(message, ModelRequest):
328
330
  for part in message.parts:
329
- if isinstance(part, (SystemPromptPart, UserPromptPart)):
331
+ if isinstance(part, SystemPromptPart | UserPromptPart):
330
332
  request_tokens += _estimate_string_tokens(part.content)
331
333
  elif isinstance(part, ToolReturnPart):
332
334
  request_tokens += _estimate_string_tokens(part.model_response_str())
@@ -5,7 +5,7 @@ from collections.abc import AsyncIterator, Sequence
5
5
  from contextlib import asynccontextmanager
6
6
  from dataclasses import dataclass, field
7
7
  from datetime import datetime
8
- from typing import Annotated, Any, Literal, Protocol, Union, cast
8
+ from typing import Annotated, Any, Literal, Protocol, cast
9
9
  from uuid import uuid4
10
10
 
11
11
  import httpx
@@ -51,7 +51,7 @@ LatestGeminiModelNames = Literal[
51
51
  ]
52
52
  """Latest Gemini models."""
53
53
 
54
- GeminiModelName = Union[str, LatestGeminiModelNames]
54
+ GeminiModelName = str | LatestGeminiModelNames
55
55
  """Possible Gemini model names.
56
56
 
57
57
  Since Gemini supports a variety of date-stamped models, we explicitly list the latest models but
@@ -211,7 +211,9 @@ class GeminiModel(Model):
211
211
  generation_config = _settings_to_generation_config(model_settings)
212
212
  if model_request_parameters.output_mode == 'native':
213
213
  if tools:
214
- raise UserError('Gemini does not support structured output and tools at the same time.')
214
+ raise UserError(
215
+ 'Gemini does not support `NativeOutput` and tools at the same time. Use `output_type=ToolOutput(...)` instead.'
216
+ )
215
217
 
216
218
  generation_config['response_mime_type'] = 'application/json'
217
219
 
@@ -615,7 +617,7 @@ def _content_model_response(m: ModelResponse) -> _GeminiContent:
615
617
  elif isinstance(item, TextPart):
616
618
  if item.content:
617
619
  parts.append(_GeminiTextPart(text=item.content))
618
- elif isinstance(item, (BuiltinToolCallPart, BuiltinToolReturnPart)): # pragma: no cover
620
+ elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
619
621
  # This is currently never returned from gemini
620
622
  pass
621
623
  else:
@@ -735,16 +737,13 @@ def _part_discriminator(v: Any) -> str:
735
737
 
736
738
  # See <https://ai.google.dev/api/caching#Part>
737
739
  # we don't currently support other part types
738
- # TODO discriminator
739
740
  _GeminiPartUnion = Annotated[
740
- Union[
741
- Annotated[_GeminiTextPart, pydantic.Tag('text')],
742
- Annotated[_GeminiFunctionCallPart, pydantic.Tag('function_call')],
743
- Annotated[_GeminiFunctionResponsePart, pydantic.Tag('function_response')],
744
- Annotated[_GeminiInlineDataPart, pydantic.Tag('inline_data')],
745
- Annotated[_GeminiFileDataPart, pydantic.Tag('file_data')],
746
- Annotated[_GeminiThoughtPart, pydantic.Tag('thought')],
747
- ],
741
+ Annotated[_GeminiTextPart, pydantic.Tag('text')]
742
+ | Annotated[_GeminiFunctionCallPart, pydantic.Tag('function_call')]
743
+ | Annotated[_GeminiFunctionResponsePart, pydantic.Tag('function_response')]
744
+ | Annotated[_GeminiInlineDataPart, pydantic.Tag('inline_data')]
745
+ | Annotated[_GeminiFileDataPart, pydantic.Tag('file_data')]
746
+ | Annotated[_GeminiThoughtPart, pydantic.Tag('thought')],
748
747
  pydantic.Discriminator(_part_discriminator),
749
748
  ]
750
749
 
@@ -5,7 +5,7 @@ from collections.abc import AsyncIterator, Awaitable
5
5
  from contextlib import asynccontextmanager
6
6
  from dataclasses import dataclass, field
7
7
  from datetime import datetime
8
- from typing import Any, Literal, Union, cast, overload
8
+ from typing import Any, Literal, cast, overload
9
9
  from uuid import uuid4
10
10
 
11
11
  from typing_extensions import assert_never
@@ -91,7 +91,7 @@ LatestGoogleModelNames = Literal[
91
91
  ]
92
92
  """Latest Gemini models."""
93
93
 
94
- GoogleModelName = Union[str, LatestGoogleModelNames]
94
+ GoogleModelName = str | LatestGoogleModelNames
95
95
  """Possible Gemini model names.
96
96
 
97
97
  Since Gemini supports a variety of date-stamped models, we explicitly list the latest models but
@@ -264,6 +264,14 @@ class GoogleModel(Model):
264
264
  yield await self._process_streamed_response(response, model_request_parameters) # type: ignore
265
265
 
266
266
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolDict] | None:
267
+ if model_request_parameters.builtin_tools:
268
+ if model_request_parameters.output_tools:
269
+ raise UserError(
270
+ 'Gemini does not support output tools and built-in tools at the same time. Use `output_type=PromptedOutput(...)` instead.'
271
+ )
272
+ if model_request_parameters.function_tools:
273
+ raise UserError('Gemini does not support user tools and built-in tools at the same time.')
274
+
267
275
  tools: list[ToolDict] = [
268
276
  ToolDict(function_declarations=[_function_declaration_from_tool(t)])
269
277
  for t in model_request_parameters.tool_defs.values()
@@ -334,7 +342,9 @@ class GoogleModel(Model):
334
342
  response_schema = None
335
343
  if model_request_parameters.output_mode == 'native':
336
344
  if tools:
337
- raise UserError('Gemini does not support structured output and tools at the same time.')
345
+ raise UserError(
346
+ 'Gemini does not support `NativeOutput` and tools at the same time. Use `output_type=ToolOutput(...)` instead.'
347
+ )
338
348
  response_mime_type = 'application/json'
339
349
  output_object = model_request_parameters.output_object
340
350
  assert output_object is not None
@@ -349,7 +359,7 @@ class GoogleModel(Model):
349
359
  'headers': {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}
350
360
  }
351
361
  if timeout := model_settings.get('timeout'):
352
- if isinstance(timeout, (int, float)):
362
+ if isinstance(timeout, int | float):
353
363
  http_options['timeout'] = int(1000 * timeout)
354
364
  else:
355
365
  raise UserError('Google does not support setting ModelSettings.timeout to a httpx.Timeout')
@@ -559,6 +569,10 @@ class GeminiStreamedResponse(StreamedResponse):
559
569
  )
560
570
  if maybe_event is not None: # pragma: no branch
561
571
  yield maybe_event
572
+ elif part.executable_code is not None:
573
+ pass
574
+ elif part.code_execution_result is not None:
575
+ pass
562
576
  else:
563
577
  assert part.function_response is not None, f'Unexpected part: {part}' # pragma: no cover
564
578