pydantic-ai-slim 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/_pydantic.py CHANGED
@@ -8,7 +8,7 @@ from __future__ import annotations as _annotations
8
8
  from inspect import Parameter, signature
9
9
  from typing import TYPE_CHECKING, Any, Callable, TypedDict, cast, get_origin
10
10
 
11
- from pydantic import ConfigDict, TypeAdapter
11
+ from pydantic import ConfigDict
12
12
  from pydantic._internal import _decorators, _generate_schema, _typing_extra
13
13
  from pydantic._internal._config import ConfigWrapper
14
14
  from pydantic.fields import FieldInfo
@@ -17,13 +17,13 @@ from pydantic.plugin._schema_validator import create_schema_validator
17
17
  from pydantic_core import SchemaValidator, core_schema
18
18
 
19
19
  from ._griffe import doc_descriptions
20
- from ._utils import ObjectJsonSchema, check_object_json_schema, is_model_like
20
+ from ._utils import check_object_json_schema, is_model_like
21
21
 
22
22
  if TYPE_CHECKING:
23
- pass
23
+ from .tools import ObjectJsonSchema
24
24
 
25
25
 
26
- __all__ = 'function_schema', 'LazyTypeAdapter'
26
+ __all__ = ('function_schema',)
27
27
 
28
28
 
29
29
  class FunctionSchema(TypedDict):
@@ -138,14 +138,14 @@ def function_schema(function: Callable[..., Any], takes_ctx: bool) -> FunctionSc
138
138
  json_schema = GenerateJsonSchema().generate(schema)
139
139
 
140
140
  # workaround for https://github.com/pydantic/pydantic/issues/10785
141
- # if we build a custom TypeDict schema (matches when `single_arg_name` is None), we manually set
141
+ # if we build a custom TypeDict schema (matches when `single_arg_name is None`), we manually set
142
142
  # `additionalProperties` in the JSON Schema
143
143
  if single_arg_name is None:
144
144
  json_schema['additionalProperties'] = bool(var_kwargs_schema)
145
-
146
- # instead of passing `description` through in core_schema, we just add it here
147
- if description:
148
- json_schema = {'description': description} | json_schema
145
+ elif not description:
146
+ # if the tool description is not set, and we have a single parameter, take the description from that
147
+ # and set it on the tool
148
+ description = json_schema.pop('description', None)
149
149
 
150
150
  return FunctionSchema(
151
151
  description=description,
@@ -168,11 +168,13 @@ def takes_ctx(function: Callable[..., Any]) -> bool:
168
168
  """
169
169
  sig = signature(function)
170
170
  try:
171
- _, first_param = next(iter(sig.parameters.items()))
171
+ first_param_name = next(iter(sig.parameters.keys()))
172
172
  except StopIteration:
173
173
  return False
174
174
  else:
175
- return first_param.annotation is not sig.empty and _is_call_ctx(first_param.annotation)
175
+ type_hints = _typing_extra.get_function_type_hints(function)
176
+ annotation = type_hints[first_param_name]
177
+ return annotation is not sig.empty and _is_call_ctx(annotation)
176
178
 
177
179
 
178
180
  def _build_schema(
@@ -212,21 +214,3 @@ def _is_call_ctx(annotation: Any) -> bool:
212
214
  return annotation is RunContext or (
213
215
  _typing_extra.is_generic_alias(annotation) and get_origin(annotation) is RunContext
214
216
  )
215
-
216
-
217
- if TYPE_CHECKING:
218
- LazyTypeAdapter = TypeAdapter
219
- else:
220
-
221
- class LazyTypeAdapter:
222
- __slots__ = '_args', '_kwargs', '_type_adapter'
223
-
224
- def __init__(self, *args, **kwargs):
225
- self._args = args
226
- self._kwargs = kwargs
227
- self._type_adapter = None
228
-
229
- def __getattr__(self, item):
230
- if self._type_adapter is None:
231
- self._type_adapter = TypeAdapter(*self._args, **self._kwargs)
232
- return getattr(self._type_adapter, item)
pydantic_ai/_result.py CHANGED
@@ -3,18 +3,17 @@ from __future__ import annotations as _annotations
3
3
  import inspect
4
4
  import sys
5
5
  import types
6
- from collections.abc import Awaitable
6
+ from collections.abc import Awaitable, Iterable
7
7
  from dataclasses import dataclass, field
8
8
  from typing import Any, Callable, Generic, Literal, Union, cast, get_args, get_origin
9
9
 
10
10
  from pydantic import TypeAdapter, ValidationError
11
11
  from typing_extensions import Self, TypeAliasType, TypedDict
12
12
 
13
- from . import _utils, messages
13
+ from . import _utils, messages as _messages
14
14
  from .exceptions import ModelRetry
15
- from .messages import ModelStructuredResponse, ToolCall
16
15
  from .result import ResultData
17
- from .tools import AgentDeps, ResultValidatorFunc, RunContext
16
+ from .tools import AgentDeps, ResultValidatorFunc, RunContext, ToolDefinition
18
17
 
19
18
 
20
19
  @dataclass
@@ -28,7 +27,12 @@ class ResultValidator(Generic[AgentDeps, ResultData]):
28
27
  self._is_async = inspect.iscoroutinefunction(self.function)
29
28
 
30
29
  async def validate(
31
- self, result: ResultData, deps: AgentDeps, retry: int, tool_call: messages.ToolCall | None
30
+ self,
31
+ result: ResultData,
32
+ deps: AgentDeps,
33
+ retry: int,
34
+ tool_call: _messages.ToolCallPart | None,
35
+ messages: list[_messages.ModelMessage],
32
36
  ) -> ResultData:
33
37
  """Validate a result but calling the function.
34
38
 
@@ -37,12 +41,13 @@ class ResultValidator(Generic[AgentDeps, ResultData]):
37
41
  deps: The agent dependencies.
38
42
  retry: The current retry number.
39
43
  tool_call: The original tool call message, `None` if there was no tool call.
44
+ messages: The messages exchanged so far in the conversation.
40
45
 
41
46
  Returns:
42
47
  Result of either the validated result data (ok) or a retry message (Err).
43
48
  """
44
49
  if self._takes_ctx:
45
- args = RunContext(deps, retry, tool_call.tool_name if tool_call else None), result
50
+ args = RunContext(deps, retry, messages, tool_call.tool_name if tool_call else None), result
46
51
  else:
47
52
  args = (result,)
48
53
 
@@ -54,10 +59,10 @@ class ResultValidator(Generic[AgentDeps, ResultData]):
54
59
  function = cast(Callable[[Any], ResultData], self.function)
55
60
  result_data = await _utils.run_in_executor(function, *args)
56
61
  except ModelRetry as r:
57
- m = messages.RetryPrompt(content=r.message)
62
+ m = _messages.RetryPromptPart(content=r.message)
58
63
  if tool_call is not None:
59
64
  m.tool_name = tool_call.tool_name
60
- m.tool_id = tool_call.tool_id
65
+ m.tool_call_id = tool_call.tool_call_id
61
66
  raise ToolRetryError(m) from r
62
67
  else:
63
68
  return result_data
@@ -66,7 +71,7 @@ class ResultValidator(Generic[AgentDeps, ResultData]):
66
71
  class ToolRetryError(Exception):
67
72
  """Internal exception used to signal a `ToolRetry` message should be returned to the LLM."""
68
73
 
69
- def __init__(self, tool_retry: messages.RetryPrompt):
74
+ def __init__(self, tool_retry: _messages.RetryPromptPart):
70
75
  self.tool_retry = tool_retry
71
76
  super().__init__()
72
77
 
@@ -94,10 +99,7 @@ class ResultSchema(Generic[ResultData]):
94
99
  allow_text_result = False
95
100
 
96
101
  def _build_tool(a: Any, tool_name_: str, multiple: bool) -> ResultTool[ResultData]:
97
- return cast(
98
- ResultTool[ResultData],
99
- ResultTool.build(a, tool_name_, description, multiple), # pyright: ignore[reportUnknownMemberType]
100
- )
102
+ return cast(ResultTool[ResultData], ResultTool(a, tool_name_, description, multiple))
101
103
 
102
104
  tools: dict[str, ResultTool[ResultData]] = {}
103
105
  if args := get_union_args(response_type):
@@ -111,48 +113,61 @@ class ResultSchema(Generic[ResultData]):
111
113
 
112
114
  return cls(tools=tools, allow_text_result=allow_text_result)
113
115
 
114
- def find_tool(self, message: ModelStructuredResponse) -> tuple[ToolCall, ResultTool[ResultData]] | None:
116
+ def find_named_tool(
117
+ self, parts: Iterable[_messages.ModelResponsePart], tool_name: str
118
+ ) -> tuple[_messages.ToolCallPart, ResultTool[ResultData]] | None:
119
+ """Find a tool that matches one of the calls, with a specific name."""
120
+ for part in parts:
121
+ if isinstance(part, _messages.ToolCallPart):
122
+ if part.tool_name == tool_name:
123
+ return part, self.tools[tool_name]
124
+
125
+ def find_tool(
126
+ self,
127
+ parts: Iterable[_messages.ModelResponsePart],
128
+ ) -> tuple[_messages.ToolCallPart, ResultTool[ResultData]] | None:
115
129
  """Find a tool that matches one of the calls."""
116
- for call in message.calls:
117
- if result := self.tools.get(call.tool_name):
118
- return call, result
130
+ for part in parts:
131
+ if isinstance(part, _messages.ToolCallPart):
132
+ if result := self.tools.get(part.tool_name):
133
+ return part, result
119
134
 
120
135
  def tool_names(self) -> list[str]:
121
136
  """Return the names of the tools."""
122
137
  return list(self.tools.keys())
123
138
 
139
+ def tool_defs(self) -> list[ToolDefinition]:
140
+ """Get tool definitions to register with the model."""
141
+ return [t.tool_def for t in self.tools.values()]
142
+
124
143
 
125
144
  DEFAULT_DESCRIPTION = 'The final response which ends this conversation'
126
145
 
127
146
 
128
- @dataclass
147
+ @dataclass(init=False)
129
148
  class ResultTool(Generic[ResultData]):
130
- name: str
131
- description: str
149
+ tool_def: ToolDefinition
132
150
  type_adapter: TypeAdapter[Any]
133
- json_schema: _utils.ObjectJsonSchema
134
- outer_typed_dict_key: str | None
135
151
 
136
- @classmethod
137
- def build(cls, response_type: type[ResultData], name: str, description: str | None, multiple: bool) -> Self | None:
152
+ def __init__(self, response_type: type[ResultData], name: str, description: str | None, multiple: bool):
138
153
  """Build a ResultTool dataclass from a response type."""
139
154
  assert response_type is not str, 'ResultTool does not support str as a response type'
140
155
 
141
156
  if _utils.is_model_like(response_type):
142
- type_adapter = TypeAdapter(response_type)
157
+ self.type_adapter = TypeAdapter(response_type)
143
158
  outer_typed_dict_key: str | None = None
144
159
  # noinspection PyArgumentList
145
- json_schema = _utils.check_object_json_schema(type_adapter.json_schema())
160
+ parameters_json_schema = _utils.check_object_json_schema(self.type_adapter.json_schema())
146
161
  else:
147
162
  response_data_typed_dict = TypedDict('response_data_typed_dict', {'response': response_type}) # noqa
148
- type_adapter = TypeAdapter(response_data_typed_dict)
163
+ self.type_adapter = TypeAdapter(response_data_typed_dict)
149
164
  outer_typed_dict_key = 'response'
150
165
  # noinspection PyArgumentList
151
- json_schema = _utils.check_object_json_schema(type_adapter.json_schema())
166
+ parameters_json_schema = _utils.check_object_json_schema(self.type_adapter.json_schema())
152
167
  # including `response_data_typed_dict` as a title here doesn't add anything and could confuse the LLM
153
- json_schema.pop('title')
168
+ parameters_json_schema.pop('title')
154
169
 
155
- if json_schema_description := json_schema.pop('description', None):
170
+ if json_schema_description := parameters_json_schema.pop('description', None):
156
171
  if description is None:
157
172
  tool_description = json_schema_description
158
173
  else:
@@ -162,16 +177,15 @@ class ResultTool(Generic[ResultData]):
162
177
  if multiple:
163
178
  tool_description = f'{union_arg_name(response_type)}: {tool_description}'
164
179
 
165
- return cls(
180
+ self.tool_def = ToolDefinition(
166
181
  name=name,
167
182
  description=tool_description,
168
- type_adapter=type_adapter,
169
- json_schema=json_schema,
183
+ parameters_json_schema=parameters_json_schema,
170
184
  outer_typed_dict_key=outer_typed_dict_key,
171
185
  )
172
186
 
173
187
  def validate(
174
- self, tool_call: messages.ToolCall, allow_partial: bool = False, wrap_validation_errors: bool = True
188
+ self, tool_call: _messages.ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True
175
189
  ) -> ResultData:
176
190
  """Validate a result message.
177
191
 
@@ -185,7 +199,7 @@ class ResultTool(Generic[ResultData]):
185
199
  """
186
200
  try:
187
201
  pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
188
- if isinstance(tool_call.args, messages.ArgsJson):
202
+ if isinstance(tool_call.args, _messages.ArgsJson):
189
203
  result = self.type_adapter.validate_json(
190
204
  tool_call.args.args_json or '', experimental_allow_partial=pyd_allow_partial
191
205
  )
@@ -195,16 +209,16 @@ class ResultTool(Generic[ResultData]):
195
209
  )
196
210
  except ValidationError as e:
197
211
  if wrap_validation_errors:
198
- m = messages.RetryPrompt(
212
+ m = _messages.RetryPromptPart(
199
213
  tool_name=tool_call.tool_name,
200
214
  content=e.errors(include_url=False),
201
- tool_id=tool_call.tool_id,
215
+ tool_call_id=tool_call.tool_call_id,
202
216
  )
203
217
  raise ToolRetryError(m) from e
204
218
  else:
205
219
  raise
206
220
  else:
207
- if k := self.outer_typed_dict_key:
221
+ if k := self.tool_def.outer_typed_dict_key:
208
222
  result = result[k]
209
223
  return result
210
224
 
@@ -21,7 +21,7 @@ class SystemPromptRunner(Generic[AgentDeps]):
21
21
 
22
22
  async def run(self, deps: AgentDeps) -> str:
23
23
  if self._takes_ctx:
24
- args = (RunContext(deps, 0, None),)
24
+ args = (RunContext(deps, 0, [], None),)
25
25
  else:
26
26
  args = ()
27
27
 
pydantic_ai/_utils.py CHANGED
@@ -8,12 +8,16 @@ from dataclasses import dataclass, is_dataclass
8
8
  from datetime import datetime, timezone
9
9
  from functools import partial
10
10
  from types import GenericAlias
11
- from typing import Any, Callable, Generic, TypeVar, Union, cast, overload
11
+ from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast, overload
12
12
 
13
13
  from pydantic import BaseModel
14
14
  from pydantic.json_schema import JsonSchemaValue
15
15
  from typing_extensions import ParamSpec, TypeAlias, TypeGuard, is_typeddict
16
16
 
17
+ if TYPE_CHECKING:
18
+ from .messages import RetryPromptPart, ToolCallPart, ToolReturnPart
19
+ from .tools import ObjectJsonSchema
20
+
17
21
  _P = ParamSpec('_P')
18
22
  _R = TypeVar('_R')
19
23
 
@@ -39,10 +43,6 @@ def is_model_like(type_: Any) -> bool:
39
43
  )
40
44
 
41
45
 
42
- # With PEP-728 this should be a TypedDict with `type: Literal['object']`, and `extra_items=Any`
43
- ObjectJsonSchema: TypeAlias = dict[str, Any]
44
-
45
-
46
46
  def check_object_json_schema(schema: JsonSchemaValue) -> ObjectJsonSchema:
47
47
  from .exceptions import UserError
48
48
 
@@ -88,7 +88,7 @@ class Either(Generic[Left, Right]):
88
88
 
89
89
  Usage:
90
90
 
91
- ```py
91
+ ```python
92
92
  if left_thing := either.left:
93
93
  use_left(left_thing.value)
94
94
  else:
@@ -127,6 +127,12 @@ class Either(Generic[Left, Right]):
127
127
  def whichever(self) -> Left | Right:
128
128
  return self._left.value if self._left is not None else self.right
129
129
 
130
+ def __repr__(self):
131
+ if left := self._left:
132
+ return f'Either(left={left.value!r})'
133
+ else:
134
+ return f'Either(right={self.right!r})'
135
+
130
136
 
131
137
  @asynccontextmanager
132
138
  async def group_by_temporal(
@@ -141,7 +147,7 @@ async def group_by_temporal(
141
147
 
142
148
  Usage:
143
149
 
144
- ```py
150
+ ```python
145
151
  async with group_by_temporal(yield_groups(), 0.1) as groups_iter:
146
152
  async for groups in groups_iter:
147
153
  print(groups)
@@ -218,7 +224,7 @@ async def group_by_temporal(
218
224
 
219
225
  try:
220
226
  yield async_iter_groups()
221
- finally:
227
+ finally: # pragma: no cover
222
228
  # after iteration if a tasks still exists, cancel it, this will only happen if an error occurred
223
229
  if task:
224
230
  task.cancel('Cancelling due to error in iterator')
@@ -249,3 +255,9 @@ def sync_anext(iterator: Iterator[T]) -> T:
249
255
 
250
256
  def now_utc() -> datetime:
251
257
  return datetime.now(tz=timezone.utc)
258
+
259
+
260
+ def guard_tool_call_id(t: ToolCallPart | ToolReturnPart | RetryPromptPart, model_source: str) -> str:
261
+ """Type guard that checks a `tool_call_id` is not None both for static typing and runtime."""
262
+ assert t.tool_call_id is not None, f'{model_source} requires `tool_call_id` to be set: {t}'
263
+ return t.tool_call_id