pydantic-ai-slim 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +28 -2
- pydantic_ai/_a2a.py +1 -1
- pydantic_ai/_agent_graph.py +323 -156
- pydantic_ai/_function_schema.py +5 -5
- pydantic_ai/_griffe.py +2 -1
- pydantic_ai/_otel_messages.py +2 -2
- pydantic_ai/_output.py +31 -35
- pydantic_ai/_parts_manager.py +7 -5
- pydantic_ai/_run_context.py +3 -1
- pydantic_ai/_system_prompt.py +2 -2
- pydantic_ai/_tool_manager.py +32 -28
- pydantic_ai/_utils.py +14 -26
- pydantic_ai/ag_ui.py +82 -51
- pydantic_ai/agent/__init__.py +70 -9
- pydantic_ai/agent/abstract.py +35 -4
- pydantic_ai/agent/wrapper.py +6 -0
- pydantic_ai/builtin_tools.py +2 -2
- pydantic_ai/common_tools/duckduckgo.py +4 -2
- pydantic_ai/durable_exec/temporal/__init__.py +4 -2
- pydantic_ai/durable_exec/temporal/_agent.py +93 -11
- pydantic_ai/durable_exec/temporal/_function_toolset.py +53 -6
- pydantic_ai/durable_exec/temporal/_logfire.py +1 -1
- pydantic_ai/durable_exec/temporal/_mcp_server.py +2 -1
- pydantic_ai/durable_exec/temporal/_model.py +2 -2
- pydantic_ai/durable_exec/temporal/_run_context.py +2 -1
- pydantic_ai/durable_exec/temporal/_toolset.py +2 -1
- pydantic_ai/exceptions.py +45 -2
- pydantic_ai/format_prompt.py +2 -2
- pydantic_ai/mcp.py +15 -27
- pydantic_ai/messages.py +149 -42
- pydantic_ai/models/__init__.py +6 -4
- pydantic_ai/models/anthropic.py +9 -16
- pydantic_ai/models/bedrock.py +50 -56
- pydantic_ai/models/cohere.py +3 -3
- pydantic_ai/models/fallback.py +2 -2
- pydantic_ai/models/function.py +25 -23
- pydantic_ai/models/gemini.py +12 -13
- pydantic_ai/models/google.py +18 -4
- pydantic_ai/models/groq.py +126 -38
- pydantic_ai/models/huggingface.py +4 -4
- pydantic_ai/models/instrumented.py +35 -16
- pydantic_ai/models/mcp_sampling.py +3 -1
- pydantic_ai/models/mistral.py +6 -6
- pydantic_ai/models/openai.py +35 -40
- pydantic_ai/models/test.py +24 -4
- pydantic_ai/output.py +27 -32
- pydantic_ai/profiles/__init__.py +3 -3
- pydantic_ai/profiles/groq.py +1 -1
- pydantic_ai/profiles/openai.py +25 -4
- pydantic_ai/providers/__init__.py +4 -0
- pydantic_ai/providers/anthropic.py +2 -3
- pydantic_ai/providers/bedrock.py +3 -2
- pydantic_ai/providers/google_vertex.py +2 -1
- pydantic_ai/providers/groq.py +21 -2
- pydantic_ai/providers/litellm.py +134 -0
- pydantic_ai/result.py +144 -41
- pydantic_ai/retries.py +52 -31
- pydantic_ai/run.py +12 -5
- pydantic_ai/tools.py +127 -23
- pydantic_ai/toolsets/__init__.py +4 -1
- pydantic_ai/toolsets/_dynamic.py +4 -4
- pydantic_ai/toolsets/abstract.py +18 -2
- pydantic_ai/toolsets/approval_required.py +32 -0
- pydantic_ai/toolsets/combined.py +7 -12
- pydantic_ai/toolsets/{deferred.py → external.py} +11 -5
- pydantic_ai/toolsets/filtered.py +1 -1
- pydantic_ai/toolsets/function.py +58 -21
- pydantic_ai/toolsets/wrapper.py +2 -1
- pydantic_ai/usage.py +44 -8
- {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0.dist-info}/METADATA +8 -9
- pydantic_ai_slim-1.0.0.dist-info/RECORD +121 -0
- pydantic_ai_slim-0.8.1.dist-info/RECORD +0 -119
- {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/anthropic.py
CHANGED
|
@@ -6,7 +6,7 @@ from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator
|
|
|
6
6
|
from contextlib import asynccontextmanager
|
|
7
7
|
from dataclasses import dataclass, field
|
|
8
8
|
from datetime import datetime, timezone
|
|
9
|
-
from typing import Any, Literal,
|
|
9
|
+
from typing import Any, Literal, cast, overload
|
|
10
10
|
|
|
11
11
|
from typing_extensions import assert_never
|
|
12
12
|
|
|
@@ -99,7 +99,7 @@ except ImportError as _import_error:
|
|
|
99
99
|
LatestAnthropicModelNames = ModelParam
|
|
100
100
|
"""Latest Anthropic models."""
|
|
101
101
|
|
|
102
|
-
AnthropicModelName =
|
|
102
|
+
AnthropicModelName = str | LatestAnthropicModelNames
|
|
103
103
|
"""Possible Anthropic model names.
|
|
104
104
|
|
|
105
105
|
Since Anthropic supports a variety of date-stamped models, we explicitly list the latest models but
|
|
@@ -290,7 +290,7 @@ class AnthropicModel(Model):
|
|
|
290
290
|
for item in response.content:
|
|
291
291
|
if isinstance(item, BetaTextBlock):
|
|
292
292
|
items.append(TextPart(content=item.text))
|
|
293
|
-
elif isinstance(item,
|
|
293
|
+
elif isinstance(item, BetaWebSearchToolResultBlock | BetaCodeExecutionToolResultBlock):
|
|
294
294
|
items.append(
|
|
295
295
|
BuiltinToolReturnPart(
|
|
296
296
|
provider_name='anthropic',
|
|
@@ -327,7 +327,7 @@ class AnthropicModel(Model):
|
|
|
327
327
|
)
|
|
328
328
|
|
|
329
329
|
return ModelResponse(
|
|
330
|
-
items,
|
|
330
|
+
parts=items,
|
|
331
331
|
usage=_map_usage(response),
|
|
332
332
|
model_name=response.model,
|
|
333
333
|
provider_response_id=response.id,
|
|
@@ -536,7 +536,7 @@ class AnthropicModel(Model):
|
|
|
536
536
|
}
|
|
537
537
|
|
|
538
538
|
|
|
539
|
-
def _map_usage(message: BetaMessage |
|
|
539
|
+
def _map_usage(message: BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent) -> usage.RequestUsage:
|
|
540
540
|
if isinstance(message, BetaMessage):
|
|
541
541
|
response_usage = message.usage
|
|
542
542
|
elif isinstance(message, BetaRawMessageStartEvent):
|
|
@@ -544,12 +544,7 @@ def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.Reques
|
|
|
544
544
|
elif isinstance(message, BetaRawMessageDeltaEvent):
|
|
545
545
|
response_usage = message.usage
|
|
546
546
|
else:
|
|
547
|
-
|
|
548
|
-
# - RawMessageStopEvent
|
|
549
|
-
# - RawContentBlockStartEvent
|
|
550
|
-
# - RawContentBlockDeltaEvent
|
|
551
|
-
# - RawContentBlockStopEvent
|
|
552
|
-
return usage.RequestUsage()
|
|
547
|
+
assert_never(message)
|
|
553
548
|
|
|
554
549
|
# Store all integer-typed usage values in the details, except 'output_tokens' which is represented exactly by
|
|
555
550
|
# `response_tokens`
|
|
@@ -586,10 +581,8 @@ class AnthropicStreamedResponse(StreamedResponse):
|
|
|
586
581
|
current_block: BetaContentBlock | None = None
|
|
587
582
|
|
|
588
583
|
async for event in self._response:
|
|
589
|
-
self._usage += _map_usage(event)
|
|
590
|
-
|
|
591
584
|
if isinstance(event, BetaRawMessageStartEvent):
|
|
592
|
-
|
|
585
|
+
self._usage = _map_usage(event)
|
|
593
586
|
|
|
594
587
|
elif isinstance(event, BetaRawContentBlockStartEvent):
|
|
595
588
|
current_block = event.content_block
|
|
@@ -652,9 +645,9 @@ class AnthropicStreamedResponse(StreamedResponse):
|
|
|
652
645
|
pass
|
|
653
646
|
|
|
654
647
|
elif isinstance(event, BetaRawMessageDeltaEvent):
|
|
655
|
-
|
|
648
|
+
self._usage = _map_usage(event)
|
|
656
649
|
|
|
657
|
-
elif isinstance(event,
|
|
650
|
+
elif isinstance(event, BetaRawContentBlockStopEvent | BetaRawMessageStopEvent): # pragma: no branch
|
|
658
651
|
current_block = None
|
|
659
652
|
|
|
660
653
|
@property
|
pydantic_ai/models/bedrock.py
CHANGED
|
@@ -2,13 +2,12 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import functools
|
|
4
4
|
import typing
|
|
5
|
-
import warnings
|
|
6
5
|
from collections.abc import AsyncIterator, Iterable, Iterator, Mapping
|
|
7
6
|
from contextlib import asynccontextmanager
|
|
8
7
|
from dataclasses import dataclass, field
|
|
9
8
|
from datetime import datetime
|
|
10
9
|
from itertools import count
|
|
11
|
-
from typing import TYPE_CHECKING, Any, Generic, Literal,
|
|
10
|
+
from typing import TYPE_CHECKING, Any, Generic, Literal, cast, overload
|
|
12
11
|
|
|
13
12
|
import anyio
|
|
14
13
|
import anyio.to_thread
|
|
@@ -125,7 +124,7 @@ LatestBedrockModelNames = Literal[
|
|
|
125
124
|
]
|
|
126
125
|
"""Latest Bedrock models."""
|
|
127
126
|
|
|
128
|
-
BedrockModelName =
|
|
127
|
+
BedrockModelName = str | LatestBedrockModelNames
|
|
129
128
|
"""Possible Bedrock model names.
|
|
130
129
|
|
|
131
130
|
Since Bedrock supports a variety of date-stamped models, we explicitly list the latest models but allow any name in the type hints.
|
|
@@ -303,7 +302,7 @@ class BedrockConverseModel(Model):
|
|
|
303
302
|
)
|
|
304
303
|
response_id = response.get('ResponseMetadata', {}).get('RequestId', None)
|
|
305
304
|
return ModelResponse(
|
|
306
|
-
items,
|
|
305
|
+
parts=items,
|
|
307
306
|
usage=u,
|
|
308
307
|
model_name=self.model_name,
|
|
309
308
|
provider_response_id=response_id,
|
|
@@ -490,7 +489,7 @@ class BedrockConverseModel(Model):
|
|
|
490
489
|
else:
|
|
491
490
|
# NOTE: We don't pass the thinking part to Bedrock for models other than Claude since it raises an error.
|
|
492
491
|
pass
|
|
493
|
-
elif isinstance(item,
|
|
492
|
+
elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart):
|
|
494
493
|
pass
|
|
495
494
|
else:
|
|
496
495
|
assert isinstance(item, ToolCallPart)
|
|
@@ -546,7 +545,7 @@ class BedrockConverseModel(Model):
|
|
|
546
545
|
content.append({'video': {'format': format, 'source': {'bytes': item.data}}})
|
|
547
546
|
else:
|
|
548
547
|
raise NotImplementedError('Binary content is not supported yet.')
|
|
549
|
-
elif isinstance(item,
|
|
548
|
+
elif isinstance(item, ImageUrl | DocumentUrl | VideoUrl):
|
|
550
549
|
downloaded_item = await download_item(item, data_format='bytes', type_format='extension')
|
|
551
550
|
format = downloaded_item['data_type']
|
|
552
551
|
if item.kind == 'image-url':
|
|
@@ -601,7 +600,7 @@ class BedrockStreamedResponse(StreamedResponse):
|
|
|
601
600
|
_provider_name: str
|
|
602
601
|
_timestamp: datetime = field(default_factory=_utils.now_utc)
|
|
603
602
|
|
|
604
|
-
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
603
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
605
604
|
"""Return an async iterator of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.
|
|
606
605
|
|
|
607
606
|
This method should be implemented by subclasses to translate the vendor-specific stream of events into
|
|
@@ -610,60 +609,55 @@ class BedrockStreamedResponse(StreamedResponse):
|
|
|
610
609
|
chunk: ConverseStreamOutputTypeDef
|
|
611
610
|
tool_id: str | None = None
|
|
612
611
|
async for chunk in _AsyncIteratorWrapper(self._event_stream):
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
if text := delta['reasoningContent'].get('text'):
|
|
612
|
+
match chunk:
|
|
613
|
+
case {'messageStart': _}:
|
|
614
|
+
continue
|
|
615
|
+
case {'messageStop': _}:
|
|
616
|
+
continue
|
|
617
|
+
case {'metadata': metadata}:
|
|
618
|
+
if 'usage' in metadata: # pragma: no branch
|
|
619
|
+
self._usage += self._map_usage(metadata)
|
|
620
|
+
continue
|
|
621
|
+
case {'contentBlockStart': content_block_start}:
|
|
622
|
+
index = content_block_start['contentBlockIndex']
|
|
623
|
+
start = content_block_start['start']
|
|
624
|
+
if 'toolUse' in start: # pragma: no branch
|
|
625
|
+
tool_use_start = start['toolUse']
|
|
626
|
+
tool_id = tool_use_start['toolUseId']
|
|
627
|
+
tool_name = tool_use_start['name']
|
|
628
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
629
|
+
vendor_part_id=index,
|
|
630
|
+
tool_name=tool_name,
|
|
631
|
+
args=None,
|
|
632
|
+
tool_call_id=tool_id,
|
|
633
|
+
)
|
|
634
|
+
if maybe_event: # pragma: no branch
|
|
635
|
+
yield maybe_event
|
|
636
|
+
case {'contentBlockDelta': content_block_delta}:
|
|
637
|
+
index = content_block_delta['contentBlockIndex']
|
|
638
|
+
delta = content_block_delta['delta']
|
|
639
|
+
if 'reasoningContent' in delta:
|
|
642
640
|
yield self._parts_manager.handle_thinking_delta(
|
|
643
641
|
vendor_part_id=index,
|
|
644
|
-
content=text,
|
|
642
|
+
content=delta['reasoningContent'].get('text'),
|
|
645
643
|
signature=delta['reasoningContent'].get('signature'),
|
|
646
644
|
)
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
645
|
+
if 'text' in delta:
|
|
646
|
+
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=delta['text'])
|
|
647
|
+
if maybe_event is not None: # pragma: no branch
|
|
648
|
+
yield maybe_event
|
|
649
|
+
if 'toolUse' in delta:
|
|
650
|
+
tool_use = delta['toolUse']
|
|
651
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
652
|
+
vendor_part_id=index,
|
|
653
|
+
tool_name=tool_use.get('name'),
|
|
654
|
+
args=tool_use.get('input'),
|
|
655
|
+
tool_call_id=tool_id,
|
|
652
656
|
)
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
if 'toolUse' in delta:
|
|
658
|
-
tool_use = delta['toolUse']
|
|
659
|
-
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
660
|
-
vendor_part_id=index,
|
|
661
|
-
tool_name=tool_use.get('name'),
|
|
662
|
-
args=tool_use.get('input'),
|
|
663
|
-
tool_call_id=tool_id,
|
|
664
|
-
)
|
|
665
|
-
if maybe_event: # pragma: no branch
|
|
666
|
-
yield maybe_event
|
|
657
|
+
if maybe_event: # pragma: no branch
|
|
658
|
+
yield maybe_event
|
|
659
|
+
case _:
|
|
660
|
+
pass # pyright wants match statements to be exhaustive
|
|
667
661
|
|
|
668
662
|
@property
|
|
669
663
|
def model_name(self) -> str:
|
pydantic_ai/models/cohere.py
CHANGED
|
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
from collections.abc import Iterable
|
|
4
4
|
from dataclasses import dataclass, field
|
|
5
|
-
from typing import Literal,
|
|
5
|
+
from typing import Literal, cast
|
|
6
6
|
|
|
7
7
|
from typing_extensions import assert_never
|
|
8
8
|
|
|
@@ -72,7 +72,7 @@ LatestCohereModelNames = Literal[
|
|
|
72
72
|
]
|
|
73
73
|
"""Latest Cohere models."""
|
|
74
74
|
|
|
75
|
-
CohereModelName =
|
|
75
|
+
CohereModelName = str | LatestCohereModelNames
|
|
76
76
|
"""Possible Cohere model names.
|
|
77
77
|
|
|
78
78
|
Since Cohere supports a variety of date-stamped models, we explicitly list the latest models but
|
|
@@ -228,7 +228,7 @@ class CohereModel(Model):
|
|
|
228
228
|
pass
|
|
229
229
|
elif isinstance(item, ToolCallPart):
|
|
230
230
|
tool_calls.append(self._map_tool_call(item))
|
|
231
|
-
elif isinstance(item,
|
|
231
|
+
elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
|
|
232
232
|
# This is currently never returned from cohere
|
|
233
233
|
pass
|
|
234
234
|
else:
|
pydantic_ai/models/fallback.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
from collections.abc import AsyncIterator
|
|
3
|
+
from collections.abc import AsyncIterator, Callable
|
|
4
4
|
from contextlib import AsyncExitStack, asynccontextmanager, suppress
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
|
-
from typing import TYPE_CHECKING, Any
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
7
7
|
|
|
8
8
|
from opentelemetry.trace import get_current_span
|
|
9
9
|
|
pydantic_ai/models/function.py
CHANGED
|
@@ -2,14 +2,14 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
4
|
import re
|
|
5
|
-
from collections.abc import AsyncIterator, Awaitable, Iterable, Sequence
|
|
5
|
+
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence
|
|
6
6
|
from contextlib import asynccontextmanager
|
|
7
|
-
from dataclasses import dataclass, field
|
|
7
|
+
from dataclasses import KW_ONLY, dataclass, field
|
|
8
8
|
from datetime import datetime
|
|
9
9
|
from itertools import chain
|
|
10
|
-
from typing import Any,
|
|
10
|
+
from typing import Any, TypeAlias
|
|
11
11
|
|
|
12
|
-
from typing_extensions import
|
|
12
|
+
from typing_extensions import assert_never, overload
|
|
13
13
|
|
|
14
14
|
from .. import _utils, usage
|
|
15
15
|
from .._run_context import RunContext
|
|
@@ -44,8 +44,8 @@ class FunctionModel(Model):
|
|
|
44
44
|
Apart from `__init__`, all methods are private or match those of the base class.
|
|
45
45
|
"""
|
|
46
46
|
|
|
47
|
-
function: FunctionDef | None
|
|
48
|
-
stream_function: StreamFunctionDef | None
|
|
47
|
+
function: FunctionDef | None
|
|
48
|
+
stream_function: StreamFunctionDef | None
|
|
49
49
|
|
|
50
50
|
_model_name: str = field(repr=False)
|
|
51
51
|
_system: str = field(default='function', repr=False)
|
|
@@ -120,10 +120,10 @@ class FunctionModel(Model):
|
|
|
120
120
|
model_request_parameters: ModelRequestParameters,
|
|
121
121
|
) -> ModelResponse:
|
|
122
122
|
agent_info = AgentInfo(
|
|
123
|
-
model_request_parameters.function_tools,
|
|
124
|
-
model_request_parameters.allow_text_output,
|
|
125
|
-
model_request_parameters.output_tools,
|
|
126
|
-
model_settings,
|
|
123
|
+
function_tools=model_request_parameters.function_tools,
|
|
124
|
+
allow_text_output=model_request_parameters.allow_text_output,
|
|
125
|
+
output_tools=model_request_parameters.output_tools,
|
|
126
|
+
model_settings=model_settings,
|
|
127
127
|
)
|
|
128
128
|
|
|
129
129
|
assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
|
|
@@ -149,10 +149,10 @@ class FunctionModel(Model):
|
|
|
149
149
|
run_context: RunContext[Any] | None = None,
|
|
150
150
|
) -> AsyncIterator[StreamedResponse]:
|
|
151
151
|
agent_info = AgentInfo(
|
|
152
|
-
model_request_parameters.function_tools,
|
|
153
|
-
model_request_parameters.allow_text_output,
|
|
154
|
-
model_request_parameters.output_tools,
|
|
155
|
-
model_settings,
|
|
152
|
+
function_tools=model_request_parameters.function_tools,
|
|
153
|
+
allow_text_output=model_request_parameters.allow_text_output,
|
|
154
|
+
output_tools=model_request_parameters.output_tools,
|
|
155
|
+
model_settings=model_settings,
|
|
156
156
|
)
|
|
157
157
|
|
|
158
158
|
assert self.stream_function is not None, (
|
|
@@ -182,7 +182,7 @@ class FunctionModel(Model):
|
|
|
182
182
|
return self._system
|
|
183
183
|
|
|
184
184
|
|
|
185
|
-
@dataclass(frozen=True)
|
|
185
|
+
@dataclass(frozen=True, kw_only=True)
|
|
186
186
|
class AgentInfo:
|
|
187
187
|
"""Information about an agent.
|
|
188
188
|
|
|
@@ -212,13 +212,17 @@ class DeltaToolCall:
|
|
|
212
212
|
|
|
213
213
|
name: str | None = None
|
|
214
214
|
"""Incremental change to the name of the tool."""
|
|
215
|
+
|
|
215
216
|
json_args: str | None = None
|
|
216
217
|
"""Incremental change to the arguments as JSON"""
|
|
218
|
+
|
|
219
|
+
_: KW_ONLY
|
|
220
|
+
|
|
217
221
|
tool_call_id: str | None = None
|
|
218
222
|
"""Incremental change to the tool call ID."""
|
|
219
223
|
|
|
220
224
|
|
|
221
|
-
@dataclass
|
|
225
|
+
@dataclass(kw_only=True)
|
|
222
226
|
class DeltaThinkingPart:
|
|
223
227
|
"""Incremental change to a thinking part.
|
|
224
228
|
|
|
@@ -237,18 +241,16 @@ DeltaToolCalls: TypeAlias = dict[int, DeltaToolCall]
|
|
|
237
241
|
DeltaThinkingCalls: TypeAlias = dict[int, DeltaThinkingPart]
|
|
238
242
|
"""A mapping of thinking call IDs to incremental changes."""
|
|
239
243
|
|
|
240
|
-
|
|
241
|
-
FunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], Union[ModelResponse, Awaitable[ModelResponse]]]
|
|
244
|
+
FunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], ModelResponse | Awaitable[ModelResponse]]
|
|
242
245
|
"""A function used to generate a non-streamed response."""
|
|
243
246
|
|
|
244
|
-
# TODO: Change signature as indicated above
|
|
245
247
|
StreamFunctionDef: TypeAlias = Callable[
|
|
246
|
-
[list[ModelMessage], AgentInfo], AsyncIterator[
|
|
248
|
+
[list[ModelMessage], AgentInfo], AsyncIterator[str | DeltaToolCalls | DeltaThinkingCalls]
|
|
247
249
|
]
|
|
248
250
|
"""A function used to generate a streamed response.
|
|
249
251
|
|
|
250
|
-
While this is defined as having return type of `AsyncIterator[
|
|
251
|
-
really be considered as `
|
|
252
|
+
While this is defined as having return type of `AsyncIterator[str | DeltaToolCalls | DeltaThinkingCalls]`, it should
|
|
253
|
+
really be considered as `AsyncIterator[str] | AsyncIterator[DeltaToolCalls] | AsyncIterator[DeltaThinkingCalls]`,
|
|
252
254
|
|
|
253
255
|
E.g. you need to yield all text, all `DeltaToolCalls`, or all `DeltaThinkingCalls`, not mix them.
|
|
254
256
|
"""
|
|
@@ -326,7 +328,7 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.RequestUsage:
|
|
|
326
328
|
for message in messages:
|
|
327
329
|
if isinstance(message, ModelRequest):
|
|
328
330
|
for part in message.parts:
|
|
329
|
-
if isinstance(part,
|
|
331
|
+
if isinstance(part, SystemPromptPart | UserPromptPart):
|
|
330
332
|
request_tokens += _estimate_string_tokens(part.content)
|
|
331
333
|
elif isinstance(part, ToolReturnPart):
|
|
332
334
|
request_tokens += _estimate_string_tokens(part.model_response_str())
|
pydantic_ai/models/gemini.py
CHANGED
|
@@ -5,7 +5,7 @@ from collections.abc import AsyncIterator, Sequence
|
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from datetime import datetime
|
|
8
|
-
from typing import Annotated, Any, Literal, Protocol,
|
|
8
|
+
from typing import Annotated, Any, Literal, Protocol, cast
|
|
9
9
|
from uuid import uuid4
|
|
10
10
|
|
|
11
11
|
import httpx
|
|
@@ -51,7 +51,7 @@ LatestGeminiModelNames = Literal[
|
|
|
51
51
|
]
|
|
52
52
|
"""Latest Gemini models."""
|
|
53
53
|
|
|
54
|
-
GeminiModelName =
|
|
54
|
+
GeminiModelName = str | LatestGeminiModelNames
|
|
55
55
|
"""Possible Gemini model names.
|
|
56
56
|
|
|
57
57
|
Since Gemini supports a variety of date-stamped models, we explicitly list the latest models but
|
|
@@ -211,7 +211,9 @@ class GeminiModel(Model):
|
|
|
211
211
|
generation_config = _settings_to_generation_config(model_settings)
|
|
212
212
|
if model_request_parameters.output_mode == 'native':
|
|
213
213
|
if tools:
|
|
214
|
-
raise UserError(
|
|
214
|
+
raise UserError(
|
|
215
|
+
'Gemini does not support `NativeOutput` and tools at the same time. Use `output_type=ToolOutput(...)` instead.'
|
|
216
|
+
)
|
|
215
217
|
|
|
216
218
|
generation_config['response_mime_type'] = 'application/json'
|
|
217
219
|
|
|
@@ -615,7 +617,7 @@ def _content_model_response(m: ModelResponse) -> _GeminiContent:
|
|
|
615
617
|
elif isinstance(item, TextPart):
|
|
616
618
|
if item.content:
|
|
617
619
|
parts.append(_GeminiTextPart(text=item.content))
|
|
618
|
-
elif isinstance(item,
|
|
620
|
+
elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
|
|
619
621
|
# This is currently never returned from gemini
|
|
620
622
|
pass
|
|
621
623
|
else:
|
|
@@ -735,16 +737,13 @@ def _part_discriminator(v: Any) -> str:
|
|
|
735
737
|
|
|
736
738
|
# See <https://ai.google.dev/api/caching#Part>
|
|
737
739
|
# we don't currently support other part types
|
|
738
|
-
# TODO discriminator
|
|
739
740
|
_GeminiPartUnion = Annotated[
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
Annotated[_GeminiThoughtPart, pydantic.Tag('thought')],
|
|
747
|
-
],
|
|
741
|
+
Annotated[_GeminiTextPart, pydantic.Tag('text')]
|
|
742
|
+
| Annotated[_GeminiFunctionCallPart, pydantic.Tag('function_call')]
|
|
743
|
+
| Annotated[_GeminiFunctionResponsePart, pydantic.Tag('function_response')]
|
|
744
|
+
| Annotated[_GeminiInlineDataPart, pydantic.Tag('inline_data')]
|
|
745
|
+
| Annotated[_GeminiFileDataPart, pydantic.Tag('file_data')]
|
|
746
|
+
| Annotated[_GeminiThoughtPart, pydantic.Tag('thought')],
|
|
748
747
|
pydantic.Discriminator(_part_discriminator),
|
|
749
748
|
]
|
|
750
749
|
|
pydantic_ai/models/google.py
CHANGED
|
@@ -5,7 +5,7 @@ from collections.abc import AsyncIterator, Awaitable
|
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from datetime import datetime
|
|
8
|
-
from typing import Any, Literal,
|
|
8
|
+
from typing import Any, Literal, cast, overload
|
|
9
9
|
from uuid import uuid4
|
|
10
10
|
|
|
11
11
|
from typing_extensions import assert_never
|
|
@@ -91,7 +91,7 @@ LatestGoogleModelNames = Literal[
|
|
|
91
91
|
]
|
|
92
92
|
"""Latest Gemini models."""
|
|
93
93
|
|
|
94
|
-
GoogleModelName =
|
|
94
|
+
GoogleModelName = str | LatestGoogleModelNames
|
|
95
95
|
"""Possible Gemini model names.
|
|
96
96
|
|
|
97
97
|
Since Gemini supports a variety of date-stamped models, we explicitly list the latest models but
|
|
@@ -264,6 +264,14 @@ class GoogleModel(Model):
|
|
|
264
264
|
yield await self._process_streamed_response(response, model_request_parameters) # type: ignore
|
|
265
265
|
|
|
266
266
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolDict] | None:
|
|
267
|
+
if model_request_parameters.builtin_tools:
|
|
268
|
+
if model_request_parameters.output_tools:
|
|
269
|
+
raise UserError(
|
|
270
|
+
'Gemini does not support output tools and built-in tools at the same time. Use `output_type=PromptedOutput(...)` instead.'
|
|
271
|
+
)
|
|
272
|
+
if model_request_parameters.function_tools:
|
|
273
|
+
raise UserError('Gemini does not support user tools and built-in tools at the same time.')
|
|
274
|
+
|
|
267
275
|
tools: list[ToolDict] = [
|
|
268
276
|
ToolDict(function_declarations=[_function_declaration_from_tool(t)])
|
|
269
277
|
for t in model_request_parameters.tool_defs.values()
|
|
@@ -334,7 +342,9 @@ class GoogleModel(Model):
|
|
|
334
342
|
response_schema = None
|
|
335
343
|
if model_request_parameters.output_mode == 'native':
|
|
336
344
|
if tools:
|
|
337
|
-
raise UserError(
|
|
345
|
+
raise UserError(
|
|
346
|
+
'Gemini does not support `NativeOutput` and tools at the same time. Use `output_type=ToolOutput(...)` instead.'
|
|
347
|
+
)
|
|
338
348
|
response_mime_type = 'application/json'
|
|
339
349
|
output_object = model_request_parameters.output_object
|
|
340
350
|
assert output_object is not None
|
|
@@ -349,7 +359,7 @@ class GoogleModel(Model):
|
|
|
349
359
|
'headers': {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}
|
|
350
360
|
}
|
|
351
361
|
if timeout := model_settings.get('timeout'):
|
|
352
|
-
if isinstance(timeout,
|
|
362
|
+
if isinstance(timeout, int | float):
|
|
353
363
|
http_options['timeout'] = int(1000 * timeout)
|
|
354
364
|
else:
|
|
355
365
|
raise UserError('Google does not support setting ModelSettings.timeout to a httpx.Timeout')
|
|
@@ -559,6 +569,10 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
559
569
|
)
|
|
560
570
|
if maybe_event is not None: # pragma: no branch
|
|
561
571
|
yield maybe_event
|
|
572
|
+
elif part.executable_code is not None:
|
|
573
|
+
pass
|
|
574
|
+
elif part.code_execution_result is not None:
|
|
575
|
+
pass
|
|
562
576
|
else:
|
|
563
577
|
assert part.function_response is not None, f'Unexpected part: {part}' # pragma: no cover
|
|
564
578
|
|