pydantic-ai-slim 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/result.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
- from collections.abc import AsyncIterator, Awaitable, Callable
4
+ from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
5
5
  from copy import deepcopy
6
6
  from dataclasses import dataclass, field
7
7
  from datetime import datetime
@@ -11,35 +11,49 @@ import logfire_api
11
11
  from typing_extensions import TypeVar
12
12
 
13
13
  from . import _result, _utils, exceptions, messages as _messages, models
14
- from .tools import AgentDeps, RunContext
14
+ from .tools import AgentDepsT, RunContext
15
15
  from .usage import Usage, UsageLimits
16
16
 
17
- __all__ = 'ResultData', 'ResultValidatorFunc', 'RunResult', 'StreamedRunResult'
17
+ __all__ = 'ResultDataT', 'ResultDataT_inv', 'ResultValidatorFunc', 'RunResult', 'StreamedRunResult'
18
18
 
19
19
 
20
- ResultData = TypeVar('ResultData', default=str)
21
- """Type variable for the result data of a run."""
20
+ T = TypeVar('T')
21
+ """An invariant TypeVar."""
22
+ ResultDataT_inv = TypeVar('ResultDataT_inv', default=str)
23
+ """
24
+ An invariant type variable for the result data of a model.
25
+
26
+ We need to use an invariant typevar for `ResultValidator` and `ResultValidatorFunc` because the result data type is used
27
+ in both the input and output of a `ResultValidatorFunc`. This can theoretically lead to some issues assuming that types
28
+ possessing ResultValidator's are covariant in the result data type, but in practice this is rarely an issue, and
29
+ changing it would have negative consequences for the ergonomics of the library.
30
+
31
+ At some point, it may make sense to change the input to ResultValidatorFunc to be `Any` or `object` as doing that would
32
+ resolve these potential variance issues.
33
+ """
34
+ ResultDataT = TypeVar('ResultDataT', default=str, covariant=True)
35
+ """Covariant type variable for the result data type of a run."""
22
36
 
23
37
  ResultValidatorFunc = Union[
24
- Callable[[RunContext[AgentDeps], ResultData], ResultData],
25
- Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]],
26
- Callable[[ResultData], ResultData],
27
- Callable[[ResultData], Awaitable[ResultData]],
38
+ Callable[[RunContext[AgentDepsT], ResultDataT_inv], ResultDataT_inv],
39
+ Callable[[RunContext[AgentDepsT], ResultDataT_inv], Awaitable[ResultDataT_inv]],
40
+ Callable[[ResultDataT_inv], ResultDataT_inv],
41
+ Callable[[ResultDataT_inv], Awaitable[ResultDataT_inv]],
28
42
  ]
29
43
  """
30
- A function that always takes `ResultData` and returns `ResultData` and:
44
+ A function that always takes and returns the same type of data (which is the result type of an agent run), and:
31
45
 
32
46
  * may or may not take [`RunContext`][pydantic_ai.tools.RunContext] as a first argument
33
47
  * may or may not be async
34
48
 
35
- Usage `ResultValidatorFunc[AgentDeps, ResultData]`.
49
+ Usage `ResultValidatorFunc[AgentDeps, T]`.
36
50
  """
37
51
 
38
52
  _logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
39
53
 
40
54
 
41
55
  @dataclass
42
- class _BaseRunResult(ABC, Generic[ResultData]):
56
+ class _BaseRunResult(ABC, Generic[ResultDataT]):
43
57
  """Base type for results.
44
58
 
45
59
  You should not import or use this type directly, instead use its subclasses `RunResult` and `StreamedRunResult`.
@@ -119,10 +133,10 @@ class _BaseRunResult(ABC, Generic[ResultData]):
119
133
 
120
134
 
121
135
  @dataclass
122
- class RunResult(_BaseRunResult[ResultData]):
136
+ class RunResult(_BaseRunResult[ResultDataT]):
123
137
  """Result of a non-streamed run."""
124
138
 
125
- data: ResultData
139
+ data: ResultDataT
126
140
  """Data from the final response in the run."""
127
141
  _result_tool_name: str | None
128
142
  _usage: Usage
@@ -165,14 +179,14 @@ class RunResult(_BaseRunResult[ResultData]):
165
179
 
166
180
 
167
181
  @dataclass
168
- class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultData]):
182
+ class StreamedRunResult(_BaseRunResult[ResultDataT], Generic[AgentDepsT, ResultDataT]):
169
183
  """Result of a streamed run that returns structured data via a tool call."""
170
184
 
171
185
  _usage_limits: UsageLimits | None
172
- _stream_response: models.EitherStreamedResponse
173
- _result_schema: _result.ResultSchema[ResultData] | None
174
- _run_ctx: RunContext[AgentDeps]
175
- _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]]
186
+ _stream_response: models.StreamedResponse
187
+ _result_schema: _result.ResultSchema[ResultDataT] | None
188
+ _run_ctx: RunContext[AgentDepsT]
189
+ _result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]]
176
190
  _result_tool_name: str | None
177
191
  _on_complete: Callable[[], Awaitable[None]]
178
192
  is_complete: bool = field(default=False, init=False)
@@ -185,7 +199,7 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
185
199
  [`get_data`][pydantic_ai.result.StreamedRunResult.get_data] completes.
186
200
  """
187
201
 
188
- async def stream(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[ResultData]:
202
+ async def stream(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[ResultDataT]:
189
203
  """Stream the response as an async iterable.
190
204
 
191
205
  The pydantic validator for structured data will be called in
@@ -200,20 +214,13 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
200
214
  Returns:
201
215
  An async iterable of the response data.
202
216
  """
203
- if isinstance(self._stream_response, models.StreamTextResponse):
204
- async for text in self.stream_text(debounce_by=debounce_by):
205
- yield cast(ResultData, text)
206
- else:
207
- async for structured_message, is_last in self.stream_structured(debounce_by=debounce_by):
208
- yield await self.validate_structured_result(structured_message, allow_partial=not is_last)
217
+ async for structured_message, is_last in self.stream_structured(debounce_by=debounce_by):
218
+ result = await self.validate_structured_result(structured_message, allow_partial=not is_last)
219
+ yield result
209
220
 
210
221
  async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]:
211
222
  """Stream the text result as an async iterable.
212
223
 
213
- !!! note
214
- This method will fail if the response is structured,
215
- e.g. if [`is_structured`][pydantic_ai.result.StreamedRunResult.is_structured] returns `True`.
216
-
217
224
  !!! note
218
225
  Result validators will NOT be called on the text result if `delta=True`.
219
226
 
@@ -224,54 +231,70 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
224
231
  Debouncing is particularly important for long structured responses to reduce the overhead of
225
232
  performing validation as each token is received.
226
233
  """
234
+ if self._result_schema and not self._result_schema.allow_text_result:
235
+ raise exceptions.UserError('stream_text() can only be used with text responses')
236
+
227
237
  usage_checking_stream = _get_usage_checking_stream_response(
228
238
  self._stream_response, self._usage_limits, self.usage
229
239
  )
230
240
 
241
+ # Define a "merged" version of the iterator that will yield items that have already been retrieved
242
+ # and items that we receive while streaming. We define a dedicated async iterator for this so we can
243
+ # pass the combined stream to the group_by_temporal function within `_stream_text_deltas` below.
244
+ async def _stream_text_deltas_ungrouped() -> AsyncIterator[tuple[str, int]]:
245
+ # if the response currently has any parts with content, yield those before streaming
246
+ msg = self._stream_response.get()
247
+ for i, part in enumerate(msg.parts):
248
+ if isinstance(part, _messages.TextPart) and part.content:
249
+ yield part.content, i
250
+
251
+ async for event in usage_checking_stream:
252
+ if (
253
+ isinstance(event, _messages.PartStartEvent)
254
+ and isinstance(event.part, _messages.TextPart)
255
+ and event.part.content
256
+ ):
257
+ yield event.part.content, event.index
258
+ elif (
259
+ isinstance(event, _messages.PartDeltaEvent)
260
+ and isinstance(event.delta, _messages.TextPartDelta)
261
+ and event.delta.content_delta
262
+ ):
263
+ yield event.delta.content_delta, event.index
264
+
265
+ async def _stream_text_deltas() -> AsyncIterator[str]:
266
+ async with _utils.group_by_temporal(_stream_text_deltas_ungrouped(), debounce_by) as group_iter:
267
+ async for items in group_iter:
268
+ yield ''.join([content for content, _ in items])
269
+
231
270
  with _logfire.span('response stream text') as lf_span:
232
- if isinstance(self._stream_response, models.StreamStructuredResponse):
233
- raise exceptions.UserError('stream_text() can only be used with text responses')
234
271
  if delta:
235
- async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter:
236
- async for _ in group_iter:
237
- yield ''.join(self._stream_response.get())
238
- final_delta = ''.join(self._stream_response.get(final=True))
239
- if final_delta:
240
- yield final_delta
272
+ async for text in _stream_text_deltas():
273
+ yield text
241
274
  else:
242
275
  # a quick benchmark shows it's faster to build up a string with concat when we're
243
276
  # yielding at each step
244
- chunks: list[str] = []
245
- combined = ''
246
- async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter:
247
- async for _ in group_iter:
248
- new = False
249
- for chunk in self._stream_response.get():
250
- chunks.append(chunk)
251
- new = True
252
- if new:
253
- combined = await self._validate_text_result(''.join(chunks))
254
- yield combined
255
-
256
- new = False
257
- for chunk in self._stream_response.get(final=True):
258
- chunks.append(chunk)
259
- new = True
260
- if new:
261
- combined = await self._validate_text_result(''.join(chunks))
262
- yield combined
263
- lf_span.set_attribute('combined_text', combined)
264
- await self._marked_completed(_messages.ModelResponse.from_text(combined))
277
+ deltas: list[str] = []
278
+ combined_validated_text = ''
279
+ async for text in _stream_text_deltas():
280
+ deltas.append(text)
281
+ combined_text = ''.join(deltas)
282
+ combined_validated_text = await self._validate_text_result(combined_text)
283
+ yield combined_validated_text
284
+
285
+ lf_span.set_attribute('combined_text', combined_validated_text)
286
+ await self._marked_completed(
287
+ _messages.ModelResponse(
288
+ parts=[_messages.TextPart(combined_validated_text)],
289
+ model_name=self._stream_response.model_name(),
290
+ )
291
+ )
265
292
 
266
293
  async def stream_structured(
267
294
  self, *, debounce_by: float | None = 0.1
268
295
  ) -> AsyncIterator[tuple[_messages.ModelResponse, bool]]:
269
296
  """Stream the response as an async iterable of Structured LLM Messages.
270
297
 
271
- !!! note
272
- This method will fail if the response is text,
273
- e.g. if [`is_structured`][pydantic_ai.result.StreamedRunResult.is_structured] returns `False`.
274
-
275
298
  Args:
276
299
  debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing.
277
300
  Debouncing is particularly important for long structured responses to reduce the overhead of
@@ -285,28 +308,24 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
285
308
  )
286
309
 
287
310
  with _logfire.span('response stream structured') as lf_span:
288
- if isinstance(self._stream_response, models.StreamTextResponse):
289
- raise exceptions.UserError('stream_structured() can only be used with structured responses')
290
- else:
291
- # we should already have a message at this point, yield that first if it has any content
311
+ # if the message currently has any parts with content, yield before streaming
312
+ msg = self._stream_response.get()
313
+ for part in msg.parts:
314
+ if part.has_content():
315
+ yield msg, False
316
+ break
317
+
318
+ async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter:
319
+ async for _events in group_iter:
320
+ msg = self._stream_response.get()
321
+ yield msg, False
292
322
  msg = self._stream_response.get()
293
- for item in msg.parts:
294
- if isinstance(item, _messages.ToolCallPart) and item.has_content():
295
- yield msg, False
296
- break
297
- async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter:
298
- async for _ in group_iter:
299
- msg = self._stream_response.get()
300
- for item in msg.parts:
301
- if isinstance(item, _messages.ToolCallPart) and item.has_content():
302
- yield msg, False
303
- break
304
- msg = self._stream_response.get(final=True)
305
323
  yield msg, True
324
+ # TODO: Should this now be `final_response` instead of `structured_response`?
306
325
  lf_span.set_attribute('structured_response', msg)
307
326
  await self._marked_completed(msg)
308
327
 
309
- async def get_data(self) -> ResultData:
328
+ async def get_data(self) -> ResultDataT:
310
329
  """Stream the whole response, validate and return it."""
311
330
  usage_checking_stream = _get_usage_checking_stream_response(
312
331
  self._stream_response, self._usage_limits, self.usage
@@ -314,21 +333,9 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
314
333
 
315
334
  async for _ in usage_checking_stream:
316
335
  pass
317
-
318
- if isinstance(self._stream_response, models.StreamTextResponse):
319
- text = ''.join(self._stream_response.get(final=True))
320
- text = await self._validate_text_result(text)
321
- await self._marked_completed(_messages.ModelResponse.from_text(text))
322
- return cast(ResultData, text)
323
- else:
324
- message = self._stream_response.get(final=True)
325
- await self._marked_completed(message)
326
- return await self.validate_structured_result(message)
327
-
328
- @property
329
- def is_structured(self) -> bool:
330
- """Return whether the stream response contains structured data (as opposed to text)."""
331
- return isinstance(self._stream_response, models.StreamStructuredResponse)
336
+ message = self._stream_response.get()
337
+ await self._marked_completed(message)
338
+ return await self.validate_structured_result(message)
332
339
 
333
340
  def usage(self) -> Usage:
334
341
  """Return the usage of the whole run.
@@ -344,27 +351,36 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
344
351
 
345
352
  async def validate_structured_result(
346
353
  self, message: _messages.ModelResponse, *, allow_partial: bool = False
347
- ) -> ResultData:
354
+ ) -> ResultDataT:
348
355
  """Validate a structured result message."""
349
- assert self._result_schema is not None, 'Expected _result_schema to not be None'
350
- assert self._result_tool_name is not None, 'Expected _result_tool_name to not be None'
351
- match = self._result_schema.find_named_tool(message.parts, self._result_tool_name)
352
- if match is None:
353
- raise exceptions.UnexpectedModelBehavior(
354
- f'Invalid message, unable to find tool: {self._result_schema.tool_names()}'
355
- )
356
-
357
- call, result_tool = match
358
- result_data = result_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False)
359
-
360
- for validator in self._result_validators:
361
- result_data = await validator.validate(result_data, call, self._run_ctx)
362
- return result_data
356
+ if self._result_schema is not None and self._result_tool_name is not None:
357
+ match = self._result_schema.find_named_tool(message.parts, self._result_tool_name)
358
+ if match is None:
359
+ raise exceptions.UnexpectedModelBehavior(
360
+ f'Invalid response, unable to find tool: {self._result_schema.tool_names()}'
361
+ )
362
+
363
+ call, result_tool = match
364
+ result_data = result_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False)
365
+
366
+ for validator in self._result_validators:
367
+ result_data = await validator.validate(result_data, call, self._run_ctx)
368
+ return result_data
369
+ else:
370
+ text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart))
371
+ for validator in self._result_validators:
372
+ text = await validator.validate(
373
+ text,
374
+ None,
375
+ self._run_ctx,
376
+ )
377
+ # Since there is no result tool, we can assume that str is compatible with ResultDataT
378
+ return cast(ResultDataT, text)
363
379
 
364
380
  async def _validate_text_result(self, text: str) -> str:
365
381
  for validator in self._result_validators:
366
- text = await validator.validate( # pyright: ignore[reportAssignmentType]
367
- text, # pyright: ignore[reportArgumentType]
382
+ text = await validator.validate(
383
+ text,
368
384
  None,
369
385
  self._run_ctx,
370
386
  )
@@ -377,8 +393,10 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
377
393
 
378
394
 
379
395
  def _get_usage_checking_stream_response(
380
- stream_response: AsyncIterator[ResultData], limits: UsageLimits | None, get_usage: Callable[[], Usage]
381
- ) -> AsyncIterator[ResultData]:
396
+ stream_response: AsyncIterable[_messages.ModelResponseStreamEvent],
397
+ limits: UsageLimits | None,
398
+ get_usage: Callable[[], Usage],
399
+ ) -> AsyncIterable[_messages.ModelResponseStreamEvent]:
382
400
  if limits is not None and limits.has_token_limits():
383
401
 
384
402
  async def _usage_checking_iterator():
pydantic_ai/settings.py CHANGED
@@ -12,7 +12,8 @@ if TYPE_CHECKING:
12
12
  class ModelSettings(TypedDict, total=False):
13
13
  """Settings to configure an LLM.
14
14
 
15
- Here we include only settings which apply to multiple models / model providers.
15
+ Here we include only settings which apply to multiple models / model providers,
16
+ though not all of these settings are supported by all models.
16
17
  """
17
18
 
18
19
  max_tokens: int
@@ -24,6 +25,8 @@ class ModelSettings(TypedDict, total=False):
24
25
  * Anthropic
25
26
  * OpenAI
26
27
  * Groq
28
+ * Cohere
29
+ * Mistral
27
30
  """
28
31
 
29
32
  temperature: float
@@ -40,6 +43,8 @@ class ModelSettings(TypedDict, total=False):
40
43
  * Anthropic
41
44
  * OpenAI
42
45
  * Groq
46
+ * Cohere
47
+ * Mistral
43
48
  """
44
49
 
45
50
  top_p: float
@@ -55,6 +60,8 @@ class ModelSettings(TypedDict, total=False):
55
60
  * Anthropic
56
61
  * OpenAI
57
62
  * Groq
63
+ * Cohere
64
+ * Mistral
58
65
  """
59
66
 
60
67
  timeout: float | Timeout
@@ -66,6 +73,16 @@ class ModelSettings(TypedDict, total=False):
66
73
  * Anthropic
67
74
  * OpenAI
68
75
  * Groq
76
+ * Mistral
77
+ """
78
+
79
+ parallel_tool_calls: bool
80
+ """Whether to allow parallel tool calls.
81
+
82
+ Supported by:
83
+ * OpenAI
84
+ * Groq
85
+ * Anthropic
69
86
  """
70
87
 
71
88
 
pydantic_ai/tools.py CHANGED
@@ -4,7 +4,7 @@ import dataclasses
4
4
  import inspect
5
5
  from collections.abc import Awaitable
6
6
  from dataclasses import dataclass, field
7
- from typing import TYPE_CHECKING, Any, Callable, Generic, Union, cast
7
+ from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast
8
8
 
9
9
  from pydantic import ValidationError
10
10
  from pydantic_core import SchemaValidator
@@ -17,7 +17,8 @@ if TYPE_CHECKING:
17
17
  from .result import Usage
18
18
 
19
19
  __all__ = (
20
- 'AgentDeps',
20
+ 'AgentDepsT',
21
+ 'DocstringFormat',
21
22
  'RunContext',
22
23
  'SystemPromptFunc',
23
24
  'ToolFuncContext',
@@ -30,15 +31,15 @@ __all__ = (
30
31
  'ToolDefinition',
31
32
  )
32
33
 
33
- AgentDeps = TypeVar('AgentDeps', default=None)
34
+ AgentDepsT = TypeVar('AgentDepsT', default=None, contravariant=True)
34
35
  """Type variable for agent dependencies."""
35
36
 
36
37
 
37
38
  @dataclasses.dataclass
38
- class RunContext(Generic[AgentDeps]):
39
+ class RunContext(Generic[AgentDepsT]):
39
40
  """Information about the current call."""
40
41
 
41
- deps: AgentDeps
42
+ deps: AgentDepsT
42
43
  """Dependencies for the agent."""
43
44
  model: models.Model
44
45
  """The model used in this run."""
@@ -57,7 +58,7 @@ class RunContext(Generic[AgentDeps]):
57
58
 
58
59
  def replace_with(
59
60
  self, retry: int | None = None, tool_name: str | None | _utils.Unset = _utils.UNSET
60
- ) -> RunContext[AgentDeps]:
61
+ ) -> RunContext[AgentDepsT]:
61
62
  # Create a new `RunContext` a new `retry` value and `tool_name`.
62
63
  kwargs = {}
63
64
  if retry is not None:
@@ -71,8 +72,8 @@ ToolParams = ParamSpec('ToolParams', default=...)
71
72
  """Retrieval function param spec."""
72
73
 
73
74
  SystemPromptFunc = Union[
74
- Callable[[RunContext[AgentDeps]], str],
75
- Callable[[RunContext[AgentDeps]], Awaitable[str]],
75
+ Callable[[RunContext[AgentDepsT]], str],
76
+ Callable[[RunContext[AgentDepsT]], Awaitable[str]],
76
77
  Callable[[], str],
77
78
  Callable[[], Awaitable[str]],
78
79
  ]
@@ -81,7 +82,7 @@ SystemPromptFunc = Union[
81
82
  Usage `SystemPromptFunc[AgentDeps]`.
82
83
  """
83
84
 
84
- ToolFuncContext = Callable[Concatenate[RunContext[AgentDeps], ToolParams], Any]
85
+ ToolFuncContext = Callable[Concatenate[RunContext[AgentDepsT], ToolParams], Any]
85
86
  """A tool function that takes `RunContext` as the first argument.
86
87
 
87
88
  Usage `ToolContextFunc[AgentDeps, ToolParams]`.
@@ -91,7 +92,7 @@ ToolFuncPlain = Callable[ToolParams, Any]
91
92
 
92
93
  Usage `ToolPlainFunc[ToolParams]`.
93
94
  """
94
- ToolFuncEither = Union[ToolFuncContext[AgentDeps, ToolParams], ToolFuncPlain[ToolParams]]
95
+ ToolFuncEither = Union[ToolFuncContext[AgentDepsT, ToolParams], ToolFuncPlain[ToolParams]]
95
96
  """Either kind of tool function.
96
97
 
97
98
  This is just a union of [`ToolFuncContext`][pydantic_ai.tools.ToolFuncContext] and
@@ -99,14 +100,14 @@ This is just a union of [`ToolFuncContext`][pydantic_ai.tools.ToolFuncContext] a
99
100
 
100
101
  Usage `ToolFuncEither[AgentDeps, ToolParams]`.
101
102
  """
102
- ToolPrepareFunc: TypeAlias = 'Callable[[RunContext[AgentDeps], ToolDefinition], Awaitable[ToolDefinition | None]]'
103
+ ToolPrepareFunc: TypeAlias = 'Callable[[RunContext[AgentDepsT], ToolDefinition], Awaitable[ToolDefinition | None]]'
103
104
  """Definition of a function that can prepare a tool definition at call time.
104
105
 
105
106
  See [tool docs](../tools.md#tool-prepare) for more information.
106
107
 
107
108
  Example — here `only_if_42` is valid as a `ToolPrepareFunc`:
108
109
 
109
- ```python {lint="not-imports"}
110
+ ```python {noqa="I001"}
110
111
  from typing import Union
111
112
 
112
113
  from pydantic_ai import RunContext, Tool
@@ -127,19 +128,30 @@ hitchhiker = Tool(hitchhiker, prepare=only_if_42)
127
128
  Usage `ToolPrepareFunc[AgentDeps]`.
128
129
  """
129
130
 
131
+ DocstringFormat = Literal['google', 'numpy', 'sphinx', 'auto']
132
+ """Supported docstring formats.
133
+
134
+ * `'google'` — [Google-style](https://google.github.io/styleguide/pyguide.html#381-docstrings) docstrings.
135
+ * `'numpy'` — [Numpy-style](https://numpydoc.readthedocs.io/en/latest/format.html) docstrings.
136
+ * `'sphinx'` — [Sphinx-style](https://sphinx-rtd-tutorial.readthedocs.io/en/latest/docstrings.html#the-sphinx-docstring-format) docstrings.
137
+ * `'auto'` — Automatically infer the format based on the structure of the docstring.
138
+ """
139
+
130
140
  A = TypeVar('A')
131
141
 
132
142
 
133
143
  @dataclass(init=False)
134
- class Tool(Generic[AgentDeps]):
144
+ class Tool(Generic[AgentDepsT]):
135
145
  """A tool function for an agent."""
136
146
 
137
- function: ToolFuncEither[AgentDeps]
147
+ function: ToolFuncEither[AgentDepsT]
138
148
  takes_ctx: bool
139
149
  max_retries: int | None
140
150
  name: str
141
151
  description: str
142
- prepare: ToolPrepareFunc[AgentDeps] | None
152
+ prepare: ToolPrepareFunc[AgentDepsT] | None
153
+ docstring_format: DocstringFormat
154
+ require_parameter_descriptions: bool
143
155
  _is_async: bool = field(init=False)
144
156
  _single_arg_name: str | None = field(init=False)
145
157
  _positional_fields: list[str] = field(init=False)
@@ -150,19 +162,21 @@ class Tool(Generic[AgentDeps]):
150
162
 
151
163
  def __init__(
152
164
  self,
153
- function: ToolFuncEither[AgentDeps],
165
+ function: ToolFuncEither[AgentDepsT],
154
166
  *,
155
167
  takes_ctx: bool | None = None,
156
168
  max_retries: int | None = None,
157
169
  name: str | None = None,
158
170
  description: str | None = None,
159
- prepare: ToolPrepareFunc[AgentDeps] | None = None,
171
+ prepare: ToolPrepareFunc[AgentDepsT] | None = None,
172
+ docstring_format: DocstringFormat = 'auto',
173
+ require_parameter_descriptions: bool = False,
160
174
  ):
161
175
  """Create a new tool instance.
162
176
 
163
177
  Example usage:
164
178
 
165
- ```python {lint="not-imports"}
179
+ ```python {noqa="I001"}
166
180
  from pydantic_ai import Agent, RunContext, Tool
167
181
 
168
182
  async def my_tool(ctx: RunContext[int], x: int, y: int) -> str:
@@ -173,7 +187,7 @@ class Tool(Generic[AgentDeps]):
173
187
 
174
188
  or with a custom prepare method:
175
189
 
176
- ```python {lint="not-imports"}
190
+ ```python {noqa="I001"}
177
191
  from typing import Union
178
192
 
179
193
  from pydantic_ai import Agent, RunContext, Tool
@@ -203,17 +217,22 @@ class Tool(Generic[AgentDeps]):
203
217
  prepare: custom method to prepare the tool definition for each step, return `None` to omit this
204
218
  tool from a given step. This is useful if you want to customise a tool at call time,
205
219
  or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
220
+ docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
221
+ Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
222
+ require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
206
223
  """
207
224
  if takes_ctx is None:
208
225
  takes_ctx = _pydantic.takes_ctx(function)
209
226
 
210
- f = _pydantic.function_schema(function, takes_ctx)
227
+ f = _pydantic.function_schema(function, takes_ctx, docstring_format, require_parameter_descriptions)
211
228
  self.function = function
212
229
  self.takes_ctx = takes_ctx
213
230
  self.max_retries = max_retries
214
231
  self.name = name or function.__name__
215
232
  self.description = description or f['description']
216
233
  self.prepare = prepare
234
+ self.docstring_format = docstring_format
235
+ self.require_parameter_descriptions = require_parameter_descriptions
217
236
  self._is_async = inspect.iscoroutinefunction(self.function)
218
237
  self._single_arg_name = f['single_arg_name']
219
238
  self._positional_fields = f['positional_fields']
@@ -221,7 +240,7 @@ class Tool(Generic[AgentDeps]):
221
240
  self._validator = f['validator']
222
241
  self._parameters_json_schema = f['json_schema']
223
242
 
224
- async def prepare_tool_def(self, ctx: RunContext[AgentDeps]) -> ToolDefinition | None:
243
+ async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None:
225
244
  """Get the tool definition.
226
245
 
227
246
  By default, this method creates a tool definition, then either returns it, or calls `self.prepare`
@@ -241,7 +260,7 @@ class Tool(Generic[AgentDeps]):
241
260
  return tool_def
242
261
 
243
262
  async def run(
244
- self, message: _messages.ToolCallPart, run_context: RunContext[AgentDeps]
263
+ self, message: _messages.ToolCallPart, run_context: RunContext[AgentDepsT]
245
264
  ) -> _messages.ModelRequestPart:
246
265
  """Run the tool function asynchronously."""
247
266
  try:
@@ -274,7 +293,7 @@ class Tool(Generic[AgentDeps]):
274
293
  self,
275
294
  args_dict: dict[str, Any],
276
295
  message: _messages.ToolCallPart,
277
- run_context: RunContext[AgentDeps],
296
+ run_context: RunContext[AgentDepsT],
278
297
  ) -> tuple[list[Any], dict[str, Any]]:
279
298
  if self._single_arg_name:
280
299
  args_dict = {self._single_arg_name: args_dict}
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 0.0.18
3
+ Version: 0.0.20
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>
6
6
  License-Expression: MIT
@@ -26,11 +26,15 @@ Classifier: Topic :: Software Development :: Libraries :: Python Modules
26
26
  Requires-Python: >=3.9
27
27
  Requires-Dist: eval-type-backport>=0.2.0
28
28
  Requires-Dist: griffe>=1.3.2
29
- Requires-Dist: httpx>=0.27.2
29
+ Requires-Dist: httpx>=0.27
30
30
  Requires-Dist: logfire-api>=1.2.0
31
31
  Requires-Dist: pydantic>=2.10
32
32
  Provides-Extra: anthropic
33
33
  Requires-Dist: anthropic>=0.40.0; extra == 'anthropic'
34
+ Provides-Extra: cohere
35
+ Requires-Dist: cohere>=5.13.11; extra == 'cohere'
36
+ Provides-Extra: graph
37
+ Requires-Dist: pydantic-graph==0.0.20; extra == 'graph'
34
38
  Provides-Extra: groq
35
39
  Requires-Dist: groq>=0.12.0; extra == 'groq'
36
40
  Provides-Extra: logfire
@@ -38,7 +42,7 @@ Requires-Dist: logfire>=2.3; extra == 'logfire'
38
42
  Provides-Extra: mistral
39
43
  Requires-Dist: mistralai>=1.2.5; extra == 'mistral'
40
44
  Provides-Extra: openai
41
- Requires-Dist: openai>=1.54.3; extra == 'openai'
45
+ Requires-Dist: openai>=1.59.0; extra == 'openai'
42
46
  Provides-Extra: vertexai
43
47
  Requires-Dist: google-auth>=2.36.0; extra == 'vertexai'
44
48
  Requires-Dist: requests>=2.32.3; extra == 'vertexai'