weakincentives 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of weakincentives might be problematic. Click here for more details.

@@ -0,0 +1,440 @@
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ from __future__ import annotations
14
+
15
+ from collections.abc import Callable, Iterator, Sequence
16
+ from dataclasses import dataclass, field, fields, is_dataclass, replace
17
+ from typing import Any, ClassVar, Literal, cast, get_args, get_origin
18
+
19
+ from ._types import SupportsDataclass
20
+ from .errors import (
21
+ PromptRenderError,
22
+ PromptValidationError,
23
+ SectionPath,
24
+ )
25
+ from .response_format import ResponseFormatParams, ResponseFormatSection
26
+ from .section import Section
27
+ from .tool import Tool
28
+
29
+
30
+ @dataclass(frozen=True, slots=True)
31
+ class RenderedPrompt[OutputT = Any]:
32
+ """Rendered prompt text paired with structured output metadata."""
33
+
34
+ text: str
35
+ output_type: type[Any] | None
36
+ output_container: Literal["object", "array"] | None
37
+ allow_extra_keys: bool | None
38
+ _tools: tuple[Tool[SupportsDataclass, SupportsDataclass], ...] = field(
39
+ default_factory=tuple
40
+ )
41
+
42
+ def __str__(self) -> str: # pragma: no cover - convenience for logging
43
+ return self.text
44
+
45
+ @property
46
+ def tools(self) -> tuple[Tool[SupportsDataclass, SupportsDataclass], ...]:
47
+ """Tools contributed by enabled sections in traversal order."""
48
+
49
+ return self._tools
50
+
51
+
52
+ def _clone_dataclass(instance: SupportsDataclass) -> SupportsDataclass:
53
+ """Return a shallow copy of the provided dataclass instance."""
54
+
55
+ return cast(SupportsDataclass, replace(cast(Any, instance)))
56
+
57
+
58
+ def _format_specialization_argument(argument: object | None) -> str:
59
+ if argument is None: # pragma: no cover - defensive formatting
60
+ return "?"
61
+ if isinstance(argument, type):
62
+ return argument.__name__
63
+ return repr(argument) # pragma: no cover - fallback for debugging
64
+
65
+
66
+ @dataclass(frozen=True, slots=True)
67
+ class PromptSectionNode[ParamsT: SupportsDataclass]:
68
+ """Flattened view of a section within a prompt."""
69
+
70
+ section: Section[ParamsT]
71
+ depth: int
72
+ path: SectionPath
73
+
74
+
75
+ class Prompt[OutputT = Any]:
76
+ """Coordinate prompt sections and their parameter bindings."""
77
+
78
+ _output_container_spec: ClassVar[Literal["object", "array"] | None] = None
79
+ _output_dataclass_candidate: ClassVar[Any] = None
80
+
81
+ def __class_getitem__(cls, item: object) -> type[Prompt[Any]]:
82
+ origin = get_origin(item)
83
+ candidate = item
84
+ container: Literal["object", "array"] | None = "object"
85
+
86
+ if origin is list:
87
+ args = get_args(item)
88
+ candidate = args[0] if len(args) == 1 else None
89
+ container = "array"
90
+ label = f"list[{_format_specialization_argument(candidate)}]"
91
+ else:
92
+ container = "object"
93
+ label = _format_specialization_argument(candidate)
94
+
95
+ name = f"{cls.__name__}[{label}]"
96
+ namespace = {
97
+ "__module__": cls.__module__,
98
+ "_output_container_spec": container if candidate is not None else None,
99
+ "_output_dataclass_candidate": candidate,
100
+ }
101
+ return type(name, (cls,), namespace)
102
+
103
+ def __init__(
104
+ self,
105
+ *,
106
+ name: str | None = None,
107
+ sections: Sequence[Section[Any]] | None = None,
108
+ inject_output_instructions: bool = True,
109
+ allow_extra_keys: bool = False,
110
+ ) -> None:
111
+ self.name = name
112
+ base_sections: list[Section[SupportsDataclass]] = [
113
+ cast(Section[SupportsDataclass], section) for section in sections or ()
114
+ ]
115
+ self._sections: tuple[Section[SupportsDataclass], ...] = tuple(base_sections)
116
+ self._section_nodes: list[PromptSectionNode[SupportsDataclass]] = []
117
+ self._params_registry: dict[
118
+ type[SupportsDataclass], list[PromptSectionNode[SupportsDataclass]]
119
+ ] = {}
120
+ self._defaults_by_path: dict[SectionPath, SupportsDataclass] = {}
121
+ self._defaults_by_type: dict[type[SupportsDataclass], SupportsDataclass] = {}
122
+ self.placeholders: dict[SectionPath, set[str]] = {}
123
+ self._tool_name_registry: dict[str, SectionPath] = {}
124
+
125
+ self._output_type: type[Any] | None
126
+ self._output_container: Literal["object", "array"] | None
127
+ self._allow_extra_keys: bool | None
128
+ (
129
+ self._output_type,
130
+ self._output_container,
131
+ self._allow_extra_keys,
132
+ ) = self._resolve_output_spec(allow_extra_keys)
133
+
134
+ self.inject_output_instructions = inject_output_instructions
135
+
136
+ for section in base_sections:
137
+ self._register_section(section, path=(section.title,), depth=0)
138
+
139
+ self._response_section: ResponseFormatSection | None = None
140
+ if self._output_type is not None and self._output_container is not None:
141
+ response_params = self._build_response_format_params()
142
+ response_section = ResponseFormatSection(
143
+ params=response_params,
144
+ enabled=lambda _params, prompt=self: prompt.inject_output_instructions,
145
+ )
146
+ self._response_section = response_section
147
+ section_for_registry = cast(Section[SupportsDataclass], response_section)
148
+ self._sections += (section_for_registry,)
149
+ self._register_section(
150
+ section_for_registry,
151
+ path=(response_section.title,),
152
+ depth=0,
153
+ )
154
+
155
+ def render(self, *params: SupportsDataclass) -> RenderedPrompt[OutputT]:
156
+ """Render the prompt using provided parameter dataclass instances."""
157
+
158
+ param_lookup = self._collect_param_lookup(params)
159
+ rendered_sections: list[str] = []
160
+ collected_tools: list[Tool[SupportsDataclass, SupportsDataclass]] = []
161
+
162
+ for node, section_params in self._iter_enabled_sections(param_lookup):
163
+ params_type = node.section.params
164
+ try:
165
+ rendered = node.section.render(section_params, node.depth)
166
+ except PromptRenderError as error:
167
+ if error.section_path and error.dataclass_type:
168
+ raise
169
+ raise PromptRenderError(
170
+ error.message,
171
+ section_path=node.path,
172
+ dataclass_type=params_type,
173
+ placeholder=error.placeholder,
174
+ ) from error
175
+ except Exception as error: # pragma: no cover - defensive guard
176
+ raise PromptRenderError(
177
+ "Section rendering failed.",
178
+ section_path=node.path,
179
+ dataclass_type=params_type,
180
+ ) from error
181
+
182
+ section_tools = node.section.tools()
183
+ if section_tools:
184
+ collected_tools.extend(section_tools)
185
+
186
+ if rendered:
187
+ rendered_sections.append(rendered)
188
+
189
+ text = "\n\n".join(rendered_sections)
190
+
191
+ return RenderedPrompt(
192
+ text=text,
193
+ output_type=self._output_type,
194
+ output_container=self._output_container,
195
+ allow_extra_keys=self._allow_extra_keys,
196
+ _tools=tuple(collected_tools),
197
+ )
198
+
199
+ def _register_section(
200
+ self,
201
+ section: Section[SupportsDataclass],
202
+ *,
203
+ path: SectionPath,
204
+ depth: int,
205
+ ) -> None:
206
+ params_type = section.params
207
+ if not is_dataclass(params_type):
208
+ raise PromptValidationError(
209
+ "Section params must be a dataclass.",
210
+ section_path=path,
211
+ dataclass_type=params_type,
212
+ )
213
+
214
+ node: PromptSectionNode[SupportsDataclass] = PromptSectionNode(
215
+ section=section, depth=depth, path=path
216
+ )
217
+ self._section_nodes.append(node)
218
+ self._params_registry.setdefault(params_type, []).append(node)
219
+
220
+ if section.defaults is not None:
221
+ default_value = section.defaults
222
+ if isinstance(default_value, type) or not is_dataclass(default_value):
223
+ raise PromptValidationError(
224
+ "Section defaults must be dataclass instances.",
225
+ section_path=path,
226
+ dataclass_type=params_type,
227
+ )
228
+ if type(default_value) is not params_type:
229
+ raise PromptValidationError(
230
+ "Section defaults must match section params type.",
231
+ section_path=path,
232
+ dataclass_type=params_type,
233
+ )
234
+ self._defaults_by_path[path] = default_value
235
+ self._defaults_by_type.setdefault(params_type, default_value)
236
+
237
+ section_placeholders = section.placeholder_names()
238
+ self.placeholders[path] = set(section_placeholders)
239
+ param_fields = {field.name for field in fields(params_type)}
240
+ unknown_placeholders = section_placeholders - param_fields
241
+ if unknown_placeholders:
242
+ placeholder = sorted(unknown_placeholders)[0]
243
+ raise PromptValidationError(
244
+ "Template references unknown placeholder.",
245
+ section_path=path,
246
+ dataclass_type=params_type,
247
+ placeholder=placeholder,
248
+ )
249
+
250
+ self._register_section_tools(section, path)
251
+
252
+ for child in section.children:
253
+ child_path = path + (child.title,)
254
+ self._register_section(child, path=child_path, depth=depth + 1)
255
+
256
+ @property
257
+ def sections(self) -> tuple[PromptSectionNode[SupportsDataclass], ...]:
258
+ return tuple(self._section_nodes)
259
+
260
+ @property
261
+ def params_types(self) -> set[type[SupportsDataclass]]:
262
+ return set(self._params_registry.keys())
263
+
264
+ def _resolve_output_spec(
265
+ self, allow_extra_keys: bool
266
+ ) -> tuple[type[Any] | None, Literal["object", "array"] | None, bool | None]:
267
+ candidate = getattr(type(self), "_output_dataclass_candidate", None)
268
+ container = cast(
269
+ Literal["object", "array"] | None,
270
+ getattr(type(self), "_output_container_spec", None),
271
+ )
272
+
273
+ if candidate is None or container is None:
274
+ return None, None, None
275
+
276
+ if not isinstance(candidate, type): # pragma: no cover - defensive guard
277
+ candidate_type = cast(type[Any], type(candidate))
278
+ raise PromptValidationError(
279
+ "Prompt output type must be a dataclass.",
280
+ dataclass_type=candidate_type,
281
+ )
282
+
283
+ if not is_dataclass(candidate):
284
+ bad_dataclass = cast(type[Any], candidate)
285
+ raise PromptValidationError(
286
+ "Prompt output type must be a dataclass.",
287
+ dataclass_type=bad_dataclass,
288
+ )
289
+
290
+ dataclass_type = cast(type[Any], candidate)
291
+ return dataclass_type, container, allow_extra_keys
292
+
293
+ def _build_response_format_params(self) -> ResponseFormatParams:
294
+ container = self._output_container
295
+ if container is None: # pragma: no cover - defensive guard
296
+ raise RuntimeError(
297
+ "Output container missing during response format construction."
298
+ )
299
+
300
+ article: Literal["a", "an"] = (
301
+ "an" if container.startswith(("a", "e", "i", "o", "u")) else "a"
302
+ )
303
+ extra_clause = ". Do not add extra keys." if not self._allow_extra_keys else "."
304
+ return ResponseFormatParams(
305
+ article=article,
306
+ container=container,
307
+ extra_clause=extra_clause,
308
+ )
309
+
310
+ def _collect_param_lookup(
311
+ self, params: tuple[SupportsDataclass, ...]
312
+ ) -> dict[type[SupportsDataclass], SupportsDataclass]:
313
+ lookup: dict[type[SupportsDataclass], SupportsDataclass] = {}
314
+ for value in params:
315
+ if isinstance(value, type):
316
+ provided_type: type[Any] = value
317
+ else:
318
+ provided_type = type(value)
319
+ if isinstance(value, type) or not is_dataclass(value):
320
+ raise PromptValidationError(
321
+ "Prompt expects dataclass instances.",
322
+ dataclass_type=provided_type,
323
+ )
324
+ params_type = cast(type[SupportsDataclass], provided_type)
325
+ if params_type in lookup:
326
+ raise PromptValidationError(
327
+ "Duplicate params type supplied to prompt.",
328
+ dataclass_type=params_type,
329
+ )
330
+ if params_type not in self._params_registry:
331
+ raise PromptValidationError(
332
+ "Unexpected params type supplied to prompt.",
333
+ dataclass_type=params_type,
334
+ )
335
+ lookup[params_type] = value
336
+ return lookup
337
+
338
+ def _resolve_section_params(
339
+ self,
340
+ node: PromptSectionNode[SupportsDataclass],
341
+ param_lookup: dict[type[SupportsDataclass], SupportsDataclass],
342
+ ) -> SupportsDataclass:
343
+ params_type = node.section.params
344
+ section_params: SupportsDataclass | None = param_lookup.get(params_type)
345
+
346
+ if section_params is None:
347
+ default_value = self._defaults_by_path.get(node.path)
348
+ if default_value is not None:
349
+ section_params = _clone_dataclass(default_value)
350
+ else:
351
+ type_default = self._defaults_by_type.get(params_type)
352
+ if type_default is not None:
353
+ section_params = _clone_dataclass(type_default)
354
+ else:
355
+ try:
356
+ constructor = cast(Callable[[], SupportsDataclass], params_type)
357
+ section_params = constructor()
358
+ except TypeError as error:
359
+ raise PromptRenderError(
360
+ "Missing parameters for section.",
361
+ section_path=node.path,
362
+ dataclass_type=params_type,
363
+ ) from error
364
+
365
+ return section_params
366
+
367
+ def _iter_enabled_sections(
368
+ self,
369
+ param_lookup: dict[type[SupportsDataclass], SupportsDataclass],
370
+ ) -> Iterator[tuple[PromptSectionNode[SupportsDataclass], SupportsDataclass]]:
371
+ skip_depth: int | None = None
372
+
373
+ for node in self._section_nodes:
374
+ if skip_depth is not None:
375
+ if node.depth > skip_depth:
376
+ continue
377
+ skip_depth = None
378
+
379
+ section_params = self._resolve_section_params(node, param_lookup)
380
+
381
+ try:
382
+ enabled = node.section.is_enabled(section_params)
383
+ except Exception as error: # pragma: no cover - defensive guard
384
+ raise PromptRenderError(
385
+ "Section enabled predicate failed.",
386
+ section_path=node.path,
387
+ dataclass_type=node.section.params,
388
+ ) from error
389
+
390
+ if not enabled:
391
+ skip_depth = node.depth
392
+ continue
393
+
394
+ yield node, section_params
395
+
396
+ def _register_section_tools(
397
+ self,
398
+ section: Section[SupportsDataclass],
399
+ path: SectionPath,
400
+ ) -> None:
401
+ section_tools = section.tools()
402
+ if not section_tools:
403
+ return
404
+
405
+ tools_iterable = cast(Sequence[object], section_tools)
406
+ for tool_candidate in tools_iterable:
407
+ if not isinstance(tool_candidate, Tool):
408
+ raise PromptValidationError(
409
+ "Section tools() must return Tool instances.",
410
+ section_path=path,
411
+ dataclass_type=section.params,
412
+ )
413
+ tool: Tool[SupportsDataclass, SupportsDataclass] = cast(
414
+ Tool[SupportsDataclass, SupportsDataclass], tool_candidate
415
+ )
416
+ params_type = cast(
417
+ type[SupportsDataclass] | None, getattr(tool, "params_type", None)
418
+ )
419
+ if not isinstance(params_type, type) or not is_dataclass(params_type):
420
+ raise PromptValidationError(
421
+ "Tool params_type must be a dataclass type.",
422
+ section_path=path,
423
+ dataclass_type=(
424
+ params_type
425
+ if isinstance(params_type, type)
426
+ else type(params_type)
427
+ ),
428
+ )
429
+ existing_path = self._tool_name_registry.get(tool.name)
430
+ if existing_path is not None:
431
+ raise PromptValidationError(
432
+ "Duplicate tool name registered for prompt.",
433
+ section_path=path,
434
+ dataclass_type=tool.params_type,
435
+ )
436
+
437
+ self._tool_name_registry[tool.name] = path
438
+
439
+
440
+ __all__ = ["Prompt", "PromptSectionNode", "RenderedPrompt"]
@@ -0,0 +1,54 @@
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ from __future__ import annotations
14
+
15
+ from collections.abc import Callable
16
+ from dataclasses import dataclass
17
+ from typing import Literal
18
+
19
+ from .text import TextSection
20
+
21
+ __all__ = ["ResponseFormatParams", "ResponseFormatSection"]
22
+
23
+
24
+ @dataclass(slots=True)
25
+ class ResponseFormatParams:
26
+ """Parameter payload for the auto-generated response format section."""
27
+
28
+ article: Literal["a", "an"]
29
+ container: Literal["object", "array"]
30
+ extra_clause: str
31
+
32
+
33
+ _RESPONSE_FORMAT_BODY = """Return ONLY a single fenced JSON code block. Do not include any text
34
+ before or after the block.
35
+
36
+ The top-level JSON value MUST be ${article} ${container} that matches the fields
37
+ of the expected schema${extra_clause}"""
38
+
39
+
40
+ class ResponseFormatSection(TextSection[ResponseFormatParams]):
41
+ """Internal section that appends JSON-only response instructions."""
42
+
43
+ def __init__(
44
+ self,
45
+ *,
46
+ params: ResponseFormatParams,
47
+ enabled: Callable[[ResponseFormatParams], bool] | None = None,
48
+ ) -> None:
49
+ super().__init__(
50
+ title="Response Format",
51
+ body=_RESPONSE_FORMAT_BODY,
52
+ defaults=params,
53
+ enabled=enabled,
54
+ )
@@ -0,0 +1,120 @@
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ from __future__ import annotations
14
+
15
+ from abc import ABC, abstractmethod
16
+ from collections.abc import Callable, Sequence
17
+ from typing import TYPE_CHECKING, ClassVar, cast
18
+
19
+ if TYPE_CHECKING:
20
+ from .tool import Tool
21
+
22
+ from ._types import SupportsDataclass
23
+
24
+
25
+ class Section[ParamsT: SupportsDataclass](ABC):
26
+ """Abstract building block for prompt content."""
27
+
28
+ _params_type: ClassVar[type[SupportsDataclass] | None] = None
29
+
30
+ def __init__(
31
+ self,
32
+ *,
33
+ title: str,
34
+ defaults: ParamsT | None = None,
35
+ children: Sequence[object] | None = None,
36
+ enabled: Callable[[ParamsT], bool] | None = None,
37
+ tools: Sequence[object] | None = None,
38
+ ) -> None:
39
+ params_type = cast(
40
+ type[ParamsT] | None, getattr(self.__class__, "_params_type", None)
41
+ )
42
+ if params_type is None:
43
+ raise TypeError(
44
+ "Section must be instantiated with a concrete ParamsT type."
45
+ )
46
+
47
+ self.params_type: type[ParamsT] = params_type
48
+ self.params: type[ParamsT] = params_type
49
+ self.title = title
50
+ self.defaults = defaults
51
+
52
+ normalized_children: list[Section[SupportsDataclass]] = []
53
+ for child in children or ():
54
+ if not isinstance(child, Section):
55
+ raise TypeError("Section children must be Section instances.")
56
+ normalized_children.append(cast(Section[SupportsDataclass], child))
57
+ self.children: tuple[Section[SupportsDataclass], ...] = tuple(
58
+ normalized_children
59
+ )
60
+ self._enabled = enabled
61
+ self._tools = self._normalize_tools(tools)
62
+
63
+ def is_enabled(self, params: ParamsT) -> bool:
64
+ """Return True when the section should render for the given params."""
65
+
66
+ if self._enabled is None:
67
+ return True
68
+ return bool(self._enabled(params))
69
+
70
+ @abstractmethod
71
+ def render(self, params: ParamsT, depth: int) -> str:
72
+ """Produce markdown output for the section at the supplied depth."""
73
+
74
+ def placeholder_names(self) -> set[str]:
75
+ """Return placeholder identifiers used by the section template."""
76
+
77
+ return set()
78
+
79
+ def tools(self) -> tuple[Tool[SupportsDataclass, SupportsDataclass], ...]:
80
+ """Return the tools exposed by this section."""
81
+
82
+ return self._tools
83
+
84
+ @classmethod
85
+ def __class_getitem__(cls, item: object) -> type[Section[SupportsDataclass]]:
86
+ params_type = cls._normalize_generic_argument(item)
87
+
88
+ class _SpecializedSection(cls): # type: ignore[misc]
89
+ pass
90
+
91
+ _SpecializedSection.__name__ = cls.__name__
92
+ _SpecializedSection.__qualname__ = cls.__qualname__
93
+ _SpecializedSection.__module__ = cls.__module__
94
+ _SpecializedSection._params_type = cast(type[SupportsDataclass], params_type)
95
+ return _SpecializedSection # type: ignore[return-value]
96
+
97
+ @staticmethod
98
+ def _normalize_generic_argument(item: object) -> object:
99
+ if isinstance(item, tuple):
100
+ raise TypeError("Section[...] expects a single type argument.")
101
+ return item
102
+
103
+ @staticmethod
104
+ def _normalize_tools(
105
+ tools: Sequence[object] | None,
106
+ ) -> tuple[Tool[SupportsDataclass, SupportsDataclass], ...]:
107
+ if not tools:
108
+ return ()
109
+
110
+ from .tool import Tool
111
+
112
+ normalized: list[Tool[SupportsDataclass, SupportsDataclass]] = []
113
+ for tool in tools:
114
+ if not isinstance(tool, Tool):
115
+ raise TypeError("Section tools must be Tool instances.")
116
+ normalized.append(cast(Tool[SupportsDataclass, SupportsDataclass], tool))
117
+ return tuple(normalized)
118
+
119
+
120
+ __all__ = ["Section"]