pydantic-ai-slim 0.0.16__py3-none-any.whl → 0.0.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_griffe.py +13 -1
- pydantic_ai/_system_prompt.py +1 -0
- pydantic_ai/agent.py +272 -68
- pydantic_ai/format_as_xml.py +115 -0
- pydantic_ai/messages.py +6 -0
- pydantic_ai/models/__init__.py +28 -10
- pydantic_ai/models/anthropic.py +1 -1
- pydantic_ai/models/gemini.py +22 -24
- pydantic_ai/models/vertexai.py +2 -2
- pydantic_ai/result.py +92 -69
- pydantic_ai/settings.py +1 -61
- pydantic_ai/tools.py +7 -7
- pydantic_ai/usage.py +114 -0
- {pydantic_ai_slim-0.0.16.dist-info → pydantic_ai_slim-0.0.18.dist-info}/METADATA +1 -1
- pydantic_ai_slim-0.0.18.dist-info/RECORD +28 -0
- pydantic_ai_slim-0.0.16.dist-info/RECORD +0 -26
- {pydantic_ai_slim-0.0.16.dist-info → pydantic_ai_slim-0.0.18.dist-info}/WHEEL +0 -0
pydantic_ai/agent.py
CHANGED
|
@@ -4,13 +4,13 @@ import asyncio
|
|
|
4
4
|
import dataclasses
|
|
5
5
|
import inspect
|
|
6
6
|
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
|
|
7
|
-
from contextlib import asynccontextmanager, contextmanager
|
|
7
|
+
from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager
|
|
8
8
|
from contextvars import ContextVar
|
|
9
9
|
from types import FrameType
|
|
10
10
|
from typing import Any, Callable, Generic, Literal, cast, final, overload
|
|
11
11
|
|
|
12
12
|
import logfire_api
|
|
13
|
-
from typing_extensions import assert_never, deprecated
|
|
13
|
+
from typing_extensions import TypeVar, assert_never, deprecated
|
|
14
14
|
|
|
15
15
|
from . import (
|
|
16
16
|
_result,
|
|
@@ -20,9 +20,10 @@ from . import (
|
|
|
20
20
|
messages as _messages,
|
|
21
21
|
models,
|
|
22
22
|
result,
|
|
23
|
+
usage as _usage,
|
|
23
24
|
)
|
|
24
25
|
from .result import ResultData
|
|
25
|
-
from .settings import ModelSettings,
|
|
26
|
+
from .settings import ModelSettings, merge_model_settings
|
|
26
27
|
from .tools import (
|
|
27
28
|
AgentDeps,
|
|
28
29
|
RunContext,
|
|
@@ -56,6 +57,8 @@ EndStrategy = Literal['early', 'exhaustive']
|
|
|
56
57
|
- `'early'`: Stop processing other tool calls once a final result is found
|
|
57
58
|
- `'exhaustive'`: Process all tool calls even after finding a final result
|
|
58
59
|
"""
|
|
60
|
+
RunResultData = TypeVar('RunResultData')
|
|
61
|
+
"""Type variable for the result data of a run where `result_type` was customized on the run call."""
|
|
59
62
|
|
|
60
63
|
|
|
61
64
|
@final
|
|
@@ -98,14 +101,17 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
98
101
|
Note, if `model_settings` is provided by `run`, `run_sync`, or `run_stream`, those settings will
|
|
99
102
|
be merged with this value, with the runtime argument taking priority.
|
|
100
103
|
"""
|
|
101
|
-
|
|
104
|
+
_result_tool_name: str = dataclasses.field(repr=False)
|
|
105
|
+
_result_tool_description: str | None = dataclasses.field(repr=False)
|
|
102
106
|
_result_schema: _result.ResultSchema[ResultData] | None = dataclasses.field(repr=False)
|
|
103
107
|
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = dataclasses.field(repr=False)
|
|
104
|
-
_allow_text_result: bool = dataclasses.field(repr=False)
|
|
105
108
|
_system_prompts: tuple[str, ...] = dataclasses.field(repr=False)
|
|
106
109
|
_function_tools: dict[str, Tool[AgentDeps]] = dataclasses.field(repr=False)
|
|
107
110
|
_default_retries: int = dataclasses.field(repr=False)
|
|
108
111
|
_system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = dataclasses.field(repr=False)
|
|
112
|
+
_system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDeps]] = dataclasses.field(
|
|
113
|
+
repr=False
|
|
114
|
+
)
|
|
109
115
|
_deps_type: type[AgentDeps] = dataclasses.field(repr=False)
|
|
110
116
|
_max_result_retries: int = dataclasses.field(repr=False)
|
|
111
117
|
_override_deps: _utils.Option[AgentDeps] = dataclasses.field(default=None, repr=False)
|
|
@@ -165,11 +171,11 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
165
171
|
self.end_strategy = end_strategy
|
|
166
172
|
self.name = name
|
|
167
173
|
self.model_settings = model_settings
|
|
174
|
+
self._result_tool_name = result_tool_name
|
|
175
|
+
self._result_tool_description = result_tool_description
|
|
168
176
|
self._result_schema = _result.ResultSchema[result_type].build(
|
|
169
177
|
result_type, result_tool_name, result_tool_description
|
|
170
178
|
)
|
|
171
|
-
# if the result tool is None, or its schema allows `str`, we allow plain text results
|
|
172
|
-
self._allow_text_result = self._result_schema is None or self._result_schema.allow_text_result
|
|
173
179
|
|
|
174
180
|
self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt)
|
|
175
181
|
self._function_tools = {}
|
|
@@ -181,21 +187,53 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
181
187
|
self._register_tool(Tool(tool))
|
|
182
188
|
self._deps_type = deps_type
|
|
183
189
|
self._system_prompt_functions = []
|
|
190
|
+
self._system_prompt_dynamic_functions = {}
|
|
184
191
|
self._max_result_retries = result_retries if result_retries is not None else retries
|
|
185
192
|
self._result_validators = []
|
|
186
193
|
|
|
194
|
+
@overload
|
|
195
|
+
async def run(
|
|
196
|
+
self,
|
|
197
|
+
user_prompt: str,
|
|
198
|
+
*,
|
|
199
|
+
result_type: None = None,
|
|
200
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
201
|
+
model: models.Model | models.KnownModelName | None = None,
|
|
202
|
+
deps: AgentDeps = None,
|
|
203
|
+
model_settings: ModelSettings | None = None,
|
|
204
|
+
usage_limits: _usage.UsageLimits | None = None,
|
|
205
|
+
usage: _usage.Usage | None = None,
|
|
206
|
+
infer_name: bool = True,
|
|
207
|
+
) -> result.RunResult[ResultData]: ...
|
|
208
|
+
|
|
209
|
+
@overload
|
|
187
210
|
async def run(
|
|
188
211
|
self,
|
|
189
212
|
user_prompt: str,
|
|
190
213
|
*,
|
|
214
|
+
result_type: type[RunResultData],
|
|
191
215
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
192
216
|
model: models.Model | models.KnownModelName | None = None,
|
|
193
217
|
deps: AgentDeps = None,
|
|
194
218
|
model_settings: ModelSettings | None = None,
|
|
195
|
-
usage_limits: UsageLimits | None = None,
|
|
196
|
-
usage:
|
|
219
|
+
usage_limits: _usage.UsageLimits | None = None,
|
|
220
|
+
usage: _usage.Usage | None = None,
|
|
197
221
|
infer_name: bool = True,
|
|
198
|
-
) -> result.RunResult[
|
|
222
|
+
) -> result.RunResult[RunResultData]: ...
|
|
223
|
+
|
|
224
|
+
async def run(
|
|
225
|
+
self,
|
|
226
|
+
user_prompt: str,
|
|
227
|
+
*,
|
|
228
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
229
|
+
model: models.Model | models.KnownModelName | None = None,
|
|
230
|
+
deps: AgentDeps = None,
|
|
231
|
+
model_settings: ModelSettings | None = None,
|
|
232
|
+
usage_limits: _usage.UsageLimits | None = None,
|
|
233
|
+
usage: _usage.Usage | None = None,
|
|
234
|
+
result_type: type[RunResultData] | None = None,
|
|
235
|
+
infer_name: bool = True,
|
|
236
|
+
) -> result.RunResult[Any]:
|
|
199
237
|
"""Run the agent with a user prompt in async mode.
|
|
200
238
|
|
|
201
239
|
Example:
|
|
@@ -210,6 +248,8 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
210
248
|
```
|
|
211
249
|
|
|
212
250
|
Args:
|
|
251
|
+
result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no
|
|
252
|
+
result validators since result validators would expect an argument that matches the agent's result type.
|
|
213
253
|
user_prompt: User input to start/continue the conversation.
|
|
214
254
|
message_history: History of the conversation so far.
|
|
215
255
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
@@ -228,6 +268,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
228
268
|
|
|
229
269
|
deps = self._get_deps(deps)
|
|
230
270
|
new_message_index = len(message_history) if message_history else 0
|
|
271
|
+
result_schema = self._prepare_result_schema(result_type)
|
|
231
272
|
|
|
232
273
|
with _logfire.span(
|
|
233
274
|
'{agent_name} run {prompt=}',
|
|
@@ -236,7 +277,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
236
277
|
model_name=model_used.name(),
|
|
237
278
|
agent_name=self.name or 'agent',
|
|
238
279
|
) as run_span:
|
|
239
|
-
run_context = RunContext(deps, model_used, usage or
|
|
280
|
+
run_context = RunContext(deps, model_used, usage or _usage.Usage(), user_prompt)
|
|
240
281
|
messages = await self._prepare_messages(user_prompt, message_history, run_context)
|
|
241
282
|
run_context.messages = messages
|
|
242
283
|
|
|
@@ -244,14 +285,14 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
244
285
|
tool.current_retry = 0
|
|
245
286
|
|
|
246
287
|
model_settings = merge_model_settings(self.model_settings, model_settings)
|
|
247
|
-
usage_limits = usage_limits or UsageLimits()
|
|
288
|
+
usage_limits = usage_limits or _usage.UsageLimits()
|
|
248
289
|
|
|
249
290
|
while True:
|
|
250
291
|
usage_limits.check_before_request(run_context.usage)
|
|
251
292
|
|
|
252
293
|
run_context.run_step += 1
|
|
253
294
|
with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step):
|
|
254
|
-
agent_model = await self._prepare_model(run_context)
|
|
295
|
+
agent_model = await self._prepare_model(run_context, result_schema)
|
|
255
296
|
|
|
256
297
|
with _logfire.span('model request', run_step=run_context.run_step) as model_req_span:
|
|
257
298
|
model_response, request_usage = await agent_model.request(messages, model_settings)
|
|
@@ -263,7 +304,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
263
304
|
usage_limits.check_tokens(run_context.usage)
|
|
264
305
|
|
|
265
306
|
with _logfire.span('handle model response', run_step=run_context.run_step) as handle_span:
|
|
266
|
-
final_result, tool_responses = await self._handle_model_response(
|
|
307
|
+
final_result, tool_responses = await self._handle_model_response(
|
|
308
|
+
model_response, run_context, result_schema
|
|
309
|
+
)
|
|
267
310
|
|
|
268
311
|
if tool_responses:
|
|
269
312
|
# Add parts to the conversation as a new message
|
|
@@ -272,29 +315,62 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
272
315
|
# Check if we got a final result
|
|
273
316
|
if final_result is not None:
|
|
274
317
|
result_data = final_result.data
|
|
318
|
+
result_tool_name = final_result.tool_name
|
|
275
319
|
run_span.set_attribute('all_messages', messages)
|
|
276
320
|
run_span.set_attribute('usage', run_context.usage)
|
|
277
321
|
handle_span.set_attribute('result', result_data)
|
|
278
322
|
handle_span.message = 'handle model response -> final result'
|
|
279
|
-
return result.RunResult(
|
|
323
|
+
return result.RunResult(
|
|
324
|
+
messages, new_message_index, result_data, result_tool_name, run_context.usage
|
|
325
|
+
)
|
|
280
326
|
else:
|
|
281
327
|
# continue the conversation
|
|
282
328
|
handle_span.set_attribute('tool_responses', tool_responses)
|
|
283
329
|
tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
|
|
284
330
|
handle_span.message = f'handle model response -> {tool_responses_str}'
|
|
285
331
|
|
|
332
|
+
@overload
|
|
333
|
+
def run_sync(
|
|
334
|
+
self,
|
|
335
|
+
user_prompt: str,
|
|
336
|
+
*,
|
|
337
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
338
|
+
model: models.Model | models.KnownModelName | None = None,
|
|
339
|
+
deps: AgentDeps = None,
|
|
340
|
+
model_settings: ModelSettings | None = None,
|
|
341
|
+
usage_limits: _usage.UsageLimits | None = None,
|
|
342
|
+
usage: _usage.Usage | None = None,
|
|
343
|
+
infer_name: bool = True,
|
|
344
|
+
) -> result.RunResult[ResultData]: ...
|
|
345
|
+
|
|
346
|
+
@overload
|
|
286
347
|
def run_sync(
|
|
287
348
|
self,
|
|
288
349
|
user_prompt: str,
|
|
289
350
|
*,
|
|
351
|
+
result_type: type[RunResultData] | None,
|
|
290
352
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
291
353
|
model: models.Model | models.KnownModelName | None = None,
|
|
292
354
|
deps: AgentDeps = None,
|
|
293
355
|
model_settings: ModelSettings | None = None,
|
|
294
|
-
usage_limits: UsageLimits | None = None,
|
|
295
|
-
usage:
|
|
356
|
+
usage_limits: _usage.UsageLimits | None = None,
|
|
357
|
+
usage: _usage.Usage | None = None,
|
|
296
358
|
infer_name: bool = True,
|
|
297
|
-
) -> result.RunResult[
|
|
359
|
+
) -> result.RunResult[RunResultData]: ...
|
|
360
|
+
|
|
361
|
+
def run_sync(
|
|
362
|
+
self,
|
|
363
|
+
user_prompt: str,
|
|
364
|
+
*,
|
|
365
|
+
result_type: type[RunResultData] | None = None,
|
|
366
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
367
|
+
model: models.Model | models.KnownModelName | None = None,
|
|
368
|
+
deps: AgentDeps = None,
|
|
369
|
+
model_settings: ModelSettings | None = None,
|
|
370
|
+
usage_limits: _usage.UsageLimits | None = None,
|
|
371
|
+
usage: _usage.Usage | None = None,
|
|
372
|
+
infer_name: bool = True,
|
|
373
|
+
) -> result.RunResult[Any]:
|
|
298
374
|
"""Run the agent with a user prompt synchronously.
|
|
299
375
|
|
|
300
376
|
This is a convenience method that wraps [`self.run`][pydantic_ai.Agent.run] with `loop.run_until_complete(...)`.
|
|
@@ -313,6 +389,8 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
313
389
|
```
|
|
314
390
|
|
|
315
391
|
Args:
|
|
392
|
+
result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no
|
|
393
|
+
result validators since result validators would expect an argument that matches the agent's result type.
|
|
316
394
|
user_prompt: User input to start/continue the conversation.
|
|
317
395
|
message_history: History of the conversation so far.
|
|
318
396
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
@@ -330,6 +408,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
330
408
|
return asyncio.get_event_loop().run_until_complete(
|
|
331
409
|
self.run(
|
|
332
410
|
user_prompt,
|
|
411
|
+
result_type=result_type,
|
|
333
412
|
message_history=message_history,
|
|
334
413
|
model=model,
|
|
335
414
|
deps=deps,
|
|
@@ -340,19 +419,50 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
340
419
|
)
|
|
341
420
|
)
|
|
342
421
|
|
|
422
|
+
@overload
|
|
423
|
+
def run_stream(
|
|
424
|
+
self,
|
|
425
|
+
user_prompt: str,
|
|
426
|
+
*,
|
|
427
|
+
result_type: None = None,
|
|
428
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
429
|
+
model: models.Model | models.KnownModelName | None = None,
|
|
430
|
+
deps: AgentDeps = None,
|
|
431
|
+
model_settings: ModelSettings | None = None,
|
|
432
|
+
usage_limits: _usage.UsageLimits | None = None,
|
|
433
|
+
usage: _usage.Usage | None = None,
|
|
434
|
+
infer_name: bool = True,
|
|
435
|
+
) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDeps, ResultData]]: ...
|
|
436
|
+
|
|
437
|
+
@overload
|
|
438
|
+
def run_stream(
|
|
439
|
+
self,
|
|
440
|
+
user_prompt: str,
|
|
441
|
+
*,
|
|
442
|
+
result_type: type[RunResultData],
|
|
443
|
+
message_history: list[_messages.ModelMessage] | None = None,
|
|
444
|
+
model: models.Model | models.KnownModelName | None = None,
|
|
445
|
+
deps: AgentDeps = None,
|
|
446
|
+
model_settings: ModelSettings | None = None,
|
|
447
|
+
usage_limits: _usage.UsageLimits | None = None,
|
|
448
|
+
usage: _usage.Usage | None = None,
|
|
449
|
+
infer_name: bool = True,
|
|
450
|
+
) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDeps, RunResultData]]: ...
|
|
451
|
+
|
|
343
452
|
@asynccontextmanager
|
|
344
453
|
async def run_stream(
|
|
345
454
|
self,
|
|
346
455
|
user_prompt: str,
|
|
347
456
|
*,
|
|
457
|
+
result_type: type[RunResultData] | None = None,
|
|
348
458
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
349
459
|
model: models.Model | models.KnownModelName | None = None,
|
|
350
460
|
deps: AgentDeps = None,
|
|
351
461
|
model_settings: ModelSettings | None = None,
|
|
352
|
-
usage_limits: UsageLimits | None = None,
|
|
353
|
-
usage:
|
|
462
|
+
usage_limits: _usage.UsageLimits | None = None,
|
|
463
|
+
usage: _usage.Usage | None = None,
|
|
354
464
|
infer_name: bool = True,
|
|
355
|
-
) -> AsyncIterator[result.StreamedRunResult[AgentDeps,
|
|
465
|
+
) -> AsyncIterator[result.StreamedRunResult[AgentDeps, Any]]:
|
|
356
466
|
"""Run the agent with a user prompt in async mode, returning a streamed response.
|
|
357
467
|
|
|
358
468
|
Example:
|
|
@@ -368,6 +478,8 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
368
478
|
```
|
|
369
479
|
|
|
370
480
|
Args:
|
|
481
|
+
result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no
|
|
482
|
+
result validators since result validators would expect an argument that matches the agent's result type.
|
|
371
483
|
user_prompt: User input to start/continue the conversation.
|
|
372
484
|
message_history: History of the conversation so far.
|
|
373
485
|
model: Optional model to use for this run, required if `model` was not set when creating the agent.
|
|
@@ -388,6 +500,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
388
500
|
|
|
389
501
|
deps = self._get_deps(deps)
|
|
390
502
|
new_message_index = len(message_history) if message_history else 0
|
|
503
|
+
result_schema = self._prepare_result_schema(result_type)
|
|
391
504
|
|
|
392
505
|
with _logfire.span(
|
|
393
506
|
'{agent_name} run stream {prompt=}',
|
|
@@ -396,7 +509,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
396
509
|
model_name=model_used.name(),
|
|
397
510
|
agent_name=self.name or 'agent',
|
|
398
511
|
) as run_span:
|
|
399
|
-
run_context = RunContext(deps, model_used, usage or
|
|
512
|
+
run_context = RunContext(deps, model_used, usage or _usage.Usage(), user_prompt)
|
|
400
513
|
messages = await self._prepare_messages(user_prompt, message_history, run_context)
|
|
401
514
|
run_context.messages = messages
|
|
402
515
|
|
|
@@ -404,14 +517,14 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
404
517
|
tool.current_retry = 0
|
|
405
518
|
|
|
406
519
|
model_settings = merge_model_settings(self.model_settings, model_settings)
|
|
407
|
-
usage_limits = usage_limits or UsageLimits()
|
|
520
|
+
usage_limits = usage_limits or _usage.UsageLimits()
|
|
408
521
|
|
|
409
522
|
while True:
|
|
410
523
|
run_context.run_step += 1
|
|
411
524
|
usage_limits.check_before_request(run_context.usage)
|
|
412
525
|
|
|
413
526
|
with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step):
|
|
414
|
-
agent_model = await self._prepare_model(run_context)
|
|
527
|
+
agent_model = await self._prepare_model(run_context, result_schema)
|
|
415
528
|
|
|
416
529
|
with _logfire.span('model request {run_step=}', run_step=run_context.run_step) as model_req_span:
|
|
417
530
|
async with agent_model.request_stream(messages, model_settings) as model_response:
|
|
@@ -422,7 +535,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
422
535
|
model_req_span.__exit__(None, None, None)
|
|
423
536
|
|
|
424
537
|
with _logfire.span('handle model response') as handle_span:
|
|
425
|
-
maybe_final_result = await self._handle_streamed_model_response(
|
|
538
|
+
maybe_final_result = await self._handle_streamed_model_response(
|
|
539
|
+
model_response, run_context, result_schema
|
|
540
|
+
)
|
|
426
541
|
|
|
427
542
|
# Check if we got a final result
|
|
428
543
|
if isinstance(maybe_final_result, _MarkFinalResult):
|
|
@@ -442,7 +557,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
442
557
|
part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)
|
|
443
558
|
]
|
|
444
559
|
parts = await self._process_function_tools(
|
|
445
|
-
tool_calls, result_tool_name, run_context
|
|
560
|
+
tool_calls, result_tool_name, run_context, result_schema
|
|
446
561
|
)
|
|
447
562
|
if parts:
|
|
448
563
|
messages.append(_messages.ModelRequest(parts))
|
|
@@ -453,7 +568,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
453
568
|
new_message_index,
|
|
454
569
|
usage_limits,
|
|
455
570
|
result_stream,
|
|
456
|
-
|
|
571
|
+
result_schema,
|
|
457
572
|
run_context,
|
|
458
573
|
self._result_validators,
|
|
459
574
|
result_tool_name,
|
|
@@ -531,17 +646,37 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
531
646
|
@overload
|
|
532
647
|
def system_prompt(self, func: Callable[[], Awaitable[str]], /) -> Callable[[], Awaitable[str]]: ...
|
|
533
648
|
|
|
649
|
+
@overload
|
|
650
|
+
def system_prompt(
|
|
651
|
+
self, /, *, dynamic: bool = False
|
|
652
|
+
) -> Callable[[_system_prompt.SystemPromptFunc[AgentDeps]], _system_prompt.SystemPromptFunc[AgentDeps]]: ...
|
|
653
|
+
|
|
534
654
|
def system_prompt(
|
|
535
|
-
self,
|
|
536
|
-
|
|
655
|
+
self,
|
|
656
|
+
func: _system_prompt.SystemPromptFunc[AgentDeps] | None = None,
|
|
657
|
+
/,
|
|
658
|
+
*,
|
|
659
|
+
dynamic: bool = False,
|
|
660
|
+
) -> (
|
|
661
|
+
Callable[[_system_prompt.SystemPromptFunc[AgentDeps]], _system_prompt.SystemPromptFunc[AgentDeps]]
|
|
662
|
+
| _system_prompt.SystemPromptFunc[AgentDeps]
|
|
663
|
+
):
|
|
537
664
|
"""Decorator to register a system prompt function.
|
|
538
665
|
|
|
539
666
|
Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its only argument.
|
|
540
667
|
Can decorate a sync or async functions.
|
|
541
668
|
|
|
669
|
+
The decorator can be used either bare (`agent.system_prompt`) or as a function call
|
|
670
|
+
(`agent.system_prompt(...)`), see the examples below.
|
|
671
|
+
|
|
542
672
|
Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure
|
|
543
673
|
the type of the function, see `tests/typed_agent.py` for tests.
|
|
544
674
|
|
|
675
|
+
Args:
|
|
676
|
+
func: The function to decorate
|
|
677
|
+
dynamic: If True, the system prompt will be reevaluated even when `messages_history` is provided,
|
|
678
|
+
see [`SystemPromptPart.dynamic_ref`][pydantic_ai.messages.SystemPromptPart.dynamic_ref]
|
|
679
|
+
|
|
545
680
|
Example:
|
|
546
681
|
```python
|
|
547
682
|
from pydantic_ai import Agent, RunContext
|
|
@@ -552,17 +687,27 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
552
687
|
def simple_system_prompt() -> str:
|
|
553
688
|
return 'foobar'
|
|
554
689
|
|
|
555
|
-
@agent.system_prompt
|
|
690
|
+
@agent.system_prompt(dynamic=True)
|
|
556
691
|
async def async_system_prompt(ctx: RunContext[str]) -> str:
|
|
557
692
|
return f'{ctx.deps} is the best'
|
|
558
|
-
|
|
559
|
-
result = agent.run_sync('foobar', deps='spam')
|
|
560
|
-
print(result.data)
|
|
561
|
-
#> success (no tool calls)
|
|
562
693
|
```
|
|
563
694
|
"""
|
|
564
|
-
|
|
565
|
-
|
|
695
|
+
if func is None:
|
|
696
|
+
|
|
697
|
+
def decorator(
|
|
698
|
+
func_: _system_prompt.SystemPromptFunc[AgentDeps],
|
|
699
|
+
) -> _system_prompt.SystemPromptFunc[AgentDeps]:
|
|
700
|
+
runner = _system_prompt.SystemPromptRunner(func_, dynamic=dynamic)
|
|
701
|
+
self._system_prompt_functions.append(runner)
|
|
702
|
+
if dynamic:
|
|
703
|
+
self._system_prompt_dynamic_functions[func_.__qualname__] = runner
|
|
704
|
+
return func_
|
|
705
|
+
|
|
706
|
+
return decorator
|
|
707
|
+
else:
|
|
708
|
+
assert not dynamic, "dynamic can't be True in this case"
|
|
709
|
+
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func, dynamic=dynamic))
|
|
710
|
+
return func
|
|
566
711
|
|
|
567
712
|
@overload
|
|
568
713
|
def result_validator(
|
|
@@ -814,7 +959,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
814
959
|
|
|
815
960
|
return model_
|
|
816
961
|
|
|
817
|
-
async def _prepare_model(
|
|
962
|
+
async def _prepare_model(
|
|
963
|
+
self, run_context: RunContext[AgentDeps], result_schema: _result.ResultSchema[RunResultData] | None
|
|
964
|
+
) -> models.AgentModel:
|
|
818
965
|
"""Build tools and create an agent model."""
|
|
819
966
|
function_tools: list[ToolDefinition] = []
|
|
820
967
|
|
|
@@ -827,10 +974,39 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
827
974
|
|
|
828
975
|
return await run_context.model.agent_model(
|
|
829
976
|
function_tools=function_tools,
|
|
830
|
-
allow_text_result=self._allow_text_result,
|
|
831
|
-
result_tools=
|
|
977
|
+
allow_text_result=self._allow_text_result(result_schema),
|
|
978
|
+
result_tools=result_schema.tool_defs() if result_schema is not None else [],
|
|
832
979
|
)
|
|
833
980
|
|
|
981
|
+
async def _reevaluate_dynamic_prompts(
|
|
982
|
+
self, messages: list[_messages.ModelMessage], run_context: RunContext[AgentDeps]
|
|
983
|
+
) -> None:
|
|
984
|
+
"""Reevaluate any `SystemPromptPart` with dynamic_ref in the provided messages by running the associated runner function."""
|
|
985
|
+
# Only proceed if there's at least one dynamic runner.
|
|
986
|
+
if self._system_prompt_dynamic_functions:
|
|
987
|
+
for msg in messages:
|
|
988
|
+
if isinstance(msg, _messages.ModelRequest):
|
|
989
|
+
for i, part in enumerate(msg.parts):
|
|
990
|
+
if isinstance(part, _messages.SystemPromptPart) and part.dynamic_ref:
|
|
991
|
+
# Look up the runner by its ref
|
|
992
|
+
if runner := self._system_prompt_dynamic_functions.get(part.dynamic_ref):
|
|
993
|
+
updated_part_content = await runner.run(run_context)
|
|
994
|
+
msg.parts[i] = _messages.SystemPromptPart(
|
|
995
|
+
updated_part_content, dynamic_ref=part.dynamic_ref
|
|
996
|
+
)
|
|
997
|
+
|
|
998
|
+
def _prepare_result_schema(
|
|
999
|
+
self, result_type: type[RunResultData] | None
|
|
1000
|
+
) -> _result.ResultSchema[RunResultData] | None:
|
|
1001
|
+
if result_type is not None:
|
|
1002
|
+
if self._result_validators:
|
|
1003
|
+
raise exceptions.UserError('Cannot set a custom run `result_type` when the agent has result validators')
|
|
1004
|
+
return _result.ResultSchema[result_type].build(
|
|
1005
|
+
result_type, self._result_tool_name, self._result_tool_description
|
|
1006
|
+
)
|
|
1007
|
+
else:
|
|
1008
|
+
return self._result_schema # pyright: ignore[reportReturnType]
|
|
1009
|
+
|
|
834
1010
|
async def _prepare_messages(
|
|
835
1011
|
self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[AgentDeps]
|
|
836
1012
|
) -> list[_messages.ModelMessage]:
|
|
@@ -846,8 +1022,10 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
846
1022
|
ctx_messages.used = True
|
|
847
1023
|
|
|
848
1024
|
if message_history:
|
|
849
|
-
#
|
|
1025
|
+
# Shallow copy messages
|
|
850
1026
|
messages.extend(message_history)
|
|
1027
|
+
# Reevaluate any dynamic system prompt parts
|
|
1028
|
+
await self._reevaluate_dynamic_prompts(messages, run_context)
|
|
851
1029
|
messages.append(_messages.ModelRequest([_messages.UserPromptPart(user_prompt)]))
|
|
852
1030
|
else:
|
|
853
1031
|
parts = await self._sys_parts(run_context)
|
|
@@ -857,8 +1035,11 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
857
1035
|
return messages
|
|
858
1036
|
|
|
859
1037
|
async def _handle_model_response(
|
|
860
|
-
self,
|
|
861
|
-
|
|
1038
|
+
self,
|
|
1039
|
+
model_response: _messages.ModelResponse,
|
|
1040
|
+
run_context: RunContext[AgentDeps],
|
|
1041
|
+
result_schema: _result.ResultSchema[RunResultData] | None,
|
|
1042
|
+
) -> tuple[_MarkFinalResult[RunResultData] | None, list[_messages.ModelRequestPart]]:
|
|
862
1043
|
"""Process a non-streamed response from the model.
|
|
863
1044
|
|
|
864
1045
|
Returns:
|
|
@@ -879,19 +1060,19 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
879
1060
|
# This accounts for cases like anthropic returns that might contain a text response
|
|
880
1061
|
# and a tool call response, where the text response just indicates the tool call will happen.
|
|
881
1062
|
if tool_calls:
|
|
882
|
-
return await self._handle_structured_response(tool_calls, run_context)
|
|
1063
|
+
return await self._handle_structured_response(tool_calls, run_context, result_schema)
|
|
883
1064
|
elif texts:
|
|
884
1065
|
text = '\n\n'.join(texts)
|
|
885
|
-
return await self._handle_text_response(text, run_context)
|
|
1066
|
+
return await self._handle_text_response(text, run_context, result_schema)
|
|
886
1067
|
else:
|
|
887
1068
|
raise exceptions.UnexpectedModelBehavior('Received empty model response')
|
|
888
1069
|
|
|
889
1070
|
async def _handle_text_response(
|
|
890
|
-
self, text: str, run_context: RunContext[AgentDeps]
|
|
891
|
-
) -> tuple[_MarkFinalResult[
|
|
1071
|
+
self, text: str, run_context: RunContext[AgentDeps], result_schema: _result.ResultSchema[RunResultData] | None
|
|
1072
|
+
) -> tuple[_MarkFinalResult[RunResultData] | None, list[_messages.ModelRequestPart]]:
|
|
892
1073
|
"""Handle a plain text response from the model for non-streaming responses."""
|
|
893
|
-
if self._allow_text_result:
|
|
894
|
-
result_data_input = cast(
|
|
1074
|
+
if self._allow_text_result(result_schema):
|
|
1075
|
+
result_data_input = cast(RunResultData, text)
|
|
895
1076
|
try:
|
|
896
1077
|
result_data = await self._validate_result(result_data_input, run_context, None)
|
|
897
1078
|
except _result.ToolRetryError as e:
|
|
@@ -907,16 +1088,19 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
907
1088
|
return None, [response]
|
|
908
1089
|
|
|
909
1090
|
async def _handle_structured_response(
|
|
910
|
-
self,
|
|
911
|
-
|
|
1091
|
+
self,
|
|
1092
|
+
tool_calls: list[_messages.ToolCallPart],
|
|
1093
|
+
run_context: RunContext[AgentDeps],
|
|
1094
|
+
result_schema: _result.ResultSchema[RunResultData] | None,
|
|
1095
|
+
) -> tuple[_MarkFinalResult[RunResultData] | None, list[_messages.ModelRequestPart]]:
|
|
912
1096
|
"""Handle a structured response containing tool calls from the model for non-streaming responses."""
|
|
913
1097
|
assert tool_calls, 'Expected at least one tool call'
|
|
914
1098
|
|
|
915
1099
|
# first look for the result tool call
|
|
916
|
-
final_result: _MarkFinalResult[
|
|
1100
|
+
final_result: _MarkFinalResult[RunResultData] | None = None
|
|
917
1101
|
|
|
918
1102
|
parts: list[_messages.ModelRequestPart] = []
|
|
919
|
-
if result_schema :=
|
|
1103
|
+
if result_schema := result_schema:
|
|
920
1104
|
if match := result_schema.find_tool(tool_calls):
|
|
921
1105
|
call, result_tool = match
|
|
922
1106
|
try:
|
|
@@ -929,7 +1113,9 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
929
1113
|
final_result = _MarkFinalResult(result_data, call.tool_name)
|
|
930
1114
|
|
|
931
1115
|
# Then build the other request parts based on end strategy
|
|
932
|
-
parts += await self._process_function_tools(
|
|
1116
|
+
parts += await self._process_function_tools(
|
|
1117
|
+
tool_calls, final_result and final_result.tool_name, run_context, result_schema
|
|
1118
|
+
)
|
|
933
1119
|
|
|
934
1120
|
return final_result, parts
|
|
935
1121
|
|
|
@@ -938,6 +1124,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
938
1124
|
tool_calls: list[_messages.ToolCallPart],
|
|
939
1125
|
result_tool_name: str | None,
|
|
940
1126
|
run_context: RunContext[AgentDeps],
|
|
1127
|
+
result_schema: _result.ResultSchema[RunResultData] | None,
|
|
941
1128
|
) -> list[_messages.ModelRequestPart]:
|
|
942
1129
|
"""Process function (non-result) tool calls in parallel.
|
|
943
1130
|
|
|
@@ -971,7 +1158,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
971
1158
|
)
|
|
972
1159
|
else:
|
|
973
1160
|
tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
|
|
974
|
-
elif
|
|
1161
|
+
elif result_schema is not None and call.tool_name in result_schema.tools:
|
|
975
1162
|
# if tool_name is in _result_schema, it means we found a result tool but an error occurred in
|
|
976
1163
|
# validation, we don't add another part here
|
|
977
1164
|
if result_tool_name is not None:
|
|
@@ -983,7 +1170,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
983
1170
|
)
|
|
984
1171
|
)
|
|
985
1172
|
else:
|
|
986
|
-
parts.append(self._unknown_tool(call.tool_name, run_context))
|
|
1173
|
+
parts.append(self._unknown_tool(call.tool_name, run_context, result_schema))
|
|
987
1174
|
|
|
988
1175
|
# Run all tool tasks in parallel
|
|
989
1176
|
if tasks:
|
|
@@ -996,6 +1183,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
996
1183
|
self,
|
|
997
1184
|
model_response: models.EitherStreamedResponse,
|
|
998
1185
|
run_context: RunContext[AgentDeps],
|
|
1186
|
+
result_schema: _result.ResultSchema[RunResultData] | None,
|
|
999
1187
|
) -> (
|
|
1000
1188
|
_MarkFinalResult[models.EitherStreamedResponse]
|
|
1001
1189
|
| tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]]
|
|
@@ -1008,7 +1196,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1008
1196
|
"""
|
|
1009
1197
|
if isinstance(model_response, models.StreamTextResponse):
|
|
1010
1198
|
# plain string response
|
|
1011
|
-
if self._allow_text_result:
|
|
1199
|
+
if self._allow_text_result(result_schema):
|
|
1012
1200
|
return _MarkFinalResult(model_response, None)
|
|
1013
1201
|
else:
|
|
1014
1202
|
self._incr_result_retry(run_context)
|
|
@@ -1022,7 +1210,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1022
1210
|
text = ''.join(model_response.get(final=True))
|
|
1023
1211
|
return _messages.ModelResponse([_messages.TextPart(text)]), [response]
|
|
1024
1212
|
elif isinstance(model_response, models.StreamStructuredResponse):
|
|
1025
|
-
if
|
|
1213
|
+
if result_schema is not None:
|
|
1026
1214
|
# if there's a result schema, iterate over the stream until we find at least one tool
|
|
1027
1215
|
# NOTE: this means we ignore any other tools called here
|
|
1028
1216
|
structured_msg = model_response.get()
|
|
@@ -1033,7 +1221,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1033
1221
|
break
|
|
1034
1222
|
structured_msg = model_response.get()
|
|
1035
1223
|
|
|
1036
|
-
if match :=
|
|
1224
|
+
if match := result_schema.find_tool(structured_msg.parts):
|
|
1037
1225
|
call, _ = match
|
|
1038
1226
|
return _MarkFinalResult(model_response, call.tool_name)
|
|
1039
1227
|
|
|
@@ -1053,7 +1241,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1053
1241
|
if tool := self._function_tools.get(call.tool_name):
|
|
1054
1242
|
tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
|
|
1055
1243
|
else:
|
|
1056
|
-
parts.append(self._unknown_tool(call.tool_name, run_context))
|
|
1244
|
+
parts.append(self._unknown_tool(call.tool_name, run_context, result_schema))
|
|
1057
1245
|
|
|
1058
1246
|
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
|
|
1059
1247
|
task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
|
|
@@ -1064,13 +1252,17 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1064
1252
|
|
|
1065
1253
|
async def _validate_result(
|
|
1066
1254
|
self,
|
|
1067
|
-
result_data:
|
|
1255
|
+
result_data: RunResultData,
|
|
1068
1256
|
run_context: RunContext[AgentDeps],
|
|
1069
1257
|
tool_call: _messages.ToolCallPart | None,
|
|
1070
|
-
) ->
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
|
|
1258
|
+
) -> RunResultData:
|
|
1259
|
+
if self._result_validators:
|
|
1260
|
+
agent_result_data = cast(ResultData, result_data)
|
|
1261
|
+
for validator in self._result_validators:
|
|
1262
|
+
agent_result_data = await validator.validate(agent_result_data, tool_call, run_context)
|
|
1263
|
+
return cast(RunResultData, agent_result_data)
|
|
1264
|
+
else:
|
|
1265
|
+
return result_data
|
|
1074
1266
|
|
|
1075
1267
|
def _incr_result_retry(self, run_context: RunContext[AgentDeps]) -> None:
|
|
1076
1268
|
run_context.retry += 1
|
|
@@ -1084,14 +1276,22 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1084
1276
|
messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self._system_prompts]
|
|
1085
1277
|
for sys_prompt_runner in self._system_prompt_functions:
|
|
1086
1278
|
prompt = await sys_prompt_runner.run(run_context)
|
|
1087
|
-
|
|
1279
|
+
if sys_prompt_runner.dynamic:
|
|
1280
|
+
messages.append(_messages.SystemPromptPart(prompt, dynamic_ref=sys_prompt_runner.function.__qualname__))
|
|
1281
|
+
else:
|
|
1282
|
+
messages.append(_messages.SystemPromptPart(prompt))
|
|
1088
1283
|
return messages
|
|
1089
1284
|
|
|
1090
|
-
def _unknown_tool(
|
|
1285
|
+
def _unknown_tool(
|
|
1286
|
+
self,
|
|
1287
|
+
tool_name: str,
|
|
1288
|
+
run_context: RunContext[AgentDeps],
|
|
1289
|
+
result_schema: _result.ResultSchema[RunResultData] | None,
|
|
1290
|
+
) -> _messages.RetryPromptPart:
|
|
1091
1291
|
self._incr_result_retry(run_context)
|
|
1092
1292
|
names = list(self._function_tools.keys())
|
|
1093
|
-
if
|
|
1094
|
-
names.extend(
|
|
1293
|
+
if result_schema:
|
|
1294
|
+
names.extend(result_schema.tool_names())
|
|
1095
1295
|
if names:
|
|
1096
1296
|
msg = f'Available tools: {", ".join(names)}'
|
|
1097
1297
|
else:
|
|
@@ -1129,6 +1329,10 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
1129
1329
|
self.name = name
|
|
1130
1330
|
return
|
|
1131
1331
|
|
|
1332
|
+
@staticmethod
|
|
1333
|
+
def _allow_text_result(result_schema: _result.ResultSchema[RunResultData] | None) -> bool:
|
|
1334
|
+
return result_schema is None or result_schema.allow_text_result
|
|
1335
|
+
|
|
1132
1336
|
@property
|
|
1133
1337
|
@deprecated(
|
|
1134
1338
|
'The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.', category=None
|