pydantic-ai-slim 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/result.py CHANGED
@@ -1,19 +1,20 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
- from collections.abc import AsyncIterator, Callable
4
+ from collections.abc import AsyncIterator, Awaitable, Callable
5
5
  from dataclasses import dataclass, field
6
6
  from datetime import datetime
7
7
  from typing import Generic, TypeVar, cast
8
8
 
9
9
  import logfire_api
10
10
 
11
- from . import _result, _utils, exceptions, messages, models
12
- from .tools import AgentDeps
11
+ from . import _result, _utils, exceptions, messages as _messages, models
12
+ from .settings import UsageLimits
13
+ from .tools import AgentDeps, RunContext
13
14
 
14
15
  __all__ = (
15
16
  'ResultData',
16
- 'Cost',
17
+ 'Usage',
17
18
  'RunResult',
18
19
  'StreamedRunResult',
19
20
  )
@@ -26,30 +27,32 @@ _logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
26
27
 
27
28
 
28
29
  @dataclass
29
- class Cost:
30
- """Cost of a request or run.
30
+ class Usage:
31
+ """LLM usage associated to a request or run.
31
32
 
32
- Responsibility for calculating costs is on the model used, PydanticAI simply sums the cost of requests.
33
+ Responsibility for calculating usage is on the model; PydanticAI simply sums the usage information across requests.
33
34
 
34
- You'll need to look up the documentation of the model you're using to convent "token count" costs to monetary costs.
35
+ You'll need to look up the documentation of the model you're using to convert usage to monetary costs.
35
36
  """
36
37
 
38
+ requests: int = 0
39
+ """Number of requests made."""
37
40
  request_tokens: int | None = None
38
- """Tokens used in processing the request."""
41
+ """Tokens used in processing requests."""
39
42
  response_tokens: int | None = None
40
- """Tokens used in generating the response."""
43
+ """Tokens used in generating responses."""
41
44
  total_tokens: int | None = None
42
45
  """Total tokens used in the whole run, should generally be equal to `request_tokens + response_tokens`."""
43
46
  details: dict[str, int] | None = None
44
47
  """Any extra details returned by the model."""
45
48
 
46
- def __add__(self, other: Cost) -> Cost:
47
- """Add two costs together.
49
+ def __add__(self, other: Usage) -> Usage:
50
+ """Add two Usages together.
48
51
 
49
- This is provided so it's trivial to sum costs from multiple requests and runs.
52
+ This is provided so it's trivial to sum usage information from multiple requests and runs.
50
53
  """
51
54
  counts: dict[str, int] = {}
52
- for f in 'request_tokens', 'response_tokens', 'total_tokens':
55
+ for f in 'requests', 'request_tokens', 'response_tokens', 'total_tokens':
53
56
  self_value = getattr(self, f)
54
57
  other_value = getattr(other, f)
55
58
  if self_value is not None or other_value is not None:
@@ -61,7 +64,7 @@ class Cost:
61
64
  for key, value in other.details.items():
62
65
  details[key] = details.get(key, 0) + value
63
66
 
64
- return Cost(**counts, details=details or None)
67
+ return Usage(**counts, details=details or None)
65
68
 
66
69
 
67
70
  @dataclass
@@ -71,19 +74,19 @@ class _BaseRunResult(ABC, Generic[ResultData]):
71
74
  You should not import or use this type directly, instead use its subclasses `RunResult` and `StreamedRunResult`.
72
75
  """
73
76
 
74
- _all_messages: list[messages.Message]
77
+ _all_messages: list[_messages.ModelMessage]
75
78
  _new_message_index: int
76
79
 
77
- def all_messages(self) -> list[messages.Message]:
78
- """Return the history of messages."""
80
+ def all_messages(self) -> list[_messages.ModelMessage]:
81
+ """Return the history of _messages."""
79
82
  # this is a method to be consistent with the other methods
80
83
  return self._all_messages
81
84
 
82
85
  def all_messages_json(self) -> bytes:
83
- """Return all messages from [`all_messages`][..all_messages] as JSON bytes."""
84
- return messages.MessagesTypeAdapter.dump_json(self.all_messages())
86
+ """Return all messages from [`all_messages`][pydantic_ai.result._BaseRunResult.all_messages] as JSON bytes."""
87
+ return _messages.ModelMessagesTypeAdapter.dump_json(self.all_messages())
85
88
 
86
- def new_messages(self) -> list[messages.Message]:
89
+ def new_messages(self) -> list[_messages.ModelMessage]:
87
90
  """Return new messages associated with this run.
88
91
 
89
92
  System prompts and any messages from older runs are excluded.
@@ -91,11 +94,11 @@ class _BaseRunResult(ABC, Generic[ResultData]):
91
94
  return self.all_messages()[self._new_message_index :]
92
95
 
93
96
  def new_messages_json(self) -> bytes:
94
- """Return new messages from [`new_messages`][..new_messages] as JSON bytes."""
95
- return messages.MessagesTypeAdapter.dump_json(self.new_messages())
97
+ """Return new messages from [`new_messages`][pydantic_ai.result._BaseRunResult.new_messages] as JSON bytes."""
98
+ return _messages.ModelMessagesTypeAdapter.dump_json(self.new_messages())
96
99
 
97
100
  @abstractmethod
98
- def cost(self) -> Cost:
101
+ def usage(self) -> Usage:
99
102
  raise NotImplementedError()
100
103
 
101
104
 
@@ -105,24 +108,26 @@ class RunResult(_BaseRunResult[ResultData]):
105
108
 
106
109
  data: ResultData
107
110
  """Data from the final response in the run."""
108
- _cost: Cost
111
+ _usage: Usage
109
112
 
110
- def cost(self) -> Cost:
111
- """Return the cost of the whole run."""
112
- return self._cost
113
+ def usage(self) -> Usage:
114
+ """Return the usage of the whole run."""
115
+ return self._usage
113
116
 
114
117
 
115
118
  @dataclass
116
119
  class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultData]):
117
120
  """Result of a streamed run that returns structured data via a tool call."""
118
121
 
119
- cost_so_far: Cost
120
- """Cost of the run up until the last request."""
122
+ usage_so_far: Usage
123
+ """Usage of the run up until the last request."""
124
+ _usage_limits: UsageLimits | None
121
125
  _stream_response: models.EitherStreamedResponse
122
126
  _result_schema: _result.ResultSchema[ResultData] | None
123
- _deps: AgentDeps
127
+ _run_ctx: RunContext[AgentDeps]
124
128
  _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]]
125
- _on_complete: Callable[[list[messages.Message]], None]
129
+ _result_tool_name: str | None
130
+ _on_complete: Callable[[], Awaitable[None]]
126
131
  is_complete: bool = field(default=False, init=False)
127
132
  """Whether the stream has all been received.
128
133
 
@@ -172,11 +177,15 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
172
177
  Debouncing is particularly important for long structured responses to reduce the overhead of
173
178
  performing validation as each token is received.
174
179
  """
180
+ usage_checking_stream = _get_usage_checking_stream_response(
181
+ self._stream_response, self._usage_limits, self.usage
182
+ )
183
+
175
184
  with _logfire.span('response stream text') as lf_span:
176
185
  if isinstance(self._stream_response, models.StreamStructuredResponse):
177
186
  raise exceptions.UserError('stream_text() can only be used with text responses')
178
187
  if delta:
179
- async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter:
188
+ async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter:
180
189
  async for _ in group_iter:
181
190
  yield ''.join(self._stream_response.get())
182
191
  final_delta = ''.join(self._stream_response.get(final=True))
@@ -187,7 +196,7 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
187
196
  # yielding at each step
188
197
  chunks: list[str] = []
189
198
  combined = ''
190
- async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter:
199
+ async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter:
191
200
  async for _ in group_iter:
192
201
  new = False
193
202
  for chunk in self._stream_response.get():
@@ -205,11 +214,11 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
205
214
  combined = await self._validate_text_result(''.join(chunks))
206
215
  yield combined
207
216
  lf_span.set_attribute('combined_text', combined)
208
- self._marked_completed(text=combined)
217
+ await self._marked_completed(_messages.ModelResponse.from_text(combined))
209
218
 
210
219
  async def stream_structured(
211
220
  self, *, debounce_by: float | None = 0.1
212
- ) -> AsyncIterator[tuple[messages.ModelStructuredResponse, bool]]:
221
+ ) -> AsyncIterator[tuple[_messages.ModelResponse, bool]]:
213
222
  """Stream the response as an async iterable of Structured LLM Messages.
214
223
 
215
224
  !!! note
@@ -224,61 +233,75 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
224
233
  Returns:
225
234
  An async iterable of the structured response message and whether that is the last message.
226
235
  """
236
+ usage_checking_stream = _get_usage_checking_stream_response(
237
+ self._stream_response, self._usage_limits, self.usage
238
+ )
239
+
227
240
  with _logfire.span('response stream structured') as lf_span:
228
241
  if isinstance(self._stream_response, models.StreamTextResponse):
229
242
  raise exceptions.UserError('stream_structured() can only be used with structured responses')
230
243
  else:
231
244
  # we should already have a message at this point, yield that first if it has any content
232
245
  msg = self._stream_response.get()
233
- if any(call.has_content() for call in msg.calls):
234
- yield msg, False
235
- async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter:
246
+ for item in msg.parts:
247
+ if isinstance(item, _messages.ToolCallPart) and item.has_content():
248
+ yield msg, False
249
+ break
250
+ async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter:
236
251
  async for _ in group_iter:
237
252
  msg = self._stream_response.get()
238
- if any(call.has_content() for call in msg.calls):
239
- yield msg, False
253
+ for item in msg.parts:
254
+ if isinstance(item, _messages.ToolCallPart) and item.has_content():
255
+ yield msg, False
256
+ break
240
257
  msg = self._stream_response.get(final=True)
241
258
  yield msg, True
242
259
  lf_span.set_attribute('structured_response', msg)
243
- self._marked_completed(structured_message=msg)
260
+ await self._marked_completed(msg)
244
261
 
245
262
  async def get_data(self) -> ResultData:
246
263
  """Stream the whole response, validate and return it."""
247
- async for _ in self._stream_response:
264
+ usage_checking_stream = _get_usage_checking_stream_response(
265
+ self._stream_response, self._usage_limits, self.usage
266
+ )
267
+
268
+ async for _ in usage_checking_stream:
248
269
  pass
270
+
249
271
  if isinstance(self._stream_response, models.StreamTextResponse):
250
272
  text = ''.join(self._stream_response.get(final=True))
251
273
  text = await self._validate_text_result(text)
252
- self._marked_completed(text=text)
274
+ await self._marked_completed(_messages.ModelResponse.from_text(text))
253
275
  return cast(ResultData, text)
254
276
  else:
255
- structured_message = self._stream_response.get(final=True)
256
- self._marked_completed(structured_message=structured_message)
257
- return await self.validate_structured_result(structured_message)
277
+ message = self._stream_response.get(final=True)
278
+ await self._marked_completed(message)
279
+ return await self.validate_structured_result(message)
258
280
 
259
281
  @property
260
282
  def is_structured(self) -> bool:
261
283
  """Return whether the stream response contains structured data (as opposed to text)."""
262
284
  return isinstance(self._stream_response, models.StreamStructuredResponse)
263
285
 
264
- def cost(self) -> Cost:
265
- """Return the cost of the whole run.
286
+ def usage(self) -> Usage:
287
+ """Return the usage of the whole run.
266
288
 
267
289
  !!! note
268
- This won't return the full cost until the stream is finished.
290
+ This won't return the full usage until the stream is finished.
269
291
  """
270
- return self.cost_so_far + self._stream_response.cost()
292
+ return self.usage_so_far + self._stream_response.usage()
271
293
 
272
294
  def timestamp(self) -> datetime:
273
295
  """Get the timestamp of the response."""
274
296
  return self._stream_response.timestamp()
275
297
 
276
298
  async def validate_structured_result(
277
- self, message: messages.ModelStructuredResponse, *, allow_partial: bool = False
299
+ self, message: _messages.ModelResponse, *, allow_partial: bool = False
278
300
  ) -> ResultData:
279
301
  """Validate a structured result message."""
280
302
  assert self._result_schema is not None, 'Expected _result_schema to not be None'
281
- match = self._result_schema.find_tool(message)
303
+ assert self._result_tool_name is not None, 'Expected _result_tool_name to not be None'
304
+ match = self._result_schema.find_named_tool(message.parts, self._result_tool_name)
282
305
  if match is None:
283
306
  raise exceptions.UnexpectedModelBehavior(
284
307
  f'Invalid message, unable to find tool: {self._result_schema.tool_names()}'
@@ -288,29 +311,34 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
288
311
  result_data = result_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False)
289
312
 
290
313
  for validator in self._result_validators:
291
- result_data = await validator.validate(result_data, self._deps, 0, call)
314
+ result_data = await validator.validate(result_data, call, self._run_ctx)
292
315
  return result_data
293
316
 
294
317
  async def _validate_text_result(self, text: str) -> str:
295
318
  for validator in self._result_validators:
296
319
  text = await validator.validate( # pyright: ignore[reportAssignmentType]
297
320
  text, # pyright: ignore[reportArgumentType]
298
- self._deps,
299
- 0,
300
321
  None,
322
+ self._run_ctx,
301
323
  )
302
324
  return text
303
325
 
304
- def _marked_completed(
305
- self, *, text: str | None = None, structured_message: messages.ModelStructuredResponse | None = None
306
- ) -> None:
326
+ async def _marked_completed(self, message: _messages.ModelResponse) -> None:
307
327
  self.is_complete = True
308
- if text is not None:
309
- assert structured_message is None, 'Either text or structured_message should provided, not both'
310
- self._all_messages.append(
311
- messages.ModelTextResponse(content=text, timestamp=self._stream_response.timestamp())
312
- )
313
- else:
314
- assert structured_message is not None, 'Either text or structured_message should provided, not both'
315
- self._all_messages.append(structured_message)
316
- self._on_complete(self._all_messages)
328
+ self._all_messages.append(message)
329
+ await self._on_complete()
330
+
331
+
332
+ def _get_usage_checking_stream_response(
333
+ stream_response: AsyncIterator[ResultData], limits: UsageLimits | None, get_usage: Callable[[], Usage]
334
+ ) -> AsyncIterator[ResultData]:
335
+ if limits is not None and limits.has_token_limits():
336
+
337
+ async def _usage_checking_iterator():
338
+ async for item in stream_response:
339
+ limits.check_tokens(get_usage())
340
+ yield item
341
+
342
+ return _usage_checking_iterator()
343
+ else:
344
+ return stream_response
@@ -0,0 +1,137 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING
5
+
6
+ from httpx import Timeout
7
+ from typing_extensions import TypedDict
8
+
9
+ from .exceptions import UsageLimitExceeded
10
+
11
+ if TYPE_CHECKING:
12
+ from .result import Usage
13
+
14
+
15
+ class ModelSettings(TypedDict, total=False):
16
+ """Settings to configure an LLM.
17
+
18
+ Here we include only settings which apply to multiple models / model providers.
19
+ """
20
+
21
+ max_tokens: int
22
+ """The maximum number of tokens to generate before stopping.
23
+
24
+ Supported by:
25
+ * Gemini
26
+ * Anthropic
27
+ * OpenAI
28
+ * Groq
29
+ """
30
+
31
+ temperature: float
32
+ """Amount of randomness injected into the response.
33
+
34
+ Use `temperature` closer to `0.0` for analytical / multiple choice, and closer to a model's
35
+ maximum `temperature` for creative and generative tasks.
36
+
37
+ Note that even with `temperature` of `0.0`, the results will not be fully deterministic.
38
+
39
+ Supported by:
40
+ * Gemini
41
+ * Anthropic
42
+ * OpenAI
43
+ * Groq
44
+ """
45
+
46
+ top_p: float
47
+ """An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass.
48
+
49
+ So 0.1 means only the tokens comprising the top 10% probability mass are considered.
50
+
51
+ You should either alter `temperature` or `top_p`, but not both.
52
+
53
+ Supported by:
54
+ * Gemini
55
+ * Anthropic
56
+ * OpenAI
57
+ * Groq
58
+ """
59
+
60
+ timeout: float | Timeout
61
+ """Override the client-level default timeout for a request, in seconds.
62
+
63
+ Supported by:
64
+ * Gemini
65
+ * Anthropic
66
+ * OpenAI
67
+ * Groq
68
+ """
69
+
70
+
71
+ def merge_model_settings(base: ModelSettings | None, overrides: ModelSettings | None) -> ModelSettings | None:
72
+ """Merge two sets of model settings, preferring the overrides.
73
+
74
+ A common use case is: merge_model_settings(<agent settings>, <run settings>)
75
+ """
76
+ # Note: we may want merge recursively if/when we add non-primitive values
77
+ if base and overrides:
78
+ return base | overrides
79
+ else:
80
+ return base or overrides
81
+
82
+
83
+ @dataclass
84
+ class UsageLimits:
85
+ """Limits on model usage.
86
+
87
+ The request count is tracked by pydantic_ai, and the request limit is checked before each request to the model.
88
+ Token counts are provided in responses from the model, and the token limits are checked after each response.
89
+
90
+ Each of the limits can be set to `None` to disable that limit.
91
+ """
92
+
93
+ request_limit: int | None = 50
94
+ """The maximum number of requests allowed to the model."""
95
+ request_tokens_limit: int | None = None
96
+ """The maximum number of tokens allowed in requests to the model."""
97
+ response_tokens_limit: int | None = None
98
+ """The maximum number of tokens allowed in responses from the model."""
99
+ total_tokens_limit: int | None = None
100
+ """The maximum number of tokens allowed in requests and responses combined."""
101
+
102
+ def has_token_limits(self) -> bool:
103
+ """Returns `True` if this instance places any limits on token counts.
104
+
105
+ If this returns `False`, the `check_tokens` method will never raise an error.
106
+
107
+ This is useful because if we have token limits, we need to check them after receiving each streamed message.
108
+ If there are no limits, we can skip that processing in the streaming response iterator.
109
+ """
110
+ return any(
111
+ limit is not None
112
+ for limit in (self.request_tokens_limit, self.response_tokens_limit, self.total_tokens_limit)
113
+ )
114
+
115
+ def check_before_request(self, usage: Usage) -> None:
116
+ """Raises a `UsageLimitExceeded` exception if the next request would exceed the request_limit."""
117
+ request_limit = self.request_limit
118
+ if request_limit is not None and usage.requests >= request_limit:
119
+ raise UsageLimitExceeded(f'The next request would exceed the request_limit of {request_limit}')
120
+
121
+ def check_tokens(self, usage: Usage) -> None:
122
+ """Raises a `UsageLimitExceeded` exception if the usage exceeds any of the token limits."""
123
+ request_tokens = usage.request_tokens or 0
124
+ if self.request_tokens_limit is not None and request_tokens > self.request_tokens_limit:
125
+ raise UsageLimitExceeded(
126
+ f'Exceeded the request_tokens_limit of {self.request_tokens_limit} ({request_tokens=})'
127
+ )
128
+
129
+ response_tokens = usage.response_tokens or 0
130
+ if self.response_tokens_limit is not None and response_tokens > self.response_tokens_limit:
131
+ raise UsageLimitExceeded(
132
+ f'Exceeded the response_tokens_limit of {self.response_tokens_limit} ({response_tokens=})'
133
+ )
134
+
135
+ total_tokens = request_tokens + response_tokens
136
+ if self.total_tokens_limit is not None and total_tokens > self.total_tokens_limit:
137
+ raise UsageLimitExceeded(f'Exceeded the total_tokens_limit of {self.total_tokens_limit} ({total_tokens=})')
pydantic_ai/tools.py CHANGED
@@ -1,23 +1,18 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
+ import dataclasses
3
4
  import inspect
4
5
  from collections.abc import Awaitable
5
6
  from dataclasses import dataclass, field
6
- from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast
7
+ from typing import Any, Callable, Generic, TypeVar, Union, cast
7
8
 
8
9
  from pydantic import ValidationError
9
10
  from pydantic_core import SchemaValidator
10
11
  from typing_extensions import Concatenate, ParamSpec, TypeAlias
11
12
 
12
- from . import _pydantic, _utils, messages
13
+ from . import _pydantic, _utils, messages as _messages, models
13
14
  from .exceptions import ModelRetry, UnexpectedModelBehavior
14
15
 
15
- if TYPE_CHECKING:
16
- from .result import ResultData
17
- else:
18
- ResultData = Any
19
-
20
-
21
16
  __all__ = (
22
17
  'AgentDeps',
23
18
  'RunContext',
@@ -37,7 +32,7 @@ AgentDeps = TypeVar('AgentDeps')
37
32
  """Type variable for agent dependencies."""
38
33
 
39
34
 
40
- @dataclass
35
+ @dataclasses.dataclass
41
36
  class RunContext(Generic[AgentDeps]):
42
37
  """Information about the current call."""
43
38
 
@@ -45,8 +40,23 @@ class RunContext(Generic[AgentDeps]):
45
40
  """Dependencies for the agent."""
46
41
  retry: int
47
42
  """Number of retries so far."""
48
- tool_name: str | None = None
43
+ messages: list[_messages.ModelMessage]
44
+ """Messages exchanged in the conversation so far."""
45
+ tool_name: str | None
49
46
  """Name of the tool being called."""
47
+ model: models.Model
48
+ """The model used in this run."""
49
+
50
+ def replace_with(
51
+ self, retry: int | None = None, tool_name: str | None | _utils.Unset = _utils.UNSET
52
+ ) -> RunContext[AgentDeps]:
53
+ # Create a new `RunContext` a new `retry` value and `tool_name`.
54
+ kwargs = {}
55
+ if retry is not None:
56
+ kwargs['retry'] = retry
57
+ if tool_name is not _utils.UNSET:
58
+ kwargs['tool_name'] = tool_name
59
+ return dataclasses.replace(self, **kwargs)
50
60
 
51
61
 
52
62
  ToolParams = ParamSpec('ToolParams')
@@ -63,6 +73,8 @@ SystemPromptFunc = Union[
63
73
  Usage `SystemPromptFunc[AgentDeps]`.
64
74
  """
65
75
 
76
+ ResultData = TypeVar('ResultData')
77
+
66
78
  ResultValidatorFunc = Union[
67
79
  Callable[[RunContext[AgentDeps], ResultData], ResultData],
68
80
  Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]],
@@ -87,7 +99,7 @@ ToolFuncPlain = Callable[ToolParams, Any]
87
99
  Usage `ToolPlainFunc[ToolParams]`.
88
100
  """
89
101
  ToolFuncEither = Union[ToolFuncContext[AgentDeps, ToolParams], ToolFuncPlain[ToolParams]]
90
- """Either kind of tool function.
102
+ """Either part_kind of tool function.
91
103
 
92
104
  This is just a union of [`ToolFuncContext`][pydantic_ai.tools.ToolFuncContext] and
93
105
  [`ToolFuncPlain`][pydantic_ai.tools.ToolFuncPlain].
@@ -97,11 +109,11 @@ Usage `ToolFuncEither[AgentDeps, ToolParams]`.
97
109
  ToolPrepareFunc: TypeAlias = 'Callable[[RunContext[AgentDeps], ToolDefinition], Awaitable[ToolDefinition | None]]'
98
110
  """Definition of a function that can prepare a tool definition at call time.
99
111
 
100
- See [tool docs](../agents.md#tool-prepare) for more information.
112
+ See [tool docs](../tools.md#tool-prepare) for more information.
101
113
 
102
114
  Example — here `only_if_42` is valid as a `ToolPrepareFunc`:
103
115
 
104
- ```py
116
+ ```python {lint="not-imports"}
105
117
  from typing import Union
106
118
 
107
119
  from pydantic_ai import RunContext, Tool
@@ -157,7 +169,7 @@ class Tool(Generic[AgentDeps]):
157
169
 
158
170
  Example usage:
159
171
 
160
- ```py
172
+ ```python {lint="not-imports"}
161
173
  from pydantic_ai import Agent, RunContext, Tool
162
174
 
163
175
  async def my_tool(ctx: RunContext[int], x: int, y: int) -> str:
@@ -168,7 +180,7 @@ class Tool(Generic[AgentDeps]):
168
180
 
169
181
  or with a custom prepare method:
170
182
 
171
- ```py
183
+ ```python {lint="not-imports"}
172
184
  from typing import Union
173
185
 
174
186
  from pydantic_ai import Agent, RunContext, Tool
@@ -235,17 +247,19 @@ class Tool(Generic[AgentDeps]):
235
247
  else:
236
248
  return tool_def
237
249
 
238
- async def run(self, deps: AgentDeps, message: messages.ToolCall) -> messages.Message:
250
+ async def run(
251
+ self, message: _messages.ToolCallPart, run_context: RunContext[AgentDeps]
252
+ ) -> _messages.ModelRequestPart:
239
253
  """Run the tool function asynchronously."""
240
254
  try:
241
- if isinstance(message.args, messages.ArgsJson):
255
+ if isinstance(message.args, _messages.ArgsJson):
242
256
  args_dict = self._validator.validate_json(message.args.args_json)
243
257
  else:
244
258
  args_dict = self._validator.validate_python(message.args.args_dict)
245
259
  except ValidationError as e:
246
260
  return self._on_error(e, message)
247
261
 
248
- args, kwargs = self._call_args(deps, args_dict, message)
262
+ args, kwargs = self._call_args(args_dict, message, run_context)
249
263
  try:
250
264
  if self._is_async:
251
265
  function = cast(Callable[[Any], Awaitable[str]], self.function)
@@ -257,19 +271,23 @@ class Tool(Generic[AgentDeps]):
257
271
  return self._on_error(e, message)
258
272
 
259
273
  self.current_retry = 0
260
- return messages.ToolReturn(
274
+ return _messages.ToolReturnPart(
261
275
  tool_name=message.tool_name,
262
276
  content=response_content,
263
- tool_id=message.tool_id,
277
+ tool_call_id=message.tool_call_id,
264
278
  )
265
279
 
266
280
  def _call_args(
267
- self, deps: AgentDeps, args_dict: dict[str, Any], message: messages.ToolCall
281
+ self,
282
+ args_dict: dict[str, Any],
283
+ message: _messages.ToolCallPart,
284
+ run_context: RunContext[AgentDeps],
268
285
  ) -> tuple[list[Any], dict[str, Any]]:
269
286
  if self._single_arg_name:
270
287
  args_dict = {self._single_arg_name: args_dict}
271
288
 
272
- args = [RunContext(deps, self.current_retry, message.tool_name)] if self.takes_ctx else []
289
+ ctx = dataclasses.replace(run_context, retry=self.current_retry, tool_name=message.tool_name)
290
+ args = [ctx] if self.takes_ctx else []
273
291
  for positional_field in self._positional_fields:
274
292
  args.append(args_dict.pop(positional_field))
275
293
  if self._var_positional_field:
@@ -277,7 +295,9 @@ class Tool(Generic[AgentDeps]):
277
295
 
278
296
  return args, args_dict
279
297
 
280
- def _on_error(self, exc: ValidationError | ModelRetry, call_message: messages.ToolCall) -> messages.RetryPrompt:
298
+ def _on_error(
299
+ self, exc: ValidationError | ModelRetry, call_message: _messages.ToolCallPart
300
+ ) -> _messages.RetryPromptPart:
281
301
  self.current_retry += 1
282
302
  if self.max_retries is None or self.current_retry > self.max_retries:
283
303
  raise UnexpectedModelBehavior(f'Tool exceeded max retries count of {self.max_retries}') from exc
@@ -286,10 +306,10 @@ class Tool(Generic[AgentDeps]):
286
306
  content = exc.errors(include_url=False)
287
307
  else:
288
308
  content = exc.message
289
- return messages.RetryPrompt(
309
+ return _messages.RetryPromptPart(
290
310
  tool_name=call_message.tool_name,
291
311
  content=content,
292
- tool_id=call_message.tool_id,
312
+ tool_call_id=call_message.tool_call_id,
293
313
  )
294
314
 
295
315
 
@@ -298,7 +318,7 @@ ObjectJsonSchema: TypeAlias = dict[str, Any]
298
318
 
299
319
  This type is used to define tools parameters (aka arguments) in [ToolDefinition][pydantic_ai.tools.ToolDefinition].
300
320
 
301
- With PEP-728 this should be a TypedDict with `type: Literal['object']`, and `extra_items=Any`
321
+ With PEP-728 this should be a TypedDict with `type: Literal['object']`, and `extra_parts=Any`
302
322
  """
303
323
 
304
324