pydantic-ai-slim 0.0.8__tar.gz → 0.0.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (25) hide show
  1. {pydantic_ai_slim-0.0.8 → pydantic_ai_slim-0.0.9}/.gitignore +1 -0
  2. {pydantic_ai_slim-0.0.8 → pydantic_ai_slim-0.0.9}/PKG-INFO +1 -1
  3. {pydantic_ai_slim-0.0.8 → pydantic_ai_slim-0.0.9}/pydantic_ai/__init__.py +2 -2
  4. {pydantic_ai_slim-0.0.8 → pydantic_ai_slim-0.0.9}/pydantic_ai/_pydantic.py +27 -11
  5. {pydantic_ai_slim-0.0.8 → pydantic_ai_slim-0.0.9}/pydantic_ai/_result.py +1 -1
  6. {pydantic_ai_slim-0.0.8 → pydantic_ai_slim-0.0.9}/pydantic_ai/_system_prompt.py +1 -1
  7. {pydantic_ai_slim-0.0.8 → pydantic_ai_slim-0.0.9}/pydantic_ai/agent.py +44 -32
  8. {pydantic_ai_slim-0.0.8 → pydantic_ai_slim-0.0.9}/pydantic_ai/messages.py +7 -16
  9. {pydantic_ai_slim-0.0.8 → pydantic_ai_slim-0.0.9}/pydantic_ai/models/__init__.py +21 -11
  10. {pydantic_ai_slim-0.0.8 → pydantic_ai_slim-0.0.9}/pydantic_ai/result.py +1 -1
  11. pydantic_ai_slim-0.0.9/pydantic_ai/tools.py +240 -0
  12. {pydantic_ai_slim-0.0.8 → pydantic_ai_slim-0.0.9}/pyproject.toml +1 -1
  13. pydantic_ai_slim-0.0.8/pydantic_ai/_tool.py +0 -112
  14. pydantic_ai_slim-0.0.8/pydantic_ai/dependencies.py +0 -83
  15. {pydantic_ai_slim-0.0.8 → pydantic_ai_slim-0.0.9}/README.md +0 -0
  16. {pydantic_ai_slim-0.0.8 → pydantic_ai_slim-0.0.9}/pydantic_ai/_griffe.py +0 -0
  17. {pydantic_ai_slim-0.0.8 → pydantic_ai_slim-0.0.9}/pydantic_ai/_utils.py +0 -0
  18. {pydantic_ai_slim-0.0.8 → pydantic_ai_slim-0.0.9}/pydantic_ai/exceptions.py +0 -0
  19. {pydantic_ai_slim-0.0.8 → pydantic_ai_slim-0.0.9}/pydantic_ai/models/function.py +0 -0
  20. {pydantic_ai_slim-0.0.8 → pydantic_ai_slim-0.0.9}/pydantic_ai/models/gemini.py +0 -0
  21. {pydantic_ai_slim-0.0.8 → pydantic_ai_slim-0.0.9}/pydantic_ai/models/groq.py +0 -0
  22. {pydantic_ai_slim-0.0.8 → pydantic_ai_slim-0.0.9}/pydantic_ai/models/openai.py +0 -0
  23. {pydantic_ai_slim-0.0.8 → pydantic_ai_slim-0.0.9}/pydantic_ai/models/test.py +0 -0
  24. {pydantic_ai_slim-0.0.8 → pydantic_ai_slim-0.0.9}/pydantic_ai/models/vertexai.py +0 -0
  25. {pydantic_ai_slim-0.0.8 → pydantic_ai_slim-0.0.9}/pydantic_ai/py.typed +0 -0
@@ -13,3 +13,4 @@ env*/
13
13
  /pydantic_ai_examples/.chat_app_messages.jsonl
14
14
  .cache/
15
15
  .docs-insiders-install
16
+ .vscode/
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: pydantic-ai-slim
3
- Version: 0.0.8
3
+ Version: 0.0.9
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>
6
6
  License: MIT
@@ -1,8 +1,8 @@
1
1
  from importlib.metadata import version
2
2
 
3
3
  from .agent import Agent
4
- from .dependencies import RunContext
5
4
  from .exceptions import ModelRetry, UnexpectedModelBehavior, UserError
5
+ from .tools import RunContext, Tool
6
6
 
7
- __all__ = 'Agent', 'RunContext', 'ModelRetry', 'UnexpectedModelBehavior', 'UserError', '__version__'
7
+ __all__ = 'Agent', 'Tool', 'RunContext', 'ModelRetry', 'UnexpectedModelBehavior', 'UserError', '__version__'
8
8
  __version__ = version('pydantic_ai_slim')
@@ -6,7 +6,7 @@ This module has to use numerous internal Pydantic APIs and is therefore brittle
6
6
  from __future__ import annotations as _annotations
7
7
 
8
8
  from inspect import Parameter, signature
9
- from typing import TYPE_CHECKING, Any, TypedDict, cast, get_origin
9
+ from typing import TYPE_CHECKING, Any, Callable, TypedDict, cast, get_origin
10
10
 
11
11
  from pydantic import ConfigDict, TypeAdapter
12
12
  from pydantic._internal import _decorators, _generate_schema, _typing_extra
@@ -20,8 +20,7 @@ from ._griffe import doc_descriptions
20
20
  from ._utils import ObjectJsonSchema, check_object_json_schema, is_model_like
21
21
 
22
22
  if TYPE_CHECKING:
23
- from . import _tool
24
- from .dependencies import AgentDeps, ToolParams
23
+ pass
25
24
 
26
25
 
27
26
  __all__ = 'function_schema', 'LazyTypeAdapter'
@@ -39,17 +38,16 @@ class FunctionSchema(TypedDict):
39
38
  var_positional_field: str | None
40
39
 
41
40
 
42
- def function_schema(either_function: _tool.ToolEitherFunc[AgentDeps, ToolParams]) -> FunctionSchema: # noqa: C901
41
+ def function_schema(function: Callable[..., Any], takes_ctx: bool) -> FunctionSchema: # noqa: C901
43
42
  """Build a Pydantic validator and JSON schema from a tool function.
44
43
 
45
44
  Args:
46
- either_function: The function to build a validator and JSON schema for.
45
+ function: The function to build a validator and JSON schema for.
46
+ takes_ctx: Whether the function takes a `RunContext` first argument.
47
47
 
48
48
  Returns:
49
49
  A `FunctionSchema` instance.
50
50
  """
51
- function = either_function.whichever()
52
- takes_ctx = either_function.is_left()
53
51
  config = ConfigDict(title=function.__name__)
54
52
  config_wrapper = ConfigWrapper(config)
55
53
  gen_schema = _generate_schema.GenerateSchema(config_wrapper)
@@ -78,13 +76,13 @@ def function_schema(either_function: _tool.ToolEitherFunc[AgentDeps, ToolParams]
78
76
 
79
77
  if index == 0 and takes_ctx:
80
78
  if not _is_call_ctx(annotation):
81
- errors.append('First argument must be a RunContext instance when using `.tool`')
79
+ errors.append('First parameter of tools that take context must be annotated with RunContext[...]')
82
80
  continue
83
81
  elif not takes_ctx and _is_call_ctx(annotation):
84
- errors.append('RunContext instance can only be used with `.tool`')
82
+ errors.append('RunContext annotations can only be used with tools that take context')
85
83
  continue
86
84
  elif index != 0 and _is_call_ctx(annotation):
87
- errors.append('RunContext instance can only be used as the first argument')
85
+ errors.append('RunContext annotations can only be used as the first argument')
88
86
  continue
89
87
 
90
88
  field_name = p.name
@@ -159,6 +157,24 @@ def function_schema(either_function: _tool.ToolEitherFunc[AgentDeps, ToolParams]
159
157
  )
160
158
 
161
159
 
160
+ def takes_ctx(function: Callable[..., Any]) -> bool:
161
+ """Check if a function takes a `RunContext` first argument.
162
+
163
+ Args:
164
+ function: The function to check.
165
+
166
+ Returns:
167
+ `True` if the function takes a `RunContext` as first argument, `False` otherwise.
168
+ """
169
+ sig = signature(function)
170
+ try:
171
+ _, first_param = next(iter(sig.parameters.items()))
172
+ except StopIteration:
173
+ return False
174
+ else:
175
+ return first_param.annotation is not sig.empty and _is_call_ctx(first_param.annotation)
176
+
177
+
162
178
  def _build_schema(
163
179
  fields: dict[str, core_schema.TypedDictField],
164
180
  var_kwargs_schema: core_schema.CoreSchema | None,
@@ -191,7 +207,7 @@ def _build_schema(
191
207
 
192
208
 
193
209
  def _is_call_ctx(annotation: Any) -> bool:
194
- from .dependencies import RunContext
210
+ from .tools import RunContext
195
211
 
196
212
  return annotation is RunContext or (
197
213
  _typing_extra.is_generic_alias(annotation) and get_origin(annotation) is RunContext
@@ -11,10 +11,10 @@ from pydantic import TypeAdapter, ValidationError
11
11
  from typing_extensions import Self, TypeAliasType, TypedDict
12
12
 
13
13
  from . import _utils, messages
14
- from .dependencies import AgentDeps, ResultValidatorFunc, RunContext
15
14
  from .exceptions import ModelRetry
16
15
  from .messages import ModelStructuredResponse, ToolCall
17
16
  from .result import ResultData
17
+ from .tools import AgentDeps, ResultValidatorFunc, RunContext
18
18
 
19
19
 
20
20
  @dataclass
@@ -6,7 +6,7 @@ from dataclasses import dataclass, field
6
6
  from typing import Any, Callable, Generic, cast
7
7
 
8
8
  from . import _utils
9
- from .dependencies import AgentDeps, RunContext, SystemPromptFunc
9
+ from .tools import AgentDeps, RunContext, SystemPromptFunc
10
10
 
11
11
 
12
12
  @dataclass
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import asyncio
4
+ import dataclasses
4
5
  from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
5
6
  from contextlib import asynccontextmanager, contextmanager
6
7
  from dataclasses import dataclass, field
@@ -12,15 +13,14 @@ from typing_extensions import assert_never
12
13
  from . import (
13
14
  _result,
14
15
  _system_prompt,
15
- _tool as _r,
16
16
  _utils,
17
17
  exceptions,
18
18
  messages as _messages,
19
19
  models,
20
20
  result,
21
21
  )
22
- from .dependencies import AgentDeps, RunContext, ToolContextFunc, ToolParams, ToolPlainFunc
23
22
  from .result import ResultData
23
+ from .tools import AgentDeps, RunContext, Tool, ToolFuncContext, ToolFuncEither, ToolFuncPlain, ToolParams
24
24
 
25
25
  __all__ = ('Agent',)
26
26
 
@@ -34,7 +34,7 @@ NoneType = type(None)
34
34
  class Agent(Generic[AgentDeps, ResultData]):
35
35
  """Class for defining "agents" - a way to have a specific type of "conversation" with an LLM.
36
36
 
37
- Agents are generic in the dependency type they take [`AgentDeps`][pydantic_ai.dependencies.AgentDeps]
37
+ Agents are generic in the dependency type they take [`AgentDeps`][pydantic_ai.tools.AgentDeps]
38
38
  and the result data type they return, [`ResultData`][pydantic_ai.result.ResultData].
39
39
 
40
40
  By default, if neither generic parameter is customised, agents have type `Agent[None, str]`.
@@ -58,7 +58,7 @@ class Agent(Generic[AgentDeps, ResultData]):
58
58
  _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = field(repr=False)
59
59
  _allow_text_result: bool = field(repr=False)
60
60
  _system_prompts: tuple[str, ...] = field(repr=False)
61
- _function_tools: dict[str, _r.Tool[AgentDeps, Any]] = field(repr=False)
61
+ _function_tools: dict[str, Tool[AgentDeps]] = field(repr=False)
62
62
  _default_retries: int = field(repr=False)
63
63
  _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = field(repr=False)
64
64
  _deps_type: type[AgentDeps] = field(repr=False)
@@ -83,6 +83,7 @@ class Agent(Generic[AgentDeps, ResultData]):
83
83
  result_tool_name: str = 'final_result',
84
84
  result_tool_description: str | None = None,
85
85
  result_retries: int | None = None,
86
+ tools: Sequence[Tool[AgentDeps] | ToolFuncEither[AgentDeps, ...]] = (),
86
87
  defer_model_check: bool = False,
87
88
  ):
88
89
  """Create an agent.
@@ -101,6 +102,8 @@ class Agent(Generic[AgentDeps, ResultData]):
101
102
  result_tool_name: The name of the tool to use for the final result.
102
103
  result_tool_description: The description of the final result tool.
103
104
  result_retries: The maximum number of retries to allow for result validation, defaults to `retries`.
105
+ tools: Tools to register with the agent, you can also register tools via the decorators
106
+ [`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain].
104
107
  defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model,
105
108
  it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately,
106
109
  which checks for the necessary environment variables. Set this to `false`
@@ -119,9 +122,11 @@ class Agent(Generic[AgentDeps, ResultData]):
119
122
  self._allow_text_result = self._result_schema is None or self._result_schema.allow_text_result
120
123
 
121
124
  self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt)
122
- self._function_tools: dict[str, _r.Tool[AgentDeps, Any]] = {}
123
- self._deps_type = deps_type
125
+ self._function_tools = {}
124
126
  self._default_retries = retries
127
+ for tool in tools:
128
+ self._register_tool(Tool.infer(tool))
129
+ self._deps_type = deps_type
125
130
  self._system_prompt_functions = []
126
131
  self._max_result_retries = result_retries if result_retries is not None else retries
127
132
  self._current_result_retry = 0
@@ -206,7 +211,7 @@ class Agent(Generic[AgentDeps, ResultData]):
206
211
  ) -> result.RunResult[ResultData]:
207
212
  """Run the agent with a user prompt synchronously.
208
213
 
209
- This is a convenience method that wraps `self.run` with `asyncio.run()`.
214
+ This is a convenience method that wraps `self.run` with `loop.run_until_complete()`.
210
215
 
211
216
  Args:
212
217
  user_prompt: User input to start/continue the conversation.
@@ -217,7 +222,8 @@ class Agent(Generic[AgentDeps, ResultData]):
217
222
  Returns:
218
223
  The result of the run.
219
224
  """
220
- return asyncio.run(self.run(user_prompt, message_history=message_history, model=model, deps=deps))
225
+ loop = asyncio.get_event_loop()
226
+ return loop.run_until_complete(self.run(user_prompt, message_history=message_history, model=model, deps=deps))
221
227
 
222
228
  @asynccontextmanager
223
229
  async def run_stream(
@@ -354,7 +360,7 @@ class Agent(Generic[AgentDeps, ResultData]):
354
360
  ) -> _system_prompt.SystemPromptFunc[AgentDeps]:
355
361
  """Decorator to register a system prompt function.
356
362
 
357
- Optionally takes [`RunContext`][pydantic_ai.dependencies.RunContext] as it's only argument.
363
+ Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as it's only argument.
358
364
  Can decorate a sync or async functions.
359
365
 
360
366
  Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure
@@ -405,7 +411,7 @@ class Agent(Generic[AgentDeps, ResultData]):
405
411
  ) -> _result.ResultValidatorFunc[AgentDeps, ResultData]:
406
412
  """Decorator to register a result validator function.
407
413
 
408
- Optionally takes [`RunContext`][pydantic_ai.dependencies.RunContext] as it's first argument.
414
+ Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as it's first argument.
409
415
  Can decorate a sync or async functions.
410
416
 
411
417
  Overloads for every possible signature of `result_validator` are included so the decorator doesn't obscure
@@ -438,22 +444,22 @@ class Agent(Generic[AgentDeps, ResultData]):
438
444
  return func
439
445
 
440
446
  @overload
441
- def tool(self, func: ToolContextFunc[AgentDeps, ToolParams], /) -> ToolContextFunc[AgentDeps, ToolParams]: ...
447
+ def tool(self, func: ToolFuncContext[AgentDeps, ToolParams], /) -> ToolFuncContext[AgentDeps, ToolParams]: ...
442
448
 
443
449
  @overload
444
450
  def tool(
445
451
  self, /, *, retries: int | None = None
446
- ) -> Callable[[ToolContextFunc[AgentDeps, ToolParams]], ToolContextFunc[AgentDeps, ToolParams]]: ...
452
+ ) -> Callable[[ToolFuncContext[AgentDeps, ToolParams]], ToolFuncContext[AgentDeps, ToolParams]]: ...
447
453
 
448
454
  def tool(
449
455
  self,
450
- func: ToolContextFunc[AgentDeps, ToolParams] | None = None,
456
+ func: ToolFuncContext[AgentDeps, ToolParams] | None = None,
451
457
  /,
452
458
  *,
453
459
  retries: int | None = None,
454
460
  ) -> Any:
455
461
  """Decorator to register a tool function which takes
456
- [`RunContext`][pydantic_ai.dependencies.RunContext] as its first argument.
462
+ [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
457
463
 
458
464
  Can decorate a sync or async functions.
459
465
 
@@ -490,27 +496,27 @@ class Agent(Generic[AgentDeps, ResultData]):
490
496
  if func is None:
491
497
 
492
498
  def tool_decorator(
493
- func_: ToolContextFunc[AgentDeps, ToolParams],
494
- ) -> ToolContextFunc[AgentDeps, ToolParams]:
499
+ func_: ToolFuncContext[AgentDeps, ToolParams],
500
+ ) -> ToolFuncContext[AgentDeps, ToolParams]:
495
501
  # noinspection PyTypeChecker
496
- self._register_tool(_utils.Either(left=func_), retries)
502
+ self._register_function(func_, True, retries)
497
503
  return func_
498
504
 
499
505
  return tool_decorator
500
506
  else:
501
507
  # noinspection PyTypeChecker
502
- self._register_tool(_utils.Either(left=func), retries)
508
+ self._register_function(func, True, retries)
503
509
  return func
504
510
 
505
511
  @overload
506
- def tool_plain(self, func: ToolPlainFunc[ToolParams], /) -> ToolPlainFunc[ToolParams]: ...
512
+ def tool_plain(self, func: ToolFuncPlain[ToolParams], /) -> ToolFuncPlain[ToolParams]: ...
507
513
 
508
514
  @overload
509
515
  def tool_plain(
510
516
  self, /, *, retries: int | None = None
511
- ) -> Callable[[ToolPlainFunc[ToolParams]], ToolPlainFunc[ToolParams]]: ...
517
+ ) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...
512
518
 
513
- def tool_plain(self, func: ToolPlainFunc[ToolParams] | None = None, /, *, retries: int | None = None) -> Any:
519
+ def tool_plain(self, func: ToolFuncPlain[ToolParams] | None = None, /, *, retries: int | None = None) -> Any:
514
520
  """Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
515
521
 
516
522
  Can decorate a sync or async functions.
@@ -547,28 +553,34 @@ class Agent(Generic[AgentDeps, ResultData]):
547
553
  """
548
554
  if func is None:
549
555
 
550
- def tool_decorator(
551
- func_: ToolPlainFunc[ToolParams],
552
- ) -> ToolPlainFunc[ToolParams]:
556
+ def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
553
557
  # noinspection PyTypeChecker
554
- self._register_tool(_utils.Either(right=func_), retries)
558
+ self._register_function(func_, False, retries)
555
559
  return func_
556
560
 
557
561
  return tool_decorator
558
562
  else:
559
- self._register_tool(_utils.Either(right=func), retries)
563
+ self._register_function(func, False, retries)
560
564
  return func
561
565
 
562
- def _register_tool(self, func: _r.ToolEitherFunc[AgentDeps, ToolParams], retries: int | None) -> None:
563
- """Private utility to register a tool function."""
566
+ def _register_function(
567
+ self, func: ToolFuncEither[AgentDeps, ToolParams], takes_ctx: bool, retries: int | None
568
+ ) -> None:
569
+ """Private utility to register a function as a tool."""
564
570
  retries_ = retries if retries is not None else self._default_retries
565
- tool = _r.Tool[AgentDeps, ToolParams](func, retries_)
571
+ tool = Tool(func, takes_ctx, max_retries=retries_)
572
+ self._register_tool(tool)
566
573
 
567
- if self._result_schema and tool.name in self._result_schema.tools:
568
- raise ValueError(f'Tool name conflicts with result schema name: {tool.name!r}')
574
+ def _register_tool(self, tool: Tool[AgentDeps]) -> None:
575
+ """Private utility to register a tool instance."""
576
+ if tool.max_retries is None:
577
+ tool = dataclasses.replace(tool, max_retries=self._default_retries)
569
578
 
570
579
  if tool.name in self._function_tools:
571
- raise ValueError(f'Tool name conflicts with existing tool: {tool.name!r}')
580
+ raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}')
581
+
582
+ if self._result_schema and tool.name in self._result_schema.tools:
583
+ raise exceptions.UserError(f'Tool name conflicts with result schema name: {tool.name!r}')
572
584
 
573
585
  self._function_tools[tool.name] = tool
574
586
 
@@ -1,15 +1,13 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import json
4
- from collections.abc import Mapping, Sequence
5
4
  from dataclasses import dataclass, field
6
5
  from datetime import datetime
7
- from typing import TYPE_CHECKING, Annotated, Any, Literal, Union
6
+ from typing import Annotated, Any, Literal, Union
8
7
 
9
8
  import pydantic
10
9
  import pydantic_core
11
10
  from pydantic import TypeAdapter
12
- from typing_extensions import TypeAlias, TypeAliasType
13
11
 
14
12
  from . import _pydantic
15
13
  from ._utils import now_utc as _now_utc
@@ -44,13 +42,7 @@ class UserPrompt:
44
42
  """Message type identifier, this type is available on all message as a discriminator."""
45
43
 
46
44
 
47
- JsonData: TypeAlias = 'Union[str, int, float, None, Sequence[JsonData], Mapping[str, JsonData]]'
48
- if not TYPE_CHECKING:
49
- # work around for https://github.com/pydantic/pydantic/issues/10873
50
- # this is need for pydantic to work both `json_ta` and `MessagesTypeAdapter` at the bottom of this file
51
- JsonData = TypeAliasType('JsonData', 'Union[str, int, float, None, Sequence[JsonData], Mapping[str, JsonData]]')
52
-
53
- json_ta: TypeAdapter[JsonData] = TypeAdapter(JsonData)
45
+ tool_return_ta: TypeAdapter[Any] = TypeAdapter(Any)
54
46
 
55
47
 
56
48
  @dataclass
@@ -59,7 +51,7 @@ class ToolReturn:
59
51
 
60
52
  tool_name: str
61
53
  """The name of the "tool" was called."""
62
- content: JsonData
54
+ content: Any
63
55
  """The return value."""
64
56
  tool_id: str | None = None
65
57
  """Optional tool identifier, this is used by some models including OpenAI."""
@@ -72,15 +64,14 @@ class ToolReturn:
72
64
  if isinstance(self.content, str):
73
65
  return self.content
74
66
  else:
75
- content = json_ta.validate_python(self.content)
76
- return json_ta.dump_json(content).decode()
67
+ return tool_return_ta.dump_json(self.content).decode()
77
68
 
78
- def model_response_object(self) -> dict[str, JsonData]:
69
+ def model_response_object(self) -> dict[str, Any]:
79
70
  # gemini supports JSON dict return values, but no other JSON types, hence we wrap anything else in a dict
80
71
  if isinstance(self.content, dict):
81
- return json_ta.validate_python(self.content) # pyright: ignore[reportReturnType]
72
+ return tool_return_ta.dump_python(self.content, mode='json') # pyright: ignore[reportUnknownMemberType]
82
73
  else:
83
- return {'return_value': json_ta.validate_python(self.content)}
74
+ return {'return_value': tool_return_ta.dump_python(self.content, mode='json')}
84
75
 
85
76
 
86
77
  @dataclass
@@ -259,20 +259,30 @@ def infer_model(model: Model | KnownModelName) -> Model:
259
259
  class AbstractToolDefinition(Protocol):
260
260
  """Abstract definition of a function/tool.
261
261
 
262
- This is used for both tools and result tools.
262
+ This is used for both function tools and result tools.
263
263
  """
264
264
 
265
- name: str
266
- """The name of the tool."""
267
- description: str
268
- """The description of the tool."""
269
- json_schema: ObjectJsonSchema
270
- """The JSON schema for the tool's arguments."""
271
- outer_typed_dict_key: str | None
272
- """The key in the outer [TypedDict] that wraps a result tool.
265
+ @property
266
+ def name(self) -> str:
267
+ """The name of the tool."""
268
+ ...
273
269
 
274
- This will only be set for result tools which don't have an `object` JSON schema.
275
- """
270
+ @property
271
+ def description(self) -> str:
272
+ """The description of the tool."""
273
+ ...
274
+
275
+ @property
276
+ def json_schema(self) -> ObjectJsonSchema:
277
+ """The JSON schema for the tool's arguments."""
278
+ ...
279
+
280
+ @property
281
+ def outer_typed_dict_key(self) -> str | None:
282
+ """The key in the outer [TypedDict] that wraps a result tool.
283
+
284
+ This will only be set for result tools which don't have an `object` JSON schema.
285
+ """
276
286
 
277
287
 
278
288
  @cache
@@ -9,7 +9,7 @@ from typing import Generic, TypeVar, cast
9
9
  import logfire_api
10
10
 
11
11
  from . import _result, _utils, exceptions, messages, models
12
- from .dependencies import AgentDeps
12
+ from .tools import AgentDeps
13
13
 
14
14
  __all__ = (
15
15
  'ResultData',
@@ -0,0 +1,240 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import inspect
4
+ from collections.abc import Awaitable
5
+ from dataclasses import dataclass, field
6
+ from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast
7
+
8
+ from pydantic import ValidationError
9
+ from pydantic_core import SchemaValidator
10
+ from typing_extensions import Concatenate, ParamSpec, final
11
+
12
+ from . import _pydantic, _utils, messages
13
+ from .exceptions import ModelRetry, UnexpectedModelBehavior
14
+
15
+ if TYPE_CHECKING:
16
+ from .result import ResultData
17
+ else:
18
+ ResultData = Any
19
+
20
+
21
+ __all__ = (
22
+ 'AgentDeps',
23
+ 'RunContext',
24
+ 'ResultValidatorFunc',
25
+ 'SystemPromptFunc',
26
+ 'ToolFuncContext',
27
+ 'ToolFuncPlain',
28
+ 'ToolFuncEither',
29
+ 'ToolParams',
30
+ 'Tool',
31
+ )
32
+
33
+ AgentDeps = TypeVar('AgentDeps')
34
+ """Type variable for agent dependencies."""
35
+
36
+
37
+ @dataclass
38
+ class RunContext(Generic[AgentDeps]):
39
+ """Information about the current call."""
40
+
41
+ deps: AgentDeps
42
+ """Dependencies for the agent."""
43
+ retry: int
44
+ """Number of retries so far."""
45
+ tool_name: str | None
46
+ """Name of the tool being called."""
47
+
48
+
49
+ ToolParams = ParamSpec('ToolParams')
50
+ """Retrieval function param spec."""
51
+
52
+ SystemPromptFunc = Union[
53
+ Callable[[RunContext[AgentDeps]], str],
54
+ Callable[[RunContext[AgentDeps]], Awaitable[str]],
55
+ Callable[[], str],
56
+ Callable[[], Awaitable[str]],
57
+ ]
58
+ """A function that may or maybe not take `RunContext` as an argument, and may or may not be async.
59
+
60
+ Usage `SystemPromptFunc[AgentDeps]`.
61
+ """
62
+
63
+ ResultValidatorFunc = Union[
64
+ Callable[[RunContext[AgentDeps], ResultData], ResultData],
65
+ Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]],
66
+ Callable[[ResultData], ResultData],
67
+ Callable[[ResultData], Awaitable[ResultData]],
68
+ ]
69
+ """
70
+ A function that always takes `ResultData` and returns `ResultData`,
71
+ but may or maybe not take `CallInfo` as a first argument, and may or may not be async.
72
+
73
+ Usage `ResultValidator[AgentDeps, ResultData]`.
74
+ """
75
+
76
+ ToolFuncContext = Callable[Concatenate[RunContext[AgentDeps], ToolParams], Any]
77
+ """A tool function that takes `RunContext` as the first argument.
78
+
79
+ Usage `ToolContextFunc[AgentDeps, ToolParams]`.
80
+ """
81
+ ToolFuncPlain = Callable[ToolParams, Any]
82
+ """A tool function that does not take `RunContext` as the first argument.
83
+
84
+ Usage `ToolPlainFunc[ToolParams]`.
85
+ """
86
+ ToolFuncEither = Union[ToolFuncContext[AgentDeps, ToolParams], ToolFuncPlain[ToolParams]]
87
+ """Either kind of tool function.
88
+
89
+ This is just a union of [`ToolFuncContext`][pydantic_ai.tools.ToolFuncContext] and
90
+ [`ToolFuncPlain`][pydantic_ai.tools.ToolFuncPlain].
91
+
92
+ Usage `ToolFuncEither[AgentDeps, ToolParams]`.
93
+ """
94
+
95
+ A = TypeVar('A')
96
+
97
+
98
+ @final
99
+ @dataclass(init=False)
100
+ class Tool(Generic[AgentDeps]):
101
+ """A tool function for an agent."""
102
+
103
+ function: ToolFuncEither[AgentDeps, ...]
104
+ takes_ctx: bool
105
+ max_retries: int | None
106
+ name: str
107
+ description: str
108
+ _is_async: bool = field(init=False)
109
+ _single_arg_name: str | None = field(init=False)
110
+ _positional_fields: list[str] = field(init=False)
111
+ _var_positional_field: str | None = field(init=False)
112
+ _validator: SchemaValidator = field(init=False, repr=False)
113
+ _json_schema: _utils.ObjectJsonSchema = field(init=False)
114
+ _current_retry: int = field(default=0, init=False)
115
+
116
+ def __init__(
117
+ self,
118
+ function: ToolFuncEither[AgentDeps, ...],
119
+ takes_ctx: bool,
120
+ *,
121
+ max_retries: int | None = None,
122
+ name: str | None = None,
123
+ description: str | None = None,
124
+ ):
125
+ """Create a new tool instance.
126
+
127
+ Example usage:
128
+
129
+ ```py
130
+ from pydantic_ai import Agent, RunContext, Tool
131
+
132
+ async def my_tool(ctx: RunContext[int], x: int, y: int) -> str:
133
+ return f'{ctx.deps} {x} {y}'
134
+
135
+ agent = Agent('test', tools=[Tool(my_tool, True)])
136
+ ```
137
+
138
+ Args:
139
+ function: The Python function to call as the tool.
140
+ takes_ctx: Whether the function takes a [`RunContext`][pydantic_ai.tools.RunContext] first argument.
141
+ max_retries: Maximum number of retries allowed for this tool, set to the agent default if `None`.
142
+ name: Name of the tool, inferred from the function if `None`.
143
+ description: Description of the tool, inferred from the function if `None`.
144
+ """
145
+ f = _pydantic.function_schema(function, takes_ctx)
146
+ self.function = function
147
+ self.takes_ctx = takes_ctx
148
+ self.max_retries = max_retries
149
+ self.name = name or function.__name__
150
+ self.description = description or f['description']
151
+ self._is_async = inspect.iscoroutinefunction(self.function)
152
+ self._single_arg_name = f['single_arg_name']
153
+ self._positional_fields = f['positional_fields']
154
+ self._var_positional_field = f['var_positional_field']
155
+ self._validator = f['validator']
156
+ self._json_schema = f['json_schema']
157
+
158
+ @staticmethod
159
+ def infer(function: ToolFuncEither[A, ...] | Tool[A]) -> Tool[A]:
160
+ """Create a tool from a pure function, inferring whether it takes `RunContext` as its first argument.
161
+
162
+ Args:
163
+ function: The tool function to wrap; or for convenience, a `Tool` instance.
164
+
165
+ Returns:
166
+ A new `Tool` instance.
167
+ """
168
+ if isinstance(function, Tool):
169
+ return function
170
+ else:
171
+ return Tool(function, takes_ctx=_pydantic.takes_ctx(function))
172
+
173
+ def reset(self) -> None:
174
+ """Reset the current retry count."""
175
+ self._current_retry = 0
176
+
177
+ async def run(self, deps: AgentDeps, message: messages.ToolCall) -> messages.Message:
178
+ """Run the tool function asynchronously."""
179
+ try:
180
+ if isinstance(message.args, messages.ArgsJson):
181
+ args_dict = self._validator.validate_json(message.args.args_json)
182
+ else:
183
+ args_dict = self._validator.validate_python(message.args.args_dict)
184
+ except ValidationError as e:
185
+ return self._on_error(e, message)
186
+
187
+ args, kwargs = self._call_args(deps, args_dict, message)
188
+ try:
189
+ if self._is_async:
190
+ function = cast(Callable[[Any], Awaitable[str]], self.function)
191
+ response_content = await function(*args, **kwargs)
192
+ else:
193
+ function = cast(Callable[[Any], str], self.function)
194
+ response_content = await _utils.run_in_executor(function, *args, **kwargs)
195
+ except ModelRetry as e:
196
+ return self._on_error(e, message)
197
+
198
+ self._current_retry = 0
199
+ return messages.ToolReturn(
200
+ tool_name=message.tool_name,
201
+ content=response_content,
202
+ tool_id=message.tool_id,
203
+ )
204
+
205
+ @property
206
+ def json_schema(self) -> _utils.ObjectJsonSchema:
207
+ return self._json_schema
208
+
209
+ @property
210
+ def outer_typed_dict_key(self) -> str | None:
211
+ return None
212
+
213
+ def _call_args(
214
+ self, deps: AgentDeps, args_dict: dict[str, Any], message: messages.ToolCall
215
+ ) -> tuple[list[Any], dict[str, Any]]:
216
+ if self._single_arg_name:
217
+ args_dict = {self._single_arg_name: args_dict}
218
+
219
+ args = [RunContext(deps, self._current_retry, message.tool_name)] if self.takes_ctx else []
220
+ for positional_field in self._positional_fields:
221
+ args.append(args_dict.pop(positional_field))
222
+ if self._var_positional_field:
223
+ args.extend(args_dict.pop(self._var_positional_field))
224
+
225
+ return args, args_dict
226
+
227
+ def _on_error(self, exc: ValidationError | ModelRetry, call_message: messages.ToolCall) -> messages.RetryPrompt:
228
+ self._current_retry += 1
229
+ if self.max_retries is None or self._current_retry > self.max_retries:
230
+ raise UnexpectedModelBehavior(f'Tool exceeded max retries count of {self.max_retries}') from exc
231
+ else:
232
+ if isinstance(exc, ValidationError):
233
+ content = exc.errors(include_url=False)
234
+ else:
235
+ content = exc.message
236
+ return messages.RetryPrompt(
237
+ tool_name=call_message.tool_name,
238
+ content=content,
239
+ tool_id=call_message.tool_id,
240
+ )
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "pydantic-ai-slim"
7
- version = "0.0.8"
7
+ version = "0.0.9"
8
8
  description = "Agent Framework / shim to use Pydantic with LLMs, slim package"
9
9
  authors = [
10
10
  { name = "Samuel Colvin", email = "samuel@pydantic.dev" },
@@ -1,112 +0,0 @@
1
- from __future__ import annotations as _annotations
2
-
3
- import inspect
4
- from collections.abc import Awaitable
5
- from dataclasses import dataclass, field
6
- from typing import Any, Callable, Generic, cast
7
-
8
- from pydantic import ValidationError
9
- from pydantic_core import SchemaValidator
10
-
11
- from . import _pydantic, _utils, messages
12
- from .dependencies import AgentDeps, RunContext, ToolContextFunc, ToolParams, ToolPlainFunc
13
- from .exceptions import ModelRetry, UnexpectedModelBehavior
14
-
15
- # Usage `ToolEitherFunc[AgentDependencies, P]`
16
- ToolEitherFunc = _utils.Either[ToolContextFunc[AgentDeps, ToolParams], ToolPlainFunc[ToolParams]]
17
-
18
-
19
- @dataclass(init=False)
20
- class Tool(Generic[AgentDeps, ToolParams]):
21
- """A tool function for an agent."""
22
-
23
- name: str
24
- description: str
25
- function: ToolEitherFunc[AgentDeps, ToolParams] = field(repr=False)
26
- is_async: bool
27
- single_arg_name: str | None
28
- positional_fields: list[str]
29
- var_positional_field: str | None
30
- validator: SchemaValidator = field(repr=False)
31
- json_schema: _utils.ObjectJsonSchema
32
- max_retries: int
33
- _current_retry: int = 0
34
- outer_typed_dict_key: str | None = None
35
-
36
- def __init__(self, function: ToolEitherFunc[AgentDeps, ToolParams], retries: int):
37
- """Build a Tool dataclass from a function."""
38
- self.function = function
39
- # noinspection PyTypeChecker
40
- f = _pydantic.function_schema(function)
41
- raw_function = function.whichever()
42
- self.name = raw_function.__name__
43
- self.description = f['description']
44
- self.is_async = inspect.iscoroutinefunction(raw_function)
45
- self.single_arg_name = f['single_arg_name']
46
- self.positional_fields = f['positional_fields']
47
- self.var_positional_field = f['var_positional_field']
48
- self.validator = f['validator']
49
- self.json_schema = f['json_schema']
50
- self.max_retries = retries
51
-
52
- def reset(self) -> None:
53
- """Reset the current retry count."""
54
- self._current_retry = 0
55
-
56
- async def run(self, deps: AgentDeps, message: messages.ToolCall) -> messages.Message:
57
- """Run the tool function asynchronously."""
58
- try:
59
- if isinstance(message.args, messages.ArgsJson):
60
- args_dict = self.validator.validate_json(message.args.args_json)
61
- else:
62
- args_dict = self.validator.validate_python(message.args.args_dict)
63
- except ValidationError as e:
64
- return self._on_error(e, message)
65
-
66
- args, kwargs = self._call_args(deps, args_dict, message)
67
- try:
68
- if self.is_async:
69
- function = cast(Callable[[Any], Awaitable[str]], self.function.whichever())
70
- response_content = await function(*args, **kwargs)
71
- else:
72
- function = cast(Callable[[Any], str], self.function.whichever())
73
- response_content = await _utils.run_in_executor(function, *args, **kwargs)
74
- except ModelRetry as e:
75
- return self._on_error(e, message)
76
-
77
- self._current_retry = 0
78
- return messages.ToolReturn(
79
- tool_name=message.tool_name,
80
- content=response_content,
81
- tool_id=message.tool_id,
82
- )
83
-
84
- def _call_args(
85
- self, deps: AgentDeps, args_dict: dict[str, Any], message: messages.ToolCall
86
- ) -> tuple[list[Any], dict[str, Any]]:
87
- if self.single_arg_name:
88
- args_dict = {self.single_arg_name: args_dict}
89
-
90
- args = [RunContext(deps, self._current_retry, message.tool_name)] if self.function.is_left() else []
91
- for positional_field in self.positional_fields:
92
- args.append(args_dict.pop(positional_field))
93
- if self.var_positional_field:
94
- args.extend(args_dict.pop(self.var_positional_field))
95
-
96
- return args, args_dict
97
-
98
- def _on_error(self, exc: ValidationError | ModelRetry, call_message: messages.ToolCall) -> messages.RetryPrompt:
99
- self._current_retry += 1
100
- if self._current_retry > self.max_retries:
101
- # TODO custom error with details of the tool
102
- raise UnexpectedModelBehavior(f'Tool exceeded max retries count of {self.max_retries}') from exc
103
- else:
104
- if isinstance(exc, ValidationError):
105
- content = exc.errors(include_url=False)
106
- else:
107
- content = exc.message
108
- return messages.RetryPrompt(
109
- tool_name=call_message.tool_name,
110
- content=content,
111
- tool_id=call_message.tool_id,
112
- )
@@ -1,83 +0,0 @@
1
- from __future__ import annotations as _annotations
2
-
3
- from collections.abc import Awaitable, Mapping, Sequence
4
- from dataclasses import dataclass
5
- from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union
6
-
7
- from typing_extensions import Concatenate, ParamSpec, TypeAlias
8
-
9
- if TYPE_CHECKING:
10
- from .result import ResultData
11
- else:
12
- ResultData = Any
13
-
14
- __all__ = (
15
- 'AgentDeps',
16
- 'RunContext',
17
- 'ResultValidatorFunc',
18
- 'SystemPromptFunc',
19
- 'ToolReturnValue',
20
- 'ToolContextFunc',
21
- 'ToolPlainFunc',
22
- 'ToolParams',
23
- 'JsonData',
24
- )
25
-
26
- AgentDeps = TypeVar('AgentDeps')
27
- """Type variable for agent dependencies."""
28
-
29
-
30
- @dataclass
31
- class RunContext(Generic[AgentDeps]):
32
- """Information about the current call."""
33
-
34
- deps: AgentDeps
35
- """Dependencies for the agent."""
36
- retry: int
37
- """Number of retries so far."""
38
- tool_name: str | None
39
- """Name of the tool being called."""
40
-
41
-
42
- ToolParams = ParamSpec('ToolParams')
43
- """Retrieval function param spec."""
44
-
45
- SystemPromptFunc = Union[
46
- Callable[[RunContext[AgentDeps]], str],
47
- Callable[[RunContext[AgentDeps]], Awaitable[str]],
48
- Callable[[], str],
49
- Callable[[], Awaitable[str]],
50
- ]
51
- """A function that may or maybe not take `RunContext` as an argument, and may or may not be async.
52
-
53
- Usage `SystemPromptFunc[AgentDeps]`.
54
- """
55
-
56
- ResultValidatorFunc = Union[
57
- Callable[[RunContext[AgentDeps], ResultData], ResultData],
58
- Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]],
59
- Callable[[ResultData], ResultData],
60
- Callable[[ResultData], Awaitable[ResultData]],
61
- ]
62
- """
63
- A function that always takes `ResultData` and returns `ResultData`,
64
- but may or maybe not take `CallInfo` as a first argument, and may or may not be async.
65
-
66
- Usage `ResultValidator[AgentDeps, ResultData]`.
67
- """
68
-
69
- JsonData: TypeAlias = 'None | str | int | float | Sequence[JsonData] | Mapping[str, JsonData]'
70
- """Type representing any JSON data."""
71
-
72
- ToolReturnValue = Union[JsonData, Awaitable[JsonData]]
73
- """Return value of a tool function."""
74
- ToolContextFunc = Callable[Concatenate[RunContext[AgentDeps], ToolParams], ToolReturnValue]
75
- """A tool function that takes `RunContext` as the first argument.
76
-
77
- Usage `ToolContextFunc[AgentDeps, ToolParams]`.
78
- """
79
- ToolPlainFunc = Callable[ToolParams, ToolReturnValue]
80
- """A tool function that does not take `RunContext` as the first argument.
81
-
82
- Usage `ToolPlainFunc[ToolParams]`.
83
- """