pydantic-ai-slim 0.0.6a4__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/agent.py CHANGED
@@ -11,15 +11,15 @@ from typing_extensions import assert_never
11
11
 
12
12
  from . import (
13
13
  _result,
14
- _retriever as _r,
15
14
  _system_prompt,
15
+ _tool as _r,
16
16
  _utils,
17
17
  exceptions,
18
18
  messages as _messages,
19
19
  models,
20
20
  result,
21
21
  )
22
- from .dependencies import AgentDeps, CallContext, RetrieverContextFunc, RetrieverParams, RetrieverPlainFunc
22
+ from .dependencies import AgentDeps, RunContext, ToolContextFunc, ToolParams, ToolPlainFunc
23
23
  from .result import ResultData
24
24
 
25
25
  __all__ = ('Agent',)
@@ -58,7 +58,7 @@ class Agent(Generic[AgentDeps, ResultData]):
58
58
  _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = field(repr=False)
59
59
  _allow_text_result: bool = field(repr=False)
60
60
  _system_prompts: tuple[str, ...] = field(repr=False)
61
- _retrievers: dict[str, _r.Retriever[AgentDeps, Any]] = field(repr=False)
61
+ _function_tools: dict[str, _r.Tool[AgentDeps, Any]] = field(repr=False)
62
62
  _default_retries: int = field(repr=False)
63
63
  _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = field(repr=False)
64
64
  _deps_type: type[AgentDeps] = field(repr=False)
@@ -75,8 +75,8 @@ class Agent(Generic[AgentDeps, ResultData]):
75
75
  def __init__(
76
76
  self,
77
77
  model: models.Model | models.KnownModelName | None = None,
78
- result_type: type[ResultData] = str,
79
78
  *,
79
+ result_type: type[ResultData] = str,
80
80
  system_prompt: str | Sequence[str] = (),
81
81
  deps_type: type[AgentDeps] = NoneType,
82
82
  retries: int = 1,
@@ -105,7 +105,7 @@ class Agent(Generic[AgentDeps, ResultData]):
105
105
  it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately,
106
106
  which checks for the necessary environment variables. Set this to `false`
107
107
  to defer the evaluation until the first run. Useful if you want to
108
- [override the model][pydantic_ai.Agent.override_model] for testing.
108
+ [override the model][pydantic_ai.Agent.override] for testing.
109
109
  """
110
110
  if model is None or defer_model_check:
111
111
  self.model = model
@@ -119,7 +119,7 @@ class Agent(Generic[AgentDeps, ResultData]):
119
119
  self._allow_text_result = self._result_schema is None or self._result_schema.allow_text_result
120
120
 
121
121
  self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt)
122
- self._retrievers: dict[str, _r.Retriever[AgentDeps, Any]] = {}
122
+ self._function_tools: dict[str, _r.Tool[AgentDeps, Any]] = {}
123
123
  self._deps_type = deps_type
124
124
  self._default_retries = retries
125
125
  self._system_prompt_functions = []
@@ -150,14 +150,6 @@ class Agent(Generic[AgentDeps, ResultData]):
150
150
 
151
151
  deps = self._get_deps(deps)
152
152
 
153
- new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
154
- self.last_run_messages = messages
155
-
156
- for retriever in self._retrievers.values():
157
- retriever.reset()
158
-
159
- cost = result.Cost()
160
-
161
153
  with _logfire.span(
162
154
  'agent run {prompt=}',
163
155
  prompt=user_prompt,
@@ -165,6 +157,14 @@ class Agent(Generic[AgentDeps, ResultData]):
165
157
  custom_model=custom_model,
166
158
  model_name=model_used.name(),
167
159
  ) as run_span:
160
+ new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
161
+ self.last_run_messages = messages
162
+
163
+ for tool in self._function_tools.values():
164
+ tool.reset()
165
+
166
+ cost = result.Cost()
167
+
168
168
  run_step = 0
169
169
  while True:
170
170
  run_step += 1
@@ -243,14 +243,6 @@ class Agent(Generic[AgentDeps, ResultData]):
243
243
 
244
244
  deps = self._get_deps(deps)
245
245
 
246
- new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
247
- self.last_run_messages = messages
248
-
249
- for retriever in self._retrievers.values():
250
- retriever.reset()
251
-
252
- cost = result.Cost()
253
-
254
246
  with _logfire.span(
255
247
  'agent run stream {prompt=}',
256
248
  prompt=user_prompt,
@@ -258,6 +250,14 @@ class Agent(Generic[AgentDeps, ResultData]):
258
250
  custom_model=custom_model,
259
251
  model_name=model_used.name(),
260
252
  ) as run_span:
253
+ new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
254
+ self.last_run_messages = messages
255
+
256
+ for tool in self._function_tools.values():
257
+ tool.reset()
258
+
259
+ cost = result.Cost()
260
+
261
261
  run_step = 0
262
262
  while True:
263
263
  run_step += 1
@@ -284,6 +284,7 @@ class Agent(Generic[AgentDeps, ResultData]):
284
284
  self._result_schema,
285
285
  deps,
286
286
  self._result_validators,
287
+ lambda m: run_span.set_attribute('all_messages', messages),
287
288
  )
288
289
  return
289
290
  else:
@@ -296,42 +297,51 @@ class Agent(Generic[AgentDeps, ResultData]):
296
297
  cost += model_response.cost()
297
298
 
298
299
  @contextmanager
299
- def override_deps(self, overriding_deps: AgentDeps) -> Iterator[None]:
300
- """Context manager to temporarily override agent dependencies, this is particularly useful when testing.
300
+ def override(
301
+ self,
302
+ *,
303
+ deps: AgentDeps | _utils.Unset = _utils.UNSET,
304
+ model: models.Model | models.KnownModelName | _utils.Unset = _utils.UNSET,
305
+ ) -> Iterator[None]:
306
+ """Context manager to temporarily override agent dependencies and model.
307
+
308
+ This is particularly useful when testing.
301
309
 
302
310
  Args:
303
- overriding_deps: The dependencies to use instead of the dependencies passed to the agent run.
311
+ deps: The dependencies to use instead of the dependencies passed to the agent run.
312
+ model: The model to use instead of the model passed to the agent run.
304
313
  """
305
- override_deps_before = self._override_deps
306
- self._override_deps = _utils.Some(overriding_deps)
307
- try:
308
- yield
309
- finally:
310
- self._override_deps = override_deps_before
314
+ if _utils.is_set(deps):
315
+ override_deps_before = self._override_deps
316
+ self._override_deps = _utils.Some(deps)
317
+ else:
318
+ override_deps_before = _utils.UNSET
311
319
 
312
- @contextmanager
313
- def override_model(self, overriding_model: models.Model | models.KnownModelName) -> Iterator[None]:
314
- """Context manager to temporarily override the model used by the agent.
320
+ # noinspection PyTypeChecker
321
+ if _utils.is_set(model):
322
+ override_model_before = self._override_model
323
+ # noinspection PyTypeChecker
324
+ self._override_model = _utils.Some(models.infer_model(model)) # pyright: ignore[reportArgumentType]
325
+ else:
326
+ override_model_before = _utils.UNSET
315
327
 
316
- Args:
317
- overriding_model: The model to use instead of the model passed to the agent run.
318
- """
319
- override_model_before = self._override_model
320
- self._override_model = _utils.Some(models.infer_model(overriding_model))
321
328
  try:
322
329
  yield
323
330
  finally:
324
- self._override_model = override_model_before
331
+ if _utils.is_set(override_deps_before):
332
+ self._override_deps = override_deps_before
333
+ if _utils.is_set(override_model_before):
334
+ self._override_model = override_model_before
325
335
 
326
336
  @overload
327
337
  def system_prompt(
328
- self, func: Callable[[CallContext[AgentDeps]], str], /
329
- ) -> Callable[[CallContext[AgentDeps]], str]: ...
338
+ self, func: Callable[[RunContext[AgentDeps]], str], /
339
+ ) -> Callable[[RunContext[AgentDeps]], str]: ...
330
340
 
331
341
  @overload
332
342
  def system_prompt(
333
- self, func: Callable[[CallContext[AgentDeps]], Awaitable[str]], /
334
- ) -> Callable[[CallContext[AgentDeps]], Awaitable[str]]: ...
343
+ self, func: Callable[[RunContext[AgentDeps]], Awaitable[str]], /
344
+ ) -> Callable[[RunContext[AgentDeps]], Awaitable[str]]: ...
335
345
 
336
346
  @overload
337
347
  def system_prompt(self, func: Callable[[], str], /) -> Callable[[], str]: ...
@@ -344,7 +354,7 @@ class Agent(Generic[AgentDeps, ResultData]):
344
354
  ) -> _system_prompt.SystemPromptFunc[AgentDeps]:
345
355
  """Decorator to register a system prompt function.
346
356
 
347
- Optionally takes [`CallContext`][pydantic_ai.dependencies.CallContext] as it's only argument.
357
+ Optionally takes [`RunContext`][pydantic_ai.dependencies.RunContext] as it's only argument.
348
358
  Can decorate a sync or async functions.
349
359
 
350
360
  Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure
@@ -352,7 +362,7 @@ class Agent(Generic[AgentDeps, ResultData]):
352
362
 
353
363
  Example:
354
364
  ```py
355
- from pydantic_ai import Agent, CallContext
365
+ from pydantic_ai import Agent, RunContext
356
366
 
357
367
  agent = Agent('test', deps_type=str)
358
368
 
@@ -361,12 +371,12 @@ class Agent(Generic[AgentDeps, ResultData]):
361
371
  return 'foobar'
362
372
 
363
373
  @agent.system_prompt
364
- async def async_system_prompt(ctx: CallContext[str]) -> str:
374
+ async def async_system_prompt(ctx: RunContext[str]) -> str:
365
375
  return f'{ctx.deps} is the best'
366
376
 
367
377
  result = agent.run_sync('foobar', deps='spam')
368
378
  print(result.data)
369
- #> success (no retriever calls)
379
+ #> success (no tool calls)
370
380
  ```
371
381
  """
372
382
  self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func))
@@ -374,13 +384,13 @@ class Agent(Generic[AgentDeps, ResultData]):
374
384
 
375
385
  @overload
376
386
  def result_validator(
377
- self, func: Callable[[CallContext[AgentDeps], ResultData], ResultData], /
378
- ) -> Callable[[CallContext[AgentDeps], ResultData], ResultData]: ...
387
+ self, func: Callable[[RunContext[AgentDeps], ResultData], ResultData], /
388
+ ) -> Callable[[RunContext[AgentDeps], ResultData], ResultData]: ...
379
389
 
380
390
  @overload
381
391
  def result_validator(
382
- self, func: Callable[[CallContext[AgentDeps], ResultData], Awaitable[ResultData]], /
383
- ) -> Callable[[CallContext[AgentDeps], ResultData], Awaitable[ResultData]]: ...
392
+ self, func: Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]], /
393
+ ) -> Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]]: ...
384
394
 
385
395
  @overload
386
396
  def result_validator(self, func: Callable[[ResultData], ResultData], /) -> Callable[[ResultData], ResultData]: ...
@@ -395,7 +405,7 @@ class Agent(Generic[AgentDeps, ResultData]):
395
405
  ) -> _result.ResultValidatorFunc[AgentDeps, ResultData]:
396
406
  """Decorator to register a result validator function.
397
407
 
398
- Optionally takes [`CallContext`][pydantic_ai.dependencies.CallContext] as it's first argument.
408
+ Optionally takes [`RunContext`][pydantic_ai.dependencies.RunContext] as it's first argument.
399
409
  Can decorate a sync or async functions.
400
410
 
401
411
  Overloads for every possible signature of `result_validator` are included so the decorator doesn't obscure
@@ -403,7 +413,7 @@ class Agent(Generic[AgentDeps, ResultData]):
403
413
 
404
414
  Example:
405
415
  ```py
406
- from pydantic_ai import Agent, CallContext, ModelRetry
416
+ from pydantic_ai import Agent, ModelRetry, RunContext
407
417
 
408
418
  agent = Agent('test', deps_type=str)
409
419
 
@@ -414,61 +424,57 @@ class Agent(Generic[AgentDeps, ResultData]):
414
424
  return data
415
425
 
416
426
  @agent.result_validator
417
- async def result_validator_deps(ctx: CallContext[str], data: str) -> str:
427
+ async def result_validator_deps(ctx: RunContext[str], data: str) -> str:
418
428
  if ctx.deps in data:
419
429
  raise ModelRetry('wrong response')
420
430
  return data
421
431
 
422
432
  result = agent.run_sync('foobar', deps='spam')
423
433
  print(result.data)
424
- #> success (no retriever calls)
434
+ #> success (no tool calls)
425
435
  ```
426
436
  """
427
437
  self._result_validators.append(_result.ResultValidator(func))
428
438
  return func
429
439
 
430
440
  @overload
431
- def retriever(
432
- self, func: RetrieverContextFunc[AgentDeps, RetrieverParams], /
433
- ) -> RetrieverContextFunc[AgentDeps, RetrieverParams]: ...
441
+ def tool(self, func: ToolContextFunc[AgentDeps, ToolParams], /) -> ToolContextFunc[AgentDeps, ToolParams]: ...
434
442
 
435
443
  @overload
436
- def retriever(
444
+ def tool(
437
445
  self, /, *, retries: int | None = None
438
- ) -> Callable[
439
- [RetrieverContextFunc[AgentDeps, RetrieverParams]], RetrieverContextFunc[AgentDeps, RetrieverParams]
440
- ]: ...
446
+ ) -> Callable[[ToolContextFunc[AgentDeps, ToolParams]], ToolContextFunc[AgentDeps, ToolParams]]: ...
441
447
 
442
- def retriever(
448
+ def tool(
443
449
  self,
444
- func: RetrieverContextFunc[AgentDeps, RetrieverParams] | None = None,
450
+ func: ToolContextFunc[AgentDeps, ToolParams] | None = None,
445
451
  /,
446
452
  *,
447
453
  retries: int | None = None,
448
454
  ) -> Any:
449
- """Decorator to register a retriever function which takes
450
- [`CallContext`][pydantic_ai.dependencies.CallContext] as its first argument.
455
+ """Decorator to register a tool function which takes
456
+ [`RunContext`][pydantic_ai.dependencies.RunContext] as its first argument.
451
457
 
452
458
  Can decorate a sync or async functions.
453
459
 
454
460
  The docstring is inspected to extract both the tool description and description of each parameter,
455
- [learn more](../agents.md#retrievers-tools-and-schema).
461
+ [learn more](../agents.md#function-tools-and-schema).
456
462
 
457
- We can't add overloads for every possible signature of retriever, since the return type is a recursive union
458
- so the signature of functions decorated with `@agent.retriever` is obscured.
463
+ We can't add overloads for every possible signature of tool, since the return type is a recursive union
464
+ so the signature of functions decorated with `@agent.tool` is obscured.
459
465
 
460
466
  Example:
461
467
  ```py
462
- from pydantic_ai import Agent, CallContext
468
+ from pydantic_ai import Agent, RunContext
463
469
 
464
470
  agent = Agent('test', deps_type=int)
465
471
 
466
- @agent.retriever
467
- def foobar(ctx: CallContext[int], x: int) -> int:
472
+ @agent.tool
473
+ def foobar(ctx: RunContext[int], x: int) -> int:
468
474
  return ctx.deps + x
469
475
 
470
- @agent.retriever(retries=2)
471
- async def spam(ctx: CallContext[str], y: float) -> float:
476
+ @agent.tool(retries=2)
477
+ async def spam(ctx: RunContext[str], y: float) -> float:
472
478
  return ctx.deps + y
473
479
 
474
480
  result = agent.run_sync('foobar', deps=1)
@@ -477,58 +483,56 @@ class Agent(Generic[AgentDeps, ResultData]):
477
483
  ```
478
484
 
479
485
  Args:
480
- func: The retriever function to register.
481
- retries: The number of retries to allow for this retriever, defaults to the agent's default retries,
486
+ func: The tool function to register.
487
+ retries: The number of retries to allow for this tool, defaults to the agent's default retries,
482
488
  which defaults to 1.
483
489
  """ # noqa: D205
484
490
  if func is None:
485
491
 
486
- def retriever_decorator(
487
- func_: RetrieverContextFunc[AgentDeps, RetrieverParams],
488
- ) -> RetrieverContextFunc[AgentDeps, RetrieverParams]:
492
+ def tool_decorator(
493
+ func_: ToolContextFunc[AgentDeps, ToolParams],
494
+ ) -> ToolContextFunc[AgentDeps, ToolParams]:
489
495
  # noinspection PyTypeChecker
490
- self._register_retriever(_utils.Either(left=func_), retries)
496
+ self._register_tool(_utils.Either(left=func_), retries)
491
497
  return func_
492
498
 
493
- return retriever_decorator
499
+ return tool_decorator
494
500
  else:
495
501
  # noinspection PyTypeChecker
496
- self._register_retriever(_utils.Either(left=func), retries)
502
+ self._register_tool(_utils.Either(left=func), retries)
497
503
  return func
498
504
 
499
505
  @overload
500
- def retriever_plain(self, func: RetrieverPlainFunc[RetrieverParams], /) -> RetrieverPlainFunc[RetrieverParams]: ...
506
+ def tool_plain(self, func: ToolPlainFunc[ToolParams], /) -> ToolPlainFunc[ToolParams]: ...
501
507
 
502
508
  @overload
503
- def retriever_plain(
509
+ def tool_plain(
504
510
  self, /, *, retries: int | None = None
505
- ) -> Callable[[RetrieverPlainFunc[RetrieverParams]], RetrieverPlainFunc[RetrieverParams]]: ...
511
+ ) -> Callable[[ToolPlainFunc[ToolParams]], ToolPlainFunc[ToolParams]]: ...
506
512
 
507
- def retriever_plain(
508
- self, func: RetrieverPlainFunc[RetrieverParams] | None = None, /, *, retries: int | None = None
509
- ) -> Any:
510
- """Decorator to register a retriever function which DOES NOT take `CallContext` as an argument.
513
+ def tool_plain(self, func: ToolPlainFunc[ToolParams] | None = None, /, *, retries: int | None = None) -> Any:
514
+ """Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
511
515
 
512
516
  Can decorate a sync or async functions.
513
517
 
514
518
  The docstring is inspected to extract both the tool description and description of each parameter,
515
- [learn more](../agents.md#retrievers-tools-and-schema).
519
+ [learn more](../agents.md#function-tools-and-schema).
516
520
 
517
- We can't add overloads for every possible signature of retriever, since the return type is a recursive union
518
- so the signature of functions decorated with `@agent.retriever` is obscured.
521
+ We can't add overloads for every possible signature of tool, since the return type is a recursive union
522
+ so the signature of functions decorated with `@agent.tool` is obscured.
519
523
 
520
524
  Example:
521
525
  ```py
522
- from pydantic_ai import Agent, CallContext
526
+ from pydantic_ai import Agent, RunContext
523
527
 
524
528
  agent = Agent('test')
525
529
 
526
- @agent.retriever
527
- def foobar(ctx: CallContext[int]) -> int:
530
+ @agent.tool
531
+ def foobar(ctx: RunContext[int]) -> int:
528
532
  return 123
529
533
 
530
- @agent.retriever(retries=2)
531
- async def spam(ctx: CallContext[str]) -> float:
534
+ @agent.tool(retries=2)
535
+ async def spam(ctx: RunContext[str]) -> float:
532
536
  return 3.14
533
537
 
534
538
  result = agent.run_sync('foobar', deps=1)
@@ -537,38 +541,36 @@ class Agent(Generic[AgentDeps, ResultData]):
537
541
  ```
538
542
 
539
543
  Args:
540
- func: The retriever function to register.
541
- retries: The number of retries to allow for this retriever, defaults to the agent's default retries,
544
+ func: The tool function to register.
545
+ retries: The number of retries to allow for this tool, defaults to the agent's default retries,
542
546
  which defaults to 1.
543
547
  """
544
548
  if func is None:
545
549
 
546
- def retriever_decorator(
547
- func_: RetrieverPlainFunc[RetrieverParams],
548
- ) -> RetrieverPlainFunc[RetrieverParams]:
550
+ def tool_decorator(
551
+ func_: ToolPlainFunc[ToolParams],
552
+ ) -> ToolPlainFunc[ToolParams]:
549
553
  # noinspection PyTypeChecker
550
- self._register_retriever(_utils.Either(right=func_), retries)
554
+ self._register_tool(_utils.Either(right=func_), retries)
551
555
  return func_
552
556
 
553
- return retriever_decorator
557
+ return tool_decorator
554
558
  else:
555
- self._register_retriever(_utils.Either(right=func), retries)
559
+ self._register_tool(_utils.Either(right=func), retries)
556
560
  return func
557
561
 
558
- def _register_retriever(
559
- self, func: _r.RetrieverEitherFunc[AgentDeps, RetrieverParams], retries: int | None
560
- ) -> None:
561
- """Private utility to register a retriever function."""
562
+ def _register_tool(self, func: _r.ToolEitherFunc[AgentDeps, ToolParams], retries: int | None) -> None:
563
+ """Private utility to register a tool function."""
562
564
  retries_ = retries if retries is not None else self._default_retries
563
- retriever = _r.Retriever[AgentDeps, RetrieverParams](func, retries_)
565
+ tool = _r.Tool[AgentDeps, ToolParams](func, retries_)
564
566
 
565
- if self._result_schema and retriever.name in self._result_schema.tools:
566
- raise ValueError(f'Retriever name conflicts with result schema name: {retriever.name!r}')
567
+ if self._result_schema and tool.name in self._result_schema.tools:
568
+ raise ValueError(f'Tool name conflicts with result schema name: {tool.name!r}')
567
569
 
568
- if retriever.name in self._retrievers:
569
- raise ValueError(f'Retriever name conflicts with existing retriever: {retriever.name!r}')
570
+ if tool.name in self._function_tools:
571
+ raise ValueError(f'Tool name conflicts with existing tool: {tool.name!r}')
570
572
 
571
- self._retrievers[retriever.name] = retriever
573
+ self._function_tools[tool.name] = tool
572
574
 
573
575
  async def _get_agent_model(
574
576
  self, model: models.Model | models.KnownModelName | None
@@ -583,11 +585,11 @@ class Agent(Generic[AgentDeps, ResultData]):
583
585
  """
584
586
  model_: models.Model
585
587
  if some_model := self._override_model:
586
- # we don't want `override_model()` to cover up errors from the model not being defined, hence this check
588
+ # we don't want `override()` to cover up errors from the model not being defined, hence this check
587
589
  if model is None and self.model is None:
588
590
  raise exceptions.UserError(
589
591
  '`model` must be set either when creating the agent or when calling it. '
590
- '(Even when `override_model()` is customizing the model that will actually be called)'
592
+ '(Even when `override(model=...)` is customizing the model that will actually be called)'
591
593
  )
592
594
  model_ = some_model.value
593
595
  custom_model = None
@@ -601,7 +603,7 @@ class Agent(Generic[AgentDeps, ResultData]):
601
603
  raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')
602
604
 
603
605
  result_tools = list(self._result_schema.tools.values()) if self._result_schema else None
604
- agent_model = await model_.agent_model(self._retrievers, self._allow_text_result, result_tools)
606
+ agent_model = await model_.agent_model(self._function_tools, self._allow_text_result, result_tools)
605
607
  return model_, custom_model, agent_model
606
608
 
607
609
  async def _prepare_messages(
@@ -663,12 +665,12 @@ class Agent(Generic[AgentDeps, ResultData]):
663
665
  if not model_response.calls:
664
666
  raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
665
667
 
666
- # otherwise we run all retriever functions in parallel
668
+ # otherwise we run all tool functions in parallel
667
669
  messages: list[_messages.Message] = []
668
670
  tasks: list[asyncio.Task[_messages.Message]] = []
669
671
  for call in model_response.calls:
670
- if retriever := self._retrievers.get(call.tool_name):
671
- tasks.append(asyncio.create_task(retriever.run(deps, call), name=call.tool_name))
672
+ if tool := self._function_tools.get(call.tool_name):
673
+ tasks.append(asyncio.create_task(tool.run(deps, call), name=call.tool_name))
672
674
  else:
673
675
  messages.append(self._unknown_tool(call.tool_name))
674
676
 
@@ -719,7 +721,7 @@ class Agent(Generic[AgentDeps, ResultData]):
719
721
  if self._result_schema.find_tool(structured_msg):
720
722
  return _MarkFinalResult(model_response)
721
723
 
722
- # the model is calling a retriever function, consume the response to get the next message
724
+ # the model is calling a tool function, consume the response to get the next message
723
725
  async for _ in model_response:
724
726
  pass
725
727
  structured_msg = model_response.get()
@@ -727,11 +729,11 @@ class Agent(Generic[AgentDeps, ResultData]):
727
729
  raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
728
730
  messages: list[_messages.Message] = [structured_msg]
729
731
 
730
- # we now run all retriever functions in parallel
732
+ # we now run all tool functions in parallel
731
733
  tasks: list[asyncio.Task[_messages.Message]] = []
732
734
  for call in structured_msg.calls:
733
- if retriever := self._retrievers.get(call.tool_name):
734
- tasks.append(asyncio.create_task(retriever.run(deps, call), name=call.tool_name))
735
+ if tool := self._function_tools.get(call.tool_name):
736
+ tasks.append(asyncio.create_task(tool.run(deps, call), name=call.tool_name))
735
737
  else:
736
738
  messages.append(self._unknown_tool(call.tool_name))
737
739
 
@@ -763,7 +765,7 @@ class Agent(Generic[AgentDeps, ResultData]):
763
765
 
764
766
  def _unknown_tool(self, tool_name: str) -> _messages.RetryPrompt:
765
767
  self._incr_result_retry()
766
- names = list(self._retrievers.keys())
768
+ names = list(self._function_tools.keys())
767
769
  if self._result_schema:
768
770
  names.extend(self._result_schema.tool_names())
769
771
  if names:
@@ -775,7 +777,7 @@ class Agent(Generic[AgentDeps, ResultData]):
775
777
  def _get_deps(self, deps: AgentDeps) -> AgentDeps:
776
778
  """Get deps for a run.
777
779
 
778
- If we've overridden deps via `_override_deps_stack`, use that, otherwise use the deps passed to the call.
780
+ If we've overridden deps via `_override_deps`, use that, otherwise use the deps passed to the call.
779
781
 
780
782
  We could do runtime type checking of deps against `self._deps_type`, but that's a slippery slope.
781
783
  """
@@ -13,13 +13,13 @@ else:
13
13
 
14
14
  __all__ = (
15
15
  'AgentDeps',
16
- 'CallContext',
16
+ 'RunContext',
17
17
  'ResultValidatorFunc',
18
18
  'SystemPromptFunc',
19
- 'RetrieverReturnValue',
20
- 'RetrieverContextFunc',
21
- 'RetrieverPlainFunc',
22
- 'RetrieverParams',
19
+ 'ToolReturnValue',
20
+ 'ToolContextFunc',
21
+ 'ToolPlainFunc',
22
+ 'ToolParams',
23
23
  'JsonData',
24
24
  )
25
25
 
@@ -28,7 +28,7 @@ AgentDeps = TypeVar('AgentDeps')
28
28
 
29
29
 
30
30
  @dataclass
31
- class CallContext(Generic[AgentDeps]):
31
+ class RunContext(Generic[AgentDeps]):
32
32
  """Information about the current call."""
33
33
 
34
34
  deps: AgentDeps
@@ -39,23 +39,23 @@ class CallContext(Generic[AgentDeps]):
39
39
  """Name of the tool being called."""
40
40
 
41
41
 
42
- RetrieverParams = ParamSpec('RetrieverParams')
42
+ ToolParams = ParamSpec('ToolParams')
43
43
  """Retrieval function param spec."""
44
44
 
45
45
  SystemPromptFunc = Union[
46
- Callable[[CallContext[AgentDeps]], str],
47
- Callable[[CallContext[AgentDeps]], Awaitable[str]],
46
+ Callable[[RunContext[AgentDeps]], str],
47
+ Callable[[RunContext[AgentDeps]], Awaitable[str]],
48
48
  Callable[[], str],
49
49
  Callable[[], Awaitable[str]],
50
50
  ]
51
- """A function that may or maybe not take `CallContext` as an argument, and may or may not be async.
51
+ """A function that may or maybe not take `RunContext` as an argument, and may or may not be async.
52
52
 
53
53
  Usage `SystemPromptFunc[AgentDeps]`.
54
54
  """
55
55
 
56
56
  ResultValidatorFunc = Union[
57
- Callable[[CallContext[AgentDeps], ResultData], ResultData],
58
- Callable[[CallContext[AgentDeps], ResultData], Awaitable[ResultData]],
57
+ Callable[[RunContext[AgentDeps], ResultData], ResultData],
58
+ Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]],
59
59
  Callable[[ResultData], ResultData],
60
60
  Callable[[ResultData], Awaitable[ResultData]],
61
61
  ]
@@ -69,15 +69,15 @@ Usage `ResultValidator[AgentDeps, ResultData]`.
69
69
  JsonData: TypeAlias = 'None | str | int | float | Sequence[JsonData] | Mapping[str, JsonData]'
70
70
  """Type representing any JSON data."""
71
71
 
72
- RetrieverReturnValue = Union[JsonData, Awaitable[JsonData]]
73
- """Return value of a retriever function."""
74
- RetrieverContextFunc = Callable[Concatenate[CallContext[AgentDeps], RetrieverParams], RetrieverReturnValue]
75
- """A retriever function that takes `CallContext` as the first argument.
72
+ ToolReturnValue = Union[JsonData, Awaitable[JsonData]]
73
+ """Return value of a tool function."""
74
+ ToolContextFunc = Callable[Concatenate[RunContext[AgentDeps], ToolParams], ToolReturnValue]
75
+ """A tool function that takes `RunContext` as the first argument.
76
76
 
77
- Usage `RetrieverContextFunc[AgentDeps, RetrieverParams]`.
77
+ Usage `ToolContextFunc[AgentDeps, ToolParams]`.
78
78
  """
79
- RetrieverPlainFunc = Callable[RetrieverParams, RetrieverReturnValue]
80
- """A retriever function that does not take `CallContext` as the first argument.
79
+ ToolPlainFunc = Callable[ToolParams, ToolReturnValue]
80
+ """A tool function that does not take `RunContext` as the first argument.
81
81
 
82
- Usage `RetrieverPlainFunc[RetrieverParams]`.
82
+ Usage `ToolPlainFunc[ToolParams]`.
83
83
  """
pydantic_ai/exceptions.py CHANGED
@@ -6,7 +6,7 @@ __all__ = 'ModelRetry', 'UserError', 'UnexpectedModelBehavior'
6
6
 
7
7
 
8
8
  class ModelRetry(Exception):
9
- """Exception raised when a retriever function should be retried.
9
+ """Exception raised when a tool function should be retried.
10
10
 
11
11
  The agent will return the message to the model and ask it to try calling the function/tool again.
12
12
  """