pydantic-ai-slim 0.0.6a4__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/agent.py CHANGED
@@ -11,15 +11,15 @@ from typing_extensions import assert_never
11
11
 
12
12
  from . import (
13
13
  _result,
14
- _retriever as _r,
15
14
  _system_prompt,
15
+ _tool as _r,
16
16
  _utils,
17
17
  exceptions,
18
18
  messages as _messages,
19
19
  models,
20
20
  result,
21
21
  )
22
- from .dependencies import AgentDeps, CallContext, RetrieverContextFunc, RetrieverParams, RetrieverPlainFunc
22
+ from .dependencies import AgentDeps, RunContext, ToolContextFunc, ToolParams, ToolPlainFunc
23
23
  from .result import ResultData
24
24
 
25
25
  __all__ = ('Agent',)
@@ -58,7 +58,7 @@ class Agent(Generic[AgentDeps, ResultData]):
58
58
  _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = field(repr=False)
59
59
  _allow_text_result: bool = field(repr=False)
60
60
  _system_prompts: tuple[str, ...] = field(repr=False)
61
- _retrievers: dict[str, _r.Retriever[AgentDeps, Any]] = field(repr=False)
61
+ _function_tools: dict[str, _r.Tool[AgentDeps, Any]] = field(repr=False)
62
62
  _default_retries: int = field(repr=False)
63
63
  _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = field(repr=False)
64
64
  _deps_type: type[AgentDeps] = field(repr=False)
@@ -75,8 +75,8 @@ class Agent(Generic[AgentDeps, ResultData]):
75
75
  def __init__(
76
76
  self,
77
77
  model: models.Model | models.KnownModelName | None = None,
78
- result_type: type[ResultData] = str,
79
78
  *,
79
+ result_type: type[ResultData] = str,
80
80
  system_prompt: str | Sequence[str] = (),
81
81
  deps_type: type[AgentDeps] = NoneType,
82
82
  retries: int = 1,
@@ -105,7 +105,7 @@ class Agent(Generic[AgentDeps, ResultData]):
105
105
  it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately,
106
106
  which checks for the necessary environment variables. Set this to `false`
107
107
  to defer the evaluation until the first run. Useful if you want to
108
- [override the model][pydantic_ai.Agent.override_model] for testing.
108
+ [override the model][pydantic_ai.Agent.override] for testing.
109
109
  """
110
110
  if model is None or defer_model_check:
111
111
  self.model = model
@@ -119,7 +119,7 @@ class Agent(Generic[AgentDeps, ResultData]):
119
119
  self._allow_text_result = self._result_schema is None or self._result_schema.allow_text_result
120
120
 
121
121
  self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt)
122
- self._retrievers: dict[str, _r.Retriever[AgentDeps, Any]] = {}
122
+ self._function_tools: dict[str, _r.Tool[AgentDeps, Any]] = {}
123
123
  self._deps_type = deps_type
124
124
  self._default_retries = retries
125
125
  self._system_prompt_functions = []
@@ -150,14 +150,6 @@ class Agent(Generic[AgentDeps, ResultData]):
150
150
 
151
151
  deps = self._get_deps(deps)
152
152
 
153
- new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
154
- self.last_run_messages = messages
155
-
156
- for retriever in self._retrievers.values():
157
- retriever.reset()
158
-
159
- cost = result.Cost()
160
-
161
153
  with _logfire.span(
162
154
  'agent run {prompt=}',
163
155
  prompt=user_prompt,
@@ -165,6 +157,14 @@ class Agent(Generic[AgentDeps, ResultData]):
165
157
  custom_model=custom_model,
166
158
  model_name=model_used.name(),
167
159
  ) as run_span:
160
+ new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
161
+ self.last_run_messages = messages
162
+
163
+ for tool in self._function_tools.values():
164
+ tool.reset()
165
+
166
+ cost = result.Cost()
167
+
168
168
  run_step = 0
169
169
  while True:
170
170
  run_step += 1
@@ -243,14 +243,6 @@ class Agent(Generic[AgentDeps, ResultData]):
243
243
 
244
244
  deps = self._get_deps(deps)
245
245
 
246
- new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
247
- self.last_run_messages = messages
248
-
249
- for retriever in self._retrievers.values():
250
- retriever.reset()
251
-
252
- cost = result.Cost()
253
-
254
246
  with _logfire.span(
255
247
  'agent run stream {prompt=}',
256
248
  prompt=user_prompt,
@@ -258,6 +250,14 @@ class Agent(Generic[AgentDeps, ResultData]):
258
250
  custom_model=custom_model,
259
251
  model_name=model_used.name(),
260
252
  ) as run_span:
253
+ new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
254
+ self.last_run_messages = messages
255
+
256
+ for tool in self._function_tools.values():
257
+ tool.reset()
258
+
259
+ cost = result.Cost()
260
+
261
261
  run_step = 0
262
262
  while True:
263
263
  run_step += 1
@@ -296,42 +296,51 @@ class Agent(Generic[AgentDeps, ResultData]):
296
296
  cost += model_response.cost()
297
297
 
298
298
  @contextmanager
299
- def override_deps(self, overriding_deps: AgentDeps) -> Iterator[None]:
300
- """Context manager to temporarily override agent dependencies, this is particularly useful when testing.
299
+ def override(
300
+ self,
301
+ *,
302
+ deps: AgentDeps | _utils.Unset = _utils.UNSET,
303
+ model: models.Model | models.KnownModelName | _utils.Unset = _utils.UNSET,
304
+ ) -> Iterator[None]:
305
+ """Context manager to temporarily override agent dependencies and model.
306
+
307
+ This is particularly useful when testing.
301
308
 
302
309
  Args:
303
- overriding_deps: The dependencies to use instead of the dependencies passed to the agent run.
310
+ deps: The dependencies to use instead of the dependencies passed to the agent run.
311
+ model: The model to use instead of the model passed to the agent run.
304
312
  """
305
- override_deps_before = self._override_deps
306
- self._override_deps = _utils.Some(overriding_deps)
307
- try:
308
- yield
309
- finally:
310
- self._override_deps = override_deps_before
313
+ if _utils.is_set(deps):
314
+ override_deps_before = self._override_deps
315
+ self._override_deps = _utils.Some(deps)
316
+ else:
317
+ override_deps_before = _utils.UNSET
311
318
 
312
- @contextmanager
313
- def override_model(self, overriding_model: models.Model | models.KnownModelName) -> Iterator[None]:
314
- """Context manager to temporarily override the model used by the agent.
319
+ # noinspection PyTypeChecker
320
+ if _utils.is_set(model):
321
+ override_model_before = self._override_model
322
+ # noinspection PyTypeChecker
323
+ self._override_model = _utils.Some(models.infer_model(model)) # pyright: ignore[reportArgumentType]
324
+ else:
325
+ override_model_before = _utils.UNSET
315
326
 
316
- Args:
317
- overriding_model: The model to use instead of the model passed to the agent run.
318
- """
319
- override_model_before = self._override_model
320
- self._override_model = _utils.Some(models.infer_model(overriding_model))
321
327
  try:
322
328
  yield
323
329
  finally:
324
- self._override_model = override_model_before
330
+ if _utils.is_set(override_deps_before):
331
+ self._override_deps = override_deps_before
332
+ if _utils.is_set(override_model_before):
333
+ self._override_model = override_model_before
325
334
 
326
335
  @overload
327
336
  def system_prompt(
328
- self, func: Callable[[CallContext[AgentDeps]], str], /
329
- ) -> Callable[[CallContext[AgentDeps]], str]: ...
337
+ self, func: Callable[[RunContext[AgentDeps]], str], /
338
+ ) -> Callable[[RunContext[AgentDeps]], str]: ...
330
339
 
331
340
  @overload
332
341
  def system_prompt(
333
- self, func: Callable[[CallContext[AgentDeps]], Awaitable[str]], /
334
- ) -> Callable[[CallContext[AgentDeps]], Awaitable[str]]: ...
342
+ self, func: Callable[[RunContext[AgentDeps]], Awaitable[str]], /
343
+ ) -> Callable[[RunContext[AgentDeps]], Awaitable[str]]: ...
335
344
 
336
345
  @overload
337
346
  def system_prompt(self, func: Callable[[], str], /) -> Callable[[], str]: ...
@@ -344,7 +353,7 @@ class Agent(Generic[AgentDeps, ResultData]):
344
353
  ) -> _system_prompt.SystemPromptFunc[AgentDeps]:
345
354
  """Decorator to register a system prompt function.
346
355
 
347
- Optionally takes [`CallContext`][pydantic_ai.dependencies.CallContext] as it's only argument.
356
+ Optionally takes [`RunContext`][pydantic_ai.dependencies.RunContext] as it's only argument.
348
357
  Can decorate a sync or async functions.
349
358
 
350
359
  Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure
@@ -352,7 +361,7 @@ class Agent(Generic[AgentDeps, ResultData]):
352
361
 
353
362
  Example:
354
363
  ```py
355
- from pydantic_ai import Agent, CallContext
364
+ from pydantic_ai import Agent, RunContext
356
365
 
357
366
  agent = Agent('test', deps_type=str)
358
367
 
@@ -361,12 +370,12 @@ class Agent(Generic[AgentDeps, ResultData]):
361
370
  return 'foobar'
362
371
 
363
372
  @agent.system_prompt
364
- async def async_system_prompt(ctx: CallContext[str]) -> str:
373
+ async def async_system_prompt(ctx: RunContext[str]) -> str:
365
374
  return f'{ctx.deps} is the best'
366
375
 
367
376
  result = agent.run_sync('foobar', deps='spam')
368
377
  print(result.data)
369
- #> success (no retriever calls)
378
+ #> success (no tool calls)
370
379
  ```
371
380
  """
372
381
  self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func))
@@ -374,13 +383,13 @@ class Agent(Generic[AgentDeps, ResultData]):
374
383
 
375
384
  @overload
376
385
  def result_validator(
377
- self, func: Callable[[CallContext[AgentDeps], ResultData], ResultData], /
378
- ) -> Callable[[CallContext[AgentDeps], ResultData], ResultData]: ...
386
+ self, func: Callable[[RunContext[AgentDeps], ResultData], ResultData], /
387
+ ) -> Callable[[RunContext[AgentDeps], ResultData], ResultData]: ...
379
388
 
380
389
  @overload
381
390
  def result_validator(
382
- self, func: Callable[[CallContext[AgentDeps], ResultData], Awaitable[ResultData]], /
383
- ) -> Callable[[CallContext[AgentDeps], ResultData], Awaitable[ResultData]]: ...
391
+ self, func: Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]], /
392
+ ) -> Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]]: ...
384
393
 
385
394
  @overload
386
395
  def result_validator(self, func: Callable[[ResultData], ResultData], /) -> Callable[[ResultData], ResultData]: ...
@@ -395,7 +404,7 @@ class Agent(Generic[AgentDeps, ResultData]):
395
404
  ) -> _result.ResultValidatorFunc[AgentDeps, ResultData]:
396
405
  """Decorator to register a result validator function.
397
406
 
398
- Optionally takes [`CallContext`][pydantic_ai.dependencies.CallContext] as it's first argument.
407
+ Optionally takes [`RunContext`][pydantic_ai.dependencies.RunContext] as it's first argument.
399
408
  Can decorate a sync or async functions.
400
409
 
401
410
  Overloads for every possible signature of `result_validator` are included so the decorator doesn't obscure
@@ -403,7 +412,7 @@ class Agent(Generic[AgentDeps, ResultData]):
403
412
 
404
413
  Example:
405
414
  ```py
406
- from pydantic_ai import Agent, CallContext, ModelRetry
415
+ from pydantic_ai import Agent, ModelRetry, RunContext
407
416
 
408
417
  agent = Agent('test', deps_type=str)
409
418
 
@@ -414,61 +423,57 @@ class Agent(Generic[AgentDeps, ResultData]):
414
423
  return data
415
424
 
416
425
  @agent.result_validator
417
- async def result_validator_deps(ctx: CallContext[str], data: str) -> str:
426
+ async def result_validator_deps(ctx: RunContext[str], data: str) -> str:
418
427
  if ctx.deps in data:
419
428
  raise ModelRetry('wrong response')
420
429
  return data
421
430
 
422
431
  result = agent.run_sync('foobar', deps='spam')
423
432
  print(result.data)
424
- #> success (no retriever calls)
433
+ #> success (no tool calls)
425
434
  ```
426
435
  """
427
436
  self._result_validators.append(_result.ResultValidator(func))
428
437
  return func
429
438
 
430
439
  @overload
431
- def retriever(
432
- self, func: RetrieverContextFunc[AgentDeps, RetrieverParams], /
433
- ) -> RetrieverContextFunc[AgentDeps, RetrieverParams]: ...
440
+ def tool(self, func: ToolContextFunc[AgentDeps, ToolParams], /) -> ToolContextFunc[AgentDeps, ToolParams]: ...
434
441
 
435
442
  @overload
436
- def retriever(
443
+ def tool(
437
444
  self, /, *, retries: int | None = None
438
- ) -> Callable[
439
- [RetrieverContextFunc[AgentDeps, RetrieverParams]], RetrieverContextFunc[AgentDeps, RetrieverParams]
440
- ]: ...
445
+ ) -> Callable[[ToolContextFunc[AgentDeps, ToolParams]], ToolContextFunc[AgentDeps, ToolParams]]: ...
441
446
 
442
- def retriever(
447
+ def tool(
443
448
  self,
444
- func: RetrieverContextFunc[AgentDeps, RetrieverParams] | None = None,
449
+ func: ToolContextFunc[AgentDeps, ToolParams] | None = None,
445
450
  /,
446
451
  *,
447
452
  retries: int | None = None,
448
453
  ) -> Any:
449
- """Decorator to register a retriever function which takes
450
- [`CallContext`][pydantic_ai.dependencies.CallContext] as its first argument.
454
+ """Decorator to register a tool function which takes
455
+ [`RunContext`][pydantic_ai.dependencies.RunContext] as its first argument.
451
456
 
452
457
  Can decorate a sync or async functions.
453
458
 
454
459
  The docstring is inspected to extract both the tool description and description of each parameter,
455
- [learn more](../agents.md#retrievers-tools-and-schema).
460
+ [learn more](../agents.md#function-tools-and-schema).
456
461
 
457
- We can't add overloads for every possible signature of retriever, since the return type is a recursive union
458
- so the signature of functions decorated with `@agent.retriever` is obscured.
462
+ We can't add overloads for every possible signature of tool, since the return type is a recursive union
463
+ so the signature of functions decorated with `@agent.tool` is obscured.
459
464
 
460
465
  Example:
461
466
  ```py
462
- from pydantic_ai import Agent, CallContext
467
+ from pydantic_ai import Agent, RunContext
463
468
 
464
469
  agent = Agent('test', deps_type=int)
465
470
 
466
- @agent.retriever
467
- def foobar(ctx: CallContext[int], x: int) -> int:
471
+ @agent.tool
472
+ def foobar(ctx: RunContext[int], x: int) -> int:
468
473
  return ctx.deps + x
469
474
 
470
- @agent.retriever(retries=2)
471
- async def spam(ctx: CallContext[str], y: float) -> float:
475
+ @agent.tool(retries=2)
476
+ async def spam(ctx: RunContext[str], y: float) -> float:
472
477
  return ctx.deps + y
473
478
 
474
479
  result = agent.run_sync('foobar', deps=1)
@@ -477,58 +482,56 @@ class Agent(Generic[AgentDeps, ResultData]):
477
482
  ```
478
483
 
479
484
  Args:
480
- func: The retriever function to register.
481
- retries: The number of retries to allow for this retriever, defaults to the agent's default retries,
485
+ func: The tool function to register.
486
+ retries: The number of retries to allow for this tool, defaults to the agent's default retries,
482
487
  which defaults to 1.
483
488
  """ # noqa: D205
484
489
  if func is None:
485
490
 
486
- def retriever_decorator(
487
- func_: RetrieverContextFunc[AgentDeps, RetrieverParams],
488
- ) -> RetrieverContextFunc[AgentDeps, RetrieverParams]:
491
+ def tool_decorator(
492
+ func_: ToolContextFunc[AgentDeps, ToolParams],
493
+ ) -> ToolContextFunc[AgentDeps, ToolParams]:
489
494
  # noinspection PyTypeChecker
490
- self._register_retriever(_utils.Either(left=func_), retries)
495
+ self._register_tool(_utils.Either(left=func_), retries)
491
496
  return func_
492
497
 
493
- return retriever_decorator
498
+ return tool_decorator
494
499
  else:
495
500
  # noinspection PyTypeChecker
496
- self._register_retriever(_utils.Either(left=func), retries)
501
+ self._register_tool(_utils.Either(left=func), retries)
497
502
  return func
498
503
 
499
504
  @overload
500
- def retriever_plain(self, func: RetrieverPlainFunc[RetrieverParams], /) -> RetrieverPlainFunc[RetrieverParams]: ...
505
+ def tool_plain(self, func: ToolPlainFunc[ToolParams], /) -> ToolPlainFunc[ToolParams]: ...
501
506
 
502
507
  @overload
503
- def retriever_plain(
508
+ def tool_plain(
504
509
  self, /, *, retries: int | None = None
505
- ) -> Callable[[RetrieverPlainFunc[RetrieverParams]], RetrieverPlainFunc[RetrieverParams]]: ...
510
+ ) -> Callable[[ToolPlainFunc[ToolParams]], ToolPlainFunc[ToolParams]]: ...
506
511
 
507
- def retriever_plain(
508
- self, func: RetrieverPlainFunc[RetrieverParams] | None = None, /, *, retries: int | None = None
509
- ) -> Any:
510
- """Decorator to register a retriever function which DOES NOT take `CallContext` as an argument.
512
+ def tool_plain(self, func: ToolPlainFunc[ToolParams] | None = None, /, *, retries: int | None = None) -> Any:
513
+ """Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
511
514
 
512
515
  Can decorate a sync or async functions.
513
516
 
514
517
  The docstring is inspected to extract both the tool description and description of each parameter,
515
- [learn more](../agents.md#retrievers-tools-and-schema).
518
+ [learn more](../agents.md#function-tools-and-schema).
516
519
 
517
- We can't add overloads for every possible signature of retriever, since the return type is a recursive union
518
- so the signature of functions decorated with `@agent.retriever` is obscured.
520
+ We can't add overloads for every possible signature of tool, since the return type is a recursive union
521
+ so the signature of functions decorated with `@agent.tool` is obscured.
519
522
 
520
523
  Example:
521
524
  ```py
522
- from pydantic_ai import Agent, CallContext
525
+ from pydantic_ai import Agent, RunContext
523
526
 
524
527
  agent = Agent('test')
525
528
 
526
- @agent.retriever
527
- def foobar(ctx: CallContext[int]) -> int:
529
+ @agent.tool
530
+ def foobar(ctx: RunContext[int]) -> int:
528
531
  return 123
529
532
 
530
- @agent.retriever(retries=2)
531
- async def spam(ctx: CallContext[str]) -> float:
533
+ @agent.tool(retries=2)
534
+ async def spam(ctx: RunContext[str]) -> float:
532
535
  return 3.14
533
536
 
534
537
  result = agent.run_sync('foobar', deps=1)
@@ -537,38 +540,36 @@ class Agent(Generic[AgentDeps, ResultData]):
537
540
  ```
538
541
 
539
542
  Args:
540
- func: The retriever function to register.
541
- retries: The number of retries to allow for this retriever, defaults to the agent's default retries,
543
+ func: The tool function to register.
544
+ retries: The number of retries to allow for this tool, defaults to the agent's default retries,
542
545
  which defaults to 1.
543
546
  """
544
547
  if func is None:
545
548
 
546
- def retriever_decorator(
547
- func_: RetrieverPlainFunc[RetrieverParams],
548
- ) -> RetrieverPlainFunc[RetrieverParams]:
549
+ def tool_decorator(
550
+ func_: ToolPlainFunc[ToolParams],
551
+ ) -> ToolPlainFunc[ToolParams]:
549
552
  # noinspection PyTypeChecker
550
- self._register_retriever(_utils.Either(right=func_), retries)
553
+ self._register_tool(_utils.Either(right=func_), retries)
551
554
  return func_
552
555
 
553
- return retriever_decorator
556
+ return tool_decorator
554
557
  else:
555
- self._register_retriever(_utils.Either(right=func), retries)
558
+ self._register_tool(_utils.Either(right=func), retries)
556
559
  return func
557
560
 
558
- def _register_retriever(
559
- self, func: _r.RetrieverEitherFunc[AgentDeps, RetrieverParams], retries: int | None
560
- ) -> None:
561
- """Private utility to register a retriever function."""
561
+ def _register_tool(self, func: _r.ToolEitherFunc[AgentDeps, ToolParams], retries: int | None) -> None:
562
+ """Private utility to register a tool function."""
562
563
  retries_ = retries if retries is not None else self._default_retries
563
- retriever = _r.Retriever[AgentDeps, RetrieverParams](func, retries_)
564
+ tool = _r.Tool[AgentDeps, ToolParams](func, retries_)
564
565
 
565
- if self._result_schema and retriever.name in self._result_schema.tools:
566
- raise ValueError(f'Retriever name conflicts with result schema name: {retriever.name!r}')
566
+ if self._result_schema and tool.name in self._result_schema.tools:
567
+ raise ValueError(f'Tool name conflicts with result schema name: {tool.name!r}')
567
568
 
568
- if retriever.name in self._retrievers:
569
- raise ValueError(f'Retriever name conflicts with existing retriever: {retriever.name!r}')
569
+ if tool.name in self._function_tools:
570
+ raise ValueError(f'Tool name conflicts with existing tool: {tool.name!r}')
570
571
 
571
- self._retrievers[retriever.name] = retriever
572
+ self._function_tools[tool.name] = tool
572
573
 
573
574
  async def _get_agent_model(
574
575
  self, model: models.Model | models.KnownModelName | None
@@ -583,11 +584,11 @@ class Agent(Generic[AgentDeps, ResultData]):
583
584
  """
584
585
  model_: models.Model
585
586
  if some_model := self._override_model:
586
- # we don't want `override_model()` to cover up errors from the model not being defined, hence this check
587
+ # we don't want `override()` to cover up errors from the model not being defined, hence this check
587
588
  if model is None and self.model is None:
588
589
  raise exceptions.UserError(
589
590
  '`model` must be set either when creating the agent or when calling it. '
590
- '(Even when `override_model()` is customizing the model that will actually be called)'
591
+ '(Even when `override(model=...)` is customizing the model that will actually be called)'
591
592
  )
592
593
  model_ = some_model.value
593
594
  custom_model = None
@@ -601,7 +602,7 @@ class Agent(Generic[AgentDeps, ResultData]):
601
602
  raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')
602
603
 
603
604
  result_tools = list(self._result_schema.tools.values()) if self._result_schema else None
604
- agent_model = await model_.agent_model(self._retrievers, self._allow_text_result, result_tools)
605
+ agent_model = await model_.agent_model(self._function_tools, self._allow_text_result, result_tools)
605
606
  return model_, custom_model, agent_model
606
607
 
607
608
  async def _prepare_messages(
@@ -663,12 +664,12 @@ class Agent(Generic[AgentDeps, ResultData]):
663
664
  if not model_response.calls:
664
665
  raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
665
666
 
666
- # otherwise we run all retriever functions in parallel
667
+ # otherwise we run all tool functions in parallel
667
668
  messages: list[_messages.Message] = []
668
669
  tasks: list[asyncio.Task[_messages.Message]] = []
669
670
  for call in model_response.calls:
670
- if retriever := self._retrievers.get(call.tool_name):
671
- tasks.append(asyncio.create_task(retriever.run(deps, call), name=call.tool_name))
671
+ if tool := self._function_tools.get(call.tool_name):
672
+ tasks.append(asyncio.create_task(tool.run(deps, call), name=call.tool_name))
672
673
  else:
673
674
  messages.append(self._unknown_tool(call.tool_name))
674
675
 
@@ -719,7 +720,7 @@ class Agent(Generic[AgentDeps, ResultData]):
719
720
  if self._result_schema.find_tool(structured_msg):
720
721
  return _MarkFinalResult(model_response)
721
722
 
722
- # the model is calling a retriever function, consume the response to get the next message
723
+ # the model is calling a tool function, consume the response to get the next message
723
724
  async for _ in model_response:
724
725
  pass
725
726
  structured_msg = model_response.get()
@@ -727,11 +728,11 @@ class Agent(Generic[AgentDeps, ResultData]):
727
728
  raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
728
729
  messages: list[_messages.Message] = [structured_msg]
729
730
 
730
- # we now run all retriever functions in parallel
731
+ # we now run all tool functions in parallel
731
732
  tasks: list[asyncio.Task[_messages.Message]] = []
732
733
  for call in structured_msg.calls:
733
- if retriever := self._retrievers.get(call.tool_name):
734
- tasks.append(asyncio.create_task(retriever.run(deps, call), name=call.tool_name))
734
+ if tool := self._function_tools.get(call.tool_name):
735
+ tasks.append(asyncio.create_task(tool.run(deps, call), name=call.tool_name))
735
736
  else:
736
737
  messages.append(self._unknown_tool(call.tool_name))
737
738
 
@@ -763,7 +764,7 @@ class Agent(Generic[AgentDeps, ResultData]):
763
764
 
764
765
  def _unknown_tool(self, tool_name: str) -> _messages.RetryPrompt:
765
766
  self._incr_result_retry()
766
- names = list(self._retrievers.keys())
767
+ names = list(self._function_tools.keys())
767
768
  if self._result_schema:
768
769
  names.extend(self._result_schema.tool_names())
769
770
  if names:
@@ -775,7 +776,7 @@ class Agent(Generic[AgentDeps, ResultData]):
775
776
  def _get_deps(self, deps: AgentDeps) -> AgentDeps:
776
777
  """Get deps for a run.
777
778
 
778
- If we've overridden deps via `_override_deps_stack`, use that, otherwise use the deps passed to the call.
779
+ If we've overridden deps via `_override_deps`, use that, otherwise use the deps passed to the call.
779
780
 
780
781
  We could do runtime type checking of deps against `self._deps_type`, but that's a slippery slope.
781
782
  """
@@ -13,13 +13,13 @@ else:
13
13
 
14
14
  __all__ = (
15
15
  'AgentDeps',
16
- 'CallContext',
16
+ 'RunContext',
17
17
  'ResultValidatorFunc',
18
18
  'SystemPromptFunc',
19
- 'RetrieverReturnValue',
20
- 'RetrieverContextFunc',
21
- 'RetrieverPlainFunc',
22
- 'RetrieverParams',
19
+ 'ToolReturnValue',
20
+ 'ToolContextFunc',
21
+ 'ToolPlainFunc',
22
+ 'ToolParams',
23
23
  'JsonData',
24
24
  )
25
25
 
@@ -28,7 +28,7 @@ AgentDeps = TypeVar('AgentDeps')
28
28
 
29
29
 
30
30
  @dataclass
31
- class CallContext(Generic[AgentDeps]):
31
+ class RunContext(Generic[AgentDeps]):
32
32
  """Information about the current call."""
33
33
 
34
34
  deps: AgentDeps
@@ -39,23 +39,23 @@ class CallContext(Generic[AgentDeps]):
39
39
  """Name of the tool being called."""
40
40
 
41
41
 
42
- RetrieverParams = ParamSpec('RetrieverParams')
42
+ ToolParams = ParamSpec('ToolParams')
43
43
  """Retrieval function param spec."""
44
44
 
45
45
  SystemPromptFunc = Union[
46
- Callable[[CallContext[AgentDeps]], str],
47
- Callable[[CallContext[AgentDeps]], Awaitable[str]],
46
+ Callable[[RunContext[AgentDeps]], str],
47
+ Callable[[RunContext[AgentDeps]], Awaitable[str]],
48
48
  Callable[[], str],
49
49
  Callable[[], Awaitable[str]],
50
50
  ]
51
- """A function that may or maybe not take `CallContext` as an argument, and may or may not be async.
51
+ """A function that may or maybe not take `RunContext` as an argument, and may or may not be async.
52
52
 
53
53
  Usage `SystemPromptFunc[AgentDeps]`.
54
54
  """
55
55
 
56
56
  ResultValidatorFunc = Union[
57
- Callable[[CallContext[AgentDeps], ResultData], ResultData],
58
- Callable[[CallContext[AgentDeps], ResultData], Awaitable[ResultData]],
57
+ Callable[[RunContext[AgentDeps], ResultData], ResultData],
58
+ Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]],
59
59
  Callable[[ResultData], ResultData],
60
60
  Callable[[ResultData], Awaitable[ResultData]],
61
61
  ]
@@ -69,15 +69,15 @@ Usage `ResultValidator[AgentDeps, ResultData]`.
69
69
  JsonData: TypeAlias = 'None | str | int | float | Sequence[JsonData] | Mapping[str, JsonData]'
70
70
  """Type representing any JSON data."""
71
71
 
72
- RetrieverReturnValue = Union[JsonData, Awaitable[JsonData]]
73
- """Return value of a retriever function."""
74
- RetrieverContextFunc = Callable[Concatenate[CallContext[AgentDeps], RetrieverParams], RetrieverReturnValue]
75
- """A retriever function that takes `CallContext` as the first argument.
72
+ ToolReturnValue = Union[JsonData, Awaitable[JsonData]]
73
+ """Return value of a tool function."""
74
+ ToolContextFunc = Callable[Concatenate[RunContext[AgentDeps], ToolParams], ToolReturnValue]
75
+ """A tool function that takes `RunContext` as the first argument.
76
76
 
77
- Usage `RetrieverContextFunc[AgentDeps, RetrieverParams]`.
77
+ Usage `ToolContextFunc[AgentDeps, ToolParams]`.
78
78
  """
79
- RetrieverPlainFunc = Callable[RetrieverParams, RetrieverReturnValue]
80
- """A retriever function that does not take `CallContext` as the first argument.
79
+ ToolPlainFunc = Callable[ToolParams, ToolReturnValue]
80
+ """A tool function that does not take `RunContext` as the first argument.
81
81
 
82
- Usage `RetrieverPlainFunc[RetrieverParams]`.
82
+ Usage `ToolPlainFunc[ToolParams]`.
83
83
  """
pydantic_ai/exceptions.py CHANGED
@@ -6,7 +6,7 @@ __all__ = 'ModelRetry', 'UserError', 'UnexpectedModelBehavior'
6
6
 
7
7
 
8
8
  class ModelRetry(Exception):
9
- """Exception raised when a retriever function should be retried.
9
+ """Exception raised when a tool function should be retried.
10
10
 
11
11
  The agent will return the message to the model and ask it to try calling the function/tool again.
12
12
  """