pydantic-ai-slim 0.0.18__py3-none-any.whl → 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_griffe.py +10 -3
- pydantic_ai/_parts_manager.py +239 -0
- pydantic_ai/_pydantic.py +16 -3
- pydantic_ai/_utils.py +80 -17
- pydantic_ai/agent.py +82 -74
- pydantic_ai/format_as_xml.py +2 -1
- pydantic_ai/messages.py +218 -9
- pydantic_ai/models/__init__.py +31 -72
- pydantic_ai/models/anthropic.py +21 -21
- pydantic_ai/models/function.py +47 -79
- pydantic_ai/models/gemini.py +76 -122
- pydantic_ai/models/groq.py +53 -125
- pydantic_ai/models/mistral.py +75 -137
- pydantic_ai/models/ollama.py +1 -0
- pydantic_ai/models/openai.py +50 -125
- pydantic_ai/models/test.py +40 -73
- pydantic_ai/result.py +91 -92
- pydantic_ai/tools.py +24 -5
- {pydantic_ai_slim-0.0.18.dist-info → pydantic_ai_slim-0.0.19.dist-info}/METADATA +3 -1
- pydantic_ai_slim-0.0.19.dist-info/RECORD +29 -0
- pydantic_ai_slim-0.0.18.dist-info/RECORD +0 -28
- {pydantic_ai_slim-0.0.18.dist-info → pydantic_ai_slim-0.0.19.dist-info}/WHEEL +0 -0
pydantic_ai/agent.py
CHANGED
|
@@ -26,6 +26,7 @@ from .result import ResultData
|
|
|
26
26
|
from .settings import ModelSettings, merge_model_settings
|
|
27
27
|
from .tools import (
|
|
28
28
|
AgentDeps,
|
|
29
|
+
DocstringFormat,
|
|
29
30
|
RunContext,
|
|
30
31
|
Tool,
|
|
31
32
|
ToolDefinition,
|
|
@@ -242,9 +243,10 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
242
243
|
|
|
243
244
|
agent = Agent('openai:gpt-4o')
|
|
244
245
|
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
246
|
+
async def main():
|
|
247
|
+
result = await agent.run('What is the capital of France?')
|
|
248
|
+
print(result.data)
|
|
249
|
+
#> Paris
|
|
248
250
|
```
|
|
249
251
|
|
|
250
252
|
Args:
|
|
@@ -382,10 +384,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
382
384
|
|
|
383
385
|
agent = Agent('openai:gpt-4o')
|
|
384
386
|
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
#> Paris
|
|
387
|
+
result_sync = agent.run_sync('What is the capital of Italy?')
|
|
388
|
+
print(result_sync.data)
|
|
389
|
+
#> Rome
|
|
389
390
|
```
|
|
390
391
|
|
|
391
392
|
Args:
|
|
@@ -535,7 +536,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
535
536
|
model_req_span.__exit__(None, None, None)
|
|
536
537
|
|
|
537
538
|
with _logfire.span('handle model response') as handle_span:
|
|
538
|
-
maybe_final_result = await self.
|
|
539
|
+
maybe_final_result = await self._handle_streamed_response(
|
|
539
540
|
model_response, run_context, result_schema
|
|
540
541
|
)
|
|
541
542
|
|
|
@@ -774,6 +775,8 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
774
775
|
*,
|
|
775
776
|
retries: int | None = None,
|
|
776
777
|
prepare: ToolPrepareFunc[AgentDeps] | None = None,
|
|
778
|
+
docstring_format: DocstringFormat = 'auto',
|
|
779
|
+
require_parameter_descriptions: bool = False,
|
|
777
780
|
) -> Callable[[ToolFuncContext[AgentDeps, ToolParams]], ToolFuncContext[AgentDeps, ToolParams]]: ...
|
|
778
781
|
|
|
779
782
|
def tool(
|
|
@@ -783,6 +786,8 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
783
786
|
*,
|
|
784
787
|
retries: int | None = None,
|
|
785
788
|
prepare: ToolPrepareFunc[AgentDeps] | None = None,
|
|
789
|
+
docstring_format: DocstringFormat = 'auto',
|
|
790
|
+
require_parameter_descriptions: bool = False,
|
|
786
791
|
) -> Any:
|
|
787
792
|
"""Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
|
|
788
793
|
|
|
@@ -820,6 +825,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
820
825
|
prepare: custom method to prepare the tool definition for each step, return `None` to omit this
|
|
821
826
|
tool from a given step. This is useful if you want to customise a tool at call time,
|
|
822
827
|
or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
|
|
828
|
+
docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
|
|
829
|
+
Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
|
|
830
|
+
require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
|
|
823
831
|
"""
|
|
824
832
|
if func is None:
|
|
825
833
|
|
|
@@ -827,13 +835,13 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
827
835
|
func_: ToolFuncContext[AgentDeps, ToolParams],
|
|
828
836
|
) -> ToolFuncContext[AgentDeps, ToolParams]:
|
|
829
837
|
# noinspection PyTypeChecker
|
|
830
|
-
self._register_function(func_, True, retries, prepare)
|
|
838
|
+
self._register_function(func_, True, retries, prepare, docstring_format, require_parameter_descriptions)
|
|
831
839
|
return func_
|
|
832
840
|
|
|
833
841
|
return tool_decorator
|
|
834
842
|
else:
|
|
835
843
|
# noinspection PyTypeChecker
|
|
836
|
-
self._register_function(func, True, retries, prepare)
|
|
844
|
+
self._register_function(func, True, retries, prepare, docstring_format, require_parameter_descriptions)
|
|
837
845
|
return func
|
|
838
846
|
|
|
839
847
|
@overload
|
|
@@ -846,6 +854,8 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
846
854
|
*,
|
|
847
855
|
retries: int | None = None,
|
|
848
856
|
prepare: ToolPrepareFunc[AgentDeps] | None = None,
|
|
857
|
+
docstring_format: DocstringFormat = 'auto',
|
|
858
|
+
require_parameter_descriptions: bool = False,
|
|
849
859
|
) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...
|
|
850
860
|
|
|
851
861
|
def tool_plain(
|
|
@@ -855,6 +865,8 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
855
865
|
*,
|
|
856
866
|
retries: int | None = None,
|
|
857
867
|
prepare: ToolPrepareFunc[AgentDeps] | None = None,
|
|
868
|
+
docstring_format: DocstringFormat = 'auto',
|
|
869
|
+
require_parameter_descriptions: bool = False,
|
|
858
870
|
) -> Any:
|
|
859
871
|
"""Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
|
|
860
872
|
|
|
@@ -892,17 +904,22 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
892
904
|
prepare: custom method to prepare the tool definition for each step, return `None` to omit this
|
|
893
905
|
tool from a given step. This is useful if you want to customise a tool at call time,
|
|
894
906
|
or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
|
|
907
|
+
docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
|
|
908
|
+
Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
|
|
909
|
+
require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
|
|
895
910
|
"""
|
|
896
911
|
if func is None:
|
|
897
912
|
|
|
898
913
|
def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
|
|
899
914
|
# noinspection PyTypeChecker
|
|
900
|
-
self._register_function(
|
|
915
|
+
self._register_function(
|
|
916
|
+
func_, False, retries, prepare, docstring_format, require_parameter_descriptions
|
|
917
|
+
)
|
|
901
918
|
return func_
|
|
902
919
|
|
|
903
920
|
return tool_decorator
|
|
904
921
|
else:
|
|
905
|
-
self._register_function(func, False, retries, prepare)
|
|
922
|
+
self._register_function(func, False, retries, prepare, docstring_format, require_parameter_descriptions)
|
|
906
923
|
return func
|
|
907
924
|
|
|
908
925
|
def _register_function(
|
|
@@ -911,10 +928,19 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
911
928
|
takes_ctx: bool,
|
|
912
929
|
retries: int | None,
|
|
913
930
|
prepare: ToolPrepareFunc[AgentDeps] | None,
|
|
931
|
+
docstring_format: DocstringFormat,
|
|
932
|
+
require_parameter_descriptions: bool,
|
|
914
933
|
) -> None:
|
|
915
934
|
"""Private utility to register a function as a tool."""
|
|
916
935
|
retries_ = retries if retries is not None else self._default_retries
|
|
917
|
-
tool = Tool(
|
|
936
|
+
tool = Tool(
|
|
937
|
+
func,
|
|
938
|
+
takes_ctx=takes_ctx,
|
|
939
|
+
max_retries=retries_,
|
|
940
|
+
prepare=prepare,
|
|
941
|
+
docstring_format=docstring_format,
|
|
942
|
+
require_parameter_descriptions=require_parameter_descriptions,
|
|
943
|
+
)
|
|
918
944
|
self._register_tool(tool)
|
|
919
945
|
|
|
920
946
|
def _register_tool(self, tool: Tool[AgentDeps]) -> None:
|
|
@@ -1100,7 +1126,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1100
1126
|
final_result: _MarkFinalResult[RunResultData] | None = None
|
|
1101
1127
|
|
|
1102
1128
|
parts: list[_messages.ModelRequestPart] = []
|
|
1103
|
-
if result_schema
|
|
1129
|
+
if result_schema is not None:
|
|
1104
1130
|
if match := result_schema.find_tool(tool_calls):
|
|
1105
1131
|
call, result_tool = match
|
|
1106
1132
|
try:
|
|
@@ -1179,76 +1205,58 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1179
1205
|
parts.extend(task_results)
|
|
1180
1206
|
return parts
|
|
1181
1207
|
|
|
1182
|
-
async def
|
|
1208
|
+
async def _handle_streamed_response(
|
|
1183
1209
|
self,
|
|
1184
|
-
|
|
1210
|
+
streamed_response: models.StreamedResponse,
|
|
1185
1211
|
run_context: RunContext[AgentDeps],
|
|
1186
1212
|
result_schema: _result.ResultSchema[RunResultData] | None,
|
|
1187
|
-
) ->
|
|
1188
|
-
_MarkFinalResult[models.EitherStreamedResponse]
|
|
1189
|
-
| tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]]
|
|
1190
|
-
):
|
|
1213
|
+
) -> _MarkFinalResult[models.StreamedResponse] | tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]]:
|
|
1191
1214
|
"""Process a streamed response from the model.
|
|
1192
1215
|
|
|
1193
1216
|
Returns:
|
|
1194
1217
|
Either a final result or a tuple of the model response and the tool responses for the next request.
|
|
1195
1218
|
If a final result is returned, the conversation should end.
|
|
1196
1219
|
"""
|
|
1197
|
-
|
|
1198
|
-
|
|
1199
|
-
|
|
1200
|
-
|
|
1201
|
-
|
|
1202
|
-
|
|
1203
|
-
|
|
1204
|
-
|
|
1205
|
-
|
|
1206
|
-
|
|
1207
|
-
|
|
1208
|
-
|
|
1209
|
-
|
|
1210
|
-
|
|
1211
|
-
|
|
1212
|
-
elif isinstance(model_response, models.StreamStructuredResponse):
|
|
1213
|
-
if result_schema is not None:
|
|
1214
|
-
# if there's a result schema, iterate over the stream until we find at least one tool
|
|
1215
|
-
# NOTE: this means we ignore any other tools called here
|
|
1216
|
-
structured_msg = model_response.get()
|
|
1217
|
-
while not structured_msg.parts:
|
|
1218
|
-
try:
|
|
1219
|
-
await model_response.__anext__()
|
|
1220
|
-
except StopAsyncIteration:
|
|
1221
|
-
break
|
|
1222
|
-
structured_msg = model_response.get()
|
|
1223
|
-
|
|
1224
|
-
if match := result_schema.find_tool(structured_msg.parts):
|
|
1225
|
-
call, _ = match
|
|
1226
|
-
return _MarkFinalResult(model_response, call.tool_name)
|
|
1227
|
-
|
|
1228
|
-
# the model is calling a tool function, consume the response to get the next message
|
|
1229
|
-
async for _ in model_response:
|
|
1230
|
-
pass
|
|
1231
|
-
model_response_msg = model_response.get()
|
|
1232
|
-
if not model_response_msg.parts:
|
|
1233
|
-
raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
|
|
1234
|
-
|
|
1235
|
-
# we now run all tool functions in parallel
|
|
1236
|
-
tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
|
|
1237
|
-
parts: list[_messages.ModelRequestPart] = []
|
|
1238
|
-
for item in model_response_msg.parts:
|
|
1239
|
-
if isinstance(item, _messages.ToolCallPart):
|
|
1240
|
-
call = item
|
|
1241
|
-
if tool := self._function_tools.get(call.tool_name):
|
|
1242
|
-
tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
|
|
1243
|
-
else:
|
|
1244
|
-
parts.append(self._unknown_tool(call.tool_name, run_context, result_schema))
|
|
1220
|
+
received_text = False
|
|
1221
|
+
|
|
1222
|
+
async for maybe_part_event in streamed_response:
|
|
1223
|
+
if isinstance(maybe_part_event, _messages.PartStartEvent):
|
|
1224
|
+
new_part = maybe_part_event.part
|
|
1225
|
+
if isinstance(new_part, _messages.TextPart):
|
|
1226
|
+
received_text = True
|
|
1227
|
+
if self._allow_text_result(result_schema):
|
|
1228
|
+
return _MarkFinalResult(streamed_response, None)
|
|
1229
|
+
elif isinstance(new_part, _messages.ToolCallPart):
|
|
1230
|
+
if result_schema is not None and (match := result_schema.find_tool([new_part])):
|
|
1231
|
+
call, _ = match
|
|
1232
|
+
return _MarkFinalResult(streamed_response, call.tool_name)
|
|
1233
|
+
else:
|
|
1234
|
+
assert_never(new_part)
|
|
1245
1235
|
|
|
1246
|
-
|
|
1247
|
-
|
|
1248
|
-
|
|
1249
|
-
|
|
1250
|
-
|
|
1251
|
-
|
|
1236
|
+
tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
|
|
1237
|
+
parts: list[_messages.ModelRequestPart] = []
|
|
1238
|
+
model_response = streamed_response.get()
|
|
1239
|
+
if not model_response.parts:
|
|
1240
|
+
raise exceptions.UnexpectedModelBehavior('Received empty model response')
|
|
1241
|
+
for p in model_response.parts:
|
|
1242
|
+
if isinstance(p, _messages.ToolCallPart):
|
|
1243
|
+
if tool := self._function_tools.get(p.tool_name):
|
|
1244
|
+
tasks.append(asyncio.create_task(tool.run(p, run_context), name=p.tool_name))
|
|
1245
|
+
else:
|
|
1246
|
+
parts.append(self._unknown_tool(p.tool_name, run_context, result_schema))
|
|
1247
|
+
|
|
1248
|
+
if received_text and not tasks and not parts:
|
|
1249
|
+
# Can only get here if self._allow_text_result returns `False` for the provided result_schema
|
|
1250
|
+
self._incr_result_retry(run_context)
|
|
1251
|
+
model_response = _messages.RetryPromptPart(
|
|
1252
|
+
content='Plain text responses are not permitted, please call one of the functions instead.',
|
|
1253
|
+
)
|
|
1254
|
+
return streamed_response.get(), [model_response]
|
|
1255
|
+
|
|
1256
|
+
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
|
|
1257
|
+
task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
|
|
1258
|
+
parts.extend(task_results)
|
|
1259
|
+
return model_response, parts
|
|
1252
1260
|
|
|
1253
1261
|
async def _validate_result(
|
|
1254
1262
|
self,
|
pydantic_ai/format_as_xml.py
CHANGED
|
@@ -37,7 +37,8 @@ def format_as_xml(
|
|
|
37
37
|
none_str: String to use for `None` values.
|
|
38
38
|
indent: Indentation string to use for pretty printing.
|
|
39
39
|
|
|
40
|
-
Returns:
|
|
40
|
+
Returns:
|
|
41
|
+
XML representation of the object.
|
|
41
42
|
|
|
42
43
|
Example:
|
|
43
44
|
```python {title="format_as_xml_example.py" lint="skip"}
|
pydantic_ai/messages.py
CHANGED
|
@@ -1,14 +1,15 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
from dataclasses import dataclass, field
|
|
3
|
+
from dataclasses import dataclass, field, replace
|
|
4
4
|
from datetime import datetime
|
|
5
|
-
from typing import Annotated, Any, Literal, Union, cast
|
|
5
|
+
from typing import Annotated, Any, Literal, Union, cast, overload
|
|
6
6
|
|
|
7
7
|
import pydantic
|
|
8
8
|
import pydantic_core
|
|
9
9
|
from typing_extensions import Self, assert_never
|
|
10
10
|
|
|
11
11
|
from ._utils import now_utc as _now_utc
|
|
12
|
+
from .exceptions import UnexpectedModelBehavior
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
@dataclass
|
|
@@ -72,12 +73,14 @@ class ToolReturnPart:
|
|
|
72
73
|
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
73
74
|
|
|
74
75
|
def model_response_str(self) -> str:
|
|
76
|
+
"""Return a string representation of the content for the model."""
|
|
75
77
|
if isinstance(self.content, str):
|
|
76
78
|
return self.content
|
|
77
79
|
else:
|
|
78
80
|
return tool_return_ta.dump_json(self.content).decode()
|
|
79
81
|
|
|
80
82
|
def model_response_object(self) -> dict[str, Any]:
|
|
83
|
+
"""Return a dictionary representation of the content, wrapping non-dict types appropriately."""
|
|
81
84
|
# gemini supports JSON dict return values, but no other JSON types, hence we wrap anything else in a dict
|
|
82
85
|
if isinstance(self.content, dict):
|
|
83
86
|
return tool_return_ta.dump_python(self.content, mode='json') # pyright: ignore[reportUnknownMemberType]
|
|
@@ -124,6 +127,7 @@ class RetryPromptPart:
|
|
|
124
127
|
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
125
128
|
|
|
126
129
|
def model_response(self) -> str:
|
|
130
|
+
"""Return a string message describing why the retry is requested."""
|
|
127
131
|
if isinstance(self.content, str):
|
|
128
132
|
description = self.content
|
|
129
133
|
else:
|
|
@@ -159,6 +163,10 @@ class TextPart:
|
|
|
159
163
|
part_kind: Literal['text'] = 'text'
|
|
160
164
|
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
161
165
|
|
|
166
|
+
def has_content(self) -> bool:
|
|
167
|
+
"""Return `True` if the text content is non-empty."""
|
|
168
|
+
return bool(self.content)
|
|
169
|
+
|
|
162
170
|
|
|
163
171
|
@dataclass
|
|
164
172
|
class ArgsJson:
|
|
@@ -197,7 +205,7 @@ class ToolCallPart:
|
|
|
197
205
|
|
|
198
206
|
@classmethod
|
|
199
207
|
def from_raw_args(cls, tool_name: str, args: str | dict[str, Any], tool_call_id: str | None = None) -> Self:
|
|
200
|
-
"""Create a `ToolCallPart` from raw arguments
|
|
208
|
+
"""Create a `ToolCallPart` from raw arguments, converting them to `ArgsJson` or `ArgsDict`."""
|
|
201
209
|
if isinstance(args, str):
|
|
202
210
|
return cls(tool_name, ArgsJson(args), tool_call_id)
|
|
203
211
|
elif isinstance(args, dict):
|
|
@@ -226,6 +234,7 @@ class ToolCallPart:
|
|
|
226
234
|
return pydantic_core.to_json(self.args.args_dict).decode()
|
|
227
235
|
|
|
228
236
|
def has_content(self) -> bool:
|
|
237
|
+
"""Return `True` if the arguments contain any data."""
|
|
229
238
|
if isinstance(self.args, ArgsDict):
|
|
230
239
|
return any(self.args.args_dict.values())
|
|
231
240
|
else:
|
|
@@ -254,17 +263,217 @@ class ModelResponse:
|
|
|
254
263
|
|
|
255
264
|
@classmethod
|
|
256
265
|
def from_text(cls, content: str, timestamp: datetime | None = None) -> Self:
|
|
257
|
-
|
|
266
|
+
"""Create a `ModelResponse` containing a single `TextPart`."""
|
|
267
|
+
return cls([TextPart(content=content)], timestamp=timestamp or _now_utc())
|
|
258
268
|
|
|
259
269
|
@classmethod
|
|
260
270
|
def from_tool_call(cls, tool_call: ToolCallPart) -> Self:
|
|
271
|
+
"""Create a `ModelResponse` containing a single `ToolCallPart`."""
|
|
261
272
|
return cls([tool_call])
|
|
262
273
|
|
|
263
274
|
|
|
264
|
-
ModelMessage = Union[ModelRequest, ModelResponse]
|
|
265
|
-
"""Any message
|
|
275
|
+
ModelMessage = Annotated[Union[ModelRequest, ModelResponse], pydantic.Discriminator('kind')]
|
|
276
|
+
"""Any message sent to or returned by a model."""
|
|
266
277
|
|
|
267
|
-
ModelMessagesTypeAdapter = pydantic.TypeAdapter(
|
|
268
|
-
list[Annotated[ModelMessage, pydantic.Discriminator('kind')]], config=pydantic.ConfigDict(defer_build=True)
|
|
269
|
-
)
|
|
278
|
+
ModelMessagesTypeAdapter = pydantic.TypeAdapter(list[ModelMessage], config=pydantic.ConfigDict(defer_build=True))
|
|
270
279
|
"""Pydantic [`TypeAdapter`][pydantic.type_adapter.TypeAdapter] for (de)serializing messages."""
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
@dataclass
|
|
283
|
+
class TextPartDelta:
|
|
284
|
+
"""A partial update (delta) for a `TextPart` to append new text content."""
|
|
285
|
+
|
|
286
|
+
content_delta: str
|
|
287
|
+
"""The incremental text content to add to the existing `TextPart` content."""
|
|
288
|
+
|
|
289
|
+
part_delta_kind: Literal['text'] = 'text'
|
|
290
|
+
"""Part delta type identifier, used as a discriminator."""
|
|
291
|
+
|
|
292
|
+
def apply(self, part: ModelResponsePart) -> TextPart:
|
|
293
|
+
"""Apply this text delta to an existing `TextPart`.
|
|
294
|
+
|
|
295
|
+
Args:
|
|
296
|
+
part: The existing model response part, which must be a `TextPart`.
|
|
297
|
+
|
|
298
|
+
Returns:
|
|
299
|
+
A new `TextPart` with updated text content.
|
|
300
|
+
|
|
301
|
+
Raises:
|
|
302
|
+
ValueError: If `part` is not a `TextPart`.
|
|
303
|
+
"""
|
|
304
|
+
if not isinstance(part, TextPart):
|
|
305
|
+
raise ValueError('Cannot apply TextPartDeltas to non-TextParts')
|
|
306
|
+
return replace(part, content=part.content + self.content_delta)
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
@dataclass
|
|
310
|
+
class ToolCallPartDelta:
|
|
311
|
+
"""A partial update (delta) for a `ToolCallPart` to modify tool name, arguments, or tool call ID."""
|
|
312
|
+
|
|
313
|
+
tool_name_delta: str | None = None
|
|
314
|
+
"""Incremental text to add to the existing tool name, if any."""
|
|
315
|
+
|
|
316
|
+
args_delta: str | dict[str, Any] | None = None
|
|
317
|
+
"""Incremental data to add to the tool arguments.
|
|
318
|
+
|
|
319
|
+
If this is a string, it will be appended to existing JSON arguments.
|
|
320
|
+
If this is a dict, it will be merged with existing dict arguments.
|
|
321
|
+
"""
|
|
322
|
+
|
|
323
|
+
tool_call_id: str | None = None
|
|
324
|
+
"""Optional tool call identifier, this is used by some models including OpenAI.
|
|
325
|
+
|
|
326
|
+
Note this is never treated as a delta — it can replace None, but otherwise if a
|
|
327
|
+
non-matching value is provided an error will be raised."""
|
|
328
|
+
|
|
329
|
+
part_delta_kind: Literal['tool_call'] = 'tool_call'
|
|
330
|
+
"""Part delta type identifier, used as a discriminator."""
|
|
331
|
+
|
|
332
|
+
def as_part(self) -> ToolCallPart | None:
|
|
333
|
+
"""Convert this delta to a fully formed `ToolCallPart` if possible, otherwise return `None`.
|
|
334
|
+
|
|
335
|
+
Returns:
|
|
336
|
+
A `ToolCallPart` if both `tool_name_delta` and `args_delta` are set, otherwise `None`.
|
|
337
|
+
"""
|
|
338
|
+
if self.tool_name_delta is None or self.args_delta is None:
|
|
339
|
+
return None
|
|
340
|
+
|
|
341
|
+
return ToolCallPart.from_raw_args(
|
|
342
|
+
self.tool_name_delta,
|
|
343
|
+
self.args_delta,
|
|
344
|
+
self.tool_call_id,
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
@overload
|
|
348
|
+
def apply(self, part: ModelResponsePart) -> ToolCallPart: ...
|
|
349
|
+
|
|
350
|
+
@overload
|
|
351
|
+
def apply(self, part: ModelResponsePart | ToolCallPartDelta) -> ToolCallPart | ToolCallPartDelta: ...
|
|
352
|
+
|
|
353
|
+
def apply(self, part: ModelResponsePart | ToolCallPartDelta) -> ToolCallPart | ToolCallPartDelta:
|
|
354
|
+
"""Apply this delta to a part or delta, returning a new part or delta with the changes applied.
|
|
355
|
+
|
|
356
|
+
Args:
|
|
357
|
+
part: The existing model response part or delta to update.
|
|
358
|
+
|
|
359
|
+
Returns:
|
|
360
|
+
Either a new `ToolCallPart` or an updated `ToolCallPartDelta`.
|
|
361
|
+
|
|
362
|
+
Raises:
|
|
363
|
+
ValueError: If `part` is neither a `ToolCallPart` nor a `ToolCallPartDelta`.
|
|
364
|
+
UnexpectedModelBehavior: If applying JSON deltas to dict arguments or vice versa.
|
|
365
|
+
"""
|
|
366
|
+
if isinstance(part, ToolCallPart):
|
|
367
|
+
return self._apply_to_part(part)
|
|
368
|
+
|
|
369
|
+
if isinstance(part, ToolCallPartDelta):
|
|
370
|
+
return self._apply_to_delta(part)
|
|
371
|
+
|
|
372
|
+
raise ValueError(f'Can only apply ToolCallPartDeltas to ToolCallParts or ToolCallPartDeltas, not {part}')
|
|
373
|
+
|
|
374
|
+
def _apply_to_delta(self, delta: ToolCallPartDelta) -> ToolCallPart | ToolCallPartDelta:
|
|
375
|
+
"""Internal helper to apply this delta to another delta."""
|
|
376
|
+
if self.tool_name_delta:
|
|
377
|
+
# Append incremental text to the existing tool_name_delta
|
|
378
|
+
updated_tool_name_delta = (delta.tool_name_delta or '') + self.tool_name_delta
|
|
379
|
+
delta = replace(delta, tool_name_delta=updated_tool_name_delta)
|
|
380
|
+
|
|
381
|
+
if isinstance(self.args_delta, str):
|
|
382
|
+
if isinstance(delta.args_delta, dict):
|
|
383
|
+
raise UnexpectedModelBehavior(
|
|
384
|
+
f'Cannot apply JSON deltas to non-JSON tool arguments ({delta=}, {self=})'
|
|
385
|
+
)
|
|
386
|
+
updated_args_delta = (delta.args_delta or '') + self.args_delta
|
|
387
|
+
delta = replace(delta, args_delta=updated_args_delta)
|
|
388
|
+
elif isinstance(self.args_delta, dict):
|
|
389
|
+
if isinstance(delta.args_delta, str):
|
|
390
|
+
raise UnexpectedModelBehavior(
|
|
391
|
+
f'Cannot apply dict deltas to non-dict tool arguments ({delta=}, {self=})'
|
|
392
|
+
)
|
|
393
|
+
updated_args_delta = {**(delta.args_delta or {}), **self.args_delta}
|
|
394
|
+
delta = replace(delta, args_delta=updated_args_delta)
|
|
395
|
+
|
|
396
|
+
if self.tool_call_id:
|
|
397
|
+
# Set the tool_call_id if it wasn't present, otherwise error if it has changed
|
|
398
|
+
if delta.tool_call_id is not None and delta.tool_call_id != self.tool_call_id:
|
|
399
|
+
raise UnexpectedModelBehavior(
|
|
400
|
+
f'Cannot apply a new tool_call_id to a ToolCallPartDelta that already has one ({delta=}, {self=})'
|
|
401
|
+
)
|
|
402
|
+
delta = replace(delta, tool_call_id=self.tool_call_id)
|
|
403
|
+
|
|
404
|
+
# If we now have enough data to create a full ToolCallPart, do so
|
|
405
|
+
if delta.tool_name_delta is not None and delta.args_delta is not None:
|
|
406
|
+
return ToolCallPart.from_raw_args(
|
|
407
|
+
delta.tool_name_delta,
|
|
408
|
+
delta.args_delta,
|
|
409
|
+
delta.tool_call_id,
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
return delta
|
|
413
|
+
|
|
414
|
+
def _apply_to_part(self, part: ToolCallPart) -> ToolCallPart:
|
|
415
|
+
"""Internal helper to apply this delta directly to a `ToolCallPart`."""
|
|
416
|
+
if self.tool_name_delta:
|
|
417
|
+
# Append incremental text to the existing tool_name
|
|
418
|
+
tool_name = part.tool_name + self.tool_name_delta
|
|
419
|
+
part = replace(part, tool_name=tool_name)
|
|
420
|
+
|
|
421
|
+
if isinstance(self.args_delta, str):
|
|
422
|
+
if not isinstance(part.args, ArgsJson):
|
|
423
|
+
raise UnexpectedModelBehavior(f'Cannot apply JSON deltas to non-JSON tool arguments ({part=}, {self=})')
|
|
424
|
+
updated_json = part.args.args_json + self.args_delta
|
|
425
|
+
part = replace(part, args=ArgsJson(updated_json))
|
|
426
|
+
elif isinstance(self.args_delta, dict):
|
|
427
|
+
if not isinstance(part.args, ArgsDict):
|
|
428
|
+
raise UnexpectedModelBehavior(f'Cannot apply dict deltas to non-dict tool arguments ({part=}, {self=})')
|
|
429
|
+
updated_dict = {**(part.args.args_dict or {}), **self.args_delta}
|
|
430
|
+
part = replace(part, args=ArgsDict(updated_dict))
|
|
431
|
+
|
|
432
|
+
if self.tool_call_id:
|
|
433
|
+
# Replace the tool_call_id entirely if given
|
|
434
|
+
if part.tool_call_id is not None and part.tool_call_id != self.tool_call_id:
|
|
435
|
+
raise UnexpectedModelBehavior(
|
|
436
|
+
f'Cannot apply a new tool_call_id to a ToolCallPartDelta that already has one ({part=}, {self=})'
|
|
437
|
+
)
|
|
438
|
+
part = replace(part, tool_call_id=self.tool_call_id)
|
|
439
|
+
return part
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
ModelResponsePartDelta = Annotated[Union[TextPartDelta, ToolCallPartDelta], pydantic.Discriminator('part_delta_kind')]
|
|
443
|
+
"""A partial update (delta) for any model response part."""
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
@dataclass
|
|
447
|
+
class PartStartEvent:
|
|
448
|
+
"""An event indicating that a new part has started.
|
|
449
|
+
|
|
450
|
+
If multiple `PartStartEvent`s are received with the same index,
|
|
451
|
+
the new one should fully replace the old one.
|
|
452
|
+
"""
|
|
453
|
+
|
|
454
|
+
index: int
|
|
455
|
+
"""The index of the part within the overall response parts list."""
|
|
456
|
+
|
|
457
|
+
part: ModelResponsePart
|
|
458
|
+
"""The newly started `ModelResponsePart`."""
|
|
459
|
+
|
|
460
|
+
event_kind: Literal['part_start'] = 'part_start'
|
|
461
|
+
"""Event type identifier, used as a discriminator."""
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
@dataclass
|
|
465
|
+
class PartDeltaEvent:
|
|
466
|
+
"""An event indicating a delta update for an existing part."""
|
|
467
|
+
|
|
468
|
+
index: int
|
|
469
|
+
"""The index of the part within the overall response parts list."""
|
|
470
|
+
|
|
471
|
+
delta: ModelResponsePartDelta
|
|
472
|
+
"""The delta to apply to the specified part."""
|
|
473
|
+
|
|
474
|
+
event_kind: Literal['part_delta'] = 'part_delta'
|
|
475
|
+
"""Event type identifier, used as a discriminator."""
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
ModelResponseStreamEvent = Annotated[Union[PartStartEvent, PartDeltaEvent], pydantic.Discriminator('event_kind')]
|
|
479
|
+
"""An event in the model response stream, either starting a new part or applying a delta to an existing one."""
|