bead 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. bead/__init__.py +11 -0
  2. bead/__main__.py +11 -0
  3. bead/active_learning/__init__.py +15 -0
  4. bead/active_learning/config.py +231 -0
  5. bead/active_learning/loop.py +566 -0
  6. bead/active_learning/models/__init__.py +24 -0
  7. bead/active_learning/models/base.py +852 -0
  8. bead/active_learning/models/binary.py +910 -0
  9. bead/active_learning/models/categorical.py +943 -0
  10. bead/active_learning/models/cloze.py +862 -0
  11. bead/active_learning/models/forced_choice.py +956 -0
  12. bead/active_learning/models/free_text.py +773 -0
  13. bead/active_learning/models/lora.py +365 -0
  14. bead/active_learning/models/magnitude.py +835 -0
  15. bead/active_learning/models/multi_select.py +795 -0
  16. bead/active_learning/models/ordinal_scale.py +811 -0
  17. bead/active_learning/models/peft_adapter.py +155 -0
  18. bead/active_learning/models/random_effects.py +639 -0
  19. bead/active_learning/selection.py +354 -0
  20. bead/active_learning/strategies.py +391 -0
  21. bead/active_learning/trainers/__init__.py +26 -0
  22. bead/active_learning/trainers/base.py +210 -0
  23. bead/active_learning/trainers/data_collator.py +172 -0
  24. bead/active_learning/trainers/dataset_utils.py +261 -0
  25. bead/active_learning/trainers/huggingface.py +304 -0
  26. bead/active_learning/trainers/lightning.py +324 -0
  27. bead/active_learning/trainers/metrics.py +424 -0
  28. bead/active_learning/trainers/mixed_effects.py +551 -0
  29. bead/active_learning/trainers/model_wrapper.py +509 -0
  30. bead/active_learning/trainers/registry.py +104 -0
  31. bead/adapters/__init__.py +11 -0
  32. bead/adapters/huggingface.py +61 -0
  33. bead/behavioral/__init__.py +116 -0
  34. bead/behavioral/analytics.py +646 -0
  35. bead/behavioral/extraction.py +343 -0
  36. bead/behavioral/merging.py +343 -0
  37. bead/cli/__init__.py +11 -0
  38. bead/cli/active_learning.py +513 -0
  39. bead/cli/active_learning_commands.py +779 -0
  40. bead/cli/completion.py +359 -0
  41. bead/cli/config.py +624 -0
  42. bead/cli/constraint_builders.py +286 -0
  43. bead/cli/deployment.py +859 -0
  44. bead/cli/deployment_trials.py +493 -0
  45. bead/cli/deployment_ui.py +332 -0
  46. bead/cli/display.py +378 -0
  47. bead/cli/items.py +960 -0
  48. bead/cli/items_factories.py +776 -0
  49. bead/cli/list_constraints.py +714 -0
  50. bead/cli/lists.py +490 -0
  51. bead/cli/main.py +430 -0
  52. bead/cli/models.py +877 -0
  53. bead/cli/resource_loaders.py +621 -0
  54. bead/cli/resources.py +1036 -0
  55. bead/cli/shell.py +356 -0
  56. bead/cli/simulate.py +840 -0
  57. bead/cli/templates.py +1158 -0
  58. bead/cli/training.py +1080 -0
  59. bead/cli/utils.py +614 -0
  60. bead/cli/workflow.py +1273 -0
  61. bead/config/__init__.py +68 -0
  62. bead/config/active_learning.py +1009 -0
  63. bead/config/config.py +192 -0
  64. bead/config/defaults.py +118 -0
  65. bead/config/deployment.py +217 -0
  66. bead/config/env.py +147 -0
  67. bead/config/item.py +45 -0
  68. bead/config/list.py +193 -0
  69. bead/config/loader.py +149 -0
  70. bead/config/logging.py +42 -0
  71. bead/config/model.py +49 -0
  72. bead/config/paths.py +46 -0
  73. bead/config/profiles.py +320 -0
  74. bead/config/resources.py +47 -0
  75. bead/config/serialization.py +210 -0
  76. bead/config/simulation.py +206 -0
  77. bead/config/template.py +238 -0
  78. bead/config/validation.py +267 -0
  79. bead/data/__init__.py +65 -0
  80. bead/data/base.py +87 -0
  81. bead/data/identifiers.py +97 -0
  82. bead/data/language_codes.py +61 -0
  83. bead/data/metadata.py +270 -0
  84. bead/data/range.py +123 -0
  85. bead/data/repository.py +358 -0
  86. bead/data/serialization.py +249 -0
  87. bead/data/timestamps.py +89 -0
  88. bead/data/validation.py +349 -0
  89. bead/data_collection/__init__.py +11 -0
  90. bead/data_collection/jatos.py +223 -0
  91. bead/data_collection/merger.py +154 -0
  92. bead/data_collection/prolific.py +198 -0
  93. bead/deployment/__init__.py +5 -0
  94. bead/deployment/distribution.py +402 -0
  95. bead/deployment/jatos/__init__.py +1 -0
  96. bead/deployment/jatos/api.py +200 -0
  97. bead/deployment/jatos/exporter.py +210 -0
  98. bead/deployment/jspsych/__init__.py +9 -0
  99. bead/deployment/jspsych/biome.json +44 -0
  100. bead/deployment/jspsych/config.py +411 -0
  101. bead/deployment/jspsych/generator.py +598 -0
  102. bead/deployment/jspsych/package.json +51 -0
  103. bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
  104. bead/deployment/jspsych/randomizer.py +299 -0
  105. bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
  106. bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
  107. bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
  108. bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
  109. bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
  110. bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
  111. bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
  112. bead/deployment/jspsych/src/plugins/rating.ts +248 -0
  113. bead/deployment/jspsych/src/slopit/index.ts +9 -0
  114. bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
  115. bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
  116. bead/deployment/jspsych/templates/experiment.css +1 -0
  117. bead/deployment/jspsych/templates/experiment.js.template +289 -0
  118. bead/deployment/jspsych/templates/index.html +51 -0
  119. bead/deployment/jspsych/templates/randomizer.js +241 -0
  120. bead/deployment/jspsych/templates/randomizer.js.template +313 -0
  121. bead/deployment/jspsych/trials.py +723 -0
  122. bead/deployment/jspsych/tsconfig.json +23 -0
  123. bead/deployment/jspsych/tsup.config.ts +30 -0
  124. bead/deployment/jspsych/ui/__init__.py +1 -0
  125. bead/deployment/jspsych/ui/components.py +383 -0
  126. bead/deployment/jspsych/ui/styles.py +411 -0
  127. bead/dsl/__init__.py +80 -0
  128. bead/dsl/ast.py +168 -0
  129. bead/dsl/context.py +178 -0
  130. bead/dsl/errors.py +71 -0
  131. bead/dsl/evaluator.py +570 -0
  132. bead/dsl/grammar.lark +81 -0
  133. bead/dsl/parser.py +231 -0
  134. bead/dsl/stdlib.py +929 -0
  135. bead/evaluation/__init__.py +13 -0
  136. bead/evaluation/convergence.py +485 -0
  137. bead/evaluation/interannotator.py +398 -0
  138. bead/items/__init__.py +40 -0
  139. bead/items/adapters/__init__.py +70 -0
  140. bead/items/adapters/anthropic.py +224 -0
  141. bead/items/adapters/api_utils.py +167 -0
  142. bead/items/adapters/base.py +216 -0
  143. bead/items/adapters/google.py +259 -0
  144. bead/items/adapters/huggingface.py +1074 -0
  145. bead/items/adapters/openai.py +323 -0
  146. bead/items/adapters/registry.py +202 -0
  147. bead/items/adapters/sentence_transformers.py +224 -0
  148. bead/items/adapters/togetherai.py +309 -0
  149. bead/items/binary.py +515 -0
  150. bead/items/cache.py +558 -0
  151. bead/items/categorical.py +593 -0
  152. bead/items/cloze.py +757 -0
  153. bead/items/constructor.py +784 -0
  154. bead/items/forced_choice.py +413 -0
  155. bead/items/free_text.py +681 -0
  156. bead/items/generation.py +432 -0
  157. bead/items/item.py +396 -0
  158. bead/items/item_template.py +787 -0
  159. bead/items/magnitude.py +573 -0
  160. bead/items/multi_select.py +621 -0
  161. bead/items/ordinal_scale.py +569 -0
  162. bead/items/scoring.py +448 -0
  163. bead/items/validation.py +723 -0
  164. bead/lists/__init__.py +30 -0
  165. bead/lists/balancer.py +263 -0
  166. bead/lists/constraints.py +1067 -0
  167. bead/lists/experiment_list.py +286 -0
  168. bead/lists/list_collection.py +378 -0
  169. bead/lists/partitioner.py +1141 -0
  170. bead/lists/stratification.py +254 -0
  171. bead/participants/__init__.py +73 -0
  172. bead/participants/collection.py +699 -0
  173. bead/participants/merging.py +312 -0
  174. bead/participants/metadata_spec.py +491 -0
  175. bead/participants/models.py +276 -0
  176. bead/resources/__init__.py +29 -0
  177. bead/resources/adapters/__init__.py +19 -0
  178. bead/resources/adapters/base.py +104 -0
  179. bead/resources/adapters/cache.py +128 -0
  180. bead/resources/adapters/glazing.py +508 -0
  181. bead/resources/adapters/registry.py +117 -0
  182. bead/resources/adapters/unimorph.py +796 -0
  183. bead/resources/classification.py +856 -0
  184. bead/resources/constraint_builders.py +329 -0
  185. bead/resources/constraints.py +165 -0
  186. bead/resources/lexical_item.py +223 -0
  187. bead/resources/lexicon.py +744 -0
  188. bead/resources/loaders.py +209 -0
  189. bead/resources/template.py +441 -0
  190. bead/resources/template_collection.py +707 -0
  191. bead/resources/template_generation.py +349 -0
  192. bead/simulation/__init__.py +29 -0
  193. bead/simulation/annotators/__init__.py +15 -0
  194. bead/simulation/annotators/base.py +175 -0
  195. bead/simulation/annotators/distance_based.py +135 -0
  196. bead/simulation/annotators/lm_based.py +114 -0
  197. bead/simulation/annotators/oracle.py +182 -0
  198. bead/simulation/annotators/random.py +181 -0
  199. bead/simulation/dsl_extension/__init__.py +3 -0
  200. bead/simulation/noise_models/__init__.py +13 -0
  201. bead/simulation/noise_models/base.py +42 -0
  202. bead/simulation/noise_models/random_noise.py +82 -0
  203. bead/simulation/noise_models/systematic.py +132 -0
  204. bead/simulation/noise_models/temperature.py +86 -0
  205. bead/simulation/runner.py +144 -0
  206. bead/simulation/strategies/__init__.py +23 -0
  207. bead/simulation/strategies/base.py +123 -0
  208. bead/simulation/strategies/binary.py +103 -0
  209. bead/simulation/strategies/categorical.py +123 -0
  210. bead/simulation/strategies/cloze.py +224 -0
  211. bead/simulation/strategies/forced_choice.py +127 -0
  212. bead/simulation/strategies/free_text.py +105 -0
  213. bead/simulation/strategies/magnitude.py +116 -0
  214. bead/simulation/strategies/multi_select.py +129 -0
  215. bead/simulation/strategies/ordinal_scale.py +131 -0
  216. bead/templates/__init__.py +27 -0
  217. bead/templates/adapters/__init__.py +17 -0
  218. bead/templates/adapters/base.py +128 -0
  219. bead/templates/adapters/cache.py +178 -0
  220. bead/templates/adapters/huggingface.py +312 -0
  221. bead/templates/combinatorics.py +103 -0
  222. bead/templates/filler.py +605 -0
  223. bead/templates/renderers.py +177 -0
  224. bead/templates/resolver.py +178 -0
  225. bead/templates/strategies.py +1806 -0
  226. bead/templates/streaming.py +195 -0
  227. bead-0.1.0.dist-info/METADATA +212 -0
  228. bead-0.1.0.dist-info/RECORD +231 -0
  229. bead-0.1.0.dist-info/WHEEL +4 -0
  230. bead-0.1.0.dist-info/entry_points.txt +2 -0
  231. bead-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,784 @@
1
+ """Item constructor for building experimental items from templates.
2
+
3
+ This module provides the ItemConstructor class which transforms filled templates
4
+ into experimental items by applying model-based constraints and collecting
5
+ model outputs for analysis.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from collections.abc import Iterator
11
+ from datetime import UTC, datetime
12
+ from uuid import UUID
13
+
14
+ import numpy as np
15
+
16
+ from bead.dsl.ast import (
17
+ ASTNode,
18
+ AttributeAccess,
19
+ BinaryOp,
20
+ FunctionCall,
21
+ Literal,
22
+ UnaryOp,
23
+ Variable,
24
+ )
25
+ from bead.dsl.context import EvaluationContext
26
+ from bead.dsl.evaluator import Evaluator
27
+ from bead.dsl.parser import parse
28
+ from bead.dsl.stdlib import register_stdlib
29
+ from bead.items.adapters.registry import ModelAdapterRegistry
30
+ from bead.items.cache import ModelOutputCache
31
+ from bead.items.item import Item, MetadataValue, ModelOutput
32
+ from bead.items.item_template import ItemTemplate
33
+ from bead.resources.constraints import Constraint
34
+ from bead.templates.filler import FilledTemplate
35
+ from bead.templates.resolver import ConstraintResolver
36
+
37
+
38
+ class ItemConstructor:
39
+ """Construct experimental items from filled templates.
40
+
41
+ Transforms filled templates into items by:
42
+ 1. Resolving element references to text
43
+ 2. Computing required model outputs (from constraints)
44
+ 3. Evaluating constraints with model outputs
45
+ 4. Creating Item instances with metadata
46
+
47
+ Parameters
48
+ ----------
49
+ model_registry : ModelAdapterRegistry
50
+ Registry of model adapters for constraint evaluation.
51
+ cache : ModelOutputCache
52
+ Cache for model outputs to avoid redundant computation.
53
+ constraint_resolver : ConstraintResolver | None, optional
54
+ Resolver for evaluating non-model constraints. If None, only
55
+ model-based constraints can be evaluated.
56
+
57
+ Attributes
58
+ ----------
59
+ model_registry : ModelAdapterRegistry
60
+ Registry of model adapters for constraint evaluation.
61
+ cache : ModelOutputCache
62
+ Cache for model outputs to avoid redundant computation.
63
+ constraint_resolver : ConstraintResolver | None
64
+ Resolver for evaluating constraints (not used for model constraints).
65
+
66
+ Examples
67
+ --------
68
+ >>> from bead.items.adapters.registry import default_registry
69
+ >>> from bead.items.cache import ModelOutputCache
70
+ >>> cache = ModelOutputCache(backend="memory")
71
+ >>> constructor = ItemConstructor(default_registry, cache)
72
+ >>> constraints = {constraint_id: constraint_obj}
73
+ >>> items = list(constructor.construct_items(
74
+ ... template, filled_templates, constraints
75
+ ... ))
76
+ """
77
+
78
+ def __init__(
79
+ self,
80
+ model_registry: ModelAdapterRegistry,
81
+ cache: ModelOutputCache,
82
+ constraint_resolver: ConstraintResolver | None = None,
83
+ ) -> None:
84
+ self.model_registry = model_registry
85
+ self.cache = cache
86
+ self.constraint_resolver = constraint_resolver
87
+ self._dsl_evaluator = Evaluator(use_cache=True)
88
+
89
+ def construct_items(
90
+ self,
91
+ item_template: ItemTemplate,
92
+ filled_templates: dict[UUID, FilledTemplate],
93
+ constraints: dict[UUID, Constraint],
94
+ ) -> Iterator[Item]:
95
+ """Construct items from template and filled templates.
96
+
97
+ For each combination of filled templates:
98
+ 1. Render elements (resolve filled_template_ref → text)
99
+ 2. Compute required model outputs (from constraints)
100
+ 3. Check constraints using model outputs
101
+ 4. Yield item if all constraints satisfied
102
+
103
+ Parameters
104
+ ----------
105
+ item_template : ItemTemplate
106
+ Template defining item structure and constraints.
107
+ filled_templates : dict[UUID, FilledTemplate]
108
+ Map of filled template UUIDs to FilledTemplate instances.
109
+ constraints : dict[UUID, Constraint]
110
+ Map of constraint UUIDs to Constraint objects.
111
+
112
+ Yields
113
+ ------
114
+ Item
115
+ Constructed items that satisfy all constraints.
116
+
117
+ Raises
118
+ ------
119
+ ValueError
120
+ If template references missing filled templates or constraints.
121
+ RuntimeError
122
+ If constraint evaluation or model computation fails.
123
+
124
+ Examples
125
+ --------
126
+ >>> template = ItemTemplate(...)
127
+ >>> filled = {uuid1: filled1, uuid2: filled2}
128
+ >>> constraints = {c_id: constraint_obj}
129
+ >>> items = list(constructor.construct_items(
130
+ ... template, filled, constraints
131
+ ... ))
132
+ >>> len(items)
133
+ 2
134
+ """
135
+ # Render elements to text
136
+ rendered_elements = self._render_elements(item_template, filled_templates)
137
+
138
+ # Compute model outputs required by constraints
139
+ model_outputs = self._compute_model_outputs(
140
+ item_template, rendered_elements, constraints
141
+ )
142
+
143
+ # Check constraints
144
+ constraint_satisfaction = self._check_constraints(
145
+ item_template, rendered_elements, model_outputs, constraints
146
+ )
147
+
148
+ # Only yield item if all constraints satisfied
149
+ if all(constraint_satisfaction.values()):
150
+ # Create item
151
+ item = Item(
152
+ item_template_id=item_template.id,
153
+ filled_template_refs=list(filled_templates.keys()),
154
+ rendered_elements=rendered_elements,
155
+ model_outputs=model_outputs,
156
+ constraint_satisfaction=constraint_satisfaction,
157
+ )
158
+ yield item
159
+
160
+ def _render_elements(
161
+ self,
162
+ item_template: ItemTemplate,
163
+ filled_templates: dict[UUID, FilledTemplate],
164
+ ) -> dict[str, str]:
165
+ """Render ItemElements to text.
166
+
167
+ Resolve element references: text elements use content directly,
168
+ filled_template_ref elements use the rendered text from FilledTemplate.
169
+
170
+ Parameters
171
+ ----------
172
+ item_template : ItemTemplate
173
+ Template with elements to render.
174
+ filled_templates : dict[UUID, FilledTemplate]
175
+ Map of filled template UUIDs to instances.
176
+
177
+ Returns
178
+ -------
179
+ dict[str, str]
180
+ Map of element names to rendered text.
181
+
182
+ Raises
183
+ ------
184
+ ValueError
185
+ If element references missing filled template.
186
+ """
187
+ rendered: dict[str, str] = {}
188
+
189
+ for element in item_template.elements:
190
+ if element.is_text:
191
+ # Static text element
192
+ rendered[element.element_name] = element.content or ""
193
+ elif element.is_template_ref:
194
+ # Reference to filled template
195
+ ref_id = element.filled_template_ref_id
196
+ if ref_id is None:
197
+ raise ValueError(
198
+ f"Element {element.element_name} has no filled_template_ref_id"
199
+ )
200
+ if ref_id not in filled_templates:
201
+ raise ValueError(
202
+ f"Element {element.element_name} references missing "
203
+ f"filled template {ref_id}"
204
+ )
205
+ filled_template = filled_templates[ref_id]
206
+ rendered[element.element_name] = filled_template.rendered_text
207
+
208
+ return rendered
209
+
210
+ def _compute_model_outputs(
211
+ self,
212
+ item_template: ItemTemplate,
213
+ rendered_elements: dict[str, str],
214
+ constraints: dict[UUID, Constraint],
215
+ ) -> list[ModelOutput]:
216
+ """Execute model calls required by constraints.
217
+
218
+ Parse DSL constraints to find model function calls, then execute
219
+ them via adapters with caching.
220
+
221
+ Parameters
222
+ ----------
223
+ item_template : ItemTemplate
224
+ Template with constraints.
225
+ rendered_elements : dict[str, str]
226
+ Rendered element text.
227
+ constraints : dict[UUID, Constraint]
228
+ Map of constraint UUIDs to Constraint objects.
229
+
230
+ Returns
231
+ -------
232
+ list[ModelOutput]
233
+ All model outputs computed for this item.
234
+
235
+ Raises
236
+ ------
237
+ RuntimeError
238
+ If model computation fails.
239
+ ValueError
240
+ If constraint UUID not found in constraints dict.
241
+ """
242
+ model_outputs: list[ModelOutput] = []
243
+
244
+ # Extract model calls from all DSL constraints
245
+ for constraint_id in item_template.constraints:
246
+ if constraint_id not in constraints:
247
+ raise ValueError(f"Constraint {constraint_id} not found")
248
+
249
+ constraint = constraints[constraint_id]
250
+
251
+ # Parse constraint expression to AST
252
+ try:
253
+ ast_node = parse(constraint.expression)
254
+ except Exception as e:
255
+ raise RuntimeError(
256
+ f"Failed to parse constraint '{constraint.expression}': {e}"
257
+ ) from e
258
+
259
+ # Extract all model function calls from AST
260
+ model_calls = self._extract_model_calls(ast_node, rendered_elements)
261
+
262
+ # Execute each model call
263
+ for call in model_calls:
264
+ try:
265
+ output = self._execute_model_call(call)
266
+ if output:
267
+ model_outputs.append(output)
268
+ except Exception as e:
269
+ raise RuntimeError(
270
+ f"Failed to execute model call {call}: {e}"
271
+ ) from e
272
+
273
+ return model_outputs
274
+
275
+ def _extract_model_calls(
276
+ self, ast_node: ASTNode, rendered_elements: dict[str, str]
277
+ ) -> list[dict[str, str | int | float | bool | None]]:
278
+ """Extract model function calls from AST.
279
+
280
+ Recursively traverse AST to find calls to model functions
281
+ (lm_prob, nli, similarity, etc.) and extract their arguments.
282
+
283
+ Parameters
284
+ ----------
285
+ ast_node : ASTNode
286
+ AST node to traverse.
287
+ rendered_elements : dict[str, str]
288
+ Rendered elements for variable resolution.
289
+
290
+ Returns
291
+ -------
292
+ list[dict[str, str | int | float | bool | None]]
293
+ List of model call specifications with function name and arguments.
294
+ """
295
+ calls: list[dict[str, str | int | float | bool | None]] = []
296
+
297
+ if isinstance(ast_node, FunctionCall):
298
+ # Check if this is a model function call
299
+ # Function can be Variable (for functions) or AttributeAccess (for methods)
300
+ if isinstance(ast_node.function, Variable):
301
+ func_name: str = ast_node.function.name
302
+ elif isinstance(ast_node.function, AttributeAccess):
303
+ func_name = ast_node.function.attribute
304
+ else:
305
+ # Skip other function call types
306
+ return calls
307
+
308
+ model_functions = {
309
+ "lm_prob",
310
+ "lm_perplexity",
311
+ "nli",
312
+ "similarity",
313
+ "embedding",
314
+ }
315
+ if func_name in model_functions:
316
+ # Extract arguments
317
+ call_spec = self._extract_call_args(
318
+ func_name, ast_node.arguments, rendered_elements
319
+ )
320
+ if call_spec:
321
+ calls.append(call_spec)
322
+
323
+ # Also check arguments for nested calls
324
+ for arg in ast_node.arguments:
325
+ calls.extend(self._extract_model_calls(arg, rendered_elements))
326
+
327
+ # Recursively check other node types
328
+ elif isinstance(ast_node, BinaryOp):
329
+ calls.extend(self._extract_model_calls(ast_node.left, rendered_elements))
330
+ calls.extend(self._extract_model_calls(ast_node.right, rendered_elements))
331
+ elif isinstance(ast_node, UnaryOp):
332
+ calls.extend(self._extract_model_calls(ast_node.operand, rendered_elements))
333
+ elif isinstance(ast_node, AttributeAccess):
334
+ calls.extend(self._extract_model_calls(ast_node.object, rendered_elements))
335
+
336
+ return calls
337
+
338
+ def _extract_call_args(
339
+ self,
340
+ func_name: str,
341
+ args: list[ASTNode],
342
+ rendered_elements: dict[str, str],
343
+ ) -> dict[str, str | int | float | bool | None] | None:
344
+ """Extract arguments from a model function call.
345
+
346
+ Parameters
347
+ ----------
348
+ func_name : str
349
+ Name of the function.
350
+ args : list[ASTNode]
351
+ AST nodes representing function arguments.
352
+ rendered_elements : dict[str, str]
353
+ Rendered elements for variable resolution.
354
+
355
+ Returns
356
+ -------
357
+ dict[str, Any] | None
358
+ Call specification with function, args, and model name.
359
+ """
360
+ # Resolve literal values and variables
361
+ resolved_args: list[str | int | float | bool | None] = []
362
+ for arg in args:
363
+ if isinstance(arg, Literal):
364
+ resolved_args.append(arg.value)
365
+ elif isinstance(arg, Variable):
366
+ # Try to resolve from rendered elements
367
+ if arg.name in rendered_elements:
368
+ resolved_args.append(rendered_elements[arg.name])
369
+ else:
370
+ # Can't resolve, skip this call
371
+ return None
372
+ else:
373
+ # Complex expression, can't extract statically
374
+ return None
375
+
376
+ # Build call specification based on function type
377
+ if func_name in {"lm_prob", "lm_perplexity"}:
378
+ # lm_prob(text, model='gpt2')
379
+ if len(resolved_args) == 0:
380
+ return None
381
+ text = str(resolved_args[0])
382
+ model = str(resolved_args[1]) if len(resolved_args) > 1 else "gpt2"
383
+ operation = "log_probability" if func_name == "lm_prob" else "perplexity"
384
+ return {
385
+ "function": func_name,
386
+ "text": text,
387
+ "model": model,
388
+ "operation": operation,
389
+ }
390
+
391
+ elif func_name == "nli":
392
+ # nli(premise, hypothesis, model='roberta-large-mnli')
393
+ if len(resolved_args) < 2:
394
+ return None
395
+ premise = str(resolved_args[0])
396
+ hypothesis = str(resolved_args[1])
397
+ default_nli_model = "roberta-large-mnli"
398
+ model = (
399
+ str(resolved_args[2]) if len(resolved_args) > 2 else default_nli_model
400
+ )
401
+ return {
402
+ "function": func_name,
403
+ "premise": premise,
404
+ "hypothesis": hypothesis,
405
+ "model": model,
406
+ "operation": "nli",
407
+ }
408
+
409
+ elif func_name == "similarity":
410
+ # similarity(text1, text2, model='all-MiniLM-L6-v2')
411
+ if len(resolved_args) < 2:
412
+ return None
413
+ text1 = str(resolved_args[0])
414
+ text2 = str(resolved_args[1])
415
+ model = (
416
+ str(resolved_args[2]) if len(resolved_args) > 2 else "all-MiniLM-L6-v2"
417
+ )
418
+ return {
419
+ "function": func_name,
420
+ "text1": text1,
421
+ "text2": text2,
422
+ "model": model,
423
+ "operation": "similarity",
424
+ }
425
+
426
+ elif func_name == "embedding":
427
+ # embedding(text, model='all-MiniLM-L6-v2')
428
+ if len(resolved_args) == 0:
429
+ return None
430
+ text = str(resolved_args[0])
431
+ model = (
432
+ str(resolved_args[1]) if len(resolved_args) > 1 else "all-MiniLM-L6-v2"
433
+ )
434
+ return {
435
+ "function": func_name,
436
+ "text": text,
437
+ "model": model,
438
+ "operation": "embedding",
439
+ }
440
+
441
+ return None
442
+
443
+ def _execute_model_call(
444
+ self, call_spec: dict[str, str | int | float | bool | None]
445
+ ) -> ModelOutput | None:
446
+ """Execute a single model call and return ModelOutput.
447
+
448
+ Parameters
449
+ ----------
450
+ call_spec : dict[str, str | int | float | bool | None]
451
+ Call specification with function, args, and model.
452
+
453
+ Returns
454
+ -------
455
+ ModelOutput | None
456
+ Model output if successful, None if already cached or failed.
457
+
458
+ Raises
459
+ ------
460
+ RuntimeError
461
+ If model execution fails.
462
+ """
463
+ operation = str(call_spec["operation"])
464
+ model_name = str(call_spec["model"])
465
+
466
+ # Determine adapter type based on operation
467
+ if operation in {"log_probability", "perplexity"}:
468
+ adapter_type = "huggingface_lm"
469
+ elif operation == "nli":
470
+ adapter_type = "huggingface_nli"
471
+ elif operation in {"similarity", "embedding"}:
472
+ adapter_type = "sentence_transformer"
473
+ else:
474
+ raise ValueError(f"Unknown operation: {operation}")
475
+
476
+ # Check cache first
477
+ cache_key_args: dict[str, str | int | float | bool | None] = {}
478
+ if operation in {"log_probability", "perplexity"}:
479
+ cache_key_args = {"text": call_spec["text"]}
480
+ elif operation == "nli":
481
+ cache_key_args = {
482
+ "premise": call_spec["premise"],
483
+ "hypothesis": call_spec["hypothesis"],
484
+ }
485
+ elif operation == "similarity":
486
+ cache_key_args = {
487
+ "text1": call_spec["text1"],
488
+ "text2": call_spec["text2"],
489
+ }
490
+ elif operation == "embedding":
491
+ cache_key_args = {"text": call_spec["text"]}
492
+
493
+ cached_result = self.cache.get(model_name, operation, **cache_key_args)
494
+ if cached_result is not None:
495
+ # Already cached, create ModelOutput from cache
496
+ cache_key = self.cache.generate_cache_key(
497
+ model_name, operation, **cache_key_args
498
+ )
499
+ # Convert inputs to MetadataValue compatible dict
500
+ metadata_inputs: dict[str, MetadataValue] = {
501
+ k: str(v) for k, v in cache_key_args.items()
502
+ }
503
+ return ModelOutput(
504
+ model_name=model_name,
505
+ model_version="unknown", # Could fetch from cache
506
+ operation=operation,
507
+ inputs=metadata_inputs,
508
+ output=cached_result,
509
+ cache_key=cache_key,
510
+ computation_metadata={
511
+ "timestamp": datetime.now(UTC).isoformat(),
512
+ "from_cache": True,
513
+ },
514
+ )
515
+
516
+ # Get adapter and execute
517
+ adapter = self.model_registry.get_adapter(
518
+ adapter_type=adapter_type,
519
+ model_name=model_name,
520
+ cache=self.cache,
521
+ )
522
+
523
+ # Execute the operation
524
+ if operation == "log_probability":
525
+ result = adapter.compute_log_probability(str(call_spec["text"]))
526
+ elif operation == "perplexity":
527
+ result = adapter.compute_perplexity(str(call_spec["text"]))
528
+ elif operation == "nli":
529
+ result = adapter.compute_nli(
530
+ str(call_spec["premise"]), str(call_spec["hypothesis"])
531
+ )
532
+ elif operation == "similarity":
533
+ result = adapter.compute_similarity(
534
+ str(call_spec["text1"]), str(call_spec["text2"])
535
+ )
536
+ elif operation == "embedding":
537
+ result = adapter.get_embedding(str(call_spec["text"]))
538
+ else:
539
+ raise ValueError(f"Unknown operation: {operation}")
540
+
541
+ # Generate cache key
542
+ cache_key = self.cache.generate_cache_key(
543
+ model_name, operation, **cache_key_args
544
+ )
545
+
546
+ # Convert inputs to MetadataValue compatible dict
547
+ metadata_inputs: dict[str, MetadataValue] = {
548
+ k: str(v) for k, v in cache_key_args.items()
549
+ }
550
+
551
+ # Create ModelOutput
552
+ model_version = (
553
+ adapter.model_version if hasattr(adapter, "model_version") else "unknown"
554
+ )
555
+ return ModelOutput(
556
+ model_name=model_name,
557
+ model_version=model_version,
558
+ operation=operation,
559
+ inputs=metadata_inputs,
560
+ output=result, # type: ignore[arg-type] # Output can be various types
561
+ cache_key=cache_key,
562
+ computation_metadata={
563
+ "timestamp": datetime.now(UTC).isoformat(),
564
+ "from_cache": False,
565
+ },
566
+ )
567
+
568
+ def _check_constraints(
569
+ self,
570
+ item_template: ItemTemplate,
571
+ rendered_elements: dict[str, str],
572
+ model_outputs: list[ModelOutput],
573
+ constraints: dict[UUID, Constraint],
574
+ ) -> dict[UUID, bool]:
575
+ """Evaluate constraints using model outputs.
576
+
577
+ Check each constraint against rendered elements and model outputs.
578
+
579
+ Parameters
580
+ ----------
581
+ item_template : ItemTemplate
582
+ Template with constraints.
583
+ rendered_elements : dict[str, str]
584
+ Rendered element text.
585
+ model_outputs : list[ModelOutput]
586
+ Model outputs to use in constraint evaluation.
587
+ constraints : dict[UUID, Constraint]
588
+ Map of constraint UUIDs to Constraint objects.
589
+
590
+ Returns
591
+ -------
592
+ dict[UUID, bool]
593
+ Map of constraint UUIDs to satisfaction status.
594
+
595
+ Raises
596
+ ------
597
+ RuntimeError
598
+ If constraint evaluation fails.
599
+ ValueError
600
+ If constraint UUID not found.
601
+ """
602
+ constraint_satisfaction: dict[UUID, bool] = {}
603
+
604
+ # Evaluate each constraint
605
+ for constraint_id in item_template.constraints:
606
+ if constraint_id not in constraints:
607
+ raise ValueError(f"Constraint {constraint_id} not found")
608
+
609
+ constraint = constraints[constraint_id]
610
+
611
+ # Evaluate constraint
612
+ satisfied = self._evaluate_dsl_constraint(
613
+ constraint, rendered_elements, model_outputs
614
+ )
615
+ constraint_satisfaction[constraint_id] = satisfied
616
+
617
+ return constraint_satisfaction
618
+
619
+ def _evaluate_dsl_constraint(
620
+ self,
621
+ constraint: Constraint,
622
+ rendered_elements: dict[str, str],
623
+ model_outputs: list[ModelOutput],
624
+ ) -> bool:
625
+ """Evaluate a DSL constraint with model outputs.
626
+
627
+ Parse and evaluate DSL expression with element variables and
628
+ model output values in context.
629
+
630
+ Parameters
631
+ ----------
632
+ constraint : Constraint
633
+ Constraint to evaluate.
634
+ rendered_elements : dict[str, str]
635
+ Rendered element text for variable substitution.
636
+ model_outputs : list[ModelOutput]
637
+ Model outputs to include in context.
638
+
639
+ Returns
640
+ -------
641
+ bool
642
+ True if constraint is satisfied.
643
+
644
+ Raises
645
+ ------
646
+ RuntimeError
647
+ If DSL evaluation fails.
648
+ """
649
+ # Create evaluation context
650
+ context = EvaluationContext()
651
+
652
+ # Register standard library
653
+ register_stdlib(context)
654
+
655
+ # Register model functions that will use cached outputs
656
+ self._register_model_functions(context, model_outputs)
657
+
658
+ # Set element variables
659
+ for name, text in rendered_elements.items():
660
+ context.set_variable(name, text)
661
+
662
+ # Parse and evaluate
663
+ try:
664
+ ast_node = parse(constraint.expression)
665
+ result = self._dsl_evaluator.evaluate(ast_node, context)
666
+ return bool(result)
667
+ except Exception as e:
668
+ raise RuntimeError(
669
+ f"Failed to evaluate constraint '{constraint.expression}': {e}"
670
+ ) from e
671
+
672
+ def _register_model_functions(
673
+ self,
674
+ context: EvaluationContext,
675
+ model_outputs: list[ModelOutput],
676
+ ) -> None:
677
+ """Register model functions in DSL context.
678
+
679
+ Add functions like lm_prob(), nli(), similarity() that can access
680
+ precomputed model outputs from cache.
681
+
682
+ Parameters
683
+ ----------
684
+ context : EvaluationContext
685
+ DSL evaluation context.
686
+ model_outputs : list[ModelOutput]
687
+ Precomputed model outputs.
688
+ """
689
+ # Create lookup for model outputs
690
+ output_map: dict[tuple[str, str, str], ModelOutput] = {}
691
+ for output in model_outputs:
692
+ # Key includes model, operation, and stringified inputs
693
+ inputs_str = str(sorted(output.inputs.items()))
694
+ key = (output.model_name, output.operation, inputs_str)
695
+ output_map[key] = output
696
+
697
+ # Define model functions that use cached outputs
698
+ def lm_prob(text: str, model: str = "gpt2") -> float:
699
+ """Get log probability from cache or compute."""
700
+ # Check cache first
701
+ cached = self.cache.get(model, "log_probability", text=text)
702
+ if cached is not None:
703
+ return float(cached)
704
+
705
+ # Compute if not cached
706
+ adapter = self.model_registry.get_adapter(
707
+ adapter_type="huggingface_lm",
708
+ model_name=model,
709
+ cache=self.cache,
710
+ )
711
+ result = adapter.compute_log_probability(text)
712
+ return result
713
+
714
+ def lm_perplexity(text: str, model: str = "gpt2") -> float:
715
+ """Get perplexity from cache or compute."""
716
+ cached = self.cache.get(model, "perplexity", text=text)
717
+ if cached is not None:
718
+ return float(cached)
719
+
720
+ adapter = self.model_registry.get_adapter(
721
+ adapter_type="huggingface_lm",
722
+ model_name=model,
723
+ cache=self.cache,
724
+ )
725
+ result = adapter.compute_perplexity(text)
726
+ return result
727
+
728
+ def nli(
729
+ premise: str, hypothesis: str, model: str = "roberta-large-mnli"
730
+ ) -> dict[str, float]:
731
+ """Get NLI scores from cache or compute."""
732
+ cached = self.cache.get(
733
+ model, "nli", premise=premise, hypothesis=hypothesis
734
+ )
735
+ if cached is not None:
736
+ return dict(cached) # type: ignore[arg-type]
737
+
738
+ adapter = self.model_registry.get_adapter(
739
+ adapter_type="huggingface_nli",
740
+ model_name=model,
741
+ cache=self.cache,
742
+ )
743
+ result = adapter.compute_nli(premise, hypothesis)
744
+ return result
745
+
746
+ def similarity(
747
+ text1: str, text2: str, model: str = "all-MiniLM-L6-v2"
748
+ ) -> float:
749
+ """Get similarity from cache or compute."""
750
+ cached = self.cache.get(model, "similarity", text1=text1, text2=text2)
751
+ if cached is not None:
752
+ return float(cached)
753
+
754
+ adapter = self.model_registry.get_adapter(
755
+ adapter_type="sentence_transformer",
756
+ model_name=model,
757
+ cache=self.cache,
758
+ )
759
+ result = adapter.compute_similarity(text1, text2)
760
+ return result
761
+
762
+ def embedding(text: str, model: str = "all-MiniLM-L6-v2") -> list[float]:
763
+ """Get embedding from cache or compute."""
764
+ cached = self.cache.get(model, "embedding", text=text)
765
+ if cached is not None:
766
+ # Convert numpy array back to list
767
+ if isinstance(cached, np.ndarray):
768
+ return cached.tolist() # type: ignore[return-value]
769
+ return list(cached) # type: ignore[arg-type]
770
+
771
+ adapter = self.model_registry.get_adapter(
772
+ adapter_type="sentence_transformer",
773
+ model_name=model,
774
+ cache=self.cache,
775
+ )
776
+ result = adapter.get_embedding(text)
777
+ return result.tolist() # type: ignore[return-value]
778
+
779
+ # Register functions in context
780
+ context.set_function("lm_prob", lm_prob)
781
+ context.set_function("lm_perplexity", lm_perplexity)
782
+ context.set_function("nli", nli)
783
+ context.set_function("similarity", similarity)
784
+ context.set_function("embedding", embedding)