pydantic-ai-slim 1.0.7__py3-none-any.whl → 1.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +43 -23
- pydantic_ai/_cli.py +1 -1
- pydantic_ai/_otel_messages.py +2 -0
- pydantic_ai/_parts_manager.py +82 -12
- pydantic_ai/_run_context.py +8 -1
- pydantic_ai/_tool_manager.py +1 -0
- pydantic_ai/ag_ui.py +93 -40
- pydantic_ai/agent/__init__.py +2 -4
- pydantic_ai/builtin_tools.py +12 -0
- pydantic_ai/durable_exec/temporal/_model.py +14 -6
- pydantic_ai/durable_exec/temporal/_run_context.py +2 -1
- pydantic_ai/messages.py +69 -30
- pydantic_ai/models/__init__.py +4 -6
- pydantic_ai/models/anthropic.py +119 -45
- pydantic_ai/models/function.py +17 -8
- pydantic_ai/models/google.py +105 -16
- pydantic_ai/models/groq.py +68 -17
- pydantic_ai/models/openai.py +262 -41
- pydantic_ai/providers/__init__.py +1 -1
- pydantic_ai/result.py +24 -8
- pydantic_ai/toolsets/function.py +8 -2
- pydantic_ai/usage.py +2 -2
- {pydantic_ai_slim-1.0.7.dist-info → pydantic_ai_slim-1.0.9.dist-info}/METADATA +5 -5
- {pydantic_ai_slim-1.0.7.dist-info → pydantic_ai_slim-1.0.9.dist-info}/RECORD +27 -27
- {pydantic_ai_slim-1.0.7.dist-info → pydantic_ai_slim-1.0.9.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-1.0.7.dist-info → pydantic_ai_slim-1.0.9.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-1.0.7.dist-info → pydantic_ai_slim-1.0.9.dist-info}/licenses/LICENSE +0 -0
|
@@ -4,7 +4,7 @@ from collections.abc import AsyncIterator, Callable
|
|
|
4
4
|
from contextlib import asynccontextmanager
|
|
5
5
|
from dataclasses import dataclass
|
|
6
6
|
from datetime import datetime
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any, cast
|
|
8
8
|
|
|
9
9
|
from pydantic import ConfigDict, with_config
|
|
10
10
|
from temporalio import activity, workflow
|
|
@@ -30,7 +30,8 @@ from ._run_context import TemporalRunContext
|
|
|
30
30
|
@with_config(ConfigDict(arbitrary_types_allowed=True))
|
|
31
31
|
class _RequestParams:
|
|
32
32
|
messages: list[ModelMessage]
|
|
33
|
-
model_settings
|
|
33
|
+
# `model_settings` can't be a `ModelSettings` because Temporal would end up dropping fields only defined on its subclasses.
|
|
34
|
+
model_settings: dict[str, Any] | None
|
|
34
35
|
model_request_parameters: ModelRequestParameters
|
|
35
36
|
serialized_run_context: Any
|
|
36
37
|
|
|
@@ -82,7 +83,11 @@ class TemporalModel(WrapperModel):
|
|
|
82
83
|
|
|
83
84
|
@activity.defn(name=f'{activity_name_prefix}__model_request')
|
|
84
85
|
async def request_activity(params: _RequestParams) -> ModelResponse:
|
|
85
|
-
return await self.wrapped.request(
|
|
86
|
+
return await self.wrapped.request(
|
|
87
|
+
params.messages,
|
|
88
|
+
cast(ModelSettings | None, params.model_settings),
|
|
89
|
+
params.model_request_parameters,
|
|
90
|
+
)
|
|
86
91
|
|
|
87
92
|
self.request_activity = request_activity
|
|
88
93
|
|
|
@@ -92,7 +97,10 @@ class TemporalModel(WrapperModel):
|
|
|
92
97
|
|
|
93
98
|
run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
|
|
94
99
|
async with self.wrapped.request_stream(
|
|
95
|
-
params.messages,
|
|
100
|
+
params.messages,
|
|
101
|
+
cast(ModelSettings | None, params.model_settings),
|
|
102
|
+
params.model_request_parameters,
|
|
103
|
+
run_context,
|
|
96
104
|
) as streamed_response:
|
|
97
105
|
await self.event_stream_handler(run_context, streamed_response)
|
|
98
106
|
|
|
@@ -124,7 +132,7 @@ class TemporalModel(WrapperModel):
|
|
|
124
132
|
activity=self.request_activity,
|
|
125
133
|
arg=_RequestParams(
|
|
126
134
|
messages=messages,
|
|
127
|
-
model_settings=model_settings,
|
|
135
|
+
model_settings=cast(dict[str, Any] | None, model_settings),
|
|
128
136
|
model_request_parameters=model_request_parameters,
|
|
129
137
|
serialized_run_context=None,
|
|
130
138
|
),
|
|
@@ -161,7 +169,7 @@ class TemporalModel(WrapperModel):
|
|
|
161
169
|
args=[
|
|
162
170
|
_RequestParams(
|
|
163
171
|
messages=messages,
|
|
164
|
-
model_settings=model_settings,
|
|
172
|
+
model_settings=cast(dict[str, Any] | None, model_settings),
|
|
165
173
|
model_request_parameters=model_request_parameters,
|
|
166
174
|
serialized_run_context=serialized_run_context,
|
|
167
175
|
),
|
|
@@ -9,7 +9,7 @@ from pydantic_ai.tools import AgentDepsT, RunContext
|
|
|
9
9
|
class TemporalRunContext(RunContext[AgentDepsT]):
|
|
10
10
|
"""The [`RunContext`][pydantic_ai.tools.RunContext] subclass to use to serialize and deserialize the run context for use inside a Temporal activity.
|
|
11
11
|
|
|
12
|
-
By default, only the `deps`, `retries`, `tool_call_id`, `tool_name`, `tool_call_approved`, `retry` and `run_step` attributes will be available.
|
|
12
|
+
By default, only the `deps`, `retries`, `tool_call_id`, `tool_name`, `tool_call_approved`, `retry`, `max_retries` and `run_step` attributes will be available.
|
|
13
13
|
To make another attribute available, create a `TemporalRunContext` subclass with a custom `serialize_run_context` class method that returns a dictionary that includes the attribute and pass it to [`TemporalAgent`][pydantic_ai.durable_exec.temporal.TemporalAgent].
|
|
14
14
|
"""
|
|
15
15
|
|
|
@@ -42,6 +42,7 @@ class TemporalRunContext(RunContext[AgentDepsT]):
|
|
|
42
42
|
'tool_name': ctx.tool_name,
|
|
43
43
|
'tool_call_approved': ctx.tool_call_approved,
|
|
44
44
|
'retry': ctx.retry,
|
|
45
|
+
'max_retries': ctx.max_retries,
|
|
45
46
|
'run_step': ctx.run_step,
|
|
46
47
|
}
|
|
47
48
|
|
pydantic_ai/messages.py
CHANGED
|
@@ -668,8 +668,11 @@ class BaseToolReturnPart:
|
|
|
668
668
|
content: Any
|
|
669
669
|
"""The return value."""
|
|
670
670
|
|
|
671
|
-
tool_call_id: str
|
|
672
|
-
"""The tool call identifier, this is used by some models including OpenAI.
|
|
671
|
+
tool_call_id: str = field(default_factory=_generate_tool_call_id)
|
|
672
|
+
"""The tool call identifier, this is used by some models including OpenAI.
|
|
673
|
+
|
|
674
|
+
In case the tool call id is not provided by the model, Pydantic AI will generate a random one.
|
|
675
|
+
"""
|
|
673
676
|
|
|
674
677
|
_: KW_ONLY
|
|
675
678
|
|
|
@@ -708,14 +711,16 @@ class BaseToolReturnPart:
|
|
|
708
711
|
def otel_message_parts(self, settings: InstrumentationSettings) -> list[_otel_messages.MessagePart]:
|
|
709
712
|
from .models.instrumented import InstrumentedModel
|
|
710
713
|
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
714
|
+
part = _otel_messages.ToolCallResponsePart(
|
|
715
|
+
type='tool_call_response',
|
|
716
|
+
id=self.tool_call_id,
|
|
717
|
+
name=self.tool_name,
|
|
718
|
+
)
|
|
719
|
+
|
|
720
|
+
if settings.include_content and self.content is not None:
|
|
721
|
+
part['result'] = InstrumentedModel.serialize_any(self.content)
|
|
722
|
+
|
|
723
|
+
return [part]
|
|
719
724
|
|
|
720
725
|
def has_content(self) -> bool:
|
|
721
726
|
"""Return `True` if the tool return has content."""
|
|
@@ -820,14 +825,16 @@ class RetryPromptPart:
|
|
|
820
825
|
if self.tool_name is None:
|
|
821
826
|
return [_otel_messages.TextPart(type='text', content=self.model_response())]
|
|
822
827
|
else:
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
828
|
+
part = _otel_messages.ToolCallResponsePart(
|
|
829
|
+
type='tool_call_response',
|
|
830
|
+
id=self.tool_call_id,
|
|
831
|
+
name=self.tool_name,
|
|
832
|
+
)
|
|
833
|
+
|
|
834
|
+
if settings.include_content:
|
|
835
|
+
part['result'] = self.model_response()
|
|
836
|
+
|
|
837
|
+
return [part]
|
|
831
838
|
|
|
832
839
|
__repr__ = _utils.dataclasses_no_defaults_repr
|
|
833
840
|
|
|
@@ -1131,8 +1138,10 @@ class ModelResponse:
|
|
|
1131
1138
|
**({'content': part.content} if settings.include_content else {}),
|
|
1132
1139
|
)
|
|
1133
1140
|
)
|
|
1134
|
-
elif isinstance(part,
|
|
1141
|
+
elif isinstance(part, BaseToolCallPart):
|
|
1135
1142
|
call_part = _otel_messages.ToolCallPart(type='tool_call', id=part.tool_call_id, name=part.tool_name)
|
|
1143
|
+
if isinstance(part, BuiltinToolCallPart):
|
|
1144
|
+
call_part['builtin'] = True
|
|
1136
1145
|
if settings.include_content and part.args is not None:
|
|
1137
1146
|
from .models.instrumented import InstrumentedModel
|
|
1138
1147
|
|
|
@@ -1142,6 +1151,23 @@ class ModelResponse:
|
|
|
1142
1151
|
call_part['arguments'] = {k: InstrumentedModel.serialize_any(v) for k, v in part.args.items()}
|
|
1143
1152
|
|
|
1144
1153
|
parts.append(call_part)
|
|
1154
|
+
elif isinstance(part, BuiltinToolReturnPart):
|
|
1155
|
+
return_part = _otel_messages.ToolCallResponsePart(
|
|
1156
|
+
type='tool_call_response',
|
|
1157
|
+
id=part.tool_call_id,
|
|
1158
|
+
name=part.tool_name,
|
|
1159
|
+
builtin=True,
|
|
1160
|
+
)
|
|
1161
|
+
if settings.include_content and part.content is not None: # pragma: no branch
|
|
1162
|
+
from .models.instrumented import InstrumentedModel
|
|
1163
|
+
|
|
1164
|
+
return_part['result'] = (
|
|
1165
|
+
part.content
|
|
1166
|
+
if isinstance(part.content, str)
|
|
1167
|
+
else {k: InstrumentedModel.serialize_any(v) for k, v in part.content.items()}
|
|
1168
|
+
)
|
|
1169
|
+
|
|
1170
|
+
parts.append(return_part)
|
|
1145
1171
|
return parts
|
|
1146
1172
|
|
|
1147
1173
|
@property
|
|
@@ -1299,35 +1325,39 @@ class ToolCallPartDelta:
|
|
|
1299
1325
|
return ToolCallPart(self.tool_name_delta, self.args_delta, self.tool_call_id or _generate_tool_call_id())
|
|
1300
1326
|
|
|
1301
1327
|
@overload
|
|
1302
|
-
def apply(self, part: ModelResponsePart) -> ToolCallPart: ...
|
|
1328
|
+
def apply(self, part: ModelResponsePart) -> ToolCallPart | BuiltinToolCallPart: ...
|
|
1303
1329
|
|
|
1304
1330
|
@overload
|
|
1305
|
-
def apply(
|
|
1331
|
+
def apply(
|
|
1332
|
+
self, part: ModelResponsePart | ToolCallPartDelta
|
|
1333
|
+
) -> ToolCallPart | BuiltinToolCallPart | ToolCallPartDelta: ...
|
|
1306
1334
|
|
|
1307
|
-
def apply(
|
|
1335
|
+
def apply(
|
|
1336
|
+
self, part: ModelResponsePart | ToolCallPartDelta
|
|
1337
|
+
) -> ToolCallPart | BuiltinToolCallPart | ToolCallPartDelta:
|
|
1308
1338
|
"""Apply this delta to a part or delta, returning a new part or delta with the changes applied.
|
|
1309
1339
|
|
|
1310
1340
|
Args:
|
|
1311
1341
|
part: The existing model response part or delta to update.
|
|
1312
1342
|
|
|
1313
1343
|
Returns:
|
|
1314
|
-
Either a new `ToolCallPart` or an updated `ToolCallPartDelta`.
|
|
1344
|
+
Either a new `ToolCallPart` or `BuiltinToolCallPart`, or an updated `ToolCallPartDelta`.
|
|
1315
1345
|
|
|
1316
1346
|
Raises:
|
|
1317
|
-
ValueError: If `part` is neither a `ToolCallPart` nor a `ToolCallPartDelta`.
|
|
1347
|
+
ValueError: If `part` is neither a `ToolCallPart`, `BuiltinToolCallPart`, nor a `ToolCallPartDelta`.
|
|
1318
1348
|
UnexpectedModelBehavior: If applying JSON deltas to dict arguments or vice versa.
|
|
1319
1349
|
"""
|
|
1320
|
-
if isinstance(part, ToolCallPart):
|
|
1350
|
+
if isinstance(part, ToolCallPart | BuiltinToolCallPart):
|
|
1321
1351
|
return self._apply_to_part(part)
|
|
1322
1352
|
|
|
1323
1353
|
if isinstance(part, ToolCallPartDelta):
|
|
1324
1354
|
return self._apply_to_delta(part)
|
|
1325
1355
|
|
|
1326
1356
|
raise ValueError( # pragma: no cover
|
|
1327
|
-
f'Can only apply ToolCallPartDeltas to ToolCallParts or ToolCallPartDeltas, not {part}'
|
|
1357
|
+
f'Can only apply ToolCallPartDeltas to ToolCallParts, BuiltinToolCallParts, or ToolCallPartDeltas, not {part}'
|
|
1328
1358
|
)
|
|
1329
1359
|
|
|
1330
|
-
def _apply_to_delta(self, delta: ToolCallPartDelta) -> ToolCallPart | ToolCallPartDelta:
|
|
1360
|
+
def _apply_to_delta(self, delta: ToolCallPartDelta) -> ToolCallPart | BuiltinToolCallPart | ToolCallPartDelta:
|
|
1331
1361
|
"""Internal helper to apply this delta to another delta."""
|
|
1332
1362
|
if self.tool_name_delta:
|
|
1333
1363
|
# Append incremental text to the existing tool_name_delta
|
|
@@ -1358,8 +1388,8 @@ class ToolCallPartDelta:
|
|
|
1358
1388
|
|
|
1359
1389
|
return delta
|
|
1360
1390
|
|
|
1361
|
-
def _apply_to_part(self, part: ToolCallPart) -> ToolCallPart:
|
|
1362
|
-
"""Internal helper to apply this delta directly to a `ToolCallPart`."""
|
|
1391
|
+
def _apply_to_part(self, part: ToolCallPart | BuiltinToolCallPart) -> ToolCallPart | BuiltinToolCallPart:
|
|
1392
|
+
"""Internal helper to apply this delta directly to a `ToolCallPart` or `BuiltinToolCallPart`."""
|
|
1363
1393
|
if self.tool_name_delta:
|
|
1364
1394
|
# Append incremental text to the existing tool_name
|
|
1365
1395
|
tool_name = part.tool_name + self.tool_name_delta
|
|
@@ -1491,6 +1521,9 @@ class FunctionToolResultEvent:
|
|
|
1491
1521
|
__repr__ = _utils.dataclasses_no_defaults_repr
|
|
1492
1522
|
|
|
1493
1523
|
|
|
1524
|
+
@deprecated(
|
|
1525
|
+
'`BuiltinToolCallEvent` is deprecated, look for `PartStartEvent` and `PartDeltaEvent` with `BuiltinToolCallPart` instead.'
|
|
1526
|
+
)
|
|
1494
1527
|
@dataclass(repr=False)
|
|
1495
1528
|
class BuiltinToolCallEvent:
|
|
1496
1529
|
"""An event indicating the start to a call to a built-in tool."""
|
|
@@ -1504,6 +1537,9 @@ class BuiltinToolCallEvent:
|
|
|
1504
1537
|
"""Event type identifier, used as a discriminator."""
|
|
1505
1538
|
|
|
1506
1539
|
|
|
1540
|
+
@deprecated(
|
|
1541
|
+
'`BuiltinToolResultEvent` is deprecated, look for `PartStartEvent` and `PartDeltaEvent` with `BuiltinToolReturnPart` instead.'
|
|
1542
|
+
)
|
|
1507
1543
|
@dataclass(repr=False)
|
|
1508
1544
|
class BuiltinToolResultEvent:
|
|
1509
1545
|
"""An event indicating the result of a built-in tool call."""
|
|
@@ -1518,7 +1554,10 @@ class BuiltinToolResultEvent:
|
|
|
1518
1554
|
|
|
1519
1555
|
|
|
1520
1556
|
HandleResponseEvent = Annotated[
|
|
1521
|
-
FunctionToolCallEvent
|
|
1557
|
+
FunctionToolCallEvent
|
|
1558
|
+
| FunctionToolResultEvent
|
|
1559
|
+
| BuiltinToolCallEvent # pyright: ignore[reportDeprecated]
|
|
1560
|
+
| BuiltinToolResultEvent, # pyright: ignore[reportDeprecated]
|
|
1522
1561
|
pydantic.Discriminator('event_kind'),
|
|
1523
1562
|
]
|
|
1524
1563
|
"""An event yielded when handling a model response, indicating tool calls and results."""
|
pydantic_ai/models/__init__.py
CHANGED
|
@@ -783,6 +783,8 @@ def cached_async_http_client(*, provider: str | None = None, timeout: int = 600,
|
|
|
783
783
|
The client is cached based on the provider parameter. If provider is None, it's used for non-provider specific
|
|
784
784
|
requests (like downloading images). Multiple agents and calls can share the same client when they use the same provider.
|
|
785
785
|
|
|
786
|
+
Each client will get its own transport with its own connection pool. The default pool size is defined by `httpx.DEFAULT_LIMITS`.
|
|
787
|
+
|
|
786
788
|
There are good reasons why in production you should use a `httpx.AsyncClient` as an async context manager as
|
|
787
789
|
described in [encode/httpx#2026](https://github.com/encode/httpx/pull/2026), but when experimenting or showing
|
|
788
790
|
examples, it's very useful not to.
|
|
@@ -793,6 +795,8 @@ def cached_async_http_client(*, provider: str | None = None, timeout: int = 600,
|
|
|
793
795
|
client = _cached_async_http_client(provider=provider, timeout=timeout, connect=connect)
|
|
794
796
|
if client.is_closed:
|
|
795
797
|
# This happens if the context manager is used, so we need to create a new client.
|
|
798
|
+
# Since there is no API from `functools.cache` to clear the cache for a specific
|
|
799
|
+
# key, clear the entire cache here as a workaround.
|
|
796
800
|
_cached_async_http_client.cache_clear()
|
|
797
801
|
client = _cached_async_http_client(provider=provider, timeout=timeout, connect=connect)
|
|
798
802
|
return client
|
|
@@ -801,17 +805,11 @@ def cached_async_http_client(*, provider: str | None = None, timeout: int = 600,
|
|
|
801
805
|
@cache
|
|
802
806
|
def _cached_async_http_client(provider: str | None, timeout: int = 600, connect: int = 5) -> httpx.AsyncClient:
|
|
803
807
|
return httpx.AsyncClient(
|
|
804
|
-
transport=_cached_async_http_transport(),
|
|
805
808
|
timeout=httpx.Timeout(timeout=timeout, connect=connect),
|
|
806
809
|
headers={'User-Agent': get_user_agent()},
|
|
807
810
|
)
|
|
808
811
|
|
|
809
812
|
|
|
810
|
-
@cache
|
|
811
|
-
def _cached_async_http_transport() -> httpx.AsyncHTTPTransport:
|
|
812
|
-
return httpx.AsyncHTTPTransport()
|
|
813
|
-
|
|
814
|
-
|
|
815
813
|
DataT = TypeVar('DataT', str, bytes)
|
|
816
814
|
|
|
817
815
|
|
pydantic_ai/models/anthropic.py
CHANGED
|
@@ -7,6 +7,7 @@ from dataclasses import dataclass, field
|
|
|
7
7
|
from datetime import datetime
|
|
8
8
|
from typing import Any, Literal, cast, overload
|
|
9
9
|
|
|
10
|
+
from pydantic import TypeAdapter
|
|
10
11
|
from typing_extensions import assert_never
|
|
11
12
|
|
|
12
13
|
from pydantic_ai.builtin_tools import CodeExecutionTool, WebSearchTool
|
|
@@ -60,7 +61,9 @@ try:
|
|
|
60
61
|
BetaCitationsDelta,
|
|
61
62
|
BetaCodeExecutionTool20250522Param,
|
|
62
63
|
BetaCodeExecutionToolResultBlock,
|
|
64
|
+
BetaCodeExecutionToolResultBlockContent,
|
|
63
65
|
BetaCodeExecutionToolResultBlockParam,
|
|
66
|
+
BetaCodeExecutionToolResultBlockParamContentParam,
|
|
64
67
|
BetaContentBlock,
|
|
65
68
|
BetaContentBlockParam,
|
|
66
69
|
BetaImageBlockParam,
|
|
@@ -97,7 +100,9 @@ try:
|
|
|
97
100
|
BetaToolUseBlockParam,
|
|
98
101
|
BetaWebSearchTool20250305Param,
|
|
99
102
|
BetaWebSearchToolResultBlock,
|
|
103
|
+
BetaWebSearchToolResultBlockContent,
|
|
100
104
|
BetaWebSearchToolResultBlockParam,
|
|
105
|
+
BetaWebSearchToolResultBlockParamContentParam,
|
|
101
106
|
)
|
|
102
107
|
from anthropic.types.beta.beta_web_search_tool_20250305_param import UserLocation
|
|
103
108
|
from anthropic.types.model_param import ModelParam
|
|
@@ -302,24 +307,12 @@ class AnthropicModel(Model):
|
|
|
302
307
|
for item in response.content:
|
|
303
308
|
if isinstance(item, BetaTextBlock):
|
|
304
309
|
items.append(TextPart(content=item.text))
|
|
305
|
-
elif isinstance(item, BetaWebSearchToolResultBlock | BetaCodeExecutionToolResultBlock):
|
|
306
|
-
items.append(
|
|
307
|
-
BuiltinToolReturnPart(
|
|
308
|
-
provider_name=self.system,
|
|
309
|
-
tool_name=item.type,
|
|
310
|
-
content=item.content,
|
|
311
|
-
tool_call_id=item.tool_use_id,
|
|
312
|
-
)
|
|
313
|
-
)
|
|
314
310
|
elif isinstance(item, BetaServerToolUseBlock):
|
|
315
|
-
items.append(
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
tool_call_id=item.id,
|
|
321
|
-
)
|
|
322
|
-
)
|
|
311
|
+
items.append(_map_server_tool_use_block(item, self.system))
|
|
312
|
+
elif isinstance(item, BetaWebSearchToolResultBlock):
|
|
313
|
+
items.append(_map_web_search_tool_result_block(item, self.system))
|
|
314
|
+
elif isinstance(item, BetaCodeExecutionToolResultBlock):
|
|
315
|
+
items.append(_map_code_execution_tool_result_block(item, self.system))
|
|
323
316
|
elif isinstance(item, BetaRedactedThinkingBlock):
|
|
324
317
|
items.append(
|
|
325
318
|
ThinkingPart(id='redacted_thinking', content='', signature=item.data, provider_name=self.system)
|
|
@@ -485,27 +478,54 @@ class AnthropicModel(Model):
|
|
|
485
478
|
)
|
|
486
479
|
elif isinstance(response_part, BuiltinToolCallPart):
|
|
487
480
|
if response_part.provider_name == self.system:
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
481
|
+
tool_use_id = _guard_tool_call_id(t=response_part)
|
|
482
|
+
if response_part.tool_name == WebSearchTool.kind:
|
|
483
|
+
server_tool_use_block_param = BetaServerToolUseBlockParam(
|
|
484
|
+
id=tool_use_id,
|
|
485
|
+
type='server_tool_use',
|
|
486
|
+
name='web_search',
|
|
487
|
+
input=response_part.args_as_dict(),
|
|
488
|
+
)
|
|
489
|
+
assistant_content_params.append(server_tool_use_block_param)
|
|
490
|
+
elif response_part.tool_name == CodeExecutionTool.kind: # pragma: no branch
|
|
491
|
+
server_tool_use_block_param = BetaServerToolUseBlockParam(
|
|
492
|
+
id=tool_use_id,
|
|
493
|
+
type='server_tool_use',
|
|
494
|
+
name='code_execution',
|
|
495
|
+
input=response_part.args_as_dict(),
|
|
496
|
+
)
|
|
497
|
+
assistant_content_params.append(server_tool_use_block_param)
|
|
495
498
|
elif isinstance(response_part, BuiltinToolReturnPart):
|
|
496
499
|
if response_part.provider_name == self.system:
|
|
497
500
|
tool_use_id = _guard_tool_call_id(t=response_part)
|
|
498
|
-
if response_part.tool_name
|
|
499
|
-
|
|
500
|
-
|
|
501
|
+
if response_part.tool_name in (
|
|
502
|
+
WebSearchTool.kind,
|
|
503
|
+
'web_search_tool_result', # Backward compatibility
|
|
504
|
+
) and isinstance(response_part.content, dict | list):
|
|
505
|
+
assistant_content_params.append(
|
|
506
|
+
BetaWebSearchToolResultBlockParam(
|
|
507
|
+
tool_use_id=tool_use_id,
|
|
508
|
+
type='web_search_tool_result',
|
|
509
|
+
content=cast(
|
|
510
|
+
BetaWebSearchToolResultBlockParamContentParam,
|
|
511
|
+
response_part.content, # pyright: ignore[reportUnknownMemberType]
|
|
512
|
+
),
|
|
513
|
+
)
|
|
501
514
|
)
|
|
502
|
-
elif response_part.tool_name
|
|
503
|
-
|
|
504
|
-
|
|
515
|
+
elif response_part.tool_name in ( # pragma: no branch
|
|
516
|
+
CodeExecutionTool.kind,
|
|
517
|
+
'code_execution_tool_result', # Backward compatibility
|
|
518
|
+
) and isinstance(response_part.content, dict):
|
|
519
|
+
assistant_content_params.append(
|
|
520
|
+
BetaCodeExecutionToolResultBlockParam(
|
|
521
|
+
tool_use_id=tool_use_id,
|
|
522
|
+
type='code_execution_tool_result',
|
|
523
|
+
content=cast(
|
|
524
|
+
BetaCodeExecutionToolResultBlockParamContentParam,
|
|
525
|
+
response_part.content, # pyright: ignore[reportUnknownMemberType]
|
|
526
|
+
),
|
|
527
|
+
)
|
|
505
528
|
)
|
|
506
|
-
else:
|
|
507
|
-
raise ValueError(f'Unsupported tool name: {response_part.tool_name}')
|
|
508
|
-
assistant_content_params.append(server_tool_result_block_param)
|
|
509
529
|
else:
|
|
510
530
|
assert_never(response_part)
|
|
511
531
|
if len(assistant_content_params) > 0:
|
|
@@ -646,7 +666,7 @@ class AnthropicStreamedResponse(StreamedResponse):
|
|
|
646
666
|
)
|
|
647
667
|
elif isinstance(current_block, BetaToolUseBlock):
|
|
648
668
|
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
649
|
-
vendor_part_id=
|
|
669
|
+
vendor_part_id=event.index,
|
|
650
670
|
tool_name=current_block.name,
|
|
651
671
|
args=cast(dict[str, Any], current_block.input) or None,
|
|
652
672
|
tool_call_id=current_block.id,
|
|
@@ -654,7 +674,20 @@ class AnthropicStreamedResponse(StreamedResponse):
|
|
|
654
674
|
if maybe_event is not None: # pragma: no branch
|
|
655
675
|
yield maybe_event
|
|
656
676
|
elif isinstance(current_block, BetaServerToolUseBlock):
|
|
657
|
-
|
|
677
|
+
yield self._parts_manager.handle_builtin_tool_call_part(
|
|
678
|
+
vendor_part_id=event.index,
|
|
679
|
+
part=_map_server_tool_use_block(current_block, self.provider_name),
|
|
680
|
+
)
|
|
681
|
+
elif isinstance(current_block, BetaWebSearchToolResultBlock):
|
|
682
|
+
yield self._parts_manager.handle_builtin_tool_return_part(
|
|
683
|
+
vendor_part_id=event.index,
|
|
684
|
+
part=_map_web_search_tool_result_block(current_block, self.provider_name),
|
|
685
|
+
)
|
|
686
|
+
elif isinstance(current_block, BetaCodeExecutionToolResultBlock):
|
|
687
|
+
yield self._parts_manager.handle_builtin_tool_return_part(
|
|
688
|
+
vendor_part_id=event.index,
|
|
689
|
+
part=_map_code_execution_tool_result_block(current_block, self.provider_name),
|
|
690
|
+
)
|
|
658
691
|
|
|
659
692
|
elif isinstance(event, BetaRawContentBlockDeltaEvent):
|
|
660
693
|
if isinstance(event.delta, BetaTextDelta):
|
|
@@ -675,21 +708,13 @@ class AnthropicStreamedResponse(StreamedResponse):
|
|
|
675
708
|
signature=event.delta.signature,
|
|
676
709
|
provider_name=self.provider_name,
|
|
677
710
|
)
|
|
678
|
-
elif (
|
|
679
|
-
current_block
|
|
680
|
-
and event.delta.type == 'input_json_delta'
|
|
681
|
-
and isinstance(current_block, BetaToolUseBlock)
|
|
682
|
-
): # pragma: no branch
|
|
711
|
+
elif isinstance(event.delta, BetaInputJSONDelta):
|
|
683
712
|
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
684
|
-
vendor_part_id=
|
|
685
|
-
tool_name='',
|
|
713
|
+
vendor_part_id=event.index,
|
|
686
714
|
args=event.delta.partial_json,
|
|
687
|
-
tool_call_id=current_block.id,
|
|
688
715
|
)
|
|
689
716
|
if maybe_event is not None: # pragma: no branch
|
|
690
717
|
yield maybe_event
|
|
691
|
-
elif isinstance(event.delta, BetaInputJSONDelta):
|
|
692
|
-
pass
|
|
693
718
|
# TODO(Marcelo): We need to handle citations.
|
|
694
719
|
elif isinstance(event.delta, BetaCitationsDelta):
|
|
695
720
|
pass
|
|
@@ -717,3 +742,52 @@ class AnthropicStreamedResponse(StreamedResponse):
|
|
|
717
742
|
def timestamp(self) -> datetime:
|
|
718
743
|
"""Get the timestamp of the response."""
|
|
719
744
|
return self._timestamp
|
|
745
|
+
|
|
746
|
+
|
|
747
|
+
def _map_server_tool_use_block(item: BetaServerToolUseBlock, provider_name: str) -> BuiltinToolCallPart:
|
|
748
|
+
if item.name == 'web_search':
|
|
749
|
+
return BuiltinToolCallPart(
|
|
750
|
+
provider_name=provider_name,
|
|
751
|
+
tool_name=WebSearchTool.kind,
|
|
752
|
+
args=cast(dict[str, Any], item.input) or None,
|
|
753
|
+
tool_call_id=item.id,
|
|
754
|
+
)
|
|
755
|
+
elif item.name == 'code_execution':
|
|
756
|
+
return BuiltinToolCallPart(
|
|
757
|
+
provider_name=provider_name,
|
|
758
|
+
tool_name=CodeExecutionTool.kind,
|
|
759
|
+
args=cast(dict[str, Any], item.input) or None,
|
|
760
|
+
tool_call_id=item.id,
|
|
761
|
+
)
|
|
762
|
+
else:
|
|
763
|
+
assert_never(item.name)
|
|
764
|
+
|
|
765
|
+
|
|
766
|
+
web_search_tool_result_content_ta: TypeAdapter[BetaWebSearchToolResultBlockContent] = TypeAdapter(
|
|
767
|
+
BetaWebSearchToolResultBlockContent
|
|
768
|
+
)
|
|
769
|
+
|
|
770
|
+
|
|
771
|
+
def _map_web_search_tool_result_block(item: BetaWebSearchToolResultBlock, provider_name: str) -> BuiltinToolReturnPart:
|
|
772
|
+
return BuiltinToolReturnPart(
|
|
773
|
+
provider_name=provider_name,
|
|
774
|
+
tool_name=WebSearchTool.kind,
|
|
775
|
+
content=web_search_tool_result_content_ta.dump_python(item.content, mode='json'),
|
|
776
|
+
tool_call_id=item.tool_use_id,
|
|
777
|
+
)
|
|
778
|
+
|
|
779
|
+
|
|
780
|
+
code_execution_tool_result_content_ta: TypeAdapter[BetaCodeExecutionToolResultBlockContent] = TypeAdapter(
|
|
781
|
+
BetaCodeExecutionToolResultBlockContent
|
|
782
|
+
)
|
|
783
|
+
|
|
784
|
+
|
|
785
|
+
def _map_code_execution_tool_result_block(
|
|
786
|
+
item: BetaCodeExecutionToolResultBlock, provider_name: str
|
|
787
|
+
) -> BuiltinToolReturnPart:
|
|
788
|
+
return BuiltinToolReturnPart(
|
|
789
|
+
provider_name=provider_name,
|
|
790
|
+
tool_name=CodeExecutionTool.kind,
|
|
791
|
+
content=code_execution_tool_result_content_ta.dump_python(item.content, mode='json'),
|
|
792
|
+
tool_call_id=item.tool_use_id,
|
|
793
|
+
)
|
pydantic_ai/models/function.py
CHANGED
|
@@ -247,18 +247,20 @@ DeltaToolCalls: TypeAlias = dict[int, DeltaToolCall]
|
|
|
247
247
|
DeltaThinkingCalls: TypeAlias = dict[int, DeltaThinkingPart]
|
|
248
248
|
"""A mapping of thinking call IDs to incremental changes."""
|
|
249
249
|
|
|
250
|
+
BuiltinToolCallsReturns: TypeAlias = dict[int, BuiltinToolCallPart | BuiltinToolReturnPart]
|
|
251
|
+
|
|
250
252
|
FunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], ModelResponse | Awaitable[ModelResponse]]
|
|
251
253
|
"""A function used to generate a non-streamed response."""
|
|
252
254
|
|
|
253
255
|
StreamFunctionDef: TypeAlias = Callable[
|
|
254
|
-
[list[ModelMessage], AgentInfo], AsyncIterator[str | DeltaToolCalls | DeltaThinkingCalls]
|
|
256
|
+
[list[ModelMessage], AgentInfo], AsyncIterator[str | DeltaToolCalls | DeltaThinkingCalls | BuiltinToolCallsReturns]
|
|
255
257
|
]
|
|
256
258
|
"""A function used to generate a streamed response.
|
|
257
259
|
|
|
258
|
-
While this is defined as having return type of `AsyncIterator[str | DeltaToolCalls | DeltaThinkingCalls]`, it should
|
|
260
|
+
While this is defined as having return type of `AsyncIterator[str | DeltaToolCalls | DeltaThinkingCalls | BuiltinTools]`, it should
|
|
259
261
|
really be considered as `AsyncIterator[str] | AsyncIterator[DeltaToolCalls] | AsyncIterator[DeltaThinkingCalls]`,
|
|
260
262
|
|
|
261
|
-
E.g. you need to yield all text, all `DeltaToolCalls`, or all `
|
|
263
|
+
E.g. you need to yield all text, all `DeltaToolCalls`, all `DeltaThinkingCalls`, or all `BuiltinToolCallsReturns`, not mix them.
|
|
262
264
|
"""
|
|
263
265
|
|
|
264
266
|
|
|
@@ -267,7 +269,7 @@ class FunctionStreamedResponse(StreamedResponse):
|
|
|
267
269
|
"""Implementation of `StreamedResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
|
|
268
270
|
|
|
269
271
|
_model_name: str
|
|
270
|
-
_iter: AsyncIterator[str | DeltaToolCalls | DeltaThinkingCalls]
|
|
272
|
+
_iter: AsyncIterator[str | DeltaToolCalls | DeltaThinkingCalls | BuiltinToolCallsReturns]
|
|
271
273
|
_timestamp: datetime = field(default_factory=_utils.now_utc)
|
|
272
274
|
|
|
273
275
|
def __post_init__(self):
|
|
@@ -305,6 +307,16 @@ class FunctionStreamedResponse(StreamedResponse):
|
|
|
305
307
|
)
|
|
306
308
|
if maybe_event is not None: # pragma: no branch
|
|
307
309
|
yield maybe_event
|
|
310
|
+
elif isinstance(delta, BuiltinToolCallPart):
|
|
311
|
+
if content := delta.args_as_json_str(): # pragma: no branch
|
|
312
|
+
response_tokens = _estimate_string_tokens(content)
|
|
313
|
+
self._usage += usage.RequestUsage(output_tokens=response_tokens)
|
|
314
|
+
yield self._parts_manager.handle_builtin_tool_call_part(vendor_part_id=dtc_index, part=delta)
|
|
315
|
+
elif isinstance(delta, BuiltinToolReturnPart):
|
|
316
|
+
if content := delta.model_response_str(): # pragma: no branch
|
|
317
|
+
response_tokens = _estimate_string_tokens(content)
|
|
318
|
+
self._usage += usage.RequestUsage(output_tokens=response_tokens)
|
|
319
|
+
yield self._parts_manager.handle_builtin_tool_return_part(vendor_part_id=dtc_index, part=delta)
|
|
308
320
|
else:
|
|
309
321
|
assert_never(delta)
|
|
310
322
|
|
|
@@ -351,11 +363,8 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.RequestUsage:
|
|
|
351
363
|
response_tokens += _estimate_string_tokens(part.content)
|
|
352
364
|
elif isinstance(part, ToolCallPart):
|
|
353
365
|
response_tokens += 1 + _estimate_string_tokens(part.args_as_json_str())
|
|
354
|
-
# TODO(Marcelo): We need to add coverage here.
|
|
355
366
|
elif isinstance(part, BuiltinToolCallPart): # pragma: no cover
|
|
356
|
-
|
|
357
|
-
response_tokens += 1 + _estimate_string_tokens(call.args_as_json_str())
|
|
358
|
-
# TODO(Marcelo): We need to add coverage here.
|
|
367
|
+
response_tokens += 1 + _estimate_string_tokens(part.args_as_json_str())
|
|
359
368
|
elif isinstance(part, BuiltinToolReturnPart): # pragma: no cover
|
|
360
369
|
response_tokens += _estimate_string_tokens(part.model_response_str())
|
|
361
370
|
else:
|