pydantic-ai-slim 0.0.11__tar.gz → 0.0.16__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- {pydantic_ai_slim-0.0.11 → pydantic_ai_slim-0.0.16}/.gitignore +1 -2
- {pydantic_ai_slim-0.0.11 → pydantic_ai_slim-0.0.16}/PKG-INFO +7 -3
- pydantic_ai_slim-0.0.16/pydantic_ai/__init__.py +19 -0
- {pydantic_ai_slim-0.0.11 → pydantic_ai_slim-0.0.16}/pydantic_ai/_griffe.py +1 -2
- {pydantic_ai_slim-0.0.11 → pydantic_ai_slim-0.0.16}/pydantic_ai/_pydantic.py +13 -29
- {pydantic_ai_slim-0.0.11 → pydantic_ai_slim-0.0.16}/pydantic_ai/_result.py +52 -41
- {pydantic_ai_slim-0.0.11 → pydantic_ai_slim-0.0.16}/pydantic_ai/_system_prompt.py +2 -2
- {pydantic_ai_slim-0.0.11 → pydantic_ai_slim-0.0.16}/pydantic_ai/_utils.py +20 -8
- pydantic_ai_slim-0.0.16/pydantic_ai/agent.py +1196 -0
- {pydantic_ai_slim-0.0.11 → pydantic_ai_slim-0.0.16}/pydantic_ai/exceptions.py +20 -2
- pydantic_ai_slim-0.0.16/pydantic_ai/messages.py +264 -0
- {pydantic_ai_slim-0.0.11 → pydantic_ai_slim-0.0.16}/pydantic_ai/models/__init__.py +68 -50
- pydantic_ai_slim-0.0.16/pydantic_ai/models/anthropic.py +344 -0
- {pydantic_ai_slim-0.0.11 → pydantic_ai_slim-0.0.16}/pydantic_ai/models/function.py +69 -53
- {pydantic_ai_slim-0.0.11 → pydantic_ai_slim-0.0.16}/pydantic_ai/models/gemini.py +184 -136
- {pydantic_ai_slim-0.0.11 → pydantic_ai_slim-0.0.16}/pydantic_ai/models/groq.py +139 -124
- pydantic_ai_slim-0.0.16/pydantic_ai/models/mistral.py +663 -0
- pydantic_ai_slim-0.0.16/pydantic_ai/models/ollama.py +119 -0
- {pydantic_ai_slim-0.0.11 → pydantic_ai_slim-0.0.16}/pydantic_ai/models/openai.py +159 -130
- {pydantic_ai_slim-0.0.11 → pydantic_ai_slim-0.0.16}/pydantic_ai/models/test.py +137 -88
- {pydantic_ai_slim-0.0.11 → pydantic_ai_slim-0.0.16}/pydantic_ai/models/vertexai.py +14 -9
- {pydantic_ai_slim-0.0.11 → pydantic_ai_slim-0.0.16}/pydantic_ai/result.py +130 -78
- pydantic_ai_slim-0.0.16/pydantic_ai/settings.py +141 -0
- pydantic_ai_slim-0.0.16/pydantic_ai/tools.py +338 -0
- {pydantic_ai_slim-0.0.11 → pydantic_ai_slim-0.0.16}/pyproject.toml +3 -1
- pydantic_ai_slim-0.0.11/pydantic_ai/__init__.py +0 -8
- pydantic_ai_slim-0.0.11/pydantic_ai/agent.py +0 -857
- pydantic_ai_slim-0.0.11/pydantic_ai/messages.py +0 -200
- pydantic_ai_slim-0.0.11/pydantic_ai/tools.py +0 -240
- {pydantic_ai_slim-0.0.11 → pydantic_ai_slim-0.0.16}/README.md +0 -0
- {pydantic_ai_slim-0.0.11 → pydantic_ai_slim-0.0.16}/pydantic_ai/py.typed +0 -0
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai-slim
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.16
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
|
|
5
5
|
Author-email: Samuel Colvin <samuel@pydantic.dev>
|
|
6
|
-
License: MIT
|
|
6
|
+
License-Expression: MIT
|
|
7
7
|
Classifier: Development Status :: 4 - Beta
|
|
8
8
|
Classifier: Environment :: Console
|
|
9
9
|
Classifier: Environment :: MacOS X
|
|
@@ -29,10 +29,14 @@ Requires-Dist: griffe>=1.3.2
|
|
|
29
29
|
Requires-Dist: httpx>=0.27.2
|
|
30
30
|
Requires-Dist: logfire-api>=1.2.0
|
|
31
31
|
Requires-Dist: pydantic>=2.10
|
|
32
|
+
Provides-Extra: anthropic
|
|
33
|
+
Requires-Dist: anthropic>=0.40.0; extra == 'anthropic'
|
|
32
34
|
Provides-Extra: groq
|
|
33
35
|
Requires-Dist: groq>=0.12.0; extra == 'groq'
|
|
34
36
|
Provides-Extra: logfire
|
|
35
37
|
Requires-Dist: logfire>=2.3; extra == 'logfire'
|
|
38
|
+
Provides-Extra: mistral
|
|
39
|
+
Requires-Dist: mistralai>=1.2.5; extra == 'mistral'
|
|
36
40
|
Provides-Extra: openai
|
|
37
41
|
Requires-Dist: openai>=1.54.3; extra == 'openai'
|
|
38
42
|
Provides-Extra: vertexai
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from importlib.metadata import version
|
|
2
|
+
|
|
3
|
+
from .agent import Agent, capture_run_messages
|
|
4
|
+
from .exceptions import AgentRunError, ModelRetry, UnexpectedModelBehavior, UsageLimitExceeded, UserError
|
|
5
|
+
from .tools import RunContext, Tool
|
|
6
|
+
|
|
7
|
+
__all__ = (
|
|
8
|
+
'Agent',
|
|
9
|
+
'capture_run_messages',
|
|
10
|
+
'RunContext',
|
|
11
|
+
'Tool',
|
|
12
|
+
'AgentRunError',
|
|
13
|
+
'ModelRetry',
|
|
14
|
+
'UnexpectedModelBehavior',
|
|
15
|
+
'UsageLimitExceeded',
|
|
16
|
+
'UserError',
|
|
17
|
+
'__version__',
|
|
18
|
+
)
|
|
19
|
+
__version__ = version('pydantic_ai_slim')
|
|
@@ -4,8 +4,7 @@ import re
|
|
|
4
4
|
from inspect import Signature
|
|
5
5
|
from typing import Any, Callable, Literal, cast
|
|
6
6
|
|
|
7
|
-
from
|
|
8
|
-
from _griffe.models import Docstring, Object as GriffeObject
|
|
7
|
+
from griffe import Docstring, DocstringSectionKind, Object as GriffeObject
|
|
9
8
|
|
|
10
9
|
DocstringStyle = Literal['google', 'numpy', 'sphinx']
|
|
11
10
|
|
|
@@ -8,7 +8,7 @@ from __future__ import annotations as _annotations
|
|
|
8
8
|
from inspect import Parameter, signature
|
|
9
9
|
from typing import TYPE_CHECKING, Any, Callable, TypedDict, cast, get_origin
|
|
10
10
|
|
|
11
|
-
from pydantic import ConfigDict
|
|
11
|
+
from pydantic import ConfigDict
|
|
12
12
|
from pydantic._internal import _decorators, _generate_schema, _typing_extra
|
|
13
13
|
from pydantic._internal._config import ConfigWrapper
|
|
14
14
|
from pydantic.fields import FieldInfo
|
|
@@ -17,13 +17,13 @@ from pydantic.plugin._schema_validator import create_schema_validator
|
|
|
17
17
|
from pydantic_core import SchemaValidator, core_schema
|
|
18
18
|
|
|
19
19
|
from ._griffe import doc_descriptions
|
|
20
|
-
from ._utils import
|
|
20
|
+
from ._utils import check_object_json_schema, is_model_like
|
|
21
21
|
|
|
22
22
|
if TYPE_CHECKING:
|
|
23
|
-
|
|
23
|
+
from .tools import ObjectJsonSchema
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
__all__ = 'function_schema',
|
|
26
|
+
__all__ = ('function_schema',)
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
class FunctionSchema(TypedDict):
|
|
@@ -138,14 +138,14 @@ def function_schema(function: Callable[..., Any], takes_ctx: bool) -> FunctionSc
|
|
|
138
138
|
json_schema = GenerateJsonSchema().generate(schema)
|
|
139
139
|
|
|
140
140
|
# workaround for https://github.com/pydantic/pydantic/issues/10785
|
|
141
|
-
# if we build a custom TypeDict schema (matches when `single_arg_name
|
|
141
|
+
# if we build a custom TypeDict schema (matches when `single_arg_name is None`), we manually set
|
|
142
142
|
# `additionalProperties` in the JSON Schema
|
|
143
143
|
if single_arg_name is None:
|
|
144
144
|
json_schema['additionalProperties'] = bool(var_kwargs_schema)
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
145
|
+
elif not description:
|
|
146
|
+
# if the tool description is not set, and we have a single parameter, take the description from that
|
|
147
|
+
# and set it on the tool
|
|
148
|
+
description = json_schema.pop('description', None)
|
|
149
149
|
|
|
150
150
|
return FunctionSchema(
|
|
151
151
|
description=description,
|
|
@@ -168,11 +168,13 @@ def takes_ctx(function: Callable[..., Any]) -> bool:
|
|
|
168
168
|
"""
|
|
169
169
|
sig = signature(function)
|
|
170
170
|
try:
|
|
171
|
-
|
|
171
|
+
first_param_name = next(iter(sig.parameters.keys()))
|
|
172
172
|
except StopIteration:
|
|
173
173
|
return False
|
|
174
174
|
else:
|
|
175
|
-
|
|
175
|
+
type_hints = _typing_extra.get_function_type_hints(function)
|
|
176
|
+
annotation = type_hints[first_param_name]
|
|
177
|
+
return annotation is not sig.empty and _is_call_ctx(annotation)
|
|
176
178
|
|
|
177
179
|
|
|
178
180
|
def _build_schema(
|
|
@@ -212,21 +214,3 @@ def _is_call_ctx(annotation: Any) -> bool:
|
|
|
212
214
|
return annotation is RunContext or (
|
|
213
215
|
_typing_extra.is_generic_alias(annotation) and get_origin(annotation) is RunContext
|
|
214
216
|
)
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
if TYPE_CHECKING:
|
|
218
|
-
LazyTypeAdapter = TypeAdapter
|
|
219
|
-
else:
|
|
220
|
-
|
|
221
|
-
class LazyTypeAdapter:
|
|
222
|
-
__slots__ = '_args', '_kwargs', '_type_adapter'
|
|
223
|
-
|
|
224
|
-
def __init__(self, *args, **kwargs):
|
|
225
|
-
self._args = args
|
|
226
|
-
self._kwargs = kwargs
|
|
227
|
-
self._type_adapter = None
|
|
228
|
-
|
|
229
|
-
def __getattr__(self, item):
|
|
230
|
-
if self._type_adapter is None:
|
|
231
|
-
self._type_adapter = TypeAdapter(*self._args, **self._kwargs)
|
|
232
|
-
return getattr(self._type_adapter, item)
|
|
@@ -3,18 +3,17 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
import inspect
|
|
4
4
|
import sys
|
|
5
5
|
import types
|
|
6
|
-
from collections.abc import Awaitable
|
|
6
|
+
from collections.abc import Awaitable, Iterable
|
|
7
7
|
from dataclasses import dataclass, field
|
|
8
8
|
from typing import Any, Callable, Generic, Literal, Union, cast, get_args, get_origin
|
|
9
9
|
|
|
10
10
|
from pydantic import TypeAdapter, ValidationError
|
|
11
11
|
from typing_extensions import Self, TypeAliasType, TypedDict
|
|
12
12
|
|
|
13
|
-
from . import _utils, messages
|
|
13
|
+
from . import _utils, messages as _messages
|
|
14
14
|
from .exceptions import ModelRetry
|
|
15
|
-
from .
|
|
16
|
-
from .
|
|
17
|
-
from .tools import AgentDeps, ResultValidatorFunc, RunContext
|
|
15
|
+
from .result import ResultData, ResultValidatorFunc
|
|
16
|
+
from .tools import AgentDeps, RunContext, ToolDefinition
|
|
18
17
|
|
|
19
18
|
|
|
20
19
|
@dataclass
|
|
@@ -28,21 +27,24 @@ class ResultValidator(Generic[AgentDeps, ResultData]):
|
|
|
28
27
|
self._is_async = inspect.iscoroutinefunction(self.function)
|
|
29
28
|
|
|
30
29
|
async def validate(
|
|
31
|
-
self,
|
|
30
|
+
self,
|
|
31
|
+
result: ResultData,
|
|
32
|
+
tool_call: _messages.ToolCallPart | None,
|
|
33
|
+
run_context: RunContext[AgentDeps],
|
|
32
34
|
) -> ResultData:
|
|
33
35
|
"""Validate a result but calling the function.
|
|
34
36
|
|
|
35
37
|
Args:
|
|
36
38
|
result: The result data after Pydantic validation the message content.
|
|
37
|
-
deps: The agent dependencies.
|
|
38
|
-
retry: The current retry number.
|
|
39
39
|
tool_call: The original tool call message, `None` if there was no tool call.
|
|
40
|
+
run_context: The current run context.
|
|
40
41
|
|
|
41
42
|
Returns:
|
|
42
43
|
Result of either the validated result data (ok) or a retry message (Err).
|
|
43
44
|
"""
|
|
44
45
|
if self._takes_ctx:
|
|
45
|
-
|
|
46
|
+
ctx = run_context.replace_with(tool_name=tool_call.tool_name if tool_call else None)
|
|
47
|
+
args = ctx, result
|
|
46
48
|
else:
|
|
47
49
|
args = (result,)
|
|
48
50
|
|
|
@@ -54,10 +56,10 @@ class ResultValidator(Generic[AgentDeps, ResultData]):
|
|
|
54
56
|
function = cast(Callable[[Any], ResultData], self.function)
|
|
55
57
|
result_data = await _utils.run_in_executor(function, *args)
|
|
56
58
|
except ModelRetry as r:
|
|
57
|
-
m =
|
|
59
|
+
m = _messages.RetryPromptPart(content=r.message)
|
|
58
60
|
if tool_call is not None:
|
|
59
61
|
m.tool_name = tool_call.tool_name
|
|
60
|
-
m.
|
|
62
|
+
m.tool_call_id = tool_call.tool_call_id
|
|
61
63
|
raise ToolRetryError(m) from r
|
|
62
64
|
else:
|
|
63
65
|
return result_data
|
|
@@ -66,7 +68,7 @@ class ResultValidator(Generic[AgentDeps, ResultData]):
|
|
|
66
68
|
class ToolRetryError(Exception):
|
|
67
69
|
"""Internal exception used to signal a `ToolRetry` message should be returned to the LLM."""
|
|
68
70
|
|
|
69
|
-
def __init__(self, tool_retry:
|
|
71
|
+
def __init__(self, tool_retry: _messages.RetryPromptPart):
|
|
70
72
|
self.tool_retry = tool_retry
|
|
71
73
|
super().__init__()
|
|
72
74
|
|
|
@@ -94,10 +96,7 @@ class ResultSchema(Generic[ResultData]):
|
|
|
94
96
|
allow_text_result = False
|
|
95
97
|
|
|
96
98
|
def _build_tool(a: Any, tool_name_: str, multiple: bool) -> ResultTool[ResultData]:
|
|
97
|
-
return cast(
|
|
98
|
-
ResultTool[ResultData],
|
|
99
|
-
ResultTool.build(a, tool_name_, description, multiple), # pyright: ignore[reportUnknownMemberType]
|
|
100
|
-
)
|
|
99
|
+
return cast(ResultTool[ResultData], ResultTool(a, tool_name_, description, multiple))
|
|
101
100
|
|
|
102
101
|
tools: dict[str, ResultTool[ResultData]] = {}
|
|
103
102
|
if args := get_union_args(response_type):
|
|
@@ -111,48 +110,61 @@ class ResultSchema(Generic[ResultData]):
|
|
|
111
110
|
|
|
112
111
|
return cls(tools=tools, allow_text_result=allow_text_result)
|
|
113
112
|
|
|
114
|
-
def
|
|
113
|
+
def find_named_tool(
|
|
114
|
+
self, parts: Iterable[_messages.ModelResponsePart], tool_name: str
|
|
115
|
+
) -> tuple[_messages.ToolCallPart, ResultTool[ResultData]] | None:
|
|
116
|
+
"""Find a tool that matches one of the calls, with a specific name."""
|
|
117
|
+
for part in parts:
|
|
118
|
+
if isinstance(part, _messages.ToolCallPart):
|
|
119
|
+
if part.tool_name == tool_name:
|
|
120
|
+
return part, self.tools[tool_name]
|
|
121
|
+
|
|
122
|
+
def find_tool(
|
|
123
|
+
self,
|
|
124
|
+
parts: Iterable[_messages.ModelResponsePart],
|
|
125
|
+
) -> tuple[_messages.ToolCallPart, ResultTool[ResultData]] | None:
|
|
115
126
|
"""Find a tool that matches one of the calls."""
|
|
116
|
-
for
|
|
117
|
-
if
|
|
118
|
-
|
|
127
|
+
for part in parts:
|
|
128
|
+
if isinstance(part, _messages.ToolCallPart):
|
|
129
|
+
if result := self.tools.get(part.tool_name):
|
|
130
|
+
return part, result
|
|
119
131
|
|
|
120
132
|
def tool_names(self) -> list[str]:
|
|
121
133
|
"""Return the names of the tools."""
|
|
122
134
|
return list(self.tools.keys())
|
|
123
135
|
|
|
136
|
+
def tool_defs(self) -> list[ToolDefinition]:
|
|
137
|
+
"""Get tool definitions to register with the model."""
|
|
138
|
+
return [t.tool_def for t in self.tools.values()]
|
|
139
|
+
|
|
124
140
|
|
|
125
141
|
DEFAULT_DESCRIPTION = 'The final response which ends this conversation'
|
|
126
142
|
|
|
127
143
|
|
|
128
|
-
@dataclass
|
|
144
|
+
@dataclass(init=False)
|
|
129
145
|
class ResultTool(Generic[ResultData]):
|
|
130
|
-
|
|
131
|
-
description: str
|
|
146
|
+
tool_def: ToolDefinition
|
|
132
147
|
type_adapter: TypeAdapter[Any]
|
|
133
|
-
json_schema: _utils.ObjectJsonSchema
|
|
134
|
-
outer_typed_dict_key: str | None
|
|
135
148
|
|
|
136
|
-
|
|
137
|
-
def build(cls, response_type: type[ResultData], name: str, description: str | None, multiple: bool) -> Self | None:
|
|
149
|
+
def __init__(self, response_type: type[ResultData], name: str, description: str | None, multiple: bool):
|
|
138
150
|
"""Build a ResultTool dataclass from a response type."""
|
|
139
151
|
assert response_type is not str, 'ResultTool does not support str as a response type'
|
|
140
152
|
|
|
141
153
|
if _utils.is_model_like(response_type):
|
|
142
|
-
type_adapter = TypeAdapter(response_type)
|
|
154
|
+
self.type_adapter = TypeAdapter(response_type)
|
|
143
155
|
outer_typed_dict_key: str | None = None
|
|
144
156
|
# noinspection PyArgumentList
|
|
145
|
-
|
|
157
|
+
parameters_json_schema = _utils.check_object_json_schema(self.type_adapter.json_schema())
|
|
146
158
|
else:
|
|
147
159
|
response_data_typed_dict = TypedDict('response_data_typed_dict', {'response': response_type}) # noqa
|
|
148
|
-
type_adapter = TypeAdapter(response_data_typed_dict)
|
|
160
|
+
self.type_adapter = TypeAdapter(response_data_typed_dict)
|
|
149
161
|
outer_typed_dict_key = 'response'
|
|
150
162
|
# noinspection PyArgumentList
|
|
151
|
-
|
|
163
|
+
parameters_json_schema = _utils.check_object_json_schema(self.type_adapter.json_schema())
|
|
152
164
|
# including `response_data_typed_dict` as a title here doesn't add anything and could confuse the LLM
|
|
153
|
-
|
|
165
|
+
parameters_json_schema.pop('title')
|
|
154
166
|
|
|
155
|
-
if json_schema_description :=
|
|
167
|
+
if json_schema_description := parameters_json_schema.pop('description', None):
|
|
156
168
|
if description is None:
|
|
157
169
|
tool_description = json_schema_description
|
|
158
170
|
else:
|
|
@@ -162,16 +174,15 @@ class ResultTool(Generic[ResultData]):
|
|
|
162
174
|
if multiple:
|
|
163
175
|
tool_description = f'{union_arg_name(response_type)}: {tool_description}'
|
|
164
176
|
|
|
165
|
-
|
|
177
|
+
self.tool_def = ToolDefinition(
|
|
166
178
|
name=name,
|
|
167
179
|
description=tool_description,
|
|
168
|
-
|
|
169
|
-
json_schema=json_schema,
|
|
180
|
+
parameters_json_schema=parameters_json_schema,
|
|
170
181
|
outer_typed_dict_key=outer_typed_dict_key,
|
|
171
182
|
)
|
|
172
183
|
|
|
173
184
|
def validate(
|
|
174
|
-
self, tool_call:
|
|
185
|
+
self, tool_call: _messages.ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True
|
|
175
186
|
) -> ResultData:
|
|
176
187
|
"""Validate a result message.
|
|
177
188
|
|
|
@@ -185,7 +196,7 @@ class ResultTool(Generic[ResultData]):
|
|
|
185
196
|
"""
|
|
186
197
|
try:
|
|
187
198
|
pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
|
|
188
|
-
if isinstance(tool_call.args,
|
|
199
|
+
if isinstance(tool_call.args, _messages.ArgsJson):
|
|
189
200
|
result = self.type_adapter.validate_json(
|
|
190
201
|
tool_call.args.args_json or '', experimental_allow_partial=pyd_allow_partial
|
|
191
202
|
)
|
|
@@ -195,16 +206,16 @@ class ResultTool(Generic[ResultData]):
|
|
|
195
206
|
)
|
|
196
207
|
except ValidationError as e:
|
|
197
208
|
if wrap_validation_errors:
|
|
198
|
-
m =
|
|
209
|
+
m = _messages.RetryPromptPart(
|
|
199
210
|
tool_name=tool_call.tool_name,
|
|
200
211
|
content=e.errors(include_url=False),
|
|
201
|
-
|
|
212
|
+
tool_call_id=tool_call.tool_call_id,
|
|
202
213
|
)
|
|
203
214
|
raise ToolRetryError(m) from e
|
|
204
215
|
else:
|
|
205
216
|
raise
|
|
206
217
|
else:
|
|
207
|
-
if k := self.outer_typed_dict_key:
|
|
218
|
+
if k := self.tool_def.outer_typed_dict_key:
|
|
208
219
|
result = result[k]
|
|
209
220
|
return result
|
|
210
221
|
|
|
@@ -19,9 +19,9 @@ class SystemPromptRunner(Generic[AgentDeps]):
|
|
|
19
19
|
self._takes_ctx = len(inspect.signature(self.function).parameters) > 0
|
|
20
20
|
self._is_async = inspect.iscoroutinefunction(self.function)
|
|
21
21
|
|
|
22
|
-
async def run(self,
|
|
22
|
+
async def run(self, run_context: RunContext[AgentDeps]) -> str:
|
|
23
23
|
if self._takes_ctx:
|
|
24
|
-
args = (
|
|
24
|
+
args = (run_context,)
|
|
25
25
|
else:
|
|
26
26
|
args = ()
|
|
27
27
|
|
|
@@ -8,12 +8,16 @@ from dataclasses import dataclass, is_dataclass
|
|
|
8
8
|
from datetime import datetime, timezone
|
|
9
9
|
from functools import partial
|
|
10
10
|
from types import GenericAlias
|
|
11
|
-
from typing import Any, Callable, Generic, TypeVar, Union, cast, overload
|
|
11
|
+
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast, overload
|
|
12
12
|
|
|
13
13
|
from pydantic import BaseModel
|
|
14
14
|
from pydantic.json_schema import JsonSchemaValue
|
|
15
15
|
from typing_extensions import ParamSpec, TypeAlias, TypeGuard, is_typeddict
|
|
16
16
|
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from .messages import RetryPromptPart, ToolCallPart, ToolReturnPart
|
|
19
|
+
from .tools import ObjectJsonSchema
|
|
20
|
+
|
|
17
21
|
_P = ParamSpec('_P')
|
|
18
22
|
_R = TypeVar('_R')
|
|
19
23
|
|
|
@@ -39,10 +43,6 @@ def is_model_like(type_: Any) -> bool:
|
|
|
39
43
|
)
|
|
40
44
|
|
|
41
45
|
|
|
42
|
-
# With PEP-728 this should be a TypedDict with `type: Literal['object']`, and `extra_items=Any`
|
|
43
|
-
ObjectJsonSchema: TypeAlias = dict[str, Any]
|
|
44
|
-
|
|
45
|
-
|
|
46
46
|
def check_object_json_schema(schema: JsonSchemaValue) -> ObjectJsonSchema:
|
|
47
47
|
from .exceptions import UserError
|
|
48
48
|
|
|
@@ -88,7 +88,7 @@ class Either(Generic[Left, Right]):
|
|
|
88
88
|
|
|
89
89
|
Usage:
|
|
90
90
|
|
|
91
|
-
```
|
|
91
|
+
```python
|
|
92
92
|
if left_thing := either.left:
|
|
93
93
|
use_left(left_thing.value)
|
|
94
94
|
else:
|
|
@@ -127,6 +127,12 @@ class Either(Generic[Left, Right]):
|
|
|
127
127
|
def whichever(self) -> Left | Right:
|
|
128
128
|
return self._left.value if self._left is not None else self.right
|
|
129
129
|
|
|
130
|
+
def __repr__(self):
|
|
131
|
+
if left := self._left:
|
|
132
|
+
return f'Either(left={left.value!r})'
|
|
133
|
+
else:
|
|
134
|
+
return f'Either(right={self.right!r})'
|
|
135
|
+
|
|
130
136
|
|
|
131
137
|
@asynccontextmanager
|
|
132
138
|
async def group_by_temporal(
|
|
@@ -141,7 +147,7 @@ async def group_by_temporal(
|
|
|
141
147
|
|
|
142
148
|
Usage:
|
|
143
149
|
|
|
144
|
-
```
|
|
150
|
+
```python
|
|
145
151
|
async with group_by_temporal(yield_groups(), 0.1) as groups_iter:
|
|
146
152
|
async for groups in groups_iter:
|
|
147
153
|
print(groups)
|
|
@@ -218,7 +224,7 @@ async def group_by_temporal(
|
|
|
218
224
|
|
|
219
225
|
try:
|
|
220
226
|
yield async_iter_groups()
|
|
221
|
-
finally:
|
|
227
|
+
finally: # pragma: no cover
|
|
222
228
|
# after iteration if a tasks still exists, cancel it, this will only happen if an error occurred
|
|
223
229
|
if task:
|
|
224
230
|
task.cancel('Cancelling due to error in iterator')
|
|
@@ -249,3 +255,9 @@ def sync_anext(iterator: Iterator[T]) -> T:
|
|
|
249
255
|
|
|
250
256
|
def now_utc() -> datetime:
|
|
251
257
|
return datetime.now(tz=timezone.utc)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def guard_tool_call_id(t: ToolCallPart | ToolReturnPart | RetryPromptPart, model_source: str) -> str:
|
|
261
|
+
"""Type guard that checks a `tool_call_id` is not None both for static typing and runtime."""
|
|
262
|
+
assert t.tool_call_id is not None, f'{model_source} requires `tool_call_id` to be set: {t}'
|
|
263
|
+
return t.tool_call_id
|