weakincentives 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. weakincentives/__init__.py +67 -0
  2. weakincentives/adapters/__init__.py +37 -0
  3. weakincentives/adapters/_names.py +32 -0
  4. weakincentives/adapters/_provider_protocols.py +69 -0
  5. weakincentives/adapters/_tool_messages.py +80 -0
  6. weakincentives/adapters/core.py +102 -0
  7. weakincentives/adapters/litellm.py +254 -0
  8. weakincentives/adapters/openai.py +254 -0
  9. weakincentives/adapters/shared.py +1021 -0
  10. weakincentives/cli/__init__.py +23 -0
  11. weakincentives/cli/wink.py +58 -0
  12. weakincentives/dbc/__init__.py +412 -0
  13. weakincentives/deadlines.py +58 -0
  14. weakincentives/prompt/__init__.py +105 -0
  15. weakincentives/prompt/_generic_params_specializer.py +64 -0
  16. weakincentives/prompt/_normalization.py +48 -0
  17. weakincentives/prompt/_overrides_protocols.py +33 -0
  18. weakincentives/prompt/_types.py +34 -0
  19. weakincentives/prompt/chapter.py +146 -0
  20. weakincentives/prompt/composition.py +281 -0
  21. weakincentives/prompt/errors.py +57 -0
  22. weakincentives/prompt/markdown.py +108 -0
  23. weakincentives/prompt/overrides/__init__.py +59 -0
  24. weakincentives/prompt/overrides/_fs.py +164 -0
  25. weakincentives/prompt/overrides/inspection.py +141 -0
  26. weakincentives/prompt/overrides/local_store.py +275 -0
  27. weakincentives/prompt/overrides/validation.py +534 -0
  28. weakincentives/prompt/overrides/versioning.py +269 -0
  29. weakincentives/prompt/prompt.py +353 -0
  30. weakincentives/prompt/protocols.py +103 -0
  31. weakincentives/prompt/registry.py +375 -0
  32. weakincentives/prompt/rendering.py +288 -0
  33. weakincentives/prompt/response_format.py +60 -0
  34. weakincentives/prompt/section.py +166 -0
  35. weakincentives/prompt/structured_output.py +179 -0
  36. weakincentives/prompt/tool.py +397 -0
  37. weakincentives/prompt/tool_result.py +30 -0
  38. weakincentives/py.typed +0 -0
  39. weakincentives/runtime/__init__.py +82 -0
  40. weakincentives/runtime/events/__init__.py +126 -0
  41. weakincentives/runtime/events/_types.py +110 -0
  42. weakincentives/runtime/logging.py +284 -0
  43. weakincentives/runtime/session/__init__.py +46 -0
  44. weakincentives/runtime/session/_slice_types.py +24 -0
  45. weakincentives/runtime/session/_types.py +55 -0
  46. weakincentives/runtime/session/dataclasses.py +29 -0
  47. weakincentives/runtime/session/protocols.py +34 -0
  48. weakincentives/runtime/session/reducer_context.py +40 -0
  49. weakincentives/runtime/session/reducers.py +82 -0
  50. weakincentives/runtime/session/selectors.py +56 -0
  51. weakincentives/runtime/session/session.py +387 -0
  52. weakincentives/runtime/session/snapshots.py +310 -0
  53. weakincentives/serde/__init__.py +19 -0
  54. weakincentives/serde/_utils.py +240 -0
  55. weakincentives/serde/dataclass_serde.py +55 -0
  56. weakincentives/serde/dump.py +189 -0
  57. weakincentives/serde/parse.py +417 -0
  58. weakincentives/serde/schema.py +260 -0
  59. weakincentives/tools/__init__.py +154 -0
  60. weakincentives/tools/_context.py +38 -0
  61. weakincentives/tools/asteval.py +853 -0
  62. weakincentives/tools/errors.py +26 -0
  63. weakincentives/tools/planning.py +831 -0
  64. weakincentives/tools/podman.py +1655 -0
  65. weakincentives/tools/subagents.py +346 -0
  66. weakincentives/tools/vfs.py +1390 -0
  67. weakincentives/types/__init__.py +35 -0
  68. weakincentives/types/json.py +45 -0
  69. weakincentives-0.9.0.dist-info/METADATA +775 -0
  70. weakincentives-0.9.0.dist-info/RECORD +73 -0
  71. weakincentives-0.9.0.dist-info/WHEEL +4 -0
  72. weakincentives-0.9.0.dist-info/entry_points.txt +2 -0
  73. weakincentives-0.9.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,288 @@
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ """Rendering helpers for :mod:`weakincentives.prompt`."""
14
+
15
+ from __future__ import annotations
16
+
17
+ from collections.abc import Callable, Iterator, Mapping, MutableMapping
18
+ from dataclasses import dataclass, field, is_dataclass, replace
19
+ from types import MappingProxyType
20
+ from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast, override
21
+
22
+ from ..deadlines import Deadline
23
+ from ._types import SupportsDataclass, SupportsToolResult
24
+ from .errors import PromptRenderError, PromptValidationError, SectionPath
25
+ from .registry import RegistrySnapshot, SectionNode
26
+ from .response_format import ResponseFormatSection
27
+ from .structured_output import StructuredOutputConfig
28
+ from .tool import Tool
29
+
30
+ if TYPE_CHECKING: # pragma: no cover - typing only
31
+ from .overrides import ToolOverride
32
+
33
+
34
+ _EMPTY_TOOL_PARAM_DESCRIPTIONS: Mapping[str, Mapping[str, str]] = MappingProxyType({})
35
+
36
+
37
+ OutputT_co = TypeVar("OutputT_co", covariant=True)
38
+
39
+
40
+ @dataclass(frozen=True, slots=True)
41
+ class RenderedPrompt[OutputT_co]:
42
+ """Rendered prompt text paired with structured output metadata."""
43
+
44
+ text: str
45
+ structured_output: StructuredOutputConfig[SupportsDataclass] | None = None
46
+ deadline: Deadline | None = None
47
+ _tools: tuple[Tool[SupportsDataclass, SupportsToolResult], ...] = field(
48
+ default_factory=tuple
49
+ )
50
+ _tool_param_descriptions: Mapping[str, Mapping[str, str]] = field(
51
+ default=_EMPTY_TOOL_PARAM_DESCRIPTIONS
52
+ )
53
+
54
+ @override
55
+ def __str__(self) -> str: # pragma: no cover - delegated behaviour
56
+ return self.text
57
+
58
+ @property
59
+ def tools(self) -> tuple[Tool[SupportsDataclass, SupportsToolResult], ...]:
60
+ """Tools contributed by enabled sections in traversal order."""
61
+
62
+ return self._tools
63
+
64
+ @property
65
+ def tool_param_descriptions(
66
+ self,
67
+ ) -> Mapping[str, Mapping[str, str]]:
68
+ """Description patches keyed by tool name."""
69
+
70
+ return self._tool_param_descriptions
71
+
72
+ @property
73
+ def output_type(self) -> type[SupportsDataclass] | None:
74
+ """Return the declared dataclass type for structured output."""
75
+
76
+ if self.structured_output is None:
77
+ return None
78
+ return self.structured_output.dataclass_type
79
+
80
+ @property
81
+ def container(self) -> Literal["object", "array"] | None:
82
+ """Return the declared container for structured output."""
83
+
84
+ if self.structured_output is None:
85
+ return None
86
+ return self.structured_output.container
87
+
88
+ @property
89
+ def allow_extra_keys(self) -> bool | None:
90
+ """Return whether extra keys are allowed in structured output."""
91
+
92
+ if self.structured_output is None:
93
+ return None
94
+ return self.structured_output.allow_extra_keys
95
+
96
+
97
+ def _freeze_tool_param_descriptions(
98
+ descriptions: Mapping[str, dict[str, str]],
99
+ ) -> Mapping[str, Mapping[str, str]]:
100
+ if not descriptions:
101
+ return MappingProxyType({})
102
+ frozen: dict[str, Mapping[str, str]] = {}
103
+ for name, field_mapping in descriptions.items():
104
+ frozen[name] = MappingProxyType(dict(field_mapping))
105
+ return MappingProxyType(frozen)
106
+
107
+
108
+ OutputT = TypeVar("OutputT")
109
+
110
+
111
+ class PromptRenderer[OutputT]:
112
+ """Render prompts using a registry snapshot."""
113
+
114
+ def __init__(
115
+ self,
116
+ *,
117
+ registry: RegistrySnapshot,
118
+ structured_output: StructuredOutputConfig[SupportsDataclass] | None,
119
+ response_section: ResponseFormatSection | None,
120
+ ) -> None:
121
+ super().__init__()
122
+ self._registry = registry
123
+ self._structured_output = structured_output
124
+ self._response_section: ResponseFormatSection | None = response_section
125
+
126
+ def build_param_lookup(
127
+ self, params: tuple[SupportsDataclass, ...]
128
+ ) -> dict[type[SupportsDataclass], SupportsDataclass]:
129
+ lookup: dict[type[SupportsDataclass], SupportsDataclass] = {}
130
+ for value in params:
131
+ if isinstance(value, type):
132
+ provided_type: type[Any] = value
133
+ else:
134
+ provided_type = type(value)
135
+ if isinstance(value, type) or not is_dataclass(value):
136
+ raise PromptValidationError(
137
+ "Prompt expects dataclass instances.",
138
+ dataclass_type=provided_type,
139
+ )
140
+ params_type = cast(type[SupportsDataclass], provided_type)
141
+ if params_type in lookup:
142
+ raise PromptValidationError(
143
+ "Duplicate params type supplied to prompt.",
144
+ dataclass_type=params_type,
145
+ )
146
+ if params_type not in self._registry.param_types:
147
+ raise PromptValidationError(
148
+ "Unexpected params type supplied to prompt.",
149
+ dataclass_type=params_type,
150
+ )
151
+ lookup[params_type] = value
152
+ return lookup
153
+
154
+ def render(
155
+ self,
156
+ param_lookup: Mapping[type[SupportsDataclass], SupportsDataclass],
157
+ overrides: Mapping[SectionPath, str] | None = None,
158
+ tool_overrides: Mapping[str, ToolOverride] | None = None,
159
+ *,
160
+ inject_output_instructions: bool | None = None,
161
+ ) -> RenderedPrompt[OutputT]:
162
+ rendered_sections: list[str] = []
163
+ collected_tools: list[Tool[SupportsDataclass, SupportsToolResult]] = []
164
+ override_lookup = dict(overrides or {})
165
+ tool_override_lookup = dict(tool_overrides or {})
166
+ field_description_patches: dict[str, dict[str, str]] = {}
167
+
168
+ for node, section_params in self._iter_enabled_sections(
169
+ dict(param_lookup),
170
+ inject_output_instructions=inject_output_instructions,
171
+ ):
172
+ override_body = (
173
+ override_lookup.get(node.path)
174
+ if getattr(node.section, "accepts_overrides", True)
175
+ else None
176
+ )
177
+ rendered = self._render_section(node, section_params, override_body)
178
+
179
+ section_tools = node.section.tools()
180
+ if section_tools:
181
+ for tool in section_tools:
182
+ override = (
183
+ tool_override_lookup.get(tool.name)
184
+ if tool.accepts_overrides
185
+ else None
186
+ )
187
+ patched_tool = tool
188
+ if override is not None:
189
+ if (
190
+ override.description is not None
191
+ and override.description != tool.description
192
+ ):
193
+ patched_tool = replace(
194
+ tool, description=override.description
195
+ )
196
+ if override.param_descriptions:
197
+ field_description_patches[tool.name] = dict(
198
+ override.param_descriptions
199
+ )
200
+ collected_tools.append(patched_tool)
201
+
202
+ if rendered:
203
+ rendered_sections.append(rendered)
204
+
205
+ text = "\n\n".join(rendered_sections)
206
+
207
+ return RenderedPrompt[OutputT](
208
+ text=text,
209
+ structured_output=self._structured_output,
210
+ _tools=tuple(collected_tools),
211
+ _tool_param_descriptions=_freeze_tool_param_descriptions(
212
+ field_description_patches
213
+ ),
214
+ )
215
+
216
+ def _iter_enabled_sections(
217
+ self,
218
+ param_lookup: MutableMapping[type[SupportsDataclass], SupportsDataclass],
219
+ *,
220
+ inject_output_instructions: bool | None = None,
221
+ ) -> Iterator[tuple[SectionNode[SupportsDataclass], SupportsDataclass | None]]:
222
+ skip_depth: int | None = None
223
+
224
+ for node in self._registry.sections:
225
+ if skip_depth is not None:
226
+ if node.depth > skip_depth:
227
+ continue
228
+ skip_depth = None
229
+
230
+ section_params = self._registry.resolve_section_params(node, param_lookup)
231
+
232
+ if node.section is self._response_section and (
233
+ inject_output_instructions is not None
234
+ ):
235
+ enabled = inject_output_instructions
236
+ else:
237
+ try:
238
+ enabled = node.section.is_enabled(section_params)
239
+ except Exception as error: # pragma: no cover - defensive
240
+ raise PromptRenderError(
241
+ "Section enabled predicate failed.",
242
+ section_path=node.path,
243
+ dataclass_type=node.section.param_type,
244
+ ) from error
245
+
246
+ if not enabled:
247
+ skip_depth = node.depth
248
+ continue
249
+
250
+ yield node, section_params
251
+
252
+ def _render_section(
253
+ self,
254
+ node: SectionNode[SupportsDataclass],
255
+ section_params: SupportsDataclass | None,
256
+ override_body: str | None,
257
+ ) -> str:
258
+ params_type = node.section.param_type
259
+ try:
260
+ render_override = getattr(node.section, "render_with_template", None)
261
+ if override_body is not None and callable(render_override):
262
+ override_renderer = cast(
263
+ Callable[[str, SupportsDataclass | None, int], str],
264
+ render_override,
265
+ )
266
+ rendered = override_renderer(override_body, section_params, node.depth)
267
+ else:
268
+ rendered = node.section.render(section_params, node.depth)
269
+ except PromptRenderError as error:
270
+ if error.section_path and error.dataclass_type:
271
+ raise
272
+ raise PromptRenderError(
273
+ error.message,
274
+ section_path=node.path,
275
+ dataclass_type=params_type,
276
+ placeholder=error.placeholder,
277
+ ) from error
278
+ except Exception as error: # pragma: no cover - defensive
279
+ raise PromptRenderError(
280
+ "Section rendering failed.",
281
+ section_path=node.path,
282
+ dataclass_type=params_type,
283
+ ) from error
284
+
285
+ return rendered
286
+
287
+
288
+ __all__ = ["PromptRenderer", "RenderedPrompt"]
@@ -0,0 +1,60 @@
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ from __future__ import annotations
14
+
15
+ from collections.abc import Callable
16
+ from dataclasses import dataclass
17
+ from typing import Final, Literal
18
+
19
+ from ._types import SupportsDataclass
20
+ from .markdown import MarkdownSection
21
+
22
+ __all__ = ["ResponseFormatParams", "ResponseFormatSection"]
23
+
24
+
25
+ @dataclass(slots=True)
26
+ class ResponseFormatParams:
27
+ """Parameter payload for the auto-generated response format section."""
28
+
29
+ article: Literal["a", "an"]
30
+ container: Literal["object", "array"]
31
+ extra_clause: str
32
+
33
+
34
+ _RESPONSE_FORMAT_BODY: Final[
35
+ str
36
+ ] = """Return ONLY a single fenced JSON code block. Do not include any text
37
+ before or after the block.
38
+
39
+ The top-level JSON value MUST be ${article} ${container} that matches the fields
40
+ of the expected schema${extra_clause}"""
41
+
42
+
43
+ class ResponseFormatSection(MarkdownSection[ResponseFormatParams]):
44
+ """Internal section that appends JSON-only response instructions."""
45
+
46
+ def __init__(
47
+ self,
48
+ *,
49
+ params: ResponseFormatParams,
50
+ enabled: Callable[[SupportsDataclass], bool] | None = None,
51
+ accepts_overrides: bool = False,
52
+ ) -> None:
53
+ super().__init__(
54
+ title="Response Format",
55
+ key="response-format",
56
+ template=_RESPONSE_FORMAT_BODY,
57
+ default_params=params,
58
+ enabled=enabled,
59
+ accepts_overrides=accepts_overrides,
60
+ )
@@ -0,0 +1,166 @@
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ from __future__ import annotations
14
+
15
+ import inspect
16
+ from abc import ABC, abstractmethod
17
+ from collections.abc import Callable, Sequence
18
+ from typing import TYPE_CHECKING, ClassVar, TypeVar, cast
19
+
20
+ if TYPE_CHECKING:
21
+ from .tool import Tool
22
+
23
+ from ._generic_params_specializer import GenericParamsSpecializer
24
+ from ._normalization import normalize_component_key
25
+ from ._types import SupportsDataclass, SupportsToolResult
26
+
27
+ SectionParamsT = TypeVar("SectionParamsT", bound=SupportsDataclass, covariant=True)
28
+
29
+ EnabledPredicate = Callable[[SupportsDataclass], bool] | Callable[[], bool]
30
+
31
+
32
+ class Section(GenericParamsSpecializer[SectionParamsT], ABC):
33
+ """Abstract building block for prompt content."""
34
+
35
+ _generic_owner_name: ClassVar[str | None] = "Section"
36
+
37
+ def __init__(
38
+ self,
39
+ *,
40
+ title: str,
41
+ key: str,
42
+ default_params: SectionParamsT | None = None,
43
+ children: Sequence[Section[SupportsDataclass]] | None = None,
44
+ enabled: EnabledPredicate | None = None,
45
+ tools: Sequence[object] | None = None,
46
+ accepts_overrides: bool = True,
47
+ ) -> None:
48
+ super().__init__()
49
+ params_candidate = getattr(self.__class__, "_params_type", None)
50
+ candidate_type = (
51
+ params_candidate if isinstance(params_candidate, type) else None
52
+ )
53
+ params_type = cast(type[SupportsDataclass] | None, candidate_type)
54
+
55
+ self.params_type: type[SectionParamsT] | None = cast(
56
+ type[SectionParamsT] | None, params_type
57
+ )
58
+ self.param_type: type[SectionParamsT] | None = self.params_type
59
+ self.title = title
60
+ self.key = self._normalize_key(key)
61
+ self.default_params = default_params
62
+ self.accepts_overrides = accepts_overrides
63
+
64
+ if self.params_type is None and self.default_params is not None:
65
+ raise TypeError("Section without parameters cannot define default_params.")
66
+
67
+ normalized_children: list[Section[SupportsDataclass]] = []
68
+ raw_children: Sequence[object] = cast(Sequence[object], children or ())
69
+ for child in raw_children:
70
+ if not isinstance(child, Section):
71
+ raise TypeError("Section children must be Section instances.")
72
+ normalized_children.append(cast(Section[SupportsDataclass], child))
73
+ self.children = tuple(normalized_children)
74
+ self._enabled: Callable[[SupportsDataclass | None], bool] | None = (
75
+ self._normalize_enabled(enabled, params_type)
76
+ )
77
+ self._tools = self._normalize_tools(tools)
78
+
79
+ def is_enabled(self, params: SupportsDataclass | None) -> bool:
80
+ """Return True when the section should render for the given params."""
81
+
82
+ if self._enabled is None:
83
+ return True
84
+ return bool(self._enabled(params))
85
+
86
+ @abstractmethod
87
+ def render(self, params: SupportsDataclass | None, depth: int) -> str:
88
+ """Produce markdown output for the section at the supplied depth."""
89
+
90
+ def placeholder_names(self) -> set[str]:
91
+ """Return placeholder identifiers used by the section template."""
92
+
93
+ return set()
94
+
95
+ def tools(self) -> tuple[Tool[SupportsDataclass, SupportsToolResult], ...]:
96
+ """Return the tools exposed by this section."""
97
+
98
+ return self._tools
99
+
100
+ def original_body_template(self) -> str | None:
101
+ """Return the template text that participates in hashing, when available."""
102
+
103
+ return None
104
+
105
+ @staticmethod
106
+ def _normalize_key(key: str) -> str:
107
+ return normalize_component_key(key, owner="Section")
108
+
109
+ @staticmethod
110
+ def _normalize_tools(
111
+ tools: Sequence[object] | None,
112
+ ) -> tuple[Tool[SupportsDataclass, SupportsToolResult], ...]:
113
+ if not tools:
114
+ return ()
115
+
116
+ from .tool import Tool
117
+
118
+ normalized: list[Tool[SupportsDataclass, SupportsToolResult]] = []
119
+ for tool in tools:
120
+ if not isinstance(tool, Tool):
121
+ raise TypeError("Section tools must be Tool instances.")
122
+ normalized.append(cast(Tool[SupportsDataclass, SupportsToolResult], tool))
123
+ return tuple(normalized)
124
+
125
+ @staticmethod
126
+ def _normalize_enabled(
127
+ enabled: EnabledPredicate | None,
128
+ params_type: type[SupportsDataclass] | None,
129
+ ) -> Callable[[SupportsDataclass | None], bool] | None:
130
+ if enabled is None:
131
+ return None
132
+ if params_type is None and not _callable_requires_positional_argument(enabled):
133
+ zero_arg = cast(Callable[[], bool], enabled)
134
+
135
+ def _without_params(_: SupportsDataclass | None) -> bool:
136
+ return bool(zero_arg())
137
+
138
+ return _without_params
139
+
140
+ coerced = cast(Callable[[SupportsDataclass], bool], enabled)
141
+
142
+ def _with_params(value: SupportsDataclass | None) -> bool:
143
+ return bool(coerced(cast(SupportsDataclass, value)))
144
+
145
+ return _with_params
146
+
147
+
148
+ def _callable_requires_positional_argument(callback: EnabledPredicate) -> bool:
149
+ try:
150
+ signature = inspect.signature(callback)
151
+ except (TypeError, ValueError):
152
+ return True
153
+ for parameter in signature.parameters.values():
154
+ if (
155
+ parameter.kind
156
+ in (
157
+ inspect.Parameter.POSITIONAL_ONLY,
158
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
159
+ )
160
+ and parameter.default is inspect.Signature.empty
161
+ ):
162
+ return True
163
+ return False
164
+
165
+
166
+ __all__ = ["Section"]
@@ -0,0 +1,179 @@
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ from __future__ import annotations
14
+
15
+ import json
16
+ import re
17
+ from collections.abc import Mapping, Sequence
18
+ from dataclasses import dataclass
19
+ from typing import Final, Literal, Protocol, TypeVar, cast
20
+
21
+ from ..serde.parse import parse as parse_dataclass
22
+ from ..types import JSONValue, ParseableDataclassT
23
+ from ._types import SupportsDataclass
24
+
25
+ __all__ = [
26
+ "ARRAY_WRAPPER_KEY",
27
+ "OutputParseError",
28
+ "StructuredOutputConfig",
29
+ "parse_structured_output",
30
+ ]
31
+
32
+ ARRAY_WRAPPER_KEY: Final[str] = "items"
33
+
34
+ _JSON_FENCE_PATTERN: Final[re.Pattern[str]] = re.compile(
35
+ r"```json\s*\n(.*?)```", re.IGNORECASE | re.DOTALL
36
+ )
37
+
38
+
39
+ DataclassT = TypeVar("DataclassT", bound=SupportsDataclass)
40
+
41
+
42
+ @dataclass(frozen=True, slots=True)
43
+ class StructuredOutputConfig[DataclassT]:
44
+ """Resolved structured output declaration for a prompt."""
45
+
46
+ dataclass_type: type[DataclassT]
47
+ container: Literal["object", "array"]
48
+ allow_extra_keys: bool
49
+
50
+
51
+ class StructuredRenderedPrompt[PayloadT](Protocol):
52
+ @property
53
+ def structured_output(self) -> StructuredOutputConfig[SupportsDataclass] | None:
54
+ """Structured output metadata declared by the prompt."""
55
+
56
+
57
+ class OutputParseError(Exception):
58
+ """Raised when structured output parsing fails."""
59
+
60
+ def __init__(
61
+ self,
62
+ message: str,
63
+ *,
64
+ dataclass_type: type[SupportsDataclass] | None = None,
65
+ ) -> None:
66
+ super().__init__(message)
67
+ self.message = message
68
+ self.dataclass_type = dataclass_type
69
+
70
+
71
+ def parse_structured_output[PayloadT](
72
+ output_text: str, rendered: StructuredRenderedPrompt[PayloadT]
73
+ ) -> PayloadT:
74
+ """Parse a model response into the structured output type declared by the prompt."""
75
+
76
+ config = rendered.structured_output
77
+ if config is None:
78
+ raise OutputParseError("Prompt does not declare structured output.")
79
+
80
+ dataclass_type = config.dataclass_type
81
+ container = config.container
82
+ allow_extra_keys = config.allow_extra_keys
83
+ payload = _extract_json_payload(output_text, dataclass_type)
84
+ try:
85
+ parsed = parse_dataclass_payload(
86
+ dataclass_type,
87
+ container,
88
+ payload,
89
+ allow_extra_keys=allow_extra_keys,
90
+ object_error="Expected top-level JSON object.",
91
+ array_error="Expected top-level JSON array.",
92
+ array_item_error="Array item at index {index} is not an object.",
93
+ )
94
+ except (TypeError, ValueError) as error:
95
+ raise OutputParseError(str(error), dataclass_type=dataclass_type) from error
96
+
97
+ return cast(PayloadT, parsed)
98
+
99
+
100
+ def _extract_json_payload(
101
+ text: str, dataclass_type: type[SupportsDataclass]
102
+ ) -> JSONValue:
103
+ fenced_match = _JSON_FENCE_PATTERN.search(text)
104
+ if fenced_match is not None:
105
+ block = fenced_match.group(1).strip()
106
+ try:
107
+ return json.loads(block)
108
+ except json.JSONDecodeError as error:
109
+ raise OutputParseError(
110
+ "Failed to decode JSON from fenced code block.",
111
+ dataclass_type=dataclass_type,
112
+ ) from error
113
+
114
+ stripped = text.strip()
115
+ if stripped:
116
+ try:
117
+ return json.loads(stripped)
118
+ except json.JSONDecodeError:
119
+ pass
120
+
121
+ decoder = json.JSONDecoder()
122
+ for index, character in enumerate(text):
123
+ if character not in "{[":
124
+ continue
125
+ try:
126
+ payload, _ = decoder.raw_decode(text, index)
127
+ except json.JSONDecodeError:
128
+ continue
129
+ return cast(JSONValue, payload)
130
+
131
+ raise OutputParseError(
132
+ "No JSON object or array found in assistant message.",
133
+ dataclass_type=dataclass_type,
134
+ )
135
+
136
+
137
+ def parse_dataclass_payload(
138
+ dataclass_type: type[ParseableDataclassT],
139
+ container: Literal["object", "array"],
140
+ payload: JSONValue,
141
+ *,
142
+ allow_extra_keys: bool,
143
+ object_error: str,
144
+ array_error: str,
145
+ array_item_error: str,
146
+ ) -> ParseableDataclassT | list[ParseableDataclassT]:
147
+ if container not in {"object", "array"}:
148
+ raise TypeError("Unknown output container declared.")
149
+
150
+ extra_mode: Literal["ignore", "forbid"] = "ignore" if allow_extra_keys else "forbid"
151
+
152
+ if container == "object":
153
+ if not isinstance(payload, Mapping):
154
+ raise TypeError(object_error)
155
+ mapping_payload = cast(Mapping[str, JSONValue], payload)
156
+ return parse_dataclass(dataclass_type, mapping_payload, extra=extra_mode)
157
+
158
+ if isinstance(payload, Mapping):
159
+ mapping_payload = cast(Mapping[str, JSONValue], payload)
160
+ if ARRAY_WRAPPER_KEY not in mapping_payload:
161
+ raise TypeError(array_error)
162
+ payload = mapping_payload[ARRAY_WRAPPER_KEY]
163
+ if not isinstance(payload, Sequence) or isinstance(
164
+ payload, (str, bytes, bytearray)
165
+ ):
166
+ raise TypeError(array_error)
167
+ sequence_payload = cast(Sequence[JSONValue], payload)
168
+ parsed_items: list[ParseableDataclassT] = []
169
+ for index, item in enumerate(sequence_payload):
170
+ if not isinstance(item, Mapping):
171
+ raise TypeError(array_item_error.format(index=index))
172
+ mapping_item = cast(Mapping[str, JSONValue], item)
173
+ parsed_item = parse_dataclass(
174
+ dataclass_type,
175
+ mapping_item,
176
+ extra=extra_mode,
177
+ )
178
+ parsed_items.append(parsed_item)
179
+ return parsed_items