pydantic-ai-slim 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/_griffe.py CHANGED
@@ -1,16 +1,24 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
+ import logging
3
4
  import re
5
+ from contextlib import contextmanager
4
6
  from inspect import Signature
5
- from typing import Any, Callable, Literal, cast
7
+ from typing import TYPE_CHECKING, Any, Callable, Literal, cast
6
8
 
7
9
  from griffe import Docstring, DocstringSectionKind, Object as GriffeObject
8
10
 
11
+ if TYPE_CHECKING:
12
+ from .tools import DocstringFormat
13
+
9
14
  DocstringStyle = Literal['google', 'numpy', 'sphinx']
10
15
 
11
16
 
12
17
  def doc_descriptions(
13
- func: Callable[..., Any], sig: Signature, *, style: DocstringStyle | None = None
18
+ func: Callable[..., Any],
19
+ sig: Signature,
20
+ *,
21
+ docstring_format: DocstringFormat,
14
22
  ) -> tuple[str, dict[str, str]]:
15
23
  """Extract the function description and parameter descriptions from a function's docstring.
16
24
 
@@ -24,8 +32,10 @@ def doc_descriptions(
24
32
  # see https://github.com/mkdocstrings/griffe/issues/293
25
33
  parent = cast(GriffeObject, sig)
26
34
 
27
- docstring = Docstring(doc, lineno=1, parser=style or _infer_docstring_style(doc), parent=parent)
28
- sections = docstring.parse()
35
+ docstring_style = _infer_docstring_style(doc) if docstring_format == 'auto' else docstring_format
36
+ docstring = Docstring(doc, lineno=1, parser=docstring_style, parent=parent)
37
+ with _disable_griffe_logging():
38
+ sections = docstring.parse()
29
39
 
30
40
  params = {}
31
41
  if parameters := next((p for p in sections if p.kind == DocstringSectionKind.parameters), None):
@@ -125,3 +135,12 @@ _docstring_style_patterns: list[tuple[str, list[str], DocstringStyle]] = [
125
135
  'numpy',
126
136
  ),
127
137
  ]
138
+
139
+
140
+ @contextmanager
141
+ def _disable_griffe_logging():
142
+ # Hacky, but suggested here: https://github.com/mkdocstrings/griffe/issues/293#issuecomment-2167668117
143
+ old_level = logging.root.getEffectiveLevel()
144
+ logging.root.setLevel(logging.ERROR)
145
+ yield
146
+ logging.root.setLevel(old_level)
@@ -0,0 +1,239 @@
1
+ """This module provides functionality to manage and update parts of a model's streamed response.
2
+
3
+ The manager tracks which parts (in particular, text and tool calls) correspond to which
4
+ vendor-specific identifiers (e.g., `index`, `tool_call_id`, etc., as appropriate for a given model),
5
+ and produces PydanticAI-format events as appropriate for consumers of the streaming APIs.
6
+
7
+ The "vendor-specific identifiers" to use depend on the semantics of the responses of the responses from the vendor,
8
+ and are tightly coupled to the specific model being used, and the PydanticAI Model subclass implementation.
9
+
10
+ This `ModelResponsePartsManager` is used in each of the subclasses of `StreamedResponse` as a way to consolidate
11
+ event-emitting logic.
12
+ """
13
+
14
+ from __future__ import annotations as _annotations
15
+
16
+ from collections.abc import Hashable
17
+ from dataclasses import dataclass, field
18
+ from typing import Any, Union
19
+
20
+ from pydantic_ai.exceptions import UnexpectedModelBehavior
21
+ from pydantic_ai.messages import (
22
+ ModelResponsePart,
23
+ ModelResponseStreamEvent,
24
+ PartDeltaEvent,
25
+ PartStartEvent,
26
+ TextPart,
27
+ TextPartDelta,
28
+ ToolCallPart,
29
+ ToolCallPartDelta,
30
+ )
31
+
32
+ VendorId = Hashable
33
+ """
34
+ Type alias for a vendor identifier, which can be any hashable type (e.g., a string, UUID, etc.)
35
+ """
36
+
37
+ ManagedPart = Union[ModelResponsePart, ToolCallPartDelta]
38
+ """
39
+ A union of types that are managed by the ModelResponsePartsManager.
40
+ Because many vendors have streaming APIs that may produce not-fully-formed tool calls,
41
+ this includes ToolCallPartDelta's in addition to the more fully-formed ModelResponsePart's.
42
+ """
43
+
44
+
45
+ @dataclass
46
+ class ModelResponsePartsManager:
47
+ """Manages a sequence of parts that make up a model's streamed response.
48
+
49
+ Parts are generally added and/or updated by providing deltas, which are tracked by vendor-specific IDs.
50
+ """
51
+
52
+ _parts: list[ManagedPart] = field(default_factory=list, init=False)
53
+ """A list of parts (text or tool calls) that make up the current state of the model's response."""
54
+ _vendor_id_to_part_index: dict[VendorId, int] = field(default_factory=dict, init=False)
55
+ """Maps a vendor's "part" ID (if provided) to the index in `_parts` where that part resides."""
56
+
57
+ def get_parts(self) -> list[ModelResponsePart]:
58
+ """Return only model response parts that are complete (i.e., not ToolCallPartDelta's).
59
+
60
+ Returns:
61
+ A list of ModelResponsePart objects. ToolCallPartDelta objects are excluded.
62
+ """
63
+ return [p for p in self._parts if not isinstance(p, ToolCallPartDelta)]
64
+
65
+ def handle_text_delta(
66
+ self,
67
+ *,
68
+ vendor_part_id: Hashable | None,
69
+ content: str,
70
+ ) -> ModelResponseStreamEvent:
71
+ """Handle incoming text content, creating or updating a TextPart in the manager as appropriate.
72
+
73
+ When `vendor_part_id` is None, the latest part is updated if it exists and is a TextPart;
74
+ otherwise, a new TextPart is created. When a non-None ID is specified, the TextPart corresponding
75
+ to that vendor ID is either created or updated.
76
+
77
+ Args:
78
+ vendor_part_id: The ID the vendor uses to identify this piece
79
+ of text. If None, a new part will be created unless the latest part is already
80
+ a TextPart.
81
+ content: The text content to append to the appropriate TextPart.
82
+
83
+ Returns:
84
+ A `PartStartEvent` if a new part was created, or a `PartDeltaEvent` if an existing part was updated.
85
+
86
+ Raises:
87
+ UnexpectedModelBehavior: If attempting to apply text content to a part that is
88
+ not a TextPart.
89
+ """
90
+ existing_text_part_and_index: tuple[TextPart, int] | None = None
91
+
92
+ if vendor_part_id is None:
93
+ # If the vendor_part_id is None, check if the latest part is a TextPart to update
94
+ if self._parts:
95
+ part_index = len(self._parts) - 1
96
+ latest_part = self._parts[part_index]
97
+ if isinstance(latest_part, TextPart):
98
+ existing_text_part_and_index = latest_part, part_index
99
+ else:
100
+ # Otherwise, attempt to look up an existing TextPart by vendor_part_id
101
+ part_index = self._vendor_id_to_part_index.get(vendor_part_id)
102
+ if part_index is not None:
103
+ existing_part = self._parts[part_index]
104
+ if not isinstance(existing_part, TextPart):
105
+ raise UnexpectedModelBehavior(f'Cannot apply a text delta to {existing_part=}')
106
+ existing_text_part_and_index = existing_part, part_index
107
+
108
+ if existing_text_part_and_index is None:
109
+ # There is no existing text part that should be updated, so create a new one
110
+ new_part_index = len(self._parts)
111
+ part = TextPart(content=content)
112
+ if vendor_part_id is not None:
113
+ self._vendor_id_to_part_index[vendor_part_id] = new_part_index
114
+ self._parts.append(part)
115
+ return PartStartEvent(index=new_part_index, part=part)
116
+ else:
117
+ # Update the existing TextPart with the new content delta
118
+ existing_text_part, part_index = existing_text_part_and_index
119
+ part_delta = TextPartDelta(content_delta=content)
120
+ self._parts[part_index] = part_delta.apply(existing_text_part)
121
+ return PartDeltaEvent(index=part_index, delta=part_delta)
122
+
123
+ def handle_tool_call_delta(
124
+ self,
125
+ *,
126
+ vendor_part_id: Hashable | None,
127
+ tool_name: str | None,
128
+ args: str | dict[str, Any] | None,
129
+ tool_call_id: str | None,
130
+ ) -> ModelResponseStreamEvent | None:
131
+ """Handle or update a tool call, creating or updating a `ToolCallPart` or `ToolCallPartDelta`.
132
+
133
+ Managed items remain as `ToolCallPartDelta`s until they have both a tool_name and arguments, at which
134
+ point they are upgraded to `ToolCallPart`s.
135
+
136
+ If `vendor_part_id` is None, updates the latest matching ToolCallPart (or ToolCallPartDelta)
137
+ if any. Otherwise, a new part (or delta) may be created.
138
+
139
+ Args:
140
+ vendor_part_id: The ID the vendor uses for this tool call.
141
+ If None, the latest matching tool call may be updated.
142
+ tool_name: The name of the tool. If None, the manager does not enforce
143
+ a name match when `vendor_part_id` is None.
144
+ args: Arguments for the tool call, either as a string or a dictionary of key-value pairs.
145
+ tool_call_id: An optional string representing an identifier for this tool call.
146
+
147
+ Returns:
148
+ - A `PartStartEvent` if a new (fully realized) ToolCallPart is created.
149
+ - A `PartDeltaEvent` if an existing part is updated.
150
+ - `None` if no new event is emitted (e.g., the part is still incomplete).
151
+
152
+ Raises:
153
+ UnexpectedModelBehavior: If attempting to apply a tool call delta to a part that is not
154
+ a ToolCallPart or ToolCallPartDelta.
155
+ """
156
+ existing_matching_part_and_index: tuple[ToolCallPartDelta | ToolCallPart, int] | None = None
157
+
158
+ if vendor_part_id is None:
159
+ # vendor_part_id is None, so check if the latest part is a matching tool call or delta to update
160
+ # When the vendor_part_id is None, if the tool_name is _not_ None, assume this should be a new part rather
161
+ # than a delta on an existing one. We can change this behavior in the future if necessary for some model.
162
+ if tool_name is None and self._parts:
163
+ part_index = len(self._parts) - 1
164
+ latest_part = self._parts[part_index]
165
+ if isinstance(latest_part, (ToolCallPart, ToolCallPartDelta)):
166
+ existing_matching_part_and_index = latest_part, part_index
167
+ else:
168
+ # vendor_part_id is provided, so look up the corresponding part or delta
169
+ part_index = self._vendor_id_to_part_index.get(vendor_part_id)
170
+ if part_index is not None:
171
+ existing_part = self._parts[part_index]
172
+ if not isinstance(existing_part, (ToolCallPartDelta, ToolCallPart)):
173
+ raise UnexpectedModelBehavior(f'Cannot apply a tool call delta to {existing_part=}')
174
+ existing_matching_part_and_index = existing_part, part_index
175
+
176
+ if existing_matching_part_and_index is None:
177
+ # No matching part/delta was found, so create a new ToolCallPartDelta (or ToolCallPart if fully formed)
178
+ delta = ToolCallPartDelta(tool_name_delta=tool_name, args_delta=args, tool_call_id=tool_call_id)
179
+ part = delta.as_part() or delta
180
+ if vendor_part_id is not None:
181
+ self._vendor_id_to_part_index[vendor_part_id] = len(self._parts)
182
+ new_part_index = len(self._parts)
183
+ self._parts.append(part)
184
+ # Only emit a PartStartEvent if we have enough information to produce a full ToolCallPart
185
+ if isinstance(part, ToolCallPart):
186
+ return PartStartEvent(index=new_part_index, part=part)
187
+ else:
188
+ # Update the existing part or delta with the new information
189
+ existing_part, part_index = existing_matching_part_and_index
190
+ delta = ToolCallPartDelta(tool_name_delta=tool_name, args_delta=args, tool_call_id=tool_call_id)
191
+ updated_part = delta.apply(existing_part)
192
+ self._parts[part_index] = updated_part
193
+ if isinstance(updated_part, ToolCallPart):
194
+ if isinstance(existing_part, ToolCallPartDelta):
195
+ # We just upgraded a delta to a full part, so emit a PartStartEvent
196
+ return PartStartEvent(index=part_index, part=updated_part)
197
+ else:
198
+ # We updated an existing part, so emit a PartDeltaEvent
199
+ return PartDeltaEvent(index=part_index, delta=delta)
200
+
201
+ def handle_tool_call_part(
202
+ self,
203
+ *,
204
+ vendor_part_id: Hashable | None,
205
+ tool_name: str,
206
+ args: str | dict[str, Any],
207
+ tool_call_id: str | None = None,
208
+ ) -> ModelResponseStreamEvent:
209
+ """Immediately create or fully-overwrite a ToolCallPart with the given information.
210
+
211
+ This does not apply a delta; it directly sets the tool call part contents.
212
+
213
+ Args:
214
+ vendor_part_id: The vendor's ID for this tool call part. If not
215
+ None and an existing part is found, that part is overwritten.
216
+ tool_name: The name of the tool being invoked.
217
+ args: The arguments for the tool call, either as a string or a dictionary.
218
+ tool_call_id: An optional string identifier for this tool call.
219
+
220
+ Returns:
221
+ ModelResponseStreamEvent: A `PartStartEvent` indicating that a new tool call part
222
+ has been added to the manager, or replaced an existing part.
223
+ """
224
+ new_part = ToolCallPart.from_raw_args(tool_name=tool_name, args=args, tool_call_id=tool_call_id)
225
+ if vendor_part_id is None:
226
+ # vendor_part_id is None, so we unconditionally append a new ToolCallPart to the end of the list
227
+ new_part_index = len(self._parts)
228
+ self._parts.append(new_part)
229
+ else:
230
+ # vendor_part_id is provided, so find and overwrite or create a new ToolCallPart.
231
+ maybe_part_index = self._vendor_id_to_part_index.get(vendor_part_id)
232
+ if maybe_part_index is not None:
233
+ new_part_index = maybe_part_index
234
+ self._parts[new_part_index] = new_part
235
+ else:
236
+ new_part_index = len(self._parts)
237
+ self._parts.append(new_part)
238
+ self._vendor_id_to_part_index[vendor_part_id] = new_part_index
239
+ return PartStartEvent(index=new_part_index, part=new_part)
pydantic_ai/_pydantic.py CHANGED
@@ -20,7 +20,7 @@ from ._griffe import doc_descriptions
20
20
  from ._utils import check_object_json_schema, is_model_like
21
21
 
22
22
  if TYPE_CHECKING:
23
- from .tools import ObjectJsonSchema
23
+ from .tools import DocstringFormat, ObjectJsonSchema
24
24
 
25
25
 
26
26
  __all__ = ('function_schema',)
@@ -38,12 +38,19 @@ class FunctionSchema(TypedDict):
38
38
  var_positional_field: str | None
39
39
 
40
40
 
41
- def function_schema(function: Callable[..., Any], takes_ctx: bool) -> FunctionSchema: # noqa: C901
41
+ def function_schema( # noqa: C901
42
+ function: Callable[..., Any],
43
+ takes_ctx: bool,
44
+ docstring_format: DocstringFormat,
45
+ require_parameter_descriptions: bool,
46
+ ) -> FunctionSchema:
42
47
  """Build a Pydantic validator and JSON schema from a tool function.
43
48
 
44
49
  Args:
45
50
  function: The function to build a validator and JSON schema for.
46
51
  takes_ctx: Whether the function takes a `RunContext` first argument.
52
+ docstring_format: The docstring format to use.
53
+ require_parameter_descriptions: Whether to require descriptions for all tool function parameters.
47
54
 
48
55
  Returns:
49
56
  A `FunctionSchema` instance.
@@ -62,7 +69,13 @@ def function_schema(function: Callable[..., Any], takes_ctx: bool) -> FunctionSc
62
69
  var_positional_field: str | None = None
63
70
  errors: list[str] = []
64
71
  decorators = _decorators.DecoratorInfos()
65
- description, field_descriptions = doc_descriptions(function, sig)
72
+
73
+ description, field_descriptions = doc_descriptions(function, sig, docstring_format=docstring_format)
74
+
75
+ if require_parameter_descriptions:
76
+ if len(field_descriptions) != len(sig.parameters):
77
+ missing_params = set(sig.parameters) - set(field_descriptions)
78
+ errors.append(f'Missing parameter descriptions for {", ".join(missing_params)}')
66
79
 
67
80
  for index, (name, p) in enumerate(sig.parameters.items()):
68
81
  if p.annotation is sig.empty:
@@ -12,6 +12,7 @@ from .tools import AgentDeps, RunContext, SystemPromptFunc
12
12
  @dataclass
13
13
  class SystemPromptRunner(Generic[AgentDeps]):
14
14
  function: SystemPromptFunc[AgentDeps]
15
+ dynamic: bool = False
15
16
  _takes_ctx: bool = field(init=False)
16
17
  _is_async: bool = field(init=False)
17
18
 
pydantic_ai/_utils.py CHANGED
@@ -15,7 +15,7 @@ from pydantic.json_schema import JsonSchemaValue
15
15
  from typing_extensions import ParamSpec, TypeAlias, TypeGuard, is_typeddict
16
16
 
17
17
  if TYPE_CHECKING:
18
- from .messages import RetryPromptPart, ToolCallPart, ToolReturnPart
18
+ from . import messages as _messages
19
19
  from .tools import ObjectJsonSchema
20
20
 
21
21
  _P = ParamSpec('_P')
@@ -136,7 +136,7 @@ class Either(Generic[Left, Right]):
136
136
 
137
137
  @asynccontextmanager
138
138
  async def group_by_temporal(
139
- aiter: AsyncIterator[T], soft_max_interval: float | None
139
+ aiterable: AsyncIterable[T], soft_max_interval: float | None
140
140
  ) -> AsyncIterator[AsyncIterable[list[T]]]:
141
141
  """Group items from an async iterable into lists based on time interval between them.
142
142
 
@@ -154,18 +154,18 @@ async def group_by_temporal(
154
154
  ```
155
155
 
156
156
  Args:
157
- aiter: The async iterable to group.
157
+ aiterable: The async iterable to group.
158
158
  soft_max_interval: Maximum interval over which to group items, this should avoid a trickle of items causing
159
159
  a group to never be yielded. It's a soft max in the sense that once we're over this time, we yield items
160
160
  as soon as `aiter.__anext__()` returns. If `None`, no grouping/debouncing is performed
161
161
 
162
162
  Returns:
163
- A context manager usable as an iterator async iterable of lists of items from the input async iterable.
163
+ A context manager usable as an async iterable of lists of items produced by the input async iterable.
164
164
  """
165
165
  if soft_max_interval is None:
166
166
 
167
167
  async def async_iter_groups_noop() -> AsyncIterator[list[T]]:
168
- async for item in aiter:
168
+ async for item in aiterable:
169
169
  yield [item]
170
170
 
171
171
  yield async_iter_groups_noop()
@@ -181,6 +181,7 @@ async def group_by_temporal(
181
181
  buffer: list[T] = []
182
182
  group_start_time = time.monotonic()
183
183
 
184
+ aiterator = aiterable.__aiter__()
184
185
  while True:
185
186
  if group_start_time is None:
186
187
  # group hasn't started, we just wait for the maximum interval
@@ -193,7 +194,7 @@ async def group_by_temporal(
193
194
  if task is None:
194
195
  # aiter.__anext__() returns an Awaitable[T], not a Coroutine which asyncio.create_task expects
195
196
  # so far, this doesn't seem to be a problem
196
- task = asyncio.create_task(aiter.__anext__()) # pyright: ignore[reportArgumentType]
197
+ task = asyncio.create_task(aiterator.__anext__()) # pyright: ignore[reportArgumentType]
197
198
 
198
199
  # we use asyncio.wait to avoid cancelling the coroutine if it's not done
199
200
  done, _ = await asyncio.wait((task,), timeout=wait_time)
@@ -232,16 +233,6 @@ async def group_by_temporal(
232
233
  await task
233
234
 
234
235
 
235
- def add_optional(a: str | None, b: str | None) -> str | None:
236
- """Add two optional strings."""
237
- if a is None:
238
- return b
239
- elif b is None:
240
- return a
241
- else:
242
- return a + b
243
-
244
-
245
236
  def sync_anext(iterator: Iterator[T]) -> T:
246
237
  """Get the next item from a sync iterator, raising `StopAsyncIteration` if it's exhausted.
247
238
 
@@ -257,7 +248,79 @@ def now_utc() -> datetime:
257
248
  return datetime.now(tz=timezone.utc)
258
249
 
259
250
 
260
- def guard_tool_call_id(t: ToolCallPart | ToolReturnPart | RetryPromptPart, model_source: str) -> str:
251
+ def guard_tool_call_id(
252
+ t: _messages.ToolCallPart | _messages.ToolReturnPart | _messages.RetryPromptPart, model_source: str
253
+ ) -> str:
261
254
  """Type guard that checks a `tool_call_id` is not None both for static typing and runtime."""
262
255
  assert t.tool_call_id is not None, f'{model_source} requires `tool_call_id` to be set: {t}'
263
256
  return t.tool_call_id
257
+
258
+
259
+ class PeekableAsyncStream(Generic[T]):
260
+ """Wraps an async iterable of type T and allows peeking at the *next* item without consuming it.
261
+
262
+ We only buffer one item at a time (the next item). Once that item is yielded, it is discarded.
263
+ This is a single-pass stream.
264
+ """
265
+
266
+ def __init__(self, source: AsyncIterable[T]):
267
+ self._source = source
268
+ self._source_iter: AsyncIterator[T] | None = None
269
+ self._buffer: T | Unset = UNSET
270
+ self._exhausted = False
271
+
272
+ async def peek(self) -> T | Unset:
273
+ """Returns the next item that would be yielded without consuming it.
274
+
275
+ Returns None if the stream is exhausted.
276
+ """
277
+ if self._exhausted:
278
+ return UNSET
279
+
280
+ # If we already have a buffered item, just return it.
281
+ if not isinstance(self._buffer, Unset):
282
+ return self._buffer
283
+
284
+ # Otherwise, we need to fetch the next item from the underlying iterator.
285
+ if self._source_iter is None:
286
+ self._source_iter = self._source.__aiter__()
287
+
288
+ try:
289
+ self._buffer = await self._source_iter.__anext__()
290
+ except StopAsyncIteration:
291
+ self._exhausted = True
292
+ return UNSET
293
+
294
+ return self._buffer
295
+
296
+ async def is_exhausted(self) -> bool:
297
+ """Returns True if the stream is exhausted, False otherwise."""
298
+ return isinstance(await self.peek(), Unset)
299
+
300
+ def __aiter__(self) -> AsyncIterator[T]:
301
+ # For a single-pass iteration, we can return self as the iterator.
302
+ return self
303
+
304
+ async def __anext__(self) -> T:
305
+ """Yields the buffered item if present, otherwise fetches the next item from the underlying source.
306
+
307
+ Raises StopAsyncIteration if the stream is exhausted.
308
+ """
309
+ if self._exhausted:
310
+ raise StopAsyncIteration
311
+
312
+ # If we have a buffered item, yield it.
313
+ if not isinstance(self._buffer, Unset):
314
+ item = self._buffer
315
+ self._buffer = UNSET
316
+ return item
317
+
318
+ # Otherwise, fetch the next item from the source.
319
+ if self._source_iter is None:
320
+ self._source_iter = self._source.__aiter__()
321
+
322
+ try:
323
+ return await self._source_iter.__anext__()
324
+ except StopAsyncIteration:
325
+ self._exhausted = True
326
+ raise