pydantic-ai-slim 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_pydantic.py +13 -29
- pydantic_ai/_result.py +52 -38
- pydantic_ai/_system_prompt.py +1 -1
- pydantic_ai/_utils.py +20 -8
- pydantic_ai/agent.py +431 -167
- pydantic_ai/messages.py +90 -48
- pydantic_ai/models/__init__.py +59 -42
- pydantic_ai/models/anthropic.py +344 -0
- pydantic_ai/models/function.py +66 -44
- pydantic_ai/models/gemini.py +160 -117
- pydantic_ai/models/groq.py +125 -108
- pydantic_ai/models/mistral.py +680 -0
- pydantic_ai/models/ollama.py +116 -0
- pydantic_ai/models/openai.py +145 -114
- pydantic_ai/models/test.py +109 -77
- pydantic_ai/models/vertexai.py +14 -9
- pydantic_ai/result.py +35 -37
- pydantic_ai/settings.py +72 -0
- pydantic_ai/tools.py +140 -45
- {pydantic_ai_slim-0.0.11.dist-info → pydantic_ai_slim-0.0.13.dist-info}/METADATA +8 -3
- pydantic_ai_slim-0.0.13.dist-info/RECORD +26 -0
- {pydantic_ai_slim-0.0.11.dist-info → pydantic_ai_slim-0.0.13.dist-info}/WHEEL +1 -1
- pydantic_ai_slim-0.0.11.dist-info/RECORD +0 -22
pydantic_ai/agent.py
CHANGED
|
@@ -7,7 +7,7 @@ from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
|
|
|
7
7
|
from contextlib import asynccontextmanager, contextmanager
|
|
8
8
|
from dataclasses import dataclass, field
|
|
9
9
|
from types import FrameType
|
|
10
|
-
from typing import Any, Callable, Generic, cast, final, overload
|
|
10
|
+
from typing import Any, Callable, Generic, Literal, cast, final, overload
|
|
11
11
|
|
|
12
12
|
import logfire_api
|
|
13
13
|
from typing_extensions import assert_never
|
|
@@ -22,13 +22,30 @@ from . import (
|
|
|
22
22
|
result,
|
|
23
23
|
)
|
|
24
24
|
from .result import ResultData
|
|
25
|
-
from .
|
|
25
|
+
from .settings import ModelSettings, merge_model_settings
|
|
26
|
+
from .tools import (
|
|
27
|
+
AgentDeps,
|
|
28
|
+
RunContext,
|
|
29
|
+
Tool,
|
|
30
|
+
ToolDefinition,
|
|
31
|
+
ToolFuncContext,
|
|
32
|
+
ToolFuncEither,
|
|
33
|
+
ToolFuncPlain,
|
|
34
|
+
ToolParams,
|
|
35
|
+
ToolPrepareFunc,
|
|
36
|
+
)
|
|
26
37
|
|
|
27
38
|
__all__ = ('Agent',)
|
|
28
39
|
|
|
29
40
|
_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
|
|
30
41
|
|
|
31
42
|
NoneType = type(None)
|
|
43
|
+
EndStrategy = Literal['early', 'exhaustive']
|
|
44
|
+
"""The strategy for handling multiple tool calls when a final result is found.
|
|
45
|
+
|
|
46
|
+
- `'early'`: Stop processing other tool calls once a final result is found
|
|
47
|
+
- `'exhaustive'`: Process all tool calls even after finding a final result
|
|
48
|
+
"""
|
|
32
49
|
|
|
33
50
|
|
|
34
51
|
@final
|
|
@@ -43,7 +60,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
43
60
|
|
|
44
61
|
Minimal usage example:
|
|
45
62
|
|
|
46
|
-
```
|
|
63
|
+
```python
|
|
47
64
|
from pydantic_ai import Agent
|
|
48
65
|
|
|
49
66
|
agent = Agent('openai:gpt-4o')
|
|
@@ -53,14 +70,31 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
53
70
|
```
|
|
54
71
|
"""
|
|
55
72
|
|
|
56
|
-
# dataclass fields
|
|
73
|
+
# we use dataclass fields in order to conveniently know what attributes are available
|
|
57
74
|
model: models.Model | models.KnownModelName | None
|
|
58
75
|
"""The default model configured for this agent."""
|
|
76
|
+
|
|
59
77
|
name: str | None
|
|
60
78
|
"""The name of the agent, used for logging.
|
|
61
79
|
|
|
62
80
|
If `None`, we try to infer the agent name from the call frame when the agent is first run.
|
|
63
81
|
"""
|
|
82
|
+
end_strategy: EndStrategy
|
|
83
|
+
"""Strategy for handling tool calls when a final result is found."""
|
|
84
|
+
|
|
85
|
+
model_settings: ModelSettings | None
|
|
86
|
+
"""Optional model request settings to use for this agents's runs, by default.
|
|
87
|
+
|
|
88
|
+
Note, if `model_settings` is provided by `run`, `run_sync`, or `run_stream`, those settings will
|
|
89
|
+
be merged with this value, with the runtime argument taking priority.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
last_run_messages: list[_messages.ModelMessage] | None
|
|
93
|
+
"""The messages from the last run, useful when a run raised an exception.
|
|
94
|
+
|
|
95
|
+
Note: these are not used by the agent, e.g. in future runs, they are just stored for developers' convenience.
|
|
96
|
+
"""
|
|
97
|
+
|
|
64
98
|
_result_schema: _result.ResultSchema[ResultData] | None = field(repr=False)
|
|
65
99
|
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = field(repr=False)
|
|
66
100
|
_allow_text_result: bool = field(repr=False)
|
|
@@ -73,11 +107,6 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
73
107
|
_current_result_retry: int = field(repr=False)
|
|
74
108
|
_override_deps: _utils.Option[AgentDeps] = field(default=None, repr=False)
|
|
75
109
|
_override_model: _utils.Option[models.Model] = field(default=None, repr=False)
|
|
76
|
-
last_run_messages: list[_messages.Message] | None = None
|
|
77
|
-
"""The messages from the last run, useful when a run raised an exception.
|
|
78
|
-
|
|
79
|
-
Note: these are not used by the agent, e.g. in future runs, they are just stored for developers' convenience.
|
|
80
|
-
"""
|
|
81
110
|
|
|
82
111
|
def __init__(
|
|
83
112
|
self,
|
|
@@ -87,18 +116,20 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
87
116
|
system_prompt: str | Sequence[str] = (),
|
|
88
117
|
deps_type: type[AgentDeps] = NoneType,
|
|
89
118
|
name: str | None = None,
|
|
119
|
+
model_settings: ModelSettings | None = None,
|
|
90
120
|
retries: int = 1,
|
|
91
121
|
result_tool_name: str = 'final_result',
|
|
92
122
|
result_tool_description: str | None = None,
|
|
93
123
|
result_retries: int | None = None,
|
|
94
124
|
tools: Sequence[Tool[AgentDeps] | ToolFuncEither[AgentDeps, ...]] = (),
|
|
95
125
|
defer_model_check: bool = False,
|
|
126
|
+
end_strategy: EndStrategy = 'early',
|
|
96
127
|
):
|
|
97
128
|
"""Create an agent.
|
|
98
129
|
|
|
99
130
|
Args:
|
|
100
131
|
model: The default model to use for this agent, if not provide,
|
|
101
|
-
you must provide the model when calling
|
|
132
|
+
you must provide the model when calling it.
|
|
102
133
|
result_type: The type of the result data, used to validate the result data, defaults to `str`.
|
|
103
134
|
system_prompt: Static system prompts to use for this agent, you can also register system
|
|
104
135
|
prompts via a function with [`system_prompt`][pydantic_ai.Agent.system_prompt].
|
|
@@ -108,6 +139,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
108
139
|
or add a type hint `: Agent[None, <return type>]`.
|
|
109
140
|
name: The name of the agent, used for logging. If `None`, we try to infer the agent name from the call frame
|
|
110
141
|
when the agent is first run.
|
|
142
|
+
model_settings: Optional model request settings to use for this agent's runs, by default.
|
|
111
143
|
retries: The default number of retries to allow before raising an error.
|
|
112
144
|
result_tool_name: The name of the tool to use for the final result.
|
|
113
145
|
result_tool_description: The description of the final result tool.
|
|
@@ -119,13 +151,18 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
119
151
|
which checks for the necessary environment variables. Set this to `false`
|
|
120
152
|
to defer the evaluation until the first run. Useful if you want to
|
|
121
153
|
[override the model][pydantic_ai.Agent.override] for testing.
|
|
154
|
+
end_strategy: Strategy for handling tool calls that are requested alongside a final result.
|
|
155
|
+
See [`EndStrategy`][pydantic_ai.agent.EndStrategy] for more information.
|
|
122
156
|
"""
|
|
123
157
|
if model is None or defer_model_check:
|
|
124
158
|
self.model = model
|
|
125
159
|
else:
|
|
126
160
|
self.model = models.infer_model(model)
|
|
127
161
|
|
|
162
|
+
self.end_strategy = end_strategy
|
|
128
163
|
self.name = name
|
|
164
|
+
self.model_settings = model_settings
|
|
165
|
+
self.last_run_messages = None
|
|
129
166
|
self._result_schema = _result.ResultSchema[result_type].build(
|
|
130
167
|
result_type, result_tool_name, result_tool_description
|
|
131
168
|
)
|
|
@@ -136,7 +173,10 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
136
173
|
self._function_tools = {}
|
|
137
174
|
self._default_retries = retries
|
|
138
175
|
for tool in tools:
|
|
139
|
-
|
|
176
|
+
if isinstance(tool, Tool):
|
|
177
|
+
self._register_tool(tool)
|
|
178
|
+
else:
|
|
179
|
+
self._register_tool(Tool(tool))
|
|
140
180
|
self._deps_type = deps_type
|
|
141
181
|
self._system_prompt_functions = []
|
|
142
182
|
self._max_result_retries = result_retries if result_retries is not None else retries
|
|
@@ -147,63 +187,84 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
147
187
|
self,
|
|
148
188
|
user_prompt: str,
|
|
149
189
|
*,
|
|
150
|
-
message_history: list[_messages.
|
|
190
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
151
191
|
model: models.Model | models.KnownModelName | None = None,
|
|
152
192
|
deps: AgentDeps = None,
|
|
193
|
+
model_settings: ModelSettings | None = None,
|
|
153
194
|
infer_name: bool = True,
|
|
154
195
|
) -> result.RunResult[ResultData]:
|
|
155
196
|
"""Run the agent with a user prompt in async mode.
|
|
156
197
|
|
|
198
|
+
Example:
|
|
199
|
+
```python
|
|
200
|
+
from pydantic_ai import Agent
|
|
201
|
+
|
|
202
|
+
agent = Agent('openai:gpt-4o')
|
|
203
|
+
|
|
204
|
+
result_sync = agent.run_sync('What is the capital of Italy?')
|
|
205
|
+
print(result_sync.data)
|
|
206
|
+
#> Rome
|
|
207
|
+
```
|
|
208
|
+
|
|
157
209
|
Args:
|
|
158
210
|
user_prompt: User input to start/continue the conversation.
|
|
159
211
|
message_history: History of the conversation so far.
|
|
160
212
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
161
213
|
deps: Optional dependencies to use for this run.
|
|
162
214
|
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
215
|
+
model_settings: Optional settings to use for this model's request.
|
|
163
216
|
|
|
164
217
|
Returns:
|
|
165
218
|
The result of the run.
|
|
166
219
|
"""
|
|
167
220
|
if infer_name and self.name is None:
|
|
168
221
|
self._infer_name(inspect.currentframe())
|
|
169
|
-
model_used,
|
|
222
|
+
model_used, mode_selection = await self._get_model(model)
|
|
170
223
|
|
|
171
224
|
deps = self._get_deps(deps)
|
|
225
|
+
new_message_index = len(message_history) if message_history else 0
|
|
172
226
|
|
|
173
227
|
with _logfire.span(
|
|
174
228
|
'{agent_name} run {prompt=}',
|
|
175
229
|
prompt=user_prompt,
|
|
176
230
|
agent=self,
|
|
177
|
-
|
|
231
|
+
mode_selection=mode_selection,
|
|
178
232
|
model_name=model_used.name(),
|
|
179
233
|
agent_name=self.name or 'agent',
|
|
180
234
|
) as run_span:
|
|
181
|
-
|
|
182
|
-
self.last_run_messages = messages
|
|
235
|
+
self.last_run_messages = messages = await self._prepare_messages(deps, user_prompt, message_history)
|
|
183
236
|
|
|
184
237
|
for tool in self._function_tools.values():
|
|
185
|
-
tool.
|
|
238
|
+
tool.current_retry = 0
|
|
186
239
|
|
|
187
240
|
cost = result.Cost()
|
|
188
241
|
|
|
242
|
+
model_settings = merge_model_settings(self.model_settings, model_settings)
|
|
243
|
+
|
|
189
244
|
run_step = 0
|
|
190
245
|
while True:
|
|
191
246
|
run_step += 1
|
|
192
|
-
with _logfire.span('model
|
|
193
|
-
|
|
247
|
+
with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
|
|
248
|
+
agent_model = await self._prepare_model(model_used, deps, messages)
|
|
249
|
+
|
|
250
|
+
with _logfire.span('model request', run_step=run_step) as model_req_span:
|
|
251
|
+
model_response, request_cost = await agent_model.request(messages, model_settings)
|
|
194
252
|
model_req_span.set_attribute('response', model_response)
|
|
195
253
|
model_req_span.set_attribute('cost', request_cost)
|
|
196
|
-
model_req_span.message = f'model request -> {model_response.role}'
|
|
197
254
|
|
|
198
255
|
messages.append(model_response)
|
|
199
256
|
cost += request_cost
|
|
200
257
|
|
|
201
|
-
with _logfire.span('handle model response') as handle_span:
|
|
202
|
-
|
|
258
|
+
with _logfire.span('handle model response', run_step=run_step) as handle_span:
|
|
259
|
+
final_result, tool_responses = await self._handle_model_response(model_response, deps, messages)
|
|
260
|
+
|
|
261
|
+
if tool_responses:
|
|
262
|
+
# Add parts to the conversation as a new message
|
|
263
|
+
messages.append(_messages.ModelRequest(tool_responses))
|
|
203
264
|
|
|
204
|
-
if
|
|
205
|
-
|
|
206
|
-
result_data =
|
|
265
|
+
# Check if we got a final result
|
|
266
|
+
if final_result is not None:
|
|
267
|
+
result_data = final_result.data
|
|
207
268
|
run_span.set_attribute('all_messages', messages)
|
|
208
269
|
run_span.set_attribute('cost', cost)
|
|
209
270
|
handle_span.set_attribute('result', result_data)
|
|
@@ -211,24 +272,36 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
211
272
|
return result.RunResult(messages, new_message_index, result_data, cost)
|
|
212
273
|
else:
|
|
213
274
|
# continue the conversation
|
|
214
|
-
tool_responses = either
|
|
215
275
|
handle_span.set_attribute('tool_responses', tool_responses)
|
|
216
|
-
|
|
217
|
-
handle_span.message = f'handle model response -> {
|
|
218
|
-
messages.extend(tool_responses)
|
|
276
|
+
tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
|
|
277
|
+
handle_span.message = f'handle model response -> {tool_responses_str}'
|
|
219
278
|
|
|
220
279
|
def run_sync(
|
|
221
280
|
self,
|
|
222
281
|
user_prompt: str,
|
|
223
282
|
*,
|
|
224
|
-
message_history: list[_messages.
|
|
283
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
225
284
|
model: models.Model | models.KnownModelName | None = None,
|
|
226
285
|
deps: AgentDeps = None,
|
|
286
|
+
model_settings: ModelSettings | None = None,
|
|
227
287
|
infer_name: bool = True,
|
|
228
288
|
) -> result.RunResult[ResultData]:
|
|
229
289
|
"""Run the agent with a user prompt synchronously.
|
|
230
290
|
|
|
231
|
-
This is a convenience method that wraps `self.run` with `loop.run_until_complete()`.
|
|
291
|
+
This is a convenience method that wraps [`self.run`][pydantic_ai.Agent.run] with `loop.run_until_complete(...)`.
|
|
292
|
+
You therefore can't use this method inside async code or if there's an active event loop.
|
|
293
|
+
|
|
294
|
+
Example:
|
|
295
|
+
```python
|
|
296
|
+
from pydantic_ai import Agent
|
|
297
|
+
|
|
298
|
+
agent = Agent('openai:gpt-4o')
|
|
299
|
+
|
|
300
|
+
async def main():
|
|
301
|
+
result = await agent.run('What is the capital of France?')
|
|
302
|
+
print(result.data)
|
|
303
|
+
#> Paris
|
|
304
|
+
```
|
|
232
305
|
|
|
233
306
|
Args:
|
|
234
307
|
user_prompt: User input to start/continue the conversation.
|
|
@@ -236,15 +309,22 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
236
309
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
237
310
|
deps: Optional dependencies to use for this run.
|
|
238
311
|
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
312
|
+
model_settings: Optional settings to use for this model's request.
|
|
239
313
|
|
|
240
314
|
Returns:
|
|
241
315
|
The result of the run.
|
|
242
316
|
"""
|
|
243
317
|
if infer_name and self.name is None:
|
|
244
318
|
self._infer_name(inspect.currentframe())
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
319
|
+
return asyncio.get_event_loop().run_until_complete(
|
|
320
|
+
self.run(
|
|
321
|
+
user_prompt,
|
|
322
|
+
message_history=message_history,
|
|
323
|
+
model=model,
|
|
324
|
+
deps=deps,
|
|
325
|
+
infer_name=False,
|
|
326
|
+
model_settings=model_settings,
|
|
327
|
+
)
|
|
248
328
|
)
|
|
249
329
|
|
|
250
330
|
@asynccontextmanager
|
|
@@ -252,19 +332,33 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
252
332
|
self,
|
|
253
333
|
user_prompt: str,
|
|
254
334
|
*,
|
|
255
|
-
message_history: list[_messages.
|
|
335
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
256
336
|
model: models.Model | models.KnownModelName | None = None,
|
|
257
337
|
deps: AgentDeps = None,
|
|
338
|
+
model_settings: ModelSettings | None = None,
|
|
258
339
|
infer_name: bool = True,
|
|
259
340
|
) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]:
|
|
260
341
|
"""Run the agent with a user prompt in async mode, returning a streamed response.
|
|
261
342
|
|
|
343
|
+
Example:
|
|
344
|
+
```python
|
|
345
|
+
from pydantic_ai import Agent
|
|
346
|
+
|
|
347
|
+
agent = Agent('openai:gpt-4o')
|
|
348
|
+
|
|
349
|
+
async def main():
|
|
350
|
+
async with agent.run_stream('What is the capital of the UK?') as response:
|
|
351
|
+
print(await response.get_data())
|
|
352
|
+
#> London
|
|
353
|
+
```
|
|
354
|
+
|
|
262
355
|
Args:
|
|
263
356
|
user_prompt: User input to start/continue the conversation.
|
|
264
357
|
message_history: History of the conversation so far.
|
|
265
358
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
266
359
|
deps: Optional dependencies to use for this run.
|
|
267
360
|
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
361
|
+
model_settings: Optional settings to use for this model's request.
|
|
268
362
|
|
|
269
363
|
Returns:
|
|
270
364
|
The result of the run.
|
|
@@ -273,44 +367,70 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
273
367
|
# f_back because `asynccontextmanager` adds one frame
|
|
274
368
|
if frame := inspect.currentframe(): # pragma: no branch
|
|
275
369
|
self._infer_name(frame.f_back)
|
|
276
|
-
model_used,
|
|
370
|
+
model_used, mode_selection = await self._get_model(model)
|
|
277
371
|
|
|
278
372
|
deps = self._get_deps(deps)
|
|
373
|
+
new_message_index = len(message_history) if message_history else 0
|
|
279
374
|
|
|
280
375
|
with _logfire.span(
|
|
281
376
|
'{agent_name} run stream {prompt=}',
|
|
282
377
|
prompt=user_prompt,
|
|
283
378
|
agent=self,
|
|
284
|
-
|
|
379
|
+
mode_selection=mode_selection,
|
|
285
380
|
model_name=model_used.name(),
|
|
286
381
|
agent_name=self.name or 'agent',
|
|
287
382
|
) as run_span:
|
|
288
|
-
|
|
289
|
-
self.last_run_messages = messages
|
|
383
|
+
self.last_run_messages = messages = await self._prepare_messages(deps, user_prompt, message_history)
|
|
290
384
|
|
|
291
385
|
for tool in self._function_tools.values():
|
|
292
|
-
tool.
|
|
386
|
+
tool.current_retry = 0
|
|
293
387
|
|
|
294
388
|
cost = result.Cost()
|
|
389
|
+
model_settings = merge_model_settings(self.model_settings, model_settings)
|
|
295
390
|
|
|
296
391
|
run_step = 0
|
|
297
392
|
while True:
|
|
298
393
|
run_step += 1
|
|
394
|
+
|
|
395
|
+
with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
|
|
396
|
+
agent_model = await self._prepare_model(model_used, deps, messages)
|
|
397
|
+
|
|
299
398
|
with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
|
|
300
|
-
async with agent_model.request_stream(messages) as model_response:
|
|
399
|
+
async with agent_model.request_stream(messages, model_settings) as model_response:
|
|
301
400
|
model_req_span.set_attribute('response_type', model_response.__class__.__name__)
|
|
302
401
|
# We want to end the "model request" span here, but we can't exit the context manager
|
|
303
402
|
# in the traditional way
|
|
304
403
|
model_req_span.__exit__(None, None, None)
|
|
305
404
|
|
|
306
405
|
with _logfire.span('handle model response') as handle_span:
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
406
|
+
maybe_final_result = await self._handle_streamed_model_response(
|
|
407
|
+
model_response, deps, messages
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
# Check if we got a final result
|
|
411
|
+
if isinstance(maybe_final_result, _MarkFinalResult):
|
|
412
|
+
result_stream = maybe_final_result.data
|
|
413
|
+
result_tool_name = maybe_final_result.tool_name
|
|
313
414
|
handle_span.message = 'handle model response -> final result'
|
|
415
|
+
|
|
416
|
+
async def on_complete():
|
|
417
|
+
"""Called when the stream has completed.
|
|
418
|
+
|
|
419
|
+
The model response will have been added to messages by now
|
|
420
|
+
by `StreamedRunResult._marked_completed`.
|
|
421
|
+
"""
|
|
422
|
+
last_message = messages[-1]
|
|
423
|
+
assert isinstance(last_message, _messages.ModelResponse)
|
|
424
|
+
tool_calls = [
|
|
425
|
+
part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)
|
|
426
|
+
]
|
|
427
|
+
parts = await self._process_function_tools(
|
|
428
|
+
tool_calls, result_tool_name, deps, messages
|
|
429
|
+
)
|
|
430
|
+
if parts:
|
|
431
|
+
messages.append(_messages.ModelRequest(parts))
|
|
432
|
+
run_span.set_attribute('all_messages', messages)
|
|
433
|
+
|
|
314
434
|
yield result.StreamedRunResult(
|
|
315
435
|
messages,
|
|
316
436
|
new_message_index,
|
|
@@ -319,15 +439,22 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
319
439
|
self._result_schema,
|
|
320
440
|
deps,
|
|
321
441
|
self._result_validators,
|
|
322
|
-
|
|
442
|
+
result_tool_name,
|
|
443
|
+
on_complete,
|
|
323
444
|
)
|
|
324
445
|
return
|
|
325
446
|
else:
|
|
326
|
-
|
|
447
|
+
# continue the conversation
|
|
448
|
+
model_response_msg, tool_responses = maybe_final_result
|
|
449
|
+
# if we got a model response add that to messages
|
|
450
|
+
messages.append(model_response_msg)
|
|
451
|
+
if tool_responses:
|
|
452
|
+
# if we got one or more tool response parts, add a model request message
|
|
453
|
+
messages.append(_messages.ModelRequest(tool_responses))
|
|
454
|
+
|
|
327
455
|
handle_span.set_attribute('tool_responses', tool_responses)
|
|
328
|
-
|
|
329
|
-
handle_span.message = f'handle model response -> {
|
|
330
|
-
messages.extend(tool_responses)
|
|
456
|
+
tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
|
|
457
|
+
handle_span.message = f'handle model response -> {tool_responses_str}'
|
|
331
458
|
# the model_response should have been fully streamed by now, we can add it's cost
|
|
332
459
|
cost += model_response.cost()
|
|
333
460
|
|
|
@@ -341,6 +468,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
341
468
|
"""Context manager to temporarily override agent dependencies and model.
|
|
342
469
|
|
|
343
470
|
This is particularly useful when testing.
|
|
471
|
+
You can find an example of this [here](../testing-evals.md#overriding-model-via-pytest-fixtures).
|
|
344
472
|
|
|
345
473
|
Args:
|
|
346
474
|
deps: The dependencies to use instead of the dependencies passed to the agent run.
|
|
@@ -389,14 +517,14 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
389
517
|
) -> _system_prompt.SystemPromptFunc[AgentDeps]:
|
|
390
518
|
"""Decorator to register a system prompt function.
|
|
391
519
|
|
|
392
|
-
Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as
|
|
520
|
+
Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its only argument.
|
|
393
521
|
Can decorate a sync or async functions.
|
|
394
522
|
|
|
395
523
|
Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure
|
|
396
524
|
the type of the function, see `tests/typed_agent.py` for tests.
|
|
397
525
|
|
|
398
526
|
Example:
|
|
399
|
-
```
|
|
527
|
+
```python
|
|
400
528
|
from pydantic_ai import Agent, RunContext
|
|
401
529
|
|
|
402
530
|
agent = Agent('test', deps_type=str)
|
|
@@ -440,14 +568,14 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
440
568
|
) -> _result.ResultValidatorFunc[AgentDeps, ResultData]:
|
|
441
569
|
"""Decorator to register a result validator function.
|
|
442
570
|
|
|
443
|
-
Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as
|
|
571
|
+
Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
|
|
444
572
|
Can decorate a sync or async functions.
|
|
445
573
|
|
|
446
574
|
Overloads for every possible signature of `result_validator` are included so the decorator doesn't obscure
|
|
447
575
|
the type of the function, see `tests/typed_agent.py` for tests.
|
|
448
576
|
|
|
449
577
|
Example:
|
|
450
|
-
```
|
|
578
|
+
```python
|
|
451
579
|
from pydantic_ai import Agent, ModelRetry, RunContext
|
|
452
580
|
|
|
453
581
|
agent = Agent('test', deps_type=str)
|
|
@@ -477,7 +605,11 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
477
605
|
|
|
478
606
|
@overload
|
|
479
607
|
def tool(
|
|
480
|
-
self,
|
|
608
|
+
self,
|
|
609
|
+
/,
|
|
610
|
+
*,
|
|
611
|
+
retries: int | None = None,
|
|
612
|
+
prepare: ToolPrepareFunc[AgentDeps] | None = None,
|
|
481
613
|
) -> Callable[[ToolFuncContext[AgentDeps, ToolParams]], ToolFuncContext[AgentDeps, ToolParams]]: ...
|
|
482
614
|
|
|
483
615
|
def tool(
|
|
@@ -486,20 +618,20 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
486
618
|
/,
|
|
487
619
|
*,
|
|
488
620
|
retries: int | None = None,
|
|
621
|
+
prepare: ToolPrepareFunc[AgentDeps] | None = None,
|
|
489
622
|
) -> Any:
|
|
490
|
-
"""Decorator to register a tool function which takes
|
|
491
|
-
[`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
|
|
623
|
+
"""Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
|
|
492
624
|
|
|
493
625
|
Can decorate a sync or async functions.
|
|
494
626
|
|
|
495
627
|
The docstring is inspected to extract both the tool description and description of each parameter,
|
|
496
|
-
[learn more](../
|
|
628
|
+
[learn more](../tools.md#function-tools-and-schema).
|
|
497
629
|
|
|
498
630
|
We can't add overloads for every possible signature of tool, since the return type is a recursive union
|
|
499
631
|
so the signature of functions decorated with `@agent.tool` is obscured.
|
|
500
632
|
|
|
501
633
|
Example:
|
|
502
|
-
```
|
|
634
|
+
```python
|
|
503
635
|
from pydantic_ai import Agent, RunContext
|
|
504
636
|
|
|
505
637
|
agent = Agent('test', deps_type=int)
|
|
@@ -521,20 +653,23 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
521
653
|
func: The tool function to register.
|
|
522
654
|
retries: The number of retries to allow for this tool, defaults to the agent's default retries,
|
|
523
655
|
which defaults to 1.
|
|
524
|
-
|
|
656
|
+
prepare: custom method to prepare the tool definition for each step, return `None` to omit this
|
|
657
|
+
tool from a given step. This is useful if you want to customise a tool at call time,
|
|
658
|
+
or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
|
|
659
|
+
"""
|
|
525
660
|
if func is None:
|
|
526
661
|
|
|
527
662
|
def tool_decorator(
|
|
528
663
|
func_: ToolFuncContext[AgentDeps, ToolParams],
|
|
529
664
|
) -> ToolFuncContext[AgentDeps, ToolParams]:
|
|
530
665
|
# noinspection PyTypeChecker
|
|
531
|
-
self._register_function(func_, True, retries)
|
|
666
|
+
self._register_function(func_, True, retries, prepare)
|
|
532
667
|
return func_
|
|
533
668
|
|
|
534
669
|
return tool_decorator
|
|
535
670
|
else:
|
|
536
671
|
# noinspection PyTypeChecker
|
|
537
|
-
self._register_function(func, True, retries)
|
|
672
|
+
self._register_function(func, True, retries, prepare)
|
|
538
673
|
return func
|
|
539
674
|
|
|
540
675
|
@overload
|
|
@@ -542,22 +677,33 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
542
677
|
|
|
543
678
|
@overload
|
|
544
679
|
def tool_plain(
|
|
545
|
-
self,
|
|
680
|
+
self,
|
|
681
|
+
/,
|
|
682
|
+
*,
|
|
683
|
+
retries: int | None = None,
|
|
684
|
+
prepare: ToolPrepareFunc[AgentDeps] | None = None,
|
|
546
685
|
) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...
|
|
547
686
|
|
|
548
|
-
def tool_plain(
|
|
687
|
+
def tool_plain(
|
|
688
|
+
self,
|
|
689
|
+
func: ToolFuncPlain[ToolParams] | None = None,
|
|
690
|
+
/,
|
|
691
|
+
*,
|
|
692
|
+
retries: int | None = None,
|
|
693
|
+
prepare: ToolPrepareFunc[AgentDeps] | None = None,
|
|
694
|
+
) -> Any:
|
|
549
695
|
"""Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
|
|
550
696
|
|
|
551
697
|
Can decorate a sync or async functions.
|
|
552
698
|
|
|
553
699
|
The docstring is inspected to extract both the tool description and description of each parameter,
|
|
554
|
-
[learn more](../
|
|
700
|
+
[learn more](../tools.md#function-tools-and-schema).
|
|
555
701
|
|
|
556
702
|
We can't add overloads for every possible signature of tool, since the return type is a recursive union
|
|
557
703
|
so the signature of functions decorated with `@agent.tool` is obscured.
|
|
558
704
|
|
|
559
705
|
Example:
|
|
560
|
-
```
|
|
706
|
+
```python
|
|
561
707
|
from pydantic_ai import Agent, RunContext
|
|
562
708
|
|
|
563
709
|
agent = Agent('test')
|
|
@@ -579,30 +725,38 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
579
725
|
func: The tool function to register.
|
|
580
726
|
retries: The number of retries to allow for this tool, defaults to the agent's default retries,
|
|
581
727
|
which defaults to 1.
|
|
728
|
+
prepare: custom method to prepare the tool definition for each step, return `None` to omit this
|
|
729
|
+
tool from a given step. This is useful if you want to customise a tool at call time,
|
|
730
|
+
or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
|
|
582
731
|
"""
|
|
583
732
|
if func is None:
|
|
584
733
|
|
|
585
734
|
def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
|
|
586
735
|
# noinspection PyTypeChecker
|
|
587
|
-
self._register_function(func_, False, retries)
|
|
736
|
+
self._register_function(func_, False, retries, prepare)
|
|
588
737
|
return func_
|
|
589
738
|
|
|
590
739
|
return tool_decorator
|
|
591
740
|
else:
|
|
592
|
-
self._register_function(func, False, retries)
|
|
741
|
+
self._register_function(func, False, retries, prepare)
|
|
593
742
|
return func
|
|
594
743
|
|
|
595
744
|
def _register_function(
|
|
596
|
-
self,
|
|
745
|
+
self,
|
|
746
|
+
func: ToolFuncEither[AgentDeps, ToolParams],
|
|
747
|
+
takes_ctx: bool,
|
|
748
|
+
retries: int | None,
|
|
749
|
+
prepare: ToolPrepareFunc[AgentDeps] | None,
|
|
597
750
|
) -> None:
|
|
598
751
|
"""Private utility to register a function as a tool."""
|
|
599
752
|
retries_ = retries if retries is not None else self._default_retries
|
|
600
|
-
tool = Tool(func, takes_ctx, max_retries=retries_)
|
|
753
|
+
tool = Tool(func, takes_ctx=takes_ctx, max_retries=retries_, prepare=prepare)
|
|
601
754
|
self._register_tool(tool)
|
|
602
755
|
|
|
603
756
|
def _register_tool(self, tool: Tool[AgentDeps]) -> None:
|
|
604
757
|
"""Private utility to register a tool instance."""
|
|
605
758
|
if tool.max_retries is None:
|
|
759
|
+
# noinspection PyTypeChecker
|
|
606
760
|
tool = dataclasses.replace(tool, max_retries=self._default_retries)
|
|
607
761
|
|
|
608
762
|
if tool.name in self._function_tools:
|
|
@@ -613,16 +767,14 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
613
767
|
|
|
614
768
|
self._function_tools[tool.name] = tool
|
|
615
769
|
|
|
616
|
-
async def
|
|
617
|
-
self, model: models.Model | models.KnownModelName | None
|
|
618
|
-
) -> tuple[models.Model, models.Model | None, models.AgentModel]:
|
|
770
|
+
async def _get_model(self, model: models.Model | models.KnownModelName | None) -> tuple[models.Model, str]:
|
|
619
771
|
"""Create a model configured for this agent.
|
|
620
772
|
|
|
621
773
|
Args:
|
|
622
774
|
model: model to use for this run, required if `model` was not set when creating the agent.
|
|
623
775
|
|
|
624
776
|
Returns:
|
|
625
|
-
a tuple of `(model used,
|
|
777
|
+
a tuple of `(model used, how the model was selected)`
|
|
626
778
|
"""
|
|
627
779
|
model_: models.Model
|
|
628
780
|
if some_model := self._override_model:
|
|
@@ -633,160 +785,267 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
633
785
|
'(Even when `override(model=...)` is customizing the model that will actually be called)'
|
|
634
786
|
)
|
|
635
787
|
model_ = some_model.value
|
|
636
|
-
|
|
788
|
+
mode_selection = 'override-model'
|
|
637
789
|
elif model is not None:
|
|
638
|
-
|
|
790
|
+
model_ = models.infer_model(model)
|
|
791
|
+
mode_selection = 'custom'
|
|
639
792
|
elif self.model is not None:
|
|
640
793
|
# noinspection PyTypeChecker
|
|
641
794
|
model_ = self.model = models.infer_model(self.model)
|
|
642
|
-
|
|
795
|
+
mode_selection = 'from-agent'
|
|
643
796
|
else:
|
|
644
797
|
raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')
|
|
645
798
|
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
799
|
+
return model_, mode_selection
|
|
800
|
+
|
|
801
|
+
async def _prepare_model(
|
|
802
|
+
self, model: models.Model, deps: AgentDeps, messages: list[_messages.ModelMessage]
|
|
803
|
+
) -> models.AgentModel:
|
|
804
|
+
"""Create building tools and create an agent model."""
|
|
805
|
+
function_tools: list[ToolDefinition] = []
|
|
806
|
+
|
|
807
|
+
async def add_tool(tool: Tool[AgentDeps]) -> None:
|
|
808
|
+
ctx = RunContext(deps, tool.current_retry, messages, tool.name)
|
|
809
|
+
if tool_def := await tool.prepare_tool_def(ctx):
|
|
810
|
+
function_tools.append(tool_def)
|
|
811
|
+
|
|
812
|
+
await asyncio.gather(*map(add_tool, self._function_tools.values()))
|
|
813
|
+
|
|
814
|
+
return await model.agent_model(
|
|
815
|
+
function_tools=function_tools,
|
|
816
|
+
allow_text_result=self._allow_text_result,
|
|
817
|
+
result_tools=self._result_schema.tool_defs() if self._result_schema is not None else [],
|
|
818
|
+
)
|
|
649
819
|
|
|
650
820
|
async def _prepare_messages(
|
|
651
|
-
self, deps: AgentDeps, user_prompt: str, message_history: list[_messages.
|
|
652
|
-
) ->
|
|
653
|
-
|
|
654
|
-
if message_history and any(m.role == 'system' for m in message_history):
|
|
821
|
+
self, deps: AgentDeps, user_prompt: str, message_history: list[_messages.ModelMessage] | None
|
|
822
|
+
) -> list[_messages.ModelMessage]:
|
|
823
|
+
if message_history:
|
|
655
824
|
# shallow copy messages
|
|
656
825
|
messages = message_history.copy()
|
|
826
|
+
messages.append(_messages.ModelRequest([_messages.UserPromptPart(user_prompt)]))
|
|
657
827
|
else:
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
828
|
+
parts = await self._sys_parts(deps)
|
|
829
|
+
parts.append(_messages.UserPromptPart(user_prompt))
|
|
830
|
+
messages: list[_messages.ModelMessage] = [_messages.ModelRequest(parts)]
|
|
661
831
|
|
|
662
|
-
|
|
663
|
-
messages.append(_messages.UserPrompt(user_prompt))
|
|
664
|
-
return new_message_index, messages
|
|
832
|
+
return messages
|
|
665
833
|
|
|
666
834
|
async def _handle_model_response(
|
|
667
|
-
self, model_response: _messages.
|
|
668
|
-
) -> _MarkFinalResult[ResultData] | list[_messages.
|
|
835
|
+
self, model_response: _messages.ModelResponse, deps: AgentDeps, conv_messages: list[_messages.ModelMessage]
|
|
836
|
+
) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
|
|
669
837
|
"""Process a non-streamed response from the model.
|
|
670
838
|
|
|
671
839
|
Returns:
|
|
672
|
-
|
|
840
|
+
A tuple of `(final_result, request parts)`. If `final_result` is not `None`, the conversation should end.
|
|
673
841
|
"""
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
842
|
+
texts: list[str] = []
|
|
843
|
+
tool_calls: list[_messages.ToolCallPart] = []
|
|
844
|
+
for item in model_response.parts:
|
|
845
|
+
if isinstance(item, _messages.TextPart):
|
|
846
|
+
texts.append(item.content)
|
|
847
|
+
else:
|
|
848
|
+
tool_calls.append(item)
|
|
849
|
+
|
|
850
|
+
if texts:
|
|
851
|
+
text = '\n\n'.join(texts)
|
|
852
|
+
return await self._handle_text_response(text, deps, conv_messages)
|
|
853
|
+
elif tool_calls:
|
|
854
|
+
return await self._handle_structured_response(tool_calls, deps, conv_messages)
|
|
855
|
+
else:
|
|
856
|
+
raise exceptions.UnexpectedModelBehavior('Received empty model response')
|
|
857
|
+
|
|
858
|
+
async def _handle_text_response(
|
|
859
|
+
self, text: str, deps: AgentDeps, conv_messages: list[_messages.ModelMessage]
|
|
860
|
+
) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
|
|
861
|
+
"""Handle a plain text response from the model for non-streaming responses."""
|
|
862
|
+
if self._allow_text_result:
|
|
863
|
+
result_data_input = cast(ResultData, text)
|
|
864
|
+
try:
|
|
865
|
+
result_data = await self._validate_result(result_data_input, deps, None, conv_messages)
|
|
866
|
+
except _result.ToolRetryError as e:
|
|
867
|
+
self._incr_result_retry()
|
|
868
|
+
return None, [e.tool_retry]
|
|
869
|
+
else:
|
|
870
|
+
return _MarkFinalResult(result_data, None), []
|
|
871
|
+
else:
|
|
872
|
+
self._incr_result_retry()
|
|
873
|
+
response = _messages.RetryPromptPart(
|
|
874
|
+
content='Plain text responses are not permitted, please call one of the functions instead.',
|
|
875
|
+
)
|
|
876
|
+
return None, [response]
|
|
877
|
+
|
|
878
|
+
async def _handle_structured_response(
|
|
879
|
+
self, tool_calls: list[_messages.ToolCallPart], deps: AgentDeps, conv_messages: list[_messages.ModelMessage]
|
|
880
|
+
) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
|
|
881
|
+
"""Handle a structured response containing tool calls from the model for non-streaming responses."""
|
|
882
|
+
assert tool_calls, 'Expected at least one tool call'
|
|
883
|
+
|
|
884
|
+
# first look for the result tool call
|
|
885
|
+
final_result: _MarkFinalResult[ResultData] | None = None
|
|
886
|
+
|
|
887
|
+
parts: list[_messages.ModelRequestPart] = []
|
|
888
|
+
if result_schema := self._result_schema:
|
|
889
|
+
if match := result_schema.find_tool(tool_calls):
|
|
890
|
+
call, result_tool = match
|
|
678
891
|
try:
|
|
679
|
-
result_data =
|
|
892
|
+
result_data = result_tool.validate(call)
|
|
893
|
+
result_data = await self._validate_result(result_data, deps, call, conv_messages)
|
|
680
894
|
except _result.ToolRetryError as e:
|
|
681
895
|
self._incr_result_retry()
|
|
682
|
-
|
|
896
|
+
parts.append(e.tool_retry)
|
|
683
897
|
else:
|
|
684
|
-
|
|
685
|
-
else:
|
|
686
|
-
self._incr_result_retry()
|
|
687
|
-
response = _messages.RetryPrompt(
|
|
688
|
-
content='Plain text responses are not permitted, please call one of the functions instead.',
|
|
689
|
-
)
|
|
690
|
-
return [response]
|
|
691
|
-
elif model_response.role == 'model-structured-response':
|
|
692
|
-
if self._result_schema is not None:
|
|
693
|
-
# if there's a result schema, and any of the calls match one of its tools, return the result
|
|
694
|
-
# NOTE: this means we ignore any other tools called here
|
|
695
|
-
if match := self._result_schema.find_tool(model_response):
|
|
696
|
-
call, result_tool = match
|
|
697
|
-
try:
|
|
698
|
-
result_data = result_tool.validate(call)
|
|
699
|
-
result_data = await self._validate_result(result_data, deps, call)
|
|
700
|
-
except _result.ToolRetryError as e:
|
|
701
|
-
self._incr_result_retry()
|
|
702
|
-
return [e.tool_retry]
|
|
703
|
-
else:
|
|
704
|
-
return _MarkFinalResult(result_data)
|
|
898
|
+
final_result = _MarkFinalResult(result_data, call.tool_name)
|
|
705
899
|
|
|
706
|
-
|
|
707
|
-
|
|
900
|
+
# Then build the other request parts based on end strategy
|
|
901
|
+
parts += await self._process_function_tools(
|
|
902
|
+
tool_calls, final_result and final_result.tool_name, deps, conv_messages
|
|
903
|
+
)
|
|
904
|
+
|
|
905
|
+
return final_result, parts
|
|
708
906
|
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
907
|
+
async def _process_function_tools(
|
|
908
|
+
self,
|
|
909
|
+
tool_calls: list[_messages.ToolCallPart],
|
|
910
|
+
result_tool_name: str | None,
|
|
911
|
+
deps: AgentDeps,
|
|
912
|
+
conv_messages: list[_messages.ModelMessage],
|
|
913
|
+
) -> list[_messages.ModelRequestPart]:
|
|
914
|
+
"""Process function (non-result) tool calls in parallel.
|
|
915
|
+
|
|
916
|
+
Also add stub return parts for any other tools that need it.
|
|
917
|
+
"""
|
|
918
|
+
parts: list[_messages.ModelRequestPart] = []
|
|
919
|
+
tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
|
|
920
|
+
|
|
921
|
+
stub_function_tools = bool(result_tool_name) and self.end_strategy == 'early'
|
|
922
|
+
|
|
923
|
+
# we rely on the fact that if we found a result, it's the first result tool in the last
|
|
924
|
+
found_used_result_tool = False
|
|
925
|
+
for call in tool_calls:
|
|
926
|
+
if call.tool_name == result_tool_name and not found_used_result_tool:
|
|
927
|
+
found_used_result_tool = True
|
|
928
|
+
parts.append(
|
|
929
|
+
_messages.ToolReturnPart(
|
|
930
|
+
tool_name=call.tool_name,
|
|
931
|
+
content='Final result processed.',
|
|
932
|
+
tool_call_id=call.tool_call_id,
|
|
933
|
+
)
|
|
934
|
+
)
|
|
935
|
+
elif tool := self._function_tools.get(call.tool_name):
|
|
936
|
+
if stub_function_tools:
|
|
937
|
+
parts.append(
|
|
938
|
+
_messages.ToolReturnPart(
|
|
939
|
+
tool_name=call.tool_name,
|
|
940
|
+
content='Tool not executed - a final result was already processed.',
|
|
941
|
+
tool_call_id=call.tool_call_id,
|
|
942
|
+
)
|
|
943
|
+
)
|
|
715
944
|
else:
|
|
716
|
-
|
|
945
|
+
tasks.append(asyncio.create_task(tool.run(deps, call, conv_messages), name=call.tool_name))
|
|
946
|
+
elif self._result_schema is not None and call.tool_name in self._result_schema.tools:
|
|
947
|
+
# if tool_name is in _result_schema, it means we found a result tool but an error occurred in
|
|
948
|
+
# validation, we don't add another part here
|
|
949
|
+
if result_tool_name is not None:
|
|
950
|
+
parts.append(
|
|
951
|
+
_messages.ToolReturnPart(
|
|
952
|
+
tool_name=call.tool_name,
|
|
953
|
+
content='Result tool not used - a final result was already processed.',
|
|
954
|
+
tool_call_id=call.tool_call_id,
|
|
955
|
+
)
|
|
956
|
+
)
|
|
957
|
+
else:
|
|
958
|
+
parts.append(self._unknown_tool(call.tool_name))
|
|
717
959
|
|
|
960
|
+
# Run all tool tasks in parallel
|
|
961
|
+
if tasks:
|
|
718
962
|
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
assert_never(model_response)
|
|
963
|
+
task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
|
|
964
|
+
parts.extend(task_results)
|
|
965
|
+
return parts
|
|
723
966
|
|
|
724
967
|
async def _handle_streamed_model_response(
|
|
725
|
-
self,
|
|
726
|
-
|
|
968
|
+
self,
|
|
969
|
+
model_response: models.EitherStreamedResponse,
|
|
970
|
+
deps: AgentDeps,
|
|
971
|
+
conv_messages: list[_messages.ModelMessage],
|
|
972
|
+
) -> (
|
|
973
|
+
_MarkFinalResult[models.EitherStreamedResponse]
|
|
974
|
+
| tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]]
|
|
975
|
+
):
|
|
727
976
|
"""Process a streamed response from the model.
|
|
728
977
|
|
|
729
|
-
TODO: change the response type to `models.EitherStreamedResponse | list[_messages.Message]` once we drop 3.9
|
|
730
|
-
(with 3.9 we get `TypeError: Subscripted generics cannot be used with class and instance checks`)
|
|
731
|
-
|
|
732
978
|
Returns:
|
|
733
|
-
|
|
979
|
+
Either a final result or a tuple of the model response and the tool responses for the next request.
|
|
980
|
+
If a final result is returned, the conversation should end.
|
|
734
981
|
"""
|
|
735
982
|
if isinstance(model_response, models.StreamTextResponse):
|
|
736
983
|
# plain string response
|
|
737
984
|
if self._allow_text_result:
|
|
738
|
-
return _MarkFinalResult(model_response)
|
|
985
|
+
return _MarkFinalResult(model_response, None)
|
|
739
986
|
else:
|
|
740
987
|
self._incr_result_retry()
|
|
741
|
-
response = _messages.
|
|
988
|
+
response = _messages.RetryPromptPart(
|
|
742
989
|
content='Plain text responses are not permitted, please call one of the functions instead.',
|
|
743
990
|
)
|
|
744
991
|
# stream the response, so cost is correct
|
|
745
992
|
async for _ in model_response:
|
|
746
993
|
pass
|
|
747
994
|
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
995
|
+
text = ''.join(model_response.get(final=True))
|
|
996
|
+
return _messages.ModelResponse([_messages.TextPart(text)]), [response]
|
|
997
|
+
elif isinstance(model_response, models.StreamStructuredResponse):
|
|
751
998
|
if self._result_schema is not None:
|
|
752
999
|
# if there's a result schema, iterate over the stream until we find at least one tool
|
|
753
1000
|
# NOTE: this means we ignore any other tools called here
|
|
754
1001
|
structured_msg = model_response.get()
|
|
755
|
-
while not structured_msg.
|
|
1002
|
+
while not structured_msg.parts:
|
|
756
1003
|
try:
|
|
757
1004
|
await model_response.__anext__()
|
|
758
1005
|
except StopAsyncIteration:
|
|
759
1006
|
break
|
|
760
1007
|
structured_msg = model_response.get()
|
|
761
1008
|
|
|
762
|
-
if self._result_schema.find_tool(structured_msg):
|
|
763
|
-
|
|
1009
|
+
if match := self._result_schema.find_tool(structured_msg.parts):
|
|
1010
|
+
call, _ = match
|
|
1011
|
+
return _MarkFinalResult(model_response, call.tool_name)
|
|
764
1012
|
|
|
765
1013
|
# the model is calling a tool function, consume the response to get the next message
|
|
766
1014
|
async for _ in model_response:
|
|
767
1015
|
pass
|
|
768
|
-
|
|
769
|
-
if not
|
|
1016
|
+
model_response_msg = model_response.get()
|
|
1017
|
+
if not model_response_msg.parts:
|
|
770
1018
|
raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
|
|
771
|
-
messages: list[_messages.Message] = [structured_msg]
|
|
772
1019
|
|
|
773
1020
|
# we now run all tool functions in parallel
|
|
774
|
-
tasks: list[asyncio.Task[_messages.
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
1021
|
+
tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
|
|
1022
|
+
parts: list[_messages.ModelRequestPart] = []
|
|
1023
|
+
for item in model_response_msg.parts:
|
|
1024
|
+
if isinstance(item, _messages.ToolCallPart):
|
|
1025
|
+
call = item
|
|
1026
|
+
if tool := self._function_tools.get(call.tool_name):
|
|
1027
|
+
tasks.append(asyncio.create_task(tool.run(deps, call, conv_messages), name=call.tool_name))
|
|
1028
|
+
else:
|
|
1029
|
+
parts.append(self._unknown_tool(call.tool_name))
|
|
780
1030
|
|
|
781
1031
|
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
|
|
782
|
-
|
|
783
|
-
|
|
1032
|
+
task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
|
|
1033
|
+
parts.extend(task_results)
|
|
1034
|
+
return model_response_msg, parts
|
|
1035
|
+
else:
|
|
1036
|
+
assert_never(model_response)
|
|
784
1037
|
|
|
785
1038
|
async def _validate_result(
|
|
786
|
-
self,
|
|
1039
|
+
self,
|
|
1040
|
+
result_data: ResultData,
|
|
1041
|
+
deps: AgentDeps,
|
|
1042
|
+
tool_call: _messages.ToolCallPart | None,
|
|
1043
|
+
conv_messages: list[_messages.ModelMessage],
|
|
787
1044
|
) -> ResultData:
|
|
788
1045
|
for validator in self._result_validators:
|
|
789
|
-
result_data = await validator.validate(
|
|
1046
|
+
result_data = await validator.validate(
|
|
1047
|
+
result_data, deps, self._current_result_retry, tool_call, conv_messages
|
|
1048
|
+
)
|
|
790
1049
|
return result_data
|
|
791
1050
|
|
|
792
1051
|
def _incr_result_retry(self) -> None:
|
|
@@ -796,15 +1055,15 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
796
1055
|
f'Exceeded maximum retries ({self._max_result_retries}) for result validation'
|
|
797
1056
|
)
|
|
798
1057
|
|
|
799
|
-
async def
|
|
1058
|
+
async def _sys_parts(self, deps: AgentDeps) -> list[_messages.ModelRequestPart]:
|
|
800
1059
|
"""Build the initial messages for the conversation."""
|
|
801
|
-
messages: list[_messages.
|
|
1060
|
+
messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self._system_prompts]
|
|
802
1061
|
for sys_prompt_runner in self._system_prompt_functions:
|
|
803
1062
|
prompt = await sys_prompt_runner.run(deps)
|
|
804
|
-
messages.append(_messages.
|
|
1063
|
+
messages.append(_messages.SystemPromptPart(prompt))
|
|
805
1064
|
return messages
|
|
806
1065
|
|
|
807
|
-
def _unknown_tool(self, tool_name: str) -> _messages.
|
|
1066
|
+
def _unknown_tool(self, tool_name: str) -> _messages.RetryPromptPart:
|
|
808
1067
|
self._incr_result_retry()
|
|
809
1068
|
names = list(self._function_tools.keys())
|
|
810
1069
|
if self._result_schema:
|
|
@@ -813,7 +1072,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
813
1072
|
msg = f'Available tools: {", ".join(names)}'
|
|
814
1073
|
else:
|
|
815
1074
|
msg = 'No tools available.'
|
|
816
|
-
return _messages.
|
|
1075
|
+
return _messages.RetryPromptPart(content=f'Unknown tool name: {tool_name!r}. {msg}')
|
|
817
1076
|
|
|
818
1077
|
def _get_deps(self, deps: AgentDeps) -> AgentDeps:
|
|
819
1078
|
"""Get deps for a run.
|
|
@@ -852,6 +1111,11 @@ class _MarkFinalResult(Generic[ResultData]):
|
|
|
852
1111
|
"""Marker class to indicate that the result is the final result.
|
|
853
1112
|
|
|
854
1113
|
This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultData` directly.
|
|
1114
|
+
|
|
1115
|
+
It also avoids problems in the case where the result type is itself `None`, but is set.
|
|
855
1116
|
"""
|
|
856
1117
|
|
|
857
1118
|
data: ResultData
|
|
1119
|
+
"""The final result data."""
|
|
1120
|
+
tool_name: str | None
|
|
1121
|
+
"""Name of the final result tool, None if the result is a string."""
|