weakincentives 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- weakincentives/__init__.py +67 -0
- weakincentives/adapters/__init__.py +37 -0
- weakincentives/adapters/_names.py +32 -0
- weakincentives/adapters/_provider_protocols.py +69 -0
- weakincentives/adapters/_tool_messages.py +80 -0
- weakincentives/adapters/core.py +102 -0
- weakincentives/adapters/litellm.py +254 -0
- weakincentives/adapters/openai.py +254 -0
- weakincentives/adapters/shared.py +1021 -0
- weakincentives/cli/__init__.py +23 -0
- weakincentives/cli/wink.py +58 -0
- weakincentives/dbc/__init__.py +412 -0
- weakincentives/deadlines.py +58 -0
- weakincentives/prompt/__init__.py +105 -0
- weakincentives/prompt/_generic_params_specializer.py +64 -0
- weakincentives/prompt/_normalization.py +48 -0
- weakincentives/prompt/_overrides_protocols.py +33 -0
- weakincentives/prompt/_types.py +34 -0
- weakincentives/prompt/chapter.py +146 -0
- weakincentives/prompt/composition.py +281 -0
- weakincentives/prompt/errors.py +57 -0
- weakincentives/prompt/markdown.py +108 -0
- weakincentives/prompt/overrides/__init__.py +59 -0
- weakincentives/prompt/overrides/_fs.py +164 -0
- weakincentives/prompt/overrides/inspection.py +141 -0
- weakincentives/prompt/overrides/local_store.py +275 -0
- weakincentives/prompt/overrides/validation.py +534 -0
- weakincentives/prompt/overrides/versioning.py +269 -0
- weakincentives/prompt/prompt.py +353 -0
- weakincentives/prompt/protocols.py +103 -0
- weakincentives/prompt/registry.py +375 -0
- weakincentives/prompt/rendering.py +288 -0
- weakincentives/prompt/response_format.py +60 -0
- weakincentives/prompt/section.py +166 -0
- weakincentives/prompt/structured_output.py +179 -0
- weakincentives/prompt/tool.py +397 -0
- weakincentives/prompt/tool_result.py +30 -0
- weakincentives/py.typed +0 -0
- weakincentives/runtime/__init__.py +82 -0
- weakincentives/runtime/events/__init__.py +126 -0
- weakincentives/runtime/events/_types.py +110 -0
- weakincentives/runtime/logging.py +284 -0
- weakincentives/runtime/session/__init__.py +46 -0
- weakincentives/runtime/session/_slice_types.py +24 -0
- weakincentives/runtime/session/_types.py +55 -0
- weakincentives/runtime/session/dataclasses.py +29 -0
- weakincentives/runtime/session/protocols.py +34 -0
- weakincentives/runtime/session/reducer_context.py +40 -0
- weakincentives/runtime/session/reducers.py +82 -0
- weakincentives/runtime/session/selectors.py +56 -0
- weakincentives/runtime/session/session.py +387 -0
- weakincentives/runtime/session/snapshots.py +310 -0
- weakincentives/serde/__init__.py +19 -0
- weakincentives/serde/_utils.py +240 -0
- weakincentives/serde/dataclass_serde.py +55 -0
- weakincentives/serde/dump.py +189 -0
- weakincentives/serde/parse.py +417 -0
- weakincentives/serde/schema.py +260 -0
- weakincentives/tools/__init__.py +154 -0
- weakincentives/tools/_context.py +38 -0
- weakincentives/tools/asteval.py +853 -0
- weakincentives/tools/errors.py +26 -0
- weakincentives/tools/planning.py +831 -0
- weakincentives/tools/podman.py +1655 -0
- weakincentives/tools/subagents.py +346 -0
- weakincentives/tools/vfs.py +1390 -0
- weakincentives/types/__init__.py +35 -0
- weakincentives/types/json.py +45 -0
- weakincentives-0.9.0.dist-info/METADATA +775 -0
- weakincentives-0.9.0.dist-info/RECORD +73 -0
- weakincentives-0.9.0.dist-info/WHEEL +4 -0
- weakincentives-0.9.0.dist-info/entry_points.txt +2 -0
- weakincentives-0.9.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,534 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
from collections.abc import Iterable, Mapping
|
|
16
|
+
from dataclasses import fields, is_dataclass
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import Literal, cast, overload
|
|
19
|
+
|
|
20
|
+
from ...runtime.logging import StructuredLogger, get_logger
|
|
21
|
+
from ...types import JSONValue
|
|
22
|
+
from .versioning import (
|
|
23
|
+
HexDigest,
|
|
24
|
+
PromptDescriptor,
|
|
25
|
+
PromptLike,
|
|
26
|
+
PromptOverride,
|
|
27
|
+
PromptOverridesError,
|
|
28
|
+
SectionDescriptor,
|
|
29
|
+
SectionOverride,
|
|
30
|
+
ToolContractProtocol,
|
|
31
|
+
ToolDescriptor,
|
|
32
|
+
ToolOverride,
|
|
33
|
+
ensure_hex_digest,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
_LOGGER: StructuredLogger = get_logger(
|
|
37
|
+
__name__, context={"component": "prompt_overrides"}
|
|
38
|
+
)
|
|
39
|
+
FORMAT_VERSION = 1
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def validate_header(
|
|
43
|
+
payload: Mapping[str, JSONValue],
|
|
44
|
+
descriptor: PromptDescriptor,
|
|
45
|
+
tag: str,
|
|
46
|
+
file_path: Path,
|
|
47
|
+
) -> None:
|
|
48
|
+
version = payload.get("version")
|
|
49
|
+
if version != FORMAT_VERSION:
|
|
50
|
+
raise PromptOverridesError(
|
|
51
|
+
f"Unsupported override file version {version!r} in {file_path}."
|
|
52
|
+
)
|
|
53
|
+
ns = payload.get("ns")
|
|
54
|
+
prompt_key = payload.get("prompt_key")
|
|
55
|
+
tag_value = payload.get("tag")
|
|
56
|
+
if ns != descriptor.ns or prompt_key != descriptor.key or tag_value != tag:
|
|
57
|
+
raise PromptOverridesError(
|
|
58
|
+
"Override file metadata does not match descriptor inputs."
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _section_descriptor_index(
|
|
63
|
+
descriptor: PromptDescriptor,
|
|
64
|
+
) -> dict[tuple[str, ...], SectionDescriptor]:
|
|
65
|
+
return {section.path: section for section in descriptor.sections}
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _tool_descriptor_index(
|
|
69
|
+
descriptor: PromptDescriptor,
|
|
70
|
+
) -> dict[str, ToolDescriptor]:
|
|
71
|
+
return {tool.name: tool for tool in descriptor.tools}
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _format_section_path(path: tuple[str, ...]) -> str:
|
|
75
|
+
return "/".join(path)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _log_mismatched_override_metadata(
|
|
79
|
+
descriptor: PromptDescriptor,
|
|
80
|
+
override: PromptOverride,
|
|
81
|
+
) -> None:
|
|
82
|
+
_LOGGER.debug(
|
|
83
|
+
"Skipping override due to descriptor metadata mismatch.",
|
|
84
|
+
event="prompt_override_mismatched_descriptor",
|
|
85
|
+
context={
|
|
86
|
+
"expected_ns": descriptor.ns,
|
|
87
|
+
"expected_key": descriptor.key,
|
|
88
|
+
"override_ns": override.ns,
|
|
89
|
+
"override_key": override.prompt_key,
|
|
90
|
+
},
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@overload
|
|
95
|
+
def _normalize_section_override(
|
|
96
|
+
*,
|
|
97
|
+
path: tuple[str, ...],
|
|
98
|
+
descriptor_section: SectionDescriptor | None,
|
|
99
|
+
expected_hash: JSONValue,
|
|
100
|
+
body: JSONValue,
|
|
101
|
+
strict: Literal[True],
|
|
102
|
+
path_display: str,
|
|
103
|
+
body_error_message: str,
|
|
104
|
+
) -> SectionOverride: ...
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@overload
|
|
108
|
+
def _normalize_section_override(
|
|
109
|
+
*,
|
|
110
|
+
path: tuple[str, ...],
|
|
111
|
+
descriptor_section: SectionDescriptor | None,
|
|
112
|
+
expected_hash: JSONValue,
|
|
113
|
+
body: JSONValue,
|
|
114
|
+
strict: Literal[False],
|
|
115
|
+
path_display: str,
|
|
116
|
+
body_error_message: str,
|
|
117
|
+
) -> SectionOverride | None: ...
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _normalize_section_override(
|
|
121
|
+
*,
|
|
122
|
+
path: tuple[str, ...],
|
|
123
|
+
descriptor_section: SectionDescriptor | None,
|
|
124
|
+
expected_hash: JSONValue,
|
|
125
|
+
body: JSONValue,
|
|
126
|
+
strict: bool,
|
|
127
|
+
path_display: str,
|
|
128
|
+
body_error_message: str,
|
|
129
|
+
) -> SectionOverride | None:
|
|
130
|
+
if descriptor_section is None:
|
|
131
|
+
if strict:
|
|
132
|
+
raise PromptOverridesError(
|
|
133
|
+
f"Unknown section path for override: {path_display}"
|
|
134
|
+
)
|
|
135
|
+
_LOGGER.debug(
|
|
136
|
+
"Skipping unknown override section path.",
|
|
137
|
+
event="prompt_override_unknown_section",
|
|
138
|
+
context={"path": path_display},
|
|
139
|
+
)
|
|
140
|
+
return None
|
|
141
|
+
expected_digest = ensure_hex_digest(
|
|
142
|
+
cast(HexDigest | str, expected_hash),
|
|
143
|
+
field_name="Section expected_hash",
|
|
144
|
+
)
|
|
145
|
+
if expected_digest != descriptor_section.content_hash:
|
|
146
|
+
if strict:
|
|
147
|
+
raise PromptOverridesError(f"Hash mismatch for section {path_display}.")
|
|
148
|
+
_LOGGER.debug(
|
|
149
|
+
"Skipping stale section override.",
|
|
150
|
+
event="prompt_override_stale_section",
|
|
151
|
+
context={
|
|
152
|
+
"path": path_display,
|
|
153
|
+
"expected_hash": str(descriptor_section.content_hash),
|
|
154
|
+
"found_hash": str(expected_digest),
|
|
155
|
+
},
|
|
156
|
+
)
|
|
157
|
+
return None
|
|
158
|
+
if not isinstance(body, str):
|
|
159
|
+
raise PromptOverridesError(body_error_message)
|
|
160
|
+
return SectionOverride(
|
|
161
|
+
expected_hash=expected_digest,
|
|
162
|
+
body=body,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@overload
|
|
167
|
+
def _normalize_tool_override(
|
|
168
|
+
*,
|
|
169
|
+
name: str,
|
|
170
|
+
descriptor_tool: ToolDescriptor | None,
|
|
171
|
+
expected_hash: JSONValue,
|
|
172
|
+
description: JSONValue,
|
|
173
|
+
param_descriptions: JSONValue,
|
|
174
|
+
strict: Literal[True],
|
|
175
|
+
description_error_message: str,
|
|
176
|
+
param_mapping_error_message: str,
|
|
177
|
+
param_entry_error_message: str,
|
|
178
|
+
) -> ToolOverride: ...
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
@overload
|
|
182
|
+
def _normalize_tool_override(
|
|
183
|
+
*,
|
|
184
|
+
name: str,
|
|
185
|
+
descriptor_tool: ToolDescriptor | None,
|
|
186
|
+
expected_hash: JSONValue,
|
|
187
|
+
description: JSONValue,
|
|
188
|
+
param_descriptions: JSONValue,
|
|
189
|
+
strict: Literal[False],
|
|
190
|
+
description_error_message: str,
|
|
191
|
+
param_mapping_error_message: str,
|
|
192
|
+
param_entry_error_message: str,
|
|
193
|
+
) -> ToolOverride | None: ...
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def _normalize_tool_override(
|
|
197
|
+
*,
|
|
198
|
+
name: str,
|
|
199
|
+
descriptor_tool: ToolDescriptor | None,
|
|
200
|
+
expected_hash: JSONValue,
|
|
201
|
+
description: JSONValue,
|
|
202
|
+
param_descriptions: JSONValue,
|
|
203
|
+
strict: bool,
|
|
204
|
+
description_error_message: str,
|
|
205
|
+
param_mapping_error_message: str,
|
|
206
|
+
param_entry_error_message: str,
|
|
207
|
+
) -> ToolOverride | None:
|
|
208
|
+
if descriptor_tool is None:
|
|
209
|
+
if strict:
|
|
210
|
+
raise PromptOverridesError(f"Unknown tool override: {name}")
|
|
211
|
+
_LOGGER.debug(
|
|
212
|
+
"Skipping unknown tool override.",
|
|
213
|
+
event="prompt_override_unknown_tool",
|
|
214
|
+
context={"tool": name},
|
|
215
|
+
)
|
|
216
|
+
return None
|
|
217
|
+
expected_digest = ensure_hex_digest(
|
|
218
|
+
cast(HexDigest | str, expected_hash),
|
|
219
|
+
field_name="Tool expected_contract_hash",
|
|
220
|
+
)
|
|
221
|
+
if expected_digest != descriptor_tool.contract_hash:
|
|
222
|
+
if strict:
|
|
223
|
+
raise PromptOverridesError(f"Hash mismatch for tool override: {name}.")
|
|
224
|
+
_LOGGER.debug(
|
|
225
|
+
"Skipping stale tool override.",
|
|
226
|
+
event="prompt_override_stale_tool",
|
|
227
|
+
context={
|
|
228
|
+
"tool": name,
|
|
229
|
+
"expected_hash": str(descriptor_tool.contract_hash),
|
|
230
|
+
"found_hash": str(expected_digest),
|
|
231
|
+
},
|
|
232
|
+
)
|
|
233
|
+
return None
|
|
234
|
+
if description is not None and not isinstance(description, str):
|
|
235
|
+
raise PromptOverridesError(description_error_message)
|
|
236
|
+
if param_descriptions is None:
|
|
237
|
+
param_descriptions = {}
|
|
238
|
+
if not isinstance(param_descriptions, Mapping):
|
|
239
|
+
raise PromptOverridesError(param_mapping_error_message)
|
|
240
|
+
mapping_params = cast(Mapping[str, JSONValue], param_descriptions)
|
|
241
|
+
normalized_params: dict[str, str] = {}
|
|
242
|
+
for key, value in mapping_params.items():
|
|
243
|
+
if not isinstance(value, str):
|
|
244
|
+
raise PromptOverridesError(param_entry_error_message)
|
|
245
|
+
normalized_params[key] = value
|
|
246
|
+
if description is None:
|
|
247
|
+
normalized_description: str | None = None
|
|
248
|
+
else:
|
|
249
|
+
normalized_description = description
|
|
250
|
+
return ToolOverride(
|
|
251
|
+
name=name,
|
|
252
|
+
expected_contract_hash=expected_digest,
|
|
253
|
+
description=normalized_description,
|
|
254
|
+
param_descriptions=normalized_params,
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def load_sections(
|
|
259
|
+
payload: JSONValue | None,
|
|
260
|
+
descriptor: PromptDescriptor,
|
|
261
|
+
) -> dict[tuple[str, ...], SectionOverride]:
|
|
262
|
+
if payload is None:
|
|
263
|
+
return {}
|
|
264
|
+
if not isinstance(payload, Mapping):
|
|
265
|
+
raise PromptOverridesError("Sections payload must be a mapping.")
|
|
266
|
+
if not payload:
|
|
267
|
+
return {}
|
|
268
|
+
mapping_payload = cast(Mapping[object, JSONValue], payload)
|
|
269
|
+
mapping_entries = cast(Iterable[tuple[object, JSONValue]], mapping_payload.items())
|
|
270
|
+
descriptor_index = _section_descriptor_index(descriptor)
|
|
271
|
+
overrides: dict[tuple[str, ...], SectionOverride] = {}
|
|
272
|
+
for path_key_raw, section_payload_raw in mapping_entries:
|
|
273
|
+
if not isinstance(path_key_raw, str):
|
|
274
|
+
raise PromptOverridesError("Section keys must be strings.")
|
|
275
|
+
path_key = path_key_raw
|
|
276
|
+
path = tuple(part for part in path_key.split("/") if part)
|
|
277
|
+
if not isinstance(section_payload_raw, Mapping):
|
|
278
|
+
raise PromptOverridesError("Section payload must be an object.")
|
|
279
|
+
section_payload = cast(Mapping[str, JSONValue], section_payload_raw)
|
|
280
|
+
expected_hash = section_payload.get("expected_hash")
|
|
281
|
+
body = section_payload.get("body")
|
|
282
|
+
section_override = _normalize_section_override(
|
|
283
|
+
path=path,
|
|
284
|
+
descriptor_section=descriptor_index.get(path),
|
|
285
|
+
expected_hash=expected_hash,
|
|
286
|
+
body=body,
|
|
287
|
+
strict=False,
|
|
288
|
+
path_display=path_key,
|
|
289
|
+
body_error_message="Section body must be a string.",
|
|
290
|
+
)
|
|
291
|
+
if section_override is not None:
|
|
292
|
+
overrides[path] = section_override
|
|
293
|
+
return overrides
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def filter_override_for_descriptor(
|
|
297
|
+
descriptor: PromptDescriptor,
|
|
298
|
+
override: PromptOverride,
|
|
299
|
+
) -> tuple[dict[tuple[str, ...], SectionOverride], dict[str, ToolOverride]]:
|
|
300
|
+
if override.ns != descriptor.ns or override.prompt_key != descriptor.key:
|
|
301
|
+
_log_mismatched_override_metadata(descriptor, override)
|
|
302
|
+
return {}, {}
|
|
303
|
+
|
|
304
|
+
descriptor_sections = _section_descriptor_index(descriptor)
|
|
305
|
+
descriptor_tools = _tool_descriptor_index(descriptor)
|
|
306
|
+
|
|
307
|
+
filtered_sections: dict[tuple[str, ...], SectionOverride] = {}
|
|
308
|
+
for path, section_override in override.sections.items():
|
|
309
|
+
normalized_section = _normalize_section_override(
|
|
310
|
+
path=path,
|
|
311
|
+
descriptor_section=descriptor_sections.get(path),
|
|
312
|
+
expected_hash=section_override.expected_hash,
|
|
313
|
+
body=section_override.body,
|
|
314
|
+
strict=False,
|
|
315
|
+
path_display=_format_section_path(path),
|
|
316
|
+
body_error_message="Section override body must be a string.",
|
|
317
|
+
)
|
|
318
|
+
if normalized_section is not None:
|
|
319
|
+
filtered_sections[path] = normalized_section
|
|
320
|
+
|
|
321
|
+
filtered_tools: dict[str, ToolOverride] = {}
|
|
322
|
+
for name, tool_override in override.tool_overrides.items():
|
|
323
|
+
normalized_tool = _normalize_tool_override(
|
|
324
|
+
name=name,
|
|
325
|
+
descriptor_tool=descriptor_tools.get(name),
|
|
326
|
+
expected_hash=tool_override.expected_contract_hash,
|
|
327
|
+
description=tool_override.description,
|
|
328
|
+
param_descriptions=tool_override.param_descriptions,
|
|
329
|
+
strict=False,
|
|
330
|
+
description_error_message="Tool description override must be a string when set.",
|
|
331
|
+
param_mapping_error_message="Tool param_descriptions must be a mapping when provided.",
|
|
332
|
+
param_entry_error_message="Tool param description entries must be strings.",
|
|
333
|
+
)
|
|
334
|
+
if normalized_tool is not None:
|
|
335
|
+
filtered_tools[name] = normalized_tool
|
|
336
|
+
|
|
337
|
+
return filtered_sections, filtered_tools
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def load_tools(
|
|
341
|
+
payload: JSONValue | None,
|
|
342
|
+
descriptor: PromptDescriptor,
|
|
343
|
+
) -> dict[str, ToolOverride]:
|
|
344
|
+
if payload is None:
|
|
345
|
+
return {}
|
|
346
|
+
if not isinstance(payload, Mapping):
|
|
347
|
+
raise PromptOverridesError("Tools payload must be a mapping.")
|
|
348
|
+
if not payload:
|
|
349
|
+
return {}
|
|
350
|
+
mapping_payload = cast(Mapping[object, JSONValue], payload)
|
|
351
|
+
mapping_entries = cast(Iterable[tuple[object, JSONValue]], mapping_payload.items())
|
|
352
|
+
descriptor_index = _tool_descriptor_index(descriptor)
|
|
353
|
+
overrides: dict[str, ToolOverride] = {}
|
|
354
|
+
for tool_name_raw, tool_payload_raw in mapping_entries:
|
|
355
|
+
if not isinstance(tool_name_raw, str):
|
|
356
|
+
raise PromptOverridesError("Tool names must be strings.")
|
|
357
|
+
tool_name = tool_name_raw
|
|
358
|
+
if not isinstance(tool_payload_raw, Mapping):
|
|
359
|
+
raise PromptOverridesError("Tool payload must be an object.")
|
|
360
|
+
tool_payload = cast(Mapping[str, JSONValue], tool_payload_raw)
|
|
361
|
+
expected_hash = tool_payload.get("expected_contract_hash")
|
|
362
|
+
description = tool_payload.get("description")
|
|
363
|
+
param_payload = tool_payload.get("param_descriptions")
|
|
364
|
+
tool_override = _normalize_tool_override(
|
|
365
|
+
name=tool_name,
|
|
366
|
+
descriptor_tool=descriptor_index.get(tool_name),
|
|
367
|
+
expected_hash=expected_hash,
|
|
368
|
+
description=description,
|
|
369
|
+
param_descriptions=param_payload,
|
|
370
|
+
strict=False,
|
|
371
|
+
description_error_message="Tool description must be a string when set.",
|
|
372
|
+
param_mapping_error_message="Tool param_descriptions must be a mapping when provided.",
|
|
373
|
+
param_entry_error_message="Tool param description entries must be strings.",
|
|
374
|
+
)
|
|
375
|
+
if tool_override is not None:
|
|
376
|
+
overrides[tool_name] = tool_override
|
|
377
|
+
return overrides
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def validate_sections_for_write(
|
|
381
|
+
sections: Mapping[tuple[str, ...], SectionOverride],
|
|
382
|
+
descriptor: PromptDescriptor,
|
|
383
|
+
) -> dict[tuple[str, ...], SectionOverride]:
|
|
384
|
+
descriptor_index = _section_descriptor_index(descriptor)
|
|
385
|
+
validated: dict[tuple[str, ...], SectionOverride] = {}
|
|
386
|
+
for path, section_override in sections.items():
|
|
387
|
+
path_display = "/".join(path)
|
|
388
|
+
normalized_section = _normalize_section_override(
|
|
389
|
+
path=path,
|
|
390
|
+
descriptor_section=descriptor_index.get(path),
|
|
391
|
+
expected_hash=section_override.expected_hash,
|
|
392
|
+
body=section_override.body,
|
|
393
|
+
strict=True,
|
|
394
|
+
path_display=path_display,
|
|
395
|
+
body_error_message=(
|
|
396
|
+
f"Section override body must be a string for {path_display}."
|
|
397
|
+
),
|
|
398
|
+
)
|
|
399
|
+
validated[path] = normalized_section
|
|
400
|
+
return validated
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
def validate_tools_for_write(
|
|
404
|
+
tools: Mapping[str, ToolOverride],
|
|
405
|
+
descriptor: PromptDescriptor,
|
|
406
|
+
) -> dict[str, ToolOverride]:
|
|
407
|
+
if not tools:
|
|
408
|
+
return {}
|
|
409
|
+
descriptor_index = _tool_descriptor_index(descriptor)
|
|
410
|
+
validated: dict[str, ToolOverride] = {}
|
|
411
|
+
for name, tool_override in tools.items():
|
|
412
|
+
normalized_tool = _normalize_tool_override(
|
|
413
|
+
name=name,
|
|
414
|
+
descriptor_tool=descriptor_index.get(name),
|
|
415
|
+
expected_hash=tool_override.expected_contract_hash,
|
|
416
|
+
description=tool_override.description,
|
|
417
|
+
param_descriptions=tool_override.param_descriptions,
|
|
418
|
+
strict=True,
|
|
419
|
+
description_error_message=(
|
|
420
|
+
f"Tool description override must be a string for {name}."
|
|
421
|
+
),
|
|
422
|
+
param_mapping_error_message=(
|
|
423
|
+
f"Tool parameter descriptions must be a mapping for {name}."
|
|
424
|
+
),
|
|
425
|
+
param_entry_error_message=(
|
|
426
|
+
f"Tool parameter descriptions must map strings to strings for {name}."
|
|
427
|
+
),
|
|
428
|
+
)
|
|
429
|
+
validated[name] = normalized_tool
|
|
430
|
+
return validated
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
def serialize_sections(
|
|
434
|
+
sections: Mapping[tuple[str, ...], SectionOverride],
|
|
435
|
+
) -> dict[str, dict[str, str]]:
|
|
436
|
+
serialized: dict[str, dict[str, str]] = {}
|
|
437
|
+
for path, section_override in sections.items():
|
|
438
|
+
key = "/".join(path)
|
|
439
|
+
serialized[key] = {
|
|
440
|
+
"expected_hash": str(section_override.expected_hash),
|
|
441
|
+
"body": section_override.body,
|
|
442
|
+
}
|
|
443
|
+
return serialized
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
def serialize_tools(
|
|
447
|
+
tools: Mapping[str, ToolOverride],
|
|
448
|
+
) -> dict[str, dict[str, JSONValue]]:
|
|
449
|
+
serialized: dict[str, dict[str, JSONValue]] = {}
|
|
450
|
+
for name, tool_override in tools.items():
|
|
451
|
+
serialized[name] = {
|
|
452
|
+
"expected_contract_hash": str(tool_override.expected_contract_hash),
|
|
453
|
+
"description": tool_override.description,
|
|
454
|
+
"param_descriptions": dict(tool_override.param_descriptions),
|
|
455
|
+
}
|
|
456
|
+
return serialized
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
def seed_sections(
|
|
460
|
+
prompt: PromptLike,
|
|
461
|
+
descriptor: PromptDescriptor,
|
|
462
|
+
) -> dict[tuple[str, ...], SectionOverride]:
|
|
463
|
+
section_lookup = {node.path: node.section for node in prompt.sections}
|
|
464
|
+
seeded: dict[tuple[str, ...], SectionOverride] = {}
|
|
465
|
+
for section in descriptor.sections:
|
|
466
|
+
section_obj = section_lookup.get(section.path)
|
|
467
|
+
if section_obj is None:
|
|
468
|
+
raise PromptOverridesError(
|
|
469
|
+
f"Prompt missing section for descriptor path {'/'.join(section.path)}."
|
|
470
|
+
)
|
|
471
|
+
template = section_obj.original_body_template()
|
|
472
|
+
if template is None:
|
|
473
|
+
raise PromptOverridesError(
|
|
474
|
+
"Cannot seed override for section without template."
|
|
475
|
+
)
|
|
476
|
+
seeded[section.path] = SectionOverride(
|
|
477
|
+
expected_hash=section.content_hash,
|
|
478
|
+
body=template,
|
|
479
|
+
)
|
|
480
|
+
return seeded
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
def seed_tools(
|
|
484
|
+
prompt: PromptLike,
|
|
485
|
+
descriptor: PromptDescriptor,
|
|
486
|
+
) -> dict[str, ToolOverride]:
|
|
487
|
+
if not descriptor.tools:
|
|
488
|
+
return {}
|
|
489
|
+
tool_lookup: dict[str, ToolContractProtocol] = {}
|
|
490
|
+
for node in prompt.sections:
|
|
491
|
+
for tool in node.section.tools():
|
|
492
|
+
tool_lookup[tool.name] = tool
|
|
493
|
+
seeded: dict[str, ToolOverride] = {}
|
|
494
|
+
for tool in descriptor.tools:
|
|
495
|
+
tool_obj = tool_lookup.get(tool.name)
|
|
496
|
+
if tool_obj is None:
|
|
497
|
+
raise PromptOverridesError(
|
|
498
|
+
f"Prompt missing tool for descriptor entry {tool.name}."
|
|
499
|
+
)
|
|
500
|
+
param_descriptions = _collect_param_descriptions(tool_obj)
|
|
501
|
+
seeded[tool.name] = ToolOverride(
|
|
502
|
+
name=tool.name,
|
|
503
|
+
expected_contract_hash=tool.contract_hash,
|
|
504
|
+
description=tool_obj.description,
|
|
505
|
+
param_descriptions=param_descriptions,
|
|
506
|
+
)
|
|
507
|
+
return seeded
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
def _collect_param_descriptions(tool: ToolContractProtocol) -> dict[str, str]:
|
|
511
|
+
params_type = getattr(tool, "params_type", None)
|
|
512
|
+
if not isinstance(params_type, type) or not is_dataclass(params_type):
|
|
513
|
+
return {}
|
|
514
|
+
descriptions: dict[str, str] = {}
|
|
515
|
+
for field in fields(params_type):
|
|
516
|
+
description = field.metadata.get("description") if field.metadata else None
|
|
517
|
+
if isinstance(description, str) and description:
|
|
518
|
+
descriptions[field.name] = description
|
|
519
|
+
return descriptions
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
__all__ = [
|
|
523
|
+
"FORMAT_VERSION",
|
|
524
|
+
"filter_override_for_descriptor",
|
|
525
|
+
"load_sections",
|
|
526
|
+
"load_tools",
|
|
527
|
+
"seed_sections",
|
|
528
|
+
"seed_tools",
|
|
529
|
+
"serialize_sections",
|
|
530
|
+
"serialize_tools",
|
|
531
|
+
"validate_header",
|
|
532
|
+
"validate_sections_for_write",
|
|
533
|
+
"validate_tools_for_write",
|
|
534
|
+
]
|