pydantic-ai-slim 0.0.12__py3-none-any.whl → 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_pydantic.py +7 -25
- pydantic_ai/_result.py +34 -16
- pydantic_ai/_system_prompt.py +1 -1
- pydantic_ai/_utils.py +9 -2
- pydantic_ai/agent.py +333 -148
- pydantic_ai/messages.py +87 -48
- pydantic_ai/models/__init__.py +30 -6
- pydantic_ai/models/anthropic.py +344 -0
- pydantic_ai/models/function.py +59 -31
- pydantic_ai/models/gemini.py +150 -108
- pydantic_ai/models/groq.py +94 -74
- pydantic_ai/models/mistral.py +680 -0
- pydantic_ai/models/ollama.py +1 -1
- pydantic_ai/models/openai.py +102 -76
- pydantic_ai/models/test.py +62 -51
- pydantic_ai/models/vertexai.py +7 -3
- pydantic_ai/result.py +35 -37
- pydantic_ai/settings.py +72 -0
- pydantic_ai/tools.py +28 -18
- {pydantic_ai_slim-0.0.12.dist-info → pydantic_ai_slim-0.0.13.dist-info}/METADATA +8 -3
- pydantic_ai_slim-0.0.13.dist-info/RECORD +26 -0
- {pydantic_ai_slim-0.0.12.dist-info → pydantic_ai_slim-0.0.13.dist-info}/WHEEL +1 -1
- pydantic_ai_slim-0.0.12.dist-info/RECORD +0 -23
pydantic_ai/agent.py
CHANGED
|
@@ -7,7 +7,7 @@ from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
|
|
|
7
7
|
from contextlib import asynccontextmanager, contextmanager
|
|
8
8
|
from dataclasses import dataclass, field
|
|
9
9
|
from types import FrameType
|
|
10
|
-
from typing import Any, Callable, Generic, cast, final, overload
|
|
10
|
+
from typing import Any, Callable, Generic, Literal, cast, final, overload
|
|
11
11
|
|
|
12
12
|
import logfire_api
|
|
13
13
|
from typing_extensions import assert_never
|
|
@@ -22,6 +22,7 @@ from . import (
|
|
|
22
22
|
result,
|
|
23
23
|
)
|
|
24
24
|
from .result import ResultData
|
|
25
|
+
from .settings import ModelSettings, merge_model_settings
|
|
25
26
|
from .tools import (
|
|
26
27
|
AgentDeps,
|
|
27
28
|
RunContext,
|
|
@@ -39,6 +40,12 @@ __all__ = ('Agent',)
|
|
|
39
40
|
_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
|
|
40
41
|
|
|
41
42
|
NoneType = type(None)
|
|
43
|
+
EndStrategy = Literal['early', 'exhaustive']
|
|
44
|
+
"""The strategy for handling multiple tool calls when a final result is found.
|
|
45
|
+
|
|
46
|
+
- `'early'`: Stop processing other tool calls once a final result is found
|
|
47
|
+
- `'exhaustive'`: Process all tool calls even after finding a final result
|
|
48
|
+
"""
|
|
42
49
|
|
|
43
50
|
|
|
44
51
|
@final
|
|
@@ -53,7 +60,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
53
60
|
|
|
54
61
|
Minimal usage example:
|
|
55
62
|
|
|
56
|
-
```
|
|
63
|
+
```python
|
|
57
64
|
from pydantic_ai import Agent
|
|
58
65
|
|
|
59
66
|
agent = Agent('openai:gpt-4o')
|
|
@@ -63,14 +70,31 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
63
70
|
```
|
|
64
71
|
"""
|
|
65
72
|
|
|
66
|
-
# dataclass fields
|
|
73
|
+
# we use dataclass fields in order to conveniently know what attributes are available
|
|
67
74
|
model: models.Model | models.KnownModelName | None
|
|
68
75
|
"""The default model configured for this agent."""
|
|
76
|
+
|
|
69
77
|
name: str | None
|
|
70
78
|
"""The name of the agent, used for logging.
|
|
71
79
|
|
|
72
80
|
If `None`, we try to infer the agent name from the call frame when the agent is first run.
|
|
73
81
|
"""
|
|
82
|
+
end_strategy: EndStrategy
|
|
83
|
+
"""Strategy for handling tool calls when a final result is found."""
|
|
84
|
+
|
|
85
|
+
model_settings: ModelSettings | None
|
|
86
|
+
"""Optional model request settings to use for this agents's runs, by default.
|
|
87
|
+
|
|
88
|
+
Note, if `model_settings` is provided by `run`, `run_sync`, or `run_stream`, those settings will
|
|
89
|
+
be merged with this value, with the runtime argument taking priority.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
last_run_messages: list[_messages.ModelMessage] | None
|
|
93
|
+
"""The messages from the last run, useful when a run raised an exception.
|
|
94
|
+
|
|
95
|
+
Note: these are not used by the agent, e.g. in future runs, they are just stored for developers' convenience.
|
|
96
|
+
"""
|
|
97
|
+
|
|
74
98
|
_result_schema: _result.ResultSchema[ResultData] | None = field(repr=False)
|
|
75
99
|
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = field(repr=False)
|
|
76
100
|
_allow_text_result: bool = field(repr=False)
|
|
@@ -83,11 +107,6 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
83
107
|
_current_result_retry: int = field(repr=False)
|
|
84
108
|
_override_deps: _utils.Option[AgentDeps] = field(default=None, repr=False)
|
|
85
109
|
_override_model: _utils.Option[models.Model] = field(default=None, repr=False)
|
|
86
|
-
last_run_messages: list[_messages.Message] | None = None
|
|
87
|
-
"""The messages from the last run, useful when a run raised an exception.
|
|
88
|
-
|
|
89
|
-
Note: these are not used by the agent, e.g. in future runs, they are just stored for developers' convenience.
|
|
90
|
-
"""
|
|
91
110
|
|
|
92
111
|
def __init__(
|
|
93
112
|
self,
|
|
@@ -97,18 +116,20 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
97
116
|
system_prompt: str | Sequence[str] = (),
|
|
98
117
|
deps_type: type[AgentDeps] = NoneType,
|
|
99
118
|
name: str | None = None,
|
|
119
|
+
model_settings: ModelSettings | None = None,
|
|
100
120
|
retries: int = 1,
|
|
101
121
|
result_tool_name: str = 'final_result',
|
|
102
122
|
result_tool_description: str | None = None,
|
|
103
123
|
result_retries: int | None = None,
|
|
104
124
|
tools: Sequence[Tool[AgentDeps] | ToolFuncEither[AgentDeps, ...]] = (),
|
|
105
125
|
defer_model_check: bool = False,
|
|
126
|
+
end_strategy: EndStrategy = 'early',
|
|
106
127
|
):
|
|
107
128
|
"""Create an agent.
|
|
108
129
|
|
|
109
130
|
Args:
|
|
110
131
|
model: The default model to use for this agent, if not provide,
|
|
111
|
-
you must provide the model when calling
|
|
132
|
+
you must provide the model when calling it.
|
|
112
133
|
result_type: The type of the result data, used to validate the result data, defaults to `str`.
|
|
113
134
|
system_prompt: Static system prompts to use for this agent, you can also register system
|
|
114
135
|
prompts via a function with [`system_prompt`][pydantic_ai.Agent.system_prompt].
|
|
@@ -118,6 +139,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
118
139
|
or add a type hint `: Agent[None, <return type>]`.
|
|
119
140
|
name: The name of the agent, used for logging. If `None`, we try to infer the agent name from the call frame
|
|
120
141
|
when the agent is first run.
|
|
142
|
+
model_settings: Optional model request settings to use for this agent's runs, by default.
|
|
121
143
|
retries: The default number of retries to allow before raising an error.
|
|
122
144
|
result_tool_name: The name of the tool to use for the final result.
|
|
123
145
|
result_tool_description: The description of the final result tool.
|
|
@@ -129,13 +151,18 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
129
151
|
which checks for the necessary environment variables. Set this to `false`
|
|
130
152
|
to defer the evaluation until the first run. Useful if you want to
|
|
131
153
|
[override the model][pydantic_ai.Agent.override] for testing.
|
|
154
|
+
end_strategy: Strategy for handling tool calls that are requested alongside a final result.
|
|
155
|
+
See [`EndStrategy`][pydantic_ai.agent.EndStrategy] for more information.
|
|
132
156
|
"""
|
|
133
157
|
if model is None or defer_model_check:
|
|
134
158
|
self.model = model
|
|
135
159
|
else:
|
|
136
160
|
self.model = models.infer_model(model)
|
|
137
161
|
|
|
162
|
+
self.end_strategy = end_strategy
|
|
138
163
|
self.name = name
|
|
164
|
+
self.model_settings = model_settings
|
|
165
|
+
self.last_run_messages = None
|
|
139
166
|
self._result_schema = _result.ResultSchema[result_type].build(
|
|
140
167
|
result_type, result_tool_name, result_tool_description
|
|
141
168
|
)
|
|
@@ -160,19 +187,32 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
160
187
|
self,
|
|
161
188
|
user_prompt: str,
|
|
162
189
|
*,
|
|
163
|
-
message_history: list[_messages.
|
|
190
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
164
191
|
model: models.Model | models.KnownModelName | None = None,
|
|
165
192
|
deps: AgentDeps = None,
|
|
193
|
+
model_settings: ModelSettings | None = None,
|
|
166
194
|
infer_name: bool = True,
|
|
167
195
|
) -> result.RunResult[ResultData]:
|
|
168
196
|
"""Run the agent with a user prompt in async mode.
|
|
169
197
|
|
|
198
|
+
Example:
|
|
199
|
+
```python
|
|
200
|
+
from pydantic_ai import Agent
|
|
201
|
+
|
|
202
|
+
agent = Agent('openai:gpt-4o')
|
|
203
|
+
|
|
204
|
+
result_sync = agent.run_sync('What is the capital of Italy?')
|
|
205
|
+
print(result_sync.data)
|
|
206
|
+
#> Rome
|
|
207
|
+
```
|
|
208
|
+
|
|
170
209
|
Args:
|
|
171
210
|
user_prompt: User input to start/continue the conversation.
|
|
172
211
|
message_history: History of the conversation so far.
|
|
173
212
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
174
213
|
deps: Optional dependencies to use for this run.
|
|
175
214
|
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
215
|
+
model_settings: Optional settings to use for this model's request.
|
|
176
216
|
|
|
177
217
|
Returns:
|
|
178
218
|
The result of the run.
|
|
@@ -182,6 +222,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
182
222
|
model_used, mode_selection = await self._get_model(model)
|
|
183
223
|
|
|
184
224
|
deps = self._get_deps(deps)
|
|
225
|
+
new_message_index = len(message_history) if message_history else 0
|
|
185
226
|
|
|
186
227
|
with _logfire.span(
|
|
187
228
|
'{agent_name} run {prompt=}',
|
|
@@ -191,34 +232,35 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
191
232
|
model_name=model_used.name(),
|
|
192
233
|
agent_name=self.name or 'agent',
|
|
193
234
|
) as run_span:
|
|
194
|
-
|
|
195
|
-
self.last_run_messages = messages
|
|
235
|
+
self.last_run_messages = messages = await self._prepare_messages(deps, user_prompt, message_history)
|
|
196
236
|
|
|
197
237
|
for tool in self._function_tools.values():
|
|
198
238
|
tool.current_retry = 0
|
|
199
239
|
|
|
200
240
|
cost = result.Cost()
|
|
201
241
|
|
|
242
|
+
model_settings = merge_model_settings(self.model_settings, model_settings)
|
|
243
|
+
|
|
202
244
|
run_step = 0
|
|
203
245
|
while True:
|
|
204
246
|
run_step += 1
|
|
205
247
|
with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
|
|
206
|
-
agent_model = await self._prepare_model(model_used, deps)
|
|
248
|
+
agent_model = await self._prepare_model(model_used, deps, messages)
|
|
207
249
|
|
|
208
250
|
with _logfire.span('model request', run_step=run_step) as model_req_span:
|
|
209
|
-
model_response, request_cost = await agent_model.request(messages)
|
|
251
|
+
model_response, request_cost = await agent_model.request(messages, model_settings)
|
|
210
252
|
model_req_span.set_attribute('response', model_response)
|
|
211
253
|
model_req_span.set_attribute('cost', request_cost)
|
|
212
|
-
model_req_span.message = f'model request -> {model_response.role}'
|
|
213
254
|
|
|
214
255
|
messages.append(model_response)
|
|
215
256
|
cost += request_cost
|
|
216
257
|
|
|
217
258
|
with _logfire.span('handle model response', run_step=run_step) as handle_span:
|
|
218
|
-
final_result,
|
|
259
|
+
final_result, tool_responses = await self._handle_model_response(model_response, deps, messages)
|
|
219
260
|
|
|
220
|
-
|
|
221
|
-
|
|
261
|
+
if tool_responses:
|
|
262
|
+
# Add parts to the conversation as a new message
|
|
263
|
+
messages.append(_messages.ModelRequest(tool_responses))
|
|
222
264
|
|
|
223
265
|
# Check if we got a final result
|
|
224
266
|
if final_result is not None:
|
|
@@ -230,22 +272,36 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
230
272
|
return result.RunResult(messages, new_message_index, result_data, cost)
|
|
231
273
|
else:
|
|
232
274
|
# continue the conversation
|
|
233
|
-
handle_span.set_attribute('tool_responses',
|
|
234
|
-
|
|
235
|
-
handle_span.message = f'handle model response -> {
|
|
275
|
+
handle_span.set_attribute('tool_responses', tool_responses)
|
|
276
|
+
tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
|
|
277
|
+
handle_span.message = f'handle model response -> {tool_responses_str}'
|
|
236
278
|
|
|
237
279
|
def run_sync(
|
|
238
280
|
self,
|
|
239
281
|
user_prompt: str,
|
|
240
282
|
*,
|
|
241
|
-
message_history: list[_messages.
|
|
283
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
242
284
|
model: models.Model | models.KnownModelName | None = None,
|
|
243
285
|
deps: AgentDeps = None,
|
|
286
|
+
model_settings: ModelSettings | None = None,
|
|
244
287
|
infer_name: bool = True,
|
|
245
288
|
) -> result.RunResult[ResultData]:
|
|
246
289
|
"""Run the agent with a user prompt synchronously.
|
|
247
290
|
|
|
248
|
-
This is a convenience method that wraps `self.run` with `loop.run_until_complete()`.
|
|
291
|
+
This is a convenience method that wraps [`self.run`][pydantic_ai.Agent.run] with `loop.run_until_complete(...)`.
|
|
292
|
+
You therefore can't use this method inside async code or if there's an active event loop.
|
|
293
|
+
|
|
294
|
+
Example:
|
|
295
|
+
```python
|
|
296
|
+
from pydantic_ai import Agent
|
|
297
|
+
|
|
298
|
+
agent = Agent('openai:gpt-4o')
|
|
299
|
+
|
|
300
|
+
async def main():
|
|
301
|
+
result = await agent.run('What is the capital of France?')
|
|
302
|
+
print(result.data)
|
|
303
|
+
#> Paris
|
|
304
|
+
```
|
|
249
305
|
|
|
250
306
|
Args:
|
|
251
307
|
user_prompt: User input to start/continue the conversation.
|
|
@@ -253,15 +309,22 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
253
309
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
254
310
|
deps: Optional dependencies to use for this run.
|
|
255
311
|
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
312
|
+
model_settings: Optional settings to use for this model's request.
|
|
256
313
|
|
|
257
314
|
Returns:
|
|
258
315
|
The result of the run.
|
|
259
316
|
"""
|
|
260
317
|
if infer_name and self.name is None:
|
|
261
318
|
self._infer_name(inspect.currentframe())
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
319
|
+
return asyncio.get_event_loop().run_until_complete(
|
|
320
|
+
self.run(
|
|
321
|
+
user_prompt,
|
|
322
|
+
message_history=message_history,
|
|
323
|
+
model=model,
|
|
324
|
+
deps=deps,
|
|
325
|
+
infer_name=False,
|
|
326
|
+
model_settings=model_settings,
|
|
327
|
+
)
|
|
265
328
|
)
|
|
266
329
|
|
|
267
330
|
@asynccontextmanager
|
|
@@ -269,19 +332,33 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
269
332
|
self,
|
|
270
333
|
user_prompt: str,
|
|
271
334
|
*,
|
|
272
|
-
message_history: list[_messages.
|
|
335
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
273
336
|
model: models.Model | models.KnownModelName | None = None,
|
|
274
337
|
deps: AgentDeps = None,
|
|
338
|
+
model_settings: ModelSettings | None = None,
|
|
275
339
|
infer_name: bool = True,
|
|
276
340
|
) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]:
|
|
277
341
|
"""Run the agent with a user prompt in async mode, returning a streamed response.
|
|
278
342
|
|
|
343
|
+
Example:
|
|
344
|
+
```python
|
|
345
|
+
from pydantic_ai import Agent
|
|
346
|
+
|
|
347
|
+
agent = Agent('openai:gpt-4o')
|
|
348
|
+
|
|
349
|
+
async def main():
|
|
350
|
+
async with agent.run_stream('What is the capital of the UK?') as response:
|
|
351
|
+
print(await response.get_data())
|
|
352
|
+
#> London
|
|
353
|
+
```
|
|
354
|
+
|
|
279
355
|
Args:
|
|
280
356
|
user_prompt: User input to start/continue the conversation.
|
|
281
357
|
message_history: History of the conversation so far.
|
|
282
358
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
283
359
|
deps: Optional dependencies to use for this run.
|
|
284
360
|
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
|
|
361
|
+
model_settings: Optional settings to use for this model's request.
|
|
285
362
|
|
|
286
363
|
Returns:
|
|
287
364
|
The result of the run.
|
|
@@ -293,6 +370,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
293
370
|
model_used, mode_selection = await self._get_model(model)
|
|
294
371
|
|
|
295
372
|
deps = self._get_deps(deps)
|
|
373
|
+
new_message_index = len(message_history) if message_history else 0
|
|
296
374
|
|
|
297
375
|
with _logfire.span(
|
|
298
376
|
'{agent_name} run stream {prompt=}',
|
|
@@ -302,42 +380,57 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
302
380
|
model_name=model_used.name(),
|
|
303
381
|
agent_name=self.name or 'agent',
|
|
304
382
|
) as run_span:
|
|
305
|
-
|
|
306
|
-
self.last_run_messages = messages
|
|
383
|
+
self.last_run_messages = messages = await self._prepare_messages(deps, user_prompt, message_history)
|
|
307
384
|
|
|
308
385
|
for tool in self._function_tools.values():
|
|
309
386
|
tool.current_retry = 0
|
|
310
387
|
|
|
311
388
|
cost = result.Cost()
|
|
389
|
+
model_settings = merge_model_settings(self.model_settings, model_settings)
|
|
312
390
|
|
|
313
391
|
run_step = 0
|
|
314
392
|
while True:
|
|
315
393
|
run_step += 1
|
|
316
394
|
|
|
317
395
|
with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
|
|
318
|
-
agent_model = await self._prepare_model(model_used, deps)
|
|
396
|
+
agent_model = await self._prepare_model(model_used, deps, messages)
|
|
319
397
|
|
|
320
398
|
with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
|
|
321
|
-
async with agent_model.request_stream(messages) as model_response:
|
|
399
|
+
async with agent_model.request_stream(messages, model_settings) as model_response:
|
|
322
400
|
model_req_span.set_attribute('response_type', model_response.__class__.__name__)
|
|
323
401
|
# We want to end the "model request" span here, but we can't exit the context manager
|
|
324
402
|
# in the traditional way
|
|
325
403
|
model_req_span.__exit__(None, None, None)
|
|
326
404
|
|
|
327
405
|
with _logfire.span('handle model response') as handle_span:
|
|
328
|
-
|
|
329
|
-
model_response, deps
|
|
406
|
+
maybe_final_result = await self._handle_streamed_model_response(
|
|
407
|
+
model_response, deps, messages
|
|
330
408
|
)
|
|
331
409
|
|
|
332
|
-
# Add all messages to the conversation
|
|
333
|
-
messages.extend(response_messages)
|
|
334
|
-
|
|
335
410
|
# Check if we got a final result
|
|
336
|
-
if
|
|
337
|
-
result_stream =
|
|
338
|
-
|
|
339
|
-
handle_span.set_attribute('result_type', result_stream.__class__.__name__)
|
|
411
|
+
if isinstance(maybe_final_result, _MarkFinalResult):
|
|
412
|
+
result_stream = maybe_final_result.data
|
|
413
|
+
result_tool_name = maybe_final_result.tool_name
|
|
340
414
|
handle_span.message = 'handle model response -> final result'
|
|
415
|
+
|
|
416
|
+
async def on_complete():
|
|
417
|
+
"""Called when the stream has completed.
|
|
418
|
+
|
|
419
|
+
The model response will have been added to messages by now
|
|
420
|
+
by `StreamedRunResult._marked_completed`.
|
|
421
|
+
"""
|
|
422
|
+
last_message = messages[-1]
|
|
423
|
+
assert isinstance(last_message, _messages.ModelResponse)
|
|
424
|
+
tool_calls = [
|
|
425
|
+
part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)
|
|
426
|
+
]
|
|
427
|
+
parts = await self._process_function_tools(
|
|
428
|
+
tool_calls, result_tool_name, deps, messages
|
|
429
|
+
)
|
|
430
|
+
if parts:
|
|
431
|
+
messages.append(_messages.ModelRequest(parts))
|
|
432
|
+
run_span.set_attribute('all_messages', messages)
|
|
433
|
+
|
|
341
434
|
yield result.StreamedRunResult(
|
|
342
435
|
messages,
|
|
343
436
|
new_message_index,
|
|
@@ -346,14 +439,22 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
346
439
|
self._result_schema,
|
|
347
440
|
deps,
|
|
348
441
|
self._result_validators,
|
|
349
|
-
|
|
442
|
+
result_tool_name,
|
|
443
|
+
on_complete,
|
|
350
444
|
)
|
|
351
445
|
return
|
|
352
446
|
else:
|
|
353
447
|
# continue the conversation
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
448
|
+
model_response_msg, tool_responses = maybe_final_result
|
|
449
|
+
# if we got a model response add that to messages
|
|
450
|
+
messages.append(model_response_msg)
|
|
451
|
+
if tool_responses:
|
|
452
|
+
# if we got one or more tool response parts, add a model request message
|
|
453
|
+
messages.append(_messages.ModelRequest(tool_responses))
|
|
454
|
+
|
|
455
|
+
handle_span.set_attribute('tool_responses', tool_responses)
|
|
456
|
+
tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
|
|
457
|
+
handle_span.message = f'handle model response -> {tool_responses_str}'
|
|
357
458
|
# the model_response should have been fully streamed by now, we can add it's cost
|
|
358
459
|
cost += model_response.cost()
|
|
359
460
|
|
|
@@ -367,6 +468,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
367
468
|
"""Context manager to temporarily override agent dependencies and model.
|
|
368
469
|
|
|
369
470
|
This is particularly useful when testing.
|
|
471
|
+
You can find an example of this [here](../testing-evals.md#overriding-model-via-pytest-fixtures).
|
|
370
472
|
|
|
371
473
|
Args:
|
|
372
474
|
deps: The dependencies to use instead of the dependencies passed to the agent run.
|
|
@@ -415,14 +517,14 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
415
517
|
) -> _system_prompt.SystemPromptFunc[AgentDeps]:
|
|
416
518
|
"""Decorator to register a system prompt function.
|
|
417
519
|
|
|
418
|
-
Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as
|
|
520
|
+
Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its only argument.
|
|
419
521
|
Can decorate a sync or async functions.
|
|
420
522
|
|
|
421
523
|
Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure
|
|
422
524
|
the type of the function, see `tests/typed_agent.py` for tests.
|
|
423
525
|
|
|
424
526
|
Example:
|
|
425
|
-
```
|
|
527
|
+
```python
|
|
426
528
|
from pydantic_ai import Agent, RunContext
|
|
427
529
|
|
|
428
530
|
agent = Agent('test', deps_type=str)
|
|
@@ -466,14 +568,14 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
466
568
|
) -> _result.ResultValidatorFunc[AgentDeps, ResultData]:
|
|
467
569
|
"""Decorator to register a result validator function.
|
|
468
570
|
|
|
469
|
-
Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as
|
|
571
|
+
Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
|
|
470
572
|
Can decorate a sync or async functions.
|
|
471
573
|
|
|
472
574
|
Overloads for every possible signature of `result_validator` are included so the decorator doesn't obscure
|
|
473
575
|
the type of the function, see `tests/typed_agent.py` for tests.
|
|
474
576
|
|
|
475
577
|
Example:
|
|
476
|
-
```
|
|
578
|
+
```python
|
|
477
579
|
from pydantic_ai import Agent, ModelRetry, RunContext
|
|
478
580
|
|
|
479
581
|
agent = Agent('test', deps_type=str)
|
|
@@ -523,13 +625,13 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
523
625
|
Can decorate a sync or async functions.
|
|
524
626
|
|
|
525
627
|
The docstring is inspected to extract both the tool description and description of each parameter,
|
|
526
|
-
[learn more](../
|
|
628
|
+
[learn more](../tools.md#function-tools-and-schema).
|
|
527
629
|
|
|
528
630
|
We can't add overloads for every possible signature of tool, since the return type is a recursive union
|
|
529
631
|
so the signature of functions decorated with `@agent.tool` is obscured.
|
|
530
632
|
|
|
531
633
|
Example:
|
|
532
|
-
```
|
|
634
|
+
```python
|
|
533
635
|
from pydantic_ai import Agent, RunContext
|
|
534
636
|
|
|
535
637
|
agent = Agent('test', deps_type=int)
|
|
@@ -595,13 +697,13 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
595
697
|
Can decorate a sync or async functions.
|
|
596
698
|
|
|
597
699
|
The docstring is inspected to extract both the tool description and description of each parameter,
|
|
598
|
-
[learn more](../
|
|
700
|
+
[learn more](../tools.md#function-tools-and-schema).
|
|
599
701
|
|
|
600
702
|
We can't add overloads for every possible signature of tool, since the return type is a recursive union
|
|
601
703
|
so the signature of functions decorated with `@agent.tool` is obscured.
|
|
602
704
|
|
|
603
705
|
Example:
|
|
604
|
-
```
|
|
706
|
+
```python
|
|
605
707
|
from pydantic_ai import Agent, RunContext
|
|
606
708
|
|
|
607
709
|
agent = Agent('test')
|
|
@@ -696,12 +798,14 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
696
798
|
|
|
697
799
|
return model_, mode_selection
|
|
698
800
|
|
|
699
|
-
async def _prepare_model(
|
|
801
|
+
async def _prepare_model(
|
|
802
|
+
self, model: models.Model, deps: AgentDeps, messages: list[_messages.ModelMessage]
|
|
803
|
+
) -> models.AgentModel:
|
|
700
804
|
"""Create building tools and create an agent model."""
|
|
701
805
|
function_tools: list[ToolDefinition] = []
|
|
702
806
|
|
|
703
807
|
async def add_tool(tool: Tool[AgentDeps]) -> None:
|
|
704
|
-
ctx = RunContext(deps, tool.current_retry, tool.name)
|
|
808
|
+
ctx = RunContext(deps, tool.current_retry, messages, tool.name)
|
|
705
809
|
if tool_def := await tool.prepare_tool_def(ctx):
|
|
706
810
|
function_tools.append(tool_def)
|
|
707
811
|
|
|
@@ -714,156 +818,234 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
714
818
|
)
|
|
715
819
|
|
|
716
820
|
async def _prepare_messages(
|
|
717
|
-
self, deps: AgentDeps, user_prompt: str, message_history: list[_messages.
|
|
718
|
-
) ->
|
|
719
|
-
|
|
720
|
-
if message_history and any(m.role == 'system' for m in message_history):
|
|
821
|
+
self, deps: AgentDeps, user_prompt: str, message_history: list[_messages.ModelMessage] | None
|
|
822
|
+
) -> list[_messages.ModelMessage]:
|
|
823
|
+
if message_history:
|
|
721
824
|
# shallow copy messages
|
|
722
825
|
messages = message_history.copy()
|
|
826
|
+
messages.append(_messages.ModelRequest([_messages.UserPromptPart(user_prompt)]))
|
|
723
827
|
else:
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
828
|
+
parts = await self._sys_parts(deps)
|
|
829
|
+
parts.append(_messages.UserPromptPart(user_prompt))
|
|
830
|
+
messages: list[_messages.ModelMessage] = [_messages.ModelRequest(parts)]
|
|
727
831
|
|
|
728
|
-
|
|
729
|
-
messages.append(_messages.UserPrompt(user_prompt))
|
|
730
|
-
return new_message_index, messages
|
|
832
|
+
return messages
|
|
731
833
|
|
|
732
834
|
async def _handle_model_response(
|
|
733
|
-
self, model_response: _messages.
|
|
734
|
-
) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.
|
|
835
|
+
self, model_response: _messages.ModelResponse, deps: AgentDeps, conv_messages: list[_messages.ModelMessage]
|
|
836
|
+
) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
|
|
735
837
|
"""Process a non-streamed response from the model.
|
|
736
838
|
|
|
737
839
|
Returns:
|
|
738
|
-
A tuple of `(final_result,
|
|
840
|
+
A tuple of `(final_result, request parts)`. If `final_result` is not `None`, the conversation should end.
|
|
739
841
|
"""
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
842
|
+
texts: list[str] = []
|
|
843
|
+
tool_calls: list[_messages.ToolCallPart] = []
|
|
844
|
+
for item in model_response.parts:
|
|
845
|
+
if isinstance(item, _messages.TextPart):
|
|
846
|
+
texts.append(item.content)
|
|
847
|
+
else:
|
|
848
|
+
tool_calls.append(item)
|
|
849
|
+
|
|
850
|
+
if texts:
|
|
851
|
+
text = '\n\n'.join(texts)
|
|
852
|
+
return await self._handle_text_response(text, deps, conv_messages)
|
|
853
|
+
elif tool_calls:
|
|
854
|
+
return await self._handle_structured_response(tool_calls, deps, conv_messages)
|
|
855
|
+
else:
|
|
856
|
+
raise exceptions.UnexpectedModelBehavior('Received empty model response')
|
|
857
|
+
|
|
858
|
+
async def _handle_text_response(
|
|
859
|
+
self, text: str, deps: AgentDeps, conv_messages: list[_messages.ModelMessage]
|
|
860
|
+
) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
|
|
861
|
+
"""Handle a plain text response from the model for non-streaming responses."""
|
|
862
|
+
if self._allow_text_result:
|
|
863
|
+
result_data_input = cast(ResultData, text)
|
|
864
|
+
try:
|
|
865
|
+
result_data = await self._validate_result(result_data_input, deps, None, conv_messages)
|
|
866
|
+
except _result.ToolRetryError as e:
|
|
867
|
+
self._incr_result_retry()
|
|
868
|
+
return None, [e.tool_retry]
|
|
869
|
+
else:
|
|
870
|
+
return _MarkFinalResult(result_data, None), []
|
|
871
|
+
else:
|
|
872
|
+
self._incr_result_retry()
|
|
873
|
+
response = _messages.RetryPromptPart(
|
|
874
|
+
content='Plain text responses are not permitted, please call one of the functions instead.',
|
|
875
|
+
)
|
|
876
|
+
return None, [response]
|
|
877
|
+
|
|
878
|
+
async def _handle_structured_response(
|
|
879
|
+
self, tool_calls: list[_messages.ToolCallPart], deps: AgentDeps, conv_messages: list[_messages.ModelMessage]
|
|
880
|
+
) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
|
|
881
|
+
"""Handle a structured response containing tool calls from the model for non-streaming responses."""
|
|
882
|
+
assert tool_calls, 'Expected at least one tool call'
|
|
883
|
+
|
|
884
|
+
# first look for the result tool call
|
|
885
|
+
final_result: _MarkFinalResult[ResultData] | None = None
|
|
886
|
+
|
|
887
|
+
parts: list[_messages.ModelRequestPart] = []
|
|
888
|
+
if result_schema := self._result_schema:
|
|
889
|
+
if match := result_schema.find_tool(tool_calls):
|
|
890
|
+
call, result_tool = match
|
|
744
891
|
try:
|
|
745
|
-
result_data =
|
|
892
|
+
result_data = result_tool.validate(call)
|
|
893
|
+
result_data = await self._validate_result(result_data, deps, call, conv_messages)
|
|
746
894
|
except _result.ToolRetryError as e:
|
|
747
895
|
self._incr_result_retry()
|
|
748
|
-
|
|
896
|
+
parts.append(e.tool_retry)
|
|
749
897
|
else:
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
898
|
+
final_result = _MarkFinalResult(result_data, call.tool_name)
|
|
899
|
+
|
|
900
|
+
# Then build the other request parts based on end strategy
|
|
901
|
+
parts += await self._process_function_tools(
|
|
902
|
+
tool_calls, final_result and final_result.tool_name, deps, conv_messages
|
|
903
|
+
)
|
|
904
|
+
|
|
905
|
+
return final_result, parts
|
|
906
|
+
|
|
907
|
+
async def _process_function_tools(
|
|
908
|
+
self,
|
|
909
|
+
tool_calls: list[_messages.ToolCallPart],
|
|
910
|
+
result_tool_name: str | None,
|
|
911
|
+
deps: AgentDeps,
|
|
912
|
+
conv_messages: list[_messages.ModelMessage],
|
|
913
|
+
) -> list[_messages.ModelRequestPart]:
|
|
914
|
+
"""Process function (non-result) tool calls in parallel.
|
|
915
|
+
|
|
916
|
+
Also add stub return parts for any other tools that need it.
|
|
917
|
+
"""
|
|
918
|
+
parts: list[_messages.ModelRequestPart] = []
|
|
919
|
+
tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
|
|
920
|
+
|
|
921
|
+
stub_function_tools = bool(result_tool_name) and self.end_strategy == 'early'
|
|
922
|
+
|
|
923
|
+
# we rely on the fact that if we found a result, it's the first result tool in the last
|
|
924
|
+
found_used_result_tool = False
|
|
925
|
+
for call in tool_calls:
|
|
926
|
+
if call.tool_name == result_tool_name and not found_used_result_tool:
|
|
927
|
+
found_used_result_tool = True
|
|
928
|
+
parts.append(
|
|
929
|
+
_messages.ToolReturnPart(
|
|
930
|
+
tool_name=call.tool_name,
|
|
931
|
+
content='Final result processed.',
|
|
932
|
+
tool_call_id=call.tool_call_id,
|
|
933
|
+
)
|
|
755
934
|
)
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
# NOTE: this means we ignore any other tools called here
|
|
761
|
-
if match := self._result_schema.find_tool(model_response):
|
|
762
|
-
call, result_tool = match
|
|
763
|
-
try:
|
|
764
|
-
result_data = result_tool.validate(call)
|
|
765
|
-
result_data = await self._validate_result(result_data, deps, call)
|
|
766
|
-
except _result.ToolRetryError as e:
|
|
767
|
-
self._incr_result_retry()
|
|
768
|
-
return None, [e.tool_retry]
|
|
769
|
-
else:
|
|
770
|
-
# Add a ToolReturn message for the schema tool call
|
|
771
|
-
tool_return = _messages.ToolReturn(
|
|
935
|
+
elif tool := self._function_tools.get(call.tool_name):
|
|
936
|
+
if stub_function_tools:
|
|
937
|
+
parts.append(
|
|
938
|
+
_messages.ToolReturnPart(
|
|
772
939
|
tool_name=call.tool_name,
|
|
773
|
-
content='
|
|
774
|
-
|
|
940
|
+
content='Tool not executed - a final result was already processed.',
|
|
941
|
+
tool_call_id=call.tool_call_id,
|
|
775
942
|
)
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
if not model_response.calls:
|
|
779
|
-
raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
|
|
780
|
-
|
|
781
|
-
# otherwise we run all tool functions in parallel
|
|
782
|
-
messages: list[_messages.Message] = []
|
|
783
|
-
tasks: list[asyncio.Task[_messages.Message]] = []
|
|
784
|
-
for call in model_response.calls:
|
|
785
|
-
if tool := self._function_tools.get(call.tool_name):
|
|
786
|
-
tasks.append(asyncio.create_task(tool.run(deps, call), name=call.tool_name))
|
|
943
|
+
)
|
|
787
944
|
else:
|
|
788
|
-
|
|
945
|
+
tasks.append(asyncio.create_task(tool.run(deps, call, conv_messages), name=call.tool_name))
|
|
946
|
+
elif self._result_schema is not None and call.tool_name in self._result_schema.tools:
|
|
947
|
+
# if tool_name is in _result_schema, it means we found a result tool but an error occurred in
|
|
948
|
+
# validation, we don't add another part here
|
|
949
|
+
if result_tool_name is not None:
|
|
950
|
+
parts.append(
|
|
951
|
+
_messages.ToolReturnPart(
|
|
952
|
+
tool_name=call.tool_name,
|
|
953
|
+
content='Result tool not used - a final result was already processed.',
|
|
954
|
+
tool_call_id=call.tool_call_id,
|
|
955
|
+
)
|
|
956
|
+
)
|
|
957
|
+
else:
|
|
958
|
+
parts.append(self._unknown_tool(call.tool_name))
|
|
789
959
|
|
|
960
|
+
# Run all tool tasks in parallel
|
|
961
|
+
if tasks:
|
|
790
962
|
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
|
|
791
|
-
task_results: Sequence[_messages.
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
else:
|
|
795
|
-
assert_never(model_response)
|
|
963
|
+
task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
|
|
964
|
+
parts.extend(task_results)
|
|
965
|
+
return parts
|
|
796
966
|
|
|
797
967
|
async def _handle_streamed_model_response(
|
|
798
|
-
self,
|
|
799
|
-
|
|
968
|
+
self,
|
|
969
|
+
model_response: models.EitherStreamedResponse,
|
|
970
|
+
deps: AgentDeps,
|
|
971
|
+
conv_messages: list[_messages.ModelMessage],
|
|
972
|
+
) -> (
|
|
973
|
+
_MarkFinalResult[models.EitherStreamedResponse]
|
|
974
|
+
| tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]]
|
|
975
|
+
):
|
|
800
976
|
"""Process a streamed response from the model.
|
|
801
977
|
|
|
802
978
|
Returns:
|
|
803
|
-
|
|
979
|
+
Either a final result or a tuple of the model response and the tool responses for the next request.
|
|
980
|
+
If a final result is returned, the conversation should end.
|
|
804
981
|
"""
|
|
805
982
|
if isinstance(model_response, models.StreamTextResponse):
|
|
806
983
|
# plain string response
|
|
807
984
|
if self._allow_text_result:
|
|
808
|
-
return _MarkFinalResult(model_response
|
|
985
|
+
return _MarkFinalResult(model_response, None)
|
|
809
986
|
else:
|
|
810
987
|
self._incr_result_retry()
|
|
811
|
-
response = _messages.
|
|
988
|
+
response = _messages.RetryPromptPart(
|
|
812
989
|
content='Plain text responses are not permitted, please call one of the functions instead.',
|
|
813
990
|
)
|
|
814
991
|
# stream the response, so cost is correct
|
|
815
992
|
async for _ in model_response:
|
|
816
993
|
pass
|
|
817
994
|
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
995
|
+
text = ''.join(model_response.get(final=True))
|
|
996
|
+
return _messages.ModelResponse([_messages.TextPart(text)]), [response]
|
|
997
|
+
elif isinstance(model_response, models.StreamStructuredResponse):
|
|
821
998
|
if self._result_schema is not None:
|
|
822
999
|
# if there's a result schema, iterate over the stream until we find at least one tool
|
|
823
1000
|
# NOTE: this means we ignore any other tools called here
|
|
824
1001
|
structured_msg = model_response.get()
|
|
825
|
-
while not structured_msg.
|
|
1002
|
+
while not structured_msg.parts:
|
|
826
1003
|
try:
|
|
827
1004
|
await model_response.__anext__()
|
|
828
1005
|
except StopAsyncIteration:
|
|
829
1006
|
break
|
|
830
1007
|
structured_msg = model_response.get()
|
|
831
1008
|
|
|
832
|
-
if match := self._result_schema.find_tool(structured_msg):
|
|
1009
|
+
if match := self._result_schema.find_tool(structured_msg.parts):
|
|
833
1010
|
call, _ = match
|
|
834
|
-
|
|
835
|
-
tool_name=call.tool_name,
|
|
836
|
-
content='Final result processed.',
|
|
837
|
-
tool_id=call.tool_id,
|
|
838
|
-
)
|
|
839
|
-
return _MarkFinalResult(model_response), [tool_return]
|
|
1011
|
+
return _MarkFinalResult(model_response, call.tool_name)
|
|
840
1012
|
|
|
841
1013
|
# the model is calling a tool function, consume the response to get the next message
|
|
842
1014
|
async for _ in model_response:
|
|
843
1015
|
pass
|
|
844
|
-
|
|
845
|
-
if not
|
|
1016
|
+
model_response_msg = model_response.get()
|
|
1017
|
+
if not model_response_msg.parts:
|
|
846
1018
|
raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
|
|
847
|
-
messages: list[_messages.Message] = [structured_msg]
|
|
848
1019
|
|
|
849
1020
|
# we now run all tool functions in parallel
|
|
850
|
-
tasks: list[asyncio.Task[_messages.
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
1021
|
+
tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
|
|
1022
|
+
parts: list[_messages.ModelRequestPart] = []
|
|
1023
|
+
for item in model_response_msg.parts:
|
|
1024
|
+
if isinstance(item, _messages.ToolCallPart):
|
|
1025
|
+
call = item
|
|
1026
|
+
if tool := self._function_tools.get(call.tool_name):
|
|
1027
|
+
tasks.append(asyncio.create_task(tool.run(deps, call, conv_messages), name=call.tool_name))
|
|
1028
|
+
else:
|
|
1029
|
+
parts.append(self._unknown_tool(call.tool_name))
|
|
856
1030
|
|
|
857
1031
|
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
|
|
858
|
-
task_results: Sequence[_messages.
|
|
859
|
-
|
|
860
|
-
return
|
|
1032
|
+
task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
|
|
1033
|
+
parts.extend(task_results)
|
|
1034
|
+
return model_response_msg, parts
|
|
1035
|
+
else:
|
|
1036
|
+
assert_never(model_response)
|
|
861
1037
|
|
|
862
1038
|
async def _validate_result(
|
|
863
|
-
self,
|
|
1039
|
+
self,
|
|
1040
|
+
result_data: ResultData,
|
|
1041
|
+
deps: AgentDeps,
|
|
1042
|
+
tool_call: _messages.ToolCallPart | None,
|
|
1043
|
+
conv_messages: list[_messages.ModelMessage],
|
|
864
1044
|
) -> ResultData:
|
|
865
1045
|
for validator in self._result_validators:
|
|
866
|
-
result_data = await validator.validate(
|
|
1046
|
+
result_data = await validator.validate(
|
|
1047
|
+
result_data, deps, self._current_result_retry, tool_call, conv_messages
|
|
1048
|
+
)
|
|
867
1049
|
return result_data
|
|
868
1050
|
|
|
869
1051
|
def _incr_result_retry(self) -> None:
|
|
@@ -873,15 +1055,15 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
873
1055
|
f'Exceeded maximum retries ({self._max_result_retries}) for result validation'
|
|
874
1056
|
)
|
|
875
1057
|
|
|
876
|
-
async def
|
|
1058
|
+
async def _sys_parts(self, deps: AgentDeps) -> list[_messages.ModelRequestPart]:
|
|
877
1059
|
"""Build the initial messages for the conversation."""
|
|
878
|
-
messages: list[_messages.
|
|
1060
|
+
messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self._system_prompts]
|
|
879
1061
|
for sys_prompt_runner in self._system_prompt_functions:
|
|
880
1062
|
prompt = await sys_prompt_runner.run(deps)
|
|
881
|
-
messages.append(_messages.
|
|
1063
|
+
messages.append(_messages.SystemPromptPart(prompt))
|
|
882
1064
|
return messages
|
|
883
1065
|
|
|
884
|
-
def _unknown_tool(self, tool_name: str) -> _messages.
|
|
1066
|
+
def _unknown_tool(self, tool_name: str) -> _messages.RetryPromptPart:
|
|
885
1067
|
self._incr_result_retry()
|
|
886
1068
|
names = list(self._function_tools.keys())
|
|
887
1069
|
if self._result_schema:
|
|
@@ -890,7 +1072,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
890
1072
|
msg = f'Available tools: {", ".join(names)}'
|
|
891
1073
|
else:
|
|
892
1074
|
msg = 'No tools available.'
|
|
893
|
-
return _messages.
|
|
1075
|
+
return _messages.RetryPromptPart(content=f'Unknown tool name: {tool_name!r}. {msg}')
|
|
894
1076
|
|
|
895
1077
|
def _get_deps(self, deps: AgentDeps) -> AgentDeps:
|
|
896
1078
|
"""Get deps for a run.
|
|
@@ -934,3 +1116,6 @@ class _MarkFinalResult(Generic[ResultData]):
|
|
|
934
1116
|
"""
|
|
935
1117
|
|
|
936
1118
|
data: ResultData
|
|
1119
|
+
"""The final result data."""
|
|
1120
|
+
tool_name: str | None
|
|
1121
|
+
"""Name of the final result tool, None if the result is a string."""
|