pydantic-ai-slim 0.0.20__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_parts_manager.py +1 -1
- pydantic_ai/_result.py +3 -7
- pydantic_ai/_utils.py +1 -56
- pydantic_ai/agent.py +34 -30
- pydantic_ai/messages.py +21 -46
- pydantic_ai/models/__init__.py +100 -57
- pydantic_ai/models/anthropic.py +17 -10
- pydantic_ai/models/cohere.py +37 -25
- pydantic_ai/models/gemini.py +20 -6
- pydantic_ai/models/groq.py +19 -17
- pydantic_ai/models/mistral.py +22 -23
- pydantic_ai/models/openai.py +19 -11
- pydantic_ai/models/test.py +37 -22
- pydantic_ai/result.py +1 -1
- pydantic_ai/settings.py +41 -1
- pydantic_ai/tools.py +11 -8
- {pydantic_ai_slim-0.0.20.dist-info → pydantic_ai_slim-0.0.21.dist-info}/METADATA +2 -2
- pydantic_ai_slim-0.0.21.dist-info/RECORD +29 -0
- pydantic_ai/models/ollama.py +0 -123
- pydantic_ai_slim-0.0.20.dist-info/RECORD +0 -30
- {pydantic_ai_slim-0.0.20.dist-info → pydantic_ai_slim-0.0.21.dist-info}/WHEEL +0 -0
pydantic_ai/_parts_manager.py
CHANGED
|
@@ -221,7 +221,7 @@ class ModelResponsePartsManager:
|
|
|
221
221
|
ModelResponseStreamEvent: A `PartStartEvent` indicating that a new tool call part
|
|
222
222
|
has been added to the manager, or replaced an existing part.
|
|
223
223
|
"""
|
|
224
|
-
new_part = ToolCallPart
|
|
224
|
+
new_part = ToolCallPart(tool_name=tool_name, args=args, tool_call_id=tool_call_id)
|
|
225
225
|
if vendor_part_id is None:
|
|
226
226
|
# vendor_part_id is None, so we unconditionally append a new ToolCallPart to the end of the list
|
|
227
227
|
new_part_index = len(self._parts)
|
pydantic_ai/_result.py
CHANGED
|
@@ -201,14 +201,10 @@ class ResultTool(Generic[ResultDataT]):
|
|
|
201
201
|
"""
|
|
202
202
|
try:
|
|
203
203
|
pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
|
|
204
|
-
if isinstance(tool_call.args,
|
|
205
|
-
result = self.type_adapter.validate_json(
|
|
206
|
-
tool_call.args.args_json or '', experimental_allow_partial=pyd_allow_partial
|
|
207
|
-
)
|
|
204
|
+
if isinstance(tool_call.args, str):
|
|
205
|
+
result = self.type_adapter.validate_json(tool_call.args, experimental_allow_partial=pyd_allow_partial)
|
|
208
206
|
else:
|
|
209
|
-
result = self.type_adapter.validate_python(
|
|
210
|
-
tool_call.args.args_dict, experimental_allow_partial=pyd_allow_partial
|
|
211
|
-
)
|
|
207
|
+
result = self.type_adapter.validate_python(tool_call.args, experimental_allow_partial=pyd_allow_partial)
|
|
212
208
|
except ValidationError as e:
|
|
213
209
|
if wrap_validation_errors:
|
|
214
210
|
m = _messages.RetryPromptPart(
|
pydantic_ai/_utils.py
CHANGED
|
@@ -8,7 +8,7 @@ from dataclasses import dataclass, is_dataclass
|
|
|
8
8
|
from datetime import datetime, timezone
|
|
9
9
|
from functools import partial
|
|
10
10
|
from types import GenericAlias
|
|
11
|
-
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union
|
|
11
|
+
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union
|
|
12
12
|
|
|
13
13
|
from pydantic import BaseModel
|
|
14
14
|
from pydantic.json_schema import JsonSchemaValue
|
|
@@ -79,61 +79,6 @@ def is_set(t_or_unset: T | Unset) -> TypeGuard[T]:
|
|
|
79
79
|
return t_or_unset is not UNSET
|
|
80
80
|
|
|
81
81
|
|
|
82
|
-
Left = TypeVar('Left')
|
|
83
|
-
Right = TypeVar('Right')
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
class Either(Generic[Left, Right]):
|
|
87
|
-
"""Two member Union that records which member was set, this is analogous to Rust enums with two variants.
|
|
88
|
-
|
|
89
|
-
Usage:
|
|
90
|
-
|
|
91
|
-
```python
|
|
92
|
-
if left_thing := either.left:
|
|
93
|
-
use_left(left_thing.value)
|
|
94
|
-
else:
|
|
95
|
-
use_right(either.right)
|
|
96
|
-
```
|
|
97
|
-
"""
|
|
98
|
-
|
|
99
|
-
__slots__ = '_left', '_right'
|
|
100
|
-
|
|
101
|
-
@overload
|
|
102
|
-
def __init__(self, *, left: Left) -> None: ...
|
|
103
|
-
|
|
104
|
-
@overload
|
|
105
|
-
def __init__(self, *, right: Right) -> None: ...
|
|
106
|
-
|
|
107
|
-
def __init__(self, left: Left | Unset = UNSET, right: Right | Unset = UNSET) -> None:
|
|
108
|
-
if left is not UNSET:
|
|
109
|
-
assert right is UNSET, '`Either` must receive exactly one argument - `left` or `right`'
|
|
110
|
-
self._left: Option[Left] = Some(cast(Left, left))
|
|
111
|
-
else:
|
|
112
|
-
assert right is not UNSET, '`Either` must receive exactly one argument - `left` or `right`'
|
|
113
|
-
self._left = None
|
|
114
|
-
self._right = cast(Right, right)
|
|
115
|
-
|
|
116
|
-
@property
|
|
117
|
-
def left(self) -> Option[Left]:
|
|
118
|
-
return self._left
|
|
119
|
-
|
|
120
|
-
@property
|
|
121
|
-
def right(self) -> Right:
|
|
122
|
-
return self._right
|
|
123
|
-
|
|
124
|
-
def is_left(self) -> bool:
|
|
125
|
-
return self._left is not None
|
|
126
|
-
|
|
127
|
-
def whichever(self) -> Left | Right:
|
|
128
|
-
return self._left.value if self._left is not None else self.right
|
|
129
|
-
|
|
130
|
-
def __repr__(self):
|
|
131
|
-
if left := self._left:
|
|
132
|
-
return f'Either(left={left.value!r})'
|
|
133
|
-
else:
|
|
134
|
-
return f'Either(right={self.right!r})'
|
|
135
|
-
|
|
136
|
-
|
|
137
82
|
@asynccontextmanager
|
|
138
83
|
async def group_by_temporal(
|
|
139
84
|
aiterable: AsyncIterable[T], soft_max_interval: float | None
|
pydantic_ai/agent.py
CHANGED
|
@@ -60,7 +60,7 @@ EndStrategy = Literal['early', 'exhaustive']
|
|
|
60
60
|
- `'early'`: Stop processing other tool calls once a final result is found
|
|
61
61
|
- `'exhaustive'`: Process all tool calls even after finding a final result
|
|
62
62
|
"""
|
|
63
|
-
|
|
63
|
+
RunResultDataT = TypeVar('RunResultDataT')
|
|
64
64
|
"""Type variable for the result data of a run where `result_type` was customized on the run call."""
|
|
65
65
|
|
|
66
66
|
|
|
@@ -214,7 +214,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
214
214
|
self,
|
|
215
215
|
user_prompt: str,
|
|
216
216
|
*,
|
|
217
|
-
result_type: type[
|
|
217
|
+
result_type: type[RunResultDataT],
|
|
218
218
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
219
219
|
model: models.Model | models.KnownModelName | None = None,
|
|
220
220
|
deps: AgentDepsT = None,
|
|
@@ -222,7 +222,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
222
222
|
usage_limits: _usage.UsageLimits | None = None,
|
|
223
223
|
usage: _usage.Usage | None = None,
|
|
224
224
|
infer_name: bool = True,
|
|
225
|
-
) -> result.RunResult[
|
|
225
|
+
) -> result.RunResult[RunResultDataT]: ...
|
|
226
226
|
|
|
227
227
|
async def run(
|
|
228
228
|
self,
|
|
@@ -234,7 +234,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
234
234
|
model_settings: ModelSettings | None = None,
|
|
235
235
|
usage_limits: _usage.UsageLimits | None = None,
|
|
236
236
|
usage: _usage.Usage | None = None,
|
|
237
|
-
result_type: type[
|
|
237
|
+
result_type: type[RunResultDataT] | None = None,
|
|
238
238
|
infer_name: bool = True,
|
|
239
239
|
) -> result.RunResult[Any]:
|
|
240
240
|
"""Run the agent with a user prompt in async mode.
|
|
@@ -352,7 +352,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
352
352
|
self,
|
|
353
353
|
user_prompt: str,
|
|
354
354
|
*,
|
|
355
|
-
result_type: type[
|
|
355
|
+
result_type: type[RunResultDataT] | None,
|
|
356
356
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
357
357
|
model: models.Model | models.KnownModelName | None = None,
|
|
358
358
|
deps: AgentDepsT = None,
|
|
@@ -360,13 +360,13 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
360
360
|
usage_limits: _usage.UsageLimits | None = None,
|
|
361
361
|
usage: _usage.Usage | None = None,
|
|
362
362
|
infer_name: bool = True,
|
|
363
|
-
) -> result.RunResult[
|
|
363
|
+
) -> result.RunResult[RunResultDataT]: ...
|
|
364
364
|
|
|
365
365
|
def run_sync(
|
|
366
366
|
self,
|
|
367
367
|
user_prompt: str,
|
|
368
368
|
*,
|
|
369
|
-
result_type: type[
|
|
369
|
+
result_type: type[RunResultDataT] | None = None,
|
|
370
370
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
371
371
|
model: models.Model | models.KnownModelName | None = None,
|
|
372
372
|
deps: AgentDepsT = None,
|
|
@@ -442,7 +442,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
442
442
|
self,
|
|
443
443
|
user_prompt: str,
|
|
444
444
|
*,
|
|
445
|
-
result_type: type[
|
|
445
|
+
result_type: type[RunResultDataT],
|
|
446
446
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
447
447
|
model: models.Model | models.KnownModelName | None = None,
|
|
448
448
|
deps: AgentDepsT = None,
|
|
@@ -450,14 +450,14 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
450
450
|
usage_limits: _usage.UsageLimits | None = None,
|
|
451
451
|
usage: _usage.Usage | None = None,
|
|
452
452
|
infer_name: bool = True,
|
|
453
|
-
) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT,
|
|
453
|
+
) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunResultDataT]]: ...
|
|
454
454
|
|
|
455
455
|
@asynccontextmanager
|
|
456
456
|
async def run_stream(
|
|
457
457
|
self,
|
|
458
458
|
user_prompt: str,
|
|
459
459
|
*,
|
|
460
|
-
result_type: type[
|
|
460
|
+
result_type: type[RunResultDataT] | None = None,
|
|
461
461
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
462
462
|
model: models.Model | models.KnownModelName | None = None,
|
|
463
463
|
deps: AgentDepsT = None,
|
|
@@ -572,7 +572,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
572
572
|
# there are result validators that might convert the result data from an overridden
|
|
573
573
|
# `result_type` to a type that is not valid as such.
|
|
574
574
|
result_validators = cast(
|
|
575
|
-
list[_result.ResultValidator[AgentDepsT,
|
|
575
|
+
list[_result.ResultValidator[AgentDepsT, RunResultDataT]], self._result_validators
|
|
576
576
|
)
|
|
577
577
|
|
|
578
578
|
yield result.StreamedRunResult(
|
|
@@ -999,7 +999,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
999
999
|
return model_
|
|
1000
1000
|
|
|
1001
1001
|
async def _prepare_model(
|
|
1002
|
-
self, run_context: RunContext[AgentDepsT], result_schema: _result.ResultSchema[
|
|
1002
|
+
self, run_context: RunContext[AgentDepsT], result_schema: _result.ResultSchema[RunResultDataT] | None
|
|
1003
1003
|
) -> models.AgentModel:
|
|
1004
1004
|
"""Build tools and create an agent model."""
|
|
1005
1005
|
function_tools: list[ToolDefinition] = []
|
|
@@ -1035,8 +1035,8 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1035
1035
|
)
|
|
1036
1036
|
|
|
1037
1037
|
def _prepare_result_schema(
|
|
1038
|
-
self, result_type: type[
|
|
1039
|
-
) -> _result.ResultSchema[
|
|
1038
|
+
self, result_type: type[RunResultDataT] | None
|
|
1039
|
+
) -> _result.ResultSchema[RunResultDataT] | None:
|
|
1040
1040
|
if result_type is not None:
|
|
1041
1041
|
if self._result_validators:
|
|
1042
1042
|
raise exceptions.UserError('Cannot set a custom run `result_type` when the agent has result validators')
|
|
@@ -1053,7 +1053,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1053
1053
|
run_context: RunContext[AgentDepsT],
|
|
1054
1054
|
) -> list[_messages.ModelMessage]:
|
|
1055
1055
|
try:
|
|
1056
|
-
ctx_messages =
|
|
1056
|
+
ctx_messages = get_captured_run_messages()
|
|
1057
1057
|
except LookupError:
|
|
1058
1058
|
messages: list[_messages.ModelMessage] = []
|
|
1059
1059
|
else:
|
|
@@ -1080,8 +1080,8 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1080
1080
|
self,
|
|
1081
1081
|
model_response: _messages.ModelResponse,
|
|
1082
1082
|
run_context: RunContext[AgentDepsT],
|
|
1083
|
-
result_schema: _result.ResultSchema[
|
|
1084
|
-
) -> tuple[_MarkFinalResult[
|
|
1083
|
+
result_schema: _result.ResultSchema[RunResultDataT] | None,
|
|
1084
|
+
) -> tuple[_MarkFinalResult[RunResultDataT] | None, list[_messages.ModelRequestPart]]:
|
|
1085
1085
|
"""Process a non-streamed response from the model.
|
|
1086
1086
|
|
|
1087
1087
|
Returns:
|
|
@@ -1110,11 +1110,11 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1110
1110
|
raise exceptions.UnexpectedModelBehavior('Received empty model response')
|
|
1111
1111
|
|
|
1112
1112
|
async def _handle_text_response(
|
|
1113
|
-
self, text: str, run_context: RunContext[AgentDepsT], result_schema: _result.ResultSchema[
|
|
1114
|
-
) -> tuple[_MarkFinalResult[
|
|
1113
|
+
self, text: str, run_context: RunContext[AgentDepsT], result_schema: _result.ResultSchema[RunResultDataT] | None
|
|
1114
|
+
) -> tuple[_MarkFinalResult[RunResultDataT] | None, list[_messages.ModelRequestPart]]:
|
|
1115
1115
|
"""Handle a plain text response from the model for non-streaming responses."""
|
|
1116
1116
|
if self._allow_text_result(result_schema):
|
|
1117
|
-
result_data_input = cast(
|
|
1117
|
+
result_data_input = cast(RunResultDataT, text)
|
|
1118
1118
|
try:
|
|
1119
1119
|
result_data = await self._validate_result(result_data_input, run_context, None)
|
|
1120
1120
|
except _result.ToolRetryError as e:
|
|
@@ -1133,13 +1133,13 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1133
1133
|
self,
|
|
1134
1134
|
tool_calls: list[_messages.ToolCallPart],
|
|
1135
1135
|
run_context: RunContext[AgentDepsT],
|
|
1136
|
-
result_schema: _result.ResultSchema[
|
|
1137
|
-
) -> tuple[_MarkFinalResult[
|
|
1136
|
+
result_schema: _result.ResultSchema[RunResultDataT] | None,
|
|
1137
|
+
) -> tuple[_MarkFinalResult[RunResultDataT] | None, list[_messages.ModelRequestPart]]:
|
|
1138
1138
|
"""Handle a structured response containing tool calls from the model for non-streaming responses."""
|
|
1139
1139
|
assert tool_calls, 'Expected at least one tool call'
|
|
1140
1140
|
|
|
1141
1141
|
# first look for the result tool call
|
|
1142
|
-
final_result: _MarkFinalResult[
|
|
1142
|
+
final_result: _MarkFinalResult[RunResultDataT] | None = None
|
|
1143
1143
|
|
|
1144
1144
|
parts: list[_messages.ModelRequestPart] = []
|
|
1145
1145
|
if result_schema is not None:
|
|
@@ -1168,7 +1168,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1168
1168
|
tool_calls: list[_messages.ToolCallPart],
|
|
1169
1169
|
result_tool_name: str | None,
|
|
1170
1170
|
run_context: RunContext[AgentDepsT],
|
|
1171
|
-
result_schema: _result.ResultSchema[
|
|
1171
|
+
result_schema: _result.ResultSchema[RunResultDataT] | None,
|
|
1172
1172
|
) -> list[_messages.ModelRequestPart]:
|
|
1173
1173
|
"""Process function (non-result) tool calls in parallel.
|
|
1174
1174
|
|
|
@@ -1227,7 +1227,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1227
1227
|
self,
|
|
1228
1228
|
streamed_response: models.StreamedResponse,
|
|
1229
1229
|
run_context: RunContext[AgentDepsT],
|
|
1230
|
-
result_schema: _result.ResultSchema[
|
|
1230
|
+
result_schema: _result.ResultSchema[RunResultDataT] | None,
|
|
1231
1231
|
) -> _MarkFinalResult[models.StreamedResponse] | tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]]:
|
|
1232
1232
|
"""Process a streamed response from the model.
|
|
1233
1233
|
|
|
@@ -1282,15 +1282,15 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1282
1282
|
|
|
1283
1283
|
async def _validate_result(
|
|
1284
1284
|
self,
|
|
1285
|
-
result_data:
|
|
1285
|
+
result_data: RunResultDataT,
|
|
1286
1286
|
run_context: RunContext[AgentDepsT],
|
|
1287
1287
|
tool_call: _messages.ToolCallPart | None,
|
|
1288
|
-
) ->
|
|
1288
|
+
) -> RunResultDataT:
|
|
1289
1289
|
if self._result_validators:
|
|
1290
1290
|
agent_result_data = cast(ResultDataT, result_data)
|
|
1291
1291
|
for validator in self._result_validators:
|
|
1292
1292
|
agent_result_data = await validator.validate(agent_result_data, tool_call, run_context)
|
|
1293
|
-
return cast(
|
|
1293
|
+
return cast(RunResultDataT, agent_result_data)
|
|
1294
1294
|
else:
|
|
1295
1295
|
return result_data
|
|
1296
1296
|
|
|
@@ -1315,7 +1315,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1315
1315
|
def _unknown_tool(
|
|
1316
1316
|
self,
|
|
1317
1317
|
tool_name: str,
|
|
1318
|
-
result_schema: _result.ResultSchema[
|
|
1318
|
+
result_schema: _result.ResultSchema[RunResultDataT] | None,
|
|
1319
1319
|
) -> _messages.RetryPromptPart:
|
|
1320
1320
|
names = list(self._function_tools.keys())
|
|
1321
1321
|
if result_schema:
|
|
@@ -1358,7 +1358,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1358
1358
|
return
|
|
1359
1359
|
|
|
1360
1360
|
@staticmethod
|
|
1361
|
-
def _allow_text_result(result_schema: _result.ResultSchema[
|
|
1361
|
+
def _allow_text_result(result_schema: _result.ResultSchema[RunResultDataT] | None) -> bool:
|
|
1362
1362
|
return result_schema is None or result_schema.allow_text_result
|
|
1363
1363
|
|
|
1364
1364
|
@property
|
|
@@ -1413,6 +1413,10 @@ def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]:
|
|
|
1413
1413
|
_messages_ctx_var.reset(token)
|
|
1414
1414
|
|
|
1415
1415
|
|
|
1416
|
+
def get_captured_run_messages() -> _RunMessages:
|
|
1417
|
+
return _messages_ctx_var.get()
|
|
1418
|
+
|
|
1419
|
+
|
|
1416
1420
|
@dataclasses.dataclass
|
|
1417
1421
|
class _MarkFinalResult(Generic[ResultDataT]):
|
|
1418
1422
|
"""Marker class to indicate that the result is the final result.
|
pydantic_ai/messages.py
CHANGED
|
@@ -6,7 +6,6 @@ from typing import Annotated, Any, Literal, Union, cast, overload
|
|
|
6
6
|
|
|
7
7
|
import pydantic
|
|
8
8
|
import pydantic_core
|
|
9
|
-
from typing_extensions import Self, assert_never
|
|
10
9
|
|
|
11
10
|
from ._utils import now_utc as _now_utc
|
|
12
11
|
from .exceptions import UnexpectedModelBehavior
|
|
@@ -168,22 +167,6 @@ class TextPart:
|
|
|
168
167
|
return bool(self.content)
|
|
169
168
|
|
|
170
169
|
|
|
171
|
-
@dataclass
|
|
172
|
-
class ArgsJson:
|
|
173
|
-
"""Tool arguments as a JSON string."""
|
|
174
|
-
|
|
175
|
-
args_json: str
|
|
176
|
-
"""A JSON string of arguments."""
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
@dataclass
|
|
180
|
-
class ArgsDict:
|
|
181
|
-
"""Tool arguments as a Python dictionary."""
|
|
182
|
-
|
|
183
|
-
args_dict: dict[str, Any]
|
|
184
|
-
"""A python dictionary of arguments."""
|
|
185
|
-
|
|
186
|
-
|
|
187
170
|
@dataclass
|
|
188
171
|
class ToolCallPart:
|
|
189
172
|
"""A tool call from a model."""
|
|
@@ -191,10 +174,10 @@ class ToolCallPart:
|
|
|
191
174
|
tool_name: str
|
|
192
175
|
"""The name of the tool to call."""
|
|
193
176
|
|
|
194
|
-
args:
|
|
177
|
+
args: str | dict[str, Any]
|
|
195
178
|
"""The arguments to pass to the tool.
|
|
196
179
|
|
|
197
|
-
|
|
180
|
+
This is stored either as a JSON string or a Python dictionary depending on how data was received.
|
|
198
181
|
"""
|
|
199
182
|
|
|
200
183
|
tool_call_id: str | None = None
|
|
@@ -203,24 +186,14 @@ class ToolCallPart:
|
|
|
203
186
|
part_kind: Literal['tool-call'] = 'tool-call'
|
|
204
187
|
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
205
188
|
|
|
206
|
-
@classmethod
|
|
207
|
-
def from_raw_args(cls, tool_name: str, args: str | dict[str, Any], tool_call_id: str | None = None) -> Self:
|
|
208
|
-
"""Create a `ToolCallPart` from raw arguments, converting them to `ArgsJson` or `ArgsDict`."""
|
|
209
|
-
if isinstance(args, str):
|
|
210
|
-
return cls(tool_name, ArgsJson(args), tool_call_id)
|
|
211
|
-
elif isinstance(args, dict):
|
|
212
|
-
return cls(tool_name, ArgsDict(args), tool_call_id)
|
|
213
|
-
else:
|
|
214
|
-
assert_never(args)
|
|
215
|
-
|
|
216
189
|
def args_as_dict(self) -> dict[str, Any]:
|
|
217
190
|
"""Return the arguments as a Python dictionary.
|
|
218
191
|
|
|
219
192
|
This is just for convenience with models that require dicts as input.
|
|
220
193
|
"""
|
|
221
|
-
if isinstance(self.args,
|
|
222
|
-
return self.args
|
|
223
|
-
args = pydantic_core.from_json(self.args
|
|
194
|
+
if isinstance(self.args, dict):
|
|
195
|
+
return self.args
|
|
196
|
+
args = pydantic_core.from_json(self.args)
|
|
224
197
|
assert isinstance(args, dict), 'args should be a dict'
|
|
225
198
|
return cast(dict[str, Any], args)
|
|
226
199
|
|
|
@@ -229,16 +202,18 @@ class ToolCallPart:
|
|
|
229
202
|
|
|
230
203
|
This is just for convenience with models that require JSON strings as input.
|
|
231
204
|
"""
|
|
232
|
-
if isinstance(self.args,
|
|
233
|
-
return self.args
|
|
234
|
-
return pydantic_core.to_json(self.args
|
|
205
|
+
if isinstance(self.args, str):
|
|
206
|
+
return self.args
|
|
207
|
+
return pydantic_core.to_json(self.args).decode()
|
|
235
208
|
|
|
236
209
|
def has_content(self) -> bool:
|
|
237
210
|
"""Return `True` if the arguments contain any data."""
|
|
238
|
-
if isinstance(self.args,
|
|
239
|
-
return
|
|
211
|
+
if isinstance(self.args, dict):
|
|
212
|
+
# TODO: This should probably return True if you have the value False, or 0, etc.
|
|
213
|
+
# It makes sense to me to ignore empty strings, but not sure about empty lists or dicts
|
|
214
|
+
return any(self.args.values())
|
|
240
215
|
else:
|
|
241
|
-
return bool(self.args
|
|
216
|
+
return bool(self.args)
|
|
242
217
|
|
|
243
218
|
|
|
244
219
|
ModelResponsePart = Annotated[Union[TextPart, ToolCallPart], pydantic.Discriminator('part_kind')]
|
|
@@ -331,7 +306,7 @@ class ToolCallPartDelta:
|
|
|
331
306
|
if self.tool_name_delta is None or self.args_delta is None:
|
|
332
307
|
return None
|
|
333
308
|
|
|
334
|
-
return ToolCallPart
|
|
309
|
+
return ToolCallPart(
|
|
335
310
|
self.tool_name_delta,
|
|
336
311
|
self.args_delta,
|
|
337
312
|
self.tool_call_id,
|
|
@@ -396,7 +371,7 @@ class ToolCallPartDelta:
|
|
|
396
371
|
|
|
397
372
|
# If we now have enough data to create a full ToolCallPart, do so
|
|
398
373
|
if delta.tool_name_delta is not None and delta.args_delta is not None:
|
|
399
|
-
return ToolCallPart
|
|
374
|
+
return ToolCallPart(
|
|
400
375
|
delta.tool_name_delta,
|
|
401
376
|
delta.args_delta,
|
|
402
377
|
delta.tool_call_id,
|
|
@@ -412,15 +387,15 @@ class ToolCallPartDelta:
|
|
|
412
387
|
part = replace(part, tool_name=tool_name)
|
|
413
388
|
|
|
414
389
|
if isinstance(self.args_delta, str):
|
|
415
|
-
if not isinstance(part.args,
|
|
390
|
+
if not isinstance(part.args, str):
|
|
416
391
|
raise UnexpectedModelBehavior(f'Cannot apply JSON deltas to non-JSON tool arguments ({part=}, {self=})')
|
|
417
|
-
updated_json = part.args
|
|
418
|
-
part = replace(part, args=
|
|
392
|
+
updated_json = part.args + self.args_delta
|
|
393
|
+
part = replace(part, args=updated_json)
|
|
419
394
|
elif isinstance(self.args_delta, dict):
|
|
420
|
-
if not isinstance(part.args,
|
|
395
|
+
if not isinstance(part.args, dict):
|
|
421
396
|
raise UnexpectedModelBehavior(f'Cannot apply dict deltas to non-dict tool arguments ({part=}, {self=})')
|
|
422
|
-
updated_dict = {**(part.args
|
|
423
|
-
part = replace(part, args=
|
|
397
|
+
updated_dict = {**(part.args or {}), **self.args_delta}
|
|
398
|
+
part = replace(part, args=updated_dict)
|
|
424
399
|
|
|
425
400
|
if self.tool_call_id:
|
|
426
401
|
# Replace the tool_call_id entirely if given
|
pydantic_ai/models/__init__.py
CHANGED
|
@@ -12,9 +12,10 @@ from contextlib import asynccontextmanager, contextmanager
|
|
|
12
12
|
from dataclasses import dataclass, field
|
|
13
13
|
from datetime import datetime
|
|
14
14
|
from functools import cache
|
|
15
|
-
from typing import TYPE_CHECKING
|
|
15
|
+
from typing import TYPE_CHECKING
|
|
16
16
|
|
|
17
17
|
import httpx
|
|
18
|
+
from typing_extensions import Literal
|
|
18
19
|
|
|
19
20
|
from .._parts_manager import ModelResponsePartsManager
|
|
20
21
|
from ..exceptions import UserError
|
|
@@ -27,58 +28,6 @@ if TYPE_CHECKING:
|
|
|
27
28
|
|
|
28
29
|
|
|
29
30
|
KnownModelName = Literal[
|
|
30
|
-
'openai:gpt-4o',
|
|
31
|
-
'openai:gpt-4o-mini',
|
|
32
|
-
'openai:gpt-4-turbo',
|
|
33
|
-
'openai:gpt-4',
|
|
34
|
-
'openai:o1-preview',
|
|
35
|
-
'openai:o1-mini',
|
|
36
|
-
'openai:o1',
|
|
37
|
-
'openai:gpt-3.5-turbo',
|
|
38
|
-
'groq:llama-3.3-70b-versatile',
|
|
39
|
-
'groq:llama-3.1-70b-versatile',
|
|
40
|
-
'groq:llama3-groq-70b-8192-tool-use-preview',
|
|
41
|
-
'groq:llama3-groq-8b-8192-tool-use-preview',
|
|
42
|
-
'groq:llama-3.1-70b-specdec',
|
|
43
|
-
'groq:llama-3.1-8b-instant',
|
|
44
|
-
'groq:llama-3.2-1b-preview',
|
|
45
|
-
'groq:llama-3.2-3b-preview',
|
|
46
|
-
'groq:llama-3.2-11b-vision-preview',
|
|
47
|
-
'groq:llama-3.2-90b-vision-preview',
|
|
48
|
-
'groq:llama3-70b-8192',
|
|
49
|
-
'groq:llama3-8b-8192',
|
|
50
|
-
'groq:mixtral-8x7b-32768',
|
|
51
|
-
'groq:gemma2-9b-it',
|
|
52
|
-
'groq:gemma-7b-it',
|
|
53
|
-
'google-gla:gemini-1.5-flash',
|
|
54
|
-
'google-gla:gemini-1.5-pro',
|
|
55
|
-
'google-gla:gemini-2.0-flash-exp',
|
|
56
|
-
'google-vertex:gemini-1.5-flash',
|
|
57
|
-
'google-vertex:gemini-1.5-pro',
|
|
58
|
-
'google-vertex:gemini-2.0-flash-exp',
|
|
59
|
-
'mistral:mistral-small-latest',
|
|
60
|
-
'mistral:mistral-large-latest',
|
|
61
|
-
'mistral:codestral-latest',
|
|
62
|
-
'mistral:mistral-moderation-latest',
|
|
63
|
-
'ollama:codellama',
|
|
64
|
-
'ollama:deepseek-r1',
|
|
65
|
-
'ollama:gemma',
|
|
66
|
-
'ollama:gemma2',
|
|
67
|
-
'ollama:llama3',
|
|
68
|
-
'ollama:llama3.1',
|
|
69
|
-
'ollama:llama3.2',
|
|
70
|
-
'ollama:llama3.2-vision',
|
|
71
|
-
'ollama:llama3.3',
|
|
72
|
-
'ollama:mistral',
|
|
73
|
-
'ollama:mistral-nemo',
|
|
74
|
-
'ollama:mixtral',
|
|
75
|
-
'ollama:phi3',
|
|
76
|
-
'ollama:phi4',
|
|
77
|
-
'ollama:qwq',
|
|
78
|
-
'ollama:qwen',
|
|
79
|
-
'ollama:qwen2',
|
|
80
|
-
'ollama:qwen2.5',
|
|
81
|
-
'ollama:starcoder2',
|
|
82
31
|
'anthropic:claude-3-5-haiku-latest',
|
|
83
32
|
'anthropic:claude-3-5-sonnet-latest',
|
|
84
33
|
'anthropic:claude-3-opus-latest',
|
|
@@ -98,6 +47,104 @@ KnownModelName = Literal[
|
|
|
98
47
|
'cohere:command-r-plus-04-2024',
|
|
99
48
|
'cohere:command-r-plus-08-2024',
|
|
100
49
|
'cohere:command-r7b-12-2024',
|
|
50
|
+
'google-gla:gemini-1.0-pro',
|
|
51
|
+
'google-gla:gemini-1.5-flash',
|
|
52
|
+
'google-gla:gemini-1.5-flash-8b',
|
|
53
|
+
'google-gla:gemini-1.5-pro',
|
|
54
|
+
'google-gla:gemini-2.0-flash-exp',
|
|
55
|
+
'google-vertex:gemini-1.0-pro',
|
|
56
|
+
'google-vertex:gemini-1.5-flash',
|
|
57
|
+
'google-vertex:gemini-1.5-flash-8b',
|
|
58
|
+
'google-vertex:gemini-1.5-pro',
|
|
59
|
+
'google-vertex:gemini-2.0-flash-exp',
|
|
60
|
+
'gpt-3.5-turbo',
|
|
61
|
+
'gpt-3.5-turbo-0125',
|
|
62
|
+
'gpt-3.5-turbo-0301',
|
|
63
|
+
'gpt-3.5-turbo-0613',
|
|
64
|
+
'gpt-3.5-turbo-1106',
|
|
65
|
+
'gpt-3.5-turbo-16k',
|
|
66
|
+
'gpt-3.5-turbo-16k-0613',
|
|
67
|
+
'gpt-4',
|
|
68
|
+
'gpt-4-0125-preview',
|
|
69
|
+
'gpt-4-0314',
|
|
70
|
+
'gpt-4-0613',
|
|
71
|
+
'gpt-4-1106-preview',
|
|
72
|
+
'gpt-4-32k',
|
|
73
|
+
'gpt-4-32k-0314',
|
|
74
|
+
'gpt-4-32k-0613',
|
|
75
|
+
'gpt-4-turbo',
|
|
76
|
+
'gpt-4-turbo-2024-04-09',
|
|
77
|
+
'gpt-4-turbo-preview',
|
|
78
|
+
'gpt-4-vision-preview',
|
|
79
|
+
'gpt-4o',
|
|
80
|
+
'gpt-4o-2024-05-13',
|
|
81
|
+
'gpt-4o-2024-08-06',
|
|
82
|
+
'gpt-4o-2024-11-20',
|
|
83
|
+
'gpt-4o-audio-preview',
|
|
84
|
+
'gpt-4o-audio-preview-2024-10-01',
|
|
85
|
+
'gpt-4o-audio-preview-2024-12-17',
|
|
86
|
+
'gpt-4o-mini',
|
|
87
|
+
'gpt-4o-mini-2024-07-18',
|
|
88
|
+
'gpt-4o-mini-audio-preview',
|
|
89
|
+
'gpt-4o-mini-audio-preview-2024-12-17',
|
|
90
|
+
'groq:gemma2-9b-it',
|
|
91
|
+
'groq:llama-3.1-8b-instant',
|
|
92
|
+
'groq:llama-3.2-11b-vision-preview',
|
|
93
|
+
'groq:llama-3.2-1b-preview',
|
|
94
|
+
'groq:llama-3.2-3b-preview',
|
|
95
|
+
'groq:llama-3.2-90b-vision-preview',
|
|
96
|
+
'groq:llama-3.3-70b-specdec',
|
|
97
|
+
'groq:llama-3.3-70b-versatile',
|
|
98
|
+
'groq:llama3-70b-8192',
|
|
99
|
+
'groq:llama3-8b-8192',
|
|
100
|
+
'groq:mixtral-8x7b-32768',
|
|
101
|
+
'mistral:codestral-latest',
|
|
102
|
+
'mistral:mistral-large-latest',
|
|
103
|
+
'mistral:mistral-moderation-latest',
|
|
104
|
+
'mistral:mistral-small-latest',
|
|
105
|
+
'o1',
|
|
106
|
+
'o1-2024-12-17',
|
|
107
|
+
'o1-mini',
|
|
108
|
+
'o1-mini-2024-09-12',
|
|
109
|
+
'o1-preview',
|
|
110
|
+
'o1-preview-2024-09-12',
|
|
111
|
+
'openai:chatgpt-4o-latest',
|
|
112
|
+
'openai:gpt-3.5-turbo',
|
|
113
|
+
'openai:gpt-3.5-turbo-0125',
|
|
114
|
+
'openai:gpt-3.5-turbo-0301',
|
|
115
|
+
'openai:gpt-3.5-turbo-0613',
|
|
116
|
+
'openai:gpt-3.5-turbo-1106',
|
|
117
|
+
'openai:gpt-3.5-turbo-16k',
|
|
118
|
+
'openai:gpt-3.5-turbo-16k-0613',
|
|
119
|
+
'openai:gpt-4',
|
|
120
|
+
'openai:gpt-4-0125-preview',
|
|
121
|
+
'openai:gpt-4-0314',
|
|
122
|
+
'openai:gpt-4-0613',
|
|
123
|
+
'openai:gpt-4-1106-preview',
|
|
124
|
+
'openai:gpt-4-32k',
|
|
125
|
+
'openai:gpt-4-32k-0314',
|
|
126
|
+
'openai:gpt-4-32k-0613',
|
|
127
|
+
'openai:gpt-4-turbo',
|
|
128
|
+
'openai:gpt-4-turbo-2024-04-09',
|
|
129
|
+
'openai:gpt-4-turbo-preview',
|
|
130
|
+
'openai:gpt-4-vision-preview',
|
|
131
|
+
'openai:gpt-4o',
|
|
132
|
+
'openai:gpt-4o-2024-05-13',
|
|
133
|
+
'openai:gpt-4o-2024-08-06',
|
|
134
|
+
'openai:gpt-4o-2024-11-20',
|
|
135
|
+
'openai:gpt-4o-audio-preview',
|
|
136
|
+
'openai:gpt-4o-audio-preview-2024-10-01',
|
|
137
|
+
'openai:gpt-4o-audio-preview-2024-12-17',
|
|
138
|
+
'openai:gpt-4o-mini',
|
|
139
|
+
'openai:gpt-4o-mini-2024-07-18',
|
|
140
|
+
'openai:gpt-4o-mini-audio-preview',
|
|
141
|
+
'openai:gpt-4o-mini-audio-preview-2024-12-17',
|
|
142
|
+
'openai:o1',
|
|
143
|
+
'openai:o1-2024-12-17',
|
|
144
|
+
'openai:o1-mini',
|
|
145
|
+
'openai:o1-mini-2024-09-12',
|
|
146
|
+
'openai:o1-preview',
|
|
147
|
+
'openai:o1-preview-2024-09-12',
|
|
101
148
|
'test',
|
|
102
149
|
]
|
|
103
150
|
"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
|
|
@@ -291,10 +338,6 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
|
|
291
338
|
from .mistral import MistralModel
|
|
292
339
|
|
|
293
340
|
return MistralModel(model[8:])
|
|
294
|
-
elif model.startswith('ollama:'):
|
|
295
|
-
from .ollama import OllamaModel
|
|
296
|
-
|
|
297
|
-
return OllamaModel(model[7:])
|
|
298
341
|
elif model.startswith('anthropic'):
|
|
299
342
|
from .anthropic import AnthropicModel
|
|
300
343
|
|