weakincentives 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of weakincentives might be problematic. Click here for more details.
- weakincentives/__init__.py +15 -2
- weakincentives/adapters/__init__.py +30 -0
- weakincentives/adapters/core.py +85 -0
- weakincentives/adapters/openai.py +361 -0
- weakincentives/prompts/__init__.py +45 -0
- weakincentives/prompts/_types.py +27 -0
- weakincentives/prompts/errors.py +57 -0
- weakincentives/prompts/prompt.py +440 -0
- weakincentives/prompts/response_format.py +54 -0
- weakincentives/prompts/section.py +120 -0
- weakincentives/prompts/structured.py +140 -0
- weakincentives/prompts/text.py +89 -0
- weakincentives/prompts/tool.py +236 -0
- weakincentives/serde/__init__.py +31 -0
- weakincentives/serde/dataclass_serde.py +1016 -0
- weakincentives-0.2.0.dist-info/METADATA +173 -0
- weakincentives-0.2.0.dist-info/RECORD +20 -0
- weakincentives-0.1.0.dist-info/METADATA +0 -21
- weakincentives-0.1.0.dist-info/RECORD +0 -6
- {weakincentives-0.1.0.dist-info → weakincentives-0.2.0.dist-info}/WHEEL +0 -0
- {weakincentives-0.1.0.dist-info → weakincentives-0.2.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
import re
|
|
17
|
+
from collections.abc import Mapping
|
|
18
|
+
from typing import Any, Literal, cast
|
|
19
|
+
|
|
20
|
+
from ..serde.dataclass_serde import parse as parse_dataclass
|
|
21
|
+
from .prompt import RenderedPrompt
|
|
22
|
+
|
|
23
|
+
__all__ = ["OutputParseError", "parse_output"]
|
|
24
|
+
|
|
25
|
+
_JSON_FENCE_PATTERN = re.compile(r"```json\s*\n(.*?)```", re.IGNORECASE | re.DOTALL)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class OutputParseError(Exception):
|
|
29
|
+
"""Raised when structured output parsing fails."""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
message: str,
|
|
34
|
+
*,
|
|
35
|
+
dataclass_type: type[Any] | None = None,
|
|
36
|
+
) -> None:
|
|
37
|
+
super().__init__(message)
|
|
38
|
+
self.message = message
|
|
39
|
+
self.dataclass_type = dataclass_type
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def parse_output[PayloadT](
|
|
43
|
+
output_text: str, rendered: RenderedPrompt[PayloadT]
|
|
44
|
+
) -> PayloadT:
|
|
45
|
+
"""Parse a model response into the structured output type declared by the prompt."""
|
|
46
|
+
|
|
47
|
+
dataclass_type = rendered.output_type
|
|
48
|
+
container = rendered.output_container
|
|
49
|
+
allow_extra_keys = rendered.allow_extra_keys
|
|
50
|
+
|
|
51
|
+
if dataclass_type is None or container is None:
|
|
52
|
+
raise OutputParseError("Prompt does not declare structured output.")
|
|
53
|
+
|
|
54
|
+
payload = _extract_json_payload(output_text, dataclass_type)
|
|
55
|
+
extra_mode: Literal["ignore", "forbid"] = "ignore" if allow_extra_keys else "forbid"
|
|
56
|
+
|
|
57
|
+
if container == "object":
|
|
58
|
+
if not isinstance(payload, Mapping):
|
|
59
|
+
raise OutputParseError(
|
|
60
|
+
"Expected top-level JSON object.",
|
|
61
|
+
dataclass_type=dataclass_type,
|
|
62
|
+
)
|
|
63
|
+
try:
|
|
64
|
+
mapping_payload = cast(Mapping[str, object], payload)
|
|
65
|
+
parsed = parse_dataclass(
|
|
66
|
+
dataclass_type,
|
|
67
|
+
mapping_payload,
|
|
68
|
+
extra=extra_mode,
|
|
69
|
+
)
|
|
70
|
+
except (TypeError, ValueError) as error:
|
|
71
|
+
raise OutputParseError(str(error), dataclass_type=dataclass_type) from error
|
|
72
|
+
return cast(PayloadT, parsed)
|
|
73
|
+
|
|
74
|
+
if container == "array":
|
|
75
|
+
if not isinstance(payload, list):
|
|
76
|
+
raise OutputParseError(
|
|
77
|
+
"Expected top-level JSON array.",
|
|
78
|
+
dataclass_type=dataclass_type,
|
|
79
|
+
)
|
|
80
|
+
payload_list = cast(list[object], payload)
|
|
81
|
+
parsed_items: list[Any] = []
|
|
82
|
+
for index, item in enumerate(payload_list):
|
|
83
|
+
if not isinstance(item, Mapping):
|
|
84
|
+
raise OutputParseError(
|
|
85
|
+
f"Array item at index {index} is not an object.",
|
|
86
|
+
dataclass_type=dataclass_type,
|
|
87
|
+
)
|
|
88
|
+
try:
|
|
89
|
+
mapping_item = cast(Mapping[str, object], item)
|
|
90
|
+
parsed_item = parse_dataclass(
|
|
91
|
+
dataclass_type,
|
|
92
|
+
mapping_item,
|
|
93
|
+
extra=extra_mode,
|
|
94
|
+
)
|
|
95
|
+
except (TypeError, ValueError) as error:
|
|
96
|
+
raise OutputParseError(
|
|
97
|
+
str(error), dataclass_type=dataclass_type
|
|
98
|
+
) from error
|
|
99
|
+
parsed_items.append(parsed_item)
|
|
100
|
+
return cast(PayloadT, parsed_items)
|
|
101
|
+
|
|
102
|
+
raise OutputParseError( # pragma: no cover - defensive guard
|
|
103
|
+
"Unknown output container declared.",
|
|
104
|
+
dataclass_type=dataclass_type,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def _extract_json_payload(text: str, dataclass_type: type[Any]) -> object:
|
|
109
|
+
fenced_match = _JSON_FENCE_PATTERN.search(text)
|
|
110
|
+
if fenced_match is not None:
|
|
111
|
+
block = fenced_match.group(1).strip()
|
|
112
|
+
try:
|
|
113
|
+
return json.loads(block)
|
|
114
|
+
except json.JSONDecodeError as error:
|
|
115
|
+
raise OutputParseError(
|
|
116
|
+
"Failed to decode JSON from fenced code block.",
|
|
117
|
+
dataclass_type=dataclass_type,
|
|
118
|
+
) from error
|
|
119
|
+
|
|
120
|
+
stripped = text.strip()
|
|
121
|
+
if stripped:
|
|
122
|
+
try:
|
|
123
|
+
return json.loads(stripped)
|
|
124
|
+
except json.JSONDecodeError:
|
|
125
|
+
pass
|
|
126
|
+
|
|
127
|
+
decoder = json.JSONDecoder()
|
|
128
|
+
for index, character in enumerate(text):
|
|
129
|
+
if character not in "{[":
|
|
130
|
+
continue
|
|
131
|
+
try:
|
|
132
|
+
payload, _ = decoder.raw_decode(text, index)
|
|
133
|
+
except json.JSONDecodeError: # pragma: no cover - defensive fallback
|
|
134
|
+
continue
|
|
135
|
+
return payload
|
|
136
|
+
|
|
137
|
+
raise OutputParseError(
|
|
138
|
+
"No JSON object or array found in assistant message.",
|
|
139
|
+
dataclass_type=dataclass_type,
|
|
140
|
+
)
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import textwrap
|
|
16
|
+
from collections.abc import Callable, Sequence
|
|
17
|
+
from dataclasses import fields, is_dataclass
|
|
18
|
+
from string import Template
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
from ._types import SupportsDataclass
|
|
22
|
+
from .errors import PromptRenderError
|
|
23
|
+
from .section import Section
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class TextSection[ParamsT: SupportsDataclass](Section[ParamsT]):
|
|
27
|
+
"""Render markdown text content using string.Template."""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
*,
|
|
32
|
+
title: str,
|
|
33
|
+
body: str,
|
|
34
|
+
defaults: ParamsT | None = None,
|
|
35
|
+
children: Sequence[object] | None = None,
|
|
36
|
+
enabled: Callable[[ParamsT], bool] | None = None,
|
|
37
|
+
tools: Sequence[object] | None = None,
|
|
38
|
+
) -> None:
|
|
39
|
+
super().__init__(
|
|
40
|
+
title=title,
|
|
41
|
+
defaults=defaults,
|
|
42
|
+
children=children,
|
|
43
|
+
enabled=enabled,
|
|
44
|
+
tools=tools,
|
|
45
|
+
)
|
|
46
|
+
self.body = body
|
|
47
|
+
|
|
48
|
+
def render(self, params: ParamsT, depth: int) -> str:
|
|
49
|
+
heading_level = "#" * (depth + 2)
|
|
50
|
+
heading = f"{heading_level} {self.title.strip()}"
|
|
51
|
+
template = Template(textwrap.dedent(self.body).strip())
|
|
52
|
+
try:
|
|
53
|
+
normalized_params = self._normalize_params(params)
|
|
54
|
+
rendered_body = template.substitute(normalized_params)
|
|
55
|
+
except KeyError as error: # pragma: no cover - handled at prompt level
|
|
56
|
+
missing = error.args[0]
|
|
57
|
+
raise PromptRenderError(
|
|
58
|
+
"Missing placeholder during render.",
|
|
59
|
+
placeholder=str(missing),
|
|
60
|
+
) from error
|
|
61
|
+
if rendered_body:
|
|
62
|
+
return f"{heading}\n\n{rendered_body.strip()}"
|
|
63
|
+
return heading
|
|
64
|
+
|
|
65
|
+
def placeholder_names(self) -> set[str]:
|
|
66
|
+
template = Template(textwrap.dedent(self.body).strip())
|
|
67
|
+
placeholders: set[str] = set()
|
|
68
|
+
for match in template.pattern.finditer(template.template):
|
|
69
|
+
named = match.group("named")
|
|
70
|
+
if named:
|
|
71
|
+
placeholders.add(named)
|
|
72
|
+
continue
|
|
73
|
+
braced = match.group("braced")
|
|
74
|
+
if braced:
|
|
75
|
+
placeholders.add(braced)
|
|
76
|
+
return placeholders
|
|
77
|
+
|
|
78
|
+
@staticmethod
|
|
79
|
+
def _normalize_params(params: ParamsT) -> dict[str, Any]:
|
|
80
|
+
if not is_dataclass(params) or isinstance(params, type):
|
|
81
|
+
raise PromptRenderError(
|
|
82
|
+
"Section params must be a dataclass instance.",
|
|
83
|
+
dataclass_type=type(params),
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
return {field.name: getattr(params, field.name) for field in fields(params)}
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
__all__ = ["TextSection"]
|
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import inspect
|
|
16
|
+
import re
|
|
17
|
+
from collections.abc import Callable
|
|
18
|
+
from dataclasses import dataclass, field, is_dataclass
|
|
19
|
+
from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints
|
|
20
|
+
|
|
21
|
+
from ._types import SupportsDataclass
|
|
22
|
+
from .errors import PromptValidationError
|
|
23
|
+
|
|
24
|
+
_NAME_PATTERN = re.compile(r"^[a-z0-9_]{1,64}$")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass(slots=True)
|
|
28
|
+
class ToolResult[ResultPayloadT]:
|
|
29
|
+
"""Structured response emitted by a tool handler."""
|
|
30
|
+
|
|
31
|
+
message: str
|
|
32
|
+
payload: ResultPayloadT
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass(slots=True)
|
|
36
|
+
class Tool[ParamsT: SupportsDataclass, ResultT: SupportsDataclass]:
|
|
37
|
+
"""Describe a callable tool exposed by prompt sections."""
|
|
38
|
+
|
|
39
|
+
name: str
|
|
40
|
+
description: str
|
|
41
|
+
handler: Callable[[ParamsT], ToolResult[ResultT]] | None
|
|
42
|
+
params_type: type[Any] = field(init=False, repr=False)
|
|
43
|
+
result_type: type[Any] = field(init=False, repr=False)
|
|
44
|
+
|
|
45
|
+
def __post_init__(self) -> None:
|
|
46
|
+
params_type = cast(
|
|
47
|
+
type[SupportsDataclass] | None, getattr(self, "params_type", None)
|
|
48
|
+
)
|
|
49
|
+
result_type = cast(
|
|
50
|
+
type[SupportsDataclass] | None, getattr(self, "result_type", None)
|
|
51
|
+
)
|
|
52
|
+
if params_type is None or result_type is None:
|
|
53
|
+
origin = getattr(self, "__orig_class__", None)
|
|
54
|
+
if origin is not None: # pragma: no cover - defensive fallback
|
|
55
|
+
args = get_args(origin)
|
|
56
|
+
if len(args) == 2 and all(isinstance(arg, type) for arg in args):
|
|
57
|
+
params_type = cast(type[SupportsDataclass], args[0])
|
|
58
|
+
result_type = cast(type[SupportsDataclass], args[1])
|
|
59
|
+
if params_type is None or result_type is None:
|
|
60
|
+
raise PromptValidationError(
|
|
61
|
+
"Tool must be instantiated with concrete type arguments.",
|
|
62
|
+
placeholder="type_arguments",
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
if not is_dataclass(params_type):
|
|
66
|
+
raise PromptValidationError(
|
|
67
|
+
"Tool ParamsT must be a dataclass type.",
|
|
68
|
+
dataclass_type=params_type,
|
|
69
|
+
placeholder="ParamsT",
|
|
70
|
+
)
|
|
71
|
+
if not is_dataclass(result_type):
|
|
72
|
+
raise PromptValidationError(
|
|
73
|
+
"Tool ResultT must be a dataclass type.",
|
|
74
|
+
dataclass_type=result_type,
|
|
75
|
+
placeholder="ResultT",
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
self.params_type = params_type
|
|
79
|
+
self.result_type = result_type
|
|
80
|
+
|
|
81
|
+
raw_name = self.name
|
|
82
|
+
stripped_name = raw_name.strip()
|
|
83
|
+
if raw_name != stripped_name:
|
|
84
|
+
normalized_name = stripped_name
|
|
85
|
+
raise PromptValidationError(
|
|
86
|
+
"Tool name must not contain surrounding whitespace.",
|
|
87
|
+
dataclass_type=params_type,
|
|
88
|
+
placeholder=normalized_name,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
name_clean = raw_name
|
|
92
|
+
if not name_clean:
|
|
93
|
+
raise PromptValidationError(
|
|
94
|
+
"Tool name must be non-empty lowercase ASCII up to 64 characters.",
|
|
95
|
+
dataclass_type=params_type,
|
|
96
|
+
placeholder=stripped_name,
|
|
97
|
+
)
|
|
98
|
+
if len(name_clean) > 64 or not _NAME_PATTERN.fullmatch(name_clean):
|
|
99
|
+
raise PromptValidationError(
|
|
100
|
+
"Tool name must use lowercase ASCII letters, digits, or underscores.",
|
|
101
|
+
dataclass_type=params_type,
|
|
102
|
+
placeholder=name_clean,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
description_clean = self.description.strip()
|
|
106
|
+
if not description_clean or len(description_clean) > 200:
|
|
107
|
+
raise PromptValidationError(
|
|
108
|
+
"Tool description must be 1-200 ASCII characters.",
|
|
109
|
+
dataclass_type=params_type,
|
|
110
|
+
placeholder="description",
|
|
111
|
+
)
|
|
112
|
+
try:
|
|
113
|
+
description_clean.encode("ascii")
|
|
114
|
+
except UnicodeEncodeError as error:
|
|
115
|
+
raise PromptValidationError(
|
|
116
|
+
"Tool description must be ASCII.",
|
|
117
|
+
dataclass_type=params_type,
|
|
118
|
+
placeholder="description",
|
|
119
|
+
) from error
|
|
120
|
+
|
|
121
|
+
handler = self.handler
|
|
122
|
+
if handler is not None:
|
|
123
|
+
self._validate_handler(handler, params_type, result_type)
|
|
124
|
+
|
|
125
|
+
self.name = name_clean
|
|
126
|
+
self.description = description_clean
|
|
127
|
+
|
|
128
|
+
def _validate_handler(
|
|
129
|
+
self,
|
|
130
|
+
handler: Callable[[ParamsT], ToolResult[ResultT]],
|
|
131
|
+
params_type: type[SupportsDataclass],
|
|
132
|
+
result_type: type[SupportsDataclass],
|
|
133
|
+
) -> None:
|
|
134
|
+
if not callable(handler):
|
|
135
|
+
raise PromptValidationError(
|
|
136
|
+
"Tool handler must be callable.",
|
|
137
|
+
dataclass_type=params_type,
|
|
138
|
+
placeholder="handler",
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
signature = inspect.signature(handler)
|
|
142
|
+
parameters = list(signature.parameters.values())
|
|
143
|
+
|
|
144
|
+
if len(parameters) != 1:
|
|
145
|
+
raise PromptValidationError(
|
|
146
|
+
"Tool handler must accept exactly one argument.",
|
|
147
|
+
dataclass_type=params_type,
|
|
148
|
+
placeholder="handler",
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
parameter = parameters[0]
|
|
152
|
+
if parameter.kind not in (
|
|
153
|
+
inspect.Parameter.POSITIONAL_ONLY,
|
|
154
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
155
|
+
):
|
|
156
|
+
raise PromptValidationError(
|
|
157
|
+
"Tool handler parameter must be positional.",
|
|
158
|
+
dataclass_type=params_type,
|
|
159
|
+
placeholder="handler",
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
try:
|
|
163
|
+
hints = get_type_hints(handler, include_extras=True)
|
|
164
|
+
except Exception: # pragma: no cover - fallback for invalid hints
|
|
165
|
+
hints = {}
|
|
166
|
+
|
|
167
|
+
annotation = hints.get(parameter.name, parameter.annotation)
|
|
168
|
+
if annotation is inspect.Parameter.empty:
|
|
169
|
+
raise PromptValidationError(
|
|
170
|
+
"Tool handler parameter must be annotated with ParamsT.",
|
|
171
|
+
dataclass_type=params_type,
|
|
172
|
+
placeholder="handler",
|
|
173
|
+
)
|
|
174
|
+
if get_origin(annotation) is Annotated:
|
|
175
|
+
annotation = get_args(annotation)[0]
|
|
176
|
+
if annotation is not params_type:
|
|
177
|
+
raise PromptValidationError(
|
|
178
|
+
"Tool handler parameter annotation must match ParamsT.",
|
|
179
|
+
dataclass_type=params_type,
|
|
180
|
+
placeholder="handler",
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
return_annotation = hints.get("return", signature.return_annotation)
|
|
184
|
+
if return_annotation is inspect.Signature.empty:
|
|
185
|
+
raise PromptValidationError(
|
|
186
|
+
"Tool handler must annotate its return value with ToolResult[ResultT].",
|
|
187
|
+
dataclass_type=params_type,
|
|
188
|
+
placeholder="return",
|
|
189
|
+
)
|
|
190
|
+
if get_origin(return_annotation) is Annotated:
|
|
191
|
+
return_annotation = get_args(return_annotation)[0]
|
|
192
|
+
|
|
193
|
+
origin = get_origin(return_annotation)
|
|
194
|
+
if origin is ToolResult:
|
|
195
|
+
result_args_raw = get_args(return_annotation)
|
|
196
|
+
if result_args_raw == (result_type,):
|
|
197
|
+
return
|
|
198
|
+
raise PromptValidationError(
|
|
199
|
+
"Tool handler return annotation must be ToolResult[ResultT].",
|
|
200
|
+
dataclass_type=params_type,
|
|
201
|
+
placeholder="return",
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
@classmethod
|
|
205
|
+
def __class_getitem__(
|
|
206
|
+
cls, item: object
|
|
207
|
+
) -> type[Tool[SupportsDataclass, SupportsDataclass]]:
|
|
208
|
+
if not isinstance(item, tuple):
|
|
209
|
+
raise TypeError("Tool[...] expects two type arguments (ParamsT, ResultT).")
|
|
210
|
+
typed_item = cast(tuple[Any, Any], item)
|
|
211
|
+
try:
|
|
212
|
+
params_candidate, result_candidate = typed_item
|
|
213
|
+
except ValueError as error:
|
|
214
|
+
raise TypeError(
|
|
215
|
+
"Tool[...] expects two type arguments (ParamsT, ResultT)."
|
|
216
|
+
) from error
|
|
217
|
+
if not isinstance(params_candidate, type) or not isinstance(
|
|
218
|
+
result_candidate, type
|
|
219
|
+
):
|
|
220
|
+
raise TypeError("Tool[...] type arguments must be types.")
|
|
221
|
+
params_type = cast(type[SupportsDataclass], params_candidate)
|
|
222
|
+
result_type = cast(type[SupportsDataclass], result_candidate)
|
|
223
|
+
|
|
224
|
+
class _SpecializedTool(cls): # type: ignore[misc]
|
|
225
|
+
def __post_init__(self) -> None: # type: ignore[override]
|
|
226
|
+
self.params_type = params_type
|
|
227
|
+
self.result_type = result_type
|
|
228
|
+
super().__post_init__()
|
|
229
|
+
|
|
230
|
+
_SpecializedTool.__name__ = cls.__name__
|
|
231
|
+
_SpecializedTool.__qualname__ = cls.__qualname__
|
|
232
|
+
_SpecializedTool.__module__ = cls.__module__
|
|
233
|
+
return _SpecializedTool # type: ignore[return-value]
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
__all__ = ["Tool", "ToolResult"]
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
# Copyright 2025 weak incentives
|
|
14
|
+
#
|
|
15
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
16
|
+
# you may not use this file except in compliance with the License.
|
|
17
|
+
# You may obtain a copy of the License at
|
|
18
|
+
#
|
|
19
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
20
|
+
#
|
|
21
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
22
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
23
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
24
|
+
# See the License for the specific language governing permissions and
|
|
25
|
+
# limitations under the License.
|
|
26
|
+
|
|
27
|
+
"""Stdlib dataclass serde utilities."""
|
|
28
|
+
|
|
29
|
+
from .dataclass_serde import clone, dump, parse, schema
|
|
30
|
+
|
|
31
|
+
__all__ = ["parse", "dump", "clone", "schema"]
|