pydantic-ai-slim 0.0.6a3__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +2 -2
- pydantic_ai/_pydantic.py +10 -10
- pydantic_ai/_result.py +4 -4
- pydantic_ai/_system_prompt.py +2 -2
- pydantic_ai/{_retriever.py → _tool.py} +13 -15
- pydantic_ai/_utils.py +9 -5
- pydantic_ai/agent.py +129 -128
- pydantic_ai/dependencies.py +20 -20
- pydantic_ai/exceptions.py +1 -1
- pydantic_ai/messages.py +12 -12
- pydantic_ai/models/__init__.py +3 -3
- pydantic_ai/models/function.py +10 -14
- pydantic_ai/models/gemini.py +8 -30
- pydantic_ai/models/groq.py +2 -2
- pydantic_ai/models/openai.py +2 -2
- pydantic_ai/models/test.py +30 -28
- pydantic_ai/models/vertexai.py +2 -59
- {pydantic_ai_slim-0.0.6a3.dist-info → pydantic_ai_slim-0.0.7.dist-info}/METADATA +7 -3
- pydantic_ai_slim-0.0.7.dist-info/RECORD +23 -0
- pydantic_ai_slim-0.0.6a3.dist-info/RECORD +0 -23
- {pydantic_ai_slim-0.0.6a3.dist-info → pydantic_ai_slim-0.0.7.dist-info}/WHEEL +0 -0
pydantic_ai/agent.py
CHANGED
|
@@ -11,15 +11,15 @@ from typing_extensions import assert_never
|
|
|
11
11
|
|
|
12
12
|
from . import (
|
|
13
13
|
_result,
|
|
14
|
-
_retriever as _r,
|
|
15
14
|
_system_prompt,
|
|
15
|
+
_tool as _r,
|
|
16
16
|
_utils,
|
|
17
17
|
exceptions,
|
|
18
18
|
messages as _messages,
|
|
19
19
|
models,
|
|
20
20
|
result,
|
|
21
21
|
)
|
|
22
|
-
from .dependencies import AgentDeps,
|
|
22
|
+
from .dependencies import AgentDeps, RunContext, ToolContextFunc, ToolParams, ToolPlainFunc
|
|
23
23
|
from .result import ResultData
|
|
24
24
|
|
|
25
25
|
__all__ = ('Agent',)
|
|
@@ -58,7 +58,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
58
58
|
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = field(repr=False)
|
|
59
59
|
_allow_text_result: bool = field(repr=False)
|
|
60
60
|
_system_prompts: tuple[str, ...] = field(repr=False)
|
|
61
|
-
|
|
61
|
+
_function_tools: dict[str, _r.Tool[AgentDeps, Any]] = field(repr=False)
|
|
62
62
|
_default_retries: int = field(repr=False)
|
|
63
63
|
_system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = field(repr=False)
|
|
64
64
|
_deps_type: type[AgentDeps] = field(repr=False)
|
|
@@ -75,8 +75,8 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
75
75
|
def __init__(
|
|
76
76
|
self,
|
|
77
77
|
model: models.Model | models.KnownModelName | None = None,
|
|
78
|
-
result_type: type[ResultData] = str,
|
|
79
78
|
*,
|
|
79
|
+
result_type: type[ResultData] = str,
|
|
80
80
|
system_prompt: str | Sequence[str] = (),
|
|
81
81
|
deps_type: type[AgentDeps] = NoneType,
|
|
82
82
|
retries: int = 1,
|
|
@@ -105,7 +105,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
105
105
|
it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately,
|
|
106
106
|
which checks for the necessary environment variables. Set this to `false`
|
|
107
107
|
to defer the evaluation until the first run. Useful if you want to
|
|
108
|
-
[override the model][pydantic_ai.Agent.
|
|
108
|
+
[override the model][pydantic_ai.Agent.override] for testing.
|
|
109
109
|
"""
|
|
110
110
|
if model is None or defer_model_check:
|
|
111
111
|
self.model = model
|
|
@@ -119,7 +119,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
119
119
|
self._allow_text_result = self._result_schema is None or self._result_schema.allow_text_result
|
|
120
120
|
|
|
121
121
|
self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt)
|
|
122
|
-
self.
|
|
122
|
+
self._function_tools: dict[str, _r.Tool[AgentDeps, Any]] = {}
|
|
123
123
|
self._deps_type = deps_type
|
|
124
124
|
self._default_retries = retries
|
|
125
125
|
self._system_prompt_functions = []
|
|
@@ -150,14 +150,6 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
150
150
|
|
|
151
151
|
deps = self._get_deps(deps)
|
|
152
152
|
|
|
153
|
-
new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
|
|
154
|
-
self.last_run_messages = messages
|
|
155
|
-
|
|
156
|
-
for retriever in self._retrievers.values():
|
|
157
|
-
retriever.reset()
|
|
158
|
-
|
|
159
|
-
cost = result.Cost()
|
|
160
|
-
|
|
161
153
|
with _logfire.span(
|
|
162
154
|
'agent run {prompt=}',
|
|
163
155
|
prompt=user_prompt,
|
|
@@ -165,6 +157,14 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
165
157
|
custom_model=custom_model,
|
|
166
158
|
model_name=model_used.name(),
|
|
167
159
|
) as run_span:
|
|
160
|
+
new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
|
|
161
|
+
self.last_run_messages = messages
|
|
162
|
+
|
|
163
|
+
for tool in self._function_tools.values():
|
|
164
|
+
tool.reset()
|
|
165
|
+
|
|
166
|
+
cost = result.Cost()
|
|
167
|
+
|
|
168
168
|
run_step = 0
|
|
169
169
|
while True:
|
|
170
170
|
run_step += 1
|
|
@@ -243,14 +243,6 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
243
243
|
|
|
244
244
|
deps = self._get_deps(deps)
|
|
245
245
|
|
|
246
|
-
new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
|
|
247
|
-
self.last_run_messages = messages
|
|
248
|
-
|
|
249
|
-
for retriever in self._retrievers.values():
|
|
250
|
-
retriever.reset()
|
|
251
|
-
|
|
252
|
-
cost = result.Cost()
|
|
253
|
-
|
|
254
246
|
with _logfire.span(
|
|
255
247
|
'agent run stream {prompt=}',
|
|
256
248
|
prompt=user_prompt,
|
|
@@ -258,6 +250,14 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
258
250
|
custom_model=custom_model,
|
|
259
251
|
model_name=model_used.name(),
|
|
260
252
|
) as run_span:
|
|
253
|
+
new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
|
|
254
|
+
self.last_run_messages = messages
|
|
255
|
+
|
|
256
|
+
for tool in self._function_tools.values():
|
|
257
|
+
tool.reset()
|
|
258
|
+
|
|
259
|
+
cost = result.Cost()
|
|
260
|
+
|
|
261
261
|
run_step = 0
|
|
262
262
|
while True:
|
|
263
263
|
run_step += 1
|
|
@@ -296,42 +296,51 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
296
296
|
cost += model_response.cost()
|
|
297
297
|
|
|
298
298
|
@contextmanager
|
|
299
|
-
def
|
|
300
|
-
|
|
299
|
+
def override(
|
|
300
|
+
self,
|
|
301
|
+
*,
|
|
302
|
+
deps: AgentDeps | _utils.Unset = _utils.UNSET,
|
|
303
|
+
model: models.Model | models.KnownModelName | _utils.Unset = _utils.UNSET,
|
|
304
|
+
) -> Iterator[None]:
|
|
305
|
+
"""Context manager to temporarily override agent dependencies and model.
|
|
306
|
+
|
|
307
|
+
This is particularly useful when testing.
|
|
301
308
|
|
|
302
309
|
Args:
|
|
303
|
-
|
|
310
|
+
deps: The dependencies to use instead of the dependencies passed to the agent run.
|
|
311
|
+
model: The model to use instead of the model passed to the agent run.
|
|
304
312
|
"""
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
self._override_deps = override_deps_before
|
|
313
|
+
if _utils.is_set(deps):
|
|
314
|
+
override_deps_before = self._override_deps
|
|
315
|
+
self._override_deps = _utils.Some(deps)
|
|
316
|
+
else:
|
|
317
|
+
override_deps_before = _utils.UNSET
|
|
311
318
|
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
319
|
+
# noinspection PyTypeChecker
|
|
320
|
+
if _utils.is_set(model):
|
|
321
|
+
override_model_before = self._override_model
|
|
322
|
+
# noinspection PyTypeChecker
|
|
323
|
+
self._override_model = _utils.Some(models.infer_model(model)) # pyright: ignore[reportArgumentType]
|
|
324
|
+
else:
|
|
325
|
+
override_model_before = _utils.UNSET
|
|
315
326
|
|
|
316
|
-
Args:
|
|
317
|
-
overriding_model: The model to use instead of the model passed to the agent run.
|
|
318
|
-
"""
|
|
319
|
-
override_model_before = self._override_model
|
|
320
|
-
self._override_model = _utils.Some(models.infer_model(overriding_model))
|
|
321
327
|
try:
|
|
322
328
|
yield
|
|
323
329
|
finally:
|
|
324
|
-
|
|
330
|
+
if _utils.is_set(override_deps_before):
|
|
331
|
+
self._override_deps = override_deps_before
|
|
332
|
+
if _utils.is_set(override_model_before):
|
|
333
|
+
self._override_model = override_model_before
|
|
325
334
|
|
|
326
335
|
@overload
|
|
327
336
|
def system_prompt(
|
|
328
|
-
self, func: Callable[[
|
|
329
|
-
) -> Callable[[
|
|
337
|
+
self, func: Callable[[RunContext[AgentDeps]], str], /
|
|
338
|
+
) -> Callable[[RunContext[AgentDeps]], str]: ...
|
|
330
339
|
|
|
331
340
|
@overload
|
|
332
341
|
def system_prompt(
|
|
333
|
-
self, func: Callable[[
|
|
334
|
-
) -> Callable[[
|
|
342
|
+
self, func: Callable[[RunContext[AgentDeps]], Awaitable[str]], /
|
|
343
|
+
) -> Callable[[RunContext[AgentDeps]], Awaitable[str]]: ...
|
|
335
344
|
|
|
336
345
|
@overload
|
|
337
346
|
def system_prompt(self, func: Callable[[], str], /) -> Callable[[], str]: ...
|
|
@@ -344,7 +353,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
344
353
|
) -> _system_prompt.SystemPromptFunc[AgentDeps]:
|
|
345
354
|
"""Decorator to register a system prompt function.
|
|
346
355
|
|
|
347
|
-
Optionally takes [`
|
|
356
|
+
Optionally takes [`RunContext`][pydantic_ai.dependencies.RunContext] as it's only argument.
|
|
348
357
|
Can decorate a sync or async functions.
|
|
349
358
|
|
|
350
359
|
Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure
|
|
@@ -352,7 +361,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
352
361
|
|
|
353
362
|
Example:
|
|
354
363
|
```py
|
|
355
|
-
from pydantic_ai import Agent,
|
|
364
|
+
from pydantic_ai import Agent, RunContext
|
|
356
365
|
|
|
357
366
|
agent = Agent('test', deps_type=str)
|
|
358
367
|
|
|
@@ -361,12 +370,12 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
361
370
|
return 'foobar'
|
|
362
371
|
|
|
363
372
|
@agent.system_prompt
|
|
364
|
-
async def async_system_prompt(ctx:
|
|
373
|
+
async def async_system_prompt(ctx: RunContext[str]) -> str:
|
|
365
374
|
return f'{ctx.deps} is the best'
|
|
366
375
|
|
|
367
376
|
result = agent.run_sync('foobar', deps='spam')
|
|
368
377
|
print(result.data)
|
|
369
|
-
#> success (no
|
|
378
|
+
#> success (no tool calls)
|
|
370
379
|
```
|
|
371
380
|
"""
|
|
372
381
|
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func))
|
|
@@ -374,13 +383,13 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
374
383
|
|
|
375
384
|
@overload
|
|
376
385
|
def result_validator(
|
|
377
|
-
self, func: Callable[[
|
|
378
|
-
) -> Callable[[
|
|
386
|
+
self, func: Callable[[RunContext[AgentDeps], ResultData], ResultData], /
|
|
387
|
+
) -> Callable[[RunContext[AgentDeps], ResultData], ResultData]: ...
|
|
379
388
|
|
|
380
389
|
@overload
|
|
381
390
|
def result_validator(
|
|
382
|
-
self, func: Callable[[
|
|
383
|
-
) -> Callable[[
|
|
391
|
+
self, func: Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]], /
|
|
392
|
+
) -> Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]]: ...
|
|
384
393
|
|
|
385
394
|
@overload
|
|
386
395
|
def result_validator(self, func: Callable[[ResultData], ResultData], /) -> Callable[[ResultData], ResultData]: ...
|
|
@@ -395,7 +404,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
395
404
|
) -> _result.ResultValidatorFunc[AgentDeps, ResultData]:
|
|
396
405
|
"""Decorator to register a result validator function.
|
|
397
406
|
|
|
398
|
-
Optionally takes [`
|
|
407
|
+
Optionally takes [`RunContext`][pydantic_ai.dependencies.RunContext] as it's first argument.
|
|
399
408
|
Can decorate a sync or async functions.
|
|
400
409
|
|
|
401
410
|
Overloads for every possible signature of `result_validator` are included so the decorator doesn't obscure
|
|
@@ -403,7 +412,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
403
412
|
|
|
404
413
|
Example:
|
|
405
414
|
```py
|
|
406
|
-
from pydantic_ai import Agent,
|
|
415
|
+
from pydantic_ai import Agent, ModelRetry, RunContext
|
|
407
416
|
|
|
408
417
|
agent = Agent('test', deps_type=str)
|
|
409
418
|
|
|
@@ -414,61 +423,57 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
414
423
|
return data
|
|
415
424
|
|
|
416
425
|
@agent.result_validator
|
|
417
|
-
async def result_validator_deps(ctx:
|
|
426
|
+
async def result_validator_deps(ctx: RunContext[str], data: str) -> str:
|
|
418
427
|
if ctx.deps in data:
|
|
419
428
|
raise ModelRetry('wrong response')
|
|
420
429
|
return data
|
|
421
430
|
|
|
422
431
|
result = agent.run_sync('foobar', deps='spam')
|
|
423
432
|
print(result.data)
|
|
424
|
-
#> success (no
|
|
433
|
+
#> success (no tool calls)
|
|
425
434
|
```
|
|
426
435
|
"""
|
|
427
436
|
self._result_validators.append(_result.ResultValidator(func))
|
|
428
437
|
return func
|
|
429
438
|
|
|
430
439
|
@overload
|
|
431
|
-
def
|
|
432
|
-
self, func: RetrieverContextFunc[AgentDeps, RetrieverParams], /
|
|
433
|
-
) -> RetrieverContextFunc[AgentDeps, RetrieverParams]: ...
|
|
440
|
+
def tool(self, func: ToolContextFunc[AgentDeps, ToolParams], /) -> ToolContextFunc[AgentDeps, ToolParams]: ...
|
|
434
441
|
|
|
435
442
|
@overload
|
|
436
|
-
def
|
|
443
|
+
def tool(
|
|
437
444
|
self, /, *, retries: int | None = None
|
|
438
|
-
) -> Callable[
|
|
439
|
-
[RetrieverContextFunc[AgentDeps, RetrieverParams]], RetrieverContextFunc[AgentDeps, RetrieverParams]
|
|
440
|
-
]: ...
|
|
445
|
+
) -> Callable[[ToolContextFunc[AgentDeps, ToolParams]], ToolContextFunc[AgentDeps, ToolParams]]: ...
|
|
441
446
|
|
|
442
|
-
def
|
|
447
|
+
def tool(
|
|
443
448
|
self,
|
|
444
|
-
func:
|
|
449
|
+
func: ToolContextFunc[AgentDeps, ToolParams] | None = None,
|
|
445
450
|
/,
|
|
446
451
|
*,
|
|
447
452
|
retries: int | None = None,
|
|
448
453
|
) -> Any:
|
|
449
|
-
"""Decorator to register a
|
|
450
|
-
[`
|
|
454
|
+
"""Decorator to register a tool function which takes
|
|
455
|
+
[`RunContext`][pydantic_ai.dependencies.RunContext] as its first argument.
|
|
451
456
|
|
|
452
457
|
Can decorate a sync or async functions.
|
|
453
458
|
|
|
454
459
|
The docstring is inspected to extract both the tool description and description of each parameter,
|
|
455
|
-
[learn more](../agents.md#
|
|
460
|
+
[learn more](../agents.md#function-tools-and-schema).
|
|
456
461
|
|
|
457
|
-
We can't add overloads for every possible signature of
|
|
458
|
-
so the signature of functions decorated with `@agent.
|
|
462
|
+
We can't add overloads for every possible signature of tool, since the return type is a recursive union
|
|
463
|
+
so the signature of functions decorated with `@agent.tool` is obscured.
|
|
459
464
|
|
|
460
465
|
Example:
|
|
461
466
|
```py
|
|
462
|
-
from pydantic_ai import Agent,
|
|
467
|
+
from pydantic_ai import Agent, RunContext
|
|
463
468
|
|
|
464
469
|
agent = Agent('test', deps_type=int)
|
|
465
470
|
|
|
466
|
-
@agent.
|
|
467
|
-
def foobar(ctx:
|
|
471
|
+
@agent.tool
|
|
472
|
+
def foobar(ctx: RunContext[int], x: int) -> int:
|
|
468
473
|
return ctx.deps + x
|
|
469
474
|
|
|
470
|
-
@agent.
|
|
471
|
-
async def spam(ctx:
|
|
475
|
+
@agent.tool(retries=2)
|
|
476
|
+
async def spam(ctx: RunContext[str], y: float) -> float:
|
|
472
477
|
return ctx.deps + y
|
|
473
478
|
|
|
474
479
|
result = agent.run_sync('foobar', deps=1)
|
|
@@ -477,58 +482,56 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
477
482
|
```
|
|
478
483
|
|
|
479
484
|
Args:
|
|
480
|
-
func: The
|
|
481
|
-
retries: The number of retries to allow for this
|
|
485
|
+
func: The tool function to register.
|
|
486
|
+
retries: The number of retries to allow for this tool, defaults to the agent's default retries,
|
|
482
487
|
which defaults to 1.
|
|
483
488
|
""" # noqa: D205
|
|
484
489
|
if func is None:
|
|
485
490
|
|
|
486
|
-
def
|
|
487
|
-
func_:
|
|
488
|
-
) ->
|
|
491
|
+
def tool_decorator(
|
|
492
|
+
func_: ToolContextFunc[AgentDeps, ToolParams],
|
|
493
|
+
) -> ToolContextFunc[AgentDeps, ToolParams]:
|
|
489
494
|
# noinspection PyTypeChecker
|
|
490
|
-
self.
|
|
495
|
+
self._register_tool(_utils.Either(left=func_), retries)
|
|
491
496
|
return func_
|
|
492
497
|
|
|
493
|
-
return
|
|
498
|
+
return tool_decorator
|
|
494
499
|
else:
|
|
495
500
|
# noinspection PyTypeChecker
|
|
496
|
-
self.
|
|
501
|
+
self._register_tool(_utils.Either(left=func), retries)
|
|
497
502
|
return func
|
|
498
503
|
|
|
499
504
|
@overload
|
|
500
|
-
def
|
|
505
|
+
def tool_plain(self, func: ToolPlainFunc[ToolParams], /) -> ToolPlainFunc[ToolParams]: ...
|
|
501
506
|
|
|
502
507
|
@overload
|
|
503
|
-
def
|
|
508
|
+
def tool_plain(
|
|
504
509
|
self, /, *, retries: int | None = None
|
|
505
|
-
) -> Callable[[
|
|
510
|
+
) -> Callable[[ToolPlainFunc[ToolParams]], ToolPlainFunc[ToolParams]]: ...
|
|
506
511
|
|
|
507
|
-
def
|
|
508
|
-
|
|
509
|
-
) -> Any:
|
|
510
|
-
"""Decorator to register a retriever function which DOES NOT take `CallContext` as an argument.
|
|
512
|
+
def tool_plain(self, func: ToolPlainFunc[ToolParams] | None = None, /, *, retries: int | None = None) -> Any:
|
|
513
|
+
"""Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
|
|
511
514
|
|
|
512
515
|
Can decorate a sync or async functions.
|
|
513
516
|
|
|
514
517
|
The docstring is inspected to extract both the tool description and description of each parameter,
|
|
515
|
-
[learn more](../agents.md#
|
|
518
|
+
[learn more](../agents.md#function-tools-and-schema).
|
|
516
519
|
|
|
517
|
-
We can't add overloads for every possible signature of
|
|
518
|
-
so the signature of functions decorated with `@agent.
|
|
520
|
+
We can't add overloads for every possible signature of tool, since the return type is a recursive union
|
|
521
|
+
so the signature of functions decorated with `@agent.tool` is obscured.
|
|
519
522
|
|
|
520
523
|
Example:
|
|
521
524
|
```py
|
|
522
|
-
from pydantic_ai import Agent,
|
|
525
|
+
from pydantic_ai import Agent, RunContext
|
|
523
526
|
|
|
524
527
|
agent = Agent('test')
|
|
525
528
|
|
|
526
|
-
@agent.
|
|
527
|
-
def foobar(ctx:
|
|
529
|
+
@agent.tool
|
|
530
|
+
def foobar(ctx: RunContext[int]) -> int:
|
|
528
531
|
return 123
|
|
529
532
|
|
|
530
|
-
@agent.
|
|
531
|
-
async def spam(ctx:
|
|
533
|
+
@agent.tool(retries=2)
|
|
534
|
+
async def spam(ctx: RunContext[str]) -> float:
|
|
532
535
|
return 3.14
|
|
533
536
|
|
|
534
537
|
result = agent.run_sync('foobar', deps=1)
|
|
@@ -537,38 +540,36 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
537
540
|
```
|
|
538
541
|
|
|
539
542
|
Args:
|
|
540
|
-
func: The
|
|
541
|
-
retries: The number of retries to allow for this
|
|
543
|
+
func: The tool function to register.
|
|
544
|
+
retries: The number of retries to allow for this tool, defaults to the agent's default retries,
|
|
542
545
|
which defaults to 1.
|
|
543
546
|
"""
|
|
544
547
|
if func is None:
|
|
545
548
|
|
|
546
|
-
def
|
|
547
|
-
func_:
|
|
548
|
-
) ->
|
|
549
|
+
def tool_decorator(
|
|
550
|
+
func_: ToolPlainFunc[ToolParams],
|
|
551
|
+
) -> ToolPlainFunc[ToolParams]:
|
|
549
552
|
# noinspection PyTypeChecker
|
|
550
|
-
self.
|
|
553
|
+
self._register_tool(_utils.Either(right=func_), retries)
|
|
551
554
|
return func_
|
|
552
555
|
|
|
553
|
-
return
|
|
556
|
+
return tool_decorator
|
|
554
557
|
else:
|
|
555
|
-
self.
|
|
558
|
+
self._register_tool(_utils.Either(right=func), retries)
|
|
556
559
|
return func
|
|
557
560
|
|
|
558
|
-
def
|
|
559
|
-
|
|
560
|
-
) -> None:
|
|
561
|
-
"""Private utility to register a retriever function."""
|
|
561
|
+
def _register_tool(self, func: _r.ToolEitherFunc[AgentDeps, ToolParams], retries: int | None) -> None:
|
|
562
|
+
"""Private utility to register a tool function."""
|
|
562
563
|
retries_ = retries if retries is not None else self._default_retries
|
|
563
|
-
|
|
564
|
+
tool = _r.Tool[AgentDeps, ToolParams](func, retries_)
|
|
564
565
|
|
|
565
|
-
if self._result_schema and
|
|
566
|
-
raise ValueError(f'
|
|
566
|
+
if self._result_schema and tool.name in self._result_schema.tools:
|
|
567
|
+
raise ValueError(f'Tool name conflicts with result schema name: {tool.name!r}')
|
|
567
568
|
|
|
568
|
-
if
|
|
569
|
-
raise ValueError(f'
|
|
569
|
+
if tool.name in self._function_tools:
|
|
570
|
+
raise ValueError(f'Tool name conflicts with existing tool: {tool.name!r}')
|
|
570
571
|
|
|
571
|
-
self.
|
|
572
|
+
self._function_tools[tool.name] = tool
|
|
572
573
|
|
|
573
574
|
async def _get_agent_model(
|
|
574
575
|
self, model: models.Model | models.KnownModelName | None
|
|
@@ -583,11 +584,11 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
583
584
|
"""
|
|
584
585
|
model_: models.Model
|
|
585
586
|
if some_model := self._override_model:
|
|
586
|
-
# we don't want `
|
|
587
|
+
# we don't want `override()` to cover up errors from the model not being defined, hence this check
|
|
587
588
|
if model is None and self.model is None:
|
|
588
589
|
raise exceptions.UserError(
|
|
589
590
|
'`model` must be set either when creating the agent or when calling it. '
|
|
590
|
-
'(Even when `
|
|
591
|
+
'(Even when `override(model=...)` is customizing the model that will actually be called)'
|
|
591
592
|
)
|
|
592
593
|
model_ = some_model.value
|
|
593
594
|
custom_model = None
|
|
@@ -601,7 +602,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
601
602
|
raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')
|
|
602
603
|
|
|
603
604
|
result_tools = list(self._result_schema.tools.values()) if self._result_schema else None
|
|
604
|
-
agent_model = await model_.agent_model(self.
|
|
605
|
+
agent_model = await model_.agent_model(self._function_tools, self._allow_text_result, result_tools)
|
|
605
606
|
return model_, custom_model, agent_model
|
|
606
607
|
|
|
607
608
|
async def _prepare_messages(
|
|
@@ -663,12 +664,12 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
663
664
|
if not model_response.calls:
|
|
664
665
|
raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
|
|
665
666
|
|
|
666
|
-
# otherwise we run all
|
|
667
|
+
# otherwise we run all tool functions in parallel
|
|
667
668
|
messages: list[_messages.Message] = []
|
|
668
669
|
tasks: list[asyncio.Task[_messages.Message]] = []
|
|
669
670
|
for call in model_response.calls:
|
|
670
|
-
if
|
|
671
|
-
tasks.append(asyncio.create_task(
|
|
671
|
+
if tool := self._function_tools.get(call.tool_name):
|
|
672
|
+
tasks.append(asyncio.create_task(tool.run(deps, call), name=call.tool_name))
|
|
672
673
|
else:
|
|
673
674
|
messages.append(self._unknown_tool(call.tool_name))
|
|
674
675
|
|
|
@@ -719,7 +720,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
719
720
|
if self._result_schema.find_tool(structured_msg):
|
|
720
721
|
return _MarkFinalResult(model_response)
|
|
721
722
|
|
|
722
|
-
# the model is calling a
|
|
723
|
+
# the model is calling a tool function, consume the response to get the next message
|
|
723
724
|
async for _ in model_response:
|
|
724
725
|
pass
|
|
725
726
|
structured_msg = model_response.get()
|
|
@@ -727,11 +728,11 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
727
728
|
raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
|
|
728
729
|
messages: list[_messages.Message] = [structured_msg]
|
|
729
730
|
|
|
730
|
-
# we now run all
|
|
731
|
+
# we now run all tool functions in parallel
|
|
731
732
|
tasks: list[asyncio.Task[_messages.Message]] = []
|
|
732
733
|
for call in structured_msg.calls:
|
|
733
|
-
if
|
|
734
|
-
tasks.append(asyncio.create_task(
|
|
734
|
+
if tool := self._function_tools.get(call.tool_name):
|
|
735
|
+
tasks.append(asyncio.create_task(tool.run(deps, call), name=call.tool_name))
|
|
735
736
|
else:
|
|
736
737
|
messages.append(self._unknown_tool(call.tool_name))
|
|
737
738
|
|
|
@@ -763,7 +764,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
763
764
|
|
|
764
765
|
def _unknown_tool(self, tool_name: str) -> _messages.RetryPrompt:
|
|
765
766
|
self._incr_result_retry()
|
|
766
|
-
names = list(self.
|
|
767
|
+
names = list(self._function_tools.keys())
|
|
767
768
|
if self._result_schema:
|
|
768
769
|
names.extend(self._result_schema.tool_names())
|
|
769
770
|
if names:
|
|
@@ -775,7 +776,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
775
776
|
def _get_deps(self, deps: AgentDeps) -> AgentDeps:
|
|
776
777
|
"""Get deps for a run.
|
|
777
778
|
|
|
778
|
-
If we've overridden deps via `
|
|
779
|
+
If we've overridden deps via `_override_deps`, use that, otherwise use the deps passed to the call.
|
|
779
780
|
|
|
780
781
|
We could do runtime type checking of deps against `self._deps_type`, but that's a slippery slope.
|
|
781
782
|
"""
|
pydantic_ai/dependencies.py
CHANGED
|
@@ -13,13 +13,13 @@ else:
|
|
|
13
13
|
|
|
14
14
|
__all__ = (
|
|
15
15
|
'AgentDeps',
|
|
16
|
-
'
|
|
16
|
+
'RunContext',
|
|
17
17
|
'ResultValidatorFunc',
|
|
18
18
|
'SystemPromptFunc',
|
|
19
|
-
'
|
|
20
|
-
'
|
|
21
|
-
'
|
|
22
|
-
'
|
|
19
|
+
'ToolReturnValue',
|
|
20
|
+
'ToolContextFunc',
|
|
21
|
+
'ToolPlainFunc',
|
|
22
|
+
'ToolParams',
|
|
23
23
|
'JsonData',
|
|
24
24
|
)
|
|
25
25
|
|
|
@@ -28,7 +28,7 @@ AgentDeps = TypeVar('AgentDeps')
|
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
@dataclass
|
|
31
|
-
class
|
|
31
|
+
class RunContext(Generic[AgentDeps]):
|
|
32
32
|
"""Information about the current call."""
|
|
33
33
|
|
|
34
34
|
deps: AgentDeps
|
|
@@ -39,23 +39,23 @@ class CallContext(Generic[AgentDeps]):
|
|
|
39
39
|
"""Name of the tool being called."""
|
|
40
40
|
|
|
41
41
|
|
|
42
|
-
|
|
42
|
+
ToolParams = ParamSpec('ToolParams')
|
|
43
43
|
"""Retrieval function param spec."""
|
|
44
44
|
|
|
45
45
|
SystemPromptFunc = Union[
|
|
46
|
-
Callable[[
|
|
47
|
-
Callable[[
|
|
46
|
+
Callable[[RunContext[AgentDeps]], str],
|
|
47
|
+
Callable[[RunContext[AgentDeps]], Awaitable[str]],
|
|
48
48
|
Callable[[], str],
|
|
49
49
|
Callable[[], Awaitable[str]],
|
|
50
50
|
]
|
|
51
|
-
"""A function that may or maybe not take `
|
|
51
|
+
"""A function that may or maybe not take `RunContext` as an argument, and may or may not be async.
|
|
52
52
|
|
|
53
53
|
Usage `SystemPromptFunc[AgentDeps]`.
|
|
54
54
|
"""
|
|
55
55
|
|
|
56
56
|
ResultValidatorFunc = Union[
|
|
57
|
-
Callable[[
|
|
58
|
-
Callable[[
|
|
57
|
+
Callable[[RunContext[AgentDeps], ResultData], ResultData],
|
|
58
|
+
Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]],
|
|
59
59
|
Callable[[ResultData], ResultData],
|
|
60
60
|
Callable[[ResultData], Awaitable[ResultData]],
|
|
61
61
|
]
|
|
@@ -69,15 +69,15 @@ Usage `ResultValidator[AgentDeps, ResultData]`.
|
|
|
69
69
|
JsonData: TypeAlias = 'None | str | int | float | Sequence[JsonData] | Mapping[str, JsonData]'
|
|
70
70
|
"""Type representing any JSON data."""
|
|
71
71
|
|
|
72
|
-
|
|
73
|
-
"""Return value of a
|
|
74
|
-
|
|
75
|
-
"""A
|
|
72
|
+
ToolReturnValue = Union[JsonData, Awaitable[JsonData]]
|
|
73
|
+
"""Return value of a tool function."""
|
|
74
|
+
ToolContextFunc = Callable[Concatenate[RunContext[AgentDeps], ToolParams], ToolReturnValue]
|
|
75
|
+
"""A tool function that takes `RunContext` as the first argument.
|
|
76
76
|
|
|
77
|
-
Usage `
|
|
77
|
+
Usage `ToolContextFunc[AgentDeps, ToolParams]`.
|
|
78
78
|
"""
|
|
79
|
-
|
|
80
|
-
"""A
|
|
79
|
+
ToolPlainFunc = Callable[ToolParams, ToolReturnValue]
|
|
80
|
+
"""A tool function that does not take `RunContext` as the first argument.
|
|
81
81
|
|
|
82
|
-
Usage `
|
|
82
|
+
Usage `ToolPlainFunc[ToolParams]`.
|
|
83
83
|
"""
|
pydantic_ai/exceptions.py
CHANGED
|
@@ -6,7 +6,7 @@ __all__ = 'ModelRetry', 'UserError', 'UnexpectedModelBehavior'
|
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class ModelRetry(Exception):
|
|
9
|
-
"""Exception raised when a
|
|
9
|
+
"""Exception raised when a tool function should be retried.
|
|
10
10
|
|
|
11
11
|
The agent will return the message to the model and ask it to try calling the function/tool again.
|
|
12
12
|
"""
|