pydantic-ai-slim 0.0.18__tar.gz → 0.0.19__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- {pydantic_ai_slim-0.0.18 → pydantic_ai_slim-0.0.19}/.gitignore +1 -0
- {pydantic_ai_slim-0.0.18 → pydantic_ai_slim-0.0.19}/PKG-INFO +3 -1
- {pydantic_ai_slim-0.0.18 → pydantic_ai_slim-0.0.19}/pydantic_ai/_griffe.py +10 -3
- pydantic_ai_slim-0.0.19/pydantic_ai/_parts_manager.py +239 -0
- {pydantic_ai_slim-0.0.18 → pydantic_ai_slim-0.0.19}/pydantic_ai/_pydantic.py +16 -3
- {pydantic_ai_slim-0.0.18 → pydantic_ai_slim-0.0.19}/pydantic_ai/_utils.py +80 -17
- {pydantic_ai_slim-0.0.18 → pydantic_ai_slim-0.0.19}/pydantic_ai/agent.py +82 -74
- {pydantic_ai_slim-0.0.18 → pydantic_ai_slim-0.0.19}/pydantic_ai/format_as_xml.py +2 -1
- pydantic_ai_slim-0.0.19/pydantic_ai/messages.py +479 -0
- {pydantic_ai_slim-0.0.18 → pydantic_ai_slim-0.0.19}/pydantic_ai/models/__init__.py +31 -72
- {pydantic_ai_slim-0.0.18 → pydantic_ai_slim-0.0.19}/pydantic_ai/models/anthropic.py +21 -21
- {pydantic_ai_slim-0.0.18 → pydantic_ai_slim-0.0.19}/pydantic_ai/models/function.py +47 -79
- {pydantic_ai_slim-0.0.18 → pydantic_ai_slim-0.0.19}/pydantic_ai/models/gemini.py +76 -122
- {pydantic_ai_slim-0.0.18 → pydantic_ai_slim-0.0.19}/pydantic_ai/models/groq.py +53 -125
- {pydantic_ai_slim-0.0.18 → pydantic_ai_slim-0.0.19}/pydantic_ai/models/mistral.py +75 -137
- {pydantic_ai_slim-0.0.18 → pydantic_ai_slim-0.0.19}/pydantic_ai/models/ollama.py +1 -0
- {pydantic_ai_slim-0.0.18 → pydantic_ai_slim-0.0.19}/pydantic_ai/models/openai.py +50 -125
- {pydantic_ai_slim-0.0.18 → pydantic_ai_slim-0.0.19}/pydantic_ai/models/test.py +40 -73
- {pydantic_ai_slim-0.0.18 → pydantic_ai_slim-0.0.19}/pydantic_ai/result.py +91 -92
- {pydantic_ai_slim-0.0.18 → pydantic_ai_slim-0.0.19}/pydantic_ai/tools.py +24 -5
- {pydantic_ai_slim-0.0.18 → pydantic_ai_slim-0.0.19}/pyproject.toml +9 -4
- pydantic_ai_slim-0.0.18/pydantic_ai/messages.py +0 -270
- {pydantic_ai_slim-0.0.18 → pydantic_ai_slim-0.0.19}/README.md +0 -0
- {pydantic_ai_slim-0.0.18 → pydantic_ai_slim-0.0.19}/pydantic_ai/__init__.py +0 -0
- {pydantic_ai_slim-0.0.18 → pydantic_ai_slim-0.0.19}/pydantic_ai/_result.py +0 -0
- {pydantic_ai_slim-0.0.18 → pydantic_ai_slim-0.0.19}/pydantic_ai/_system_prompt.py +0 -0
- {pydantic_ai_slim-0.0.18 → pydantic_ai_slim-0.0.19}/pydantic_ai/exceptions.py +0 -0
- {pydantic_ai_slim-0.0.18 → pydantic_ai_slim-0.0.19}/pydantic_ai/models/vertexai.py +0 -0
- {pydantic_ai_slim-0.0.18 → pydantic_ai_slim-0.0.19}/pydantic_ai/py.typed +0 -0
- {pydantic_ai_slim-0.0.18 → pydantic_ai_slim-0.0.19}/pydantic_ai/settings.py +0 -0
- {pydantic_ai_slim-0.0.18 → pydantic_ai_slim-0.0.19}/pydantic_ai/usage.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai-slim
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.19
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
|
|
5
5
|
Author-email: Samuel Colvin <samuel@pydantic.dev>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -31,6 +31,8 @@ Requires-Dist: logfire-api>=1.2.0
|
|
|
31
31
|
Requires-Dist: pydantic>=2.10
|
|
32
32
|
Provides-Extra: anthropic
|
|
33
33
|
Requires-Dist: anthropic>=0.40.0; extra == 'anthropic'
|
|
34
|
+
Provides-Extra: graph
|
|
35
|
+
Requires-Dist: pydantic-graph==0.0.19; extra == 'graph'
|
|
34
36
|
Provides-Extra: groq
|
|
35
37
|
Requires-Dist: groq>=0.12.0; extra == 'groq'
|
|
36
38
|
Provides-Extra: logfire
|
|
@@ -4,15 +4,21 @@ import logging
|
|
|
4
4
|
import re
|
|
5
5
|
from contextlib import contextmanager
|
|
6
6
|
from inspect import Signature
|
|
7
|
-
from typing import Any, Callable, Literal, cast
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Callable, Literal, cast
|
|
8
8
|
|
|
9
9
|
from griffe import Docstring, DocstringSectionKind, Object as GriffeObject
|
|
10
10
|
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from .tools import DocstringFormat
|
|
13
|
+
|
|
11
14
|
DocstringStyle = Literal['google', 'numpy', 'sphinx']
|
|
12
15
|
|
|
13
16
|
|
|
14
17
|
def doc_descriptions(
|
|
15
|
-
func: Callable[..., Any],
|
|
18
|
+
func: Callable[..., Any],
|
|
19
|
+
sig: Signature,
|
|
20
|
+
*,
|
|
21
|
+
docstring_format: DocstringFormat,
|
|
16
22
|
) -> tuple[str, dict[str, str]]:
|
|
17
23
|
"""Extract the function description and parameter descriptions from a function's docstring.
|
|
18
24
|
|
|
@@ -26,7 +32,8 @@ def doc_descriptions(
|
|
|
26
32
|
# see https://github.com/mkdocstrings/griffe/issues/293
|
|
27
33
|
parent = cast(GriffeObject, sig)
|
|
28
34
|
|
|
29
|
-
|
|
35
|
+
docstring_style = _infer_docstring_style(doc) if docstring_format == 'auto' else docstring_format
|
|
36
|
+
docstring = Docstring(doc, lineno=1, parser=docstring_style, parent=parent)
|
|
30
37
|
with _disable_griffe_logging():
|
|
31
38
|
sections = docstring.parse()
|
|
32
39
|
|
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
"""This module provides functionality to manage and update parts of a model's streamed response.
|
|
2
|
+
|
|
3
|
+
The manager tracks which parts (in particular, text and tool calls) correspond to which
|
|
4
|
+
vendor-specific identifiers (e.g., `index`, `tool_call_id`, etc., as appropriate for a given model),
|
|
5
|
+
and produces PydanticAI-format events as appropriate for consumers of the streaming APIs.
|
|
6
|
+
|
|
7
|
+
The "vendor-specific identifiers" to use depend on the semantics of the responses of the responses from the vendor,
|
|
8
|
+
and are tightly coupled to the specific model being used, and the PydanticAI Model subclass implementation.
|
|
9
|
+
|
|
10
|
+
This `ModelResponsePartsManager` is used in each of the subclasses of `StreamedResponse` as a way to consolidate
|
|
11
|
+
event-emitting logic.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations as _annotations
|
|
15
|
+
|
|
16
|
+
from collections.abc import Hashable
|
|
17
|
+
from dataclasses import dataclass, field
|
|
18
|
+
from typing import Any, Union
|
|
19
|
+
|
|
20
|
+
from pydantic_ai.exceptions import UnexpectedModelBehavior
|
|
21
|
+
from pydantic_ai.messages import (
|
|
22
|
+
ModelResponsePart,
|
|
23
|
+
ModelResponseStreamEvent,
|
|
24
|
+
PartDeltaEvent,
|
|
25
|
+
PartStartEvent,
|
|
26
|
+
TextPart,
|
|
27
|
+
TextPartDelta,
|
|
28
|
+
ToolCallPart,
|
|
29
|
+
ToolCallPartDelta,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
VendorId = Hashable
|
|
33
|
+
"""
|
|
34
|
+
Type alias for a vendor identifier, which can be any hashable type (e.g., a string, UUID, etc.)
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
ManagedPart = Union[ModelResponsePart, ToolCallPartDelta]
|
|
38
|
+
"""
|
|
39
|
+
A union of types that are managed by the ModelResponsePartsManager.
|
|
40
|
+
Because many vendors have streaming APIs that may produce not-fully-formed tool calls,
|
|
41
|
+
this includes ToolCallPartDelta's in addition to the more fully-formed ModelResponsePart's.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class ModelResponsePartsManager:
|
|
47
|
+
"""Manages a sequence of parts that make up a model's streamed response.
|
|
48
|
+
|
|
49
|
+
Parts are generally added and/or updated by providing deltas, which are tracked by vendor-specific IDs.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
_parts: list[ManagedPart] = field(default_factory=list, init=False)
|
|
53
|
+
"""A list of parts (text or tool calls) that make up the current state of the model's response."""
|
|
54
|
+
_vendor_id_to_part_index: dict[VendorId, int] = field(default_factory=dict, init=False)
|
|
55
|
+
"""Maps a vendor's "part" ID (if provided) to the index in `_parts` where that part resides."""
|
|
56
|
+
|
|
57
|
+
def get_parts(self) -> list[ModelResponsePart]:
|
|
58
|
+
"""Return only model response parts that are complete (i.e., not ToolCallPartDelta's).
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
A list of ModelResponsePart objects. ToolCallPartDelta objects are excluded.
|
|
62
|
+
"""
|
|
63
|
+
return [p for p in self._parts if not isinstance(p, ToolCallPartDelta)]
|
|
64
|
+
|
|
65
|
+
def handle_text_delta(
|
|
66
|
+
self,
|
|
67
|
+
*,
|
|
68
|
+
vendor_part_id: Hashable | None,
|
|
69
|
+
content: str,
|
|
70
|
+
) -> ModelResponseStreamEvent:
|
|
71
|
+
"""Handle incoming text content, creating or updating a TextPart in the manager as appropriate.
|
|
72
|
+
|
|
73
|
+
When `vendor_part_id` is None, the latest part is updated if it exists and is a TextPart;
|
|
74
|
+
otherwise, a new TextPart is created. When a non-None ID is specified, the TextPart corresponding
|
|
75
|
+
to that vendor ID is either created or updated.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
vendor_part_id: The ID the vendor uses to identify this piece
|
|
79
|
+
of text. If None, a new part will be created unless the latest part is already
|
|
80
|
+
a TextPart.
|
|
81
|
+
content: The text content to append to the appropriate TextPart.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
A `PartStartEvent` if a new part was created, or a `PartDeltaEvent` if an existing part was updated.
|
|
85
|
+
|
|
86
|
+
Raises:
|
|
87
|
+
UnexpectedModelBehavior: If attempting to apply text content to a part that is
|
|
88
|
+
not a TextPart.
|
|
89
|
+
"""
|
|
90
|
+
existing_text_part_and_index: tuple[TextPart, int] | None = None
|
|
91
|
+
|
|
92
|
+
if vendor_part_id is None:
|
|
93
|
+
# If the vendor_part_id is None, check if the latest part is a TextPart to update
|
|
94
|
+
if self._parts:
|
|
95
|
+
part_index = len(self._parts) - 1
|
|
96
|
+
latest_part = self._parts[part_index]
|
|
97
|
+
if isinstance(latest_part, TextPart):
|
|
98
|
+
existing_text_part_and_index = latest_part, part_index
|
|
99
|
+
else:
|
|
100
|
+
# Otherwise, attempt to look up an existing TextPart by vendor_part_id
|
|
101
|
+
part_index = self._vendor_id_to_part_index.get(vendor_part_id)
|
|
102
|
+
if part_index is not None:
|
|
103
|
+
existing_part = self._parts[part_index]
|
|
104
|
+
if not isinstance(existing_part, TextPart):
|
|
105
|
+
raise UnexpectedModelBehavior(f'Cannot apply a text delta to {existing_part=}')
|
|
106
|
+
existing_text_part_and_index = existing_part, part_index
|
|
107
|
+
|
|
108
|
+
if existing_text_part_and_index is None:
|
|
109
|
+
# There is no existing text part that should be updated, so create a new one
|
|
110
|
+
new_part_index = len(self._parts)
|
|
111
|
+
part = TextPart(content=content)
|
|
112
|
+
if vendor_part_id is not None:
|
|
113
|
+
self._vendor_id_to_part_index[vendor_part_id] = new_part_index
|
|
114
|
+
self._parts.append(part)
|
|
115
|
+
return PartStartEvent(index=new_part_index, part=part)
|
|
116
|
+
else:
|
|
117
|
+
# Update the existing TextPart with the new content delta
|
|
118
|
+
existing_text_part, part_index = existing_text_part_and_index
|
|
119
|
+
part_delta = TextPartDelta(content_delta=content)
|
|
120
|
+
self._parts[part_index] = part_delta.apply(existing_text_part)
|
|
121
|
+
return PartDeltaEvent(index=part_index, delta=part_delta)
|
|
122
|
+
|
|
123
|
+
def handle_tool_call_delta(
|
|
124
|
+
self,
|
|
125
|
+
*,
|
|
126
|
+
vendor_part_id: Hashable | None,
|
|
127
|
+
tool_name: str | None,
|
|
128
|
+
args: str | dict[str, Any] | None,
|
|
129
|
+
tool_call_id: str | None,
|
|
130
|
+
) -> ModelResponseStreamEvent | None:
|
|
131
|
+
"""Handle or update a tool call, creating or updating a `ToolCallPart` or `ToolCallPartDelta`.
|
|
132
|
+
|
|
133
|
+
Managed items remain as `ToolCallPartDelta`s until they have both a tool_name and arguments, at which
|
|
134
|
+
point they are upgraded to `ToolCallPart`s.
|
|
135
|
+
|
|
136
|
+
If `vendor_part_id` is None, updates the latest matching ToolCallPart (or ToolCallPartDelta)
|
|
137
|
+
if any. Otherwise, a new part (or delta) may be created.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
vendor_part_id: The ID the vendor uses for this tool call.
|
|
141
|
+
If None, the latest matching tool call may be updated.
|
|
142
|
+
tool_name: The name of the tool. If None, the manager does not enforce
|
|
143
|
+
a name match when `vendor_part_id` is None.
|
|
144
|
+
args: Arguments for the tool call, either as a string or a dictionary of key-value pairs.
|
|
145
|
+
tool_call_id: An optional string representing an identifier for this tool call.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
- A `PartStartEvent` if a new (fully realized) ToolCallPart is created.
|
|
149
|
+
- A `PartDeltaEvent` if an existing part is updated.
|
|
150
|
+
- `None` if no new event is emitted (e.g., the part is still incomplete).
|
|
151
|
+
|
|
152
|
+
Raises:
|
|
153
|
+
UnexpectedModelBehavior: If attempting to apply a tool call delta to a part that is not
|
|
154
|
+
a ToolCallPart or ToolCallPartDelta.
|
|
155
|
+
"""
|
|
156
|
+
existing_matching_part_and_index: tuple[ToolCallPartDelta | ToolCallPart, int] | None = None
|
|
157
|
+
|
|
158
|
+
if vendor_part_id is None:
|
|
159
|
+
# vendor_part_id is None, so check if the latest part is a matching tool call or delta to update
|
|
160
|
+
# When the vendor_part_id is None, if the tool_name is _not_ None, assume this should be a new part rather
|
|
161
|
+
# than a delta on an existing one. We can change this behavior in the future if necessary for some model.
|
|
162
|
+
if tool_name is None and self._parts:
|
|
163
|
+
part_index = len(self._parts) - 1
|
|
164
|
+
latest_part = self._parts[part_index]
|
|
165
|
+
if isinstance(latest_part, (ToolCallPart, ToolCallPartDelta)):
|
|
166
|
+
existing_matching_part_and_index = latest_part, part_index
|
|
167
|
+
else:
|
|
168
|
+
# vendor_part_id is provided, so look up the corresponding part or delta
|
|
169
|
+
part_index = self._vendor_id_to_part_index.get(vendor_part_id)
|
|
170
|
+
if part_index is not None:
|
|
171
|
+
existing_part = self._parts[part_index]
|
|
172
|
+
if not isinstance(existing_part, (ToolCallPartDelta, ToolCallPart)):
|
|
173
|
+
raise UnexpectedModelBehavior(f'Cannot apply a tool call delta to {existing_part=}')
|
|
174
|
+
existing_matching_part_and_index = existing_part, part_index
|
|
175
|
+
|
|
176
|
+
if existing_matching_part_and_index is None:
|
|
177
|
+
# No matching part/delta was found, so create a new ToolCallPartDelta (or ToolCallPart if fully formed)
|
|
178
|
+
delta = ToolCallPartDelta(tool_name_delta=tool_name, args_delta=args, tool_call_id=tool_call_id)
|
|
179
|
+
part = delta.as_part() or delta
|
|
180
|
+
if vendor_part_id is not None:
|
|
181
|
+
self._vendor_id_to_part_index[vendor_part_id] = len(self._parts)
|
|
182
|
+
new_part_index = len(self._parts)
|
|
183
|
+
self._parts.append(part)
|
|
184
|
+
# Only emit a PartStartEvent if we have enough information to produce a full ToolCallPart
|
|
185
|
+
if isinstance(part, ToolCallPart):
|
|
186
|
+
return PartStartEvent(index=new_part_index, part=part)
|
|
187
|
+
else:
|
|
188
|
+
# Update the existing part or delta with the new information
|
|
189
|
+
existing_part, part_index = existing_matching_part_and_index
|
|
190
|
+
delta = ToolCallPartDelta(tool_name_delta=tool_name, args_delta=args, tool_call_id=tool_call_id)
|
|
191
|
+
updated_part = delta.apply(existing_part)
|
|
192
|
+
self._parts[part_index] = updated_part
|
|
193
|
+
if isinstance(updated_part, ToolCallPart):
|
|
194
|
+
if isinstance(existing_part, ToolCallPartDelta):
|
|
195
|
+
# We just upgraded a delta to a full part, so emit a PartStartEvent
|
|
196
|
+
return PartStartEvent(index=part_index, part=updated_part)
|
|
197
|
+
else:
|
|
198
|
+
# We updated an existing part, so emit a PartDeltaEvent
|
|
199
|
+
return PartDeltaEvent(index=part_index, delta=delta)
|
|
200
|
+
|
|
201
|
+
def handle_tool_call_part(
|
|
202
|
+
self,
|
|
203
|
+
*,
|
|
204
|
+
vendor_part_id: Hashable | None,
|
|
205
|
+
tool_name: str,
|
|
206
|
+
args: str | dict[str, Any],
|
|
207
|
+
tool_call_id: str | None = None,
|
|
208
|
+
) -> ModelResponseStreamEvent:
|
|
209
|
+
"""Immediately create or fully-overwrite a ToolCallPart with the given information.
|
|
210
|
+
|
|
211
|
+
This does not apply a delta; it directly sets the tool call part contents.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
vendor_part_id: The vendor's ID for this tool call part. If not
|
|
215
|
+
None and an existing part is found, that part is overwritten.
|
|
216
|
+
tool_name: The name of the tool being invoked.
|
|
217
|
+
args: The arguments for the tool call, either as a string or a dictionary.
|
|
218
|
+
tool_call_id: An optional string identifier for this tool call.
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
ModelResponseStreamEvent: A `PartStartEvent` indicating that a new tool call part
|
|
222
|
+
has been added to the manager, or replaced an existing part.
|
|
223
|
+
"""
|
|
224
|
+
new_part = ToolCallPart.from_raw_args(tool_name=tool_name, args=args, tool_call_id=tool_call_id)
|
|
225
|
+
if vendor_part_id is None:
|
|
226
|
+
# vendor_part_id is None, so we unconditionally append a new ToolCallPart to the end of the list
|
|
227
|
+
new_part_index = len(self._parts)
|
|
228
|
+
self._parts.append(new_part)
|
|
229
|
+
else:
|
|
230
|
+
# vendor_part_id is provided, so find and overwrite or create a new ToolCallPart.
|
|
231
|
+
maybe_part_index = self._vendor_id_to_part_index.get(vendor_part_id)
|
|
232
|
+
if maybe_part_index is not None:
|
|
233
|
+
new_part_index = maybe_part_index
|
|
234
|
+
self._parts[new_part_index] = new_part
|
|
235
|
+
else:
|
|
236
|
+
new_part_index = len(self._parts)
|
|
237
|
+
self._parts.append(new_part)
|
|
238
|
+
self._vendor_id_to_part_index[vendor_part_id] = new_part_index
|
|
239
|
+
return PartStartEvent(index=new_part_index, part=new_part)
|
|
@@ -20,7 +20,7 @@ from ._griffe import doc_descriptions
|
|
|
20
20
|
from ._utils import check_object_json_schema, is_model_like
|
|
21
21
|
|
|
22
22
|
if TYPE_CHECKING:
|
|
23
|
-
from .tools import ObjectJsonSchema
|
|
23
|
+
from .tools import DocstringFormat, ObjectJsonSchema
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
__all__ = ('function_schema',)
|
|
@@ -38,12 +38,19 @@ class FunctionSchema(TypedDict):
|
|
|
38
38
|
var_positional_field: str | None
|
|
39
39
|
|
|
40
40
|
|
|
41
|
-
def function_schema(
|
|
41
|
+
def function_schema( # noqa: C901
|
|
42
|
+
function: Callable[..., Any],
|
|
43
|
+
takes_ctx: bool,
|
|
44
|
+
docstring_format: DocstringFormat,
|
|
45
|
+
require_parameter_descriptions: bool,
|
|
46
|
+
) -> FunctionSchema:
|
|
42
47
|
"""Build a Pydantic validator and JSON schema from a tool function.
|
|
43
48
|
|
|
44
49
|
Args:
|
|
45
50
|
function: The function to build a validator and JSON schema for.
|
|
46
51
|
takes_ctx: Whether the function takes a `RunContext` first argument.
|
|
52
|
+
docstring_format: The docstring format to use.
|
|
53
|
+
require_parameter_descriptions: Whether to require descriptions for all tool function parameters.
|
|
47
54
|
|
|
48
55
|
Returns:
|
|
49
56
|
A `FunctionSchema` instance.
|
|
@@ -62,7 +69,13 @@ def function_schema(function: Callable[..., Any], takes_ctx: bool) -> FunctionSc
|
|
|
62
69
|
var_positional_field: str | None = None
|
|
63
70
|
errors: list[str] = []
|
|
64
71
|
decorators = _decorators.DecoratorInfos()
|
|
65
|
-
|
|
72
|
+
|
|
73
|
+
description, field_descriptions = doc_descriptions(function, sig, docstring_format=docstring_format)
|
|
74
|
+
|
|
75
|
+
if require_parameter_descriptions:
|
|
76
|
+
if len(field_descriptions) != len(sig.parameters):
|
|
77
|
+
missing_params = set(sig.parameters) - set(field_descriptions)
|
|
78
|
+
errors.append(f'Missing parameter descriptions for {", ".join(missing_params)}')
|
|
66
79
|
|
|
67
80
|
for index, (name, p) in enumerate(sig.parameters.items()):
|
|
68
81
|
if p.annotation is sig.empty:
|
|
@@ -15,7 +15,7 @@ from pydantic.json_schema import JsonSchemaValue
|
|
|
15
15
|
from typing_extensions import ParamSpec, TypeAlias, TypeGuard, is_typeddict
|
|
16
16
|
|
|
17
17
|
if TYPE_CHECKING:
|
|
18
|
-
from .
|
|
18
|
+
from . import messages as _messages
|
|
19
19
|
from .tools import ObjectJsonSchema
|
|
20
20
|
|
|
21
21
|
_P = ParamSpec('_P')
|
|
@@ -136,7 +136,7 @@ class Either(Generic[Left, Right]):
|
|
|
136
136
|
|
|
137
137
|
@asynccontextmanager
|
|
138
138
|
async def group_by_temporal(
|
|
139
|
-
|
|
139
|
+
aiterable: AsyncIterable[T], soft_max_interval: float | None
|
|
140
140
|
) -> AsyncIterator[AsyncIterable[list[T]]]:
|
|
141
141
|
"""Group items from an async iterable into lists based on time interval between them.
|
|
142
142
|
|
|
@@ -154,18 +154,18 @@ async def group_by_temporal(
|
|
|
154
154
|
```
|
|
155
155
|
|
|
156
156
|
Args:
|
|
157
|
-
|
|
157
|
+
aiterable: The async iterable to group.
|
|
158
158
|
soft_max_interval: Maximum interval over which to group items, this should avoid a trickle of items causing
|
|
159
159
|
a group to never be yielded. It's a soft max in the sense that once we're over this time, we yield items
|
|
160
160
|
as soon as `aiter.__anext__()` returns. If `None`, no grouping/debouncing is performed
|
|
161
161
|
|
|
162
162
|
Returns:
|
|
163
|
-
A context manager usable as an
|
|
163
|
+
A context manager usable as an async iterable of lists of items produced by the input async iterable.
|
|
164
164
|
"""
|
|
165
165
|
if soft_max_interval is None:
|
|
166
166
|
|
|
167
167
|
async def async_iter_groups_noop() -> AsyncIterator[list[T]]:
|
|
168
|
-
async for item in
|
|
168
|
+
async for item in aiterable:
|
|
169
169
|
yield [item]
|
|
170
170
|
|
|
171
171
|
yield async_iter_groups_noop()
|
|
@@ -181,6 +181,7 @@ async def group_by_temporal(
|
|
|
181
181
|
buffer: list[T] = []
|
|
182
182
|
group_start_time = time.monotonic()
|
|
183
183
|
|
|
184
|
+
aiterator = aiterable.__aiter__()
|
|
184
185
|
while True:
|
|
185
186
|
if group_start_time is None:
|
|
186
187
|
# group hasn't started, we just wait for the maximum interval
|
|
@@ -193,7 +194,7 @@ async def group_by_temporal(
|
|
|
193
194
|
if task is None:
|
|
194
195
|
# aiter.__anext__() returns an Awaitable[T], not a Coroutine which asyncio.create_task expects
|
|
195
196
|
# so far, this doesn't seem to be a problem
|
|
196
|
-
task = asyncio.create_task(
|
|
197
|
+
task = asyncio.create_task(aiterator.__anext__()) # pyright: ignore[reportArgumentType]
|
|
197
198
|
|
|
198
199
|
# we use asyncio.wait to avoid cancelling the coroutine if it's not done
|
|
199
200
|
done, _ = await asyncio.wait((task,), timeout=wait_time)
|
|
@@ -232,16 +233,6 @@ async def group_by_temporal(
|
|
|
232
233
|
await task
|
|
233
234
|
|
|
234
235
|
|
|
235
|
-
def add_optional(a: str | None, b: str | None) -> str | None:
|
|
236
|
-
"""Add two optional strings."""
|
|
237
|
-
if a is None:
|
|
238
|
-
return b
|
|
239
|
-
elif b is None:
|
|
240
|
-
return a
|
|
241
|
-
else:
|
|
242
|
-
return a + b
|
|
243
|
-
|
|
244
|
-
|
|
245
236
|
def sync_anext(iterator: Iterator[T]) -> T:
|
|
246
237
|
"""Get the next item from a sync iterator, raising `StopAsyncIteration` if it's exhausted.
|
|
247
238
|
|
|
@@ -257,7 +248,79 @@ def now_utc() -> datetime:
|
|
|
257
248
|
return datetime.now(tz=timezone.utc)
|
|
258
249
|
|
|
259
250
|
|
|
260
|
-
def guard_tool_call_id(
|
|
251
|
+
def guard_tool_call_id(
|
|
252
|
+
t: _messages.ToolCallPart | _messages.ToolReturnPart | _messages.RetryPromptPart, model_source: str
|
|
253
|
+
) -> str:
|
|
261
254
|
"""Type guard that checks a `tool_call_id` is not None both for static typing and runtime."""
|
|
262
255
|
assert t.tool_call_id is not None, f'{model_source} requires `tool_call_id` to be set: {t}'
|
|
263
256
|
return t.tool_call_id
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
class PeekableAsyncStream(Generic[T]):
|
|
260
|
+
"""Wraps an async iterable of type T and allows peeking at the *next* item without consuming it.
|
|
261
|
+
|
|
262
|
+
We only buffer one item at a time (the next item). Once that item is yielded, it is discarded.
|
|
263
|
+
This is a single-pass stream.
|
|
264
|
+
"""
|
|
265
|
+
|
|
266
|
+
def __init__(self, source: AsyncIterable[T]):
|
|
267
|
+
self._source = source
|
|
268
|
+
self._source_iter: AsyncIterator[T] | None = None
|
|
269
|
+
self._buffer: T | Unset = UNSET
|
|
270
|
+
self._exhausted = False
|
|
271
|
+
|
|
272
|
+
async def peek(self) -> T | Unset:
|
|
273
|
+
"""Returns the next item that would be yielded without consuming it.
|
|
274
|
+
|
|
275
|
+
Returns None if the stream is exhausted.
|
|
276
|
+
"""
|
|
277
|
+
if self._exhausted:
|
|
278
|
+
return UNSET
|
|
279
|
+
|
|
280
|
+
# If we already have a buffered item, just return it.
|
|
281
|
+
if not isinstance(self._buffer, Unset):
|
|
282
|
+
return self._buffer
|
|
283
|
+
|
|
284
|
+
# Otherwise, we need to fetch the next item from the underlying iterator.
|
|
285
|
+
if self._source_iter is None:
|
|
286
|
+
self._source_iter = self._source.__aiter__()
|
|
287
|
+
|
|
288
|
+
try:
|
|
289
|
+
self._buffer = await self._source_iter.__anext__()
|
|
290
|
+
except StopAsyncIteration:
|
|
291
|
+
self._exhausted = True
|
|
292
|
+
return UNSET
|
|
293
|
+
|
|
294
|
+
return self._buffer
|
|
295
|
+
|
|
296
|
+
async def is_exhausted(self) -> bool:
|
|
297
|
+
"""Returns True if the stream is exhausted, False otherwise."""
|
|
298
|
+
return isinstance(await self.peek(), Unset)
|
|
299
|
+
|
|
300
|
+
def __aiter__(self) -> AsyncIterator[T]:
|
|
301
|
+
# For a single-pass iteration, we can return self as the iterator.
|
|
302
|
+
return self
|
|
303
|
+
|
|
304
|
+
async def __anext__(self) -> T:
|
|
305
|
+
"""Yields the buffered item if present, otherwise fetches the next item from the underlying source.
|
|
306
|
+
|
|
307
|
+
Raises StopAsyncIteration if the stream is exhausted.
|
|
308
|
+
"""
|
|
309
|
+
if self._exhausted:
|
|
310
|
+
raise StopAsyncIteration
|
|
311
|
+
|
|
312
|
+
# If we have a buffered item, yield it.
|
|
313
|
+
if not isinstance(self._buffer, Unset):
|
|
314
|
+
item = self._buffer
|
|
315
|
+
self._buffer = UNSET
|
|
316
|
+
return item
|
|
317
|
+
|
|
318
|
+
# Otherwise, fetch the next item from the source.
|
|
319
|
+
if self._source_iter is None:
|
|
320
|
+
self._source_iter = self._source.__aiter__()
|
|
321
|
+
|
|
322
|
+
try:
|
|
323
|
+
return await self._source_iter.__anext__()
|
|
324
|
+
except StopAsyncIteration:
|
|
325
|
+
self._exhausted = True
|
|
326
|
+
raise
|