pydantic-ai-slim 0.8.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +28 -2
- pydantic_ai/_a2a.py +1 -1
- pydantic_ai/_agent_graph.py +323 -156
- pydantic_ai/_function_schema.py +5 -5
- pydantic_ai/_griffe.py +2 -1
- pydantic_ai/_otel_messages.py +2 -2
- pydantic_ai/_output.py +31 -35
- pydantic_ai/_parts_manager.py +7 -5
- pydantic_ai/_run_context.py +3 -1
- pydantic_ai/_system_prompt.py +2 -2
- pydantic_ai/_tool_manager.py +32 -28
- pydantic_ai/_utils.py +14 -26
- pydantic_ai/ag_ui.py +82 -51
- pydantic_ai/agent/__init__.py +84 -17
- pydantic_ai/agent/abstract.py +35 -4
- pydantic_ai/agent/wrapper.py +6 -0
- pydantic_ai/builtin_tools.py +2 -2
- pydantic_ai/common_tools/duckduckgo.py +4 -2
- pydantic_ai/durable_exec/temporal/__init__.py +70 -17
- pydantic_ai/durable_exec/temporal/_agent.py +93 -11
- pydantic_ai/durable_exec/temporal/_function_toolset.py +53 -6
- pydantic_ai/durable_exec/temporal/_logfire.py +6 -3
- pydantic_ai/durable_exec/temporal/_mcp_server.py +2 -1
- pydantic_ai/durable_exec/temporal/_model.py +2 -2
- pydantic_ai/durable_exec/temporal/_run_context.py +2 -1
- pydantic_ai/durable_exec/temporal/_toolset.py +2 -1
- pydantic_ai/exceptions.py +45 -2
- pydantic_ai/format_prompt.py +2 -2
- pydantic_ai/mcp.py +15 -27
- pydantic_ai/messages.py +156 -44
- pydantic_ai/models/__init__.py +20 -7
- pydantic_ai/models/anthropic.py +10 -17
- pydantic_ai/models/bedrock.py +55 -57
- pydantic_ai/models/cohere.py +3 -3
- pydantic_ai/models/fallback.py +2 -2
- pydantic_ai/models/function.py +25 -23
- pydantic_ai/models/gemini.py +13 -14
- pydantic_ai/models/google.py +19 -5
- pydantic_ai/models/groq.py +127 -39
- pydantic_ai/models/huggingface.py +5 -5
- pydantic_ai/models/instrumented.py +49 -21
- pydantic_ai/models/mcp_sampling.py +3 -1
- pydantic_ai/models/mistral.py +8 -8
- pydantic_ai/models/openai.py +37 -42
- pydantic_ai/models/test.py +24 -4
- pydantic_ai/output.py +27 -32
- pydantic_ai/profiles/__init__.py +3 -3
- pydantic_ai/profiles/groq.py +1 -1
- pydantic_ai/profiles/openai.py +25 -4
- pydantic_ai/providers/__init__.py +4 -0
- pydantic_ai/providers/anthropic.py +2 -3
- pydantic_ai/providers/bedrock.py +3 -2
- pydantic_ai/providers/google_vertex.py +2 -1
- pydantic_ai/providers/groq.py +21 -2
- pydantic_ai/providers/litellm.py +134 -0
- pydantic_ai/result.py +173 -52
- pydantic_ai/retries.py +52 -31
- pydantic_ai/run.py +12 -5
- pydantic_ai/tools.py +127 -23
- pydantic_ai/toolsets/__init__.py +4 -1
- pydantic_ai/toolsets/_dynamic.py +4 -4
- pydantic_ai/toolsets/abstract.py +18 -2
- pydantic_ai/toolsets/approval_required.py +32 -0
- pydantic_ai/toolsets/combined.py +7 -12
- pydantic_ai/toolsets/{deferred.py → external.py} +11 -5
- pydantic_ai/toolsets/filtered.py +1 -1
- pydantic_ai/toolsets/function.py +58 -21
- pydantic_ai/toolsets/wrapper.py +2 -1
- pydantic_ai/usage.py +44 -8
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/METADATA +8 -9
- pydantic_ai_slim-1.0.0.dist-info/RECORD +121 -0
- pydantic_ai_slim-0.8.0.dist-info/RECORD +0 -119
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/__init__.py
CHANGED
|
@@ -7,16 +7,17 @@ specific LLM being used.
|
|
|
7
7
|
from __future__ import annotations as _annotations
|
|
8
8
|
|
|
9
9
|
import base64
|
|
10
|
+
import warnings
|
|
10
11
|
from abc import ABC, abstractmethod
|
|
11
12
|
from collections.abc import AsyncIterator, Iterator
|
|
12
13
|
from contextlib import asynccontextmanager, contextmanager
|
|
13
14
|
from dataclasses import dataclass, field, replace
|
|
14
15
|
from datetime import datetime
|
|
15
16
|
from functools import cache, cached_property
|
|
16
|
-
from typing import Any, Generic, TypeVar, overload
|
|
17
|
+
from typing import Any, Generic, Literal, TypeVar, overload
|
|
17
18
|
|
|
18
19
|
import httpx
|
|
19
|
-
from typing_extensions import
|
|
20
|
+
from typing_extensions import TypeAliasType, TypedDict
|
|
20
21
|
|
|
21
22
|
from .. import _utils
|
|
22
23
|
from .._output import OutputObjectDefinition
|
|
@@ -366,7 +367,7 @@ KnownModelName = TypeAliasType(
|
|
|
366
367
|
"""
|
|
367
368
|
|
|
368
369
|
|
|
369
|
-
@dataclass(repr=False)
|
|
370
|
+
@dataclass(repr=False, kw_only=True)
|
|
370
371
|
class ModelRequestParameters:
|
|
371
372
|
"""Configuration for an agent's request to a model, specifically related to tools and output handling."""
|
|
372
373
|
|
|
@@ -551,6 +552,7 @@ class StreamedResponse(ABC):
|
|
|
551
552
|
"""Streamed response from an LLM when calling a tool."""
|
|
552
553
|
|
|
553
554
|
model_request_parameters: ModelRequestParameters
|
|
555
|
+
|
|
554
556
|
final_result_event: FinalResultEvent | None = field(default=None, init=False)
|
|
555
557
|
|
|
556
558
|
_parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False)
|
|
@@ -684,19 +686,29 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
|
|
|
684
686
|
try:
|
|
685
687
|
provider, model_name = model.split(':', maxsplit=1)
|
|
686
688
|
except ValueError:
|
|
689
|
+
provider = None
|
|
687
690
|
model_name = model
|
|
688
|
-
# TODO(Marcelo): We should deprecate this way.
|
|
689
691
|
if model_name.startswith(('gpt', 'o1', 'o3')):
|
|
690
692
|
provider = 'openai'
|
|
691
693
|
elif model_name.startswith('claude'):
|
|
692
694
|
provider = 'anthropic'
|
|
693
695
|
elif model_name.startswith('gemini'):
|
|
694
696
|
provider = 'google-gla'
|
|
697
|
+
|
|
698
|
+
if provider is not None:
|
|
699
|
+
warnings.warn(
|
|
700
|
+
f"Specifying a model name without a provider prefix is deprecated. Instead of {model_name!r}, use '{provider}:{model_name}'.",
|
|
701
|
+
DeprecationWarning,
|
|
702
|
+
)
|
|
695
703
|
else:
|
|
696
704
|
raise UserError(f'Unknown model: {model}')
|
|
697
705
|
|
|
698
|
-
if provider == 'vertexai':
|
|
699
|
-
|
|
706
|
+
if provider == 'vertexai': # pragma: no cover
|
|
707
|
+
warnings.warn(
|
|
708
|
+
"The 'vertexai' provider name is deprecated. Use 'google-vertex' instead.",
|
|
709
|
+
DeprecationWarning,
|
|
710
|
+
)
|
|
711
|
+
provider = 'google-vertex'
|
|
700
712
|
|
|
701
713
|
if provider == 'cohere':
|
|
702
714
|
from .cohere import CohereModel
|
|
@@ -716,6 +728,7 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
|
|
|
716
728
|
'openrouter',
|
|
717
729
|
'together',
|
|
718
730
|
'vercel',
|
|
731
|
+
'litellm',
|
|
719
732
|
):
|
|
720
733
|
from .openai import OpenAIChatModel
|
|
721
734
|
|
|
@@ -909,5 +922,5 @@ def _get_final_result_event(e: ModelResponseStreamEvent, params: ModelRequestPar
|
|
|
909
922
|
elif isinstance(new_part, ToolCallPart) and (tool_def := params.tool_defs.get(new_part.tool_name)):
|
|
910
923
|
if tool_def.kind == 'output':
|
|
911
924
|
return FinalResultEvent(tool_name=new_part.tool_name, tool_call_id=new_part.tool_call_id)
|
|
912
|
-
elif tool_def.
|
|
925
|
+
elif tool_def.defer:
|
|
913
926
|
return FinalResultEvent(tool_name=None, tool_call_id=None)
|
pydantic_ai/models/anthropic.py
CHANGED
|
@@ -6,7 +6,7 @@ from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator
|
|
|
6
6
|
from contextlib import asynccontextmanager
|
|
7
7
|
from dataclasses import dataclass, field
|
|
8
8
|
from datetime import datetime, timezone
|
|
9
|
-
from typing import Any, Literal,
|
|
9
|
+
from typing import Any, Literal, cast, overload
|
|
10
10
|
|
|
11
11
|
from typing_extensions import assert_never
|
|
12
12
|
|
|
@@ -99,7 +99,7 @@ except ImportError as _import_error:
|
|
|
99
99
|
LatestAnthropicModelNames = ModelParam
|
|
100
100
|
"""Latest Anthropic models."""
|
|
101
101
|
|
|
102
|
-
AnthropicModelName =
|
|
102
|
+
AnthropicModelName = str | LatestAnthropicModelNames
|
|
103
103
|
"""Possible Anthropic model names.
|
|
104
104
|
|
|
105
105
|
Since Anthropic supports a variety of date-stamped models, we explicitly list the latest models but
|
|
@@ -290,7 +290,7 @@ class AnthropicModel(Model):
|
|
|
290
290
|
for item in response.content:
|
|
291
291
|
if isinstance(item, BetaTextBlock):
|
|
292
292
|
items.append(TextPart(content=item.text))
|
|
293
|
-
elif isinstance(item,
|
|
293
|
+
elif isinstance(item, BetaWebSearchToolResultBlock | BetaCodeExecutionToolResultBlock):
|
|
294
294
|
items.append(
|
|
295
295
|
BuiltinToolReturnPart(
|
|
296
296
|
provider_name='anthropic',
|
|
@@ -327,10 +327,10 @@ class AnthropicModel(Model):
|
|
|
327
327
|
)
|
|
328
328
|
|
|
329
329
|
return ModelResponse(
|
|
330
|
-
items,
|
|
330
|
+
parts=items,
|
|
331
331
|
usage=_map_usage(response),
|
|
332
332
|
model_name=response.model,
|
|
333
|
-
|
|
333
|
+
provider_response_id=response.id,
|
|
334
334
|
provider_name=self._provider.name,
|
|
335
335
|
)
|
|
336
336
|
|
|
@@ -536,7 +536,7 @@ class AnthropicModel(Model):
|
|
|
536
536
|
}
|
|
537
537
|
|
|
538
538
|
|
|
539
|
-
def _map_usage(message: BetaMessage |
|
|
539
|
+
def _map_usage(message: BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent) -> usage.RequestUsage:
|
|
540
540
|
if isinstance(message, BetaMessage):
|
|
541
541
|
response_usage = message.usage
|
|
542
542
|
elif isinstance(message, BetaRawMessageStartEvent):
|
|
@@ -544,12 +544,7 @@ def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.Reques
|
|
|
544
544
|
elif isinstance(message, BetaRawMessageDeltaEvent):
|
|
545
545
|
response_usage = message.usage
|
|
546
546
|
else:
|
|
547
|
-
|
|
548
|
-
# - RawMessageStopEvent
|
|
549
|
-
# - RawContentBlockStartEvent
|
|
550
|
-
# - RawContentBlockDeltaEvent
|
|
551
|
-
# - RawContentBlockStopEvent
|
|
552
|
-
return usage.RequestUsage()
|
|
547
|
+
assert_never(message)
|
|
553
548
|
|
|
554
549
|
# Store all integer-typed usage values in the details, except 'output_tokens' which is represented exactly by
|
|
555
550
|
# `response_tokens`
|
|
@@ -586,10 +581,8 @@ class AnthropicStreamedResponse(StreamedResponse):
|
|
|
586
581
|
current_block: BetaContentBlock | None = None
|
|
587
582
|
|
|
588
583
|
async for event in self._response:
|
|
589
|
-
self._usage += _map_usage(event)
|
|
590
|
-
|
|
591
584
|
if isinstance(event, BetaRawMessageStartEvent):
|
|
592
|
-
|
|
585
|
+
self._usage = _map_usage(event)
|
|
593
586
|
|
|
594
587
|
elif isinstance(event, BetaRawContentBlockStartEvent):
|
|
595
588
|
current_block = event.content_block
|
|
@@ -652,9 +645,9 @@ class AnthropicStreamedResponse(StreamedResponse):
|
|
|
652
645
|
pass
|
|
653
646
|
|
|
654
647
|
elif isinstance(event, BetaRawMessageDeltaEvent):
|
|
655
|
-
|
|
648
|
+
self._usage = _map_usage(event)
|
|
656
649
|
|
|
657
|
-
elif isinstance(event,
|
|
650
|
+
elif isinstance(event, BetaRawContentBlockStopEvent | BetaRawMessageStopEvent): # pragma: no branch
|
|
658
651
|
current_block = None
|
|
659
652
|
|
|
660
653
|
@property
|
pydantic_ai/models/bedrock.py
CHANGED
|
@@ -2,13 +2,12 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import functools
|
|
4
4
|
import typing
|
|
5
|
-
import warnings
|
|
6
5
|
from collections.abc import AsyncIterator, Iterable, Iterator, Mapping
|
|
7
6
|
from contextlib import asynccontextmanager
|
|
8
7
|
from dataclasses import dataclass, field
|
|
9
8
|
from datetime import datetime
|
|
10
9
|
from itertools import count
|
|
11
|
-
from typing import TYPE_CHECKING, Any, Generic, Literal,
|
|
10
|
+
from typing import TYPE_CHECKING, Any, Generic, Literal, cast, overload
|
|
12
11
|
|
|
13
12
|
import anyio
|
|
14
13
|
import anyio.to_thread
|
|
@@ -125,7 +124,7 @@ LatestBedrockModelNames = Literal[
|
|
|
125
124
|
]
|
|
126
125
|
"""Latest Bedrock models."""
|
|
127
126
|
|
|
128
|
-
BedrockModelName =
|
|
127
|
+
BedrockModelName = str | LatestBedrockModelNames
|
|
129
128
|
"""Possible Bedrock model names.
|
|
130
129
|
|
|
131
130
|
Since Bedrock supports a variety of date-stamped models, we explicitly list the latest models but allow any name in the type hints.
|
|
@@ -301,9 +300,13 @@ class BedrockConverseModel(Model):
|
|
|
301
300
|
input_tokens=response['usage']['inputTokens'],
|
|
302
301
|
output_tokens=response['usage']['outputTokens'],
|
|
303
302
|
)
|
|
304
|
-
|
|
303
|
+
response_id = response.get('ResponseMetadata', {}).get('RequestId', None)
|
|
305
304
|
return ModelResponse(
|
|
306
|
-
items,
|
|
305
|
+
parts=items,
|
|
306
|
+
usage=u,
|
|
307
|
+
model_name=self.model_name,
|
|
308
|
+
provider_response_id=response_id,
|
|
309
|
+
provider_name=self._provider.name,
|
|
307
310
|
)
|
|
308
311
|
|
|
309
312
|
@overload
|
|
@@ -486,7 +489,7 @@ class BedrockConverseModel(Model):
|
|
|
486
489
|
else:
|
|
487
490
|
# NOTE: We don't pass the thinking part to Bedrock for models other than Claude since it raises an error.
|
|
488
491
|
pass
|
|
489
|
-
elif isinstance(item,
|
|
492
|
+
elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart):
|
|
490
493
|
pass
|
|
491
494
|
else:
|
|
492
495
|
assert isinstance(item, ToolCallPart)
|
|
@@ -542,7 +545,7 @@ class BedrockConverseModel(Model):
|
|
|
542
545
|
content.append({'video': {'format': format, 'source': {'bytes': item.data}}})
|
|
543
546
|
else:
|
|
544
547
|
raise NotImplementedError('Binary content is not supported yet.')
|
|
545
|
-
elif isinstance(item,
|
|
548
|
+
elif isinstance(item, ImageUrl | DocumentUrl | VideoUrl):
|
|
546
549
|
downloaded_item = await download_item(item, data_format='bytes', type_format='extension')
|
|
547
550
|
format = downloaded_item['data_type']
|
|
548
551
|
if item.kind == 'image-url':
|
|
@@ -597,7 +600,7 @@ class BedrockStreamedResponse(StreamedResponse):
|
|
|
597
600
|
_provider_name: str
|
|
598
601
|
_timestamp: datetime = field(default_factory=_utils.now_utc)
|
|
599
602
|
|
|
600
|
-
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
603
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
601
604
|
"""Return an async iterator of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.
|
|
602
605
|
|
|
603
606
|
This method should be implemented by subclasses to translate the vendor-specific stream of events into
|
|
@@ -606,60 +609,55 @@ class BedrockStreamedResponse(StreamedResponse):
|
|
|
606
609
|
chunk: ConverseStreamOutputTypeDef
|
|
607
610
|
tool_id: str | None = None
|
|
608
611
|
async for chunk in _AsyncIteratorWrapper(self._event_stream):
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
if text := delta['reasoningContent'].get('text'):
|
|
612
|
+
match chunk:
|
|
613
|
+
case {'messageStart': _}:
|
|
614
|
+
continue
|
|
615
|
+
case {'messageStop': _}:
|
|
616
|
+
continue
|
|
617
|
+
case {'metadata': metadata}:
|
|
618
|
+
if 'usage' in metadata: # pragma: no branch
|
|
619
|
+
self._usage += self._map_usage(metadata)
|
|
620
|
+
continue
|
|
621
|
+
case {'contentBlockStart': content_block_start}:
|
|
622
|
+
index = content_block_start['contentBlockIndex']
|
|
623
|
+
start = content_block_start['start']
|
|
624
|
+
if 'toolUse' in start: # pragma: no branch
|
|
625
|
+
tool_use_start = start['toolUse']
|
|
626
|
+
tool_id = tool_use_start['toolUseId']
|
|
627
|
+
tool_name = tool_use_start['name']
|
|
628
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
629
|
+
vendor_part_id=index,
|
|
630
|
+
tool_name=tool_name,
|
|
631
|
+
args=None,
|
|
632
|
+
tool_call_id=tool_id,
|
|
633
|
+
)
|
|
634
|
+
if maybe_event: # pragma: no branch
|
|
635
|
+
yield maybe_event
|
|
636
|
+
case {'contentBlockDelta': content_block_delta}:
|
|
637
|
+
index = content_block_delta['contentBlockIndex']
|
|
638
|
+
delta = content_block_delta['delta']
|
|
639
|
+
if 'reasoningContent' in delta:
|
|
638
640
|
yield self._parts_manager.handle_thinking_delta(
|
|
639
641
|
vendor_part_id=index,
|
|
640
|
-
content=text,
|
|
642
|
+
content=delta['reasoningContent'].get('text'),
|
|
641
643
|
signature=delta['reasoningContent'].get('signature'),
|
|
642
644
|
)
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
645
|
+
if 'text' in delta:
|
|
646
|
+
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=delta['text'])
|
|
647
|
+
if maybe_event is not None: # pragma: no branch
|
|
648
|
+
yield maybe_event
|
|
649
|
+
if 'toolUse' in delta:
|
|
650
|
+
tool_use = delta['toolUse']
|
|
651
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
652
|
+
vendor_part_id=index,
|
|
653
|
+
tool_name=tool_use.get('name'),
|
|
654
|
+
args=tool_use.get('input'),
|
|
655
|
+
tool_call_id=tool_id,
|
|
648
656
|
)
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
if 'toolUse' in delta:
|
|
654
|
-
tool_use = delta['toolUse']
|
|
655
|
-
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
656
|
-
vendor_part_id=index,
|
|
657
|
-
tool_name=tool_use.get('name'),
|
|
658
|
-
args=tool_use.get('input'),
|
|
659
|
-
tool_call_id=tool_id,
|
|
660
|
-
)
|
|
661
|
-
if maybe_event: # pragma: no branch
|
|
662
|
-
yield maybe_event
|
|
657
|
+
if maybe_event: # pragma: no branch
|
|
658
|
+
yield maybe_event
|
|
659
|
+
case _:
|
|
660
|
+
pass # pyright wants match statements to be exhaustive
|
|
663
661
|
|
|
664
662
|
@property
|
|
665
663
|
def model_name(self) -> str:
|
pydantic_ai/models/cohere.py
CHANGED
|
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
from collections.abc import Iterable
|
|
4
4
|
from dataclasses import dataclass, field
|
|
5
|
-
from typing import Literal,
|
|
5
|
+
from typing import Literal, cast
|
|
6
6
|
|
|
7
7
|
from typing_extensions import assert_never
|
|
8
8
|
|
|
@@ -72,7 +72,7 @@ LatestCohereModelNames = Literal[
|
|
|
72
72
|
]
|
|
73
73
|
"""Latest Cohere models."""
|
|
74
74
|
|
|
75
|
-
CohereModelName =
|
|
75
|
+
CohereModelName = str | LatestCohereModelNames
|
|
76
76
|
"""Possible Cohere model names.
|
|
77
77
|
|
|
78
78
|
Since Cohere supports a variety of date-stamped models, we explicitly list the latest models but
|
|
@@ -228,7 +228,7 @@ class CohereModel(Model):
|
|
|
228
228
|
pass
|
|
229
229
|
elif isinstance(item, ToolCallPart):
|
|
230
230
|
tool_calls.append(self._map_tool_call(item))
|
|
231
|
-
elif isinstance(item,
|
|
231
|
+
elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
|
|
232
232
|
# This is currently never returned from cohere
|
|
233
233
|
pass
|
|
234
234
|
else:
|
pydantic_ai/models/fallback.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
from collections.abc import AsyncIterator
|
|
3
|
+
from collections.abc import AsyncIterator, Callable
|
|
4
4
|
from contextlib import AsyncExitStack, asynccontextmanager, suppress
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
|
-
from typing import TYPE_CHECKING, Any
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
7
7
|
|
|
8
8
|
from opentelemetry.trace import get_current_span
|
|
9
9
|
|
pydantic_ai/models/function.py
CHANGED
|
@@ -2,14 +2,14 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
4
|
import re
|
|
5
|
-
from collections.abc import AsyncIterator, Awaitable, Iterable, Sequence
|
|
5
|
+
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence
|
|
6
6
|
from contextlib import asynccontextmanager
|
|
7
|
-
from dataclasses import dataclass, field
|
|
7
|
+
from dataclasses import KW_ONLY, dataclass, field
|
|
8
8
|
from datetime import datetime
|
|
9
9
|
from itertools import chain
|
|
10
|
-
from typing import Any,
|
|
10
|
+
from typing import Any, TypeAlias
|
|
11
11
|
|
|
12
|
-
from typing_extensions import
|
|
12
|
+
from typing_extensions import assert_never, overload
|
|
13
13
|
|
|
14
14
|
from .. import _utils, usage
|
|
15
15
|
from .._run_context import RunContext
|
|
@@ -44,8 +44,8 @@ class FunctionModel(Model):
|
|
|
44
44
|
Apart from `__init__`, all methods are private or match those of the base class.
|
|
45
45
|
"""
|
|
46
46
|
|
|
47
|
-
function: FunctionDef | None
|
|
48
|
-
stream_function: StreamFunctionDef | None
|
|
47
|
+
function: FunctionDef | None
|
|
48
|
+
stream_function: StreamFunctionDef | None
|
|
49
49
|
|
|
50
50
|
_model_name: str = field(repr=False)
|
|
51
51
|
_system: str = field(default='function', repr=False)
|
|
@@ -120,10 +120,10 @@ class FunctionModel(Model):
|
|
|
120
120
|
model_request_parameters: ModelRequestParameters,
|
|
121
121
|
) -> ModelResponse:
|
|
122
122
|
agent_info = AgentInfo(
|
|
123
|
-
model_request_parameters.function_tools,
|
|
124
|
-
model_request_parameters.allow_text_output,
|
|
125
|
-
model_request_parameters.output_tools,
|
|
126
|
-
model_settings,
|
|
123
|
+
function_tools=model_request_parameters.function_tools,
|
|
124
|
+
allow_text_output=model_request_parameters.allow_text_output,
|
|
125
|
+
output_tools=model_request_parameters.output_tools,
|
|
126
|
+
model_settings=model_settings,
|
|
127
127
|
)
|
|
128
128
|
|
|
129
129
|
assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
|
|
@@ -149,10 +149,10 @@ class FunctionModel(Model):
|
|
|
149
149
|
run_context: RunContext[Any] | None = None,
|
|
150
150
|
) -> AsyncIterator[StreamedResponse]:
|
|
151
151
|
agent_info = AgentInfo(
|
|
152
|
-
model_request_parameters.function_tools,
|
|
153
|
-
model_request_parameters.allow_text_output,
|
|
154
|
-
model_request_parameters.output_tools,
|
|
155
|
-
model_settings,
|
|
152
|
+
function_tools=model_request_parameters.function_tools,
|
|
153
|
+
allow_text_output=model_request_parameters.allow_text_output,
|
|
154
|
+
output_tools=model_request_parameters.output_tools,
|
|
155
|
+
model_settings=model_settings,
|
|
156
156
|
)
|
|
157
157
|
|
|
158
158
|
assert self.stream_function is not None, (
|
|
@@ -182,7 +182,7 @@ class FunctionModel(Model):
|
|
|
182
182
|
return self._system
|
|
183
183
|
|
|
184
184
|
|
|
185
|
-
@dataclass(frozen=True)
|
|
185
|
+
@dataclass(frozen=True, kw_only=True)
|
|
186
186
|
class AgentInfo:
|
|
187
187
|
"""Information about an agent.
|
|
188
188
|
|
|
@@ -212,13 +212,17 @@ class DeltaToolCall:
|
|
|
212
212
|
|
|
213
213
|
name: str | None = None
|
|
214
214
|
"""Incremental change to the name of the tool."""
|
|
215
|
+
|
|
215
216
|
json_args: str | None = None
|
|
216
217
|
"""Incremental change to the arguments as JSON"""
|
|
218
|
+
|
|
219
|
+
_: KW_ONLY
|
|
220
|
+
|
|
217
221
|
tool_call_id: str | None = None
|
|
218
222
|
"""Incremental change to the tool call ID."""
|
|
219
223
|
|
|
220
224
|
|
|
221
|
-
@dataclass
|
|
225
|
+
@dataclass(kw_only=True)
|
|
222
226
|
class DeltaThinkingPart:
|
|
223
227
|
"""Incremental change to a thinking part.
|
|
224
228
|
|
|
@@ -237,18 +241,16 @@ DeltaToolCalls: TypeAlias = dict[int, DeltaToolCall]
|
|
|
237
241
|
DeltaThinkingCalls: TypeAlias = dict[int, DeltaThinkingPart]
|
|
238
242
|
"""A mapping of thinking call IDs to incremental changes."""
|
|
239
243
|
|
|
240
|
-
|
|
241
|
-
FunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], Union[ModelResponse, Awaitable[ModelResponse]]]
|
|
244
|
+
FunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], ModelResponse | Awaitable[ModelResponse]]
|
|
242
245
|
"""A function used to generate a non-streamed response."""
|
|
243
246
|
|
|
244
|
-
# TODO: Change signature as indicated above
|
|
245
247
|
StreamFunctionDef: TypeAlias = Callable[
|
|
246
|
-
[list[ModelMessage], AgentInfo], AsyncIterator[
|
|
248
|
+
[list[ModelMessage], AgentInfo], AsyncIterator[str | DeltaToolCalls | DeltaThinkingCalls]
|
|
247
249
|
]
|
|
248
250
|
"""A function used to generate a streamed response.
|
|
249
251
|
|
|
250
|
-
While this is defined as having return type of `AsyncIterator[
|
|
251
|
-
really be considered as `
|
|
252
|
+
While this is defined as having return type of `AsyncIterator[str | DeltaToolCalls | DeltaThinkingCalls]`, it should
|
|
253
|
+
really be considered as `AsyncIterator[str] | AsyncIterator[DeltaToolCalls] | AsyncIterator[DeltaThinkingCalls]`,
|
|
252
254
|
|
|
253
255
|
E.g. you need to yield all text, all `DeltaToolCalls`, or all `DeltaThinkingCalls`, not mix them.
|
|
254
256
|
"""
|
|
@@ -326,7 +328,7 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.RequestUsage:
|
|
|
326
328
|
for message in messages:
|
|
327
329
|
if isinstance(message, ModelRequest):
|
|
328
330
|
for part in message.parts:
|
|
329
|
-
if isinstance(part,
|
|
331
|
+
if isinstance(part, SystemPromptPart | UserPromptPart):
|
|
330
332
|
request_tokens += _estimate_string_tokens(part.content)
|
|
331
333
|
elif isinstance(part, ToolReturnPart):
|
|
332
334
|
request_tokens += _estimate_string_tokens(part.model_response_str())
|
pydantic_ai/models/gemini.py
CHANGED
|
@@ -5,7 +5,7 @@ from collections.abc import AsyncIterator, Sequence
|
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from datetime import datetime
|
|
8
|
-
from typing import Annotated, Any, Literal, Protocol,
|
|
8
|
+
from typing import Annotated, Any, Literal, Protocol, cast
|
|
9
9
|
from uuid import uuid4
|
|
10
10
|
|
|
11
11
|
import httpx
|
|
@@ -51,7 +51,7 @@ LatestGeminiModelNames = Literal[
|
|
|
51
51
|
]
|
|
52
52
|
"""Latest Gemini models."""
|
|
53
53
|
|
|
54
|
-
GeminiModelName =
|
|
54
|
+
GeminiModelName = str | LatestGeminiModelNames
|
|
55
55
|
"""Possible Gemini model names.
|
|
56
56
|
|
|
57
57
|
Since Gemini supports a variety of date-stamped models, we explicitly list the latest models but
|
|
@@ -211,7 +211,9 @@ class GeminiModel(Model):
|
|
|
211
211
|
generation_config = _settings_to_generation_config(model_settings)
|
|
212
212
|
if model_request_parameters.output_mode == 'native':
|
|
213
213
|
if tools:
|
|
214
|
-
raise UserError(
|
|
214
|
+
raise UserError(
|
|
215
|
+
'Gemini does not support `NativeOutput` and tools at the same time. Use `output_type=ToolOutput(...)` instead.'
|
|
216
|
+
)
|
|
215
217
|
|
|
216
218
|
generation_config['response_mime_type'] = 'application/json'
|
|
217
219
|
|
|
@@ -615,7 +617,7 @@ def _content_model_response(m: ModelResponse) -> _GeminiContent:
|
|
|
615
617
|
elif isinstance(item, TextPart):
|
|
616
618
|
if item.content:
|
|
617
619
|
parts.append(_GeminiTextPart(text=item.content))
|
|
618
|
-
elif isinstance(item,
|
|
620
|
+
elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
|
|
619
621
|
# This is currently never returned from gemini
|
|
620
622
|
pass
|
|
621
623
|
else:
|
|
@@ -690,7 +692,7 @@ def _process_response_from_parts(
|
|
|
690
692
|
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
|
|
691
693
|
)
|
|
692
694
|
return ModelResponse(
|
|
693
|
-
parts=items, usage=usage, model_name=model_name,
|
|
695
|
+
parts=items, usage=usage, model_name=model_name, provider_response_id=vendor_id, provider_details=vendor_details
|
|
694
696
|
)
|
|
695
697
|
|
|
696
698
|
|
|
@@ -735,16 +737,13 @@ def _part_discriminator(v: Any) -> str:
|
|
|
735
737
|
|
|
736
738
|
# See <https://ai.google.dev/api/caching#Part>
|
|
737
739
|
# we don't currently support other part types
|
|
738
|
-
# TODO discriminator
|
|
739
740
|
_GeminiPartUnion = Annotated[
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
Annotated[_GeminiThoughtPart, pydantic.Tag('thought')],
|
|
747
|
-
],
|
|
741
|
+
Annotated[_GeminiTextPart, pydantic.Tag('text')]
|
|
742
|
+
| Annotated[_GeminiFunctionCallPart, pydantic.Tag('function_call')]
|
|
743
|
+
| Annotated[_GeminiFunctionResponsePart, pydantic.Tag('function_response')]
|
|
744
|
+
| Annotated[_GeminiInlineDataPart, pydantic.Tag('inline_data')]
|
|
745
|
+
| Annotated[_GeminiFileDataPart, pydantic.Tag('file_data')]
|
|
746
|
+
| Annotated[_GeminiThoughtPart, pydantic.Tag('thought')],
|
|
748
747
|
pydantic.Discriminator(_part_discriminator),
|
|
749
748
|
]
|
|
750
749
|
|
pydantic_ai/models/google.py
CHANGED
|
@@ -5,7 +5,7 @@ from collections.abc import AsyncIterator, Awaitable
|
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from datetime import datetime
|
|
8
|
-
from typing import Any, Literal,
|
|
8
|
+
from typing import Any, Literal, cast, overload
|
|
9
9
|
from uuid import uuid4
|
|
10
10
|
|
|
11
11
|
from typing_extensions import assert_never
|
|
@@ -91,7 +91,7 @@ LatestGoogleModelNames = Literal[
|
|
|
91
91
|
]
|
|
92
92
|
"""Latest Gemini models."""
|
|
93
93
|
|
|
94
|
-
GoogleModelName =
|
|
94
|
+
GoogleModelName = str | LatestGoogleModelNames
|
|
95
95
|
"""Possible Gemini model names.
|
|
96
96
|
|
|
97
97
|
Since Gemini supports a variety of date-stamped models, we explicitly list the latest models but
|
|
@@ -264,6 +264,14 @@ class GoogleModel(Model):
|
|
|
264
264
|
yield await self._process_streamed_response(response, model_request_parameters) # type: ignore
|
|
265
265
|
|
|
266
266
|
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolDict] | None:
|
|
267
|
+
if model_request_parameters.builtin_tools:
|
|
268
|
+
if model_request_parameters.output_tools:
|
|
269
|
+
raise UserError(
|
|
270
|
+
'Gemini does not support output tools and built-in tools at the same time. Use `output_type=PromptedOutput(...)` instead.'
|
|
271
|
+
)
|
|
272
|
+
if model_request_parameters.function_tools:
|
|
273
|
+
raise UserError('Gemini does not support user tools and built-in tools at the same time.')
|
|
274
|
+
|
|
267
275
|
tools: list[ToolDict] = [
|
|
268
276
|
ToolDict(function_declarations=[_function_declaration_from_tool(t)])
|
|
269
277
|
for t in model_request_parameters.tool_defs.values()
|
|
@@ -334,7 +342,9 @@ class GoogleModel(Model):
|
|
|
334
342
|
response_schema = None
|
|
335
343
|
if model_request_parameters.output_mode == 'native':
|
|
336
344
|
if tools:
|
|
337
|
-
raise UserError(
|
|
345
|
+
raise UserError(
|
|
346
|
+
'Gemini does not support `NativeOutput` and tools at the same time. Use `output_type=ToolOutput(...)` instead.'
|
|
347
|
+
)
|
|
338
348
|
response_mime_type = 'application/json'
|
|
339
349
|
output_object = model_request_parameters.output_object
|
|
340
350
|
assert output_object is not None
|
|
@@ -349,7 +359,7 @@ class GoogleModel(Model):
|
|
|
349
359
|
'headers': {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}
|
|
350
360
|
}
|
|
351
361
|
if timeout := model_settings.get('timeout'):
|
|
352
|
-
if isinstance(timeout,
|
|
362
|
+
if isinstance(timeout, int | float):
|
|
353
363
|
http_options['timeout'] = int(1000 * timeout)
|
|
354
364
|
else:
|
|
355
365
|
raise UserError('Google does not support setting ModelSettings.timeout to a httpx.Timeout')
|
|
@@ -559,6 +569,10 @@ class GeminiStreamedResponse(StreamedResponse):
|
|
|
559
569
|
)
|
|
560
570
|
if maybe_event is not None: # pragma: no branch
|
|
561
571
|
yield maybe_event
|
|
572
|
+
elif part.executable_code is not None:
|
|
573
|
+
pass
|
|
574
|
+
elif part.code_execution_result is not None:
|
|
575
|
+
pass
|
|
562
576
|
else:
|
|
563
577
|
assert part.function_response is not None, f'Unexpected part: {part}' # pragma: no cover
|
|
564
578
|
|
|
@@ -648,7 +662,7 @@ def _process_response_from_parts(
|
|
|
648
662
|
parts=items,
|
|
649
663
|
model_name=model_name,
|
|
650
664
|
usage=usage,
|
|
651
|
-
|
|
665
|
+
provider_response_id=vendor_id,
|
|
652
666
|
provider_details=vendor_details,
|
|
653
667
|
provider_name=provider_name,
|
|
654
668
|
)
|