weakincentives 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of weakincentives might be problematic. Click here for more details.
- weakincentives/__init__.py +15 -2
- weakincentives/adapters/__init__.py +30 -0
- weakincentives/adapters/core.py +85 -0
- weakincentives/adapters/openai.py +361 -0
- weakincentives/prompts/__init__.py +45 -0
- weakincentives/prompts/_types.py +27 -0
- weakincentives/prompts/errors.py +57 -0
- weakincentives/prompts/prompt.py +440 -0
- weakincentives/prompts/response_format.py +54 -0
- weakincentives/prompts/section.py +120 -0
- weakincentives/prompts/structured.py +140 -0
- weakincentives/prompts/text.py +89 -0
- weakincentives/prompts/tool.py +236 -0
- weakincentives/serde/__init__.py +31 -0
- weakincentives/serde/dataclass_serde.py +1016 -0
- weakincentives-0.2.0.dist-info/METADATA +173 -0
- weakincentives-0.2.0.dist-info/RECORD +20 -0
- weakincentives-0.1.0.dist-info/METADATA +0 -21
- weakincentives-0.1.0.dist-info/RECORD +0 -6
- {weakincentives-0.1.0.dist-info → weakincentives-0.2.0.dist-info}/WHEEL +0 -0
- {weakincentives-0.1.0.dist-info → weakincentives-0.2.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,440 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
from collections.abc import Callable, Iterator, Sequence
|
|
16
|
+
from dataclasses import dataclass, field, fields, is_dataclass, replace
|
|
17
|
+
from typing import Any, ClassVar, Literal, cast, get_args, get_origin
|
|
18
|
+
|
|
19
|
+
from ._types import SupportsDataclass
|
|
20
|
+
from .errors import (
|
|
21
|
+
PromptRenderError,
|
|
22
|
+
PromptValidationError,
|
|
23
|
+
SectionPath,
|
|
24
|
+
)
|
|
25
|
+
from .response_format import ResponseFormatParams, ResponseFormatSection
|
|
26
|
+
from .section import Section
|
|
27
|
+
from .tool import Tool
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass(frozen=True, slots=True)
|
|
31
|
+
class RenderedPrompt[OutputT = Any]:
|
|
32
|
+
"""Rendered prompt text paired with structured output metadata."""
|
|
33
|
+
|
|
34
|
+
text: str
|
|
35
|
+
output_type: type[Any] | None
|
|
36
|
+
output_container: Literal["object", "array"] | None
|
|
37
|
+
allow_extra_keys: bool | None
|
|
38
|
+
_tools: tuple[Tool[SupportsDataclass, SupportsDataclass], ...] = field(
|
|
39
|
+
default_factory=tuple
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
def __str__(self) -> str: # pragma: no cover - convenience for logging
|
|
43
|
+
return self.text
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def tools(self) -> tuple[Tool[SupportsDataclass, SupportsDataclass], ...]:
|
|
47
|
+
"""Tools contributed by enabled sections in traversal order."""
|
|
48
|
+
|
|
49
|
+
return self._tools
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _clone_dataclass(instance: SupportsDataclass) -> SupportsDataclass:
|
|
53
|
+
"""Return a shallow copy of the provided dataclass instance."""
|
|
54
|
+
|
|
55
|
+
return cast(SupportsDataclass, replace(cast(Any, instance)))
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _format_specialization_argument(argument: object | None) -> str:
|
|
59
|
+
if argument is None: # pragma: no cover - defensive formatting
|
|
60
|
+
return "?"
|
|
61
|
+
if isinstance(argument, type):
|
|
62
|
+
return argument.__name__
|
|
63
|
+
return repr(argument) # pragma: no cover - fallback for debugging
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@dataclass(frozen=True, slots=True)
|
|
67
|
+
class PromptSectionNode[ParamsT: SupportsDataclass]:
|
|
68
|
+
"""Flattened view of a section within a prompt."""
|
|
69
|
+
|
|
70
|
+
section: Section[ParamsT]
|
|
71
|
+
depth: int
|
|
72
|
+
path: SectionPath
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class Prompt[OutputT = Any]:
|
|
76
|
+
"""Coordinate prompt sections and their parameter bindings."""
|
|
77
|
+
|
|
78
|
+
_output_container_spec: ClassVar[Literal["object", "array"] | None] = None
|
|
79
|
+
_output_dataclass_candidate: ClassVar[Any] = None
|
|
80
|
+
|
|
81
|
+
def __class_getitem__(cls, item: object) -> type[Prompt[Any]]:
|
|
82
|
+
origin = get_origin(item)
|
|
83
|
+
candidate = item
|
|
84
|
+
container: Literal["object", "array"] | None = "object"
|
|
85
|
+
|
|
86
|
+
if origin is list:
|
|
87
|
+
args = get_args(item)
|
|
88
|
+
candidate = args[0] if len(args) == 1 else None
|
|
89
|
+
container = "array"
|
|
90
|
+
label = f"list[{_format_specialization_argument(candidate)}]"
|
|
91
|
+
else:
|
|
92
|
+
container = "object"
|
|
93
|
+
label = _format_specialization_argument(candidate)
|
|
94
|
+
|
|
95
|
+
name = f"{cls.__name__}[{label}]"
|
|
96
|
+
namespace = {
|
|
97
|
+
"__module__": cls.__module__,
|
|
98
|
+
"_output_container_spec": container if candidate is not None else None,
|
|
99
|
+
"_output_dataclass_candidate": candidate,
|
|
100
|
+
}
|
|
101
|
+
return type(name, (cls,), namespace)
|
|
102
|
+
|
|
103
|
+
def __init__(
|
|
104
|
+
self,
|
|
105
|
+
*,
|
|
106
|
+
name: str | None = None,
|
|
107
|
+
sections: Sequence[Section[Any]] | None = None,
|
|
108
|
+
inject_output_instructions: bool = True,
|
|
109
|
+
allow_extra_keys: bool = False,
|
|
110
|
+
) -> None:
|
|
111
|
+
self.name = name
|
|
112
|
+
base_sections: list[Section[SupportsDataclass]] = [
|
|
113
|
+
cast(Section[SupportsDataclass], section) for section in sections or ()
|
|
114
|
+
]
|
|
115
|
+
self._sections: tuple[Section[SupportsDataclass], ...] = tuple(base_sections)
|
|
116
|
+
self._section_nodes: list[PromptSectionNode[SupportsDataclass]] = []
|
|
117
|
+
self._params_registry: dict[
|
|
118
|
+
type[SupportsDataclass], list[PromptSectionNode[SupportsDataclass]]
|
|
119
|
+
] = {}
|
|
120
|
+
self._defaults_by_path: dict[SectionPath, SupportsDataclass] = {}
|
|
121
|
+
self._defaults_by_type: dict[type[SupportsDataclass], SupportsDataclass] = {}
|
|
122
|
+
self.placeholders: dict[SectionPath, set[str]] = {}
|
|
123
|
+
self._tool_name_registry: dict[str, SectionPath] = {}
|
|
124
|
+
|
|
125
|
+
self._output_type: type[Any] | None
|
|
126
|
+
self._output_container: Literal["object", "array"] | None
|
|
127
|
+
self._allow_extra_keys: bool | None
|
|
128
|
+
(
|
|
129
|
+
self._output_type,
|
|
130
|
+
self._output_container,
|
|
131
|
+
self._allow_extra_keys,
|
|
132
|
+
) = self._resolve_output_spec(allow_extra_keys)
|
|
133
|
+
|
|
134
|
+
self.inject_output_instructions = inject_output_instructions
|
|
135
|
+
|
|
136
|
+
for section in base_sections:
|
|
137
|
+
self._register_section(section, path=(section.title,), depth=0)
|
|
138
|
+
|
|
139
|
+
self._response_section: ResponseFormatSection | None = None
|
|
140
|
+
if self._output_type is not None and self._output_container is not None:
|
|
141
|
+
response_params = self._build_response_format_params()
|
|
142
|
+
response_section = ResponseFormatSection(
|
|
143
|
+
params=response_params,
|
|
144
|
+
enabled=lambda _params, prompt=self: prompt.inject_output_instructions,
|
|
145
|
+
)
|
|
146
|
+
self._response_section = response_section
|
|
147
|
+
section_for_registry = cast(Section[SupportsDataclass], response_section)
|
|
148
|
+
self._sections += (section_for_registry,)
|
|
149
|
+
self._register_section(
|
|
150
|
+
section_for_registry,
|
|
151
|
+
path=(response_section.title,),
|
|
152
|
+
depth=0,
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
def render(self, *params: SupportsDataclass) -> RenderedPrompt[OutputT]:
|
|
156
|
+
"""Render the prompt using provided parameter dataclass instances."""
|
|
157
|
+
|
|
158
|
+
param_lookup = self._collect_param_lookup(params)
|
|
159
|
+
rendered_sections: list[str] = []
|
|
160
|
+
collected_tools: list[Tool[SupportsDataclass, SupportsDataclass]] = []
|
|
161
|
+
|
|
162
|
+
for node, section_params in self._iter_enabled_sections(param_lookup):
|
|
163
|
+
params_type = node.section.params
|
|
164
|
+
try:
|
|
165
|
+
rendered = node.section.render(section_params, node.depth)
|
|
166
|
+
except PromptRenderError as error:
|
|
167
|
+
if error.section_path and error.dataclass_type:
|
|
168
|
+
raise
|
|
169
|
+
raise PromptRenderError(
|
|
170
|
+
error.message,
|
|
171
|
+
section_path=node.path,
|
|
172
|
+
dataclass_type=params_type,
|
|
173
|
+
placeholder=error.placeholder,
|
|
174
|
+
) from error
|
|
175
|
+
except Exception as error: # pragma: no cover - defensive guard
|
|
176
|
+
raise PromptRenderError(
|
|
177
|
+
"Section rendering failed.",
|
|
178
|
+
section_path=node.path,
|
|
179
|
+
dataclass_type=params_type,
|
|
180
|
+
) from error
|
|
181
|
+
|
|
182
|
+
section_tools = node.section.tools()
|
|
183
|
+
if section_tools:
|
|
184
|
+
collected_tools.extend(section_tools)
|
|
185
|
+
|
|
186
|
+
if rendered:
|
|
187
|
+
rendered_sections.append(rendered)
|
|
188
|
+
|
|
189
|
+
text = "\n\n".join(rendered_sections)
|
|
190
|
+
|
|
191
|
+
return RenderedPrompt(
|
|
192
|
+
text=text,
|
|
193
|
+
output_type=self._output_type,
|
|
194
|
+
output_container=self._output_container,
|
|
195
|
+
allow_extra_keys=self._allow_extra_keys,
|
|
196
|
+
_tools=tuple(collected_tools),
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
def _register_section(
|
|
200
|
+
self,
|
|
201
|
+
section: Section[SupportsDataclass],
|
|
202
|
+
*,
|
|
203
|
+
path: SectionPath,
|
|
204
|
+
depth: int,
|
|
205
|
+
) -> None:
|
|
206
|
+
params_type = section.params
|
|
207
|
+
if not is_dataclass(params_type):
|
|
208
|
+
raise PromptValidationError(
|
|
209
|
+
"Section params must be a dataclass.",
|
|
210
|
+
section_path=path,
|
|
211
|
+
dataclass_type=params_type,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
node: PromptSectionNode[SupportsDataclass] = PromptSectionNode(
|
|
215
|
+
section=section, depth=depth, path=path
|
|
216
|
+
)
|
|
217
|
+
self._section_nodes.append(node)
|
|
218
|
+
self._params_registry.setdefault(params_type, []).append(node)
|
|
219
|
+
|
|
220
|
+
if section.defaults is not None:
|
|
221
|
+
default_value = section.defaults
|
|
222
|
+
if isinstance(default_value, type) or not is_dataclass(default_value):
|
|
223
|
+
raise PromptValidationError(
|
|
224
|
+
"Section defaults must be dataclass instances.",
|
|
225
|
+
section_path=path,
|
|
226
|
+
dataclass_type=params_type,
|
|
227
|
+
)
|
|
228
|
+
if type(default_value) is not params_type:
|
|
229
|
+
raise PromptValidationError(
|
|
230
|
+
"Section defaults must match section params type.",
|
|
231
|
+
section_path=path,
|
|
232
|
+
dataclass_type=params_type,
|
|
233
|
+
)
|
|
234
|
+
self._defaults_by_path[path] = default_value
|
|
235
|
+
self._defaults_by_type.setdefault(params_type, default_value)
|
|
236
|
+
|
|
237
|
+
section_placeholders = section.placeholder_names()
|
|
238
|
+
self.placeholders[path] = set(section_placeholders)
|
|
239
|
+
param_fields = {field.name for field in fields(params_type)}
|
|
240
|
+
unknown_placeholders = section_placeholders - param_fields
|
|
241
|
+
if unknown_placeholders:
|
|
242
|
+
placeholder = sorted(unknown_placeholders)[0]
|
|
243
|
+
raise PromptValidationError(
|
|
244
|
+
"Template references unknown placeholder.",
|
|
245
|
+
section_path=path,
|
|
246
|
+
dataclass_type=params_type,
|
|
247
|
+
placeholder=placeholder,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
self._register_section_tools(section, path)
|
|
251
|
+
|
|
252
|
+
for child in section.children:
|
|
253
|
+
child_path = path + (child.title,)
|
|
254
|
+
self._register_section(child, path=child_path, depth=depth + 1)
|
|
255
|
+
|
|
256
|
+
@property
|
|
257
|
+
def sections(self) -> tuple[PromptSectionNode[SupportsDataclass], ...]:
|
|
258
|
+
return tuple(self._section_nodes)
|
|
259
|
+
|
|
260
|
+
@property
|
|
261
|
+
def params_types(self) -> set[type[SupportsDataclass]]:
|
|
262
|
+
return set(self._params_registry.keys())
|
|
263
|
+
|
|
264
|
+
def _resolve_output_spec(
|
|
265
|
+
self, allow_extra_keys: bool
|
|
266
|
+
) -> tuple[type[Any] | None, Literal["object", "array"] | None, bool | None]:
|
|
267
|
+
candidate = getattr(type(self), "_output_dataclass_candidate", None)
|
|
268
|
+
container = cast(
|
|
269
|
+
Literal["object", "array"] | None,
|
|
270
|
+
getattr(type(self), "_output_container_spec", None),
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
if candidate is None or container is None:
|
|
274
|
+
return None, None, None
|
|
275
|
+
|
|
276
|
+
if not isinstance(candidate, type): # pragma: no cover - defensive guard
|
|
277
|
+
candidate_type = cast(type[Any], type(candidate))
|
|
278
|
+
raise PromptValidationError(
|
|
279
|
+
"Prompt output type must be a dataclass.",
|
|
280
|
+
dataclass_type=candidate_type,
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
if not is_dataclass(candidate):
|
|
284
|
+
bad_dataclass = cast(type[Any], candidate)
|
|
285
|
+
raise PromptValidationError(
|
|
286
|
+
"Prompt output type must be a dataclass.",
|
|
287
|
+
dataclass_type=bad_dataclass,
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
dataclass_type = cast(type[Any], candidate)
|
|
291
|
+
return dataclass_type, container, allow_extra_keys
|
|
292
|
+
|
|
293
|
+
def _build_response_format_params(self) -> ResponseFormatParams:
|
|
294
|
+
container = self._output_container
|
|
295
|
+
if container is None: # pragma: no cover - defensive guard
|
|
296
|
+
raise RuntimeError(
|
|
297
|
+
"Output container missing during response format construction."
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
article: Literal["a", "an"] = (
|
|
301
|
+
"an" if container.startswith(("a", "e", "i", "o", "u")) else "a"
|
|
302
|
+
)
|
|
303
|
+
extra_clause = ". Do not add extra keys." if not self._allow_extra_keys else "."
|
|
304
|
+
return ResponseFormatParams(
|
|
305
|
+
article=article,
|
|
306
|
+
container=container,
|
|
307
|
+
extra_clause=extra_clause,
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
def _collect_param_lookup(
|
|
311
|
+
self, params: tuple[SupportsDataclass, ...]
|
|
312
|
+
) -> dict[type[SupportsDataclass], SupportsDataclass]:
|
|
313
|
+
lookup: dict[type[SupportsDataclass], SupportsDataclass] = {}
|
|
314
|
+
for value in params:
|
|
315
|
+
if isinstance(value, type):
|
|
316
|
+
provided_type: type[Any] = value
|
|
317
|
+
else:
|
|
318
|
+
provided_type = type(value)
|
|
319
|
+
if isinstance(value, type) or not is_dataclass(value):
|
|
320
|
+
raise PromptValidationError(
|
|
321
|
+
"Prompt expects dataclass instances.",
|
|
322
|
+
dataclass_type=provided_type,
|
|
323
|
+
)
|
|
324
|
+
params_type = cast(type[SupportsDataclass], provided_type)
|
|
325
|
+
if params_type in lookup:
|
|
326
|
+
raise PromptValidationError(
|
|
327
|
+
"Duplicate params type supplied to prompt.",
|
|
328
|
+
dataclass_type=params_type,
|
|
329
|
+
)
|
|
330
|
+
if params_type not in self._params_registry:
|
|
331
|
+
raise PromptValidationError(
|
|
332
|
+
"Unexpected params type supplied to prompt.",
|
|
333
|
+
dataclass_type=params_type,
|
|
334
|
+
)
|
|
335
|
+
lookup[params_type] = value
|
|
336
|
+
return lookup
|
|
337
|
+
|
|
338
|
+
def _resolve_section_params(
|
|
339
|
+
self,
|
|
340
|
+
node: PromptSectionNode[SupportsDataclass],
|
|
341
|
+
param_lookup: dict[type[SupportsDataclass], SupportsDataclass],
|
|
342
|
+
) -> SupportsDataclass:
|
|
343
|
+
params_type = node.section.params
|
|
344
|
+
section_params: SupportsDataclass | None = param_lookup.get(params_type)
|
|
345
|
+
|
|
346
|
+
if section_params is None:
|
|
347
|
+
default_value = self._defaults_by_path.get(node.path)
|
|
348
|
+
if default_value is not None:
|
|
349
|
+
section_params = _clone_dataclass(default_value)
|
|
350
|
+
else:
|
|
351
|
+
type_default = self._defaults_by_type.get(params_type)
|
|
352
|
+
if type_default is not None:
|
|
353
|
+
section_params = _clone_dataclass(type_default)
|
|
354
|
+
else:
|
|
355
|
+
try:
|
|
356
|
+
constructor = cast(Callable[[], SupportsDataclass], params_type)
|
|
357
|
+
section_params = constructor()
|
|
358
|
+
except TypeError as error:
|
|
359
|
+
raise PromptRenderError(
|
|
360
|
+
"Missing parameters for section.",
|
|
361
|
+
section_path=node.path,
|
|
362
|
+
dataclass_type=params_type,
|
|
363
|
+
) from error
|
|
364
|
+
|
|
365
|
+
return section_params
|
|
366
|
+
|
|
367
|
+
def _iter_enabled_sections(
|
|
368
|
+
self,
|
|
369
|
+
param_lookup: dict[type[SupportsDataclass], SupportsDataclass],
|
|
370
|
+
) -> Iterator[tuple[PromptSectionNode[SupportsDataclass], SupportsDataclass]]:
|
|
371
|
+
skip_depth: int | None = None
|
|
372
|
+
|
|
373
|
+
for node in self._section_nodes:
|
|
374
|
+
if skip_depth is not None:
|
|
375
|
+
if node.depth > skip_depth:
|
|
376
|
+
continue
|
|
377
|
+
skip_depth = None
|
|
378
|
+
|
|
379
|
+
section_params = self._resolve_section_params(node, param_lookup)
|
|
380
|
+
|
|
381
|
+
try:
|
|
382
|
+
enabled = node.section.is_enabled(section_params)
|
|
383
|
+
except Exception as error: # pragma: no cover - defensive guard
|
|
384
|
+
raise PromptRenderError(
|
|
385
|
+
"Section enabled predicate failed.",
|
|
386
|
+
section_path=node.path,
|
|
387
|
+
dataclass_type=node.section.params,
|
|
388
|
+
) from error
|
|
389
|
+
|
|
390
|
+
if not enabled:
|
|
391
|
+
skip_depth = node.depth
|
|
392
|
+
continue
|
|
393
|
+
|
|
394
|
+
yield node, section_params
|
|
395
|
+
|
|
396
|
+
def _register_section_tools(
|
|
397
|
+
self,
|
|
398
|
+
section: Section[SupportsDataclass],
|
|
399
|
+
path: SectionPath,
|
|
400
|
+
) -> None:
|
|
401
|
+
section_tools = section.tools()
|
|
402
|
+
if not section_tools:
|
|
403
|
+
return
|
|
404
|
+
|
|
405
|
+
tools_iterable = cast(Sequence[object], section_tools)
|
|
406
|
+
for tool_candidate in tools_iterable:
|
|
407
|
+
if not isinstance(tool_candidate, Tool):
|
|
408
|
+
raise PromptValidationError(
|
|
409
|
+
"Section tools() must return Tool instances.",
|
|
410
|
+
section_path=path,
|
|
411
|
+
dataclass_type=section.params,
|
|
412
|
+
)
|
|
413
|
+
tool: Tool[SupportsDataclass, SupportsDataclass] = cast(
|
|
414
|
+
Tool[SupportsDataclass, SupportsDataclass], tool_candidate
|
|
415
|
+
)
|
|
416
|
+
params_type = cast(
|
|
417
|
+
type[SupportsDataclass] | None, getattr(tool, "params_type", None)
|
|
418
|
+
)
|
|
419
|
+
if not isinstance(params_type, type) or not is_dataclass(params_type):
|
|
420
|
+
raise PromptValidationError(
|
|
421
|
+
"Tool params_type must be a dataclass type.",
|
|
422
|
+
section_path=path,
|
|
423
|
+
dataclass_type=(
|
|
424
|
+
params_type
|
|
425
|
+
if isinstance(params_type, type)
|
|
426
|
+
else type(params_type)
|
|
427
|
+
),
|
|
428
|
+
)
|
|
429
|
+
existing_path = self._tool_name_registry.get(tool.name)
|
|
430
|
+
if existing_path is not None:
|
|
431
|
+
raise PromptValidationError(
|
|
432
|
+
"Duplicate tool name registered for prompt.",
|
|
433
|
+
section_path=path,
|
|
434
|
+
dataclass_type=tool.params_type,
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
self._tool_name_registry[tool.name] = path
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
__all__ = ["Prompt", "PromptSectionNode", "RenderedPrompt"]
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
from collections.abc import Callable
|
|
16
|
+
from dataclasses import dataclass
|
|
17
|
+
from typing import Literal
|
|
18
|
+
|
|
19
|
+
from .text import TextSection
|
|
20
|
+
|
|
21
|
+
__all__ = ["ResponseFormatParams", "ResponseFormatSection"]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass(slots=True)
|
|
25
|
+
class ResponseFormatParams:
|
|
26
|
+
"""Parameter payload for the auto-generated response format section."""
|
|
27
|
+
|
|
28
|
+
article: Literal["a", "an"]
|
|
29
|
+
container: Literal["object", "array"]
|
|
30
|
+
extra_clause: str
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
_RESPONSE_FORMAT_BODY = """Return ONLY a single fenced JSON code block. Do not include any text
|
|
34
|
+
before or after the block.
|
|
35
|
+
|
|
36
|
+
The top-level JSON value MUST be ${article} ${container} that matches the fields
|
|
37
|
+
of the expected schema${extra_clause}"""
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class ResponseFormatSection(TextSection[ResponseFormatParams]):
|
|
41
|
+
"""Internal section that appends JSON-only response instructions."""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
*,
|
|
46
|
+
params: ResponseFormatParams,
|
|
47
|
+
enabled: Callable[[ResponseFormatParams], bool] | None = None,
|
|
48
|
+
) -> None:
|
|
49
|
+
super().__init__(
|
|
50
|
+
title="Response Format",
|
|
51
|
+
body=_RESPONSE_FORMAT_BODY,
|
|
52
|
+
defaults=params,
|
|
53
|
+
enabled=enabled,
|
|
54
|
+
)
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
from abc import ABC, abstractmethod
|
|
16
|
+
from collections.abc import Callable, Sequence
|
|
17
|
+
from typing import TYPE_CHECKING, ClassVar, cast
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from .tool import Tool
|
|
21
|
+
|
|
22
|
+
from ._types import SupportsDataclass
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class Section[ParamsT: SupportsDataclass](ABC):
|
|
26
|
+
"""Abstract building block for prompt content."""
|
|
27
|
+
|
|
28
|
+
_params_type: ClassVar[type[SupportsDataclass] | None] = None
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
*,
|
|
33
|
+
title: str,
|
|
34
|
+
defaults: ParamsT | None = None,
|
|
35
|
+
children: Sequence[object] | None = None,
|
|
36
|
+
enabled: Callable[[ParamsT], bool] | None = None,
|
|
37
|
+
tools: Sequence[object] | None = None,
|
|
38
|
+
) -> None:
|
|
39
|
+
params_type = cast(
|
|
40
|
+
type[ParamsT] | None, getattr(self.__class__, "_params_type", None)
|
|
41
|
+
)
|
|
42
|
+
if params_type is None:
|
|
43
|
+
raise TypeError(
|
|
44
|
+
"Section must be instantiated with a concrete ParamsT type."
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
self.params_type: type[ParamsT] = params_type
|
|
48
|
+
self.params: type[ParamsT] = params_type
|
|
49
|
+
self.title = title
|
|
50
|
+
self.defaults = defaults
|
|
51
|
+
|
|
52
|
+
normalized_children: list[Section[SupportsDataclass]] = []
|
|
53
|
+
for child in children or ():
|
|
54
|
+
if not isinstance(child, Section):
|
|
55
|
+
raise TypeError("Section children must be Section instances.")
|
|
56
|
+
normalized_children.append(cast(Section[SupportsDataclass], child))
|
|
57
|
+
self.children: tuple[Section[SupportsDataclass], ...] = tuple(
|
|
58
|
+
normalized_children
|
|
59
|
+
)
|
|
60
|
+
self._enabled = enabled
|
|
61
|
+
self._tools = self._normalize_tools(tools)
|
|
62
|
+
|
|
63
|
+
def is_enabled(self, params: ParamsT) -> bool:
|
|
64
|
+
"""Return True when the section should render for the given params."""
|
|
65
|
+
|
|
66
|
+
if self._enabled is None:
|
|
67
|
+
return True
|
|
68
|
+
return bool(self._enabled(params))
|
|
69
|
+
|
|
70
|
+
@abstractmethod
|
|
71
|
+
def render(self, params: ParamsT, depth: int) -> str:
|
|
72
|
+
"""Produce markdown output for the section at the supplied depth."""
|
|
73
|
+
|
|
74
|
+
def placeholder_names(self) -> set[str]:
|
|
75
|
+
"""Return placeholder identifiers used by the section template."""
|
|
76
|
+
|
|
77
|
+
return set()
|
|
78
|
+
|
|
79
|
+
def tools(self) -> tuple[Tool[SupportsDataclass, SupportsDataclass], ...]:
|
|
80
|
+
"""Return the tools exposed by this section."""
|
|
81
|
+
|
|
82
|
+
return self._tools
|
|
83
|
+
|
|
84
|
+
@classmethod
|
|
85
|
+
def __class_getitem__(cls, item: object) -> type[Section[SupportsDataclass]]:
|
|
86
|
+
params_type = cls._normalize_generic_argument(item)
|
|
87
|
+
|
|
88
|
+
class _SpecializedSection(cls): # type: ignore[misc]
|
|
89
|
+
pass
|
|
90
|
+
|
|
91
|
+
_SpecializedSection.__name__ = cls.__name__
|
|
92
|
+
_SpecializedSection.__qualname__ = cls.__qualname__
|
|
93
|
+
_SpecializedSection.__module__ = cls.__module__
|
|
94
|
+
_SpecializedSection._params_type = cast(type[SupportsDataclass], params_type)
|
|
95
|
+
return _SpecializedSection # type: ignore[return-value]
|
|
96
|
+
|
|
97
|
+
@staticmethod
|
|
98
|
+
def _normalize_generic_argument(item: object) -> object:
|
|
99
|
+
if isinstance(item, tuple):
|
|
100
|
+
raise TypeError("Section[...] expects a single type argument.")
|
|
101
|
+
return item
|
|
102
|
+
|
|
103
|
+
@staticmethod
|
|
104
|
+
def _normalize_tools(
|
|
105
|
+
tools: Sequence[object] | None,
|
|
106
|
+
) -> tuple[Tool[SupportsDataclass, SupportsDataclass], ...]:
|
|
107
|
+
if not tools:
|
|
108
|
+
return ()
|
|
109
|
+
|
|
110
|
+
from .tool import Tool
|
|
111
|
+
|
|
112
|
+
normalized: list[Tool[SupportsDataclass, SupportsDataclass]] = []
|
|
113
|
+
for tool in tools:
|
|
114
|
+
if not isinstance(tool, Tool):
|
|
115
|
+
raise TypeError("Section tools must be Tool instances.")
|
|
116
|
+
normalized.append(cast(Tool[SupportsDataclass, SupportsDataclass], tool))
|
|
117
|
+
return tuple(normalized)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
__all__ = ["Section"]
|