pydantic-ai-slim 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +12 -2
- pydantic_ai/_pydantic.py +7 -25
- pydantic_ai/_result.py +33 -18
- pydantic_ai/_system_prompt.py +2 -2
- pydantic_ai/_utils.py +9 -2
- pydantic_ai/agent.py +366 -171
- pydantic_ai/exceptions.py +20 -2
- pydantic_ai/messages.py +111 -50
- pydantic_ai/models/__init__.py +39 -14
- pydantic_ai/models/anthropic.py +344 -0
- pydantic_ai/models/function.py +62 -40
- pydantic_ai/models/gemini.py +164 -124
- pydantic_ai/models/groq.py +112 -94
- pydantic_ai/models/mistral.py +668 -0
- pydantic_ai/models/ollama.py +1 -1
- pydantic_ai/models/openai.py +120 -96
- pydantic_ai/models/test.py +78 -61
- pydantic_ai/models/vertexai.py +7 -3
- pydantic_ai/result.py +96 -68
- pydantic_ai/settings.py +137 -0
- pydantic_ai/tools.py +46 -26
- {pydantic_ai_slim-0.0.12.dist-info → pydantic_ai_slim-0.0.14.dist-info}/METADATA +8 -3
- pydantic_ai_slim-0.0.14.dist-info/RECORD +26 -0
- {pydantic_ai_slim-0.0.12.dist-info → pydantic_ai_slim-0.0.14.dist-info}/WHEEL +1 -1
- pydantic_ai_slim-0.0.12.dist-info/RECORD +0 -23
pydantic_ai/agent.py
CHANGED
|
@@ -7,7 +7,7 @@ from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
|
|
|
7
7
|
from contextlib import asynccontextmanager, contextmanager
|
|
8
8
|
from dataclasses import dataclass, field
|
|
9
9
|
from types import FrameType
|
|
10
|
-
from typing import Any, Callable, Generic, cast, final, overload
|
|
10
|
+
from typing import Any, Callable, Generic, Literal, cast, final, overload
|
|
11
11
|
|
|
12
12
|
import logfire_api
|
|
13
13
|
from typing_extensions import assert_never
|
|
@@ -22,6 +22,7 @@ from . import (
|
|
|
22
22
|
result,
|
|
23
23
|
)
|
|
24
24
|
from .result import ResultData
|
|
25
|
+
from .settings import ModelSettings, UsageLimits, merge_model_settings
|
|
25
26
|
from .tools import (
|
|
26
27
|
AgentDeps,
|
|
27
28
|
RunContext,
|
|
@@ -39,6 +40,12 @@ __all__ = ('Agent',)
|
|
|
39
40
|
_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
|
|
40
41
|
|
|
41
42
|
NoneType = type(None)
|
|
43
|
+
EndStrategy = Literal['early', 'exhaustive']
|
|
44
|
+
"""The strategy for handling multiple tool calls when a final result is found.
|
|
45
|
+
|
|
46
|
+
- `'early'`: Stop processing other tool calls once a final result is found
|
|
47
|
+
- `'exhaustive'`: Process all tool calls even after finding a final result
|
|
48
|
+
"""
|
|
42
49
|
|
|
43
50
|
|
|
44
51
|
@final
|
|
@@ -53,7 +60,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
53
60
|
|
|
54
61
|
Minimal usage example:
|
|
55
62
|
|
|
56
|
-
```
|
|
63
|
+
```python
|
|
57
64
|
from pydantic_ai import Agent
|
|
58
65
|
|
|
59
66
|
agent = Agent('openai:gpt-4o')
|
|
@@ -63,14 +70,31 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
63
70
|
```
|
|
64
71
|
"""
|
|
65
72
|
|
|
66
|
-
# dataclass fields
|
|
73
|
+
# we use dataclass fields in order to conveniently know what attributes are available
|
|
67
74
|
model: models.Model | models.KnownModelName | None
|
|
68
75
|
"""The default model configured for this agent."""
|
|
76
|
+
|
|
69
77
|
name: str | None
|
|
70
78
|
"""The name of the agent, used for logging.
|
|
71
79
|
|
|
72
80
|
If `None`, we try to infer the agent name from the call frame when the agent is first run.
|
|
73
81
|
"""
|
|
82
|
+
end_strategy: EndStrategy
|
|
83
|
+
"""Strategy for handling tool calls when a final result is found."""
|
|
84
|
+
|
|
85
|
+
model_settings: ModelSettings | None
|
|
86
|
+
"""Optional model request settings to use for this agents's runs, by default.
|
|
87
|
+
|
|
88
|
+
Note, if `model_settings` is provided by `run`, `run_sync`, or `run_stream`, those settings will
|
|
89
|
+
be merged with this value, with the runtime argument taking priority.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
last_run_messages: list[_messages.ModelMessage] | None
|
|
93
|
+
"""The messages from the last run, useful when a run raised an exception.
|
|
94
|
+
|
|
95
|
+
Note: these are not used by the agent, e.g. in future runs, they are just stored for developers' convenience.
|
|
96
|
+
"""
|
|
97
|
+
|
|
74
98
|
_result_schema: _result.ResultSchema[ResultData] | None = field(repr=False)
|
|
75
99
|
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = field(repr=False)
|
|
76
100
|
_allow_text_result: bool = field(repr=False)
|
|
@@ -80,14 +104,8 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
80
104
|
_system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = field(repr=False)
|
|
81
105
|
_deps_type: type[AgentDeps] = field(repr=False)
|
|
82
106
|
_max_result_retries: int = field(repr=False)
|
|
83
|
-
_current_result_retry: int = field(repr=False)
|
|
84
107
|
_override_deps: _utils.Option[AgentDeps] = field(default=None, repr=False)
|
|
85
108
|
_override_model: _utils.Option[models.Model] = field(default=None, repr=False)
|
|
86
|
-
last_run_messages: list[_messages.Message] | None = None
|
|
87
|
-
"""The messages from the last run, useful when a run raised an exception.
|
|
88
|
-
|
|
89
|
-
Note: these are not used by the agent, e.g. in future runs, they are just stored for developers' convenience.
|
|
90
|
-
"""
|
|
91
109
|
|
|
92
110
|
def __init__(
|
|
93
111
|
self,
|
|
@@ -97,18 +115,20 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
97
115
|
system_prompt: str | Sequence[str] = (),
|
|
98
116
|
deps_type: type[AgentDeps] = NoneType,
|
|
99
117
|
name: str | None = None,
|
|
118
|
+
model_settings: ModelSettings | None = None,
|
|
100
119
|
retries: int = 1,
|
|
101
120
|
result_tool_name: str = 'final_result',
|
|
102
121
|
result_tool_description: str | None = None,
|
|
103
122
|
result_retries: int | None = None,
|
|
104
123
|
tools: Sequence[Tool[AgentDeps] | ToolFuncEither[AgentDeps, ...]] = (),
|
|
105
124
|
defer_model_check: bool = False,
|
|
125
|
+
end_strategy: EndStrategy = 'early',
|
|
106
126
|
):
|
|
107
127
|
"""Create an agent.
|
|
108
128
|
|
|
109
129
|
Args:
|
|
110
130
|
model: The default model to use for this agent, if not provide,
|
|
111
|
-
you must provide the model when calling
|
|
131
|
+
you must provide the model when calling it.
|
|
112
132
|
result_type: The type of the result data, used to validate the result data, defaults to `str`.
|
|
113
133
|
system_prompt: Static system prompts to use for this agent, you can also register system
|
|
114
134
|
prompts via a function with [`system_prompt`][pydantic_ai.Agent.system_prompt].
|
|
@@ -118,6 +138,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
118
138
|
or add a type hint `: Agent[None, <return type>]`.
|
|
119
139
|
name: The name of the agent, used for logging. If `None`, we try to infer the agent name from the call frame
|
|
120
140
|
when the agent is first run.
|
|
141
|
+
model_settings: Optional model request settings to use for this agent's runs, by default.
|
|
121
142
|
retries: The default number of retries to allow before raising an error.
|
|
122
143
|
result_tool_name: The name of the tool to use for the final result.
|
|
123
144
|
result_tool_description: The description of the final result tool.
|
|
@@ -129,13 +150,18 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
129
150
|
which checks for the necessary environment variables. Set this to `false`
|
|
130
151
|
to defer the evaluation until the first run. Useful if you want to
|
|
131
152
|
[override the model][pydantic_ai.Agent.override] for testing.
|
|
153
|
+
end_strategy: Strategy for handling tool calls that are requested alongside a final result.
|
|
154
|
+
See [`EndStrategy`][pydantic_ai.agent.EndStrategy] for more information.
|
|
132
155
|
"""
|
|
133
156
|
if model is None or defer_model_check:
|
|
134
157
|
self.model = model
|
|
135
158
|
else:
|
|
136
159
|
self.model = models.infer_model(model)
|
|
137
160
|
|
|
161
|
+
self.end_strategy = end_strategy
|
|
138
162
|
self.name = name
|
|
163
|
+
self.model_settings = model_settings
|
|
164
|
+
self.last_run_messages = None
|
|
139
165
|
self._result_schema = _result.ResultSchema[result_type].build(
|
|
140
166
|
result_type, result_tool_name, result_tool_description
|
|
141
167
|
)
|
|
@@ -153,25 +179,39 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
153
179
|
self._deps_type = deps_type
|
|
154
180
|
self._system_prompt_functions = []
|
|
155
181
|
self._max_result_retries = result_retries if result_retries is not None else retries
|
|
156
|
-
self._current_result_retry = 0
|
|
157
182
|
self._result_validators = []
|
|
158
183
|
|
|
159
184
|
async def run(
|
|
160
185
|
self,
|
|
161
186
|
user_prompt: str,
|
|
162
187
|
*,
|
|
163
|
-
message_history: list[_messages.
|
|
188
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
164
189
|
model: models.Model | models.KnownModelName | None = None,
|
|
165
190
|
deps: AgentDeps = None,
|
|
191
|
+
model_settings: ModelSettings | None = None,
|
|
192
|
+
usage_limits: UsageLimits | None = None,
|
|
166
193
|
infer_name: bool = True,
|
|
167
194
|
) -> result.RunResult[ResultData]:
|
|
168
195
|
"""Run the agent with a user prompt in async mode.
|
|
169
196
|
|
|
197
|
+
Example:
|
|
198
|
+
```python
|
|
199
|
+
from pydantic_ai import Agent
|
|
200
|
+
|
|
201
|
+
agent = Agent('openai:gpt-4o')
|
|
202
|
+
|
|
203
|
+
result_sync = agent.run_sync('What is the capital of Italy?')
|
|
204
|
+
print(result_sync.data)
|
|
205
|
+
#> Rome
|
|
206
|
+
```
|
|
207
|
+
|
|
170
208
|
Args:
|
|
171
209
|
user_prompt: User input to start/continue the conversation.
|
|
172
210
|
message_history: History of the conversation so far.
|
|
173
211
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
174
212
|
deps: Optional dependencies to use for this run.
|
|
213
|
+
model_settings: Optional settings to use for this model's request.
|
|
214
|
+
usage_limits: Optional limits on model request count or token usage.
|
|
175
215
|
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
176
216
|
|
|
177
217
|
Returns:
|
|
@@ -182,6 +222,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
182
222
|
model_used, mode_selection = await self._get_model(model)
|
|
183
223
|
|
|
184
224
|
deps = self._get_deps(deps)
|
|
225
|
+
new_message_index = len(message_history) if message_history else 0
|
|
185
226
|
|
|
186
227
|
with _logfire.span(
|
|
187
228
|
'{agent_name} run {prompt=}',
|
|
@@ -191,67 +232,91 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
191
232
|
model_name=model_used.name(),
|
|
192
233
|
agent_name=self.name or 'agent',
|
|
193
234
|
) as run_span:
|
|
194
|
-
|
|
195
|
-
self.
|
|
235
|
+
run_context = RunContext(deps, 0, [], None, model_used)
|
|
236
|
+
messages = await self._prepare_messages(user_prompt, message_history, run_context)
|
|
237
|
+
self.last_run_messages = run_context.messages = messages
|
|
196
238
|
|
|
197
239
|
for tool in self._function_tools.values():
|
|
198
240
|
tool.current_retry = 0
|
|
199
241
|
|
|
200
|
-
|
|
242
|
+
usage = result.Usage(requests=0)
|
|
243
|
+
model_settings = merge_model_settings(self.model_settings, model_settings)
|
|
244
|
+
usage_limits = usage_limits or UsageLimits()
|
|
201
245
|
|
|
202
246
|
run_step = 0
|
|
203
247
|
while True:
|
|
248
|
+
usage_limits.check_before_request(usage)
|
|
249
|
+
|
|
204
250
|
run_step += 1
|
|
205
251
|
with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
|
|
206
|
-
agent_model = await self._prepare_model(
|
|
252
|
+
agent_model = await self._prepare_model(run_context)
|
|
207
253
|
|
|
208
254
|
with _logfire.span('model request', run_step=run_step) as model_req_span:
|
|
209
|
-
model_response,
|
|
255
|
+
model_response, request_usage = await agent_model.request(messages, model_settings)
|
|
210
256
|
model_req_span.set_attribute('response', model_response)
|
|
211
|
-
model_req_span.set_attribute('
|
|
212
|
-
model_req_span.message = f'model request -> {model_response.role}'
|
|
257
|
+
model_req_span.set_attribute('usage', request_usage)
|
|
213
258
|
|
|
214
259
|
messages.append(model_response)
|
|
215
|
-
|
|
260
|
+
usage += request_usage
|
|
261
|
+
usage.requests += 1
|
|
262
|
+
usage_limits.check_tokens(request_usage)
|
|
216
263
|
|
|
217
264
|
with _logfire.span('handle model response', run_step=run_step) as handle_span:
|
|
218
|
-
final_result,
|
|
265
|
+
final_result, tool_responses = await self._handle_model_response(model_response, run_context)
|
|
219
266
|
|
|
220
|
-
|
|
221
|
-
|
|
267
|
+
if tool_responses:
|
|
268
|
+
# Add parts to the conversation as a new message
|
|
269
|
+
messages.append(_messages.ModelRequest(tool_responses))
|
|
222
270
|
|
|
223
271
|
# Check if we got a final result
|
|
224
272
|
if final_result is not None:
|
|
225
273
|
result_data = final_result.data
|
|
226
274
|
run_span.set_attribute('all_messages', messages)
|
|
227
|
-
run_span.set_attribute('
|
|
275
|
+
run_span.set_attribute('usage', usage)
|
|
228
276
|
handle_span.set_attribute('result', result_data)
|
|
229
277
|
handle_span.message = 'handle model response -> final result'
|
|
230
|
-
return result.RunResult(messages, new_message_index, result_data,
|
|
278
|
+
return result.RunResult(messages, new_message_index, result_data, usage)
|
|
231
279
|
else:
|
|
232
280
|
# continue the conversation
|
|
233
|
-
handle_span.set_attribute('tool_responses',
|
|
234
|
-
|
|
235
|
-
handle_span.message = f'handle model response -> {
|
|
281
|
+
handle_span.set_attribute('tool_responses', tool_responses)
|
|
282
|
+
tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
|
|
283
|
+
handle_span.message = f'handle model response -> {tool_responses_str}'
|
|
236
284
|
|
|
237
285
|
def run_sync(
|
|
238
286
|
self,
|
|
239
287
|
user_prompt: str,
|
|
240
288
|
*,
|
|
241
|
-
message_history: list[_messages.
|
|
289
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
242
290
|
model: models.Model | models.KnownModelName | None = None,
|
|
243
291
|
deps: AgentDeps = None,
|
|
292
|
+
model_settings: ModelSettings | None = None,
|
|
293
|
+
usage_limits: UsageLimits | None = None,
|
|
244
294
|
infer_name: bool = True,
|
|
245
295
|
) -> result.RunResult[ResultData]:
|
|
246
296
|
"""Run the agent with a user prompt synchronously.
|
|
247
297
|
|
|
248
|
-
This is a convenience method that wraps `self.run` with `loop.run_until_complete()`.
|
|
298
|
+
This is a convenience method that wraps [`self.run`][pydantic_ai.Agent.run] with `loop.run_until_complete(...)`.
|
|
299
|
+
You therefore can't use this method inside async code or if there's an active event loop.
|
|
300
|
+
|
|
301
|
+
Example:
|
|
302
|
+
```python
|
|
303
|
+
from pydantic_ai import Agent
|
|
304
|
+
|
|
305
|
+
agent = Agent('openai:gpt-4o')
|
|
306
|
+
|
|
307
|
+
async def main():
|
|
308
|
+
result = await agent.run('What is the capital of France?')
|
|
309
|
+
print(result.data)
|
|
310
|
+
#> Paris
|
|
311
|
+
```
|
|
249
312
|
|
|
250
313
|
Args:
|
|
251
314
|
user_prompt: User input to start/continue the conversation.
|
|
252
315
|
message_history: History of the conversation so far.
|
|
253
316
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
254
317
|
deps: Optional dependencies to use for this run.
|
|
318
|
+
model_settings: Optional settings to use for this model's request.
|
|
319
|
+
usage_limits: Optional limits on model request count or token usage.
|
|
255
320
|
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
256
321
|
|
|
257
322
|
Returns:
|
|
@@ -259,9 +324,16 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
259
324
|
"""
|
|
260
325
|
if infer_name and self.name is None:
|
|
261
326
|
self._infer_name(inspect.currentframe())
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
327
|
+
return asyncio.get_event_loop().run_until_complete(
|
|
328
|
+
self.run(
|
|
329
|
+
user_prompt,
|
|
330
|
+
message_history=message_history,
|
|
331
|
+
model=model,
|
|
332
|
+
deps=deps,
|
|
333
|
+
model_settings=model_settings,
|
|
334
|
+
usage_limits=usage_limits,
|
|
335
|
+
infer_name=False,
|
|
336
|
+
)
|
|
265
337
|
)
|
|
266
338
|
|
|
267
339
|
@asynccontextmanager
|
|
@@ -269,18 +341,34 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
269
341
|
self,
|
|
270
342
|
user_prompt: str,
|
|
271
343
|
*,
|
|
272
|
-
message_history: list[_messages.
|
|
344
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
273
345
|
model: models.Model | models.KnownModelName | None = None,
|
|
274
346
|
deps: AgentDeps = None,
|
|
347
|
+
model_settings: ModelSettings | None = None,
|
|
348
|
+
usage_limits: UsageLimits | None = None,
|
|
275
349
|
infer_name: bool = True,
|
|
276
350
|
) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]:
|
|
277
351
|
"""Run the agent with a user prompt in async mode, returning a streamed response.
|
|
278
352
|
|
|
353
|
+
Example:
|
|
354
|
+
```python
|
|
355
|
+
from pydantic_ai import Agent
|
|
356
|
+
|
|
357
|
+
agent = Agent('openai:gpt-4o')
|
|
358
|
+
|
|
359
|
+
async def main():
|
|
360
|
+
async with agent.run_stream('What is the capital of the UK?') as response:
|
|
361
|
+
print(await response.get_data())
|
|
362
|
+
#> London
|
|
363
|
+
```
|
|
364
|
+
|
|
279
365
|
Args:
|
|
280
366
|
user_prompt: User input to start/continue the conversation.
|
|
281
367
|
message_history: History of the conversation so far.
|
|
282
368
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
283
369
|
deps: Optional dependencies to use for this run.
|
|
370
|
+
model_settings: Optional settings to use for this model's request.
|
|
371
|
+
usage_limits: Optional limits on model request count or token usage.
|
|
284
372
|
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
285
373
|
|
|
286
374
|
Returns:
|
|
@@ -293,6 +381,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
293
381
|
model_used, mode_selection = await self._get_model(model)
|
|
294
382
|
|
|
295
383
|
deps = self._get_deps(deps)
|
|
384
|
+
new_message_index = len(message_history) if message_history else 0
|
|
296
385
|
|
|
297
386
|
with _logfire.span(
|
|
298
387
|
'{agent_name} run stream {prompt=}',
|
|
@@ -302,60 +391,89 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
302
391
|
model_name=model_used.name(),
|
|
303
392
|
agent_name=self.name or 'agent',
|
|
304
393
|
) as run_span:
|
|
305
|
-
|
|
306
|
-
self.
|
|
394
|
+
run_context = RunContext(deps, 0, [], None, model_used)
|
|
395
|
+
messages = await self._prepare_messages(user_prompt, message_history, run_context)
|
|
396
|
+
self.last_run_messages = run_context.messages = messages
|
|
307
397
|
|
|
308
398
|
for tool in self._function_tools.values():
|
|
309
399
|
tool.current_retry = 0
|
|
310
400
|
|
|
311
|
-
|
|
401
|
+
usage = result.Usage()
|
|
402
|
+
model_settings = merge_model_settings(self.model_settings, model_settings)
|
|
403
|
+
usage_limits = usage_limits or UsageLimits()
|
|
312
404
|
|
|
313
405
|
run_step = 0
|
|
314
406
|
while True:
|
|
315
407
|
run_step += 1
|
|
408
|
+
usage_limits.check_before_request(usage)
|
|
316
409
|
|
|
317
410
|
with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
|
|
318
|
-
agent_model = await self._prepare_model(
|
|
411
|
+
agent_model = await self._prepare_model(run_context)
|
|
319
412
|
|
|
320
413
|
with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
|
|
321
|
-
async with agent_model.request_stream(messages) as model_response:
|
|
414
|
+
async with agent_model.request_stream(messages, model_settings) as model_response:
|
|
415
|
+
usage.requests += 1
|
|
322
416
|
model_req_span.set_attribute('response_type', model_response.__class__.__name__)
|
|
323
417
|
# We want to end the "model request" span here, but we can't exit the context manager
|
|
324
418
|
# in the traditional way
|
|
325
419
|
model_req_span.__exit__(None, None, None)
|
|
326
420
|
|
|
327
421
|
with _logfire.span('handle model response') as handle_span:
|
|
328
|
-
|
|
329
|
-
model_response, deps
|
|
330
|
-
)
|
|
331
|
-
|
|
332
|
-
# Add all messages to the conversation
|
|
333
|
-
messages.extend(response_messages)
|
|
422
|
+
maybe_final_result = await self._handle_streamed_model_response(model_response, run_context)
|
|
334
423
|
|
|
335
424
|
# Check if we got a final result
|
|
336
|
-
if
|
|
337
|
-
result_stream =
|
|
338
|
-
|
|
339
|
-
handle_span.set_attribute('result_type', result_stream.__class__.__name__)
|
|
425
|
+
if isinstance(maybe_final_result, _MarkFinalResult):
|
|
426
|
+
result_stream = maybe_final_result.data
|
|
427
|
+
result_tool_name = maybe_final_result.tool_name
|
|
340
428
|
handle_span.message = 'handle model response -> final result'
|
|
429
|
+
|
|
430
|
+
async def on_complete():
|
|
431
|
+
"""Called when the stream has completed.
|
|
432
|
+
|
|
433
|
+
The model response will have been added to messages by now
|
|
434
|
+
by `StreamedRunResult._marked_completed`.
|
|
435
|
+
"""
|
|
436
|
+
last_message = messages[-1]
|
|
437
|
+
assert isinstance(last_message, _messages.ModelResponse)
|
|
438
|
+
tool_calls = [
|
|
439
|
+
part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)
|
|
440
|
+
]
|
|
441
|
+
parts = await self._process_function_tools(
|
|
442
|
+
tool_calls, result_tool_name, run_context
|
|
443
|
+
)
|
|
444
|
+
if parts:
|
|
445
|
+
messages.append(_messages.ModelRequest(parts))
|
|
446
|
+
run_span.set_attribute('all_messages', messages)
|
|
447
|
+
|
|
341
448
|
yield result.StreamedRunResult(
|
|
342
449
|
messages,
|
|
343
450
|
new_message_index,
|
|
344
|
-
|
|
451
|
+
usage,
|
|
452
|
+
usage_limits,
|
|
345
453
|
result_stream,
|
|
346
454
|
self._result_schema,
|
|
347
|
-
|
|
455
|
+
run_context,
|
|
348
456
|
self._result_validators,
|
|
349
|
-
|
|
457
|
+
result_tool_name,
|
|
458
|
+
on_complete,
|
|
350
459
|
)
|
|
351
460
|
return
|
|
352
461
|
else:
|
|
353
462
|
# continue the conversation
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
463
|
+
model_response_msg, tool_responses = maybe_final_result
|
|
464
|
+
# if we got a model response add that to messages
|
|
465
|
+
messages.append(model_response_msg)
|
|
466
|
+
if tool_responses:
|
|
467
|
+
# if we got one or more tool response parts, add a model request message
|
|
468
|
+
messages.append(_messages.ModelRequest(tool_responses))
|
|
469
|
+
|
|
470
|
+
handle_span.set_attribute('tool_responses', tool_responses)
|
|
471
|
+
tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
|
|
472
|
+
handle_span.message = f'handle model response -> {tool_responses_str}'
|
|
473
|
+
# the model_response should have been fully streamed by now, we can add its usage
|
|
474
|
+
model_response_usage = model_response.usage()
|
|
475
|
+
usage += model_response_usage
|
|
476
|
+
usage_limits.check_tokens(usage)
|
|
359
477
|
|
|
360
478
|
@contextmanager
|
|
361
479
|
def override(
|
|
@@ -367,6 +485,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
367
485
|
"""Context manager to temporarily override agent dependencies and model.
|
|
368
486
|
|
|
369
487
|
This is particularly useful when testing.
|
|
488
|
+
You can find an example of this [here](../testing-evals.md#overriding-model-via-pytest-fixtures).
|
|
370
489
|
|
|
371
490
|
Args:
|
|
372
491
|
deps: The dependencies to use instead of the dependencies passed to the agent run.
|
|
@@ -415,14 +534,14 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
415
534
|
) -> _system_prompt.SystemPromptFunc[AgentDeps]:
|
|
416
535
|
"""Decorator to register a system prompt function.
|
|
417
536
|
|
|
418
|
-
Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as
|
|
537
|
+
Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its only argument.
|
|
419
538
|
Can decorate a sync or async functions.
|
|
420
539
|
|
|
421
540
|
Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure
|
|
422
541
|
the type of the function, see `tests/typed_agent.py` for tests.
|
|
423
542
|
|
|
424
543
|
Example:
|
|
425
|
-
```
|
|
544
|
+
```python
|
|
426
545
|
from pydantic_ai import Agent, RunContext
|
|
427
546
|
|
|
428
547
|
agent = Agent('test', deps_type=str)
|
|
@@ -466,14 +585,14 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
466
585
|
) -> _result.ResultValidatorFunc[AgentDeps, ResultData]:
|
|
467
586
|
"""Decorator to register a result validator function.
|
|
468
587
|
|
|
469
|
-
Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as
|
|
588
|
+
Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
|
|
470
589
|
Can decorate a sync or async functions.
|
|
471
590
|
|
|
472
591
|
Overloads for every possible signature of `result_validator` are included so the decorator doesn't obscure
|
|
473
592
|
the type of the function, see `tests/typed_agent.py` for tests.
|
|
474
593
|
|
|
475
594
|
Example:
|
|
476
|
-
```
|
|
595
|
+
```python
|
|
477
596
|
from pydantic_ai import Agent, ModelRetry, RunContext
|
|
478
597
|
|
|
479
598
|
agent = Agent('test', deps_type=str)
|
|
@@ -523,13 +642,13 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
523
642
|
Can decorate a sync or async functions.
|
|
524
643
|
|
|
525
644
|
The docstring is inspected to extract both the tool description and description of each parameter,
|
|
526
|
-
[learn more](../
|
|
645
|
+
[learn more](../tools.md#function-tools-and-schema).
|
|
527
646
|
|
|
528
647
|
We can't add overloads for every possible signature of tool, since the return type is a recursive union
|
|
529
648
|
so the signature of functions decorated with `@agent.tool` is obscured.
|
|
530
649
|
|
|
531
650
|
Example:
|
|
532
|
-
```
|
|
651
|
+
```python
|
|
533
652
|
from pydantic_ai import Agent, RunContext
|
|
534
653
|
|
|
535
654
|
agent = Agent('test', deps_type=int)
|
|
@@ -595,13 +714,13 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
595
714
|
Can decorate a sync or async functions.
|
|
596
715
|
|
|
597
716
|
The docstring is inspected to extract both the tool description and description of each parameter,
|
|
598
|
-
[learn more](../
|
|
717
|
+
[learn more](../tools.md#function-tools-and-schema).
|
|
599
718
|
|
|
600
719
|
We can't add overloads for every possible signature of tool, since the return type is a recursive union
|
|
601
720
|
so the signature of functions decorated with `@agent.tool` is obscured.
|
|
602
721
|
|
|
603
722
|
Example:
|
|
604
|
-
```
|
|
723
|
+
```python
|
|
605
724
|
from pydantic_ai import Agent, RunContext
|
|
606
725
|
|
|
607
726
|
agent = Agent('test')
|
|
@@ -696,193 +815,266 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
696
815
|
|
|
697
816
|
return model_, mode_selection
|
|
698
817
|
|
|
699
|
-
async def _prepare_model(self,
|
|
700
|
-
"""
|
|
818
|
+
async def _prepare_model(self, run_context: RunContext[AgentDeps]) -> models.AgentModel:
|
|
819
|
+
"""Build tools and create an agent model."""
|
|
701
820
|
function_tools: list[ToolDefinition] = []
|
|
702
821
|
|
|
703
822
|
async def add_tool(tool: Tool[AgentDeps]) -> None:
|
|
704
|
-
ctx =
|
|
823
|
+
ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name)
|
|
705
824
|
if tool_def := await tool.prepare_tool_def(ctx):
|
|
706
825
|
function_tools.append(tool_def)
|
|
707
826
|
|
|
708
827
|
await asyncio.gather(*map(add_tool, self._function_tools.values()))
|
|
709
828
|
|
|
710
|
-
return await model.agent_model(
|
|
829
|
+
return await run_context.model.agent_model(
|
|
711
830
|
function_tools=function_tools,
|
|
712
831
|
allow_text_result=self._allow_text_result,
|
|
713
832
|
result_tools=self._result_schema.tool_defs() if self._result_schema is not None else [],
|
|
714
833
|
)
|
|
715
834
|
|
|
716
835
|
async def _prepare_messages(
|
|
717
|
-
self,
|
|
718
|
-
) ->
|
|
719
|
-
|
|
720
|
-
if message_history and any(m.role == 'system' for m in message_history):
|
|
836
|
+
self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[AgentDeps]
|
|
837
|
+
) -> list[_messages.ModelMessage]:
|
|
838
|
+
if message_history:
|
|
721
839
|
# shallow copy messages
|
|
722
840
|
messages = message_history.copy()
|
|
841
|
+
messages.append(_messages.ModelRequest([_messages.UserPromptPart(user_prompt)]))
|
|
723
842
|
else:
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
843
|
+
parts = await self._sys_parts(run_context)
|
|
844
|
+
parts.append(_messages.UserPromptPart(user_prompt))
|
|
845
|
+
messages: list[_messages.ModelMessage] = [_messages.ModelRequest(parts)]
|
|
727
846
|
|
|
728
|
-
|
|
729
|
-
messages.append(_messages.UserPrompt(user_prompt))
|
|
730
|
-
return new_message_index, messages
|
|
847
|
+
return messages
|
|
731
848
|
|
|
732
849
|
async def _handle_model_response(
|
|
733
|
-
self, model_response: _messages.
|
|
734
|
-
) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.
|
|
850
|
+
self, model_response: _messages.ModelResponse, run_context: RunContext[AgentDeps]
|
|
851
|
+
) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
|
|
735
852
|
"""Process a non-streamed response from the model.
|
|
736
853
|
|
|
737
854
|
Returns:
|
|
738
|
-
A tuple of `(final_result,
|
|
855
|
+
A tuple of `(final_result, request parts)`. If `final_result` is not `None`, the conversation should end.
|
|
739
856
|
"""
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
857
|
+
texts: list[str] = []
|
|
858
|
+
tool_calls: list[_messages.ToolCallPart] = []
|
|
859
|
+
for part in model_response.parts:
|
|
860
|
+
if isinstance(part, _messages.TextPart):
|
|
861
|
+
# ignore empty content for text parts, see #437
|
|
862
|
+
if part.content:
|
|
863
|
+
texts.append(part.content)
|
|
864
|
+
else:
|
|
865
|
+
tool_calls.append(part)
|
|
866
|
+
|
|
867
|
+
if texts:
|
|
868
|
+
text = '\n\n'.join(texts)
|
|
869
|
+
return await self._handle_text_response(text, run_context)
|
|
870
|
+
elif tool_calls:
|
|
871
|
+
return await self._handle_structured_response(tool_calls, run_context)
|
|
872
|
+
else:
|
|
873
|
+
raise exceptions.UnexpectedModelBehavior('Received empty model response')
|
|
874
|
+
|
|
875
|
+
async def _handle_text_response(
|
|
876
|
+
self, text: str, run_context: RunContext[AgentDeps]
|
|
877
|
+
) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
|
|
878
|
+
"""Handle a plain text response from the model for non-streaming responses."""
|
|
879
|
+
if self._allow_text_result:
|
|
880
|
+
result_data_input = cast(ResultData, text)
|
|
881
|
+
try:
|
|
882
|
+
result_data = await self._validate_result(result_data_input, run_context, None)
|
|
883
|
+
except _result.ToolRetryError as e:
|
|
884
|
+
self._incr_result_retry(run_context)
|
|
885
|
+
return None, [e.tool_retry]
|
|
886
|
+
else:
|
|
887
|
+
return _MarkFinalResult(result_data, None), []
|
|
888
|
+
else:
|
|
889
|
+
self._incr_result_retry(run_context)
|
|
890
|
+
response = _messages.RetryPromptPart(
|
|
891
|
+
content='Plain text responses are not permitted, please call one of the functions instead.',
|
|
892
|
+
)
|
|
893
|
+
return None, [response]
|
|
894
|
+
|
|
895
|
+
async def _handle_structured_response(
|
|
896
|
+
self, tool_calls: list[_messages.ToolCallPart], run_context: RunContext[AgentDeps]
|
|
897
|
+
) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
|
|
898
|
+
"""Handle a structured response containing tool calls from the model for non-streaming responses."""
|
|
899
|
+
assert tool_calls, 'Expected at least one tool call'
|
|
900
|
+
|
|
901
|
+
# first look for the result tool call
|
|
902
|
+
final_result: _MarkFinalResult[ResultData] | None = None
|
|
903
|
+
|
|
904
|
+
parts: list[_messages.ModelRequestPart] = []
|
|
905
|
+
if result_schema := self._result_schema:
|
|
906
|
+
if match := result_schema.find_tool(tool_calls):
|
|
907
|
+
call, result_tool = match
|
|
744
908
|
try:
|
|
745
|
-
result_data =
|
|
909
|
+
result_data = result_tool.validate(call)
|
|
910
|
+
result_data = await self._validate_result(result_data, run_context, call)
|
|
746
911
|
except _result.ToolRetryError as e:
|
|
747
|
-
self._incr_result_retry()
|
|
748
|
-
|
|
912
|
+
self._incr_result_retry(run_context)
|
|
913
|
+
parts.append(e.tool_retry)
|
|
749
914
|
else:
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
915
|
+
final_result = _MarkFinalResult(result_data, call.tool_name)
|
|
916
|
+
|
|
917
|
+
# Then build the other request parts based on end strategy
|
|
918
|
+
parts += await self._process_function_tools(tool_calls, final_result and final_result.tool_name, run_context)
|
|
919
|
+
|
|
920
|
+
return final_result, parts
|
|
921
|
+
|
|
922
|
+
async def _process_function_tools(
|
|
923
|
+
self,
|
|
924
|
+
tool_calls: list[_messages.ToolCallPart],
|
|
925
|
+
result_tool_name: str | None,
|
|
926
|
+
run_context: RunContext[AgentDeps],
|
|
927
|
+
) -> list[_messages.ModelRequestPart]:
|
|
928
|
+
"""Process function (non-result) tool calls in parallel.
|
|
929
|
+
|
|
930
|
+
Also add stub return parts for any other tools that need it.
|
|
931
|
+
"""
|
|
932
|
+
parts: list[_messages.ModelRequestPart] = []
|
|
933
|
+
tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
|
|
934
|
+
|
|
935
|
+
stub_function_tools = bool(result_tool_name) and self.end_strategy == 'early'
|
|
936
|
+
|
|
937
|
+
# we rely on the fact that if we found a result, it's the first result tool in the last
|
|
938
|
+
found_used_result_tool = False
|
|
939
|
+
for call in tool_calls:
|
|
940
|
+
if call.tool_name == result_tool_name and not found_used_result_tool:
|
|
941
|
+
found_used_result_tool = True
|
|
942
|
+
parts.append(
|
|
943
|
+
_messages.ToolReturnPart(
|
|
944
|
+
tool_name=call.tool_name,
|
|
945
|
+
content='Final result processed.',
|
|
946
|
+
tool_call_id=call.tool_call_id,
|
|
947
|
+
)
|
|
755
948
|
)
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
# NOTE: this means we ignore any other tools called here
|
|
761
|
-
if match := self._result_schema.find_tool(model_response):
|
|
762
|
-
call, result_tool = match
|
|
763
|
-
try:
|
|
764
|
-
result_data = result_tool.validate(call)
|
|
765
|
-
result_data = await self._validate_result(result_data, deps, call)
|
|
766
|
-
except _result.ToolRetryError as e:
|
|
767
|
-
self._incr_result_retry()
|
|
768
|
-
return None, [e.tool_retry]
|
|
769
|
-
else:
|
|
770
|
-
# Add a ToolReturn message for the schema tool call
|
|
771
|
-
tool_return = _messages.ToolReturn(
|
|
949
|
+
elif tool := self._function_tools.get(call.tool_name):
|
|
950
|
+
if stub_function_tools:
|
|
951
|
+
parts.append(
|
|
952
|
+
_messages.ToolReturnPart(
|
|
772
953
|
tool_name=call.tool_name,
|
|
773
|
-
content='
|
|
774
|
-
|
|
954
|
+
content='Tool not executed - a final result was already processed.',
|
|
955
|
+
tool_call_id=call.tool_call_id,
|
|
775
956
|
)
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
if not model_response.calls:
|
|
779
|
-
raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
|
|
780
|
-
|
|
781
|
-
# otherwise we run all tool functions in parallel
|
|
782
|
-
messages: list[_messages.Message] = []
|
|
783
|
-
tasks: list[asyncio.Task[_messages.Message]] = []
|
|
784
|
-
for call in model_response.calls:
|
|
785
|
-
if tool := self._function_tools.get(call.tool_name):
|
|
786
|
-
tasks.append(asyncio.create_task(tool.run(deps, call), name=call.tool_name))
|
|
957
|
+
)
|
|
787
958
|
else:
|
|
788
|
-
|
|
959
|
+
tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
|
|
960
|
+
elif self._result_schema is not None and call.tool_name in self._result_schema.tools:
|
|
961
|
+
# if tool_name is in _result_schema, it means we found a result tool but an error occurred in
|
|
962
|
+
# validation, we don't add another part here
|
|
963
|
+
if result_tool_name is not None:
|
|
964
|
+
parts.append(
|
|
965
|
+
_messages.ToolReturnPart(
|
|
966
|
+
tool_name=call.tool_name,
|
|
967
|
+
content='Result tool not used - a final result was already processed.',
|
|
968
|
+
tool_call_id=call.tool_call_id,
|
|
969
|
+
)
|
|
970
|
+
)
|
|
971
|
+
else:
|
|
972
|
+
parts.append(self._unknown_tool(call.tool_name, run_context))
|
|
789
973
|
|
|
974
|
+
# Run all tool tasks in parallel
|
|
975
|
+
if tasks:
|
|
790
976
|
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
|
|
791
|
-
task_results: Sequence[_messages.
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
else:
|
|
795
|
-
assert_never(model_response)
|
|
977
|
+
task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
|
|
978
|
+
parts.extend(task_results)
|
|
979
|
+
return parts
|
|
796
980
|
|
|
797
981
|
async def _handle_streamed_model_response(
|
|
798
|
-
self,
|
|
799
|
-
|
|
982
|
+
self,
|
|
983
|
+
model_response: models.EitherStreamedResponse,
|
|
984
|
+
run_context: RunContext[AgentDeps],
|
|
985
|
+
) -> (
|
|
986
|
+
_MarkFinalResult[models.EitherStreamedResponse]
|
|
987
|
+
| tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]]
|
|
988
|
+
):
|
|
800
989
|
"""Process a streamed response from the model.
|
|
801
990
|
|
|
802
991
|
Returns:
|
|
803
|
-
|
|
992
|
+
Either a final result or a tuple of the model response and the tool responses for the next request.
|
|
993
|
+
If a final result is returned, the conversation should end.
|
|
804
994
|
"""
|
|
805
995
|
if isinstance(model_response, models.StreamTextResponse):
|
|
806
996
|
# plain string response
|
|
807
997
|
if self._allow_text_result:
|
|
808
|
-
return _MarkFinalResult(model_response
|
|
998
|
+
return _MarkFinalResult(model_response, None)
|
|
809
999
|
else:
|
|
810
|
-
self._incr_result_retry()
|
|
811
|
-
response = _messages.
|
|
1000
|
+
self._incr_result_retry(run_context)
|
|
1001
|
+
response = _messages.RetryPromptPart(
|
|
812
1002
|
content='Plain text responses are not permitted, please call one of the functions instead.',
|
|
813
1003
|
)
|
|
814
|
-
# stream the response, so
|
|
1004
|
+
# stream the response, so usage is correct
|
|
815
1005
|
async for _ in model_response:
|
|
816
1006
|
pass
|
|
817
1007
|
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
1008
|
+
text = ''.join(model_response.get(final=True))
|
|
1009
|
+
return _messages.ModelResponse([_messages.TextPart(text)]), [response]
|
|
1010
|
+
elif isinstance(model_response, models.StreamStructuredResponse):
|
|
821
1011
|
if self._result_schema is not None:
|
|
822
1012
|
# if there's a result schema, iterate over the stream until we find at least one tool
|
|
823
1013
|
# NOTE: this means we ignore any other tools called here
|
|
824
1014
|
structured_msg = model_response.get()
|
|
825
|
-
while not structured_msg.
|
|
1015
|
+
while not structured_msg.parts:
|
|
826
1016
|
try:
|
|
827
1017
|
await model_response.__anext__()
|
|
828
1018
|
except StopAsyncIteration:
|
|
829
1019
|
break
|
|
830
1020
|
structured_msg = model_response.get()
|
|
831
1021
|
|
|
832
|
-
if match := self._result_schema.find_tool(structured_msg):
|
|
1022
|
+
if match := self._result_schema.find_tool(structured_msg.parts):
|
|
833
1023
|
call, _ = match
|
|
834
|
-
|
|
835
|
-
tool_name=call.tool_name,
|
|
836
|
-
content='Final result processed.',
|
|
837
|
-
tool_id=call.tool_id,
|
|
838
|
-
)
|
|
839
|
-
return _MarkFinalResult(model_response), [tool_return]
|
|
1024
|
+
return _MarkFinalResult(model_response, call.tool_name)
|
|
840
1025
|
|
|
841
1026
|
# the model is calling a tool function, consume the response to get the next message
|
|
842
1027
|
async for _ in model_response:
|
|
843
1028
|
pass
|
|
844
|
-
|
|
845
|
-
if not
|
|
1029
|
+
model_response_msg = model_response.get()
|
|
1030
|
+
if not model_response_msg.parts:
|
|
846
1031
|
raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
|
|
847
|
-
messages: list[_messages.Message] = [structured_msg]
|
|
848
1032
|
|
|
849
1033
|
# we now run all tool functions in parallel
|
|
850
|
-
tasks: list[asyncio.Task[_messages.
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
1034
|
+
tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
|
|
1035
|
+
parts: list[_messages.ModelRequestPart] = []
|
|
1036
|
+
for item in model_response_msg.parts:
|
|
1037
|
+
if isinstance(item, _messages.ToolCallPart):
|
|
1038
|
+
call = item
|
|
1039
|
+
if tool := self._function_tools.get(call.tool_name):
|
|
1040
|
+
tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
|
|
1041
|
+
else:
|
|
1042
|
+
parts.append(self._unknown_tool(call.tool_name, run_context))
|
|
856
1043
|
|
|
857
1044
|
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
|
|
858
|
-
task_results: Sequence[_messages.
|
|
859
|
-
|
|
860
|
-
return
|
|
1045
|
+
task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
|
|
1046
|
+
parts.extend(task_results)
|
|
1047
|
+
return model_response_msg, parts
|
|
1048
|
+
else:
|
|
1049
|
+
assert_never(model_response)
|
|
861
1050
|
|
|
862
1051
|
async def _validate_result(
|
|
863
|
-
self,
|
|
1052
|
+
self,
|
|
1053
|
+
result_data: ResultData,
|
|
1054
|
+
run_context: RunContext[AgentDeps],
|
|
1055
|
+
tool_call: _messages.ToolCallPart | None,
|
|
864
1056
|
) -> ResultData:
|
|
865
1057
|
for validator in self._result_validators:
|
|
866
|
-
result_data = await validator.validate(result_data,
|
|
1058
|
+
result_data = await validator.validate(result_data, tool_call, run_context)
|
|
867
1059
|
return result_data
|
|
868
1060
|
|
|
869
|
-
def _incr_result_retry(self) -> None:
|
|
870
|
-
|
|
871
|
-
if
|
|
1061
|
+
def _incr_result_retry(self, run_context: RunContext[AgentDeps]) -> None:
|
|
1062
|
+
run_context.retry += 1
|
|
1063
|
+
if run_context.retry > self._max_result_retries:
|
|
872
1064
|
raise exceptions.UnexpectedModelBehavior(
|
|
873
1065
|
f'Exceeded maximum retries ({self._max_result_retries}) for result validation'
|
|
874
1066
|
)
|
|
875
1067
|
|
|
876
|
-
async def
|
|
1068
|
+
async def _sys_parts(self, run_context: RunContext[AgentDeps]) -> list[_messages.ModelRequestPart]:
|
|
877
1069
|
"""Build the initial messages for the conversation."""
|
|
878
|
-
messages: list[_messages.
|
|
1070
|
+
messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self._system_prompts]
|
|
879
1071
|
for sys_prompt_runner in self._system_prompt_functions:
|
|
880
|
-
prompt = await sys_prompt_runner.run(
|
|
881
|
-
messages.append(_messages.
|
|
1072
|
+
prompt = await sys_prompt_runner.run(run_context)
|
|
1073
|
+
messages.append(_messages.SystemPromptPart(prompt))
|
|
882
1074
|
return messages
|
|
883
1075
|
|
|
884
|
-
def _unknown_tool(self, tool_name: str) -> _messages.
|
|
885
|
-
self._incr_result_retry()
|
|
1076
|
+
def _unknown_tool(self, tool_name: str, run_context: RunContext[AgentDeps]) -> _messages.RetryPromptPart:
|
|
1077
|
+
self._incr_result_retry(run_context)
|
|
886
1078
|
names = list(self._function_tools.keys())
|
|
887
1079
|
if self._result_schema:
|
|
888
1080
|
names.extend(self._result_schema.tool_names())
|
|
@@ -890,7 +1082,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
890
1082
|
msg = f'Available tools: {", ".join(names)}'
|
|
891
1083
|
else:
|
|
892
1084
|
msg = 'No tools available.'
|
|
893
|
-
return _messages.
|
|
1085
|
+
return _messages.RetryPromptPart(content=f'Unknown tool name: {tool_name!r}. {msg}')
|
|
894
1086
|
|
|
895
1087
|
def _get_deps(self, deps: AgentDeps) -> AgentDeps:
|
|
896
1088
|
"""Get deps for a run.
|
|
@@ -934,3 +1126,6 @@ class _MarkFinalResult(Generic[ResultData]):
|
|
|
934
1126
|
"""
|
|
935
1127
|
|
|
936
1128
|
data: ResultData
|
|
1129
|
+
"""The final result data."""
|
|
1130
|
+
tool_name: str | None
|
|
1131
|
+
"""Name of the final result tool, None if the result is a string."""
|