weakincentives 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- weakincentives/__init__.py +67 -0
- weakincentives/adapters/__init__.py +37 -0
- weakincentives/adapters/_names.py +32 -0
- weakincentives/adapters/_provider_protocols.py +69 -0
- weakincentives/adapters/_tool_messages.py +80 -0
- weakincentives/adapters/core.py +102 -0
- weakincentives/adapters/litellm.py +254 -0
- weakincentives/adapters/openai.py +254 -0
- weakincentives/adapters/shared.py +1021 -0
- weakincentives/cli/__init__.py +23 -0
- weakincentives/cli/wink.py +58 -0
- weakincentives/dbc/__init__.py +412 -0
- weakincentives/deadlines.py +58 -0
- weakincentives/prompt/__init__.py +105 -0
- weakincentives/prompt/_generic_params_specializer.py +64 -0
- weakincentives/prompt/_normalization.py +48 -0
- weakincentives/prompt/_overrides_protocols.py +33 -0
- weakincentives/prompt/_types.py +34 -0
- weakincentives/prompt/chapter.py +146 -0
- weakincentives/prompt/composition.py +281 -0
- weakincentives/prompt/errors.py +57 -0
- weakincentives/prompt/markdown.py +108 -0
- weakincentives/prompt/overrides/__init__.py +59 -0
- weakincentives/prompt/overrides/_fs.py +164 -0
- weakincentives/prompt/overrides/inspection.py +141 -0
- weakincentives/prompt/overrides/local_store.py +275 -0
- weakincentives/prompt/overrides/validation.py +534 -0
- weakincentives/prompt/overrides/versioning.py +269 -0
- weakincentives/prompt/prompt.py +353 -0
- weakincentives/prompt/protocols.py +103 -0
- weakincentives/prompt/registry.py +375 -0
- weakincentives/prompt/rendering.py +288 -0
- weakincentives/prompt/response_format.py +60 -0
- weakincentives/prompt/section.py +166 -0
- weakincentives/prompt/structured_output.py +179 -0
- weakincentives/prompt/tool.py +397 -0
- weakincentives/prompt/tool_result.py +30 -0
- weakincentives/py.typed +0 -0
- weakincentives/runtime/__init__.py +82 -0
- weakincentives/runtime/events/__init__.py +126 -0
- weakincentives/runtime/events/_types.py +110 -0
- weakincentives/runtime/logging.py +284 -0
- weakincentives/runtime/session/__init__.py +46 -0
- weakincentives/runtime/session/_slice_types.py +24 -0
- weakincentives/runtime/session/_types.py +55 -0
- weakincentives/runtime/session/dataclasses.py +29 -0
- weakincentives/runtime/session/protocols.py +34 -0
- weakincentives/runtime/session/reducer_context.py +40 -0
- weakincentives/runtime/session/reducers.py +82 -0
- weakincentives/runtime/session/selectors.py +56 -0
- weakincentives/runtime/session/session.py +387 -0
- weakincentives/runtime/session/snapshots.py +310 -0
- weakincentives/serde/__init__.py +19 -0
- weakincentives/serde/_utils.py +240 -0
- weakincentives/serde/dataclass_serde.py +55 -0
- weakincentives/serde/dump.py +189 -0
- weakincentives/serde/parse.py +417 -0
- weakincentives/serde/schema.py +260 -0
- weakincentives/tools/__init__.py +154 -0
- weakincentives/tools/_context.py +38 -0
- weakincentives/tools/asteval.py +853 -0
- weakincentives/tools/errors.py +26 -0
- weakincentives/tools/planning.py +831 -0
- weakincentives/tools/podman.py +1655 -0
- weakincentives/tools/subagents.py +346 -0
- weakincentives/tools/vfs.py +1390 -0
- weakincentives/types/__init__.py +35 -0
- weakincentives/types/json.py +45 -0
- weakincentives-0.9.0.dist-info/METADATA +775 -0
- weakincentives-0.9.0.dist-info/RECORD +73 -0
- weakincentives-0.9.0.dist-info/WHEEL +4 -0
- weakincentives-0.9.0.dist-info/entry_points.txt +2 -0
- weakincentives-0.9.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
"""Structural typing primitives shared across prompt tooling."""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from collections.abc import Mapping
|
|
18
|
+
from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeVar
|
|
19
|
+
|
|
20
|
+
from ..deadlines import Deadline
|
|
21
|
+
from ._overrides_protocols import PromptOverridesStoreProtocol
|
|
22
|
+
from ._types import SupportsDataclass
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING: # pragma: no cover - typing only
|
|
25
|
+
from ..runtime.events._types import EventBus
|
|
26
|
+
from ..runtime.session.protocols import SessionProtocol
|
|
27
|
+
|
|
28
|
+
PromptOutputT = TypeVar("PromptOutputT", covariant=True)
|
|
29
|
+
RenderedOutputT = TypeVar("RenderedOutputT", covariant=True)
|
|
30
|
+
AdapterOutputT = TypeVar("AdapterOutputT")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class PromptResponseProtocol(Protocol[AdapterOutputT]):
|
|
34
|
+
prompt_name: str
|
|
35
|
+
text: str | None
|
|
36
|
+
output: AdapterOutputT | None
|
|
37
|
+
tool_results: tuple[object, ...]
|
|
38
|
+
provider_payload: Mapping[str, Any] | None
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class RenderedPromptProtocol(Protocol[RenderedOutputT]):
|
|
42
|
+
"""Interface satisfied by rendered prompt snapshots."""
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def text(self) -> str: ...
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def output_type(self) -> type[Any] | None: ...
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def container(self) -> Literal["object", "array"] | None: ...
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def allow_extra_keys(self) -> bool | None: ...
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def deadline(self) -> Deadline | None: ...
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def tools(self) -> tuple[object, ...]: ...
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def tool_param_descriptions(self) -> Mapping[str, Mapping[str, str]]: ...
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class PromptProtocol(Protocol[PromptOutputT]):
|
|
67
|
+
"""Interface describing the subset of Prompt state exposed to tools."""
|
|
68
|
+
|
|
69
|
+
ns: str
|
|
70
|
+
key: str
|
|
71
|
+
name: str | None
|
|
72
|
+
|
|
73
|
+
def render(
|
|
74
|
+
self,
|
|
75
|
+
*params: SupportsDataclass,
|
|
76
|
+
overrides_store: PromptOverridesStoreProtocol | None = None,
|
|
77
|
+
tag: str = "latest",
|
|
78
|
+
inject_output_instructions: bool | None = None,
|
|
79
|
+
) -> RenderedPromptProtocol[PromptOutputT]: ...
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class ProviderAdapterProtocol(Protocol[AdapterOutputT]):
|
|
83
|
+
"""Interface describing the subset of adapter behaviour required by tools."""
|
|
84
|
+
|
|
85
|
+
def evaluate(
|
|
86
|
+
self,
|
|
87
|
+
prompt: PromptProtocol[AdapterOutputT],
|
|
88
|
+
*params: SupportsDataclass,
|
|
89
|
+
parse_output: bool = True,
|
|
90
|
+
bus: EventBus,
|
|
91
|
+
session: SessionProtocol,
|
|
92
|
+
deadline: Deadline | None = None,
|
|
93
|
+
overrides_store: PromptOverridesStoreProtocol | None = None,
|
|
94
|
+
overrides_tag: str = "latest",
|
|
95
|
+
) -> PromptResponseProtocol[AdapterOutputT]: ...
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
__all__ = [
|
|
99
|
+
"PromptProtocol",
|
|
100
|
+
"PromptResponseProtocol",
|
|
101
|
+
"ProviderAdapterProtocol",
|
|
102
|
+
"RenderedPromptProtocol",
|
|
103
|
+
]
|
|
@@ -0,0 +1,375 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
"""Section registration helpers for :mod:`weakincentives.prompt`."""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from collections.abc import Callable, Mapping, MutableMapping, Sequence
|
|
18
|
+
from dataclasses import dataclass, fields, is_dataclass, replace
|
|
19
|
+
from types import MappingProxyType
|
|
20
|
+
from typing import Any, cast
|
|
21
|
+
|
|
22
|
+
from ..dbc import invariant
|
|
23
|
+
from ._types import SupportsDataclass, SupportsToolResult
|
|
24
|
+
from .errors import PromptRenderError, PromptValidationError, SectionPath
|
|
25
|
+
from .section import Section
|
|
26
|
+
from .tool import Tool
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass(frozen=True, slots=True)
|
|
30
|
+
class SectionNode[ParamsT: SupportsDataclass]:
|
|
31
|
+
"""Flattened view of a section within a prompt."""
|
|
32
|
+
|
|
33
|
+
section: Section[ParamsT]
|
|
34
|
+
depth: int
|
|
35
|
+
path: SectionPath
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass(frozen=True)
|
|
39
|
+
class RegistrySnapshot:
|
|
40
|
+
"""Immutable view over registered prompt sections."""
|
|
41
|
+
|
|
42
|
+
sections: tuple[SectionNode[SupportsDataclass], ...]
|
|
43
|
+
params_registry: Mapping[
|
|
44
|
+
type[SupportsDataclass], tuple[SectionNode[SupportsDataclass], ...]
|
|
45
|
+
]
|
|
46
|
+
defaults_by_path: Mapping[SectionPath, SupportsDataclass]
|
|
47
|
+
defaults_by_type: Mapping[type[SupportsDataclass], SupportsDataclass]
|
|
48
|
+
placeholders: Mapping[SectionPath, frozenset[str]]
|
|
49
|
+
tool_name_registry: Mapping[str, SectionPath]
|
|
50
|
+
|
|
51
|
+
def resolve_section_params(
|
|
52
|
+
self,
|
|
53
|
+
node: SectionNode[SupportsDataclass],
|
|
54
|
+
param_lookup: MutableMapping[type[SupportsDataclass], SupportsDataclass],
|
|
55
|
+
) -> SupportsDataclass | None:
|
|
56
|
+
"""Return parameters for a section, applying defaults when necessary."""
|
|
57
|
+
|
|
58
|
+
params_type = node.section.param_type
|
|
59
|
+
if params_type is None:
|
|
60
|
+
return None
|
|
61
|
+
|
|
62
|
+
section_params: SupportsDataclass | None = param_lookup.get(params_type)
|
|
63
|
+
|
|
64
|
+
if section_params is None:
|
|
65
|
+
default_value = self.defaults_by_path.get(node.path)
|
|
66
|
+
if default_value is not None:
|
|
67
|
+
section_params = clone_dataclass(default_value)
|
|
68
|
+
else:
|
|
69
|
+
type_default = self.defaults_by_type.get(params_type)
|
|
70
|
+
if type_default is not None:
|
|
71
|
+
section_params = clone_dataclass(type_default)
|
|
72
|
+
else:
|
|
73
|
+
try:
|
|
74
|
+
constructor = cast(
|
|
75
|
+
Callable[[], SupportsDataclass | None], params_type
|
|
76
|
+
)
|
|
77
|
+
section_params = constructor()
|
|
78
|
+
except TypeError as error:
|
|
79
|
+
raise PromptRenderError(
|
|
80
|
+
"Missing parameters for section.",
|
|
81
|
+
section_path=node.path,
|
|
82
|
+
dataclass_type=params_type,
|
|
83
|
+
) from error
|
|
84
|
+
|
|
85
|
+
result: SupportsDataclass | None = section_params
|
|
86
|
+
if result is None:
|
|
87
|
+
raise PromptRenderError(
|
|
88
|
+
"Section constructor must return a dataclass instance.",
|
|
89
|
+
section_path=node.path,
|
|
90
|
+
dataclass_type=params_type,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
if not is_dataclass(result):
|
|
94
|
+
raise PromptRenderError(
|
|
95
|
+
"Section constructor must return a dataclass instance.",
|
|
96
|
+
section_path=node.path,
|
|
97
|
+
dataclass_type=params_type,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
return result
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def param_types(self) -> set[type[SupportsDataclass]]:
|
|
104
|
+
"""Return the set of parameter dataclasses registered for sections."""
|
|
105
|
+
|
|
106
|
+
return set(self.params_registry.keys())
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _registry_paths_are_registered(
|
|
110
|
+
registry: PromptRegistry,
|
|
111
|
+
) -> tuple[bool, str] | bool:
|
|
112
|
+
"""Ensure internal registries only reference known section nodes."""
|
|
113
|
+
|
|
114
|
+
section_nodes = registry._section_nodes # pyright: ignore[reportPrivateUsage]
|
|
115
|
+
node_by_path = {node.path: node for node in section_nodes}
|
|
116
|
+
defaults_by_path = registry._defaults_by_path # pyright: ignore[reportPrivateUsage]
|
|
117
|
+
placeholders = registry._placeholders # pyright: ignore[reportPrivateUsage]
|
|
118
|
+
tool_name_registry = registry._tool_name_registry # pyright: ignore[reportPrivateUsage]
|
|
119
|
+
defaults_by_type = registry._defaults_by_type # pyright: ignore[reportPrivateUsage]
|
|
120
|
+
|
|
121
|
+
unknown_default_paths = [
|
|
122
|
+
path for path in defaults_by_path if path not in node_by_path
|
|
123
|
+
]
|
|
124
|
+
if unknown_default_paths:
|
|
125
|
+
return (
|
|
126
|
+
False,
|
|
127
|
+
f"defaults reference unknown paths: {sorted(unknown_default_paths)!r}",
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
for path, default in defaults_by_path.items():
|
|
131
|
+
node = node_by_path[path]
|
|
132
|
+
params_type = node.section.param_type
|
|
133
|
+
if params_type is None:
|
|
134
|
+
return False, f"section at {path!r} does not accept params but has defaults"
|
|
135
|
+
if type(default) is not params_type:
|
|
136
|
+
return False, (
|
|
137
|
+
"default params type mismatch for path "
|
|
138
|
+
f"{path!r}: expected {params_type.__name__}, got {type(default).__name__}"
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
unknown_placeholder_paths = [
|
|
142
|
+
path for path in placeholders if path not in node_by_path
|
|
143
|
+
]
|
|
144
|
+
if unknown_placeholder_paths:
|
|
145
|
+
return False, (
|
|
146
|
+
"placeholders reference unknown paths: "
|
|
147
|
+
f"{sorted(unknown_placeholder_paths)!r}"
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
unknown_tool_paths = [
|
|
151
|
+
path for path in tool_name_registry.values() if path not in node_by_path
|
|
152
|
+
]
|
|
153
|
+
if unknown_tool_paths:
|
|
154
|
+
return False, f"tools reference unknown paths: {sorted(unknown_tool_paths)!r}"
|
|
155
|
+
|
|
156
|
+
for params_type, default in defaults_by_type.items():
|
|
157
|
+
if type(default) is not params_type:
|
|
158
|
+
return False, (
|
|
159
|
+
"default by type mismatch for "
|
|
160
|
+
f"{params_type.__name__}: got {type(default).__name__}"
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
return True
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _params_registry_is_consistent(
|
|
167
|
+
registry: PromptRegistry,
|
|
168
|
+
) -> tuple[bool, str] | bool:
|
|
169
|
+
"""Ensure params registry entries point at known nodes with matching types."""
|
|
170
|
+
|
|
171
|
+
section_nodes = list(registry._section_nodes) # pyright: ignore[reportPrivateUsage]
|
|
172
|
+
params_registry = registry._params_registry # pyright: ignore[reportPrivateUsage]
|
|
173
|
+
for params_type, nodes in params_registry.items():
|
|
174
|
+
for node in nodes:
|
|
175
|
+
if node not in section_nodes:
|
|
176
|
+
return False, (
|
|
177
|
+
"params registry references unknown node at path "
|
|
178
|
+
f"{node.path!r} for {params_type.__name__}"
|
|
179
|
+
)
|
|
180
|
+
node_params_type = node.section.param_type
|
|
181
|
+
if node_params_type is None:
|
|
182
|
+
return False, (
|
|
183
|
+
"params registry references section without params at path "
|
|
184
|
+
f"{node.path!r}"
|
|
185
|
+
)
|
|
186
|
+
if node_params_type is not params_type:
|
|
187
|
+
return False, (
|
|
188
|
+
"params registry type mismatch for path "
|
|
189
|
+
f"{node.path!r}: expected {params_type.__name__}, "
|
|
190
|
+
f"found {node_params_type.__name__}"
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
return True
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
@invariant(
|
|
197
|
+
_registry_paths_are_registered,
|
|
198
|
+
_params_registry_is_consistent,
|
|
199
|
+
)
|
|
200
|
+
class PromptRegistry:
|
|
201
|
+
"""Collect and validate prompt sections prior to rendering."""
|
|
202
|
+
|
|
203
|
+
def __init__(self) -> None:
|
|
204
|
+
super().__init__()
|
|
205
|
+
self._section_nodes: list[SectionNode[SupportsDataclass]] = []
|
|
206
|
+
self._params_registry: dict[
|
|
207
|
+
type[SupportsDataclass], list[SectionNode[SupportsDataclass]]
|
|
208
|
+
] = {}
|
|
209
|
+
self._defaults_by_path: dict[SectionPath, SupportsDataclass] = {}
|
|
210
|
+
self._defaults_by_type: dict[type[SupportsDataclass], SupportsDataclass] = {}
|
|
211
|
+
self._placeholders: dict[SectionPath, set[str]] = {}
|
|
212
|
+
self._tool_name_registry: dict[str, SectionPath] = {}
|
|
213
|
+
|
|
214
|
+
def register_sections(self, sections: Sequence[Section[SupportsDataclass]]) -> None:
|
|
215
|
+
"""Register the provided root sections."""
|
|
216
|
+
|
|
217
|
+
for section in sections:
|
|
218
|
+
self._register_section(section, path=(section.key,), depth=0)
|
|
219
|
+
|
|
220
|
+
def register_section(
|
|
221
|
+
self,
|
|
222
|
+
section: Section[SupportsDataclass],
|
|
223
|
+
*,
|
|
224
|
+
path: SectionPath,
|
|
225
|
+
depth: int,
|
|
226
|
+
) -> None:
|
|
227
|
+
"""Register a single section at the supplied path and depth."""
|
|
228
|
+
|
|
229
|
+
self._register_section(section, path=path, depth=depth)
|
|
230
|
+
|
|
231
|
+
def _register_section(
|
|
232
|
+
self,
|
|
233
|
+
section: Section[SupportsDataclass],
|
|
234
|
+
*,
|
|
235
|
+
path: SectionPath,
|
|
236
|
+
depth: int,
|
|
237
|
+
) -> None:
|
|
238
|
+
params_type = section.param_type
|
|
239
|
+
if params_type is not None and not is_dataclass(params_type):
|
|
240
|
+
raise PromptValidationError(
|
|
241
|
+
"Section params must be a dataclass.",
|
|
242
|
+
section_path=path,
|
|
243
|
+
dataclass_type=params_type,
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
node: SectionNode[SupportsDataclass] = SectionNode(
|
|
247
|
+
section=section, depth=depth, path=path
|
|
248
|
+
)
|
|
249
|
+
self._section_nodes.append(node)
|
|
250
|
+
|
|
251
|
+
if params_type is not None:
|
|
252
|
+
self._params_registry.setdefault(params_type, []).append(node)
|
|
253
|
+
|
|
254
|
+
if params_type is not None and section.default_params is not None:
|
|
255
|
+
default_value = section.default_params
|
|
256
|
+
if isinstance(default_value, type) or not is_dataclass(default_value):
|
|
257
|
+
raise PromptValidationError(
|
|
258
|
+
"Section defaults must be dataclass instances.",
|
|
259
|
+
section_path=path,
|
|
260
|
+
dataclass_type=params_type,
|
|
261
|
+
)
|
|
262
|
+
if type(default_value) is not params_type:
|
|
263
|
+
raise PromptValidationError(
|
|
264
|
+
"Section defaults must match section params type.",
|
|
265
|
+
section_path=path,
|
|
266
|
+
dataclass_type=params_type,
|
|
267
|
+
)
|
|
268
|
+
self._defaults_by_path[path] = default_value
|
|
269
|
+
_ = self._defaults_by_type.setdefault(params_type, default_value)
|
|
270
|
+
|
|
271
|
+
section_placeholders = section.placeholder_names()
|
|
272
|
+
self._placeholders[path] = set(section_placeholders)
|
|
273
|
+
if params_type is None:
|
|
274
|
+
if section_placeholders:
|
|
275
|
+
placeholder = sorted(section_placeholders)[0]
|
|
276
|
+
raise PromptValidationError(
|
|
277
|
+
"Section does not accept parameters but declares placeholders.",
|
|
278
|
+
section_path=path,
|
|
279
|
+
placeholder=placeholder,
|
|
280
|
+
)
|
|
281
|
+
else:
|
|
282
|
+
param_fields = {field.name for field in fields(params_type)}
|
|
283
|
+
unknown_placeholders = section_placeholders - param_fields
|
|
284
|
+
if unknown_placeholders:
|
|
285
|
+
placeholder = sorted(unknown_placeholders)[0]
|
|
286
|
+
raise PromptValidationError(
|
|
287
|
+
"Template references unknown placeholder.",
|
|
288
|
+
section_path=path,
|
|
289
|
+
dataclass_type=params_type,
|
|
290
|
+
placeholder=placeholder,
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
section_tools = cast(tuple[object, ...], section.tools())
|
|
294
|
+
if section_tools:
|
|
295
|
+
for tool in section_tools:
|
|
296
|
+
if not isinstance(tool, Tool):
|
|
297
|
+
raise PromptValidationError(
|
|
298
|
+
"Section tools must be Tool instances.",
|
|
299
|
+
section_path=path,
|
|
300
|
+
dataclass_type=params_type,
|
|
301
|
+
)
|
|
302
|
+
typed_tool = cast(Tool[SupportsDataclass, SupportsToolResult], tool)
|
|
303
|
+
self._register_section_tools(
|
|
304
|
+
typed_tool,
|
|
305
|
+
path,
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
for child in section.children:
|
|
309
|
+
child_path = (*path, child.key)
|
|
310
|
+
self._register_section(child, path=child_path, depth=depth + 1)
|
|
311
|
+
|
|
312
|
+
def _register_section_tools[
|
|
313
|
+
ParamsT: SupportsDataclass,
|
|
314
|
+
ResultT: SupportsToolResult,
|
|
315
|
+
](
|
|
316
|
+
self,
|
|
317
|
+
tool: Tool[ParamsT, ResultT],
|
|
318
|
+
path: SectionPath,
|
|
319
|
+
) -> None:
|
|
320
|
+
params_type = tool.params_type
|
|
321
|
+
if not is_dataclass(params_type):
|
|
322
|
+
raise PromptValidationError(
|
|
323
|
+
"Tool parameters must be dataclass types.",
|
|
324
|
+
section_path=path,
|
|
325
|
+
dataclass_type=params_type,
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
existing_path = self._tool_name_registry.get(tool.name)
|
|
329
|
+
if existing_path is not None:
|
|
330
|
+
raise PromptValidationError(
|
|
331
|
+
"Duplicate tool name registered for prompt.",
|
|
332
|
+
section_path=path,
|
|
333
|
+
dataclass_type=tool.params_type,
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
self._tool_name_registry[tool.name] = path
|
|
337
|
+
|
|
338
|
+
def snapshot(self) -> RegistrySnapshot:
|
|
339
|
+
"""Return an immutable snapshot of the registered sections."""
|
|
340
|
+
|
|
341
|
+
params_registry: dict[
|
|
342
|
+
type[SupportsDataclass], tuple[SectionNode[SupportsDataclass], ...]
|
|
343
|
+
] = {
|
|
344
|
+
params_type: tuple(nodes)
|
|
345
|
+
for params_type, nodes in self._params_registry.items()
|
|
346
|
+
}
|
|
347
|
+
defaults_by_path = MappingProxyType(dict(self._defaults_by_path))
|
|
348
|
+
defaults_by_type = MappingProxyType(dict(self._defaults_by_type))
|
|
349
|
+
placeholders = MappingProxyType(
|
|
350
|
+
{path: frozenset(names) for path, names in self._placeholders.items()}
|
|
351
|
+
)
|
|
352
|
+
tool_name_registry = MappingProxyType(dict(self._tool_name_registry))
|
|
353
|
+
|
|
354
|
+
return RegistrySnapshot(
|
|
355
|
+
sections=tuple(self._section_nodes),
|
|
356
|
+
params_registry=MappingProxyType(params_registry),
|
|
357
|
+
defaults_by_path=defaults_by_path,
|
|
358
|
+
defaults_by_type=defaults_by_type,
|
|
359
|
+
placeholders=placeholders,
|
|
360
|
+
tool_name_registry=tool_name_registry,
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
def clone_dataclass(instance: SupportsDataclass) -> SupportsDataclass:
|
|
365
|
+
"""Return a shallow copy of the provided dataclass instance."""
|
|
366
|
+
|
|
367
|
+
return cast(SupportsDataclass, replace(cast(Any, instance)))
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
__all__ = [
|
|
371
|
+
"PromptRegistry",
|
|
372
|
+
"RegistrySnapshot",
|
|
373
|
+
"SectionNode",
|
|
374
|
+
"clone_dataclass",
|
|
375
|
+
]
|