weakincentives 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- weakincentives/__init__.py +67 -0
- weakincentives/adapters/__init__.py +37 -0
- weakincentives/adapters/_names.py +32 -0
- weakincentives/adapters/_provider_protocols.py +69 -0
- weakincentives/adapters/_tool_messages.py +80 -0
- weakincentives/adapters/core.py +102 -0
- weakincentives/adapters/litellm.py +254 -0
- weakincentives/adapters/openai.py +254 -0
- weakincentives/adapters/shared.py +1021 -0
- weakincentives/cli/__init__.py +23 -0
- weakincentives/cli/wink.py +58 -0
- weakincentives/dbc/__init__.py +412 -0
- weakincentives/deadlines.py +58 -0
- weakincentives/prompt/__init__.py +105 -0
- weakincentives/prompt/_generic_params_specializer.py +64 -0
- weakincentives/prompt/_normalization.py +48 -0
- weakincentives/prompt/_overrides_protocols.py +33 -0
- weakincentives/prompt/_types.py +34 -0
- weakincentives/prompt/chapter.py +146 -0
- weakincentives/prompt/composition.py +281 -0
- weakincentives/prompt/errors.py +57 -0
- weakincentives/prompt/markdown.py +108 -0
- weakincentives/prompt/overrides/__init__.py +59 -0
- weakincentives/prompt/overrides/_fs.py +164 -0
- weakincentives/prompt/overrides/inspection.py +141 -0
- weakincentives/prompt/overrides/local_store.py +275 -0
- weakincentives/prompt/overrides/validation.py +534 -0
- weakincentives/prompt/overrides/versioning.py +269 -0
- weakincentives/prompt/prompt.py +353 -0
- weakincentives/prompt/protocols.py +103 -0
- weakincentives/prompt/registry.py +375 -0
- weakincentives/prompt/rendering.py +288 -0
- weakincentives/prompt/response_format.py +60 -0
- weakincentives/prompt/section.py +166 -0
- weakincentives/prompt/structured_output.py +179 -0
- weakincentives/prompt/tool.py +397 -0
- weakincentives/prompt/tool_result.py +30 -0
- weakincentives/py.typed +0 -0
- weakincentives/runtime/__init__.py +82 -0
- weakincentives/runtime/events/__init__.py +126 -0
- weakincentives/runtime/events/_types.py +110 -0
- weakincentives/runtime/logging.py +284 -0
- weakincentives/runtime/session/__init__.py +46 -0
- weakincentives/runtime/session/_slice_types.py +24 -0
- weakincentives/runtime/session/_types.py +55 -0
- weakincentives/runtime/session/dataclasses.py +29 -0
- weakincentives/runtime/session/protocols.py +34 -0
- weakincentives/runtime/session/reducer_context.py +40 -0
- weakincentives/runtime/session/reducers.py +82 -0
- weakincentives/runtime/session/selectors.py +56 -0
- weakincentives/runtime/session/session.py +387 -0
- weakincentives/runtime/session/snapshots.py +310 -0
- weakincentives/serde/__init__.py +19 -0
- weakincentives/serde/_utils.py +240 -0
- weakincentives/serde/dataclass_serde.py +55 -0
- weakincentives/serde/dump.py +189 -0
- weakincentives/serde/parse.py +417 -0
- weakincentives/serde/schema.py +260 -0
- weakincentives/tools/__init__.py +154 -0
- weakincentives/tools/_context.py +38 -0
- weakincentives/tools/asteval.py +853 -0
- weakincentives/tools/errors.py +26 -0
- weakincentives/tools/planning.py +831 -0
- weakincentives/tools/podman.py +1655 -0
- weakincentives/tools/subagents.py +346 -0
- weakincentives/tools/vfs.py +1390 -0
- weakincentives/types/__init__.py +35 -0
- weakincentives/types/json.py +45 -0
- weakincentives-0.9.0.dist-info/METADATA +775 -0
- weakincentives-0.9.0.dist-info/RECORD +73 -0
- weakincentives-0.9.0.dist-info/WHEEL +4 -0
- weakincentives-0.9.0.dist-info/entry_points.txt +2 -0
- weakincentives-0.9.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
"""Rendering helpers for :mod:`weakincentives.prompt`."""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from collections.abc import Callable, Iterator, Mapping, MutableMapping
|
|
18
|
+
from dataclasses import dataclass, field, is_dataclass, replace
|
|
19
|
+
from types import MappingProxyType
|
|
20
|
+
from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast, override
|
|
21
|
+
|
|
22
|
+
from ..deadlines import Deadline
|
|
23
|
+
from ._types import SupportsDataclass, SupportsToolResult
|
|
24
|
+
from .errors import PromptRenderError, PromptValidationError, SectionPath
|
|
25
|
+
from .registry import RegistrySnapshot, SectionNode
|
|
26
|
+
from .response_format import ResponseFormatSection
|
|
27
|
+
from .structured_output import StructuredOutputConfig
|
|
28
|
+
from .tool import Tool
|
|
29
|
+
|
|
30
|
+
if TYPE_CHECKING: # pragma: no cover - typing only
|
|
31
|
+
from .overrides import ToolOverride
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
_EMPTY_TOOL_PARAM_DESCRIPTIONS: Mapping[str, Mapping[str, str]] = MappingProxyType({})
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
OutputT_co = TypeVar("OutputT_co", covariant=True)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass(frozen=True, slots=True)
|
|
41
|
+
class RenderedPrompt[OutputT_co]:
|
|
42
|
+
"""Rendered prompt text paired with structured output metadata."""
|
|
43
|
+
|
|
44
|
+
text: str
|
|
45
|
+
structured_output: StructuredOutputConfig[SupportsDataclass] | None = None
|
|
46
|
+
deadline: Deadline | None = None
|
|
47
|
+
_tools: tuple[Tool[SupportsDataclass, SupportsToolResult], ...] = field(
|
|
48
|
+
default_factory=tuple
|
|
49
|
+
)
|
|
50
|
+
_tool_param_descriptions: Mapping[str, Mapping[str, str]] = field(
|
|
51
|
+
default=_EMPTY_TOOL_PARAM_DESCRIPTIONS
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
@override
|
|
55
|
+
def __str__(self) -> str: # pragma: no cover - delegated behaviour
|
|
56
|
+
return self.text
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def tools(self) -> tuple[Tool[SupportsDataclass, SupportsToolResult], ...]:
|
|
60
|
+
"""Tools contributed by enabled sections in traversal order."""
|
|
61
|
+
|
|
62
|
+
return self._tools
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def tool_param_descriptions(
|
|
66
|
+
self,
|
|
67
|
+
) -> Mapping[str, Mapping[str, str]]:
|
|
68
|
+
"""Description patches keyed by tool name."""
|
|
69
|
+
|
|
70
|
+
return self._tool_param_descriptions
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def output_type(self) -> type[SupportsDataclass] | None:
|
|
74
|
+
"""Return the declared dataclass type for structured output."""
|
|
75
|
+
|
|
76
|
+
if self.structured_output is None:
|
|
77
|
+
return None
|
|
78
|
+
return self.structured_output.dataclass_type
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def container(self) -> Literal["object", "array"] | None:
|
|
82
|
+
"""Return the declared container for structured output."""
|
|
83
|
+
|
|
84
|
+
if self.structured_output is None:
|
|
85
|
+
return None
|
|
86
|
+
return self.structured_output.container
|
|
87
|
+
|
|
88
|
+
@property
|
|
89
|
+
def allow_extra_keys(self) -> bool | None:
|
|
90
|
+
"""Return whether extra keys are allowed in structured output."""
|
|
91
|
+
|
|
92
|
+
if self.structured_output is None:
|
|
93
|
+
return None
|
|
94
|
+
return self.structured_output.allow_extra_keys
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _freeze_tool_param_descriptions(
|
|
98
|
+
descriptions: Mapping[str, dict[str, str]],
|
|
99
|
+
) -> Mapping[str, Mapping[str, str]]:
|
|
100
|
+
if not descriptions:
|
|
101
|
+
return MappingProxyType({})
|
|
102
|
+
frozen: dict[str, Mapping[str, str]] = {}
|
|
103
|
+
for name, field_mapping in descriptions.items():
|
|
104
|
+
frozen[name] = MappingProxyType(dict(field_mapping))
|
|
105
|
+
return MappingProxyType(frozen)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
OutputT = TypeVar("OutputT")
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class PromptRenderer[OutputT]:
|
|
112
|
+
"""Render prompts using a registry snapshot."""
|
|
113
|
+
|
|
114
|
+
def __init__(
|
|
115
|
+
self,
|
|
116
|
+
*,
|
|
117
|
+
registry: RegistrySnapshot,
|
|
118
|
+
structured_output: StructuredOutputConfig[SupportsDataclass] | None,
|
|
119
|
+
response_section: ResponseFormatSection | None,
|
|
120
|
+
) -> None:
|
|
121
|
+
super().__init__()
|
|
122
|
+
self._registry = registry
|
|
123
|
+
self._structured_output = structured_output
|
|
124
|
+
self._response_section: ResponseFormatSection | None = response_section
|
|
125
|
+
|
|
126
|
+
def build_param_lookup(
|
|
127
|
+
self, params: tuple[SupportsDataclass, ...]
|
|
128
|
+
) -> dict[type[SupportsDataclass], SupportsDataclass]:
|
|
129
|
+
lookup: dict[type[SupportsDataclass], SupportsDataclass] = {}
|
|
130
|
+
for value in params:
|
|
131
|
+
if isinstance(value, type):
|
|
132
|
+
provided_type: type[Any] = value
|
|
133
|
+
else:
|
|
134
|
+
provided_type = type(value)
|
|
135
|
+
if isinstance(value, type) or not is_dataclass(value):
|
|
136
|
+
raise PromptValidationError(
|
|
137
|
+
"Prompt expects dataclass instances.",
|
|
138
|
+
dataclass_type=provided_type,
|
|
139
|
+
)
|
|
140
|
+
params_type = cast(type[SupportsDataclass], provided_type)
|
|
141
|
+
if params_type in lookup:
|
|
142
|
+
raise PromptValidationError(
|
|
143
|
+
"Duplicate params type supplied to prompt.",
|
|
144
|
+
dataclass_type=params_type,
|
|
145
|
+
)
|
|
146
|
+
if params_type not in self._registry.param_types:
|
|
147
|
+
raise PromptValidationError(
|
|
148
|
+
"Unexpected params type supplied to prompt.",
|
|
149
|
+
dataclass_type=params_type,
|
|
150
|
+
)
|
|
151
|
+
lookup[params_type] = value
|
|
152
|
+
return lookup
|
|
153
|
+
|
|
154
|
+
def render(
|
|
155
|
+
self,
|
|
156
|
+
param_lookup: Mapping[type[SupportsDataclass], SupportsDataclass],
|
|
157
|
+
overrides: Mapping[SectionPath, str] | None = None,
|
|
158
|
+
tool_overrides: Mapping[str, ToolOverride] | None = None,
|
|
159
|
+
*,
|
|
160
|
+
inject_output_instructions: bool | None = None,
|
|
161
|
+
) -> RenderedPrompt[OutputT]:
|
|
162
|
+
rendered_sections: list[str] = []
|
|
163
|
+
collected_tools: list[Tool[SupportsDataclass, SupportsToolResult]] = []
|
|
164
|
+
override_lookup = dict(overrides or {})
|
|
165
|
+
tool_override_lookup = dict(tool_overrides or {})
|
|
166
|
+
field_description_patches: dict[str, dict[str, str]] = {}
|
|
167
|
+
|
|
168
|
+
for node, section_params in self._iter_enabled_sections(
|
|
169
|
+
dict(param_lookup),
|
|
170
|
+
inject_output_instructions=inject_output_instructions,
|
|
171
|
+
):
|
|
172
|
+
override_body = (
|
|
173
|
+
override_lookup.get(node.path)
|
|
174
|
+
if getattr(node.section, "accepts_overrides", True)
|
|
175
|
+
else None
|
|
176
|
+
)
|
|
177
|
+
rendered = self._render_section(node, section_params, override_body)
|
|
178
|
+
|
|
179
|
+
section_tools = node.section.tools()
|
|
180
|
+
if section_tools:
|
|
181
|
+
for tool in section_tools:
|
|
182
|
+
override = (
|
|
183
|
+
tool_override_lookup.get(tool.name)
|
|
184
|
+
if tool.accepts_overrides
|
|
185
|
+
else None
|
|
186
|
+
)
|
|
187
|
+
patched_tool = tool
|
|
188
|
+
if override is not None:
|
|
189
|
+
if (
|
|
190
|
+
override.description is not None
|
|
191
|
+
and override.description != tool.description
|
|
192
|
+
):
|
|
193
|
+
patched_tool = replace(
|
|
194
|
+
tool, description=override.description
|
|
195
|
+
)
|
|
196
|
+
if override.param_descriptions:
|
|
197
|
+
field_description_patches[tool.name] = dict(
|
|
198
|
+
override.param_descriptions
|
|
199
|
+
)
|
|
200
|
+
collected_tools.append(patched_tool)
|
|
201
|
+
|
|
202
|
+
if rendered:
|
|
203
|
+
rendered_sections.append(rendered)
|
|
204
|
+
|
|
205
|
+
text = "\n\n".join(rendered_sections)
|
|
206
|
+
|
|
207
|
+
return RenderedPrompt[OutputT](
|
|
208
|
+
text=text,
|
|
209
|
+
structured_output=self._structured_output,
|
|
210
|
+
_tools=tuple(collected_tools),
|
|
211
|
+
_tool_param_descriptions=_freeze_tool_param_descriptions(
|
|
212
|
+
field_description_patches
|
|
213
|
+
),
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
def _iter_enabled_sections(
|
|
217
|
+
self,
|
|
218
|
+
param_lookup: MutableMapping[type[SupportsDataclass], SupportsDataclass],
|
|
219
|
+
*,
|
|
220
|
+
inject_output_instructions: bool | None = None,
|
|
221
|
+
) -> Iterator[tuple[SectionNode[SupportsDataclass], SupportsDataclass | None]]:
|
|
222
|
+
skip_depth: int | None = None
|
|
223
|
+
|
|
224
|
+
for node in self._registry.sections:
|
|
225
|
+
if skip_depth is not None:
|
|
226
|
+
if node.depth > skip_depth:
|
|
227
|
+
continue
|
|
228
|
+
skip_depth = None
|
|
229
|
+
|
|
230
|
+
section_params = self._registry.resolve_section_params(node, param_lookup)
|
|
231
|
+
|
|
232
|
+
if node.section is self._response_section and (
|
|
233
|
+
inject_output_instructions is not None
|
|
234
|
+
):
|
|
235
|
+
enabled = inject_output_instructions
|
|
236
|
+
else:
|
|
237
|
+
try:
|
|
238
|
+
enabled = node.section.is_enabled(section_params)
|
|
239
|
+
except Exception as error: # pragma: no cover - defensive
|
|
240
|
+
raise PromptRenderError(
|
|
241
|
+
"Section enabled predicate failed.",
|
|
242
|
+
section_path=node.path,
|
|
243
|
+
dataclass_type=node.section.param_type,
|
|
244
|
+
) from error
|
|
245
|
+
|
|
246
|
+
if not enabled:
|
|
247
|
+
skip_depth = node.depth
|
|
248
|
+
continue
|
|
249
|
+
|
|
250
|
+
yield node, section_params
|
|
251
|
+
|
|
252
|
+
def _render_section(
|
|
253
|
+
self,
|
|
254
|
+
node: SectionNode[SupportsDataclass],
|
|
255
|
+
section_params: SupportsDataclass | None,
|
|
256
|
+
override_body: str | None,
|
|
257
|
+
) -> str:
|
|
258
|
+
params_type = node.section.param_type
|
|
259
|
+
try:
|
|
260
|
+
render_override = getattr(node.section, "render_with_template", None)
|
|
261
|
+
if override_body is not None and callable(render_override):
|
|
262
|
+
override_renderer = cast(
|
|
263
|
+
Callable[[str, SupportsDataclass | None, int], str],
|
|
264
|
+
render_override,
|
|
265
|
+
)
|
|
266
|
+
rendered = override_renderer(override_body, section_params, node.depth)
|
|
267
|
+
else:
|
|
268
|
+
rendered = node.section.render(section_params, node.depth)
|
|
269
|
+
except PromptRenderError as error:
|
|
270
|
+
if error.section_path and error.dataclass_type:
|
|
271
|
+
raise
|
|
272
|
+
raise PromptRenderError(
|
|
273
|
+
error.message,
|
|
274
|
+
section_path=node.path,
|
|
275
|
+
dataclass_type=params_type,
|
|
276
|
+
placeholder=error.placeholder,
|
|
277
|
+
) from error
|
|
278
|
+
except Exception as error: # pragma: no cover - defensive
|
|
279
|
+
raise PromptRenderError(
|
|
280
|
+
"Section rendering failed.",
|
|
281
|
+
section_path=node.path,
|
|
282
|
+
dataclass_type=params_type,
|
|
283
|
+
) from error
|
|
284
|
+
|
|
285
|
+
return rendered
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
__all__ = ["PromptRenderer", "RenderedPrompt"]
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
from collections.abc import Callable
|
|
16
|
+
from dataclasses import dataclass
|
|
17
|
+
from typing import Final, Literal
|
|
18
|
+
|
|
19
|
+
from ._types import SupportsDataclass
|
|
20
|
+
from .markdown import MarkdownSection
|
|
21
|
+
|
|
22
|
+
__all__ = ["ResponseFormatParams", "ResponseFormatSection"]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass(slots=True)
|
|
26
|
+
class ResponseFormatParams:
|
|
27
|
+
"""Parameter payload for the auto-generated response format section."""
|
|
28
|
+
|
|
29
|
+
article: Literal["a", "an"]
|
|
30
|
+
container: Literal["object", "array"]
|
|
31
|
+
extra_clause: str
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
_RESPONSE_FORMAT_BODY: Final[
|
|
35
|
+
str
|
|
36
|
+
] = """Return ONLY a single fenced JSON code block. Do not include any text
|
|
37
|
+
before or after the block.
|
|
38
|
+
|
|
39
|
+
The top-level JSON value MUST be ${article} ${container} that matches the fields
|
|
40
|
+
of the expected schema${extra_clause}"""
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class ResponseFormatSection(MarkdownSection[ResponseFormatParams]):
|
|
44
|
+
"""Internal section that appends JSON-only response instructions."""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
*,
|
|
49
|
+
params: ResponseFormatParams,
|
|
50
|
+
enabled: Callable[[SupportsDataclass], bool] | None = None,
|
|
51
|
+
accepts_overrides: bool = False,
|
|
52
|
+
) -> None:
|
|
53
|
+
super().__init__(
|
|
54
|
+
title="Response Format",
|
|
55
|
+
key="response-format",
|
|
56
|
+
template=_RESPONSE_FORMAT_BODY,
|
|
57
|
+
default_params=params,
|
|
58
|
+
enabled=enabled,
|
|
59
|
+
accepts_overrides=accepts_overrides,
|
|
60
|
+
)
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import inspect
|
|
16
|
+
from abc import ABC, abstractmethod
|
|
17
|
+
from collections.abc import Callable, Sequence
|
|
18
|
+
from typing import TYPE_CHECKING, ClassVar, TypeVar, cast
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from .tool import Tool
|
|
22
|
+
|
|
23
|
+
from ._generic_params_specializer import GenericParamsSpecializer
|
|
24
|
+
from ._normalization import normalize_component_key
|
|
25
|
+
from ._types import SupportsDataclass, SupportsToolResult
|
|
26
|
+
|
|
27
|
+
SectionParamsT = TypeVar("SectionParamsT", bound=SupportsDataclass, covariant=True)
|
|
28
|
+
|
|
29
|
+
EnabledPredicate = Callable[[SupportsDataclass], bool] | Callable[[], bool]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Section(GenericParamsSpecializer[SectionParamsT], ABC):
|
|
33
|
+
"""Abstract building block for prompt content."""
|
|
34
|
+
|
|
35
|
+
_generic_owner_name: ClassVar[str | None] = "Section"
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
*,
|
|
40
|
+
title: str,
|
|
41
|
+
key: str,
|
|
42
|
+
default_params: SectionParamsT | None = None,
|
|
43
|
+
children: Sequence[Section[SupportsDataclass]] | None = None,
|
|
44
|
+
enabled: EnabledPredicate | None = None,
|
|
45
|
+
tools: Sequence[object] | None = None,
|
|
46
|
+
accepts_overrides: bool = True,
|
|
47
|
+
) -> None:
|
|
48
|
+
super().__init__()
|
|
49
|
+
params_candidate = getattr(self.__class__, "_params_type", None)
|
|
50
|
+
candidate_type = (
|
|
51
|
+
params_candidate if isinstance(params_candidate, type) else None
|
|
52
|
+
)
|
|
53
|
+
params_type = cast(type[SupportsDataclass] | None, candidate_type)
|
|
54
|
+
|
|
55
|
+
self.params_type: type[SectionParamsT] | None = cast(
|
|
56
|
+
type[SectionParamsT] | None, params_type
|
|
57
|
+
)
|
|
58
|
+
self.param_type: type[SectionParamsT] | None = self.params_type
|
|
59
|
+
self.title = title
|
|
60
|
+
self.key = self._normalize_key(key)
|
|
61
|
+
self.default_params = default_params
|
|
62
|
+
self.accepts_overrides = accepts_overrides
|
|
63
|
+
|
|
64
|
+
if self.params_type is None and self.default_params is not None:
|
|
65
|
+
raise TypeError("Section without parameters cannot define default_params.")
|
|
66
|
+
|
|
67
|
+
normalized_children: list[Section[SupportsDataclass]] = []
|
|
68
|
+
raw_children: Sequence[object] = cast(Sequence[object], children or ())
|
|
69
|
+
for child in raw_children:
|
|
70
|
+
if not isinstance(child, Section):
|
|
71
|
+
raise TypeError("Section children must be Section instances.")
|
|
72
|
+
normalized_children.append(cast(Section[SupportsDataclass], child))
|
|
73
|
+
self.children = tuple(normalized_children)
|
|
74
|
+
self._enabled: Callable[[SupportsDataclass | None], bool] | None = (
|
|
75
|
+
self._normalize_enabled(enabled, params_type)
|
|
76
|
+
)
|
|
77
|
+
self._tools = self._normalize_tools(tools)
|
|
78
|
+
|
|
79
|
+
def is_enabled(self, params: SupportsDataclass | None) -> bool:
|
|
80
|
+
"""Return True when the section should render for the given params."""
|
|
81
|
+
|
|
82
|
+
if self._enabled is None:
|
|
83
|
+
return True
|
|
84
|
+
return bool(self._enabled(params))
|
|
85
|
+
|
|
86
|
+
@abstractmethod
|
|
87
|
+
def render(self, params: SupportsDataclass | None, depth: int) -> str:
|
|
88
|
+
"""Produce markdown output for the section at the supplied depth."""
|
|
89
|
+
|
|
90
|
+
def placeholder_names(self) -> set[str]:
|
|
91
|
+
"""Return placeholder identifiers used by the section template."""
|
|
92
|
+
|
|
93
|
+
return set()
|
|
94
|
+
|
|
95
|
+
def tools(self) -> tuple[Tool[SupportsDataclass, SupportsToolResult], ...]:
|
|
96
|
+
"""Return the tools exposed by this section."""
|
|
97
|
+
|
|
98
|
+
return self._tools
|
|
99
|
+
|
|
100
|
+
def original_body_template(self) -> str | None:
|
|
101
|
+
"""Return the template text that participates in hashing, when available."""
|
|
102
|
+
|
|
103
|
+
return None
|
|
104
|
+
|
|
105
|
+
@staticmethod
|
|
106
|
+
def _normalize_key(key: str) -> str:
|
|
107
|
+
return normalize_component_key(key, owner="Section")
|
|
108
|
+
|
|
109
|
+
@staticmethod
|
|
110
|
+
def _normalize_tools(
|
|
111
|
+
tools: Sequence[object] | None,
|
|
112
|
+
) -> tuple[Tool[SupportsDataclass, SupportsToolResult], ...]:
|
|
113
|
+
if not tools:
|
|
114
|
+
return ()
|
|
115
|
+
|
|
116
|
+
from .tool import Tool
|
|
117
|
+
|
|
118
|
+
normalized: list[Tool[SupportsDataclass, SupportsToolResult]] = []
|
|
119
|
+
for tool in tools:
|
|
120
|
+
if not isinstance(tool, Tool):
|
|
121
|
+
raise TypeError("Section tools must be Tool instances.")
|
|
122
|
+
normalized.append(cast(Tool[SupportsDataclass, SupportsToolResult], tool))
|
|
123
|
+
return tuple(normalized)
|
|
124
|
+
|
|
125
|
+
@staticmethod
|
|
126
|
+
def _normalize_enabled(
|
|
127
|
+
enabled: EnabledPredicate | None,
|
|
128
|
+
params_type: type[SupportsDataclass] | None,
|
|
129
|
+
) -> Callable[[SupportsDataclass | None], bool] | None:
|
|
130
|
+
if enabled is None:
|
|
131
|
+
return None
|
|
132
|
+
if params_type is None and not _callable_requires_positional_argument(enabled):
|
|
133
|
+
zero_arg = cast(Callable[[], bool], enabled)
|
|
134
|
+
|
|
135
|
+
def _without_params(_: SupportsDataclass | None) -> bool:
|
|
136
|
+
return bool(zero_arg())
|
|
137
|
+
|
|
138
|
+
return _without_params
|
|
139
|
+
|
|
140
|
+
coerced = cast(Callable[[SupportsDataclass], bool], enabled)
|
|
141
|
+
|
|
142
|
+
def _with_params(value: SupportsDataclass | None) -> bool:
|
|
143
|
+
return bool(coerced(cast(SupportsDataclass, value)))
|
|
144
|
+
|
|
145
|
+
return _with_params
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def _callable_requires_positional_argument(callback: EnabledPredicate) -> bool:
|
|
149
|
+
try:
|
|
150
|
+
signature = inspect.signature(callback)
|
|
151
|
+
except (TypeError, ValueError):
|
|
152
|
+
return True
|
|
153
|
+
for parameter in signature.parameters.values():
|
|
154
|
+
if (
|
|
155
|
+
parameter.kind
|
|
156
|
+
in (
|
|
157
|
+
inspect.Parameter.POSITIONAL_ONLY,
|
|
158
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
159
|
+
)
|
|
160
|
+
and parameter.default is inspect.Signature.empty
|
|
161
|
+
):
|
|
162
|
+
return True
|
|
163
|
+
return False
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
__all__ = ["Section"]
|
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
import re
|
|
17
|
+
from collections.abc import Mapping, Sequence
|
|
18
|
+
from dataclasses import dataclass
|
|
19
|
+
from typing import Final, Literal, Protocol, TypeVar, cast
|
|
20
|
+
|
|
21
|
+
from ..serde.parse import parse as parse_dataclass
|
|
22
|
+
from ..types import JSONValue, ParseableDataclassT
|
|
23
|
+
from ._types import SupportsDataclass
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
"ARRAY_WRAPPER_KEY",
|
|
27
|
+
"OutputParseError",
|
|
28
|
+
"StructuredOutputConfig",
|
|
29
|
+
"parse_structured_output",
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
ARRAY_WRAPPER_KEY: Final[str] = "items"
|
|
33
|
+
|
|
34
|
+
_JSON_FENCE_PATTERN: Final[re.Pattern[str]] = re.compile(
|
|
35
|
+
r"```json\s*\n(.*?)```", re.IGNORECASE | re.DOTALL
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
DataclassT = TypeVar("DataclassT", bound=SupportsDataclass)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass(frozen=True, slots=True)
|
|
43
|
+
class StructuredOutputConfig[DataclassT]:
|
|
44
|
+
"""Resolved structured output declaration for a prompt."""
|
|
45
|
+
|
|
46
|
+
dataclass_type: type[DataclassT]
|
|
47
|
+
container: Literal["object", "array"]
|
|
48
|
+
allow_extra_keys: bool
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class StructuredRenderedPrompt[PayloadT](Protocol):
|
|
52
|
+
@property
|
|
53
|
+
def structured_output(self) -> StructuredOutputConfig[SupportsDataclass] | None:
|
|
54
|
+
"""Structured output metadata declared by the prompt."""
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class OutputParseError(Exception):
|
|
58
|
+
"""Raised when structured output parsing fails."""
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
message: str,
|
|
63
|
+
*,
|
|
64
|
+
dataclass_type: type[SupportsDataclass] | None = None,
|
|
65
|
+
) -> None:
|
|
66
|
+
super().__init__(message)
|
|
67
|
+
self.message = message
|
|
68
|
+
self.dataclass_type = dataclass_type
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def parse_structured_output[PayloadT](
|
|
72
|
+
output_text: str, rendered: StructuredRenderedPrompt[PayloadT]
|
|
73
|
+
) -> PayloadT:
|
|
74
|
+
"""Parse a model response into the structured output type declared by the prompt."""
|
|
75
|
+
|
|
76
|
+
config = rendered.structured_output
|
|
77
|
+
if config is None:
|
|
78
|
+
raise OutputParseError("Prompt does not declare structured output.")
|
|
79
|
+
|
|
80
|
+
dataclass_type = config.dataclass_type
|
|
81
|
+
container = config.container
|
|
82
|
+
allow_extra_keys = config.allow_extra_keys
|
|
83
|
+
payload = _extract_json_payload(output_text, dataclass_type)
|
|
84
|
+
try:
|
|
85
|
+
parsed = parse_dataclass_payload(
|
|
86
|
+
dataclass_type,
|
|
87
|
+
container,
|
|
88
|
+
payload,
|
|
89
|
+
allow_extra_keys=allow_extra_keys,
|
|
90
|
+
object_error="Expected top-level JSON object.",
|
|
91
|
+
array_error="Expected top-level JSON array.",
|
|
92
|
+
array_item_error="Array item at index {index} is not an object.",
|
|
93
|
+
)
|
|
94
|
+
except (TypeError, ValueError) as error:
|
|
95
|
+
raise OutputParseError(str(error), dataclass_type=dataclass_type) from error
|
|
96
|
+
|
|
97
|
+
return cast(PayloadT, parsed)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _extract_json_payload(
|
|
101
|
+
text: str, dataclass_type: type[SupportsDataclass]
|
|
102
|
+
) -> JSONValue:
|
|
103
|
+
fenced_match = _JSON_FENCE_PATTERN.search(text)
|
|
104
|
+
if fenced_match is not None:
|
|
105
|
+
block = fenced_match.group(1).strip()
|
|
106
|
+
try:
|
|
107
|
+
return json.loads(block)
|
|
108
|
+
except json.JSONDecodeError as error:
|
|
109
|
+
raise OutputParseError(
|
|
110
|
+
"Failed to decode JSON from fenced code block.",
|
|
111
|
+
dataclass_type=dataclass_type,
|
|
112
|
+
) from error
|
|
113
|
+
|
|
114
|
+
stripped = text.strip()
|
|
115
|
+
if stripped:
|
|
116
|
+
try:
|
|
117
|
+
return json.loads(stripped)
|
|
118
|
+
except json.JSONDecodeError:
|
|
119
|
+
pass
|
|
120
|
+
|
|
121
|
+
decoder = json.JSONDecoder()
|
|
122
|
+
for index, character in enumerate(text):
|
|
123
|
+
if character not in "{[":
|
|
124
|
+
continue
|
|
125
|
+
try:
|
|
126
|
+
payload, _ = decoder.raw_decode(text, index)
|
|
127
|
+
except json.JSONDecodeError:
|
|
128
|
+
continue
|
|
129
|
+
return cast(JSONValue, payload)
|
|
130
|
+
|
|
131
|
+
raise OutputParseError(
|
|
132
|
+
"No JSON object or array found in assistant message.",
|
|
133
|
+
dataclass_type=dataclass_type,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def parse_dataclass_payload(
|
|
138
|
+
dataclass_type: type[ParseableDataclassT],
|
|
139
|
+
container: Literal["object", "array"],
|
|
140
|
+
payload: JSONValue,
|
|
141
|
+
*,
|
|
142
|
+
allow_extra_keys: bool,
|
|
143
|
+
object_error: str,
|
|
144
|
+
array_error: str,
|
|
145
|
+
array_item_error: str,
|
|
146
|
+
) -> ParseableDataclassT | list[ParseableDataclassT]:
|
|
147
|
+
if container not in {"object", "array"}:
|
|
148
|
+
raise TypeError("Unknown output container declared.")
|
|
149
|
+
|
|
150
|
+
extra_mode: Literal["ignore", "forbid"] = "ignore" if allow_extra_keys else "forbid"
|
|
151
|
+
|
|
152
|
+
if container == "object":
|
|
153
|
+
if not isinstance(payload, Mapping):
|
|
154
|
+
raise TypeError(object_error)
|
|
155
|
+
mapping_payload = cast(Mapping[str, JSONValue], payload)
|
|
156
|
+
return parse_dataclass(dataclass_type, mapping_payload, extra=extra_mode)
|
|
157
|
+
|
|
158
|
+
if isinstance(payload, Mapping):
|
|
159
|
+
mapping_payload = cast(Mapping[str, JSONValue], payload)
|
|
160
|
+
if ARRAY_WRAPPER_KEY not in mapping_payload:
|
|
161
|
+
raise TypeError(array_error)
|
|
162
|
+
payload = mapping_payload[ARRAY_WRAPPER_KEY]
|
|
163
|
+
if not isinstance(payload, Sequence) or isinstance(
|
|
164
|
+
payload, (str, bytes, bytearray)
|
|
165
|
+
):
|
|
166
|
+
raise TypeError(array_error)
|
|
167
|
+
sequence_payload = cast(Sequence[JSONValue], payload)
|
|
168
|
+
parsed_items: list[ParseableDataclassT] = []
|
|
169
|
+
for index, item in enumerate(sequence_payload):
|
|
170
|
+
if not isinstance(item, Mapping):
|
|
171
|
+
raise TypeError(array_item_error.format(index=index))
|
|
172
|
+
mapping_item = cast(Mapping[str, JSONValue], item)
|
|
173
|
+
parsed_item = parse_dataclass(
|
|
174
|
+
dataclass_type,
|
|
175
|
+
mapping_item,
|
|
176
|
+
extra=extra_mode,
|
|
177
|
+
)
|
|
178
|
+
parsed_items.append(parsed_item)
|
|
179
|
+
return parsed_items
|