pydantic-ai-slim 0.0.6a4__py3-none-any.whl → 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +2 -2
- pydantic_ai/_pydantic.py +10 -10
- pydantic_ai/_result.py +4 -4
- pydantic_ai/_system_prompt.py +2 -2
- pydantic_ai/{_retriever.py → _tool.py} +13 -15
- pydantic_ai/_utils.py +9 -5
- pydantic_ai/agent.py +130 -128
- pydantic_ai/dependencies.py +20 -20
- pydantic_ai/exceptions.py +1 -1
- pydantic_ai/messages.py +16 -12
- pydantic_ai/models/__init__.py +3 -3
- pydantic_ai/models/function.py +10 -14
- pydantic_ai/models/gemini.py +12 -30
- pydantic_ai/models/groq.py +2 -2
- pydantic_ai/models/openai.py +2 -2
- pydantic_ai/models/test.py +34 -36
- pydantic_ai/models/vertexai.py +4 -59
- pydantic_ai/result.py +9 -7
- {pydantic_ai_slim-0.0.6a4.dist-info → pydantic_ai_slim-0.0.8.dist-info}/METADATA +7 -3
- pydantic_ai_slim-0.0.8.dist-info/RECORD +23 -0
- pydantic_ai_slim-0.0.6a4.dist-info/RECORD +0 -23
- {pydantic_ai_slim-0.0.6a4.dist-info → pydantic_ai_slim-0.0.8.dist-info}/WHEEL +0 -0
pydantic_ai/agent.py
CHANGED
|
@@ -11,15 +11,15 @@ from typing_extensions import assert_never
|
|
|
11
11
|
|
|
12
12
|
from . import (
|
|
13
13
|
_result,
|
|
14
|
-
_retriever as _r,
|
|
15
14
|
_system_prompt,
|
|
15
|
+
_tool as _r,
|
|
16
16
|
_utils,
|
|
17
17
|
exceptions,
|
|
18
18
|
messages as _messages,
|
|
19
19
|
models,
|
|
20
20
|
result,
|
|
21
21
|
)
|
|
22
|
-
from .dependencies import AgentDeps,
|
|
22
|
+
from .dependencies import AgentDeps, RunContext, ToolContextFunc, ToolParams, ToolPlainFunc
|
|
23
23
|
from .result import ResultData
|
|
24
24
|
|
|
25
25
|
__all__ = ('Agent',)
|
|
@@ -58,7 +58,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
58
58
|
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = field(repr=False)
|
|
59
59
|
_allow_text_result: bool = field(repr=False)
|
|
60
60
|
_system_prompts: tuple[str, ...] = field(repr=False)
|
|
61
|
-
|
|
61
|
+
_function_tools: dict[str, _r.Tool[AgentDeps, Any]] = field(repr=False)
|
|
62
62
|
_default_retries: int = field(repr=False)
|
|
63
63
|
_system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = field(repr=False)
|
|
64
64
|
_deps_type: type[AgentDeps] = field(repr=False)
|
|
@@ -75,8 +75,8 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
75
75
|
def __init__(
|
|
76
76
|
self,
|
|
77
77
|
model: models.Model | models.KnownModelName | None = None,
|
|
78
|
-
result_type: type[ResultData] = str,
|
|
79
78
|
*,
|
|
79
|
+
result_type: type[ResultData] = str,
|
|
80
80
|
system_prompt: str | Sequence[str] = (),
|
|
81
81
|
deps_type: type[AgentDeps] = NoneType,
|
|
82
82
|
retries: int = 1,
|
|
@@ -105,7 +105,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
105
105
|
it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately,
|
|
106
106
|
which checks for the necessary environment variables. Set this to `false`
|
|
107
107
|
to defer the evaluation until the first run. Useful if you want to
|
|
108
|
-
[override the model][pydantic_ai.Agent.
|
|
108
|
+
[override the model][pydantic_ai.Agent.override] for testing.
|
|
109
109
|
"""
|
|
110
110
|
if model is None or defer_model_check:
|
|
111
111
|
self.model = model
|
|
@@ -119,7 +119,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
119
119
|
self._allow_text_result = self._result_schema is None or self._result_schema.allow_text_result
|
|
120
120
|
|
|
121
121
|
self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt)
|
|
122
|
-
self.
|
|
122
|
+
self._function_tools: dict[str, _r.Tool[AgentDeps, Any]] = {}
|
|
123
123
|
self._deps_type = deps_type
|
|
124
124
|
self._default_retries = retries
|
|
125
125
|
self._system_prompt_functions = []
|
|
@@ -150,14 +150,6 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
150
150
|
|
|
151
151
|
deps = self._get_deps(deps)
|
|
152
152
|
|
|
153
|
-
new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
|
|
154
|
-
self.last_run_messages = messages
|
|
155
|
-
|
|
156
|
-
for retriever in self._retrievers.values():
|
|
157
|
-
retriever.reset()
|
|
158
|
-
|
|
159
|
-
cost = result.Cost()
|
|
160
|
-
|
|
161
153
|
with _logfire.span(
|
|
162
154
|
'agent run {prompt=}',
|
|
163
155
|
prompt=user_prompt,
|
|
@@ -165,6 +157,14 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
165
157
|
custom_model=custom_model,
|
|
166
158
|
model_name=model_used.name(),
|
|
167
159
|
) as run_span:
|
|
160
|
+
new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
|
|
161
|
+
self.last_run_messages = messages
|
|
162
|
+
|
|
163
|
+
for tool in self._function_tools.values():
|
|
164
|
+
tool.reset()
|
|
165
|
+
|
|
166
|
+
cost = result.Cost()
|
|
167
|
+
|
|
168
168
|
run_step = 0
|
|
169
169
|
while True:
|
|
170
170
|
run_step += 1
|
|
@@ -243,14 +243,6 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
243
243
|
|
|
244
244
|
deps = self._get_deps(deps)
|
|
245
245
|
|
|
246
|
-
new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
|
|
247
|
-
self.last_run_messages = messages
|
|
248
|
-
|
|
249
|
-
for retriever in self._retrievers.values():
|
|
250
|
-
retriever.reset()
|
|
251
|
-
|
|
252
|
-
cost = result.Cost()
|
|
253
|
-
|
|
254
246
|
with _logfire.span(
|
|
255
247
|
'agent run stream {prompt=}',
|
|
256
248
|
prompt=user_prompt,
|
|
@@ -258,6 +250,14 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
258
250
|
custom_model=custom_model,
|
|
259
251
|
model_name=model_used.name(),
|
|
260
252
|
) as run_span:
|
|
253
|
+
new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
|
|
254
|
+
self.last_run_messages = messages
|
|
255
|
+
|
|
256
|
+
for tool in self._function_tools.values():
|
|
257
|
+
tool.reset()
|
|
258
|
+
|
|
259
|
+
cost = result.Cost()
|
|
260
|
+
|
|
261
261
|
run_step = 0
|
|
262
262
|
while True:
|
|
263
263
|
run_step += 1
|
|
@@ -284,6 +284,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
284
284
|
self._result_schema,
|
|
285
285
|
deps,
|
|
286
286
|
self._result_validators,
|
|
287
|
+
lambda m: run_span.set_attribute('all_messages', messages),
|
|
287
288
|
)
|
|
288
289
|
return
|
|
289
290
|
else:
|
|
@@ -296,42 +297,51 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
296
297
|
cost += model_response.cost()
|
|
297
298
|
|
|
298
299
|
@contextmanager
|
|
299
|
-
def
|
|
300
|
-
|
|
300
|
+
def override(
|
|
301
|
+
self,
|
|
302
|
+
*,
|
|
303
|
+
deps: AgentDeps | _utils.Unset = _utils.UNSET,
|
|
304
|
+
model: models.Model | models.KnownModelName | _utils.Unset = _utils.UNSET,
|
|
305
|
+
) -> Iterator[None]:
|
|
306
|
+
"""Context manager to temporarily override agent dependencies and model.
|
|
307
|
+
|
|
308
|
+
This is particularly useful when testing.
|
|
301
309
|
|
|
302
310
|
Args:
|
|
303
|
-
|
|
311
|
+
deps: The dependencies to use instead of the dependencies passed to the agent run.
|
|
312
|
+
model: The model to use instead of the model passed to the agent run.
|
|
304
313
|
"""
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
self._override_deps = override_deps_before
|
|
314
|
+
if _utils.is_set(deps):
|
|
315
|
+
override_deps_before = self._override_deps
|
|
316
|
+
self._override_deps = _utils.Some(deps)
|
|
317
|
+
else:
|
|
318
|
+
override_deps_before = _utils.UNSET
|
|
311
319
|
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
320
|
+
# noinspection PyTypeChecker
|
|
321
|
+
if _utils.is_set(model):
|
|
322
|
+
override_model_before = self._override_model
|
|
323
|
+
# noinspection PyTypeChecker
|
|
324
|
+
self._override_model = _utils.Some(models.infer_model(model)) # pyright: ignore[reportArgumentType]
|
|
325
|
+
else:
|
|
326
|
+
override_model_before = _utils.UNSET
|
|
315
327
|
|
|
316
|
-
Args:
|
|
317
|
-
overriding_model: The model to use instead of the model passed to the agent run.
|
|
318
|
-
"""
|
|
319
|
-
override_model_before = self._override_model
|
|
320
|
-
self._override_model = _utils.Some(models.infer_model(overriding_model))
|
|
321
328
|
try:
|
|
322
329
|
yield
|
|
323
330
|
finally:
|
|
324
|
-
|
|
331
|
+
if _utils.is_set(override_deps_before):
|
|
332
|
+
self._override_deps = override_deps_before
|
|
333
|
+
if _utils.is_set(override_model_before):
|
|
334
|
+
self._override_model = override_model_before
|
|
325
335
|
|
|
326
336
|
@overload
|
|
327
337
|
def system_prompt(
|
|
328
|
-
self, func: Callable[[
|
|
329
|
-
) -> Callable[[
|
|
338
|
+
self, func: Callable[[RunContext[AgentDeps]], str], /
|
|
339
|
+
) -> Callable[[RunContext[AgentDeps]], str]: ...
|
|
330
340
|
|
|
331
341
|
@overload
|
|
332
342
|
def system_prompt(
|
|
333
|
-
self, func: Callable[[
|
|
334
|
-
) -> Callable[[
|
|
343
|
+
self, func: Callable[[RunContext[AgentDeps]], Awaitable[str]], /
|
|
344
|
+
) -> Callable[[RunContext[AgentDeps]], Awaitable[str]]: ...
|
|
335
345
|
|
|
336
346
|
@overload
|
|
337
347
|
def system_prompt(self, func: Callable[[], str], /) -> Callable[[], str]: ...
|
|
@@ -344,7 +354,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
344
354
|
) -> _system_prompt.SystemPromptFunc[AgentDeps]:
|
|
345
355
|
"""Decorator to register a system prompt function.
|
|
346
356
|
|
|
347
|
-
Optionally takes [`
|
|
357
|
+
Optionally takes [`RunContext`][pydantic_ai.dependencies.RunContext] as it's only argument.
|
|
348
358
|
Can decorate a sync or async functions.
|
|
349
359
|
|
|
350
360
|
Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure
|
|
@@ -352,7 +362,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
352
362
|
|
|
353
363
|
Example:
|
|
354
364
|
```py
|
|
355
|
-
from pydantic_ai import Agent,
|
|
365
|
+
from pydantic_ai import Agent, RunContext
|
|
356
366
|
|
|
357
367
|
agent = Agent('test', deps_type=str)
|
|
358
368
|
|
|
@@ -361,12 +371,12 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
361
371
|
return 'foobar'
|
|
362
372
|
|
|
363
373
|
@agent.system_prompt
|
|
364
|
-
async def async_system_prompt(ctx:
|
|
374
|
+
async def async_system_prompt(ctx: RunContext[str]) -> str:
|
|
365
375
|
return f'{ctx.deps} is the best'
|
|
366
376
|
|
|
367
377
|
result = agent.run_sync('foobar', deps='spam')
|
|
368
378
|
print(result.data)
|
|
369
|
-
#> success (no
|
|
379
|
+
#> success (no tool calls)
|
|
370
380
|
```
|
|
371
381
|
"""
|
|
372
382
|
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func))
|
|
@@ -374,13 +384,13 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
374
384
|
|
|
375
385
|
@overload
|
|
376
386
|
def result_validator(
|
|
377
|
-
self, func: Callable[[
|
|
378
|
-
) -> Callable[[
|
|
387
|
+
self, func: Callable[[RunContext[AgentDeps], ResultData], ResultData], /
|
|
388
|
+
) -> Callable[[RunContext[AgentDeps], ResultData], ResultData]: ...
|
|
379
389
|
|
|
380
390
|
@overload
|
|
381
391
|
def result_validator(
|
|
382
|
-
self, func: Callable[[
|
|
383
|
-
) -> Callable[[
|
|
392
|
+
self, func: Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]], /
|
|
393
|
+
) -> Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]]: ...
|
|
384
394
|
|
|
385
395
|
@overload
|
|
386
396
|
def result_validator(self, func: Callable[[ResultData], ResultData], /) -> Callable[[ResultData], ResultData]: ...
|
|
@@ -395,7 +405,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
395
405
|
) -> _result.ResultValidatorFunc[AgentDeps, ResultData]:
|
|
396
406
|
"""Decorator to register a result validator function.
|
|
397
407
|
|
|
398
|
-
Optionally takes [`
|
|
408
|
+
Optionally takes [`RunContext`][pydantic_ai.dependencies.RunContext] as it's first argument.
|
|
399
409
|
Can decorate a sync or async functions.
|
|
400
410
|
|
|
401
411
|
Overloads for every possible signature of `result_validator` are included so the decorator doesn't obscure
|
|
@@ -403,7 +413,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
403
413
|
|
|
404
414
|
Example:
|
|
405
415
|
```py
|
|
406
|
-
from pydantic_ai import Agent,
|
|
416
|
+
from pydantic_ai import Agent, ModelRetry, RunContext
|
|
407
417
|
|
|
408
418
|
agent = Agent('test', deps_type=str)
|
|
409
419
|
|
|
@@ -414,61 +424,57 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
414
424
|
return data
|
|
415
425
|
|
|
416
426
|
@agent.result_validator
|
|
417
|
-
async def result_validator_deps(ctx:
|
|
427
|
+
async def result_validator_deps(ctx: RunContext[str], data: str) -> str:
|
|
418
428
|
if ctx.deps in data:
|
|
419
429
|
raise ModelRetry('wrong response')
|
|
420
430
|
return data
|
|
421
431
|
|
|
422
432
|
result = agent.run_sync('foobar', deps='spam')
|
|
423
433
|
print(result.data)
|
|
424
|
-
#> success (no
|
|
434
|
+
#> success (no tool calls)
|
|
425
435
|
```
|
|
426
436
|
"""
|
|
427
437
|
self._result_validators.append(_result.ResultValidator(func))
|
|
428
438
|
return func
|
|
429
439
|
|
|
430
440
|
@overload
|
|
431
|
-
def
|
|
432
|
-
self, func: RetrieverContextFunc[AgentDeps, RetrieverParams], /
|
|
433
|
-
) -> RetrieverContextFunc[AgentDeps, RetrieverParams]: ...
|
|
441
|
+
def tool(self, func: ToolContextFunc[AgentDeps, ToolParams], /) -> ToolContextFunc[AgentDeps, ToolParams]: ...
|
|
434
442
|
|
|
435
443
|
@overload
|
|
436
|
-
def
|
|
444
|
+
def tool(
|
|
437
445
|
self, /, *, retries: int | None = None
|
|
438
|
-
) -> Callable[
|
|
439
|
-
[RetrieverContextFunc[AgentDeps, RetrieverParams]], RetrieverContextFunc[AgentDeps, RetrieverParams]
|
|
440
|
-
]: ...
|
|
446
|
+
) -> Callable[[ToolContextFunc[AgentDeps, ToolParams]], ToolContextFunc[AgentDeps, ToolParams]]: ...
|
|
441
447
|
|
|
442
|
-
def
|
|
448
|
+
def tool(
|
|
443
449
|
self,
|
|
444
|
-
func:
|
|
450
|
+
func: ToolContextFunc[AgentDeps, ToolParams] | None = None,
|
|
445
451
|
/,
|
|
446
452
|
*,
|
|
447
453
|
retries: int | None = None,
|
|
448
454
|
) -> Any:
|
|
449
|
-
"""Decorator to register a
|
|
450
|
-
[`
|
|
455
|
+
"""Decorator to register a tool function which takes
|
|
456
|
+
[`RunContext`][pydantic_ai.dependencies.RunContext] as its first argument.
|
|
451
457
|
|
|
452
458
|
Can decorate a sync or async functions.
|
|
453
459
|
|
|
454
460
|
The docstring is inspected to extract both the tool description and description of each parameter,
|
|
455
|
-
[learn more](../agents.md#
|
|
461
|
+
[learn more](../agents.md#function-tools-and-schema).
|
|
456
462
|
|
|
457
|
-
We can't add overloads for every possible signature of
|
|
458
|
-
so the signature of functions decorated with `@agent.
|
|
463
|
+
We can't add overloads for every possible signature of tool, since the return type is a recursive union
|
|
464
|
+
so the signature of functions decorated with `@agent.tool` is obscured.
|
|
459
465
|
|
|
460
466
|
Example:
|
|
461
467
|
```py
|
|
462
|
-
from pydantic_ai import Agent,
|
|
468
|
+
from pydantic_ai import Agent, RunContext
|
|
463
469
|
|
|
464
470
|
agent = Agent('test', deps_type=int)
|
|
465
471
|
|
|
466
|
-
@agent.
|
|
467
|
-
def foobar(ctx:
|
|
472
|
+
@agent.tool
|
|
473
|
+
def foobar(ctx: RunContext[int], x: int) -> int:
|
|
468
474
|
return ctx.deps + x
|
|
469
475
|
|
|
470
|
-
@agent.
|
|
471
|
-
async def spam(ctx:
|
|
476
|
+
@agent.tool(retries=2)
|
|
477
|
+
async def spam(ctx: RunContext[str], y: float) -> float:
|
|
472
478
|
return ctx.deps + y
|
|
473
479
|
|
|
474
480
|
result = agent.run_sync('foobar', deps=1)
|
|
@@ -477,58 +483,56 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
477
483
|
```
|
|
478
484
|
|
|
479
485
|
Args:
|
|
480
|
-
func: The
|
|
481
|
-
retries: The number of retries to allow for this
|
|
486
|
+
func: The tool function to register.
|
|
487
|
+
retries: The number of retries to allow for this tool, defaults to the agent's default retries,
|
|
482
488
|
which defaults to 1.
|
|
483
489
|
""" # noqa: D205
|
|
484
490
|
if func is None:
|
|
485
491
|
|
|
486
|
-
def
|
|
487
|
-
func_:
|
|
488
|
-
) ->
|
|
492
|
+
def tool_decorator(
|
|
493
|
+
func_: ToolContextFunc[AgentDeps, ToolParams],
|
|
494
|
+
) -> ToolContextFunc[AgentDeps, ToolParams]:
|
|
489
495
|
# noinspection PyTypeChecker
|
|
490
|
-
self.
|
|
496
|
+
self._register_tool(_utils.Either(left=func_), retries)
|
|
491
497
|
return func_
|
|
492
498
|
|
|
493
|
-
return
|
|
499
|
+
return tool_decorator
|
|
494
500
|
else:
|
|
495
501
|
# noinspection PyTypeChecker
|
|
496
|
-
self.
|
|
502
|
+
self._register_tool(_utils.Either(left=func), retries)
|
|
497
503
|
return func
|
|
498
504
|
|
|
499
505
|
@overload
|
|
500
|
-
def
|
|
506
|
+
def tool_plain(self, func: ToolPlainFunc[ToolParams], /) -> ToolPlainFunc[ToolParams]: ...
|
|
501
507
|
|
|
502
508
|
@overload
|
|
503
|
-
def
|
|
509
|
+
def tool_plain(
|
|
504
510
|
self, /, *, retries: int | None = None
|
|
505
|
-
) -> Callable[[
|
|
511
|
+
) -> Callable[[ToolPlainFunc[ToolParams]], ToolPlainFunc[ToolParams]]: ...
|
|
506
512
|
|
|
507
|
-
def
|
|
508
|
-
|
|
509
|
-
) -> Any:
|
|
510
|
-
"""Decorator to register a retriever function which DOES NOT take `CallContext` as an argument.
|
|
513
|
+
def tool_plain(self, func: ToolPlainFunc[ToolParams] | None = None, /, *, retries: int | None = None) -> Any:
|
|
514
|
+
"""Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
|
|
511
515
|
|
|
512
516
|
Can decorate a sync or async functions.
|
|
513
517
|
|
|
514
518
|
The docstring is inspected to extract both the tool description and description of each parameter,
|
|
515
|
-
[learn more](../agents.md#
|
|
519
|
+
[learn more](../agents.md#function-tools-and-schema).
|
|
516
520
|
|
|
517
|
-
We can't add overloads for every possible signature of
|
|
518
|
-
so the signature of functions decorated with `@agent.
|
|
521
|
+
We can't add overloads for every possible signature of tool, since the return type is a recursive union
|
|
522
|
+
so the signature of functions decorated with `@agent.tool` is obscured.
|
|
519
523
|
|
|
520
524
|
Example:
|
|
521
525
|
```py
|
|
522
|
-
from pydantic_ai import Agent,
|
|
526
|
+
from pydantic_ai import Agent, RunContext
|
|
523
527
|
|
|
524
528
|
agent = Agent('test')
|
|
525
529
|
|
|
526
|
-
@agent.
|
|
527
|
-
def foobar(ctx:
|
|
530
|
+
@agent.tool
|
|
531
|
+
def foobar(ctx: RunContext[int]) -> int:
|
|
528
532
|
return 123
|
|
529
533
|
|
|
530
|
-
@agent.
|
|
531
|
-
async def spam(ctx:
|
|
534
|
+
@agent.tool(retries=2)
|
|
535
|
+
async def spam(ctx: RunContext[str]) -> float:
|
|
532
536
|
return 3.14
|
|
533
537
|
|
|
534
538
|
result = agent.run_sync('foobar', deps=1)
|
|
@@ -537,38 +541,36 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
537
541
|
```
|
|
538
542
|
|
|
539
543
|
Args:
|
|
540
|
-
func: The
|
|
541
|
-
retries: The number of retries to allow for this
|
|
544
|
+
func: The tool function to register.
|
|
545
|
+
retries: The number of retries to allow for this tool, defaults to the agent's default retries,
|
|
542
546
|
which defaults to 1.
|
|
543
547
|
"""
|
|
544
548
|
if func is None:
|
|
545
549
|
|
|
546
|
-
def
|
|
547
|
-
func_:
|
|
548
|
-
) ->
|
|
550
|
+
def tool_decorator(
|
|
551
|
+
func_: ToolPlainFunc[ToolParams],
|
|
552
|
+
) -> ToolPlainFunc[ToolParams]:
|
|
549
553
|
# noinspection PyTypeChecker
|
|
550
|
-
self.
|
|
554
|
+
self._register_tool(_utils.Either(right=func_), retries)
|
|
551
555
|
return func_
|
|
552
556
|
|
|
553
|
-
return
|
|
557
|
+
return tool_decorator
|
|
554
558
|
else:
|
|
555
|
-
self.
|
|
559
|
+
self._register_tool(_utils.Either(right=func), retries)
|
|
556
560
|
return func
|
|
557
561
|
|
|
558
|
-
def
|
|
559
|
-
|
|
560
|
-
) -> None:
|
|
561
|
-
"""Private utility to register a retriever function."""
|
|
562
|
+
def _register_tool(self, func: _r.ToolEitherFunc[AgentDeps, ToolParams], retries: int | None) -> None:
|
|
563
|
+
"""Private utility to register a tool function."""
|
|
562
564
|
retries_ = retries if retries is not None else self._default_retries
|
|
563
|
-
|
|
565
|
+
tool = _r.Tool[AgentDeps, ToolParams](func, retries_)
|
|
564
566
|
|
|
565
|
-
if self._result_schema and
|
|
566
|
-
raise ValueError(f'
|
|
567
|
+
if self._result_schema and tool.name in self._result_schema.tools:
|
|
568
|
+
raise ValueError(f'Tool name conflicts with result schema name: {tool.name!r}')
|
|
567
569
|
|
|
568
|
-
if
|
|
569
|
-
raise ValueError(f'
|
|
570
|
+
if tool.name in self._function_tools:
|
|
571
|
+
raise ValueError(f'Tool name conflicts with existing tool: {tool.name!r}')
|
|
570
572
|
|
|
571
|
-
self.
|
|
573
|
+
self._function_tools[tool.name] = tool
|
|
572
574
|
|
|
573
575
|
async def _get_agent_model(
|
|
574
576
|
self, model: models.Model | models.KnownModelName | None
|
|
@@ -583,11 +585,11 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
583
585
|
"""
|
|
584
586
|
model_: models.Model
|
|
585
587
|
if some_model := self._override_model:
|
|
586
|
-
# we don't want `
|
|
588
|
+
# we don't want `override()` to cover up errors from the model not being defined, hence this check
|
|
587
589
|
if model is None and self.model is None:
|
|
588
590
|
raise exceptions.UserError(
|
|
589
591
|
'`model` must be set either when creating the agent or when calling it. '
|
|
590
|
-
'(Even when `
|
|
592
|
+
'(Even when `override(model=...)` is customizing the model that will actually be called)'
|
|
591
593
|
)
|
|
592
594
|
model_ = some_model.value
|
|
593
595
|
custom_model = None
|
|
@@ -601,7 +603,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
601
603
|
raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')
|
|
602
604
|
|
|
603
605
|
result_tools = list(self._result_schema.tools.values()) if self._result_schema else None
|
|
604
|
-
agent_model = await model_.agent_model(self.
|
|
606
|
+
agent_model = await model_.agent_model(self._function_tools, self._allow_text_result, result_tools)
|
|
605
607
|
return model_, custom_model, agent_model
|
|
606
608
|
|
|
607
609
|
async def _prepare_messages(
|
|
@@ -663,12 +665,12 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
663
665
|
if not model_response.calls:
|
|
664
666
|
raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
|
|
665
667
|
|
|
666
|
-
# otherwise we run all
|
|
668
|
+
# otherwise we run all tool functions in parallel
|
|
667
669
|
messages: list[_messages.Message] = []
|
|
668
670
|
tasks: list[asyncio.Task[_messages.Message]] = []
|
|
669
671
|
for call in model_response.calls:
|
|
670
|
-
if
|
|
671
|
-
tasks.append(asyncio.create_task(
|
|
672
|
+
if tool := self._function_tools.get(call.tool_name):
|
|
673
|
+
tasks.append(asyncio.create_task(tool.run(deps, call), name=call.tool_name))
|
|
672
674
|
else:
|
|
673
675
|
messages.append(self._unknown_tool(call.tool_name))
|
|
674
676
|
|
|
@@ -719,7 +721,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
719
721
|
if self._result_schema.find_tool(structured_msg):
|
|
720
722
|
return _MarkFinalResult(model_response)
|
|
721
723
|
|
|
722
|
-
# the model is calling a
|
|
724
|
+
# the model is calling a tool function, consume the response to get the next message
|
|
723
725
|
async for _ in model_response:
|
|
724
726
|
pass
|
|
725
727
|
structured_msg = model_response.get()
|
|
@@ -727,11 +729,11 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
727
729
|
raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
|
|
728
730
|
messages: list[_messages.Message] = [structured_msg]
|
|
729
731
|
|
|
730
|
-
# we now run all
|
|
732
|
+
# we now run all tool functions in parallel
|
|
731
733
|
tasks: list[asyncio.Task[_messages.Message]] = []
|
|
732
734
|
for call in structured_msg.calls:
|
|
733
|
-
if
|
|
734
|
-
tasks.append(asyncio.create_task(
|
|
735
|
+
if tool := self._function_tools.get(call.tool_name):
|
|
736
|
+
tasks.append(asyncio.create_task(tool.run(deps, call), name=call.tool_name))
|
|
735
737
|
else:
|
|
736
738
|
messages.append(self._unknown_tool(call.tool_name))
|
|
737
739
|
|
|
@@ -763,7 +765,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
763
765
|
|
|
764
766
|
def _unknown_tool(self, tool_name: str) -> _messages.RetryPrompt:
|
|
765
767
|
self._incr_result_retry()
|
|
766
|
-
names = list(self.
|
|
768
|
+
names = list(self._function_tools.keys())
|
|
767
769
|
if self._result_schema:
|
|
768
770
|
names.extend(self._result_schema.tool_names())
|
|
769
771
|
if names:
|
|
@@ -775,7 +777,7 @@ class Agent(Generic[AgentDeps, ResultData]):
|
|
|
775
777
|
def _get_deps(self, deps: AgentDeps) -> AgentDeps:
|
|
776
778
|
"""Get deps for a run.
|
|
777
779
|
|
|
778
|
-
If we've overridden deps via `
|
|
780
|
+
If we've overridden deps via `_override_deps`, use that, otherwise use the deps passed to the call.
|
|
779
781
|
|
|
780
782
|
We could do runtime type checking of deps against `self._deps_type`, but that's a slippery slope.
|
|
781
783
|
"""
|
pydantic_ai/dependencies.py
CHANGED
|
@@ -13,13 +13,13 @@ else:
|
|
|
13
13
|
|
|
14
14
|
__all__ = (
|
|
15
15
|
'AgentDeps',
|
|
16
|
-
'
|
|
16
|
+
'RunContext',
|
|
17
17
|
'ResultValidatorFunc',
|
|
18
18
|
'SystemPromptFunc',
|
|
19
|
-
'
|
|
20
|
-
'
|
|
21
|
-
'
|
|
22
|
-
'
|
|
19
|
+
'ToolReturnValue',
|
|
20
|
+
'ToolContextFunc',
|
|
21
|
+
'ToolPlainFunc',
|
|
22
|
+
'ToolParams',
|
|
23
23
|
'JsonData',
|
|
24
24
|
)
|
|
25
25
|
|
|
@@ -28,7 +28,7 @@ AgentDeps = TypeVar('AgentDeps')
|
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
@dataclass
|
|
31
|
-
class
|
|
31
|
+
class RunContext(Generic[AgentDeps]):
|
|
32
32
|
"""Information about the current call."""
|
|
33
33
|
|
|
34
34
|
deps: AgentDeps
|
|
@@ -39,23 +39,23 @@ class CallContext(Generic[AgentDeps]):
|
|
|
39
39
|
"""Name of the tool being called."""
|
|
40
40
|
|
|
41
41
|
|
|
42
|
-
|
|
42
|
+
ToolParams = ParamSpec('ToolParams')
|
|
43
43
|
"""Retrieval function param spec."""
|
|
44
44
|
|
|
45
45
|
SystemPromptFunc = Union[
|
|
46
|
-
Callable[[
|
|
47
|
-
Callable[[
|
|
46
|
+
Callable[[RunContext[AgentDeps]], str],
|
|
47
|
+
Callable[[RunContext[AgentDeps]], Awaitable[str]],
|
|
48
48
|
Callable[[], str],
|
|
49
49
|
Callable[[], Awaitable[str]],
|
|
50
50
|
]
|
|
51
|
-
"""A function that may or maybe not take `
|
|
51
|
+
"""A function that may or maybe not take `RunContext` as an argument, and may or may not be async.
|
|
52
52
|
|
|
53
53
|
Usage `SystemPromptFunc[AgentDeps]`.
|
|
54
54
|
"""
|
|
55
55
|
|
|
56
56
|
ResultValidatorFunc = Union[
|
|
57
|
-
Callable[[
|
|
58
|
-
Callable[[
|
|
57
|
+
Callable[[RunContext[AgentDeps], ResultData], ResultData],
|
|
58
|
+
Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]],
|
|
59
59
|
Callable[[ResultData], ResultData],
|
|
60
60
|
Callable[[ResultData], Awaitable[ResultData]],
|
|
61
61
|
]
|
|
@@ -69,15 +69,15 @@ Usage `ResultValidator[AgentDeps, ResultData]`.
|
|
|
69
69
|
JsonData: TypeAlias = 'None | str | int | float | Sequence[JsonData] | Mapping[str, JsonData]'
|
|
70
70
|
"""Type representing any JSON data."""
|
|
71
71
|
|
|
72
|
-
|
|
73
|
-
"""Return value of a
|
|
74
|
-
|
|
75
|
-
"""A
|
|
72
|
+
ToolReturnValue = Union[JsonData, Awaitable[JsonData]]
|
|
73
|
+
"""Return value of a tool function."""
|
|
74
|
+
ToolContextFunc = Callable[Concatenate[RunContext[AgentDeps], ToolParams], ToolReturnValue]
|
|
75
|
+
"""A tool function that takes `RunContext` as the first argument.
|
|
76
76
|
|
|
77
|
-
Usage `
|
|
77
|
+
Usage `ToolContextFunc[AgentDeps, ToolParams]`.
|
|
78
78
|
"""
|
|
79
|
-
|
|
80
|
-
"""A
|
|
79
|
+
ToolPlainFunc = Callable[ToolParams, ToolReturnValue]
|
|
80
|
+
"""A tool function that does not take `RunContext` as the first argument.
|
|
81
81
|
|
|
82
|
-
Usage `
|
|
82
|
+
Usage `ToolPlainFunc[ToolParams]`.
|
|
83
83
|
"""
|
pydantic_ai/exceptions.py
CHANGED
|
@@ -6,7 +6,7 @@ __all__ = 'ModelRetry', 'UserError', 'UnexpectedModelBehavior'
|
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class ModelRetry(Exception):
|
|
9
|
-
"""Exception raised when a
|
|
9
|
+
"""Exception raised when a tool function should be retried.
|
|
10
10
|
|
|
11
11
|
The agent will return the message to the model and ask it to try calling the function/tool again.
|
|
12
12
|
"""
|