pydantic-ai-slim 1.2.1__py3-none-any.whl → 1.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_ai/__init__.py +6 -0
- pydantic_ai/_agent_graph.py +67 -20
- pydantic_ai/_cli.py +2 -2
- pydantic_ai/_output.py +20 -12
- pydantic_ai/_run_context.py +6 -2
- pydantic_ai/_utils.py +26 -8
- pydantic_ai/ag_ui.py +50 -696
- pydantic_ai/agent/__init__.py +13 -25
- pydantic_ai/agent/abstract.py +146 -9
- pydantic_ai/builtin_tools.py +106 -4
- pydantic_ai/direct.py +16 -4
- pydantic_ai/durable_exec/dbos/_agent.py +3 -0
- pydantic_ai/durable_exec/prefect/_agent.py +3 -0
- pydantic_ai/durable_exec/temporal/__init__.py +11 -0
- pydantic_ai/durable_exec/temporal/_agent.py +3 -0
- pydantic_ai/durable_exec/temporal/_function_toolset.py +23 -72
- pydantic_ai/durable_exec/temporal/_mcp_server.py +30 -30
- pydantic_ai/durable_exec/temporal/_run_context.py +7 -2
- pydantic_ai/durable_exec/temporal/_toolset.py +67 -3
- pydantic_ai/exceptions.py +6 -1
- pydantic_ai/mcp.py +1 -22
- pydantic_ai/messages.py +46 -8
- pydantic_ai/models/__init__.py +87 -38
- pydantic_ai/models/anthropic.py +132 -11
- pydantic_ai/models/bedrock.py +4 -4
- pydantic_ai/models/cohere.py +0 -7
- pydantic_ai/models/gemini.py +9 -2
- pydantic_ai/models/google.py +26 -23
- pydantic_ai/models/groq.py +13 -5
- pydantic_ai/models/huggingface.py +2 -2
- pydantic_ai/models/openai.py +251 -52
- pydantic_ai/models/outlines.py +563 -0
- pydantic_ai/models/test.py +6 -3
- pydantic_ai/profiles/openai.py +7 -0
- pydantic_ai/providers/__init__.py +25 -12
- pydantic_ai/providers/anthropic.py +2 -2
- pydantic_ai/providers/bedrock.py +60 -16
- pydantic_ai/providers/gateway.py +60 -72
- pydantic_ai/providers/google.py +91 -24
- pydantic_ai/providers/openrouter.py +3 -0
- pydantic_ai/providers/outlines.py +40 -0
- pydantic_ai/providers/ovhcloud.py +95 -0
- pydantic_ai/result.py +173 -8
- pydantic_ai/run.py +40 -24
- pydantic_ai/settings.py +8 -0
- pydantic_ai/tools.py +10 -6
- pydantic_ai/toolsets/fastmcp.py +215 -0
- pydantic_ai/ui/__init__.py +16 -0
- pydantic_ai/ui/_adapter.py +386 -0
- pydantic_ai/ui/_event_stream.py +591 -0
- pydantic_ai/ui/_messages_builder.py +28 -0
- pydantic_ai/ui/ag_ui/__init__.py +9 -0
- pydantic_ai/ui/ag_ui/_adapter.py +187 -0
- pydantic_ai/ui/ag_ui/_event_stream.py +236 -0
- pydantic_ai/ui/ag_ui/app.py +148 -0
- pydantic_ai/ui/vercel_ai/__init__.py +16 -0
- pydantic_ai/ui/vercel_ai/_adapter.py +199 -0
- pydantic_ai/ui/vercel_ai/_event_stream.py +187 -0
- pydantic_ai/ui/vercel_ai/_utils.py +16 -0
- pydantic_ai/ui/vercel_ai/request_types.py +275 -0
- pydantic_ai/ui/vercel_ai/response_types.py +230 -0
- pydantic_ai/usage.py +13 -2
- {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/METADATA +23 -5
- {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/RECORD +67 -49
- {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/__init__.py
CHANGED
|
@@ -27,6 +27,7 @@ from .._run_context import RunContext
|
|
|
27
27
|
from ..builtin_tools import AbstractBuiltinTool
|
|
28
28
|
from ..exceptions import UserError
|
|
29
29
|
from ..messages import (
|
|
30
|
+
BaseToolCallPart,
|
|
30
31
|
BinaryImage,
|
|
31
32
|
FilePart,
|
|
32
33
|
FileUrl,
|
|
@@ -35,14 +36,18 @@ from ..messages import (
|
|
|
35
36
|
ModelMessage,
|
|
36
37
|
ModelRequest,
|
|
37
38
|
ModelResponse,
|
|
39
|
+
ModelResponsePart,
|
|
38
40
|
ModelResponseStreamEvent,
|
|
41
|
+
PartEndEvent,
|
|
39
42
|
PartStartEvent,
|
|
40
43
|
TextPart,
|
|
44
|
+
ThinkingPart,
|
|
41
45
|
ToolCallPart,
|
|
42
46
|
VideoUrl,
|
|
43
47
|
)
|
|
44
48
|
from ..output import OutputMode
|
|
45
49
|
from ..profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec
|
|
50
|
+
from ..providers import infer_provider
|
|
46
51
|
from ..settings import ModelSettings, merge_model_settings
|
|
47
52
|
from ..tools import ToolDefinition
|
|
48
53
|
from ..usage import RequestUsage
|
|
@@ -129,15 +134,8 @@ KnownModelName = TypeAliasType(
|
|
|
129
134
|
'cerebras:qwen-3-235b-a22b-thinking-2507',
|
|
130
135
|
'cohere:c4ai-aya-expanse-32b',
|
|
131
136
|
'cohere:c4ai-aya-expanse-8b',
|
|
132
|
-
'cohere:command',
|
|
133
|
-
'cohere:command-light',
|
|
134
|
-
'cohere:command-light-nightly',
|
|
135
137
|
'cohere:command-nightly',
|
|
136
|
-
'cohere:command-r',
|
|
137
|
-
'cohere:command-r-03-2024',
|
|
138
138
|
'cohere:command-r-08-2024',
|
|
139
|
-
'cohere:command-r-plus',
|
|
140
|
-
'cohere:command-r-plus-04-2024',
|
|
141
139
|
'cohere:command-r-plus-08-2024',
|
|
142
140
|
'cohere:command-r7b-12-2024',
|
|
143
141
|
'deepseek:deepseek-chat',
|
|
@@ -416,9 +414,17 @@ class Model(ABC):
|
|
|
416
414
|
they need to customize the preparation flow further, but most implementations should simply call
|
|
417
415
|
``self.prepare_request(...)`` at the start of their ``request`` (and related) methods.
|
|
418
416
|
"""
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
417
|
+
model_settings = merge_model_settings(self.settings, model_settings)
|
|
418
|
+
|
|
419
|
+
if builtin_tools := model_request_parameters.builtin_tools:
|
|
420
|
+
# Deduplicate builtin tools
|
|
421
|
+
model_request_parameters = replace(
|
|
422
|
+
model_request_parameters,
|
|
423
|
+
builtin_tools=list({tool.unique_id: tool for tool in builtin_tools}.values()),
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
model_request_parameters = self.customize_request_parameters(model_request_parameters)
|
|
427
|
+
return model_settings, model_request_parameters
|
|
422
428
|
|
|
423
429
|
@property
|
|
424
430
|
@abstractmethod
|
|
@@ -541,7 +547,44 @@ class StreamedResponse(ABC):
|
|
|
541
547
|
async for event in iterator:
|
|
542
548
|
yield event
|
|
543
549
|
|
|
544
|
-
|
|
550
|
+
async def iterator_with_part_end(
|
|
551
|
+
iterator: AsyncIterator[ModelResponseStreamEvent],
|
|
552
|
+
) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
553
|
+
last_start_event: PartStartEvent | None = None
|
|
554
|
+
|
|
555
|
+
def part_end_event(next_part: ModelResponsePart | None = None) -> PartEndEvent | None:
|
|
556
|
+
if not last_start_event:
|
|
557
|
+
return None
|
|
558
|
+
|
|
559
|
+
index = last_start_event.index
|
|
560
|
+
part = self._parts_manager.get_parts()[index]
|
|
561
|
+
if not isinstance(part, TextPart | ThinkingPart | BaseToolCallPart):
|
|
562
|
+
# Parts other than these 3 don't have deltas, so don't need an end part.
|
|
563
|
+
return None
|
|
564
|
+
|
|
565
|
+
return PartEndEvent(
|
|
566
|
+
index=index,
|
|
567
|
+
part=part,
|
|
568
|
+
next_part_kind=next_part.part_kind if next_part else None,
|
|
569
|
+
)
|
|
570
|
+
|
|
571
|
+
async for event in iterator:
|
|
572
|
+
if isinstance(event, PartStartEvent):
|
|
573
|
+
if last_start_event:
|
|
574
|
+
end_event = part_end_event(event.part)
|
|
575
|
+
if end_event:
|
|
576
|
+
yield end_event
|
|
577
|
+
|
|
578
|
+
event.previous_part_kind = last_start_event.part.part_kind
|
|
579
|
+
last_start_event = event
|
|
580
|
+
|
|
581
|
+
yield event
|
|
582
|
+
|
|
583
|
+
end_event = part_end_event()
|
|
584
|
+
if end_event:
|
|
585
|
+
yield end_event
|
|
586
|
+
|
|
587
|
+
self._event_iterator = iterator_with_part_end(iterator_with_final_event(self._get_event_iterator()))
|
|
545
588
|
return self._event_iterator
|
|
546
589
|
|
|
547
590
|
@abstractmethod
|
|
@@ -644,41 +687,39 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
|
|
|
644
687
|
return TestModel()
|
|
645
688
|
|
|
646
689
|
try:
|
|
647
|
-
|
|
690
|
+
provider_name, model_name = model.split(':', maxsplit=1)
|
|
648
691
|
except ValueError:
|
|
649
|
-
|
|
692
|
+
provider_name = None
|
|
650
693
|
model_name = model
|
|
651
694
|
if model_name.startswith(('gpt', 'o1', 'o3')):
|
|
652
|
-
|
|
695
|
+
provider_name = 'openai'
|
|
653
696
|
elif model_name.startswith('claude'):
|
|
654
|
-
|
|
697
|
+
provider_name = 'anthropic'
|
|
655
698
|
elif model_name.startswith('gemini'):
|
|
656
|
-
|
|
699
|
+
provider_name = 'google-gla'
|
|
657
700
|
|
|
658
|
-
if
|
|
701
|
+
if provider_name is not None:
|
|
659
702
|
warnings.warn(
|
|
660
|
-
f"Specifying a model name without a provider prefix is deprecated. Instead of {model_name!r}, use '{
|
|
703
|
+
f"Specifying a model name without a provider prefix is deprecated. Instead of {model_name!r}, use '{provider_name}:{model_name}'.",
|
|
661
704
|
DeprecationWarning,
|
|
662
705
|
)
|
|
663
706
|
else:
|
|
664
707
|
raise UserError(f'Unknown model: {model}')
|
|
665
708
|
|
|
666
|
-
if
|
|
709
|
+
if provider_name == 'vertexai': # pragma: no cover
|
|
667
710
|
warnings.warn(
|
|
668
711
|
"The 'vertexai' provider name is deprecated. Use 'google-vertex' instead.",
|
|
669
712
|
DeprecationWarning,
|
|
670
713
|
)
|
|
671
|
-
|
|
714
|
+
provider_name = 'google-vertex'
|
|
672
715
|
|
|
673
|
-
|
|
674
|
-
from ..providers.gateway import infer_model as infer_model_from_gateway
|
|
716
|
+
provider = infer_provider(provider_name)
|
|
675
717
|
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
elif provider in (
|
|
718
|
+
model_kind = provider_name
|
|
719
|
+
if model_kind.startswith('gateway/'):
|
|
720
|
+
model_kind = provider_name.removeprefix('gateway/')
|
|
721
|
+
if model_kind in (
|
|
722
|
+
'openai',
|
|
682
723
|
'azure',
|
|
683
724
|
'deepseek',
|
|
684
725
|
'cerebras',
|
|
@@ -688,42 +729,50 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
|
|
|
688
729
|
'heroku',
|
|
689
730
|
'moonshotai',
|
|
690
731
|
'ollama',
|
|
691
|
-
'openai',
|
|
692
|
-
'openai-chat',
|
|
693
732
|
'openrouter',
|
|
694
733
|
'together',
|
|
695
734
|
'vercel',
|
|
696
735
|
'litellm',
|
|
697
736
|
'nebius',
|
|
737
|
+
'ovhcloud',
|
|
698
738
|
):
|
|
739
|
+
model_kind = 'openai-chat'
|
|
740
|
+
elif model_kind in ('google-gla', 'google-vertex'):
|
|
741
|
+
model_kind = 'google'
|
|
742
|
+
|
|
743
|
+
if model_kind == 'openai-chat':
|
|
699
744
|
from .openai import OpenAIChatModel
|
|
700
745
|
|
|
701
746
|
return OpenAIChatModel(model_name, provider=provider)
|
|
702
|
-
elif
|
|
747
|
+
elif model_kind == 'openai-responses':
|
|
703
748
|
from .openai import OpenAIResponsesModel
|
|
704
749
|
|
|
705
|
-
return OpenAIResponsesModel(model_name, provider=
|
|
706
|
-
elif
|
|
750
|
+
return OpenAIResponsesModel(model_name, provider=provider)
|
|
751
|
+
elif model_kind == 'google':
|
|
707
752
|
from .google import GoogleModel
|
|
708
753
|
|
|
709
754
|
return GoogleModel(model_name, provider=provider)
|
|
710
|
-
elif
|
|
755
|
+
elif model_kind == 'groq':
|
|
711
756
|
from .groq import GroqModel
|
|
712
757
|
|
|
713
758
|
return GroqModel(model_name, provider=provider)
|
|
714
|
-
elif
|
|
759
|
+
elif model_kind == 'cohere':
|
|
760
|
+
from .cohere import CohereModel
|
|
761
|
+
|
|
762
|
+
return CohereModel(model_name, provider=provider)
|
|
763
|
+
elif model_kind == 'mistral':
|
|
715
764
|
from .mistral import MistralModel
|
|
716
765
|
|
|
717
766
|
return MistralModel(model_name, provider=provider)
|
|
718
|
-
elif
|
|
767
|
+
elif model_kind == 'anthropic':
|
|
719
768
|
from .anthropic import AnthropicModel
|
|
720
769
|
|
|
721
770
|
return AnthropicModel(model_name, provider=provider)
|
|
722
|
-
elif
|
|
771
|
+
elif model_kind == 'bedrock':
|
|
723
772
|
from .bedrock import BedrockConverseModel
|
|
724
773
|
|
|
725
774
|
return BedrockConverseModel(model_name, provider=provider)
|
|
726
|
-
elif
|
|
775
|
+
elif model_kind == 'huggingface':
|
|
727
776
|
from .huggingface import HuggingFaceModel
|
|
728
777
|
|
|
729
778
|
return HuggingFaceModel(model_name, provider=provider)
|
pydantic_ai/models/anthropic.py
CHANGED
|
@@ -3,7 +3,7 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
import io
|
|
4
4
|
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
|
-
from dataclasses import dataclass, field
|
|
6
|
+
from dataclasses import dataclass, field, replace
|
|
7
7
|
from datetime import datetime
|
|
8
8
|
from typing import Any, Literal, cast, overload
|
|
9
9
|
|
|
@@ -13,7 +13,7 @@ from typing_extensions import assert_never
|
|
|
13
13
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
14
14
|
from .._run_context import RunContext
|
|
15
15
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
16
|
-
from ..builtin_tools import CodeExecutionTool, MemoryTool, WebSearchTool
|
|
16
|
+
from ..builtin_tools import CodeExecutionTool, MCPServerTool, MemoryTool, WebSearchTool
|
|
17
17
|
from ..exceptions import UserError
|
|
18
18
|
from ..messages import (
|
|
19
19
|
BinaryContent,
|
|
@@ -68,6 +68,9 @@ try:
|
|
|
68
68
|
BetaContentBlockParam,
|
|
69
69
|
BetaImageBlockParam,
|
|
70
70
|
BetaInputJSONDelta,
|
|
71
|
+
BetaMCPToolResultBlock,
|
|
72
|
+
BetaMCPToolUseBlock,
|
|
73
|
+
BetaMCPToolUseBlockParam,
|
|
71
74
|
BetaMemoryTool20250818Param,
|
|
72
75
|
BetaMessage,
|
|
73
76
|
BetaMessageParam,
|
|
@@ -82,6 +85,8 @@ try:
|
|
|
82
85
|
BetaRawMessageStreamEvent,
|
|
83
86
|
BetaRedactedThinkingBlock,
|
|
84
87
|
BetaRedactedThinkingBlockParam,
|
|
88
|
+
BetaRequestMCPServerToolConfigurationParam,
|
|
89
|
+
BetaRequestMCPServerURLDefinitionParam,
|
|
85
90
|
BetaServerToolUseBlock,
|
|
86
91
|
BetaServerToolUseBlockParam,
|
|
87
92
|
BetaSignatureDelta,
|
|
@@ -162,7 +167,7 @@ class AnthropicModel(Model):
|
|
|
162
167
|
self,
|
|
163
168
|
model_name: AnthropicModelName,
|
|
164
169
|
*,
|
|
165
|
-
provider: Literal['anthropic'] | Provider[AsyncAnthropicClient] = 'anthropic',
|
|
170
|
+
provider: Literal['anthropic', 'gateway'] | Provider[AsyncAnthropicClient] = 'anthropic',
|
|
166
171
|
profile: ModelProfileSpec | None = None,
|
|
167
172
|
settings: ModelSettings | None = None,
|
|
168
173
|
):
|
|
@@ -179,7 +184,7 @@ class AnthropicModel(Model):
|
|
|
179
184
|
self._model_name = model_name
|
|
180
185
|
|
|
181
186
|
if isinstance(provider, str):
|
|
182
|
-
provider = infer_provider(provider)
|
|
187
|
+
provider = infer_provider('gateway/anthropic' if provider == 'gateway' else provider)
|
|
183
188
|
self._provider = provider
|
|
184
189
|
self.client = provider.client
|
|
185
190
|
|
|
@@ -264,7 +269,7 @@ class AnthropicModel(Model):
|
|
|
264
269
|
) -> BetaMessage | AsyncStream[BetaRawMessageStreamEvent]:
|
|
265
270
|
# standalone function to make it easier to override
|
|
266
271
|
tools = self._get_tools(model_request_parameters)
|
|
267
|
-
tools, beta_features = self._add_builtin_tools(tools, model_request_parameters)
|
|
272
|
+
tools, mcp_servers, beta_features = self._add_builtin_tools(tools, model_request_parameters)
|
|
268
273
|
|
|
269
274
|
tool_choice: BetaToolChoiceParam | None
|
|
270
275
|
|
|
@@ -300,6 +305,7 @@ class AnthropicModel(Model):
|
|
|
300
305
|
model=self._model_name,
|
|
301
306
|
tools=tools or OMIT,
|
|
302
307
|
tool_choice=tool_choice or OMIT,
|
|
308
|
+
mcp_servers=mcp_servers or OMIT,
|
|
303
309
|
stream=stream,
|
|
304
310
|
thinking=model_settings.get('anthropic_thinking', OMIT),
|
|
305
311
|
stop_sequences=model_settings.get('stop_sequences', OMIT),
|
|
@@ -318,11 +324,14 @@ class AnthropicModel(Model):
|
|
|
318
324
|
def _process_response(self, response: BetaMessage) -> ModelResponse:
|
|
319
325
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
320
326
|
items: list[ModelResponsePart] = []
|
|
327
|
+
builtin_tool_calls: dict[str, BuiltinToolCallPart] = {}
|
|
321
328
|
for item in response.content:
|
|
322
329
|
if isinstance(item, BetaTextBlock):
|
|
323
330
|
items.append(TextPart(content=item.text))
|
|
324
331
|
elif isinstance(item, BetaServerToolUseBlock):
|
|
325
|
-
|
|
332
|
+
call_part = _map_server_tool_use_block(item, self.system)
|
|
333
|
+
builtin_tool_calls[call_part.tool_call_id] = call_part
|
|
334
|
+
items.append(call_part)
|
|
326
335
|
elif isinstance(item, BetaWebSearchToolResultBlock):
|
|
327
336
|
items.append(_map_web_search_tool_result_block(item, self.system))
|
|
328
337
|
elif isinstance(item, BetaCodeExecutionToolResultBlock):
|
|
@@ -333,6 +342,13 @@ class AnthropicModel(Model):
|
|
|
333
342
|
)
|
|
334
343
|
elif isinstance(item, BetaThinkingBlock):
|
|
335
344
|
items.append(ThinkingPart(content=item.thinking, signature=item.signature, provider_name=self.system))
|
|
345
|
+
elif isinstance(item, BetaMCPToolUseBlock):
|
|
346
|
+
call_part = _map_mcp_server_use_block(item, self.system)
|
|
347
|
+
builtin_tool_calls[call_part.tool_call_id] = call_part
|
|
348
|
+
items.append(call_part)
|
|
349
|
+
elif isinstance(item, BetaMCPToolResultBlock):
|
|
350
|
+
call_part = builtin_tool_calls.get(item.tool_use_id)
|
|
351
|
+
items.append(_map_mcp_server_result_block(item, call_part, self.system))
|
|
336
352
|
else:
|
|
337
353
|
assert isinstance(item, BetaToolUseBlock), f'unexpected item type {type(item)}'
|
|
338
354
|
items.append(
|
|
@@ -383,8 +399,9 @@ class AnthropicModel(Model):
|
|
|
383
399
|
|
|
384
400
|
def _add_builtin_tools(
|
|
385
401
|
self, tools: list[BetaToolUnionParam], model_request_parameters: ModelRequestParameters
|
|
386
|
-
) -> tuple[list[BetaToolUnionParam], list[str]]:
|
|
402
|
+
) -> tuple[list[BetaToolUnionParam], list[BetaRequestMCPServerURLDefinitionParam], list[str]]:
|
|
387
403
|
beta_features: list[str] = []
|
|
404
|
+
mcp_servers: list[BetaRequestMCPServerURLDefinitionParam] = []
|
|
388
405
|
for tool in model_request_parameters.builtin_tools:
|
|
389
406
|
if isinstance(tool, WebSearchTool):
|
|
390
407
|
user_location = UserLocation(type='approximate', **tool.user_location) if tool.user_location else None
|
|
@@ -408,11 +425,26 @@ class AnthropicModel(Model):
|
|
|
408
425
|
tools = [tool for tool in tools if tool['name'] != 'memory']
|
|
409
426
|
tools.append(BetaMemoryTool20250818Param(name='memory', type='memory_20250818'))
|
|
410
427
|
beta_features.append('context-management-2025-06-27')
|
|
428
|
+
elif isinstance(tool, MCPServerTool) and tool.url:
|
|
429
|
+
mcp_server_url_definition_param = BetaRequestMCPServerURLDefinitionParam(
|
|
430
|
+
type='url',
|
|
431
|
+
name=tool.id,
|
|
432
|
+
url=tool.url,
|
|
433
|
+
)
|
|
434
|
+
if tool.allowed_tools is not None: # pragma: no branch
|
|
435
|
+
mcp_server_url_definition_param['tool_configuration'] = BetaRequestMCPServerToolConfigurationParam(
|
|
436
|
+
enabled=bool(tool.allowed_tools),
|
|
437
|
+
allowed_tools=tool.allowed_tools,
|
|
438
|
+
)
|
|
439
|
+
if tool.authorization_token: # pragma: no cover
|
|
440
|
+
mcp_server_url_definition_param['authorization_token'] = tool.authorization_token
|
|
441
|
+
mcp_servers.append(mcp_server_url_definition_param)
|
|
442
|
+
beta_features.append('mcp-client-2025-04-04')
|
|
411
443
|
else: # pragma: no cover
|
|
412
444
|
raise UserError(
|
|
413
445
|
f'`{tool.__class__.__name__}` is not supported by `AnthropicModel`. If it should be, please file an issue.'
|
|
414
446
|
)
|
|
415
|
-
return tools, beta_features
|
|
447
|
+
return tools, mcp_servers, beta_features
|
|
416
448
|
|
|
417
449
|
async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[BetaMessageParam]]: # noqa: C901
|
|
418
450
|
"""Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
|
|
@@ -458,6 +490,8 @@ class AnthropicModel(Model):
|
|
|
458
490
|
| BetaCodeExecutionToolResultBlockParam
|
|
459
491
|
| BetaThinkingBlockParam
|
|
460
492
|
| BetaRedactedThinkingBlockParam
|
|
493
|
+
| BetaMCPToolUseBlockParam
|
|
494
|
+
| BetaMCPToolResultBlock
|
|
461
495
|
] = []
|
|
462
496
|
for response_part in m.parts:
|
|
463
497
|
if isinstance(response_part, TextPart):
|
|
@@ -508,7 +542,7 @@ class AnthropicModel(Model):
|
|
|
508
542
|
input=response_part.args_as_dict(),
|
|
509
543
|
)
|
|
510
544
|
assistant_content_params.append(server_tool_use_block_param)
|
|
511
|
-
elif response_part.tool_name == CodeExecutionTool.kind:
|
|
545
|
+
elif response_part.tool_name == CodeExecutionTool.kind:
|
|
512
546
|
server_tool_use_block_param = BetaServerToolUseBlockParam(
|
|
513
547
|
id=tool_use_id,
|
|
514
548
|
type='server_tool_use',
|
|
@@ -516,6 +550,21 @@ class AnthropicModel(Model):
|
|
|
516
550
|
input=response_part.args_as_dict(),
|
|
517
551
|
)
|
|
518
552
|
assistant_content_params.append(server_tool_use_block_param)
|
|
553
|
+
elif (
|
|
554
|
+
response_part.tool_name.startswith(MCPServerTool.kind)
|
|
555
|
+
and (server_id := response_part.tool_name.split(':', 1)[1])
|
|
556
|
+
and (args := response_part.args_as_dict())
|
|
557
|
+
and (tool_name := args.get('tool_name'))
|
|
558
|
+
and (tool_args := args.get('tool_args'))
|
|
559
|
+
): # pragma: no branch
|
|
560
|
+
mcp_tool_use_block_param = BetaMCPToolUseBlockParam(
|
|
561
|
+
id=tool_use_id,
|
|
562
|
+
type='mcp_tool_use',
|
|
563
|
+
server_name=server_id,
|
|
564
|
+
name=tool_name,
|
|
565
|
+
input=tool_args,
|
|
566
|
+
)
|
|
567
|
+
assistant_content_params.append(mcp_tool_use_block_param)
|
|
519
568
|
elif isinstance(response_part, BuiltinToolReturnPart):
|
|
520
569
|
if response_part.provider_name == self.system:
|
|
521
570
|
tool_use_id = _guard_tool_call_id(t=response_part)
|
|
@@ -547,6 +596,16 @@ class AnthropicModel(Model):
|
|
|
547
596
|
),
|
|
548
597
|
)
|
|
549
598
|
)
|
|
599
|
+
elif response_part.tool_name.startswith(MCPServerTool.kind) and isinstance(
|
|
600
|
+
response_part.content, dict
|
|
601
|
+
): # pragma: no branch
|
|
602
|
+
assistant_content_params.append(
|
|
603
|
+
BetaMCPToolResultBlock(
|
|
604
|
+
tool_use_id=tool_use_id,
|
|
605
|
+
type='mcp_tool_result',
|
|
606
|
+
**cast(dict[str, Any], response_part.content), # pyright: ignore[reportUnknownMemberType]
|
|
607
|
+
)
|
|
608
|
+
)
|
|
550
609
|
elif isinstance(response_part, FilePart): # pragma: no cover
|
|
551
610
|
# Files generated by models are not sent back to models that don't themselves generate files.
|
|
552
611
|
pass
|
|
@@ -661,6 +720,7 @@ class AnthropicStreamedResponse(StreamedResponse):
|
|
|
661
720
|
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
|
|
662
721
|
current_block: BetaContentBlock | None = None
|
|
663
722
|
|
|
723
|
+
builtin_tool_calls: dict[str, BuiltinToolCallPart] = {}
|
|
664
724
|
async for event in self._response:
|
|
665
725
|
if isinstance(event, BetaRawMessageStartEvent):
|
|
666
726
|
self._usage = _map_usage(event, self._provider_name, self._provider_url, self._model_name)
|
|
@@ -698,9 +758,11 @@ class AnthropicStreamedResponse(StreamedResponse):
|
|
|
698
758
|
if maybe_event is not None: # pragma: no branch
|
|
699
759
|
yield maybe_event
|
|
700
760
|
elif isinstance(current_block, BetaServerToolUseBlock):
|
|
761
|
+
call_part = _map_server_tool_use_block(current_block, self.provider_name)
|
|
762
|
+
builtin_tool_calls[call_part.tool_call_id] = call_part
|
|
701
763
|
yield self._parts_manager.handle_part(
|
|
702
764
|
vendor_part_id=event.index,
|
|
703
|
-
part=
|
|
765
|
+
part=call_part,
|
|
704
766
|
)
|
|
705
767
|
elif isinstance(current_block, BetaWebSearchToolResultBlock):
|
|
706
768
|
yield self._parts_manager.handle_part(
|
|
@@ -712,6 +774,32 @@ class AnthropicStreamedResponse(StreamedResponse):
|
|
|
712
774
|
vendor_part_id=event.index,
|
|
713
775
|
part=_map_code_execution_tool_result_block(current_block, self.provider_name),
|
|
714
776
|
)
|
|
777
|
+
elif isinstance(current_block, BetaMCPToolUseBlock):
|
|
778
|
+
call_part = _map_mcp_server_use_block(current_block, self.provider_name)
|
|
779
|
+
builtin_tool_calls[call_part.tool_call_id] = call_part
|
|
780
|
+
|
|
781
|
+
args_json = call_part.args_as_json_str()
|
|
782
|
+
# Drop the final `{}}` so that we can add tool args deltas
|
|
783
|
+
args_json_delta = args_json[:-3]
|
|
784
|
+
assert args_json_delta.endswith('"tool_args":'), (
|
|
785
|
+
f'Expected {args_json_delta!r} to end in `"tool_args":`'
|
|
786
|
+
)
|
|
787
|
+
|
|
788
|
+
yield self._parts_manager.handle_part(
|
|
789
|
+
vendor_part_id=event.index, part=replace(call_part, args=None)
|
|
790
|
+
)
|
|
791
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
792
|
+
vendor_part_id=event.index,
|
|
793
|
+
args=args_json_delta,
|
|
794
|
+
)
|
|
795
|
+
if maybe_event is not None: # pragma: no branch
|
|
796
|
+
yield maybe_event
|
|
797
|
+
elif isinstance(current_block, BetaMCPToolResultBlock):
|
|
798
|
+
call_part = builtin_tool_calls.get(current_block.tool_use_id)
|
|
799
|
+
yield self._parts_manager.handle_part(
|
|
800
|
+
vendor_part_id=event.index,
|
|
801
|
+
part=_map_mcp_server_result_block(current_block, call_part, self.provider_name),
|
|
802
|
+
)
|
|
715
803
|
|
|
716
804
|
elif isinstance(event, BetaRawContentBlockDeltaEvent):
|
|
717
805
|
if isinstance(event.delta, BetaTextDelta):
|
|
@@ -749,7 +837,16 @@ class AnthropicStreamedResponse(StreamedResponse):
|
|
|
749
837
|
self.provider_details = {'finish_reason': raw_finish_reason}
|
|
750
838
|
self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
|
|
751
839
|
|
|
752
|
-
elif isinstance(event, BetaRawContentBlockStopEvent
|
|
840
|
+
elif isinstance(event, BetaRawContentBlockStopEvent): # pragma: no branch
|
|
841
|
+
if isinstance(current_block, BetaMCPToolUseBlock):
|
|
842
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
843
|
+
vendor_part_id=event.index,
|
|
844
|
+
args='}',
|
|
845
|
+
)
|
|
846
|
+
if maybe_event is not None: # pragma: no branch
|
|
847
|
+
yield maybe_event
|
|
848
|
+
current_block = None
|
|
849
|
+
elif isinstance(event, BetaRawMessageStopEvent): # pragma: no branch
|
|
753
850
|
current_block = None
|
|
754
851
|
|
|
755
852
|
@property
|
|
@@ -817,3 +914,27 @@ def _map_code_execution_tool_result_block(
|
|
|
817
914
|
content=code_execution_tool_result_content_ta.dump_python(item.content, mode='json'),
|
|
818
915
|
tool_call_id=item.tool_use_id,
|
|
819
916
|
)
|
|
917
|
+
|
|
918
|
+
|
|
919
|
+
def _map_mcp_server_use_block(item: BetaMCPToolUseBlock, provider_name: str) -> BuiltinToolCallPart:
|
|
920
|
+
return BuiltinToolCallPart(
|
|
921
|
+
provider_name=provider_name,
|
|
922
|
+
tool_name=':'.join([MCPServerTool.kind, item.server_name]),
|
|
923
|
+
args={
|
|
924
|
+
'action': 'call_tool',
|
|
925
|
+
'tool_name': item.name,
|
|
926
|
+
'tool_args': cast(dict[str, Any], item.input),
|
|
927
|
+
},
|
|
928
|
+
tool_call_id=item.id,
|
|
929
|
+
)
|
|
930
|
+
|
|
931
|
+
|
|
932
|
+
def _map_mcp_server_result_block(
|
|
933
|
+
item: BetaMCPToolResultBlock, call_part: BuiltinToolCallPart | None, provider_name: str
|
|
934
|
+
) -> BuiltinToolReturnPart:
|
|
935
|
+
return BuiltinToolReturnPart(
|
|
936
|
+
provider_name=provider_name,
|
|
937
|
+
tool_name=call_part.tool_name if call_part else MCPServerTool.kind,
|
|
938
|
+
content=item.model_dump(mode='json', include={'content', 'is_error'}),
|
|
939
|
+
tool_call_id=item.tool_use_id,
|
|
940
|
+
)
|
pydantic_ai/models/bedrock.py
CHANGED
|
@@ -207,7 +207,7 @@ class BedrockConverseModel(Model):
|
|
|
207
207
|
self,
|
|
208
208
|
model_name: BedrockModelName,
|
|
209
209
|
*,
|
|
210
|
-
provider: Literal['bedrock'] | Provider[BaseClient] = 'bedrock',
|
|
210
|
+
provider: Literal['bedrock', 'gateway'] | Provider[BaseClient] = 'bedrock',
|
|
211
211
|
profile: ModelProfileSpec | None = None,
|
|
212
212
|
settings: ModelSettings | None = None,
|
|
213
213
|
):
|
|
@@ -226,7 +226,7 @@ class BedrockConverseModel(Model):
|
|
|
226
226
|
self._model_name = model_name
|
|
227
227
|
|
|
228
228
|
if isinstance(provider, str):
|
|
229
|
-
provider = infer_provider(provider)
|
|
229
|
+
provider = infer_provider('gateway/bedrock' if provider == 'gateway' else provider)
|
|
230
230
|
self._provider = provider
|
|
231
231
|
self.client = cast('BedrockRuntimeClient', provider.client)
|
|
232
232
|
|
|
@@ -701,8 +701,8 @@ class BedrockStreamedResponse(StreamedResponse):
|
|
|
701
701
|
signature=signature,
|
|
702
702
|
provider_name=self.provider_name if signature else None,
|
|
703
703
|
)
|
|
704
|
-
if
|
|
705
|
-
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=
|
|
704
|
+
if text := delta.get('text'):
|
|
705
|
+
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=text)
|
|
706
706
|
if maybe_event is not None: # pragma: no branch
|
|
707
707
|
yield maybe_event
|
|
708
708
|
if 'toolUse' in delta:
|
pydantic_ai/models/cohere.py
CHANGED
|
@@ -62,15 +62,8 @@ except ImportError as _import_error:
|
|
|
62
62
|
LatestCohereModelNames = Literal[
|
|
63
63
|
'c4ai-aya-expanse-32b',
|
|
64
64
|
'c4ai-aya-expanse-8b',
|
|
65
|
-
'command',
|
|
66
|
-
'command-light',
|
|
67
|
-
'command-light-nightly',
|
|
68
65
|
'command-nightly',
|
|
69
|
-
'command-r',
|
|
70
|
-
'command-r-03-2024',
|
|
71
66
|
'command-r-08-2024',
|
|
72
|
-
'command-r-plus',
|
|
73
|
-
'command-r-plus-04-2024',
|
|
74
67
|
'command-r-plus-08-2024',
|
|
75
68
|
'command-r7b-12-2024',
|
|
76
69
|
]
|
pydantic_ai/models/gemini.py
CHANGED
|
@@ -38,7 +38,7 @@ from ..messages import (
|
|
|
38
38
|
VideoUrl,
|
|
39
39
|
)
|
|
40
40
|
from ..profiles import ModelProfileSpec
|
|
41
|
-
from ..providers import Provider
|
|
41
|
+
from ..providers import Provider
|
|
42
42
|
from ..settings import ModelSettings
|
|
43
43
|
from ..tools import ToolDefinition
|
|
44
44
|
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent
|
|
@@ -131,7 +131,14 @@ class GeminiModel(Model):
|
|
|
131
131
|
self._model_name = model_name
|
|
132
132
|
|
|
133
133
|
if isinstance(provider, str):
|
|
134
|
-
provider
|
|
134
|
+
if provider == 'google-gla':
|
|
135
|
+
from pydantic_ai.providers.google_gla import GoogleGLAProvider # type: ignore[reportDeprecated]
|
|
136
|
+
|
|
137
|
+
provider = GoogleGLAProvider() # type: ignore[reportDeprecated]
|
|
138
|
+
else:
|
|
139
|
+
from pydantic_ai.providers.google_vertex import GoogleVertexProvider # type: ignore[reportDeprecated]
|
|
140
|
+
|
|
141
|
+
provider = GoogleVertexProvider() # type: ignore[reportDeprecated]
|
|
135
142
|
self._provider = provider
|
|
136
143
|
self.client = provider.client
|
|
137
144
|
self._url = str(self.client.base_url)
|