weakincentives 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of weakincentives might be problematic. Click here for more details.
- weakincentives/__init__.py +26 -2
- weakincentives/adapters/__init__.py +6 -5
- weakincentives/adapters/core.py +7 -17
- weakincentives/adapters/litellm.py +594 -0
- weakincentives/adapters/openai.py +286 -57
- weakincentives/events.py +103 -0
- weakincentives/examples/__init__.py +67 -0
- weakincentives/examples/code_review_prompt.py +118 -0
- weakincentives/examples/code_review_session.py +171 -0
- weakincentives/examples/code_review_tools.py +376 -0
- weakincentives/{prompts → prompt}/__init__.py +6 -8
- weakincentives/{prompts → prompt}/_types.py +1 -1
- weakincentives/{prompts/text.py → prompt/markdown.py} +19 -9
- weakincentives/{prompts → prompt}/prompt.py +216 -66
- weakincentives/{prompts → prompt}/response_format.py +9 -6
- weakincentives/{prompts → prompt}/section.py +25 -4
- weakincentives/{prompts/structured.py → prompt/structured_output.py} +16 -5
- weakincentives/{prompts → prompt}/tool.py +6 -6
- weakincentives/prompt/versioning.py +144 -0
- weakincentives/serde/__init__.py +0 -14
- weakincentives/serde/dataclass_serde.py +3 -17
- weakincentives/session/__init__.py +31 -0
- weakincentives/session/reducers.py +60 -0
- weakincentives/session/selectors.py +45 -0
- weakincentives/session/session.py +168 -0
- weakincentives/tools/__init__.py +69 -0
- weakincentives/tools/errors.py +22 -0
- weakincentives/tools/planning.py +538 -0
- weakincentives/tools/vfs.py +590 -0
- weakincentives-0.3.0.dist-info/METADATA +231 -0
- weakincentives-0.3.0.dist-info/RECORD +35 -0
- weakincentives-0.2.0.dist-info/METADATA +0 -173
- weakincentives-0.2.0.dist-info/RECORD +0 -20
- /weakincentives/{prompts → prompt}/errors.py +0 -0
- {weakincentives-0.2.0.dist-info → weakincentives-0.3.0.dist-info}/WHEEL +0 -0
- {weakincentives-0.2.0.dist-info → weakincentives-0.3.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -10,7 +10,7 @@
|
|
|
10
10
|
# See the License for the specific language governing permissions and
|
|
11
11
|
# limitations under the License.
|
|
12
12
|
|
|
13
|
-
"""Internal typing helpers for the
|
|
13
|
+
"""Internal typing helpers for the :mod:`weakincentives.prompt` package."""
|
|
14
14
|
|
|
15
15
|
from __future__ import annotations
|
|
16
16
|
|
|
@@ -23,32 +23,39 @@ from .errors import PromptRenderError
|
|
|
23
23
|
from .section import Section
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
class
|
|
27
|
-
"""Render markdown
|
|
26
|
+
class MarkdownSection[ParamsT: SupportsDataclass](Section[ParamsT]):
|
|
27
|
+
"""Render markdown content using :class:`string.Template`."""
|
|
28
28
|
|
|
29
29
|
def __init__(
|
|
30
30
|
self,
|
|
31
31
|
*,
|
|
32
32
|
title: str,
|
|
33
|
-
|
|
34
|
-
|
|
33
|
+
template: str,
|
|
34
|
+
key: str,
|
|
35
|
+
default_params: ParamsT | None = None,
|
|
35
36
|
children: Sequence[object] | None = None,
|
|
36
37
|
enabled: Callable[[ParamsT], bool] | None = None,
|
|
37
38
|
tools: Sequence[object] | None = None,
|
|
38
39
|
) -> None:
|
|
40
|
+
self.template = template
|
|
39
41
|
super().__init__(
|
|
40
42
|
title=title,
|
|
41
|
-
|
|
43
|
+
key=key,
|
|
44
|
+
default_params=default_params,
|
|
42
45
|
children=children,
|
|
43
46
|
enabled=enabled,
|
|
44
47
|
tools=tools,
|
|
45
48
|
)
|
|
46
|
-
self.body = body
|
|
47
49
|
|
|
48
50
|
def render(self, params: ParamsT, depth: int) -> str:
|
|
51
|
+
return self.render_with_template(self.template, params, depth)
|
|
52
|
+
|
|
53
|
+
def render_with_template(
|
|
54
|
+
self, template_text: str, params: ParamsT, depth: int
|
|
55
|
+
) -> str:
|
|
49
56
|
heading_level = "#" * (depth + 2)
|
|
50
57
|
heading = f"{heading_level} {self.title.strip()}"
|
|
51
|
-
template = Template(textwrap.dedent(
|
|
58
|
+
template = Template(textwrap.dedent(template_text).strip())
|
|
52
59
|
try:
|
|
53
60
|
normalized_params = self._normalize_params(params)
|
|
54
61
|
rendered_body = template.substitute(normalized_params)
|
|
@@ -63,7 +70,7 @@ class TextSection[ParamsT: SupportsDataclass](Section[ParamsT]):
|
|
|
63
70
|
return heading
|
|
64
71
|
|
|
65
72
|
def placeholder_names(self) -> set[str]:
|
|
66
|
-
template = Template(textwrap.dedent(self.
|
|
73
|
+
template = Template(textwrap.dedent(self.template).strip())
|
|
67
74
|
placeholders: set[str] = set()
|
|
68
75
|
for match in template.pattern.finditer(template.template):
|
|
69
76
|
named = match.group("named")
|
|
@@ -85,5 +92,8 @@ class TextSection[ParamsT: SupportsDataclass](Section[ParamsT]):
|
|
|
85
92
|
|
|
86
93
|
return {field.name: getattr(params, field.name) for field in fields(params)}
|
|
87
94
|
|
|
95
|
+
def original_body_template(self) -> str:
|
|
96
|
+
return self.template
|
|
97
|
+
|
|
88
98
|
|
|
89
|
-
__all__ = ["
|
|
99
|
+
__all__ = ["MarkdownSection"]
|
|
@@ -12,8 +12,9 @@
|
|
|
12
12
|
|
|
13
13
|
from __future__ import annotations
|
|
14
14
|
|
|
15
|
-
from collections.abc import Callable, Iterator, Sequence
|
|
15
|
+
from collections.abc import Callable, Iterator, Mapping, Sequence
|
|
16
16
|
from dataclasses import dataclass, field, fields, is_dataclass, replace
|
|
17
|
+
from types import MappingProxyType
|
|
17
18
|
from typing import Any, ClassVar, Literal, cast, get_args, get_origin
|
|
18
19
|
|
|
19
20
|
from ._types import SupportsDataclass
|
|
@@ -25,19 +26,25 @@ from .errors import (
|
|
|
25
26
|
from .response_format import ResponseFormatParams, ResponseFormatSection
|
|
26
27
|
from .section import Section
|
|
27
28
|
from .tool import Tool
|
|
29
|
+
from .versioning import PromptVersionStore, ToolOverride
|
|
30
|
+
|
|
31
|
+
_EMPTY_TOOL_PARAM_DESCRIPTIONS: Mapping[str, Mapping[str, str]] = MappingProxyType({})
|
|
28
32
|
|
|
29
33
|
|
|
30
34
|
@dataclass(frozen=True, slots=True)
|
|
31
|
-
class RenderedPrompt[OutputT
|
|
35
|
+
class RenderedPrompt[OutputT]:
|
|
32
36
|
"""Rendered prompt text paired with structured output metadata."""
|
|
33
37
|
|
|
34
38
|
text: str
|
|
35
39
|
output_type: type[Any] | None
|
|
36
|
-
|
|
40
|
+
container: Literal["object", "array"] | None
|
|
37
41
|
allow_extra_keys: bool | None
|
|
38
42
|
_tools: tuple[Tool[SupportsDataclass, SupportsDataclass], ...] = field(
|
|
39
43
|
default_factory=tuple
|
|
40
44
|
)
|
|
45
|
+
_tool_param_descriptions: Mapping[str, Mapping[str, str]] = field(
|
|
46
|
+
default=_EMPTY_TOOL_PARAM_DESCRIPTIONS
|
|
47
|
+
)
|
|
41
48
|
|
|
42
49
|
def __str__(self) -> str: # pragma: no cover - convenience for logging
|
|
43
50
|
return self.text
|
|
@@ -48,6 +55,14 @@ class RenderedPrompt[OutputT = Any]:
|
|
|
48
55
|
|
|
49
56
|
return self._tools
|
|
50
57
|
|
|
58
|
+
@property
|
|
59
|
+
def tool_param_descriptions(
|
|
60
|
+
self,
|
|
61
|
+
) -> Mapping[str, Mapping[str, str]]:
|
|
62
|
+
"""Description patches keyed by tool name."""
|
|
63
|
+
|
|
64
|
+
return self._tool_param_descriptions
|
|
65
|
+
|
|
51
66
|
|
|
52
67
|
def _clone_dataclass(instance: SupportsDataclass) -> SupportsDataclass:
|
|
53
68
|
"""Return a shallow copy of the provided dataclass instance."""
|
|
@@ -64,7 +79,7 @@ def _format_specialization_argument(argument: object | None) -> str:
|
|
|
64
79
|
|
|
65
80
|
|
|
66
81
|
@dataclass(frozen=True, slots=True)
|
|
67
|
-
class
|
|
82
|
+
class SectionNode[ParamsT: SupportsDataclass]:
|
|
68
83
|
"""Flattened view of a section within a prompt."""
|
|
69
84
|
|
|
70
85
|
section: Section[ParamsT]
|
|
@@ -72,7 +87,7 @@ class PromptSectionNode[ParamsT: SupportsDataclass]:
|
|
|
72
87
|
path: SectionPath
|
|
73
88
|
|
|
74
89
|
|
|
75
|
-
class Prompt[OutputT
|
|
90
|
+
class Prompt[OutputT]:
|
|
76
91
|
"""Coordinate prompt sections and their parameter bindings."""
|
|
77
92
|
|
|
78
93
|
_output_container_spec: ClassVar[Literal["object", "array"] | None] = None
|
|
@@ -103,19 +118,29 @@ class Prompt[OutputT = Any]:
|
|
|
103
118
|
def __init__(
|
|
104
119
|
self,
|
|
105
120
|
*,
|
|
121
|
+
ns: str,
|
|
122
|
+
key: str,
|
|
106
123
|
name: str | None = None,
|
|
107
124
|
sections: Sequence[Section[Any]] | None = None,
|
|
108
125
|
inject_output_instructions: bool = True,
|
|
109
126
|
allow_extra_keys: bool = False,
|
|
110
127
|
) -> None:
|
|
128
|
+
stripped_ns = ns.strip()
|
|
129
|
+
if not stripped_ns:
|
|
130
|
+
raise PromptValidationError("Prompt namespace must be a non-empty string.")
|
|
131
|
+
stripped_key = key.strip()
|
|
132
|
+
if not stripped_key:
|
|
133
|
+
raise PromptValidationError("Prompt key must be a non-empty string.")
|
|
134
|
+
self.ns = stripped_ns
|
|
135
|
+
self.key = stripped_key
|
|
111
136
|
self.name = name
|
|
112
137
|
base_sections: list[Section[SupportsDataclass]] = [
|
|
113
138
|
cast(Section[SupportsDataclass], section) for section in sections or ()
|
|
114
139
|
]
|
|
115
140
|
self._sections: tuple[Section[SupportsDataclass], ...] = tuple(base_sections)
|
|
116
|
-
self._section_nodes: list[
|
|
141
|
+
self._section_nodes: list[SectionNode[SupportsDataclass]] = []
|
|
117
142
|
self._params_registry: dict[
|
|
118
|
-
type[SupportsDataclass], list[
|
|
143
|
+
type[SupportsDataclass], list[SectionNode[SupportsDataclass]]
|
|
119
144
|
] = {}
|
|
120
145
|
self._defaults_by_path: dict[SectionPath, SupportsDataclass] = {}
|
|
121
146
|
self._defaults_by_type: dict[type[SupportsDataclass], SupportsDataclass] = {}
|
|
@@ -134,7 +159,7 @@ class Prompt[OutputT = Any]:
|
|
|
134
159
|
self.inject_output_instructions = inject_output_instructions
|
|
135
160
|
|
|
136
161
|
for section in base_sections:
|
|
137
|
-
self._register_section(section, path=(section.
|
|
162
|
+
self._register_section(section, path=(section.key,), depth=0)
|
|
138
163
|
|
|
139
164
|
self._response_section: ResponseFormatSection | None = None
|
|
140
165
|
if self._output_type is not None and self._output_container is not None:
|
|
@@ -148,52 +173,68 @@ class Prompt[OutputT = Any]:
|
|
|
148
173
|
self._sections += (section_for_registry,)
|
|
149
174
|
self._register_section(
|
|
150
175
|
section_for_registry,
|
|
151
|
-
path=(response_section.
|
|
176
|
+
path=(response_section.key,),
|
|
152
177
|
depth=0,
|
|
153
178
|
)
|
|
154
179
|
|
|
155
|
-
def render(
|
|
180
|
+
def render(
|
|
181
|
+
self,
|
|
182
|
+
*params: SupportsDataclass,
|
|
183
|
+
inject_output_instructions: bool | None = None,
|
|
184
|
+
) -> RenderedPrompt[OutputT]:
|
|
156
185
|
"""Render the prompt using provided parameter dataclass instances."""
|
|
157
186
|
|
|
158
187
|
param_lookup = self._collect_param_lookup(params)
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
params_type = node.section.params
|
|
164
|
-
try:
|
|
165
|
-
rendered = node.section.render(section_params, node.depth)
|
|
166
|
-
except PromptRenderError as error:
|
|
167
|
-
if error.section_path and error.dataclass_type:
|
|
168
|
-
raise
|
|
169
|
-
raise PromptRenderError(
|
|
170
|
-
error.message,
|
|
171
|
-
section_path=node.path,
|
|
172
|
-
dataclass_type=params_type,
|
|
173
|
-
placeholder=error.placeholder,
|
|
174
|
-
) from error
|
|
175
|
-
except Exception as error: # pragma: no cover - defensive guard
|
|
176
|
-
raise PromptRenderError(
|
|
177
|
-
"Section rendering failed.",
|
|
178
|
-
section_path=node.path,
|
|
179
|
-
dataclass_type=params_type,
|
|
180
|
-
) from error
|
|
181
|
-
|
|
182
|
-
section_tools = node.section.tools()
|
|
183
|
-
if section_tools:
|
|
184
|
-
collected_tools.extend(section_tools)
|
|
185
|
-
|
|
186
|
-
if rendered:
|
|
187
|
-
rendered_sections.append(rendered)
|
|
188
|
+
return self._render_internal(
|
|
189
|
+
param_lookup,
|
|
190
|
+
inject_output_instructions=inject_output_instructions,
|
|
191
|
+
)
|
|
188
192
|
|
|
189
|
-
|
|
193
|
+
def render_with_overrides(
|
|
194
|
+
self,
|
|
195
|
+
*params: SupportsDataclass,
|
|
196
|
+
version_store: PromptVersionStore,
|
|
197
|
+
tag: str = "latest",
|
|
198
|
+
inject_output_instructions: bool | None = None,
|
|
199
|
+
) -> RenderedPrompt[OutputT]:
|
|
200
|
+
"""Render the prompt using overrides supplied by a version store."""
|
|
201
|
+
|
|
202
|
+
from .versioning import PromptDescriptor
|
|
203
|
+
|
|
204
|
+
descriptor = PromptDescriptor.from_prompt(self)
|
|
205
|
+
override = version_store.resolve(descriptor=descriptor, tag=tag)
|
|
206
|
+
|
|
207
|
+
overrides: dict[SectionPath, str] = {}
|
|
208
|
+
tool_overrides: dict[str, ToolOverride] = {}
|
|
209
|
+
if (
|
|
210
|
+
override is not None
|
|
211
|
+
and override.ns == descriptor.ns
|
|
212
|
+
and override.prompt_key == descriptor.key
|
|
213
|
+
):
|
|
214
|
+
descriptor_index = {
|
|
215
|
+
section.path: section.content_hash for section in descriptor.sections
|
|
216
|
+
}
|
|
217
|
+
for path, body in override.overrides.items():
|
|
218
|
+
if path in descriptor_index:
|
|
219
|
+
overrides[path] = body
|
|
220
|
+
if override.tool_overrides:
|
|
221
|
+
descriptor_tool_index = {
|
|
222
|
+
tool.name: tool.contract_hash for tool in descriptor.tools
|
|
223
|
+
}
|
|
224
|
+
for name, tool_override in override.tool_overrides.items():
|
|
225
|
+
descriptor_hash = descriptor_tool_index.get(name)
|
|
226
|
+
if (
|
|
227
|
+
descriptor_hash is not None
|
|
228
|
+
and tool_override.expected_contract_hash == descriptor_hash
|
|
229
|
+
):
|
|
230
|
+
tool_overrides[name] = tool_override
|
|
190
231
|
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
232
|
+
param_lookup = self._collect_param_lookup(params)
|
|
233
|
+
return self._render_internal(
|
|
234
|
+
param_lookup,
|
|
235
|
+
overrides,
|
|
236
|
+
tool_overrides,
|
|
237
|
+
inject_output_instructions=inject_output_instructions,
|
|
197
238
|
)
|
|
198
239
|
|
|
199
240
|
def _register_section(
|
|
@@ -203,7 +244,7 @@ class Prompt[OutputT = Any]:
|
|
|
203
244
|
path: SectionPath,
|
|
204
245
|
depth: int,
|
|
205
246
|
) -> None:
|
|
206
|
-
params_type = section.
|
|
247
|
+
params_type = section.param_type
|
|
207
248
|
if not is_dataclass(params_type):
|
|
208
249
|
raise PromptValidationError(
|
|
209
250
|
"Section params must be a dataclass.",
|
|
@@ -211,14 +252,14 @@ class Prompt[OutputT = Any]:
|
|
|
211
252
|
dataclass_type=params_type,
|
|
212
253
|
)
|
|
213
254
|
|
|
214
|
-
node:
|
|
255
|
+
node: SectionNode[SupportsDataclass] = SectionNode(
|
|
215
256
|
section=section, depth=depth, path=path
|
|
216
257
|
)
|
|
217
258
|
self._section_nodes.append(node)
|
|
218
259
|
self._params_registry.setdefault(params_type, []).append(node)
|
|
219
260
|
|
|
220
|
-
if section.
|
|
221
|
-
default_value = section.
|
|
261
|
+
if section.default_params is not None:
|
|
262
|
+
default_value = section.default_params
|
|
222
263
|
if isinstance(default_value, type) or not is_dataclass(default_value):
|
|
223
264
|
raise PromptValidationError(
|
|
224
265
|
"Section defaults must be dataclass instances.",
|
|
@@ -250,15 +291,15 @@ class Prompt[OutputT = Any]:
|
|
|
250
291
|
self._register_section_tools(section, path)
|
|
251
292
|
|
|
252
293
|
for child in section.children:
|
|
253
|
-
child_path = path + (child.
|
|
294
|
+
child_path = path + (child.key,)
|
|
254
295
|
self._register_section(child, path=child_path, depth=depth + 1)
|
|
255
296
|
|
|
256
297
|
@property
|
|
257
|
-
def sections(self) -> tuple[
|
|
298
|
+
def sections(self) -> tuple[SectionNode[SupportsDataclass], ...]:
|
|
258
299
|
return tuple(self._section_nodes)
|
|
259
300
|
|
|
260
301
|
@property
|
|
261
|
-
def
|
|
302
|
+
def param_types(self) -> set[type[SupportsDataclass]]:
|
|
262
303
|
return set(self._params_registry.keys())
|
|
263
304
|
|
|
264
305
|
def _resolve_output_spec(
|
|
@@ -335,12 +376,103 @@ class Prompt[OutputT = Any]:
|
|
|
335
376
|
lookup[params_type] = value
|
|
336
377
|
return lookup
|
|
337
378
|
|
|
379
|
+
def _render_internal(
|
|
380
|
+
self,
|
|
381
|
+
param_lookup: Mapping[type[SupportsDataclass], SupportsDataclass],
|
|
382
|
+
overrides: Mapping[SectionPath, str] | None = None,
|
|
383
|
+
tool_overrides: Mapping[str, ToolOverride] | None = None,
|
|
384
|
+
*,
|
|
385
|
+
inject_output_instructions: bool | None = None,
|
|
386
|
+
) -> RenderedPrompt[OutputT]:
|
|
387
|
+
rendered_sections: list[str] = []
|
|
388
|
+
collected_tools: list[Tool[SupportsDataclass, SupportsDataclass]] = []
|
|
389
|
+
override_lookup = dict(overrides or {})
|
|
390
|
+
tool_override_lookup = dict(tool_overrides or {})
|
|
391
|
+
field_description_patches: dict[str, dict[str, str]] = {}
|
|
392
|
+
|
|
393
|
+
for node, section_params in self._iter_enabled_sections(
|
|
394
|
+
dict(param_lookup),
|
|
395
|
+
inject_output_instructions=inject_output_instructions,
|
|
396
|
+
):
|
|
397
|
+
override_body = override_lookup.get(node.path)
|
|
398
|
+
rendered = self._render_section(node, section_params, override_body)
|
|
399
|
+
|
|
400
|
+
section_tools = node.section.tools()
|
|
401
|
+
if section_tools:
|
|
402
|
+
for tool in section_tools:
|
|
403
|
+
override = tool_override_lookup.get(tool.name)
|
|
404
|
+
patched_tool = tool
|
|
405
|
+
if override is not None:
|
|
406
|
+
if (
|
|
407
|
+
override.description is not None
|
|
408
|
+
and override.description != tool.description
|
|
409
|
+
):
|
|
410
|
+
patched_tool = replace(
|
|
411
|
+
tool, description=override.description
|
|
412
|
+
)
|
|
413
|
+
if override.param_descriptions:
|
|
414
|
+
field_description_patches[tool.name] = dict(
|
|
415
|
+
override.param_descriptions
|
|
416
|
+
)
|
|
417
|
+
collected_tools.append(patched_tool)
|
|
418
|
+
|
|
419
|
+
if rendered:
|
|
420
|
+
rendered_sections.append(rendered)
|
|
421
|
+
|
|
422
|
+
text = "\n\n".join(rendered_sections)
|
|
423
|
+
|
|
424
|
+
return RenderedPrompt(
|
|
425
|
+
text=text,
|
|
426
|
+
output_type=self._output_type,
|
|
427
|
+
container=self._output_container,
|
|
428
|
+
allow_extra_keys=self._allow_extra_keys,
|
|
429
|
+
_tools=tuple(collected_tools),
|
|
430
|
+
_tool_param_descriptions=_freeze_tool_param_descriptions(
|
|
431
|
+
field_description_patches
|
|
432
|
+
),
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
def _render_section(
|
|
436
|
+
self,
|
|
437
|
+
node: SectionNode[SupportsDataclass],
|
|
438
|
+
section_params: SupportsDataclass,
|
|
439
|
+
override_body: str | None,
|
|
440
|
+
) -> str:
|
|
441
|
+
params_type = node.section.param_type
|
|
442
|
+
try:
|
|
443
|
+
render_override = getattr(node.section, "render_with_template", None)
|
|
444
|
+
if override_body is not None and callable(render_override):
|
|
445
|
+
override_renderer = cast(
|
|
446
|
+
Callable[[str, SupportsDataclass, int], str],
|
|
447
|
+
render_override,
|
|
448
|
+
)
|
|
449
|
+
rendered = override_renderer(override_body, section_params, node.depth)
|
|
450
|
+
else:
|
|
451
|
+
rendered = node.section.render(section_params, node.depth)
|
|
452
|
+
except PromptRenderError as error:
|
|
453
|
+
if error.section_path and error.dataclass_type:
|
|
454
|
+
raise
|
|
455
|
+
raise PromptRenderError(
|
|
456
|
+
error.message,
|
|
457
|
+
section_path=node.path,
|
|
458
|
+
dataclass_type=params_type,
|
|
459
|
+
placeholder=error.placeholder,
|
|
460
|
+
) from error
|
|
461
|
+
except Exception as error: # pragma: no cover - defensive guard
|
|
462
|
+
raise PromptRenderError(
|
|
463
|
+
"Section rendering failed.",
|
|
464
|
+
section_path=node.path,
|
|
465
|
+
dataclass_type=params_type,
|
|
466
|
+
) from error
|
|
467
|
+
|
|
468
|
+
return rendered
|
|
469
|
+
|
|
338
470
|
def _resolve_section_params(
|
|
339
471
|
self,
|
|
340
|
-
node:
|
|
472
|
+
node: SectionNode[SupportsDataclass],
|
|
341
473
|
param_lookup: dict[type[SupportsDataclass], SupportsDataclass],
|
|
342
474
|
) -> SupportsDataclass:
|
|
343
|
-
params_type = node.section.
|
|
475
|
+
params_type = node.section.param_type
|
|
344
476
|
section_params: SupportsDataclass | None = param_lookup.get(params_type)
|
|
345
477
|
|
|
346
478
|
if section_params is None:
|
|
@@ -367,7 +499,9 @@ class Prompt[OutputT = Any]:
|
|
|
367
499
|
def _iter_enabled_sections(
|
|
368
500
|
self,
|
|
369
501
|
param_lookup: dict[type[SupportsDataclass], SupportsDataclass],
|
|
370
|
-
|
|
502
|
+
*,
|
|
503
|
+
inject_output_instructions: bool | None = None,
|
|
504
|
+
) -> Iterator[tuple[SectionNode[SupportsDataclass], SupportsDataclass]]:
|
|
371
505
|
skip_depth: int | None = None
|
|
372
506
|
|
|
373
507
|
for node in self._section_nodes:
|
|
@@ -378,14 +512,19 @@ class Prompt[OutputT = Any]:
|
|
|
378
512
|
|
|
379
513
|
section_params = self._resolve_section_params(node, param_lookup)
|
|
380
514
|
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
515
|
+
if node.section is self._response_section and (
|
|
516
|
+
inject_output_instructions is not None
|
|
517
|
+
):
|
|
518
|
+
enabled = inject_output_instructions
|
|
519
|
+
else:
|
|
520
|
+
try:
|
|
521
|
+
enabled = node.section.is_enabled(section_params)
|
|
522
|
+
except Exception as error: # pragma: no cover - defensive guard
|
|
523
|
+
raise PromptRenderError(
|
|
524
|
+
"Section enabled predicate failed.",
|
|
525
|
+
section_path=node.path,
|
|
526
|
+
dataclass_type=node.section.param_type,
|
|
527
|
+
) from error
|
|
389
528
|
|
|
390
529
|
if not enabled:
|
|
391
530
|
skip_depth = node.depth
|
|
@@ -408,7 +547,7 @@ class Prompt[OutputT = Any]:
|
|
|
408
547
|
raise PromptValidationError(
|
|
409
548
|
"Section tools() must return Tool instances.",
|
|
410
549
|
section_path=path,
|
|
411
|
-
dataclass_type=section.
|
|
550
|
+
dataclass_type=section.param_type,
|
|
412
551
|
)
|
|
413
552
|
tool: Tool[SupportsDataclass, SupportsDataclass] = cast(
|
|
414
553
|
Tool[SupportsDataclass, SupportsDataclass], tool_candidate
|
|
@@ -437,4 +576,15 @@ class Prompt[OutputT = Any]:
|
|
|
437
576
|
self._tool_name_registry[tool.name] = path
|
|
438
577
|
|
|
439
578
|
|
|
440
|
-
__all__ = ["Prompt", "
|
|
579
|
+
__all__ = ["Prompt", "RenderedPrompt", "SectionNode"]
|
|
580
|
+
|
|
581
|
+
|
|
582
|
+
def _freeze_tool_param_descriptions(
|
|
583
|
+
descriptions: Mapping[str, dict[str, str]],
|
|
584
|
+
) -> Mapping[str, Mapping[str, str]]:
|
|
585
|
+
if not descriptions:
|
|
586
|
+
return MappingProxyType({})
|
|
587
|
+
frozen: dict[str, Mapping[str, str]] = {}
|
|
588
|
+
for name, field_mapping in descriptions.items():
|
|
589
|
+
frozen[name] = MappingProxyType(dict(field_mapping))
|
|
590
|
+
return MappingProxyType(frozen)
|
|
@@ -14,9 +14,9 @@ from __future__ import annotations
|
|
|
14
14
|
|
|
15
15
|
from collections.abc import Callable
|
|
16
16
|
from dataclasses import dataclass
|
|
17
|
-
from typing import Literal
|
|
17
|
+
from typing import Final, Literal
|
|
18
18
|
|
|
19
|
-
from .
|
|
19
|
+
from .markdown import MarkdownSection
|
|
20
20
|
|
|
21
21
|
__all__ = ["ResponseFormatParams", "ResponseFormatSection"]
|
|
22
22
|
|
|
@@ -30,14 +30,16 @@ class ResponseFormatParams:
|
|
|
30
30
|
extra_clause: str
|
|
31
31
|
|
|
32
32
|
|
|
33
|
-
_RESPONSE_FORMAT_BODY
|
|
33
|
+
_RESPONSE_FORMAT_BODY: Final[
|
|
34
|
+
str
|
|
35
|
+
] = """Return ONLY a single fenced JSON code block. Do not include any text
|
|
34
36
|
before or after the block.
|
|
35
37
|
|
|
36
38
|
The top-level JSON value MUST be ${article} ${container} that matches the fields
|
|
37
39
|
of the expected schema${extra_clause}"""
|
|
38
40
|
|
|
39
41
|
|
|
40
|
-
class ResponseFormatSection(
|
|
42
|
+
class ResponseFormatSection(MarkdownSection[ResponseFormatParams]):
|
|
41
43
|
"""Internal section that appends JSON-only response instructions."""
|
|
42
44
|
|
|
43
45
|
def __init__(
|
|
@@ -48,7 +50,8 @@ class ResponseFormatSection(TextSection[ResponseFormatParams]):
|
|
|
48
50
|
) -> None:
|
|
49
51
|
super().__init__(
|
|
50
52
|
title="Response Format",
|
|
51
|
-
|
|
52
|
-
|
|
53
|
+
key="response-format",
|
|
54
|
+
template=_RESPONSE_FORMAT_BODY,
|
|
55
|
+
default_params=params,
|
|
53
56
|
enabled=enabled,
|
|
54
57
|
)
|
|
@@ -12,15 +12,20 @@
|
|
|
12
12
|
|
|
13
13
|
from __future__ import annotations
|
|
14
14
|
|
|
15
|
+
import re
|
|
15
16
|
from abc import ABC, abstractmethod
|
|
16
17
|
from collections.abc import Callable, Sequence
|
|
17
|
-
from typing import TYPE_CHECKING, ClassVar, cast
|
|
18
|
+
from typing import TYPE_CHECKING, ClassVar, Final, cast
|
|
18
19
|
|
|
19
20
|
if TYPE_CHECKING:
|
|
20
21
|
from .tool import Tool
|
|
21
22
|
|
|
22
23
|
from ._types import SupportsDataclass
|
|
23
24
|
|
|
25
|
+
_SECTION_KEY_PATTERN: Final[re.Pattern[str]] = re.compile(
|
|
26
|
+
r"^[a-z0-9][a-z0-9._-]{0,63}$"
|
|
27
|
+
)
|
|
28
|
+
|
|
24
29
|
|
|
25
30
|
class Section[ParamsT: SupportsDataclass](ABC):
|
|
26
31
|
"""Abstract building block for prompt content."""
|
|
@@ -31,7 +36,8 @@ class Section[ParamsT: SupportsDataclass](ABC):
|
|
|
31
36
|
self,
|
|
32
37
|
*,
|
|
33
38
|
title: str,
|
|
34
|
-
|
|
39
|
+
key: str,
|
|
40
|
+
default_params: ParamsT | None = None,
|
|
35
41
|
children: Sequence[object] | None = None,
|
|
36
42
|
enabled: Callable[[ParamsT], bool] | None = None,
|
|
37
43
|
tools: Sequence[object] | None = None,
|
|
@@ -45,9 +51,10 @@ class Section[ParamsT: SupportsDataclass](ABC):
|
|
|
45
51
|
)
|
|
46
52
|
|
|
47
53
|
self.params_type: type[ParamsT] = params_type
|
|
48
|
-
self.
|
|
54
|
+
self.param_type: type[ParamsT] = params_type
|
|
49
55
|
self.title = title
|
|
50
|
-
self.
|
|
56
|
+
self.key = self._normalize_key(key)
|
|
57
|
+
self.default_params = default_params
|
|
51
58
|
|
|
52
59
|
normalized_children: list[Section[SupportsDataclass]] = []
|
|
53
60
|
for child in children or ():
|
|
@@ -81,6 +88,11 @@ class Section[ParamsT: SupportsDataclass](ABC):
|
|
|
81
88
|
|
|
82
89
|
return self._tools
|
|
83
90
|
|
|
91
|
+
def original_body_template(self) -> str | None:
|
|
92
|
+
"""Return the template text that participates in hashing, when available."""
|
|
93
|
+
|
|
94
|
+
return None
|
|
95
|
+
|
|
84
96
|
@classmethod
|
|
85
97
|
def __class_getitem__(cls, item: object) -> type[Section[SupportsDataclass]]:
|
|
86
98
|
params_type = cls._normalize_generic_argument(item)
|
|
@@ -94,6 +106,15 @@ class Section[ParamsT: SupportsDataclass](ABC):
|
|
|
94
106
|
_SpecializedSection._params_type = cast(type[SupportsDataclass], params_type)
|
|
95
107
|
return _SpecializedSection # type: ignore[return-value]
|
|
96
108
|
|
|
109
|
+
@staticmethod
|
|
110
|
+
def _normalize_key(key: str) -> str:
|
|
111
|
+
normalized = key.strip().lower()
|
|
112
|
+
if not normalized:
|
|
113
|
+
raise ValueError("Section key must be a non-empty string.")
|
|
114
|
+
if not _SECTION_KEY_PATTERN.match(normalized):
|
|
115
|
+
raise ValueError("Section key must match ^[a-z0-9][a-z0-9._-]{0,63}$.")
|
|
116
|
+
return normalized
|
|
117
|
+
|
|
97
118
|
@staticmethod
|
|
98
119
|
def _normalize_generic_argument(item: object) -> object:
|
|
99
120
|
if isinstance(item, tuple):
|
|
@@ -15,14 +15,18 @@ from __future__ import annotations
|
|
|
15
15
|
import json
|
|
16
16
|
import re
|
|
17
17
|
from collections.abc import Mapping
|
|
18
|
-
from typing import Any, Literal, cast
|
|
18
|
+
from typing import Any, Final, Literal, cast
|
|
19
19
|
|
|
20
20
|
from ..serde.dataclass_serde import parse as parse_dataclass
|
|
21
21
|
from .prompt import RenderedPrompt
|
|
22
22
|
|
|
23
|
-
__all__ = ["OutputParseError", "
|
|
23
|
+
__all__ = ["ARRAY_WRAPPER_KEY", "OutputParseError", "parse_structured_output"]
|
|
24
24
|
|
|
25
|
-
|
|
25
|
+
ARRAY_WRAPPER_KEY: Final[str] = "items"
|
|
26
|
+
|
|
27
|
+
_JSON_FENCE_PATTERN: Final[re.Pattern[str]] = re.compile(
|
|
28
|
+
r"```json\s*\n(.*?)```", re.IGNORECASE | re.DOTALL
|
|
29
|
+
)
|
|
26
30
|
|
|
27
31
|
|
|
28
32
|
class OutputParseError(Exception):
|
|
@@ -39,13 +43,13 @@ class OutputParseError(Exception):
|
|
|
39
43
|
self.dataclass_type = dataclass_type
|
|
40
44
|
|
|
41
45
|
|
|
42
|
-
def
|
|
46
|
+
def parse_structured_output[PayloadT](
|
|
43
47
|
output_text: str, rendered: RenderedPrompt[PayloadT]
|
|
44
48
|
) -> PayloadT:
|
|
45
49
|
"""Parse a model response into the structured output type declared by the prompt."""
|
|
46
50
|
|
|
47
51
|
dataclass_type = rendered.output_type
|
|
48
|
-
container = rendered.
|
|
52
|
+
container = rendered.container
|
|
49
53
|
allow_extra_keys = rendered.allow_extra_keys
|
|
50
54
|
|
|
51
55
|
if dataclass_type is None or container is None:
|
|
@@ -72,6 +76,13 @@ def parse_output[PayloadT](
|
|
|
72
76
|
return cast(PayloadT, parsed)
|
|
73
77
|
|
|
74
78
|
if container == "array":
|
|
79
|
+
if isinstance(payload, Mapping):
|
|
80
|
+
if ARRAY_WRAPPER_KEY not in payload:
|
|
81
|
+
raise OutputParseError(
|
|
82
|
+
"Expected top-level JSON array.",
|
|
83
|
+
dataclass_type=dataclass_type,
|
|
84
|
+
)
|
|
85
|
+
payload = cast(Mapping[str, object], payload)[ARRAY_WRAPPER_KEY]
|
|
75
86
|
if not isinstance(payload, list):
|
|
76
87
|
raise OutputParseError(
|
|
77
88
|
"Expected top-level JSON array.",
|