pydantic-ai-slim 0.0.55__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_ai/__init__.py +10 -3
- pydantic_ai/_agent_graph.py +67 -55
- pydantic_ai/_cli.py +1 -2
- pydantic_ai/{_result.py → _output.py} +69 -47
- pydantic_ai/_utils.py +20 -0
- pydantic_ai/agent.py +501 -161
- pydantic_ai/format_as_xml.py +6 -113
- pydantic_ai/format_prompt.py +116 -0
- pydantic_ai/messages.py +104 -21
- pydantic_ai/models/__init__.py +24 -4
- pydantic_ai/models/_json_schema.py +156 -0
- pydantic_ai/models/anthropic.py +5 -3
- pydantic_ai/models/bedrock.py +100 -22
- pydantic_ai/models/cohere.py +48 -44
- pydantic_ai/models/fallback.py +2 -1
- pydantic_ai/models/function.py +8 -8
- pydantic_ai/models/gemini.py +65 -75
- pydantic_ai/models/groq.py +32 -28
- pydantic_ai/models/instrumented.py +4 -4
- pydantic_ai/models/mistral.py +62 -58
- pydantic_ai/models/openai.py +110 -158
- pydantic_ai/models/test.py +45 -46
- pydantic_ai/result.py +203 -90
- pydantic_ai/tools.py +3 -3
- {pydantic_ai_slim-0.0.55.dist-info → pydantic_ai_slim-0.1.0.dist-info}/METADATA +5 -5
- pydantic_ai_slim-0.1.0.dist-info/RECORD +53 -0
- pydantic_ai_slim-0.0.55.dist-info/RECORD +0 -51
- {pydantic_ai_slim-0.0.55.dist-info → pydantic_ai_slim-0.1.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.0.55.dist-info → pydantic_ai_slim-0.1.0.dist-info}/entry_points.txt +0 -0
pydantic_ai/models/test.py
CHANGED
|
@@ -22,9 +22,9 @@ from ..messages import (
|
|
|
22
22
|
ToolCallPart,
|
|
23
23
|
ToolReturnPart,
|
|
24
24
|
)
|
|
25
|
-
from ..result import Usage
|
|
26
25
|
from ..settings import ModelSettings
|
|
27
26
|
from ..tools import ToolDefinition
|
|
27
|
+
from ..usage import Usage
|
|
28
28
|
from . import (
|
|
29
29
|
Model,
|
|
30
30
|
ModelRequestParameters,
|
|
@@ -34,15 +34,15 @@ from .function import _estimate_string_tokens, _estimate_usage # pyright: ignor
|
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
@dataclass
|
|
37
|
-
class
|
|
38
|
-
"""A private wrapper class to tag
|
|
37
|
+
class _WrappedTextOutput:
|
|
38
|
+
"""A private wrapper class to tag an output that came from the custom_output_text field."""
|
|
39
39
|
|
|
40
40
|
value: str | None
|
|
41
41
|
|
|
42
42
|
|
|
43
43
|
@dataclass
|
|
44
|
-
class
|
|
45
|
-
"""A wrapper class to tag
|
|
44
|
+
class _WrappedToolOutput:
|
|
45
|
+
"""A wrapper class to tag an output that came from the custom_output_args field."""
|
|
46
46
|
|
|
47
47
|
value: Any | None
|
|
48
48
|
|
|
@@ -65,16 +65,16 @@ class TestModel(Model):
|
|
|
65
65
|
|
|
66
66
|
call_tools: list[str] | Literal['all'] = 'all'
|
|
67
67
|
"""List of tools to call. If `'all'`, all tools will be called."""
|
|
68
|
-
|
|
69
|
-
"""If set, this text is returned as the final
|
|
70
|
-
|
|
71
|
-
"""If set, these args will be passed to the
|
|
68
|
+
custom_output_text: str | None = None
|
|
69
|
+
"""If set, this text is returned as the final output."""
|
|
70
|
+
custom_output_args: Any | None = None
|
|
71
|
+
"""If set, these args will be passed to the output tool."""
|
|
72
72
|
seed: int = 0
|
|
73
73
|
"""Seed for generating random data."""
|
|
74
74
|
last_model_request_parameters: ModelRequestParameters | None = field(default=None, init=False)
|
|
75
75
|
"""The last ModelRequestParameters passed to the model in a request.
|
|
76
76
|
|
|
77
|
-
The ModelRequestParameters contains information about the function and
|
|
77
|
+
The ModelRequestParameters contains information about the function and output tools available during request handling.
|
|
78
78
|
|
|
79
79
|
This is set when a request is made, so will reflect the function tools from the last step of the last run.
|
|
80
80
|
"""
|
|
@@ -88,7 +88,6 @@ class TestModel(Model):
|
|
|
88
88
|
model_request_parameters: ModelRequestParameters,
|
|
89
89
|
) -> tuple[ModelResponse, Usage]:
|
|
90
90
|
self.last_model_request_parameters = model_request_parameters
|
|
91
|
-
|
|
92
91
|
model_response = self._request(messages, model_settings, model_request_parameters)
|
|
93
92
|
usage = _estimate_usage([*messages, model_response])
|
|
94
93
|
return model_response, usage
|
|
@@ -128,29 +127,29 @@ class TestModel(Model):
|
|
|
128
127
|
tools_to_call = (function_tools_lookup[name] for name in self.call_tools)
|
|
129
128
|
return [(r.name, r) for r in tools_to_call]
|
|
130
129
|
|
|
131
|
-
def
|
|
132
|
-
if self.
|
|
133
|
-
assert model_request_parameters.
|
|
134
|
-
'Plain response not allowed, but `
|
|
130
|
+
def _get_output(self, model_request_parameters: ModelRequestParameters) -> _WrappedTextOutput | _WrappedToolOutput:
|
|
131
|
+
if self.custom_output_text is not None:
|
|
132
|
+
assert model_request_parameters.allow_text_output, (
|
|
133
|
+
'Plain response not allowed, but `custom_output_text` is set.'
|
|
135
134
|
)
|
|
136
|
-
assert self.
|
|
137
|
-
return
|
|
138
|
-
elif self.
|
|
139
|
-
assert model_request_parameters.
|
|
140
|
-
'No
|
|
135
|
+
assert self.custom_output_args is None, 'Cannot set both `custom_output_text` and `custom_output_args`.'
|
|
136
|
+
return _WrappedTextOutput(self.custom_output_text)
|
|
137
|
+
elif self.custom_output_args is not None:
|
|
138
|
+
assert model_request_parameters.output_tools is not None, (
|
|
139
|
+
'No output tools provided, but `custom_output_args` is set.'
|
|
141
140
|
)
|
|
142
|
-
|
|
141
|
+
output_tool = model_request_parameters.output_tools[0]
|
|
143
142
|
|
|
144
|
-
if k :=
|
|
145
|
-
return
|
|
143
|
+
if k := output_tool.outer_typed_dict_key:
|
|
144
|
+
return _WrappedToolOutput({k: self.custom_output_args})
|
|
146
145
|
else:
|
|
147
|
-
return
|
|
148
|
-
elif model_request_parameters.
|
|
149
|
-
return
|
|
150
|
-
elif model_request_parameters.
|
|
151
|
-
return
|
|
146
|
+
return _WrappedToolOutput(self.custom_output_args)
|
|
147
|
+
elif model_request_parameters.allow_text_output:
|
|
148
|
+
return _WrappedTextOutput(None)
|
|
149
|
+
elif model_request_parameters.output_tools:
|
|
150
|
+
return _WrappedToolOutput(None)
|
|
152
151
|
else:
|
|
153
|
-
return
|
|
152
|
+
return _WrappedTextOutput(None) # pragma: no cover
|
|
154
153
|
|
|
155
154
|
def _request(
|
|
156
155
|
self,
|
|
@@ -159,8 +158,8 @@ class TestModel(Model):
|
|
|
159
158
|
model_request_parameters: ModelRequestParameters,
|
|
160
159
|
) -> ModelResponse:
|
|
161
160
|
tool_calls = self._get_tool_calls(model_request_parameters)
|
|
162
|
-
|
|
163
|
-
|
|
161
|
+
output_wrapper = self._get_output(model_request_parameters)
|
|
162
|
+
output_tools = model_request_parameters.output_tools
|
|
164
163
|
|
|
165
164
|
# if there are tools, the first thing we want to do is call all of them
|
|
166
165
|
if tool_calls and not any(isinstance(m, ModelResponse) for m in messages):
|
|
@@ -176,29 +175,29 @@ class TestModel(Model):
|
|
|
176
175
|
# check if there are any retry prompts, if so retry them
|
|
177
176
|
new_retry_names = {p.tool_name for p in last_message.parts if isinstance(p, RetryPromptPart)}
|
|
178
177
|
if new_retry_names:
|
|
179
|
-
# Handle retries for both function tools and
|
|
178
|
+
# Handle retries for both function tools and output tools
|
|
180
179
|
# Check function tools first
|
|
181
180
|
retry_parts: list[ModelResponsePart] = [
|
|
182
181
|
ToolCallPart(name, self.gen_tool_args(args)) for name, args in tool_calls if name in new_retry_names
|
|
183
182
|
]
|
|
184
|
-
# Check
|
|
185
|
-
if
|
|
183
|
+
# Check output tools
|
|
184
|
+
if output_tools:
|
|
186
185
|
retry_parts.extend(
|
|
187
186
|
[
|
|
188
187
|
ToolCallPart(
|
|
189
188
|
tool.name,
|
|
190
|
-
|
|
191
|
-
if isinstance(
|
|
189
|
+
output_wrapper.value
|
|
190
|
+
if isinstance(output_wrapper, _WrappedToolOutput) and output_wrapper.value is not None
|
|
192
191
|
else self.gen_tool_args(tool),
|
|
193
192
|
)
|
|
194
|
-
for tool in
|
|
193
|
+
for tool in output_tools
|
|
195
194
|
if tool.name in new_retry_names
|
|
196
195
|
]
|
|
197
196
|
)
|
|
198
197
|
return ModelResponse(parts=retry_parts, model_name=self._model_name)
|
|
199
198
|
|
|
200
|
-
if isinstance(
|
|
201
|
-
if (response_text :=
|
|
199
|
+
if isinstance(output_wrapper, _WrappedTextOutput):
|
|
200
|
+
if (response_text := output_wrapper.value) is None:
|
|
202
201
|
# build up details of tool responses
|
|
203
202
|
output: dict[str, Any] = {}
|
|
204
203
|
for message in messages:
|
|
@@ -215,16 +214,16 @@ class TestModel(Model):
|
|
|
215
214
|
else:
|
|
216
215
|
return ModelResponse(parts=[TextPart(response_text)], model_name=self._model_name)
|
|
217
216
|
else:
|
|
218
|
-
assert
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
if
|
|
217
|
+
assert output_tools, 'No output tools provided'
|
|
218
|
+
custom_output_args = output_wrapper.value
|
|
219
|
+
output_tool = output_tools[self.seed % len(output_tools)]
|
|
220
|
+
if custom_output_args is not None:
|
|
222
221
|
return ModelResponse(
|
|
223
|
-
parts=[ToolCallPart(
|
|
222
|
+
parts=[ToolCallPart(output_tool.name, custom_output_args)], model_name=self._model_name
|
|
224
223
|
)
|
|
225
224
|
else:
|
|
226
|
-
response_args = self.gen_tool_args(
|
|
227
|
-
return ModelResponse(parts=[ToolCallPart(
|
|
225
|
+
response_args = self.gen_tool_args(output_tool)
|
|
226
|
+
return ModelResponse(parts=[ToolCallPart(output_tool.name, response_args)], model_name=self._model_name)
|
|
228
227
|
|
|
229
228
|
|
|
230
229
|
@dataclass
|