bead 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. bead/__init__.py +11 -0
  2. bead/__main__.py +11 -0
  3. bead/active_learning/__init__.py +15 -0
  4. bead/active_learning/config.py +231 -0
  5. bead/active_learning/loop.py +566 -0
  6. bead/active_learning/models/__init__.py +24 -0
  7. bead/active_learning/models/base.py +852 -0
  8. bead/active_learning/models/binary.py +910 -0
  9. bead/active_learning/models/categorical.py +943 -0
  10. bead/active_learning/models/cloze.py +862 -0
  11. bead/active_learning/models/forced_choice.py +956 -0
  12. bead/active_learning/models/free_text.py +773 -0
  13. bead/active_learning/models/lora.py +365 -0
  14. bead/active_learning/models/magnitude.py +835 -0
  15. bead/active_learning/models/multi_select.py +795 -0
  16. bead/active_learning/models/ordinal_scale.py +811 -0
  17. bead/active_learning/models/peft_adapter.py +155 -0
  18. bead/active_learning/models/random_effects.py +639 -0
  19. bead/active_learning/selection.py +354 -0
  20. bead/active_learning/strategies.py +391 -0
  21. bead/active_learning/trainers/__init__.py +26 -0
  22. bead/active_learning/trainers/base.py +210 -0
  23. bead/active_learning/trainers/data_collator.py +172 -0
  24. bead/active_learning/trainers/dataset_utils.py +261 -0
  25. bead/active_learning/trainers/huggingface.py +304 -0
  26. bead/active_learning/trainers/lightning.py +324 -0
  27. bead/active_learning/trainers/metrics.py +424 -0
  28. bead/active_learning/trainers/mixed_effects.py +551 -0
  29. bead/active_learning/trainers/model_wrapper.py +509 -0
  30. bead/active_learning/trainers/registry.py +104 -0
  31. bead/adapters/__init__.py +11 -0
  32. bead/adapters/huggingface.py +61 -0
  33. bead/behavioral/__init__.py +116 -0
  34. bead/behavioral/analytics.py +646 -0
  35. bead/behavioral/extraction.py +343 -0
  36. bead/behavioral/merging.py +343 -0
  37. bead/cli/__init__.py +11 -0
  38. bead/cli/active_learning.py +513 -0
  39. bead/cli/active_learning_commands.py +779 -0
  40. bead/cli/completion.py +359 -0
  41. bead/cli/config.py +624 -0
  42. bead/cli/constraint_builders.py +286 -0
  43. bead/cli/deployment.py +859 -0
  44. bead/cli/deployment_trials.py +493 -0
  45. bead/cli/deployment_ui.py +332 -0
  46. bead/cli/display.py +378 -0
  47. bead/cli/items.py +960 -0
  48. bead/cli/items_factories.py +776 -0
  49. bead/cli/list_constraints.py +714 -0
  50. bead/cli/lists.py +490 -0
  51. bead/cli/main.py +430 -0
  52. bead/cli/models.py +877 -0
  53. bead/cli/resource_loaders.py +621 -0
  54. bead/cli/resources.py +1036 -0
  55. bead/cli/shell.py +356 -0
  56. bead/cli/simulate.py +840 -0
  57. bead/cli/templates.py +1158 -0
  58. bead/cli/training.py +1080 -0
  59. bead/cli/utils.py +614 -0
  60. bead/cli/workflow.py +1273 -0
  61. bead/config/__init__.py +68 -0
  62. bead/config/active_learning.py +1009 -0
  63. bead/config/config.py +192 -0
  64. bead/config/defaults.py +118 -0
  65. bead/config/deployment.py +217 -0
  66. bead/config/env.py +147 -0
  67. bead/config/item.py +45 -0
  68. bead/config/list.py +193 -0
  69. bead/config/loader.py +149 -0
  70. bead/config/logging.py +42 -0
  71. bead/config/model.py +49 -0
  72. bead/config/paths.py +46 -0
  73. bead/config/profiles.py +320 -0
  74. bead/config/resources.py +47 -0
  75. bead/config/serialization.py +210 -0
  76. bead/config/simulation.py +206 -0
  77. bead/config/template.py +238 -0
  78. bead/config/validation.py +267 -0
  79. bead/data/__init__.py +65 -0
  80. bead/data/base.py +87 -0
  81. bead/data/identifiers.py +97 -0
  82. bead/data/language_codes.py +61 -0
  83. bead/data/metadata.py +270 -0
  84. bead/data/range.py +123 -0
  85. bead/data/repository.py +358 -0
  86. bead/data/serialization.py +249 -0
  87. bead/data/timestamps.py +89 -0
  88. bead/data/validation.py +349 -0
  89. bead/data_collection/__init__.py +11 -0
  90. bead/data_collection/jatos.py +223 -0
  91. bead/data_collection/merger.py +154 -0
  92. bead/data_collection/prolific.py +198 -0
  93. bead/deployment/__init__.py +5 -0
  94. bead/deployment/distribution.py +402 -0
  95. bead/deployment/jatos/__init__.py +1 -0
  96. bead/deployment/jatos/api.py +200 -0
  97. bead/deployment/jatos/exporter.py +210 -0
  98. bead/deployment/jspsych/__init__.py +9 -0
  99. bead/deployment/jspsych/biome.json +44 -0
  100. bead/deployment/jspsych/config.py +411 -0
  101. bead/deployment/jspsych/generator.py +598 -0
  102. bead/deployment/jspsych/package.json +51 -0
  103. bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
  104. bead/deployment/jspsych/randomizer.py +299 -0
  105. bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
  106. bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
  107. bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
  108. bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
  109. bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
  110. bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
  111. bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
  112. bead/deployment/jspsych/src/plugins/rating.ts +248 -0
  113. bead/deployment/jspsych/src/slopit/index.ts +9 -0
  114. bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
  115. bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
  116. bead/deployment/jspsych/templates/experiment.css +1 -0
  117. bead/deployment/jspsych/templates/experiment.js.template +289 -0
  118. bead/deployment/jspsych/templates/index.html +51 -0
  119. bead/deployment/jspsych/templates/randomizer.js +241 -0
  120. bead/deployment/jspsych/templates/randomizer.js.template +313 -0
  121. bead/deployment/jspsych/trials.py +723 -0
  122. bead/deployment/jspsych/tsconfig.json +23 -0
  123. bead/deployment/jspsych/tsup.config.ts +30 -0
  124. bead/deployment/jspsych/ui/__init__.py +1 -0
  125. bead/deployment/jspsych/ui/components.py +383 -0
  126. bead/deployment/jspsych/ui/styles.py +411 -0
  127. bead/dsl/__init__.py +80 -0
  128. bead/dsl/ast.py +168 -0
  129. bead/dsl/context.py +178 -0
  130. bead/dsl/errors.py +71 -0
  131. bead/dsl/evaluator.py +570 -0
  132. bead/dsl/grammar.lark +81 -0
  133. bead/dsl/parser.py +231 -0
  134. bead/dsl/stdlib.py +929 -0
  135. bead/evaluation/__init__.py +13 -0
  136. bead/evaluation/convergence.py +485 -0
  137. bead/evaluation/interannotator.py +398 -0
  138. bead/items/__init__.py +40 -0
  139. bead/items/adapters/__init__.py +70 -0
  140. bead/items/adapters/anthropic.py +224 -0
  141. bead/items/adapters/api_utils.py +167 -0
  142. bead/items/adapters/base.py +216 -0
  143. bead/items/adapters/google.py +259 -0
  144. bead/items/adapters/huggingface.py +1074 -0
  145. bead/items/adapters/openai.py +323 -0
  146. bead/items/adapters/registry.py +202 -0
  147. bead/items/adapters/sentence_transformers.py +224 -0
  148. bead/items/adapters/togetherai.py +309 -0
  149. bead/items/binary.py +515 -0
  150. bead/items/cache.py +558 -0
  151. bead/items/categorical.py +593 -0
  152. bead/items/cloze.py +757 -0
  153. bead/items/constructor.py +784 -0
  154. bead/items/forced_choice.py +413 -0
  155. bead/items/free_text.py +681 -0
  156. bead/items/generation.py +432 -0
  157. bead/items/item.py +396 -0
  158. bead/items/item_template.py +787 -0
  159. bead/items/magnitude.py +573 -0
  160. bead/items/multi_select.py +621 -0
  161. bead/items/ordinal_scale.py +569 -0
  162. bead/items/scoring.py +448 -0
  163. bead/items/validation.py +723 -0
  164. bead/lists/__init__.py +30 -0
  165. bead/lists/balancer.py +263 -0
  166. bead/lists/constraints.py +1067 -0
  167. bead/lists/experiment_list.py +286 -0
  168. bead/lists/list_collection.py +378 -0
  169. bead/lists/partitioner.py +1141 -0
  170. bead/lists/stratification.py +254 -0
  171. bead/participants/__init__.py +73 -0
  172. bead/participants/collection.py +699 -0
  173. bead/participants/merging.py +312 -0
  174. bead/participants/metadata_spec.py +491 -0
  175. bead/participants/models.py +276 -0
  176. bead/resources/__init__.py +29 -0
  177. bead/resources/adapters/__init__.py +19 -0
  178. bead/resources/adapters/base.py +104 -0
  179. bead/resources/adapters/cache.py +128 -0
  180. bead/resources/adapters/glazing.py +508 -0
  181. bead/resources/adapters/registry.py +117 -0
  182. bead/resources/adapters/unimorph.py +796 -0
  183. bead/resources/classification.py +856 -0
  184. bead/resources/constraint_builders.py +329 -0
  185. bead/resources/constraints.py +165 -0
  186. bead/resources/lexical_item.py +223 -0
  187. bead/resources/lexicon.py +744 -0
  188. bead/resources/loaders.py +209 -0
  189. bead/resources/template.py +441 -0
  190. bead/resources/template_collection.py +707 -0
  191. bead/resources/template_generation.py +349 -0
  192. bead/simulation/__init__.py +29 -0
  193. bead/simulation/annotators/__init__.py +15 -0
  194. bead/simulation/annotators/base.py +175 -0
  195. bead/simulation/annotators/distance_based.py +135 -0
  196. bead/simulation/annotators/lm_based.py +114 -0
  197. bead/simulation/annotators/oracle.py +182 -0
  198. bead/simulation/annotators/random.py +181 -0
  199. bead/simulation/dsl_extension/__init__.py +3 -0
  200. bead/simulation/noise_models/__init__.py +13 -0
  201. bead/simulation/noise_models/base.py +42 -0
  202. bead/simulation/noise_models/random_noise.py +82 -0
  203. bead/simulation/noise_models/systematic.py +132 -0
  204. bead/simulation/noise_models/temperature.py +86 -0
  205. bead/simulation/runner.py +144 -0
  206. bead/simulation/strategies/__init__.py +23 -0
  207. bead/simulation/strategies/base.py +123 -0
  208. bead/simulation/strategies/binary.py +103 -0
  209. bead/simulation/strategies/categorical.py +123 -0
  210. bead/simulation/strategies/cloze.py +224 -0
  211. bead/simulation/strategies/forced_choice.py +127 -0
  212. bead/simulation/strategies/free_text.py +105 -0
  213. bead/simulation/strategies/magnitude.py +116 -0
  214. bead/simulation/strategies/multi_select.py +129 -0
  215. bead/simulation/strategies/ordinal_scale.py +131 -0
  216. bead/templates/__init__.py +27 -0
  217. bead/templates/adapters/__init__.py +17 -0
  218. bead/templates/adapters/base.py +128 -0
  219. bead/templates/adapters/cache.py +178 -0
  220. bead/templates/adapters/huggingface.py +312 -0
  221. bead/templates/combinatorics.py +103 -0
  222. bead/templates/filler.py +605 -0
  223. bead/templates/renderers.py +177 -0
  224. bead/templates/resolver.py +178 -0
  225. bead/templates/strategies.py +1806 -0
  226. bead/templates/streaming.py +195 -0
  227. bead-0.1.0.dist-info/METADATA +212 -0
  228. bead-0.1.0.dist-info/RECORD +231 -0
  229. bead-0.1.0.dist-info/WHEEL +4 -0
  230. bead-0.1.0.dist-info/entry_points.txt +2 -0
  231. bead-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,1806 @@
1
+ """Filling strategies for template population."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import random
7
+ import re
8
+ import time
9
+ from abc import ABC, abstractmethod
10
+ from collections.abc import Iterator
11
+ from typing import Literal, cast
12
+ from uuid import UUID
13
+
14
+ from bead.data.language_codes import LanguageCode, validate_iso639_code
15
+ from bead.dsl.evaluator import DSLEvaluator
16
+ from bead.items.item import Item
17
+ from bead.resources.constraints import ContextValue
18
+ from bead.resources.lexical_item import LexicalItem
19
+ from bead.resources.lexicon import Lexicon
20
+ from bead.resources.template import Slot, Template
21
+ from bead.templates.adapters import HuggingFaceMLMAdapter, ModelOutputCache
22
+ from bead.templates.combinatorics import cartesian_product
23
+ from bead.templates.filler import FilledTemplate, TemplateFiller
24
+ from bead.templates.resolver import ConstraintResolver
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ # Type aliases for strategy configuration
29
+ ConfigValue = (
30
+ int
31
+ | str
32
+ | bool
33
+ | None
34
+ | list[int]
35
+ | ConstraintResolver
36
+ | HuggingFaceMLMAdapter
37
+ | ModelOutputCache
38
+ | dict[str, int]
39
+ | dict[str, bool]
40
+ )
41
+ StrategyConfig = dict[str, ConfigValue]
42
+
43
+
44
+ class FillingStrategy(ABC):
45
+ """Abstract base class for template filling strategies.
46
+
47
+ A filling strategy determines how to combine lexical items
48
+ to fill template slots. Strategies differ in:
49
+ - Selection criteria (all vs. sample)
50
+ - Ordering (deterministic vs. random)
51
+ - Grouping (balanced vs. unbalanced)
52
+
53
+ Examples
54
+ --------
55
+ >>> strategy = ExhaustiveStrategy()
56
+ >>> combinations = strategy.generate_combinations(slot_items)
57
+ >>> len(list(combinations))
58
+ 12
59
+ """
60
+
61
+ @abstractmethod
62
+ def generate_combinations(
63
+ self,
64
+ slot_items: dict[str, list[LexicalItem]],
65
+ ) -> list[dict[str, LexicalItem]]:
66
+ """Generate combinations of items for template slots.
67
+
68
+ Parameters
69
+ ----------
70
+ slot_items : dict[str, list[LexicalItem]]
71
+ Mapping of slot names to lists of valid items.
72
+
73
+ Returns
74
+ -------
75
+ list[dict[str, LexicalItem]]
76
+ List of slot-to-item mappings representing filled templates.
77
+
78
+ Examples
79
+ --------
80
+ >>> slot_items = {
81
+ ... "subject": [item1, item2],
82
+ ... "verb": [item3, item4],
83
+ ... }
84
+ >>> combinations = strategy.generate_combinations(slot_items)
85
+ >>> len(combinations)
86
+ 4
87
+ """
88
+ pass
89
+
90
+ @property
91
+ @abstractmethod
92
+ def name(self) -> str:
93
+ """Get strategy name for metadata.
94
+
95
+ Returns
96
+ -------
97
+ str
98
+ Strategy name.
99
+ """
100
+ pass
101
+
102
+
103
+ class ExhaustiveStrategy(FillingStrategy):
104
+ """Generate all possible combinations of slot fillers.
105
+
106
+ This strategy produces the complete Cartesian product of all
107
+ valid items for each slot. Use for small combinatorial spaces.
108
+
109
+ **Warning**: Combinatorial explosion! With N slots and M items
110
+ per slot, generates M^N combinations.
111
+
112
+ Examples
113
+ --------
114
+ >>> strategy = ExhaustiveStrategy()
115
+ >>> slot_items = {"a": [1, 2], "b": [3, 4]}
116
+ >>> combinations = strategy.generate_combinations(slot_items)
117
+ >>> len(combinations)
118
+ 4
119
+ >>> combinations[0]
120
+ {"a": 1, "b": 3}
121
+ """
122
+
123
+ @property
124
+ def name(self) -> str:
125
+ """Get strategy name."""
126
+ return "exhaustive"
127
+
128
+ def generate_combinations(
129
+ self,
130
+ slot_items: dict[str, list[LexicalItem]],
131
+ ) -> list[dict[str, LexicalItem]]:
132
+ """Generate all combinations.
133
+
134
+ Parameters
135
+ ----------
136
+ slot_items : dict[str, list[LexicalItem]]
137
+ Mapping of slot names to valid items.
138
+
139
+ Returns
140
+ -------
141
+ list[dict[str, LexicalItem]]
142
+ All possible slot-to-item combinations.
143
+ """
144
+ if not slot_items:
145
+ return []
146
+
147
+ # Get ordered slot names and item lists
148
+ slot_names = list(slot_items.keys())
149
+ item_lists = [slot_items[name] for name in slot_names]
150
+
151
+ # Generate all combinations
152
+ combinations: list[dict[str, LexicalItem]] = []
153
+ for combo_tuple in cartesian_product(*item_lists):
154
+ combo_dict = dict(zip(slot_names, combo_tuple, strict=True))
155
+ combinations.append(combo_dict)
156
+
157
+ return combinations
158
+
159
+
160
+ class RandomStrategy(FillingStrategy):
161
+ """Generate random sample of combinations.
162
+
163
+ Sample combinations randomly with optional seeding for
164
+ reproducibility. Use for large combinatorial spaces.
165
+
166
+ Parameters
167
+ ----------
168
+ n_samples : int
169
+ Number of combinations to generate.
170
+ seed : int | None
171
+ Random seed for reproducibility. Default: None.
172
+
173
+ Examples
174
+ --------
175
+ >>> strategy = RandomStrategy(n_samples=10, seed=42)
176
+ >>> combinations = strategy.generate_combinations(slot_items)
177
+ >>> len(combinations)
178
+ 10
179
+ """
180
+
181
+ def __init__(self, n_samples: int, seed: int | None = None) -> None:
182
+ """Initialize random strategy.
183
+
184
+ Parameters
185
+ ----------
186
+ n_samples : int
187
+ Number of combinations to generate.
188
+ seed : int | None
189
+ Random seed for reproducibility.
190
+ """
191
+ self.n_samples = n_samples
192
+ self.seed = seed
193
+
194
+ @property
195
+ def name(self) -> str:
196
+ """Get strategy name."""
197
+ return "random"
198
+
199
+ def generate_combinations(
200
+ self,
201
+ slot_items: dict[str, list[LexicalItem]],
202
+ ) -> list[dict[str, LexicalItem]]:
203
+ """Generate random combinations.
204
+
205
+ Parameters
206
+ ----------
207
+ slot_items : dict[str, list[LexicalItem]]
208
+ Mapping of slot names to valid items.
209
+
210
+ Returns
211
+ -------
212
+ list[dict[str, LexicalItem]]
213
+ Randomly sampled combinations.
214
+ """
215
+ if not slot_items:
216
+ return []
217
+
218
+ # Set random seed if provided
219
+ if self.seed is not None:
220
+ random.seed(self.seed)
221
+
222
+ # Get ordered slot names and item lists
223
+ slot_names = list(slot_items.keys())
224
+ item_lists = [slot_items[name] for name in slot_names]
225
+
226
+ # Generate random combinations
227
+ combinations: list[dict[str, LexicalItem]] = []
228
+ for _ in range(self.n_samples):
229
+ combo_tuple = tuple(random.choice(items) for items in item_lists)
230
+ combo_dict = dict(zip(slot_names, combo_tuple, strict=True))
231
+ combinations.append(combo_dict)
232
+
233
+ return combinations
234
+
235
+
236
+ class StratifiedStrategy(FillingStrategy):
237
+ """Generate balanced sample across item groups.
238
+
239
+ Ensure each group of items (e.g., by POS, features) is
240
+ represented proportionally in the sample.
241
+
242
+ Parameters
243
+ ----------
244
+ n_samples : int
245
+ Total number of combinations to generate.
246
+ grouping_property : str
247
+ Property to group items by (e.g., "pos", "features.transitivity").
248
+ seed : int | None
249
+ Random seed for reproducibility. Default: None.
250
+
251
+ Examples
252
+ --------
253
+ >>> strategy = StratifiedStrategy(
254
+ ... n_samples=20,
255
+ ... grouping_property="pos",
256
+ ... seed=42
257
+ ... )
258
+ >>> combinations = strategy.generate_combinations(slot_items)
259
+ >>> # Ensures balanced representation of different POS values
260
+ """
261
+
262
+ def __init__(
263
+ self,
264
+ n_samples: int,
265
+ grouping_property: str,
266
+ seed: int | None = None,
267
+ ) -> None:
268
+ """Initialize stratified strategy.
269
+
270
+ Parameters
271
+ ----------
272
+ n_samples : int
273
+ Total number of combinations to generate.
274
+ grouping_property : str
275
+ Property to group items by.
276
+ seed : int | None
277
+ Random seed for reproducibility.
278
+ """
279
+ self.n_samples = n_samples
280
+ self.grouping_property = grouping_property
281
+ self.seed = seed
282
+
283
+ @property
284
+ def name(self) -> str:
285
+ """Get strategy name."""
286
+ return "stratified"
287
+
288
+ def generate_combinations(
289
+ self,
290
+ slot_items: dict[str, list[LexicalItem]],
291
+ ) -> list[dict[str, LexicalItem]]:
292
+ """Generate stratified combinations.
293
+
294
+ Parameters
295
+ ----------
296
+ slot_items : dict[str, list[LexicalItem]]
297
+ Mapping of slot names to valid items.
298
+
299
+ Returns
300
+ -------
301
+ list[dict[str, LexicalItem]]
302
+ Balanced combinations across groups.
303
+ """
304
+ if not slot_items:
305
+ return []
306
+
307
+ # Set random seed if provided
308
+ if self.seed is not None:
309
+ random.seed(self.seed)
310
+
311
+ # Group items by the specified property
312
+ grouped_items: dict[str, dict[str, list[LexicalItem]]] = {}
313
+ for slot_name, items in slot_items.items():
314
+ slot_groups: dict[str, list[LexicalItem]] = {}
315
+ for item in items:
316
+ # Get property value (handle nested properties)
317
+ value = self._get_property_value(item, self.grouping_property)
318
+ if value not in slot_groups:
319
+ slot_groups[value] = []
320
+ slot_groups[value].append(item)
321
+ grouped_items[slot_name] = slot_groups
322
+
323
+ # Sample proportionally from each group
324
+ combinations: list[dict[str, LexicalItem]] = []
325
+ slot_names = list(slot_items.keys())
326
+
327
+ # Calculate samples per group
328
+ # For simplicity, sample equally from all groups
329
+ for _ in range(self.n_samples):
330
+ combo_dict: dict[str, LexicalItem] = {}
331
+ for slot_name in slot_names:
332
+ slot_groups = grouped_items[slot_name]
333
+ # Choose a random group, then a random item from that group
334
+ if slot_groups:
335
+ group_key = random.choice(list(slot_groups.keys()))
336
+ item = random.choice(slot_groups[group_key])
337
+ combo_dict[slot_name] = item
338
+ combinations.append(combo_dict)
339
+
340
+ return combinations
341
+
342
+ def _get_property_value(self, item: LexicalItem, property_path: str) -> str:
343
+ """Get property value from item, handling nested properties.
344
+
345
+ Parameters
346
+ ----------
347
+ item : LexicalItem
348
+ Item to get property from.
349
+ property_path : str
350
+ Property path (e.g., "pos" or "features.transitivity").
351
+
352
+ Returns
353
+ -------
354
+ str
355
+ Property value as string.
356
+ """
357
+ parts = property_path.split(".")
358
+ value = item
359
+ for part in parts:
360
+ if hasattr(value, part):
361
+ value = getattr(value, part)
362
+ else:
363
+ return "unknown"
364
+
365
+ # Convert to string for grouping
366
+ if value is None:
367
+ return "none"
368
+ return str(value)
369
+
370
+
371
+ class MLMFillingStrategy(FillingStrategy):
372
+ """Fill templates using masked language models with beam search.
373
+
374
+ Uses pre-trained MLMs (BERT, RoBERTa, etc.) to propose linguistically
375
+ plausible slot fillers. Supports beam search for multiple slots with
376
+ configurable fill directions.
377
+
378
+ Parameters
379
+ ----------
380
+ resolver : ConstraintResolver
381
+ Constraint resolver for filtering candidates
382
+ model_adapter : HuggingFaceMLMAdapter
383
+ Loaded MLM adapter
384
+ beam_size : int
385
+ Beam search width (K best hypotheses)
386
+ fill_direction : Literal
387
+ Direction for filling slots. One of: "left_to_right", "right_to_left",
388
+ "inside_out", "outside_in", "custom"
389
+ custom_order : list[int] | None
390
+ Custom slot fill order (slot indices)
391
+ top_k : int
392
+ Top-K candidates per slot from MLM
393
+ cache : ModelOutputCache | None
394
+ Cache for model predictions
395
+ budget : int | None
396
+ Maximum combinations to generate
397
+
398
+ Examples
399
+ --------
400
+ >>> from bead.templates.adapters import HuggingFaceMLMAdapter, ModelOutputCache
401
+ >>> adapter = HuggingFaceMLMAdapter("bert-base-uncased")
402
+ >>> adapter.load_model()
403
+ >>> cache = ModelOutputCache(Path("/tmp/cache"))
404
+ >>> strategy = MLMFillingStrategy(
405
+ ... resolver=resolver,
406
+ ... model_adapter=adapter,
407
+ ... beam_size=5,
408
+ ... fill_direction="left_to_right",
409
+ ... cache=cache
410
+ ... )
411
+ >>> combinations = strategy.generate_combinations(slot_items)
412
+ """
413
+
414
+ def __init__(
415
+ self,
416
+ resolver: ConstraintResolver,
417
+ model_adapter: HuggingFaceMLMAdapter,
418
+ beam_size: int = 5,
419
+ fill_direction: Literal[
420
+ "left_to_right", "right_to_left", "inside_out", "outside_in", "custom"
421
+ ] = "left_to_right",
422
+ custom_order: list[int] | None = None,
423
+ top_k: int = 20,
424
+ cache: ModelOutputCache | None = None,
425
+ budget: int | None = None,
426
+ per_slot_max_fills: dict[str, int] | None = None,
427
+ per_slot_enforce_unique: dict[str, bool] | None = None,
428
+ ) -> None:
429
+ """Initialize MLM strategy.
430
+
431
+ Parameters
432
+ ----------
433
+ resolver : ConstraintResolver
434
+ Constraint resolver
435
+ model_adapter : HuggingFaceMLMAdapter
436
+ MLM adapter (must be loaded)
437
+ beam_size : int
438
+ Beam width
439
+ fill_direction : str
440
+ Fill direction
441
+ custom_order : list[int] | None
442
+ Custom fill order
443
+ top_k : int
444
+ Top-K from MLM
445
+ cache : ModelOutputCache | None
446
+ Prediction cache
447
+ budget : int | None
448
+ Max combinations
449
+ per_slot_max_fills : dict[str, int] | None
450
+ Maximum number of unique fills per slot (after constraint filtering)
451
+ per_slot_enforce_unique : dict[str, bool] | None
452
+ Whether to enforce uniqueness for each slot across beam hypotheses
453
+ """
454
+ self.resolver = resolver
455
+ self.model_adapter = model_adapter
456
+ self.beam_size = beam_size
457
+ self.fill_direction = fill_direction
458
+ self.custom_order = custom_order
459
+ self.top_k = top_k
460
+ self.cache = cache
461
+ self.budget = budget
462
+ self.per_slot_max_fills = per_slot_max_fills or {}
463
+ self.per_slot_enforce_unique = per_slot_enforce_unique or {}
464
+
465
+ if not model_adapter.is_loaded():
466
+ raise ValueError("Model adapter must be loaded before use")
467
+
468
+ if fill_direction == "custom" and custom_order is None:
469
+ raise ValueError("custom_order required when fill_direction is 'custom'")
470
+
471
+ @property
472
+ def name(self) -> str:
473
+ """Get strategy name."""
474
+ return "mlm"
475
+
476
+ def generate_combinations(
477
+ self,
478
+ slot_items: dict[str, list[LexicalItem]],
479
+ ) -> list[dict[str, LexicalItem]]:
480
+ """Generate combinations using MLM beam search.
481
+
482
+ Note: This method adapts slot_items to template-based generation.
483
+ The actual beam search is implemented in generate_from_template.
484
+
485
+ Parameters
486
+ ----------
487
+ slot_items : dict[str, list[LexicalItem]]
488
+ Mapping of slot names to valid items (for constraint filtering)
489
+
490
+ Returns
491
+ -------
492
+ list[dict[str, LexicalItem]]
493
+ Combinations generated via beam search
494
+
495
+ Raises
496
+ ------
497
+ NotImplementedError
498
+ This method requires template context. Use generate_from_template instead.
499
+ """
500
+ raise NotImplementedError(
501
+ "MLMFillingStrategy requires template context. "
502
+ "Use TemplateFiller with MLMFillingStrategy, which calls "
503
+ "generate_from_template internally."
504
+ )
505
+
506
+ def generate_from_template(
507
+ self,
508
+ template: Template,
509
+ lexicons: list[Lexicon],
510
+ language_code: LanguageCode | None = None,
511
+ ) -> Iterator[dict[str, LexicalItem]]:
512
+ """Generate combinations from template using beam search.
513
+
514
+ Parameters
515
+ ----------
516
+ template : Template
517
+ Template to fill
518
+ lexicons : list[Lexicon]
519
+ Lexicons for constraint resolution
520
+ language_code : LanguageCode | None
521
+ Language filter
522
+
523
+ Yields
524
+ ------
525
+ dict[str, LexicalItem]
526
+ Slot-to-item mappings
527
+ """
528
+ logger.info(
529
+ f"[MLMFillingStrategy] Starting beam search for template: {template.name}"
530
+ )
531
+
532
+ # Get slot names and order
533
+ slot_names = list(template.slots.keys())
534
+ if not slot_names:
535
+ return
536
+
537
+ fill_order = self._get_fill_order(len(slot_names))
538
+ logger.info(
539
+ f"[MLMFillingStrategy] Slots to fill ({len(slot_names)}): {slot_names}"
540
+ )
541
+ logger.info(
542
+ f"[MLMFillingStrategy] Fill order: {[slot_names[i] for i in fill_order]}"
543
+ )
544
+ logger.info(f"[MLMFillingStrategy] Beam size: {self.beam_size}")
545
+
546
+ # Initialize beam with empty hypothesis
547
+ # Each beam item: (filled_slots_dict, cumulative_log_prob)
548
+ beam: list[tuple[dict[str, LexicalItem], float]] = [({}, 0.0)]
549
+
550
+ # Track seen items per slot (for uniqueness enforcement)
551
+ seen_items_per_slot: dict[str, set[UUID]] = {
552
+ slot_name: set() for slot_name in slot_names
553
+ }
554
+
555
+ # Fill slots in order
556
+ beam_start = time.time()
557
+ for step_num, slot_idx in enumerate(fill_order, 1):
558
+ step_start = time.time()
559
+ slot_name = slot_names[slot_idx]
560
+ slot = template.slots[slot_name]
561
+ logger.info(
562
+ f"[MLMFillingStrategy] Step {step_num}/{len(fill_order)}: Filling slot '{slot_name}', current beam size: {len(beam)}" # noqa: E501
563
+ )
564
+
565
+ new_beam: list[tuple[dict[str, LexicalItem], float]] = []
566
+
567
+ # Check if uniqueness is enforced for this slot
568
+ enforce_unique = self.per_slot_enforce_unique.get(slot_name, False)
569
+ max_fills = self.per_slot_max_fills.get(slot_name, None)
570
+ logger.info(
571
+ f"[MLMFillingStrategy] enforce_unique={enforce_unique}, max_fills={max_fills}" # noqa: E501
572
+ )
573
+
574
+ # BATCHED: Get MLM predictions for all beam hypotheses at once
575
+ if beam:
576
+ # Collect masked texts for all hypotheses
577
+ masked_start = time.time()
578
+ masked_texts = []
579
+ for filled_slots, _ in beam:
580
+ masked_text = self._create_masked_text(
581
+ template, slot_names, filled_slots, slot_idx
582
+ )
583
+ masked_texts.append(masked_text)
584
+ masked_elapsed = time.time() - masked_start
585
+
586
+ # Batch predict - single model call for entire beam
587
+ logger.info(
588
+ f"[MLMFillingStrategy] Batch predicting for {len(masked_texts)} hypotheses..." # noqa: E501
589
+ )
590
+ batch_start = time.time()
591
+ predictions_batch = self._get_mlm_predictions_batch(masked_texts)
592
+ batch_elapsed = time.time() - batch_start
593
+ logger.info(
594
+ f"[MLMFillingStrategy] Batch prediction took {batch_elapsed:.2f}s (masking took {masked_elapsed:.3f}s)" # noqa: E501
595
+ )
596
+
597
+ # Expand each hypothesis with its predictions
598
+ expand_start = time.time()
599
+ total_candidates = 0
600
+ for (filled_slots, cum_log_prob), predictions in zip(
601
+ beam, predictions_batch, strict=True
602
+ ):
603
+ # Filter predictions to get candidates
604
+ candidates = self._filter_mlm_predictions(
605
+ predictions,
606
+ slot,
607
+ lexicons,
608
+ language_code,
609
+ seen_items=seen_items_per_slot[slot_name]
610
+ if enforce_unique
611
+ else None,
612
+ max_fills=max_fills,
613
+ )
614
+ total_candidates += len(candidates)
615
+
616
+ # Add each candidate to beam
617
+ for item, log_prob in candidates:
618
+ new_filled = filled_slots.copy()
619
+ new_filled[slot_name] = item
620
+ new_log_prob = cum_log_prob + log_prob
621
+ new_beam.append((new_filled, new_log_prob))
622
+
623
+ # Track seen items if uniqueness is enforced
624
+ if enforce_unique:
625
+ seen_items_per_slot[slot_name].add(item.id)
626
+ expand_elapsed = time.time() - expand_start
627
+ logger.info(
628
+ f"[MLMFillingStrategy] Expanded beam with {total_candidates} total candidates in {expand_elapsed:.3f}s" # noqa: E501
629
+ )
630
+
631
+ # Prune beam to top-K by score (length-normalized)
632
+ prune_start = time.time()
633
+ if new_beam:
634
+ # Length-normalize scores
635
+ num_filled = len(new_beam[0][0])
636
+ scored_beam = [
637
+ (filled, log_prob / num_filled, log_prob)
638
+ for filled, log_prob in new_beam
639
+ ]
640
+ scored_beam.sort(key=lambda x: x[1], reverse=True)
641
+
642
+ # Keep top beam_size
643
+ beam = [
644
+ (filled, cum_log_prob)
645
+ for filled, _, cum_log_prob in scored_beam[: self.beam_size]
646
+ ]
647
+ prune_elapsed = time.time() - prune_start
648
+ logger.info(
649
+ f"[MLMFillingStrategy] Pruned {len(new_beam)} hypotheses to {len(beam)} in {prune_elapsed:.3f}s" # noqa: E501
650
+ )
651
+ else:
652
+ # No valid candidates - empty beam
653
+ logger.warning(
654
+ "[MLMFillingStrategy] No valid candidates found! Beam is empty."
655
+ )
656
+ beam = []
657
+ break
658
+
659
+ step_elapsed = time.time() - step_start
660
+ logger.info(
661
+ f"[MLMFillingStrategy] Step {step_num} completed in {step_elapsed:.2f}s\n" # noqa: E501
662
+ )
663
+
664
+ beam_elapsed = time.time() - beam_start
665
+ logger.info(
666
+ f"[MLMFillingStrategy] Beam search complete in {beam_elapsed:.2f}s, yielding {len(beam)} hypotheses" # noqa: E501
667
+ )
668
+
669
+ # Yield final hypotheses
670
+ count = 0
671
+ for filled_slots, _ in beam:
672
+ if self.budget and count >= self.budget:
673
+ break
674
+ yield filled_slots
675
+ count += 1
676
+
677
+ def _get_fill_order(self, num_slots: int) -> list[int]:
678
+ """Get slot fill order based on fill_direction.
679
+
680
+ Parameters
681
+ ----------
682
+ num_slots : int
683
+ Number of slots
684
+
685
+ Returns
686
+ -------
687
+ list[int]
688
+ Slot indices in fill order
689
+ """
690
+ if self.fill_direction == "custom":
691
+ if self.custom_order is None:
692
+ raise ValueError("custom_order not set")
693
+ return self.custom_order
694
+
695
+ indices = list(range(num_slots))
696
+
697
+ if self.fill_direction == "left_to_right":
698
+ return indices
699
+ elif self.fill_direction == "right_to_left":
700
+ return list(reversed(indices))
701
+ elif self.fill_direction == "inside_out":
702
+ # Alternate from center outward
703
+ mid = num_slots // 2
704
+ order: list[int] = []
705
+ for i in range(num_slots):
706
+ if i % 2 == 0:
707
+ order.append(mid + i // 2)
708
+ else:
709
+ order.append(mid - (i + 1) // 2)
710
+ return [idx for idx in order if 0 <= idx < num_slots]
711
+ elif self.fill_direction == "outside_in":
712
+ # Alternate from edges inward
713
+ order: list[int] = []
714
+ left, right = 0, num_slots - 1
715
+ while left <= right:
716
+ order.append(left)
717
+ if left != right:
718
+ order.append(right)
719
+ left += 1
720
+ right -= 1
721
+ return order
722
+ else:
723
+ raise ValueError(f"Unknown fill_direction: {self.fill_direction}")
724
+
725
+ def _get_mlm_candidates(
726
+ self,
727
+ template: Template,
728
+ slot_names: list[str],
729
+ slot_idx: int,
730
+ filled_slots: dict[str, LexicalItem],
731
+ slot: Slot,
732
+ lexicons: list[Lexicon],
733
+ language_code: LanguageCode | None,
734
+ seen_items: set[UUID] | None = None,
735
+ max_fills: int | None = None,
736
+ ) -> list[tuple[LexicalItem, float]]:
737
+ """Get MLM candidates for a slot.
738
+
739
+ Parameters
740
+ ----------
741
+ template : Template
742
+ Template being filled
743
+ slot_names : list[str]
744
+ Ordered slot names
745
+ slot_idx : int
746
+ Index of slot to fill
747
+ filled_slots : dict[str, LexicalItem]
748
+ Already-filled slots
749
+ slot : Slot
750
+ Slot object
751
+ lexicons : list[Lexicon]
752
+ Lexicons for lookup
753
+ language_code : LanguageCode | None
754
+ Language filter
755
+ seen_items : set | None
756
+ Set of item IDs already used for this slot (for uniqueness enforcement)
757
+ max_fills : int | None
758
+ Maximum number of candidates to return (applied after filtering)
759
+
760
+ Returns
761
+ -------
762
+ list[tuple[LexicalItem, float]]
763
+ (item, log_prob) pairs, limited by max_fills and uniqueness
764
+ """
765
+ # Normalize language code to ISO 639-3
766
+ if language_code is not None:
767
+ language_code = validate_iso639_code(language_code)
768
+
769
+ # Create masked text
770
+ masked_text = self._create_masked_text(
771
+ template, slot_names, filled_slots, slot_idx
772
+ )
773
+
774
+ # Get predictions from MLM (with cache)
775
+ if self.cache:
776
+ predictions = self.cache.get(
777
+ self.model_adapter.model_name,
778
+ masked_text,
779
+ 0, # First mask position
780
+ self.top_k,
781
+ )
782
+ else:
783
+ predictions = None
784
+
785
+ if predictions is None:
786
+ predictions = self.model_adapter.predict_masked_token(
787
+ masked_text,
788
+ mask_position=0,
789
+ top_k=self.top_k,
790
+ )
791
+ if self.cache:
792
+ self.cache.set(
793
+ self.model_adapter.model_name,
794
+ masked_text,
795
+ 0,
796
+ self.top_k,
797
+ predictions,
798
+ )
799
+
800
+ # Filter by constraints and find matching lexical items
801
+ candidates: list[tuple[LexicalItem, float]] = []
802
+ for token, log_prob in predictions:
803
+ # Find matching items in lexicons
804
+ for lexicon in lexicons:
805
+ for item in lexicon.items.values():
806
+ # Skip if already seen (uniqueness enforcement)
807
+ if seen_items is not None and item.id in seen_items:
808
+ continue
809
+
810
+ # Match lemma and language
811
+ if item.lemma.lower() == token.lower():
812
+ if language_code is None or item.language_code == language_code:
813
+ # Check slot constraints
814
+ if slot.constraints:
815
+ # Evaluate constraints using resolver
816
+ if self.resolver.evaluate_slot_constraints(
817
+ item, slot.constraints
818
+ ):
819
+ candidates.append((item, log_prob))
820
+ else:
821
+ candidates.append((item, log_prob))
822
+
823
+ # Apply max_fills limit (take top-N by log probability)
824
+ if max_fills is not None and len(candidates) > max_fills:
825
+ # Already sorted by log_prob (descending) from MLM predictions
826
+ # But need to ensure we take highest scoring ones
827
+ candidates.sort(key=lambda x: x[1], reverse=True)
828
+ candidates = candidates[:max_fills]
829
+
830
+ return candidates
831
+
832
+ def _get_mlm_predictions_batch(
833
+ self, masked_texts: list[str]
834
+ ) -> list[list[tuple[str, float]]]:
835
+ """Get MLM predictions for a batch of masked texts.
836
+
837
+ Parameters
838
+ ----------
839
+ masked_texts : list[str]
840
+ List of texts with mask tokens
841
+
842
+ Returns
843
+ -------
844
+ list[list[tuple[str, float]]]
845
+ Predictions for each text: list of (token, log_prob) tuples
846
+ """
847
+ cache_start = time.time()
848
+
849
+ # Check cache for each text first
850
+ predictions_batch: list[list[tuple[str, float]] | None] = []
851
+ texts_to_predict: list[int] = [] # Indices needing prediction
852
+
853
+ for i, masked_text in enumerate(masked_texts):
854
+ if self.cache:
855
+ predictions = self.cache.get(
856
+ self.model_adapter.model_name,
857
+ masked_text,
858
+ 0, # First mask position
859
+ self.top_k,
860
+ )
861
+ else:
862
+ predictions = None
863
+
864
+ predictions_batch.append(predictions)
865
+ if predictions is None:
866
+ texts_to_predict.append(i)
867
+
868
+ cache_elapsed = time.time() - cache_start
869
+ cache_hits = len(masked_texts) - len(texts_to_predict)
870
+ logger.info(
871
+ f"[MLMFillingStrategy] Cache: {cache_hits}/"
872
+ f"{len(masked_texts)} hits in {cache_elapsed:.3f}s"
873
+ )
874
+
875
+ # Batch predict uncached texts
876
+ if texts_to_predict:
877
+ logger.info(
878
+ f"[MLMFillingStrategy] Calling model for "
879
+ f"{len(texts_to_predict)} uncached texts..."
880
+ )
881
+ model_start = time.time()
882
+ uncached_texts = [masked_texts[i] for i in texts_to_predict]
883
+ new_predictions = self.model_adapter.predict_masked_token_batch(
884
+ uncached_texts,
885
+ mask_position=0,
886
+ top_k=self.top_k,
887
+ )
888
+ model_elapsed = time.time() - model_start
889
+ per_text = model_elapsed / len(texts_to_predict)
890
+ logger.info(
891
+ f"[MLMFillingStrategy] Model inference took "
892
+ f"{model_elapsed:.2f}s ({per_text:.3f}s per text)"
893
+ )
894
+
895
+ # Fill in predictions and cache them
896
+ cache_write_start = time.time()
897
+ for idx, predictions in zip(texts_to_predict, new_predictions, strict=True):
898
+ predictions_batch[idx] = predictions
899
+ if self.cache:
900
+ self.cache.set(
901
+ self.model_adapter.model_name,
902
+ masked_texts[idx],
903
+ 0,
904
+ self.top_k,
905
+ predictions,
906
+ )
907
+ cache_write_elapsed = time.time() - cache_write_start
908
+ logger.info(
909
+ f"[MLMFillingStrategy] Cache writes took {cache_write_elapsed:.3f}s"
910
+ )
911
+
912
+ # Convert None to empty list (shouldn't happen but for type safety)
913
+ return [p if p is not None else [] for p in predictions_batch]
914
+
915
+ def _filter_mlm_predictions(
916
+ self,
917
+ predictions: list[tuple[str, float]],
918
+ slot: Slot,
919
+ lexicons: list[Lexicon],
920
+ language_code: LanguageCode | None,
921
+ seen_items: set[UUID] | None = None,
922
+ max_fills: int | None = None,
923
+ ) -> list[tuple[LexicalItem, float]]:
924
+ """Filter MLM predictions to valid lexical items.
925
+
926
+ Parameters
927
+ ----------
928
+ predictions : list[tuple[str, float]]
929
+ Raw (token, log_prob) predictions from MLM
930
+ slot : Slot
931
+ Slot object with constraints
932
+ lexicons : list[Lexicon]
933
+ Lexicons for lookup
934
+ language_code : LanguageCode | None
935
+ Language filter
936
+ seen_items : set[UUID] | None
937
+ Set of item IDs already used (for uniqueness enforcement)
938
+ max_fills : int | None
939
+ Maximum number of candidates to return
940
+
941
+ Returns
942
+ -------
943
+ list[tuple[LexicalItem, float]]
944
+ Filtered (item, log_prob) pairs
945
+ """
946
+ # Normalize language code
947
+ if language_code is not None:
948
+ language_code = validate_iso639_code(language_code)
949
+
950
+ # Filter by constraints and find matching lexical items
951
+ candidates: list[tuple[LexicalItem, float]] = []
952
+ for token, log_prob in predictions:
953
+ # Find matching items in lexicons
954
+ for lexicon in lexicons:
955
+ for item in lexicon.items.values():
956
+ # Skip if already seen (uniqueness enforcement)
957
+ if seen_items is not None and item.id in seen_items:
958
+ continue
959
+
960
+ # Match lemma and language
961
+ if item.lemma.lower() == token.lower():
962
+ if language_code is None or item.language_code == language_code:
963
+ # Check slot constraints
964
+ if slot.constraints:
965
+ # Evaluate constraints using resolver
966
+ if self.resolver.evaluate_slot_constraints(
967
+ item, slot.constraints
968
+ ):
969
+ candidates.append((item, log_prob))
970
+ else:
971
+ candidates.append((item, log_prob))
972
+
973
+ # Apply max_fills limit (take top-N by log probability)
974
+ if max_fills is not None and len(candidates) > max_fills:
975
+ # Already sorted by log_prob (descending) from MLM predictions
976
+ # But need to ensure we take highest scoring ones
977
+ candidates.sort(key=lambda x: x[1], reverse=True)
978
+ candidates = candidates[:max_fills]
979
+
980
+ return candidates
981
+
982
+ def _create_masked_text(
983
+ self,
984
+ template: Template,
985
+ slot_names: list[str],
986
+ filled_slots: dict[str, LexicalItem],
987
+ current_slot_idx: int,
988
+ ) -> str:
989
+ """Create text with mask token for current slot.
990
+
991
+ Parameters
992
+ ----------
993
+ template : Template
994
+ Template
995
+ slot_names : list[str]
996
+ Slot names
997
+ filled_slots : dict[str, LexicalItem]
998
+ Filled slots
999
+ current_slot_idx : int
1000
+ Current slot index
1001
+
1002
+ Returns
1003
+ -------
1004
+ str
1005
+ Text with [MASK] token
1006
+ """
1007
+ mask_token = self.model_adapter.get_mask_token()
1008
+ text = template.template_string
1009
+
1010
+ # Replace filled slots with lemmas
1011
+ for slot_name, item in filled_slots.items():
1012
+ placeholder = f"{{{slot_name}}}"
1013
+ text = text.replace(placeholder, item.lemma)
1014
+
1015
+ # Replace current slot with mask
1016
+ current_slot_name = slot_names[current_slot_idx]
1017
+ current_placeholder = f"{{{current_slot_name}}}"
1018
+ text = text.replace(current_placeholder, mask_token)
1019
+
1020
+ # Replace remaining unfilled slots with mask for context
1021
+ for slot_name in slot_names:
1022
+ placeholder = f"{{{slot_name}}}"
1023
+ if placeholder in text:
1024
+ text = text.replace(placeholder, mask_token)
1025
+
1026
+ return text
1027
+
1028
+
1029
+ class StrategyFiller(TemplateFiller):
1030
+ """Strategy-based template filling for simple templates.
1031
+
1032
+ Fast filling using enumeration strategies (exhaustive, random, stratified).
1033
+ Does NOT handle template-level multi-slot constraints (Template.constraints).
1034
+
1035
+ For templates with multi-slot constraints requiring agreement or
1036
+ relational checks, use CSPFiller instead.
1037
+
1038
+ Parameters
1039
+ ----------
1040
+ lexicon : Lexicon
1041
+ Lexicon containing candidate items.
1042
+ strategy : FillingStrategy
1043
+ Strategy for generating combinations.
1044
+
1045
+ Examples
1046
+ --------
1047
+ >>> from bead.templates.strategies import StrategyFiller, ExhaustiveStrategy
1048
+ >>> filler = StrategyFiller(lexicon, ExhaustiveStrategy())
1049
+ >>> filled = filler.fill(template)
1050
+ >>> len(filled)
1051
+ 12
1052
+ """
1053
+
1054
+ def __init__(self, lexicon: Lexicon, strategy: FillingStrategy) -> None:
1055
+ self.lexicon = lexicon
1056
+ self.strategy = strategy
1057
+ self.resolver = ConstraintResolver()
1058
+
1059
+ def fill(
1060
+ self,
1061
+ template: Template,
1062
+ language_code: LanguageCode | None = None,
1063
+ ) -> list[FilledTemplate]:
1064
+ """Fill template with lexical items using strategy.
1065
+
1066
+ Parameters
1067
+ ----------
1068
+ template : Template
1069
+ Template to fill.
1070
+ language_code : LanguageCode | None
1071
+ Optional language code to filter items.
1072
+
1073
+ Returns
1074
+ -------
1075
+ list[FilledTemplate]
1076
+ List of all filled template instances.
1077
+
1078
+ Raises
1079
+ ------
1080
+ ValueError
1081
+ If any slot has no valid items.
1082
+ """
1083
+ # 1. Resolve slot constraints
1084
+ slot_items = self._resolve_slot_constraints(template, language_code)
1085
+
1086
+ # 2. Check for empty slots
1087
+ empty_slots = [name for name, items in slot_items.items() if not items]
1088
+ if empty_slots:
1089
+ raise ValueError(f"No valid items for slots: {empty_slots}")
1090
+
1091
+ # 3. Generate combinations using strategy
1092
+ combinations = self.strategy.generate_combinations(slot_items)
1093
+
1094
+ # 4. Create FilledTemplate instances
1095
+ filled_templates: list[FilledTemplate] = []
1096
+ for combo in combinations:
1097
+ rendered = self._render_template(template, combo)
1098
+ filled = FilledTemplate(
1099
+ template_id=str(template.id),
1100
+ template_name=template.name,
1101
+ slot_fillers=combo,
1102
+ rendered_text=rendered,
1103
+ strategy_name=self.strategy.name,
1104
+ )
1105
+ filled_templates.append(filled)
1106
+
1107
+ return filled_templates
1108
+
1109
+ def _resolve_slot_constraints(
1110
+ self,
1111
+ template: Template,
1112
+ language_code: LanguageCode | None,
1113
+ ) -> dict[str, list[LexicalItem]]:
1114
+ """Resolve constraints for each slot.
1115
+
1116
+ Parameters
1117
+ ----------
1118
+ template : Template
1119
+ Template with slots and constraints.
1120
+ language_code : LanguageCode | None
1121
+ Optional language filter.
1122
+
1123
+ Returns
1124
+ -------
1125
+ dict[str, list[LexicalItem]]
1126
+ Mapping of slot names to valid items.
1127
+ """
1128
+ slot_items: dict[str, list[LexicalItem]] = {}
1129
+
1130
+ # Normalize language code if provided
1131
+ normalized_lang = validate_iso639_code(language_code) if language_code else None
1132
+
1133
+ for slot_name, slot in template.slots.items():
1134
+ candidates = list(self.lexicon.items.values())
1135
+
1136
+ # Filter by language code
1137
+ if normalized_lang:
1138
+ candidates = [
1139
+ item for item in candidates if item.language_code == normalized_lang
1140
+ ]
1141
+
1142
+ # Apply slot constraints
1143
+ if slot.constraints:
1144
+ filtered: list[LexicalItem] = []
1145
+ for item in candidates:
1146
+ eval_context: dict[
1147
+ str, ContextValue | LexicalItem | FilledTemplate | Item
1148
+ ] = {"self": item}
1149
+
1150
+ # Check all constraints
1151
+ passes_all_constraints = True
1152
+ for constraint in slot.constraints:
1153
+ if constraint.context:
1154
+ eval_context.update(constraint.context)
1155
+
1156
+ evaluator = DSLEvaluator()
1157
+ if not evaluator.evaluate(constraint.expression, eval_context):
1158
+ passes_all_constraints = False
1159
+ break
1160
+
1161
+ # Only add if passed ALL constraints
1162
+ if passes_all_constraints:
1163
+ filtered.append(item)
1164
+
1165
+ candidates = filtered
1166
+
1167
+ slot_items[slot_name] = candidates
1168
+
1169
+ return slot_items
1170
+
1171
+ def _render_template(
1172
+ self, template: Template, slot_fillers: dict[str, LexicalItem]
1173
+ ) -> str:
1174
+ """Render template string with slot fillers.
1175
+
1176
+ Parameters
1177
+ ----------
1178
+ template : Template
1179
+ Template with template_string.
1180
+ slot_fillers : dict[str, LexicalItem]
1181
+ Items filling each slot.
1182
+
1183
+ Returns
1184
+ -------
1185
+ str
1186
+ Rendered template string.
1187
+ """
1188
+ rendered = template.template_string
1189
+ for slot_name, item in slot_fillers.items():
1190
+ placeholder = f"{{{slot_name}}}"
1191
+ rendered = rendered.replace(placeholder, item.lemma)
1192
+ return rendered
1193
+
1194
+ def count_combinations(self, template: Template) -> int:
1195
+ """Count total possible combinations for template.
1196
+
1197
+ Parameters
1198
+ ----------
1199
+ template : Template
1200
+ Template to count combinations for.
1201
+
1202
+ Returns
1203
+ -------
1204
+ int
1205
+ Total number of possible combinations.
1206
+ """
1207
+ slot_items = self._resolve_slot_constraints(template, None)
1208
+
1209
+ if not slot_items:
1210
+ return 0
1211
+
1212
+ count = 1
1213
+ for items in slot_items.values():
1214
+ count *= len(items)
1215
+
1216
+ return count
1217
+
1218
+
1219
+ class MixedFillingStrategy(FillingStrategy):
1220
+ """Fill different template slots using different strategies.
1221
+
1222
+ Allows per-slot strategy specification, enabling workflows like:
1223
+ - Fill nouns/verbs exhaustively
1224
+ - Fill adjectives via MLM based on noun context
1225
+
1226
+ This strategy operates in two steps:
1227
+ 1. First pass: Fill slots assigned to non-MLM strategies (exhaustive, random, etc.)
1228
+ 2. Second pass: For each first pass combination, fill remaining slots via MLM
1229
+
1230
+ Parameters
1231
+ ----------
1232
+ slot_strategies : dict[str, tuple[FillingStrategy, dict]]
1233
+ Mapping of slot names to (strategy, config) tuples.
1234
+ Config is strategy-specific kwargs.
1235
+ default_strategy : FillingStrategy | None
1236
+ Default strategy for slots not explicitly specified.
1237
+
1238
+ Examples
1239
+ --------
1240
+ >>> exhaustive = ExhaustiveStrategy()
1241
+ >>> mlm_config = {
1242
+ ... "resolver": resolver,
1243
+ ... "model_adapter": mlm_adapter,
1244
+ ... "top_k": 5
1245
+ ... }
1246
+ >>> strategy = MixedFillingStrategy(
1247
+ ... slot_strategies={
1248
+ ... "noun": (exhaustive, {}),
1249
+ ... "verb": (exhaustive, {}),
1250
+ ... "adjective": ("mlm", mlm_config)
1251
+ ... }
1252
+ ... )
1253
+ """
1254
+
1255
+ def __init__(
1256
+ self,
1257
+ slot_strategies: dict[str, tuple[str | FillingStrategy, StrategyConfig]],
1258
+ default_strategy: FillingStrategy | None = None,
1259
+ ) -> None:
1260
+ """Initialize mixed strategy.
1261
+
1262
+ Parameters
1263
+ ----------
1264
+ slot_strategies : dict[str, tuple[str | FillingStrategy, StrategyConfig]]
1265
+ Mapping slot names to (strategy_name, config) or
1266
+ (strategy_instance, config). strategy_name can be:
1267
+ "exhaustive", "random", "stratified", "mlm"
1268
+ default_strategy : FillingStrategy | None
1269
+ Default strategy for unspecified slots.
1270
+ """
1271
+ self.slot_strategies = slot_strategies
1272
+ self.default_strategy = default_strategy or ExhaustiveStrategy()
1273
+
1274
+ # Separate slots by strategy type
1275
+ self.non_mlm_slots: list[str] = [] # Non-MLM slots
1276
+ self.mlm_slots: list[str] = [] # MLM slots
1277
+ self.non_mlm_strategies: dict[str, FillingStrategy] = {}
1278
+ self.mlm_configs: dict[str, StrategyConfig] = {}
1279
+
1280
+ for slot_name, (strategy, config) in slot_strategies.items():
1281
+ strategy_name = strategy if isinstance(strategy, str) else strategy.name
1282
+
1283
+ if strategy_name == "mlm":
1284
+ self.mlm_slots.append(slot_name)
1285
+ self.mlm_configs[slot_name] = config
1286
+ else:
1287
+ self.non_mlm_slots.append(slot_name)
1288
+ # Instantiate strategy if needed
1289
+ if isinstance(strategy, str):
1290
+ self.non_mlm_strategies[slot_name] = self._instantiate_strategy(
1291
+ strategy, config
1292
+ )
1293
+ else:
1294
+ self.non_mlm_strategies[slot_name] = strategy
1295
+
1296
+ def _instantiate_strategy(
1297
+ self, strategy_name: str, config: StrategyConfig
1298
+ ) -> FillingStrategy:
1299
+ """Instantiate strategy from name and config.
1300
+
1301
+ Parameters
1302
+ ----------
1303
+ strategy_name : str
1304
+ Strategy name: "exhaustive", "random", "stratified"
1305
+ config : dict
1306
+ Strategy-specific configuration
1307
+
1308
+ Returns
1309
+ -------
1310
+ FillingStrategy
1311
+ Instantiated strategy
1312
+
1313
+ Raises
1314
+ ------
1315
+ ValueError
1316
+ If strategy name is unknown
1317
+ """
1318
+ if strategy_name == "exhaustive":
1319
+ return ExhaustiveStrategy()
1320
+ elif strategy_name == "random":
1321
+ return RandomStrategy(
1322
+ n_samples=cast(int, config.get("n_samples", 100)),
1323
+ seed=cast(int | None, config.get("seed")),
1324
+ )
1325
+ elif strategy_name == "stratified":
1326
+ return StratifiedStrategy(
1327
+ n_samples=cast(int, config.get("n_samples", 100)),
1328
+ grouping_property=cast(str, config.get("grouping_property", "pos")),
1329
+ seed=cast(int | None, config.get("seed")),
1330
+ )
1331
+ else:
1332
+ raise ValueError(f"Unknown strategy: {strategy_name}")
1333
+
1334
+ @property
1335
+ def name(self) -> str:
1336
+ """Get strategy name."""
1337
+ return "mixed"
1338
+
1339
+ def generate_combinations(
1340
+ self,
1341
+ slot_items: dict[str, list[LexicalItem]],
1342
+ ) -> list[dict[str, LexicalItem]]:
1343
+ """Generate combinations using mixed strategies.
1344
+
1345
+ Note: This method signature is required by FillingStrategy,
1346
+ but MixedFillingStrategy with MLM requires template context.
1347
+ Use generate_from_template instead.
1348
+
1349
+ Parameters
1350
+ ----------
1351
+ slot_items : dict[str, list[LexicalItem]]
1352
+ Mapping of slot names to valid items
1353
+
1354
+ Returns
1355
+ -------
1356
+ list[dict[str, LexicalItem]]
1357
+ Generated combinations
1358
+
1359
+ Raises
1360
+ ------
1361
+ NotImplementedError
1362
+ If any slot uses MLM strategy (requires template context)
1363
+ """
1364
+ if self.mlm_slots:
1365
+ raise NotImplementedError(
1366
+ "MixedFillingStrategy with MLM slots requires template context. "
1367
+ "Use StrategyFiller or CSPFiller, which call generate_from_template."
1368
+ )
1369
+
1370
+ # If no MLM slots, just use non-MLM strategies
1371
+ # This is a simplified case: all slots use non-MLM strategies
1372
+
1373
+ # For each slot, generate its combinations independently
1374
+ slot_combinations: dict[str, list[LexicalItem]] = {}
1375
+
1376
+ for slot_name, items in slot_items.items():
1377
+ if slot_name in self.non_mlm_strategies:
1378
+ strategy = self.non_mlm_strategies[slot_name]
1379
+ # Generate combinations for just this slot
1380
+ combos = strategy.generate_combinations({slot_name: items})
1381
+ slot_combinations[slot_name] = [c[slot_name] for c in combos]
1382
+ else:
1383
+ # Use default strategy
1384
+ combos = self.default_strategy.generate_combinations({slot_name: items})
1385
+ slot_combinations[slot_name] = [c[slot_name] for c in combos]
1386
+
1387
+ # Generate cartesian product of all slot combinations
1388
+ slot_names = list(slot_items.keys())
1389
+ item_lists = [slot_combinations[name] for name in slot_names]
1390
+
1391
+ combinations: list[dict[str, LexicalItem]] = []
1392
+ for combo_tuple in cartesian_product(*item_lists):
1393
+ combo_dict = dict(zip(slot_names, combo_tuple, strict=True))
1394
+ combinations.append(combo_dict)
1395
+
1396
+ return combinations
1397
+
1398
+ def generate_from_template(
1399
+ self,
1400
+ template: Template,
1401
+ lexicons: list[Lexicon],
1402
+ language_code: LanguageCode | None = None,
1403
+ ) -> Iterator[dict[str, LexicalItem]]:
1404
+ """Generate combinations from template using mixed strategies.
1405
+
1406
+ First pass: Fill non-MLM slots using their assigned strategies
1407
+ Second pass: For each first pass combination, fill MLM slots using beam search
1408
+
1409
+ Parameters
1410
+ ----------
1411
+ template : Template
1412
+ Template to fill
1413
+ lexicons : list[Lexicon]
1414
+ Lexicons for constraint resolution
1415
+ language_code : LanguageCode | None
1416
+ Language filter
1417
+
1418
+ Yields
1419
+ ------
1420
+ dict[str, LexicalItem]
1421
+ Complete slot-to-item mappings
1422
+ """
1423
+ logger.info(f"[MixedFillingStrategy] Starting template: {template.name}")
1424
+ logger.info(f"[MixedFillingStrategy] Non-MLM slots: {self.non_mlm_slots}")
1425
+ logger.info(f"[MixedFillingStrategy] MLM slots: {self.mlm_slots}")
1426
+
1427
+ # First pass: Fill non-MLM slots
1428
+ first_pass_start = time.time()
1429
+ if not self.non_mlm_slots:
1430
+ # No non-MLM slots - just use MLM for all MLM slots
1431
+ first_pass_combinations: list[dict[str, LexicalItem]] = [{}]
1432
+ else:
1433
+ first_pass_combinations = self._generate_non_mlm_combinations(
1434
+ template, lexicons, language_code
1435
+ )
1436
+ first_pass_elapsed = time.time() - first_pass_start
1437
+ logger.info(
1438
+ f"[MixedFillingStrategy] First pass generated "
1439
+ f"{len(first_pass_combinations)} combinations in {first_pass_elapsed:.2f}s"
1440
+ )
1441
+
1442
+ # Second pass: Fill MLM slots for each first pass combination
1443
+ if not self.mlm_slots:
1444
+ # No MLM slots - just yield first pass combinations
1445
+ logger.info(
1446
+ "[MixedFillingStrategy] No MLM slots, yielding first pass combinations"
1447
+ )
1448
+ yield from first_pass_combinations
1449
+ else:
1450
+ logger.info(
1451
+ f"[MixedFillingStrategy] Starting second pass for "
1452
+ f"{len(first_pass_combinations)} combinations..."
1453
+ )
1454
+ second_pass_start = time.time()
1455
+ total_yielded = 0
1456
+ for i, partial_combo in enumerate(first_pass_combinations):
1457
+ combo_start = time.time()
1458
+ if i == 0:
1459
+ # Debug first combo to see what's in it
1460
+ combo_slots = list(partial_combo.keys())
1461
+ combo_values = {k: v.lemma for k, v in partial_combo.items()}
1462
+ logger.info(
1463
+ f"[MixedFillingStrategy] First combination has "
1464
+ f"slots: {combo_slots}"
1465
+ )
1466
+ logger.info(
1467
+ f"[MixedFillingStrategy] First combination "
1468
+ f"values: {combo_values}"
1469
+ )
1470
+ logger.info(
1471
+ f"[MixedFillingStrategy] Processing combination "
1472
+ f"{i + 1}/{len(first_pass_combinations)}"
1473
+ )
1474
+ # Fill remaining slots with MLM
1475
+ n_yielded_for_combo = 0
1476
+ for filled in self._fill_mlm_slots(
1477
+ template, partial_combo, lexicons, language_code
1478
+ ):
1479
+ # Filter by template-level constraints
1480
+ if self._check_template_constraints(template, filled):
1481
+ n_yielded_for_combo += 1
1482
+ total_yielded += 1
1483
+ yield filled
1484
+ combo_elapsed = time.time() - combo_start
1485
+ logger.info(
1486
+ f"[MixedFillingStrategy] Combination {i + 1} yielded "
1487
+ f"{n_yielded_for_combo} complete fillings in "
1488
+ f"{combo_elapsed:.2f}s"
1489
+ )
1490
+ second_pass_elapsed = time.time() - second_pass_start
1491
+ logger.info(
1492
+ f"[MixedFillingStrategy] Second pass complete: {total_yielded} "
1493
+ f"total fillings in {second_pass_elapsed:.2f}s"
1494
+ )
1495
+
1496
+ def _generate_non_mlm_combinations(
1497
+ self,
1498
+ template: Template,
1499
+ lexicons: list[Lexicon],
1500
+ language_code: LanguageCode | None,
1501
+ ) -> list[dict[str, LexicalItem]]:
1502
+ """Generate combinations for non-MLM slots.
1503
+
1504
+ Parameters
1505
+ ----------
1506
+ template : Template
1507
+ Template being filled
1508
+ lexicons : list[Lexicon]
1509
+ Lexicons for items
1510
+ language_code : LanguageCode | None
1511
+ Language filter
1512
+
1513
+ Returns
1514
+ -------
1515
+ list[dict[str, LexicalItem]]
1516
+ Partial combinations (only non-MLM slots filled)
1517
+ """
1518
+ # Get valid items for each non-MLM slot
1519
+ slot_items: dict[str, list[LexicalItem]] = {}
1520
+ normalized_lang = validate_iso639_code(language_code) if language_code else None
1521
+
1522
+ for slot_name in self.non_mlm_slots:
1523
+ if slot_name not in template.slots:
1524
+ continue
1525
+
1526
+ slot = template.slots[slot_name]
1527
+ candidates: list[LexicalItem] = []
1528
+
1529
+ # Collect items from all lexicons
1530
+ for lexicon in lexicons:
1531
+ for item in lexicon.items.values():
1532
+ # Filter by language
1533
+ if normalized_lang and item.language_code != normalized_lang:
1534
+ continue
1535
+ # Check slot constraints
1536
+ if slot.constraints:
1537
+ eval_context: dict[str, ContextValue | LexicalItem] = {
1538
+ "self": item
1539
+ }
1540
+ # Check ALL constraints - item must pass every one
1541
+ passes_all_constraints = True
1542
+ for constraint in slot.constraints:
1543
+ if constraint.context:
1544
+ eval_context.update(constraint.context)
1545
+ # Evaluate
1546
+ evaluator = DSLEvaluator()
1547
+ # Cast to expected context type
1548
+ typed_context = cast(
1549
+ dict[
1550
+ str,
1551
+ ContextValue | LexicalItem | FilledTemplate | Item,
1552
+ ],
1553
+ eval_context,
1554
+ )
1555
+ if not evaluator.evaluate(
1556
+ constraint.expression, typed_context
1557
+ ):
1558
+ passes_all_constraints = False
1559
+ break
1560
+
1561
+ # Only add item if it passed ALL constraints
1562
+ if not passes_all_constraints:
1563
+ continue
1564
+
1565
+ candidates.append(item)
1566
+
1567
+ slot_items[slot_name] = candidates
1568
+
1569
+ # Generate combinations using per-slot strategies
1570
+ # For each slot, we need to apply its strategy independently,
1571
+ # then take cartesian product
1572
+
1573
+ # Collect combinations per slot
1574
+ slot_combos: dict[str, list[LexicalItem]] = {}
1575
+
1576
+ for slot_name in self.non_mlm_slots:
1577
+ if slot_name not in slot_items:
1578
+ continue
1579
+
1580
+ items = slot_items[slot_name]
1581
+ strategy = self.non_mlm_strategies.get(slot_name, self.default_strategy)
1582
+
1583
+ # Generate combinations for this slot
1584
+ combos = strategy.generate_combinations({slot_name: items})
1585
+ slot_combos[slot_name] = [c[slot_name] for c in combos]
1586
+
1587
+ # Cartesian product of all non-MLM slots
1588
+ if not slot_combos:
1589
+ return [{}]
1590
+
1591
+ slot_names = list(slot_combos.keys())
1592
+ item_lists = [slot_combos[name] for name in slot_names]
1593
+
1594
+ combinations: list[dict[str, LexicalItem]] = []
1595
+ for combo_tuple in cartesian_product(*item_lists):
1596
+ combo_dict = dict(zip(slot_names, combo_tuple, strict=True))
1597
+ # Filter by template-level constraints
1598
+ if self._check_template_constraints(template, combo_dict):
1599
+ combinations.append(combo_dict)
1600
+
1601
+ return combinations
1602
+
1603
+ def _fill_mlm_slots(
1604
+ self,
1605
+ template: Template,
1606
+ partial_filling: dict[str, LexicalItem],
1607
+ lexicons: list[Lexicon],
1608
+ language_code: LanguageCode | None,
1609
+ ) -> Iterator[dict[str, LexicalItem]]:
1610
+ """Fill MLM slots given a partial filling from first pass.
1611
+
1612
+ Parameters
1613
+ ----------
1614
+ template : Template
1615
+ Template being filled
1616
+ partial_filling : dict[str, LexicalItem]
1617
+ Already-filled slots from first pass
1618
+ lexicons : list[Lexicon]
1619
+ Lexicons for items
1620
+ language_code : LanguageCode | None
1621
+ Language filter
1622
+
1623
+ Yields
1624
+ ------
1625
+ dict[str, LexicalItem]
1626
+ Complete fillings with MLM slots added
1627
+ """
1628
+ if not self.mlm_slots or not self.mlm_configs:
1629
+ yield partial_filling
1630
+ return
1631
+
1632
+ # Get base config from first MLM slot (model adapter, resolver, etc.)
1633
+ first_mlm_slot = self.mlm_slots[0]
1634
+ base_config = self.mlm_configs[first_mlm_slot].copy()
1635
+
1636
+ # Extract per-slot max_fills and enforce_unique settings
1637
+ per_slot_max_fills: dict[str, int] = {}
1638
+ per_slot_enforce_unique: dict[str, bool] = {}
1639
+
1640
+ for slot_name in self.mlm_slots:
1641
+ config = self.mlm_configs[slot_name]
1642
+ if "max_fills" in config:
1643
+ per_slot_max_fills[slot_name] = cast(int, config["max_fills"])
1644
+ if "enforce_unique" in config:
1645
+ per_slot_enforce_unique[slot_name] = cast(
1646
+ bool, config["enforce_unique"]
1647
+ )
1648
+
1649
+ # Remove per-slot settings from base config
1650
+ # (they're not MLMFillingStrategy params)
1651
+ base_config.pop("max_fills", None)
1652
+ base_config.pop("enforce_unique", None)
1653
+
1654
+ # Add per-slot dicts to config
1655
+ base_config["per_slot_max_fills"] = per_slot_max_fills
1656
+ base_config["per_slot_enforce_unique"] = per_slot_enforce_unique
1657
+
1658
+ # Create MLM strategy with properly typed config
1659
+ mlm_strategy = MLMFillingStrategy(
1660
+ resolver=cast(ConstraintResolver, base_config["resolver"]),
1661
+ model_adapter=cast(HuggingFaceMLMAdapter, base_config["model_adapter"]),
1662
+ beam_size=cast(int, base_config.get("beam_size", 5)),
1663
+ fill_direction=cast(
1664
+ Literal[
1665
+ "left_to_right",
1666
+ "right_to_left",
1667
+ "inside_out",
1668
+ "outside_in",
1669
+ "custom",
1670
+ ],
1671
+ base_config.get("fill_direction", "left_to_right"),
1672
+ ),
1673
+ custom_order=cast(list[int] | None, base_config.get("custom_order")),
1674
+ top_k=cast(int, base_config.get("top_k", 20)),
1675
+ cache=cast(ModelOutputCache | None, base_config.get("cache")),
1676
+ budget=cast(int | None, base_config.get("budget")),
1677
+ per_slot_max_fills=per_slot_max_fills,
1678
+ per_slot_enforce_unique=per_slot_enforce_unique,
1679
+ )
1680
+
1681
+ # Create a modified template with only MLM slots
1682
+ mlm_template = self._create_mlm_template(template, partial_filling)
1683
+
1684
+ # Generate completions via MLM
1685
+ for mlm_filling in mlm_strategy.generate_from_template(
1686
+ mlm_template, lexicons, language_code
1687
+ ):
1688
+ # Combine partial + MLM fillings
1689
+ complete = partial_filling.copy()
1690
+ complete.update(mlm_filling)
1691
+ yield complete
1692
+
1693
+ def _check_template_constraints(
1694
+ self,
1695
+ template: Template,
1696
+ slot_fillers: dict[str, LexicalItem],
1697
+ ) -> bool:
1698
+ """Check if slot fillers satisfy template-level constraints.
1699
+
1700
+ Only evaluates constraints where all referenced slots are present.
1701
+ Constraints referencing missing slots are skipped (deferred).
1702
+
1703
+ Parameters
1704
+ ----------
1705
+ template : Template
1706
+ Template with multi-slot constraints
1707
+ slot_fillers : dict[str, LexicalItem]
1708
+ Complete or partial slot fillings
1709
+
1710
+ Returns
1711
+ -------
1712
+ bool
1713
+ True if all evaluable template constraints are satisfied
1714
+ """
1715
+ logger.info(
1716
+ f"[TemplateConstraints] Called with template '{template.name}', "
1717
+ f"{len(template.constraints)} constraints, {len(slot_fillers)} fillers"
1718
+ )
1719
+ if not template.constraints:
1720
+ logger.info("[TemplateConstraints] No constraints, returning True")
1721
+ return True
1722
+
1723
+ # Extract slot names referenced in each constraint
1724
+ # Pattern matches "slot_name." but NOT "something.property." (no dot before)
1725
+ slot_pattern = re.compile(r"(?<![.])\b([a-zA-Z_][a-zA-Z0-9_]*)\.")
1726
+ filled_slots = set(slot_fillers.keys())
1727
+
1728
+ # Filter to only constraints where all referenced slots are filled
1729
+ evaluable_constraints = []
1730
+ for constraint in template.constraints:
1731
+ # Remove string literals before matching to avoid false matches
1732
+ # (e.g., 'V.PTCP' should not match slot 'V')
1733
+ expr_no_strings = re.sub(r"'[^']*'|\"[^\"]*\"", '""', constraint.expression)
1734
+ referenced_slots = set(slot_pattern.findall(expr_no_strings))
1735
+ if referenced_slots.issubset(filled_slots):
1736
+ evaluable_constraints.append(constraint)
1737
+ logger.info(
1738
+ f"[TemplateConstraints] Will evaluate: {constraint.description}"
1739
+ )
1740
+ else:
1741
+ missing = referenced_slots - filled_slots
1742
+ logger.info(
1743
+ f"[TemplateConstraints] Deferring (missing {missing}): "
1744
+ f"{constraint.description}"
1745
+ )
1746
+
1747
+ if not evaluable_constraints:
1748
+ return True # No constraints can be evaluated yet
1749
+
1750
+ # Use ConstraintResolver to evaluate constraints properly
1751
+ n_constraints = len(evaluable_constraints)
1752
+ n_slots = len(filled_slots)
1753
+ logger.info(
1754
+ f"[TemplateConstraints] Evaluating {n_constraints} constraints "
1755
+ f"with {n_slots} filled slots"
1756
+ )
1757
+ resolver = ConstraintResolver()
1758
+ result = resolver.evaluate_template_constraints(
1759
+ slot_fillers, evaluable_constraints
1760
+ )
1761
+ if not result:
1762
+ logger.info("[TemplateConstraints] Combination REJECTED by constraints")
1763
+ return result
1764
+
1765
+ def _create_mlm_template(
1766
+ self, template: Template, partial_filling: dict[str, LexicalItem]
1767
+ ) -> Template:
1768
+ """Create template with non-MLM slots already filled.
1769
+
1770
+ Parameters
1771
+ ----------
1772
+ template : Template
1773
+ Original template
1774
+ partial_filling : dict[str, LexicalItem]
1775
+ Items filling non-MLM slots
1776
+
1777
+ Returns
1778
+ -------
1779
+ Template
1780
+ Modified template with non-MLM slots replaced by text
1781
+ """
1782
+ # Replace non-MLM slots in template string with their fillings
1783
+ modified_string = template.template_string
1784
+ for slot_name, item in partial_filling.items():
1785
+ placeholder = f"{{{slot_name}}}"
1786
+ # Use actual form if available (e.g., "is" not "be"), otherwise lemma
1787
+ surface_form = item.form if item.form is not None else item.lemma
1788
+ modified_string = modified_string.replace(placeholder, surface_form)
1789
+
1790
+ # Create new template with only MLM slots
1791
+ mlm_slots = {
1792
+ name: slot
1793
+ for name, slot in template.slots.items()
1794
+ if name in self.mlm_slots
1795
+ }
1796
+
1797
+ # Create modified template
1798
+ modified_template = Template(
1799
+ name=f"{template.name}_mlm",
1800
+ template_string=modified_string,
1801
+ slots=mlm_slots,
1802
+ constraints=template.constraints,
1803
+ language_code=template.language_code,
1804
+ )
1805
+
1806
+ return modified_template