bead 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bead/__init__.py +11 -0
- bead/__main__.py +11 -0
- bead/active_learning/__init__.py +15 -0
- bead/active_learning/config.py +231 -0
- bead/active_learning/loop.py +566 -0
- bead/active_learning/models/__init__.py +24 -0
- bead/active_learning/models/base.py +852 -0
- bead/active_learning/models/binary.py +910 -0
- bead/active_learning/models/categorical.py +943 -0
- bead/active_learning/models/cloze.py +862 -0
- bead/active_learning/models/forced_choice.py +956 -0
- bead/active_learning/models/free_text.py +773 -0
- bead/active_learning/models/lora.py +365 -0
- bead/active_learning/models/magnitude.py +835 -0
- bead/active_learning/models/multi_select.py +795 -0
- bead/active_learning/models/ordinal_scale.py +811 -0
- bead/active_learning/models/peft_adapter.py +155 -0
- bead/active_learning/models/random_effects.py +639 -0
- bead/active_learning/selection.py +354 -0
- bead/active_learning/strategies.py +391 -0
- bead/active_learning/trainers/__init__.py +26 -0
- bead/active_learning/trainers/base.py +210 -0
- bead/active_learning/trainers/data_collator.py +172 -0
- bead/active_learning/trainers/dataset_utils.py +261 -0
- bead/active_learning/trainers/huggingface.py +304 -0
- bead/active_learning/trainers/lightning.py +324 -0
- bead/active_learning/trainers/metrics.py +424 -0
- bead/active_learning/trainers/mixed_effects.py +551 -0
- bead/active_learning/trainers/model_wrapper.py +509 -0
- bead/active_learning/trainers/registry.py +104 -0
- bead/adapters/__init__.py +11 -0
- bead/adapters/huggingface.py +61 -0
- bead/behavioral/__init__.py +116 -0
- bead/behavioral/analytics.py +646 -0
- bead/behavioral/extraction.py +343 -0
- bead/behavioral/merging.py +343 -0
- bead/cli/__init__.py +11 -0
- bead/cli/active_learning.py +513 -0
- bead/cli/active_learning_commands.py +779 -0
- bead/cli/completion.py +359 -0
- bead/cli/config.py +624 -0
- bead/cli/constraint_builders.py +286 -0
- bead/cli/deployment.py +859 -0
- bead/cli/deployment_trials.py +493 -0
- bead/cli/deployment_ui.py +332 -0
- bead/cli/display.py +378 -0
- bead/cli/items.py +960 -0
- bead/cli/items_factories.py +776 -0
- bead/cli/list_constraints.py +714 -0
- bead/cli/lists.py +490 -0
- bead/cli/main.py +430 -0
- bead/cli/models.py +877 -0
- bead/cli/resource_loaders.py +621 -0
- bead/cli/resources.py +1036 -0
- bead/cli/shell.py +356 -0
- bead/cli/simulate.py +840 -0
- bead/cli/templates.py +1158 -0
- bead/cli/training.py +1080 -0
- bead/cli/utils.py +614 -0
- bead/cli/workflow.py +1273 -0
- bead/config/__init__.py +68 -0
- bead/config/active_learning.py +1009 -0
- bead/config/config.py +192 -0
- bead/config/defaults.py +118 -0
- bead/config/deployment.py +217 -0
- bead/config/env.py +147 -0
- bead/config/item.py +45 -0
- bead/config/list.py +193 -0
- bead/config/loader.py +149 -0
- bead/config/logging.py +42 -0
- bead/config/model.py +49 -0
- bead/config/paths.py +46 -0
- bead/config/profiles.py +320 -0
- bead/config/resources.py +47 -0
- bead/config/serialization.py +210 -0
- bead/config/simulation.py +206 -0
- bead/config/template.py +238 -0
- bead/config/validation.py +267 -0
- bead/data/__init__.py +65 -0
- bead/data/base.py +87 -0
- bead/data/identifiers.py +97 -0
- bead/data/language_codes.py +61 -0
- bead/data/metadata.py +270 -0
- bead/data/range.py +123 -0
- bead/data/repository.py +358 -0
- bead/data/serialization.py +249 -0
- bead/data/timestamps.py +89 -0
- bead/data/validation.py +349 -0
- bead/data_collection/__init__.py +11 -0
- bead/data_collection/jatos.py +223 -0
- bead/data_collection/merger.py +154 -0
- bead/data_collection/prolific.py +198 -0
- bead/deployment/__init__.py +5 -0
- bead/deployment/distribution.py +402 -0
- bead/deployment/jatos/__init__.py +1 -0
- bead/deployment/jatos/api.py +200 -0
- bead/deployment/jatos/exporter.py +210 -0
- bead/deployment/jspsych/__init__.py +9 -0
- bead/deployment/jspsych/biome.json +44 -0
- bead/deployment/jspsych/config.py +411 -0
- bead/deployment/jspsych/generator.py +598 -0
- bead/deployment/jspsych/package.json +51 -0
- bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
- bead/deployment/jspsych/randomizer.py +299 -0
- bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
- bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
- bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
- bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
- bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
- bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
- bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
- bead/deployment/jspsych/src/plugins/rating.ts +248 -0
- bead/deployment/jspsych/src/slopit/index.ts +9 -0
- bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
- bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
- bead/deployment/jspsych/templates/experiment.css +1 -0
- bead/deployment/jspsych/templates/experiment.js.template +289 -0
- bead/deployment/jspsych/templates/index.html +51 -0
- bead/deployment/jspsych/templates/randomizer.js +241 -0
- bead/deployment/jspsych/templates/randomizer.js.template +313 -0
- bead/deployment/jspsych/trials.py +723 -0
- bead/deployment/jspsych/tsconfig.json +23 -0
- bead/deployment/jspsych/tsup.config.ts +30 -0
- bead/deployment/jspsych/ui/__init__.py +1 -0
- bead/deployment/jspsych/ui/components.py +383 -0
- bead/deployment/jspsych/ui/styles.py +411 -0
- bead/dsl/__init__.py +80 -0
- bead/dsl/ast.py +168 -0
- bead/dsl/context.py +178 -0
- bead/dsl/errors.py +71 -0
- bead/dsl/evaluator.py +570 -0
- bead/dsl/grammar.lark +81 -0
- bead/dsl/parser.py +231 -0
- bead/dsl/stdlib.py +929 -0
- bead/evaluation/__init__.py +13 -0
- bead/evaluation/convergence.py +485 -0
- bead/evaluation/interannotator.py +398 -0
- bead/items/__init__.py +40 -0
- bead/items/adapters/__init__.py +70 -0
- bead/items/adapters/anthropic.py +224 -0
- bead/items/adapters/api_utils.py +167 -0
- bead/items/adapters/base.py +216 -0
- bead/items/adapters/google.py +259 -0
- bead/items/adapters/huggingface.py +1074 -0
- bead/items/adapters/openai.py +323 -0
- bead/items/adapters/registry.py +202 -0
- bead/items/adapters/sentence_transformers.py +224 -0
- bead/items/adapters/togetherai.py +309 -0
- bead/items/binary.py +515 -0
- bead/items/cache.py +558 -0
- bead/items/categorical.py +593 -0
- bead/items/cloze.py +757 -0
- bead/items/constructor.py +784 -0
- bead/items/forced_choice.py +413 -0
- bead/items/free_text.py +681 -0
- bead/items/generation.py +432 -0
- bead/items/item.py +396 -0
- bead/items/item_template.py +787 -0
- bead/items/magnitude.py +573 -0
- bead/items/multi_select.py +621 -0
- bead/items/ordinal_scale.py +569 -0
- bead/items/scoring.py +448 -0
- bead/items/validation.py +723 -0
- bead/lists/__init__.py +30 -0
- bead/lists/balancer.py +263 -0
- bead/lists/constraints.py +1067 -0
- bead/lists/experiment_list.py +286 -0
- bead/lists/list_collection.py +378 -0
- bead/lists/partitioner.py +1141 -0
- bead/lists/stratification.py +254 -0
- bead/participants/__init__.py +73 -0
- bead/participants/collection.py +699 -0
- bead/participants/merging.py +312 -0
- bead/participants/metadata_spec.py +491 -0
- bead/participants/models.py +276 -0
- bead/resources/__init__.py +29 -0
- bead/resources/adapters/__init__.py +19 -0
- bead/resources/adapters/base.py +104 -0
- bead/resources/adapters/cache.py +128 -0
- bead/resources/adapters/glazing.py +508 -0
- bead/resources/adapters/registry.py +117 -0
- bead/resources/adapters/unimorph.py +796 -0
- bead/resources/classification.py +856 -0
- bead/resources/constraint_builders.py +329 -0
- bead/resources/constraints.py +165 -0
- bead/resources/lexical_item.py +223 -0
- bead/resources/lexicon.py +744 -0
- bead/resources/loaders.py +209 -0
- bead/resources/template.py +441 -0
- bead/resources/template_collection.py +707 -0
- bead/resources/template_generation.py +349 -0
- bead/simulation/__init__.py +29 -0
- bead/simulation/annotators/__init__.py +15 -0
- bead/simulation/annotators/base.py +175 -0
- bead/simulation/annotators/distance_based.py +135 -0
- bead/simulation/annotators/lm_based.py +114 -0
- bead/simulation/annotators/oracle.py +182 -0
- bead/simulation/annotators/random.py +181 -0
- bead/simulation/dsl_extension/__init__.py +3 -0
- bead/simulation/noise_models/__init__.py +13 -0
- bead/simulation/noise_models/base.py +42 -0
- bead/simulation/noise_models/random_noise.py +82 -0
- bead/simulation/noise_models/systematic.py +132 -0
- bead/simulation/noise_models/temperature.py +86 -0
- bead/simulation/runner.py +144 -0
- bead/simulation/strategies/__init__.py +23 -0
- bead/simulation/strategies/base.py +123 -0
- bead/simulation/strategies/binary.py +103 -0
- bead/simulation/strategies/categorical.py +123 -0
- bead/simulation/strategies/cloze.py +224 -0
- bead/simulation/strategies/forced_choice.py +127 -0
- bead/simulation/strategies/free_text.py +105 -0
- bead/simulation/strategies/magnitude.py +116 -0
- bead/simulation/strategies/multi_select.py +129 -0
- bead/simulation/strategies/ordinal_scale.py +131 -0
- bead/templates/__init__.py +27 -0
- bead/templates/adapters/__init__.py +17 -0
- bead/templates/adapters/base.py +128 -0
- bead/templates/adapters/cache.py +178 -0
- bead/templates/adapters/huggingface.py +312 -0
- bead/templates/combinatorics.py +103 -0
- bead/templates/filler.py +605 -0
- bead/templates/renderers.py +177 -0
- bead/templates/resolver.py +178 -0
- bead/templates/strategies.py +1806 -0
- bead/templates/streaming.py +195 -0
- bead-0.1.0.dist-info/METADATA +212 -0
- bead-0.1.0.dist-info/RECORD +231 -0
- bead-0.1.0.dist-info/WHEEL +4 -0
- bead-0.1.0.dist-info/entry_points.txt +2 -0
- bead-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1806 @@
|
|
|
1
|
+
"""Filling strategies for template population."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import random
|
|
7
|
+
import re
|
|
8
|
+
import time
|
|
9
|
+
from abc import ABC, abstractmethod
|
|
10
|
+
from collections.abc import Iterator
|
|
11
|
+
from typing import Literal, cast
|
|
12
|
+
from uuid import UUID
|
|
13
|
+
|
|
14
|
+
from bead.data.language_codes import LanguageCode, validate_iso639_code
|
|
15
|
+
from bead.dsl.evaluator import DSLEvaluator
|
|
16
|
+
from bead.items.item import Item
|
|
17
|
+
from bead.resources.constraints import ContextValue
|
|
18
|
+
from bead.resources.lexical_item import LexicalItem
|
|
19
|
+
from bead.resources.lexicon import Lexicon
|
|
20
|
+
from bead.resources.template import Slot, Template
|
|
21
|
+
from bead.templates.adapters import HuggingFaceMLMAdapter, ModelOutputCache
|
|
22
|
+
from bead.templates.combinatorics import cartesian_product
|
|
23
|
+
from bead.templates.filler import FilledTemplate, TemplateFiller
|
|
24
|
+
from bead.templates.resolver import ConstraintResolver
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
# Type aliases for strategy configuration
|
|
29
|
+
ConfigValue = (
|
|
30
|
+
int
|
|
31
|
+
| str
|
|
32
|
+
| bool
|
|
33
|
+
| None
|
|
34
|
+
| list[int]
|
|
35
|
+
| ConstraintResolver
|
|
36
|
+
| HuggingFaceMLMAdapter
|
|
37
|
+
| ModelOutputCache
|
|
38
|
+
| dict[str, int]
|
|
39
|
+
| dict[str, bool]
|
|
40
|
+
)
|
|
41
|
+
StrategyConfig = dict[str, ConfigValue]
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class FillingStrategy(ABC):
|
|
45
|
+
"""Abstract base class for template filling strategies.
|
|
46
|
+
|
|
47
|
+
A filling strategy determines how to combine lexical items
|
|
48
|
+
to fill template slots. Strategies differ in:
|
|
49
|
+
- Selection criteria (all vs. sample)
|
|
50
|
+
- Ordering (deterministic vs. random)
|
|
51
|
+
- Grouping (balanced vs. unbalanced)
|
|
52
|
+
|
|
53
|
+
Examples
|
|
54
|
+
--------
|
|
55
|
+
>>> strategy = ExhaustiveStrategy()
|
|
56
|
+
>>> combinations = strategy.generate_combinations(slot_items)
|
|
57
|
+
>>> len(list(combinations))
|
|
58
|
+
12
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
@abstractmethod
|
|
62
|
+
def generate_combinations(
|
|
63
|
+
self,
|
|
64
|
+
slot_items: dict[str, list[LexicalItem]],
|
|
65
|
+
) -> list[dict[str, LexicalItem]]:
|
|
66
|
+
"""Generate combinations of items for template slots.
|
|
67
|
+
|
|
68
|
+
Parameters
|
|
69
|
+
----------
|
|
70
|
+
slot_items : dict[str, list[LexicalItem]]
|
|
71
|
+
Mapping of slot names to lists of valid items.
|
|
72
|
+
|
|
73
|
+
Returns
|
|
74
|
+
-------
|
|
75
|
+
list[dict[str, LexicalItem]]
|
|
76
|
+
List of slot-to-item mappings representing filled templates.
|
|
77
|
+
|
|
78
|
+
Examples
|
|
79
|
+
--------
|
|
80
|
+
>>> slot_items = {
|
|
81
|
+
... "subject": [item1, item2],
|
|
82
|
+
... "verb": [item3, item4],
|
|
83
|
+
... }
|
|
84
|
+
>>> combinations = strategy.generate_combinations(slot_items)
|
|
85
|
+
>>> len(combinations)
|
|
86
|
+
4
|
|
87
|
+
"""
|
|
88
|
+
pass
|
|
89
|
+
|
|
90
|
+
@property
|
|
91
|
+
@abstractmethod
|
|
92
|
+
def name(self) -> str:
|
|
93
|
+
"""Get strategy name for metadata.
|
|
94
|
+
|
|
95
|
+
Returns
|
|
96
|
+
-------
|
|
97
|
+
str
|
|
98
|
+
Strategy name.
|
|
99
|
+
"""
|
|
100
|
+
pass
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class ExhaustiveStrategy(FillingStrategy):
|
|
104
|
+
"""Generate all possible combinations of slot fillers.
|
|
105
|
+
|
|
106
|
+
This strategy produces the complete Cartesian product of all
|
|
107
|
+
valid items for each slot. Use for small combinatorial spaces.
|
|
108
|
+
|
|
109
|
+
**Warning**: Combinatorial explosion! With N slots and M items
|
|
110
|
+
per slot, generates M^N combinations.
|
|
111
|
+
|
|
112
|
+
Examples
|
|
113
|
+
--------
|
|
114
|
+
>>> strategy = ExhaustiveStrategy()
|
|
115
|
+
>>> slot_items = {"a": [1, 2], "b": [3, 4]}
|
|
116
|
+
>>> combinations = strategy.generate_combinations(slot_items)
|
|
117
|
+
>>> len(combinations)
|
|
118
|
+
4
|
|
119
|
+
>>> combinations[0]
|
|
120
|
+
{"a": 1, "b": 3}
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
@property
|
|
124
|
+
def name(self) -> str:
|
|
125
|
+
"""Get strategy name."""
|
|
126
|
+
return "exhaustive"
|
|
127
|
+
|
|
128
|
+
def generate_combinations(
|
|
129
|
+
self,
|
|
130
|
+
slot_items: dict[str, list[LexicalItem]],
|
|
131
|
+
) -> list[dict[str, LexicalItem]]:
|
|
132
|
+
"""Generate all combinations.
|
|
133
|
+
|
|
134
|
+
Parameters
|
|
135
|
+
----------
|
|
136
|
+
slot_items : dict[str, list[LexicalItem]]
|
|
137
|
+
Mapping of slot names to valid items.
|
|
138
|
+
|
|
139
|
+
Returns
|
|
140
|
+
-------
|
|
141
|
+
list[dict[str, LexicalItem]]
|
|
142
|
+
All possible slot-to-item combinations.
|
|
143
|
+
"""
|
|
144
|
+
if not slot_items:
|
|
145
|
+
return []
|
|
146
|
+
|
|
147
|
+
# Get ordered slot names and item lists
|
|
148
|
+
slot_names = list(slot_items.keys())
|
|
149
|
+
item_lists = [slot_items[name] for name in slot_names]
|
|
150
|
+
|
|
151
|
+
# Generate all combinations
|
|
152
|
+
combinations: list[dict[str, LexicalItem]] = []
|
|
153
|
+
for combo_tuple in cartesian_product(*item_lists):
|
|
154
|
+
combo_dict = dict(zip(slot_names, combo_tuple, strict=True))
|
|
155
|
+
combinations.append(combo_dict)
|
|
156
|
+
|
|
157
|
+
return combinations
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class RandomStrategy(FillingStrategy):
|
|
161
|
+
"""Generate random sample of combinations.
|
|
162
|
+
|
|
163
|
+
Sample combinations randomly with optional seeding for
|
|
164
|
+
reproducibility. Use for large combinatorial spaces.
|
|
165
|
+
|
|
166
|
+
Parameters
|
|
167
|
+
----------
|
|
168
|
+
n_samples : int
|
|
169
|
+
Number of combinations to generate.
|
|
170
|
+
seed : int | None
|
|
171
|
+
Random seed for reproducibility. Default: None.
|
|
172
|
+
|
|
173
|
+
Examples
|
|
174
|
+
--------
|
|
175
|
+
>>> strategy = RandomStrategy(n_samples=10, seed=42)
|
|
176
|
+
>>> combinations = strategy.generate_combinations(slot_items)
|
|
177
|
+
>>> len(combinations)
|
|
178
|
+
10
|
|
179
|
+
"""
|
|
180
|
+
|
|
181
|
+
def __init__(self, n_samples: int, seed: int | None = None) -> None:
|
|
182
|
+
"""Initialize random strategy.
|
|
183
|
+
|
|
184
|
+
Parameters
|
|
185
|
+
----------
|
|
186
|
+
n_samples : int
|
|
187
|
+
Number of combinations to generate.
|
|
188
|
+
seed : int | None
|
|
189
|
+
Random seed for reproducibility.
|
|
190
|
+
"""
|
|
191
|
+
self.n_samples = n_samples
|
|
192
|
+
self.seed = seed
|
|
193
|
+
|
|
194
|
+
@property
|
|
195
|
+
def name(self) -> str:
|
|
196
|
+
"""Get strategy name."""
|
|
197
|
+
return "random"
|
|
198
|
+
|
|
199
|
+
def generate_combinations(
|
|
200
|
+
self,
|
|
201
|
+
slot_items: dict[str, list[LexicalItem]],
|
|
202
|
+
) -> list[dict[str, LexicalItem]]:
|
|
203
|
+
"""Generate random combinations.
|
|
204
|
+
|
|
205
|
+
Parameters
|
|
206
|
+
----------
|
|
207
|
+
slot_items : dict[str, list[LexicalItem]]
|
|
208
|
+
Mapping of slot names to valid items.
|
|
209
|
+
|
|
210
|
+
Returns
|
|
211
|
+
-------
|
|
212
|
+
list[dict[str, LexicalItem]]
|
|
213
|
+
Randomly sampled combinations.
|
|
214
|
+
"""
|
|
215
|
+
if not slot_items:
|
|
216
|
+
return []
|
|
217
|
+
|
|
218
|
+
# Set random seed if provided
|
|
219
|
+
if self.seed is not None:
|
|
220
|
+
random.seed(self.seed)
|
|
221
|
+
|
|
222
|
+
# Get ordered slot names and item lists
|
|
223
|
+
slot_names = list(slot_items.keys())
|
|
224
|
+
item_lists = [slot_items[name] for name in slot_names]
|
|
225
|
+
|
|
226
|
+
# Generate random combinations
|
|
227
|
+
combinations: list[dict[str, LexicalItem]] = []
|
|
228
|
+
for _ in range(self.n_samples):
|
|
229
|
+
combo_tuple = tuple(random.choice(items) for items in item_lists)
|
|
230
|
+
combo_dict = dict(zip(slot_names, combo_tuple, strict=True))
|
|
231
|
+
combinations.append(combo_dict)
|
|
232
|
+
|
|
233
|
+
return combinations
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
class StratifiedStrategy(FillingStrategy):
|
|
237
|
+
"""Generate balanced sample across item groups.
|
|
238
|
+
|
|
239
|
+
Ensure each group of items (e.g., by POS, features) is
|
|
240
|
+
represented proportionally in the sample.
|
|
241
|
+
|
|
242
|
+
Parameters
|
|
243
|
+
----------
|
|
244
|
+
n_samples : int
|
|
245
|
+
Total number of combinations to generate.
|
|
246
|
+
grouping_property : str
|
|
247
|
+
Property to group items by (e.g., "pos", "features.transitivity").
|
|
248
|
+
seed : int | None
|
|
249
|
+
Random seed for reproducibility. Default: None.
|
|
250
|
+
|
|
251
|
+
Examples
|
|
252
|
+
--------
|
|
253
|
+
>>> strategy = StratifiedStrategy(
|
|
254
|
+
... n_samples=20,
|
|
255
|
+
... grouping_property="pos",
|
|
256
|
+
... seed=42
|
|
257
|
+
... )
|
|
258
|
+
>>> combinations = strategy.generate_combinations(slot_items)
|
|
259
|
+
>>> # Ensures balanced representation of different POS values
|
|
260
|
+
"""
|
|
261
|
+
|
|
262
|
+
def __init__(
|
|
263
|
+
self,
|
|
264
|
+
n_samples: int,
|
|
265
|
+
grouping_property: str,
|
|
266
|
+
seed: int | None = None,
|
|
267
|
+
) -> None:
|
|
268
|
+
"""Initialize stratified strategy.
|
|
269
|
+
|
|
270
|
+
Parameters
|
|
271
|
+
----------
|
|
272
|
+
n_samples : int
|
|
273
|
+
Total number of combinations to generate.
|
|
274
|
+
grouping_property : str
|
|
275
|
+
Property to group items by.
|
|
276
|
+
seed : int | None
|
|
277
|
+
Random seed for reproducibility.
|
|
278
|
+
"""
|
|
279
|
+
self.n_samples = n_samples
|
|
280
|
+
self.grouping_property = grouping_property
|
|
281
|
+
self.seed = seed
|
|
282
|
+
|
|
283
|
+
@property
|
|
284
|
+
def name(self) -> str:
|
|
285
|
+
"""Get strategy name."""
|
|
286
|
+
return "stratified"
|
|
287
|
+
|
|
288
|
+
def generate_combinations(
|
|
289
|
+
self,
|
|
290
|
+
slot_items: dict[str, list[LexicalItem]],
|
|
291
|
+
) -> list[dict[str, LexicalItem]]:
|
|
292
|
+
"""Generate stratified combinations.
|
|
293
|
+
|
|
294
|
+
Parameters
|
|
295
|
+
----------
|
|
296
|
+
slot_items : dict[str, list[LexicalItem]]
|
|
297
|
+
Mapping of slot names to valid items.
|
|
298
|
+
|
|
299
|
+
Returns
|
|
300
|
+
-------
|
|
301
|
+
list[dict[str, LexicalItem]]
|
|
302
|
+
Balanced combinations across groups.
|
|
303
|
+
"""
|
|
304
|
+
if not slot_items:
|
|
305
|
+
return []
|
|
306
|
+
|
|
307
|
+
# Set random seed if provided
|
|
308
|
+
if self.seed is not None:
|
|
309
|
+
random.seed(self.seed)
|
|
310
|
+
|
|
311
|
+
# Group items by the specified property
|
|
312
|
+
grouped_items: dict[str, dict[str, list[LexicalItem]]] = {}
|
|
313
|
+
for slot_name, items in slot_items.items():
|
|
314
|
+
slot_groups: dict[str, list[LexicalItem]] = {}
|
|
315
|
+
for item in items:
|
|
316
|
+
# Get property value (handle nested properties)
|
|
317
|
+
value = self._get_property_value(item, self.grouping_property)
|
|
318
|
+
if value not in slot_groups:
|
|
319
|
+
slot_groups[value] = []
|
|
320
|
+
slot_groups[value].append(item)
|
|
321
|
+
grouped_items[slot_name] = slot_groups
|
|
322
|
+
|
|
323
|
+
# Sample proportionally from each group
|
|
324
|
+
combinations: list[dict[str, LexicalItem]] = []
|
|
325
|
+
slot_names = list(slot_items.keys())
|
|
326
|
+
|
|
327
|
+
# Calculate samples per group
|
|
328
|
+
# For simplicity, sample equally from all groups
|
|
329
|
+
for _ in range(self.n_samples):
|
|
330
|
+
combo_dict: dict[str, LexicalItem] = {}
|
|
331
|
+
for slot_name in slot_names:
|
|
332
|
+
slot_groups = grouped_items[slot_name]
|
|
333
|
+
# Choose a random group, then a random item from that group
|
|
334
|
+
if slot_groups:
|
|
335
|
+
group_key = random.choice(list(slot_groups.keys()))
|
|
336
|
+
item = random.choice(slot_groups[group_key])
|
|
337
|
+
combo_dict[slot_name] = item
|
|
338
|
+
combinations.append(combo_dict)
|
|
339
|
+
|
|
340
|
+
return combinations
|
|
341
|
+
|
|
342
|
+
def _get_property_value(self, item: LexicalItem, property_path: str) -> str:
|
|
343
|
+
"""Get property value from item, handling nested properties.
|
|
344
|
+
|
|
345
|
+
Parameters
|
|
346
|
+
----------
|
|
347
|
+
item : LexicalItem
|
|
348
|
+
Item to get property from.
|
|
349
|
+
property_path : str
|
|
350
|
+
Property path (e.g., "pos" or "features.transitivity").
|
|
351
|
+
|
|
352
|
+
Returns
|
|
353
|
+
-------
|
|
354
|
+
str
|
|
355
|
+
Property value as string.
|
|
356
|
+
"""
|
|
357
|
+
parts = property_path.split(".")
|
|
358
|
+
value = item
|
|
359
|
+
for part in parts:
|
|
360
|
+
if hasattr(value, part):
|
|
361
|
+
value = getattr(value, part)
|
|
362
|
+
else:
|
|
363
|
+
return "unknown"
|
|
364
|
+
|
|
365
|
+
# Convert to string for grouping
|
|
366
|
+
if value is None:
|
|
367
|
+
return "none"
|
|
368
|
+
return str(value)
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
class MLMFillingStrategy(FillingStrategy):
|
|
372
|
+
"""Fill templates using masked language models with beam search.
|
|
373
|
+
|
|
374
|
+
Uses pre-trained MLMs (BERT, RoBERTa, etc.) to propose linguistically
|
|
375
|
+
plausible slot fillers. Supports beam search for multiple slots with
|
|
376
|
+
configurable fill directions.
|
|
377
|
+
|
|
378
|
+
Parameters
|
|
379
|
+
----------
|
|
380
|
+
resolver : ConstraintResolver
|
|
381
|
+
Constraint resolver for filtering candidates
|
|
382
|
+
model_adapter : HuggingFaceMLMAdapter
|
|
383
|
+
Loaded MLM adapter
|
|
384
|
+
beam_size : int
|
|
385
|
+
Beam search width (K best hypotheses)
|
|
386
|
+
fill_direction : Literal
|
|
387
|
+
Direction for filling slots. One of: "left_to_right", "right_to_left",
|
|
388
|
+
"inside_out", "outside_in", "custom"
|
|
389
|
+
custom_order : list[int] | None
|
|
390
|
+
Custom slot fill order (slot indices)
|
|
391
|
+
top_k : int
|
|
392
|
+
Top-K candidates per slot from MLM
|
|
393
|
+
cache : ModelOutputCache | None
|
|
394
|
+
Cache for model predictions
|
|
395
|
+
budget : int | None
|
|
396
|
+
Maximum combinations to generate
|
|
397
|
+
|
|
398
|
+
Examples
|
|
399
|
+
--------
|
|
400
|
+
>>> from bead.templates.adapters import HuggingFaceMLMAdapter, ModelOutputCache
|
|
401
|
+
>>> adapter = HuggingFaceMLMAdapter("bert-base-uncased")
|
|
402
|
+
>>> adapter.load_model()
|
|
403
|
+
>>> cache = ModelOutputCache(Path("/tmp/cache"))
|
|
404
|
+
>>> strategy = MLMFillingStrategy(
|
|
405
|
+
... resolver=resolver,
|
|
406
|
+
... model_adapter=adapter,
|
|
407
|
+
... beam_size=5,
|
|
408
|
+
... fill_direction="left_to_right",
|
|
409
|
+
... cache=cache
|
|
410
|
+
... )
|
|
411
|
+
>>> combinations = strategy.generate_combinations(slot_items)
|
|
412
|
+
"""
|
|
413
|
+
|
|
414
|
+
def __init__(
|
|
415
|
+
self,
|
|
416
|
+
resolver: ConstraintResolver,
|
|
417
|
+
model_adapter: HuggingFaceMLMAdapter,
|
|
418
|
+
beam_size: int = 5,
|
|
419
|
+
fill_direction: Literal[
|
|
420
|
+
"left_to_right", "right_to_left", "inside_out", "outside_in", "custom"
|
|
421
|
+
] = "left_to_right",
|
|
422
|
+
custom_order: list[int] | None = None,
|
|
423
|
+
top_k: int = 20,
|
|
424
|
+
cache: ModelOutputCache | None = None,
|
|
425
|
+
budget: int | None = None,
|
|
426
|
+
per_slot_max_fills: dict[str, int] | None = None,
|
|
427
|
+
per_slot_enforce_unique: dict[str, bool] | None = None,
|
|
428
|
+
) -> None:
|
|
429
|
+
"""Initialize MLM strategy.
|
|
430
|
+
|
|
431
|
+
Parameters
|
|
432
|
+
----------
|
|
433
|
+
resolver : ConstraintResolver
|
|
434
|
+
Constraint resolver
|
|
435
|
+
model_adapter : HuggingFaceMLMAdapter
|
|
436
|
+
MLM adapter (must be loaded)
|
|
437
|
+
beam_size : int
|
|
438
|
+
Beam width
|
|
439
|
+
fill_direction : str
|
|
440
|
+
Fill direction
|
|
441
|
+
custom_order : list[int] | None
|
|
442
|
+
Custom fill order
|
|
443
|
+
top_k : int
|
|
444
|
+
Top-K from MLM
|
|
445
|
+
cache : ModelOutputCache | None
|
|
446
|
+
Prediction cache
|
|
447
|
+
budget : int | None
|
|
448
|
+
Max combinations
|
|
449
|
+
per_slot_max_fills : dict[str, int] | None
|
|
450
|
+
Maximum number of unique fills per slot (after constraint filtering)
|
|
451
|
+
per_slot_enforce_unique : dict[str, bool] | None
|
|
452
|
+
Whether to enforce uniqueness for each slot across beam hypotheses
|
|
453
|
+
"""
|
|
454
|
+
self.resolver = resolver
|
|
455
|
+
self.model_adapter = model_adapter
|
|
456
|
+
self.beam_size = beam_size
|
|
457
|
+
self.fill_direction = fill_direction
|
|
458
|
+
self.custom_order = custom_order
|
|
459
|
+
self.top_k = top_k
|
|
460
|
+
self.cache = cache
|
|
461
|
+
self.budget = budget
|
|
462
|
+
self.per_slot_max_fills = per_slot_max_fills or {}
|
|
463
|
+
self.per_slot_enforce_unique = per_slot_enforce_unique or {}
|
|
464
|
+
|
|
465
|
+
if not model_adapter.is_loaded():
|
|
466
|
+
raise ValueError("Model adapter must be loaded before use")
|
|
467
|
+
|
|
468
|
+
if fill_direction == "custom" and custom_order is None:
|
|
469
|
+
raise ValueError("custom_order required when fill_direction is 'custom'")
|
|
470
|
+
|
|
471
|
+
@property
|
|
472
|
+
def name(self) -> str:
|
|
473
|
+
"""Get strategy name."""
|
|
474
|
+
return "mlm"
|
|
475
|
+
|
|
476
|
+
def generate_combinations(
|
|
477
|
+
self,
|
|
478
|
+
slot_items: dict[str, list[LexicalItem]],
|
|
479
|
+
) -> list[dict[str, LexicalItem]]:
|
|
480
|
+
"""Generate combinations using MLM beam search.
|
|
481
|
+
|
|
482
|
+
Note: This method adapts slot_items to template-based generation.
|
|
483
|
+
The actual beam search is implemented in generate_from_template.
|
|
484
|
+
|
|
485
|
+
Parameters
|
|
486
|
+
----------
|
|
487
|
+
slot_items : dict[str, list[LexicalItem]]
|
|
488
|
+
Mapping of slot names to valid items (for constraint filtering)
|
|
489
|
+
|
|
490
|
+
Returns
|
|
491
|
+
-------
|
|
492
|
+
list[dict[str, LexicalItem]]
|
|
493
|
+
Combinations generated via beam search
|
|
494
|
+
|
|
495
|
+
Raises
|
|
496
|
+
------
|
|
497
|
+
NotImplementedError
|
|
498
|
+
This method requires template context. Use generate_from_template instead.
|
|
499
|
+
"""
|
|
500
|
+
raise NotImplementedError(
|
|
501
|
+
"MLMFillingStrategy requires template context. "
|
|
502
|
+
"Use TemplateFiller with MLMFillingStrategy, which calls "
|
|
503
|
+
"generate_from_template internally."
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
def generate_from_template(
|
|
507
|
+
self,
|
|
508
|
+
template: Template,
|
|
509
|
+
lexicons: list[Lexicon],
|
|
510
|
+
language_code: LanguageCode | None = None,
|
|
511
|
+
) -> Iterator[dict[str, LexicalItem]]:
|
|
512
|
+
"""Generate combinations from template using beam search.
|
|
513
|
+
|
|
514
|
+
Parameters
|
|
515
|
+
----------
|
|
516
|
+
template : Template
|
|
517
|
+
Template to fill
|
|
518
|
+
lexicons : list[Lexicon]
|
|
519
|
+
Lexicons for constraint resolution
|
|
520
|
+
language_code : LanguageCode | None
|
|
521
|
+
Language filter
|
|
522
|
+
|
|
523
|
+
Yields
|
|
524
|
+
------
|
|
525
|
+
dict[str, LexicalItem]
|
|
526
|
+
Slot-to-item mappings
|
|
527
|
+
"""
|
|
528
|
+
logger.info(
|
|
529
|
+
f"[MLMFillingStrategy] Starting beam search for template: {template.name}"
|
|
530
|
+
)
|
|
531
|
+
|
|
532
|
+
# Get slot names and order
|
|
533
|
+
slot_names = list(template.slots.keys())
|
|
534
|
+
if not slot_names:
|
|
535
|
+
return
|
|
536
|
+
|
|
537
|
+
fill_order = self._get_fill_order(len(slot_names))
|
|
538
|
+
logger.info(
|
|
539
|
+
f"[MLMFillingStrategy] Slots to fill ({len(slot_names)}): {slot_names}"
|
|
540
|
+
)
|
|
541
|
+
logger.info(
|
|
542
|
+
f"[MLMFillingStrategy] Fill order: {[slot_names[i] for i in fill_order]}"
|
|
543
|
+
)
|
|
544
|
+
logger.info(f"[MLMFillingStrategy] Beam size: {self.beam_size}")
|
|
545
|
+
|
|
546
|
+
# Initialize beam with empty hypothesis
|
|
547
|
+
# Each beam item: (filled_slots_dict, cumulative_log_prob)
|
|
548
|
+
beam: list[tuple[dict[str, LexicalItem], float]] = [({}, 0.0)]
|
|
549
|
+
|
|
550
|
+
# Track seen items per slot (for uniqueness enforcement)
|
|
551
|
+
seen_items_per_slot: dict[str, set[UUID]] = {
|
|
552
|
+
slot_name: set() for slot_name in slot_names
|
|
553
|
+
}
|
|
554
|
+
|
|
555
|
+
# Fill slots in order
|
|
556
|
+
beam_start = time.time()
|
|
557
|
+
for step_num, slot_idx in enumerate(fill_order, 1):
|
|
558
|
+
step_start = time.time()
|
|
559
|
+
slot_name = slot_names[slot_idx]
|
|
560
|
+
slot = template.slots[slot_name]
|
|
561
|
+
logger.info(
|
|
562
|
+
f"[MLMFillingStrategy] Step {step_num}/{len(fill_order)}: Filling slot '{slot_name}', current beam size: {len(beam)}" # noqa: E501
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
new_beam: list[tuple[dict[str, LexicalItem], float]] = []
|
|
566
|
+
|
|
567
|
+
# Check if uniqueness is enforced for this slot
|
|
568
|
+
enforce_unique = self.per_slot_enforce_unique.get(slot_name, False)
|
|
569
|
+
max_fills = self.per_slot_max_fills.get(slot_name, None)
|
|
570
|
+
logger.info(
|
|
571
|
+
f"[MLMFillingStrategy] enforce_unique={enforce_unique}, max_fills={max_fills}" # noqa: E501
|
|
572
|
+
)
|
|
573
|
+
|
|
574
|
+
# BATCHED: Get MLM predictions for all beam hypotheses at once
|
|
575
|
+
if beam:
|
|
576
|
+
# Collect masked texts for all hypotheses
|
|
577
|
+
masked_start = time.time()
|
|
578
|
+
masked_texts = []
|
|
579
|
+
for filled_slots, _ in beam:
|
|
580
|
+
masked_text = self._create_masked_text(
|
|
581
|
+
template, slot_names, filled_slots, slot_idx
|
|
582
|
+
)
|
|
583
|
+
masked_texts.append(masked_text)
|
|
584
|
+
masked_elapsed = time.time() - masked_start
|
|
585
|
+
|
|
586
|
+
# Batch predict - single model call for entire beam
|
|
587
|
+
logger.info(
|
|
588
|
+
f"[MLMFillingStrategy] Batch predicting for {len(masked_texts)} hypotheses..." # noqa: E501
|
|
589
|
+
)
|
|
590
|
+
batch_start = time.time()
|
|
591
|
+
predictions_batch = self._get_mlm_predictions_batch(masked_texts)
|
|
592
|
+
batch_elapsed = time.time() - batch_start
|
|
593
|
+
logger.info(
|
|
594
|
+
f"[MLMFillingStrategy] Batch prediction took {batch_elapsed:.2f}s (masking took {masked_elapsed:.3f}s)" # noqa: E501
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
# Expand each hypothesis with its predictions
|
|
598
|
+
expand_start = time.time()
|
|
599
|
+
total_candidates = 0
|
|
600
|
+
for (filled_slots, cum_log_prob), predictions in zip(
|
|
601
|
+
beam, predictions_batch, strict=True
|
|
602
|
+
):
|
|
603
|
+
# Filter predictions to get candidates
|
|
604
|
+
candidates = self._filter_mlm_predictions(
|
|
605
|
+
predictions,
|
|
606
|
+
slot,
|
|
607
|
+
lexicons,
|
|
608
|
+
language_code,
|
|
609
|
+
seen_items=seen_items_per_slot[slot_name]
|
|
610
|
+
if enforce_unique
|
|
611
|
+
else None,
|
|
612
|
+
max_fills=max_fills,
|
|
613
|
+
)
|
|
614
|
+
total_candidates += len(candidates)
|
|
615
|
+
|
|
616
|
+
# Add each candidate to beam
|
|
617
|
+
for item, log_prob in candidates:
|
|
618
|
+
new_filled = filled_slots.copy()
|
|
619
|
+
new_filled[slot_name] = item
|
|
620
|
+
new_log_prob = cum_log_prob + log_prob
|
|
621
|
+
new_beam.append((new_filled, new_log_prob))
|
|
622
|
+
|
|
623
|
+
# Track seen items if uniqueness is enforced
|
|
624
|
+
if enforce_unique:
|
|
625
|
+
seen_items_per_slot[slot_name].add(item.id)
|
|
626
|
+
expand_elapsed = time.time() - expand_start
|
|
627
|
+
logger.info(
|
|
628
|
+
f"[MLMFillingStrategy] Expanded beam with {total_candidates} total candidates in {expand_elapsed:.3f}s" # noqa: E501
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
# Prune beam to top-K by score (length-normalized)
|
|
632
|
+
prune_start = time.time()
|
|
633
|
+
if new_beam:
|
|
634
|
+
# Length-normalize scores
|
|
635
|
+
num_filled = len(new_beam[0][0])
|
|
636
|
+
scored_beam = [
|
|
637
|
+
(filled, log_prob / num_filled, log_prob)
|
|
638
|
+
for filled, log_prob in new_beam
|
|
639
|
+
]
|
|
640
|
+
scored_beam.sort(key=lambda x: x[1], reverse=True)
|
|
641
|
+
|
|
642
|
+
# Keep top beam_size
|
|
643
|
+
beam = [
|
|
644
|
+
(filled, cum_log_prob)
|
|
645
|
+
for filled, _, cum_log_prob in scored_beam[: self.beam_size]
|
|
646
|
+
]
|
|
647
|
+
prune_elapsed = time.time() - prune_start
|
|
648
|
+
logger.info(
|
|
649
|
+
f"[MLMFillingStrategy] Pruned {len(new_beam)} hypotheses to {len(beam)} in {prune_elapsed:.3f}s" # noqa: E501
|
|
650
|
+
)
|
|
651
|
+
else:
|
|
652
|
+
# No valid candidates - empty beam
|
|
653
|
+
logger.warning(
|
|
654
|
+
"[MLMFillingStrategy] No valid candidates found! Beam is empty."
|
|
655
|
+
)
|
|
656
|
+
beam = []
|
|
657
|
+
break
|
|
658
|
+
|
|
659
|
+
step_elapsed = time.time() - step_start
|
|
660
|
+
logger.info(
|
|
661
|
+
f"[MLMFillingStrategy] Step {step_num} completed in {step_elapsed:.2f}s\n" # noqa: E501
|
|
662
|
+
)
|
|
663
|
+
|
|
664
|
+
beam_elapsed = time.time() - beam_start
|
|
665
|
+
logger.info(
|
|
666
|
+
f"[MLMFillingStrategy] Beam search complete in {beam_elapsed:.2f}s, yielding {len(beam)} hypotheses" # noqa: E501
|
|
667
|
+
)
|
|
668
|
+
|
|
669
|
+
# Yield final hypotheses
|
|
670
|
+
count = 0
|
|
671
|
+
for filled_slots, _ in beam:
|
|
672
|
+
if self.budget and count >= self.budget:
|
|
673
|
+
break
|
|
674
|
+
yield filled_slots
|
|
675
|
+
count += 1
|
|
676
|
+
|
|
677
|
+
def _get_fill_order(self, num_slots: int) -> list[int]:
|
|
678
|
+
"""Get slot fill order based on fill_direction.
|
|
679
|
+
|
|
680
|
+
Parameters
|
|
681
|
+
----------
|
|
682
|
+
num_slots : int
|
|
683
|
+
Number of slots
|
|
684
|
+
|
|
685
|
+
Returns
|
|
686
|
+
-------
|
|
687
|
+
list[int]
|
|
688
|
+
Slot indices in fill order
|
|
689
|
+
"""
|
|
690
|
+
if self.fill_direction == "custom":
|
|
691
|
+
if self.custom_order is None:
|
|
692
|
+
raise ValueError("custom_order not set")
|
|
693
|
+
return self.custom_order
|
|
694
|
+
|
|
695
|
+
indices = list(range(num_slots))
|
|
696
|
+
|
|
697
|
+
if self.fill_direction == "left_to_right":
|
|
698
|
+
return indices
|
|
699
|
+
elif self.fill_direction == "right_to_left":
|
|
700
|
+
return list(reversed(indices))
|
|
701
|
+
elif self.fill_direction == "inside_out":
|
|
702
|
+
# Alternate from center outward
|
|
703
|
+
mid = num_slots // 2
|
|
704
|
+
order: list[int] = []
|
|
705
|
+
for i in range(num_slots):
|
|
706
|
+
if i % 2 == 0:
|
|
707
|
+
order.append(mid + i // 2)
|
|
708
|
+
else:
|
|
709
|
+
order.append(mid - (i + 1) // 2)
|
|
710
|
+
return [idx for idx in order if 0 <= idx < num_slots]
|
|
711
|
+
elif self.fill_direction == "outside_in":
|
|
712
|
+
# Alternate from edges inward
|
|
713
|
+
order: list[int] = []
|
|
714
|
+
left, right = 0, num_slots - 1
|
|
715
|
+
while left <= right:
|
|
716
|
+
order.append(left)
|
|
717
|
+
if left != right:
|
|
718
|
+
order.append(right)
|
|
719
|
+
left += 1
|
|
720
|
+
right -= 1
|
|
721
|
+
return order
|
|
722
|
+
else:
|
|
723
|
+
raise ValueError(f"Unknown fill_direction: {self.fill_direction}")
|
|
724
|
+
|
|
725
|
+
def _get_mlm_candidates(
|
|
726
|
+
self,
|
|
727
|
+
template: Template,
|
|
728
|
+
slot_names: list[str],
|
|
729
|
+
slot_idx: int,
|
|
730
|
+
filled_slots: dict[str, LexicalItem],
|
|
731
|
+
slot: Slot,
|
|
732
|
+
lexicons: list[Lexicon],
|
|
733
|
+
language_code: LanguageCode | None,
|
|
734
|
+
seen_items: set[UUID] | None = None,
|
|
735
|
+
max_fills: int | None = None,
|
|
736
|
+
) -> list[tuple[LexicalItem, float]]:
|
|
737
|
+
"""Get MLM candidates for a slot.
|
|
738
|
+
|
|
739
|
+
Parameters
|
|
740
|
+
----------
|
|
741
|
+
template : Template
|
|
742
|
+
Template being filled
|
|
743
|
+
slot_names : list[str]
|
|
744
|
+
Ordered slot names
|
|
745
|
+
slot_idx : int
|
|
746
|
+
Index of slot to fill
|
|
747
|
+
filled_slots : dict[str, LexicalItem]
|
|
748
|
+
Already-filled slots
|
|
749
|
+
slot : Slot
|
|
750
|
+
Slot object
|
|
751
|
+
lexicons : list[Lexicon]
|
|
752
|
+
Lexicons for lookup
|
|
753
|
+
language_code : LanguageCode | None
|
|
754
|
+
Language filter
|
|
755
|
+
seen_items : set | None
|
|
756
|
+
Set of item IDs already used for this slot (for uniqueness enforcement)
|
|
757
|
+
max_fills : int | None
|
|
758
|
+
Maximum number of candidates to return (applied after filtering)
|
|
759
|
+
|
|
760
|
+
Returns
|
|
761
|
+
-------
|
|
762
|
+
list[tuple[LexicalItem, float]]
|
|
763
|
+
(item, log_prob) pairs, limited by max_fills and uniqueness
|
|
764
|
+
"""
|
|
765
|
+
# Normalize language code to ISO 639-3
|
|
766
|
+
if language_code is not None:
|
|
767
|
+
language_code = validate_iso639_code(language_code)
|
|
768
|
+
|
|
769
|
+
# Create masked text
|
|
770
|
+
masked_text = self._create_masked_text(
|
|
771
|
+
template, slot_names, filled_slots, slot_idx
|
|
772
|
+
)
|
|
773
|
+
|
|
774
|
+
# Get predictions from MLM (with cache)
|
|
775
|
+
if self.cache:
|
|
776
|
+
predictions = self.cache.get(
|
|
777
|
+
self.model_adapter.model_name,
|
|
778
|
+
masked_text,
|
|
779
|
+
0, # First mask position
|
|
780
|
+
self.top_k,
|
|
781
|
+
)
|
|
782
|
+
else:
|
|
783
|
+
predictions = None
|
|
784
|
+
|
|
785
|
+
if predictions is None:
|
|
786
|
+
predictions = self.model_adapter.predict_masked_token(
|
|
787
|
+
masked_text,
|
|
788
|
+
mask_position=0,
|
|
789
|
+
top_k=self.top_k,
|
|
790
|
+
)
|
|
791
|
+
if self.cache:
|
|
792
|
+
self.cache.set(
|
|
793
|
+
self.model_adapter.model_name,
|
|
794
|
+
masked_text,
|
|
795
|
+
0,
|
|
796
|
+
self.top_k,
|
|
797
|
+
predictions,
|
|
798
|
+
)
|
|
799
|
+
|
|
800
|
+
# Filter by constraints and find matching lexical items
|
|
801
|
+
candidates: list[tuple[LexicalItem, float]] = []
|
|
802
|
+
for token, log_prob in predictions:
|
|
803
|
+
# Find matching items in lexicons
|
|
804
|
+
for lexicon in lexicons:
|
|
805
|
+
for item in lexicon.items.values():
|
|
806
|
+
# Skip if already seen (uniqueness enforcement)
|
|
807
|
+
if seen_items is not None and item.id in seen_items:
|
|
808
|
+
continue
|
|
809
|
+
|
|
810
|
+
# Match lemma and language
|
|
811
|
+
if item.lemma.lower() == token.lower():
|
|
812
|
+
if language_code is None or item.language_code == language_code:
|
|
813
|
+
# Check slot constraints
|
|
814
|
+
if slot.constraints:
|
|
815
|
+
# Evaluate constraints using resolver
|
|
816
|
+
if self.resolver.evaluate_slot_constraints(
|
|
817
|
+
item, slot.constraints
|
|
818
|
+
):
|
|
819
|
+
candidates.append((item, log_prob))
|
|
820
|
+
else:
|
|
821
|
+
candidates.append((item, log_prob))
|
|
822
|
+
|
|
823
|
+
# Apply max_fills limit (take top-N by log probability)
|
|
824
|
+
if max_fills is not None and len(candidates) > max_fills:
|
|
825
|
+
# Already sorted by log_prob (descending) from MLM predictions
|
|
826
|
+
# But need to ensure we take highest scoring ones
|
|
827
|
+
candidates.sort(key=lambda x: x[1], reverse=True)
|
|
828
|
+
candidates = candidates[:max_fills]
|
|
829
|
+
|
|
830
|
+
return candidates
|
|
831
|
+
|
|
832
|
+
def _get_mlm_predictions_batch(
|
|
833
|
+
self, masked_texts: list[str]
|
|
834
|
+
) -> list[list[tuple[str, float]]]:
|
|
835
|
+
"""Get MLM predictions for a batch of masked texts.
|
|
836
|
+
|
|
837
|
+
Parameters
|
|
838
|
+
----------
|
|
839
|
+
masked_texts : list[str]
|
|
840
|
+
List of texts with mask tokens
|
|
841
|
+
|
|
842
|
+
Returns
|
|
843
|
+
-------
|
|
844
|
+
list[list[tuple[str, float]]]
|
|
845
|
+
Predictions for each text: list of (token, log_prob) tuples
|
|
846
|
+
"""
|
|
847
|
+
cache_start = time.time()
|
|
848
|
+
|
|
849
|
+
# Check cache for each text first
|
|
850
|
+
predictions_batch: list[list[tuple[str, float]] | None] = []
|
|
851
|
+
texts_to_predict: list[int] = [] # Indices needing prediction
|
|
852
|
+
|
|
853
|
+
for i, masked_text in enumerate(masked_texts):
|
|
854
|
+
if self.cache:
|
|
855
|
+
predictions = self.cache.get(
|
|
856
|
+
self.model_adapter.model_name,
|
|
857
|
+
masked_text,
|
|
858
|
+
0, # First mask position
|
|
859
|
+
self.top_k,
|
|
860
|
+
)
|
|
861
|
+
else:
|
|
862
|
+
predictions = None
|
|
863
|
+
|
|
864
|
+
predictions_batch.append(predictions)
|
|
865
|
+
if predictions is None:
|
|
866
|
+
texts_to_predict.append(i)
|
|
867
|
+
|
|
868
|
+
cache_elapsed = time.time() - cache_start
|
|
869
|
+
cache_hits = len(masked_texts) - len(texts_to_predict)
|
|
870
|
+
logger.info(
|
|
871
|
+
f"[MLMFillingStrategy] Cache: {cache_hits}/"
|
|
872
|
+
f"{len(masked_texts)} hits in {cache_elapsed:.3f}s"
|
|
873
|
+
)
|
|
874
|
+
|
|
875
|
+
# Batch predict uncached texts
|
|
876
|
+
if texts_to_predict:
|
|
877
|
+
logger.info(
|
|
878
|
+
f"[MLMFillingStrategy] Calling model for "
|
|
879
|
+
f"{len(texts_to_predict)} uncached texts..."
|
|
880
|
+
)
|
|
881
|
+
model_start = time.time()
|
|
882
|
+
uncached_texts = [masked_texts[i] for i in texts_to_predict]
|
|
883
|
+
new_predictions = self.model_adapter.predict_masked_token_batch(
|
|
884
|
+
uncached_texts,
|
|
885
|
+
mask_position=0,
|
|
886
|
+
top_k=self.top_k,
|
|
887
|
+
)
|
|
888
|
+
model_elapsed = time.time() - model_start
|
|
889
|
+
per_text = model_elapsed / len(texts_to_predict)
|
|
890
|
+
logger.info(
|
|
891
|
+
f"[MLMFillingStrategy] Model inference took "
|
|
892
|
+
f"{model_elapsed:.2f}s ({per_text:.3f}s per text)"
|
|
893
|
+
)
|
|
894
|
+
|
|
895
|
+
# Fill in predictions and cache them
|
|
896
|
+
cache_write_start = time.time()
|
|
897
|
+
for idx, predictions in zip(texts_to_predict, new_predictions, strict=True):
|
|
898
|
+
predictions_batch[idx] = predictions
|
|
899
|
+
if self.cache:
|
|
900
|
+
self.cache.set(
|
|
901
|
+
self.model_adapter.model_name,
|
|
902
|
+
masked_texts[idx],
|
|
903
|
+
0,
|
|
904
|
+
self.top_k,
|
|
905
|
+
predictions,
|
|
906
|
+
)
|
|
907
|
+
cache_write_elapsed = time.time() - cache_write_start
|
|
908
|
+
logger.info(
|
|
909
|
+
f"[MLMFillingStrategy] Cache writes took {cache_write_elapsed:.3f}s"
|
|
910
|
+
)
|
|
911
|
+
|
|
912
|
+
# Convert None to empty list (shouldn't happen but for type safety)
|
|
913
|
+
return [p if p is not None else [] for p in predictions_batch]
|
|
914
|
+
|
|
915
|
+
def _filter_mlm_predictions(
|
|
916
|
+
self,
|
|
917
|
+
predictions: list[tuple[str, float]],
|
|
918
|
+
slot: Slot,
|
|
919
|
+
lexicons: list[Lexicon],
|
|
920
|
+
language_code: LanguageCode | None,
|
|
921
|
+
seen_items: set[UUID] | None = None,
|
|
922
|
+
max_fills: int | None = None,
|
|
923
|
+
) -> list[tuple[LexicalItem, float]]:
|
|
924
|
+
"""Filter MLM predictions to valid lexical items.
|
|
925
|
+
|
|
926
|
+
Parameters
|
|
927
|
+
----------
|
|
928
|
+
predictions : list[tuple[str, float]]
|
|
929
|
+
Raw (token, log_prob) predictions from MLM
|
|
930
|
+
slot : Slot
|
|
931
|
+
Slot object with constraints
|
|
932
|
+
lexicons : list[Lexicon]
|
|
933
|
+
Lexicons for lookup
|
|
934
|
+
language_code : LanguageCode | None
|
|
935
|
+
Language filter
|
|
936
|
+
seen_items : set[UUID] | None
|
|
937
|
+
Set of item IDs already used (for uniqueness enforcement)
|
|
938
|
+
max_fills : int | None
|
|
939
|
+
Maximum number of candidates to return
|
|
940
|
+
|
|
941
|
+
Returns
|
|
942
|
+
-------
|
|
943
|
+
list[tuple[LexicalItem, float]]
|
|
944
|
+
Filtered (item, log_prob) pairs
|
|
945
|
+
"""
|
|
946
|
+
# Normalize language code
|
|
947
|
+
if language_code is not None:
|
|
948
|
+
language_code = validate_iso639_code(language_code)
|
|
949
|
+
|
|
950
|
+
# Filter by constraints and find matching lexical items
|
|
951
|
+
candidates: list[tuple[LexicalItem, float]] = []
|
|
952
|
+
for token, log_prob in predictions:
|
|
953
|
+
# Find matching items in lexicons
|
|
954
|
+
for lexicon in lexicons:
|
|
955
|
+
for item in lexicon.items.values():
|
|
956
|
+
# Skip if already seen (uniqueness enforcement)
|
|
957
|
+
if seen_items is not None and item.id in seen_items:
|
|
958
|
+
continue
|
|
959
|
+
|
|
960
|
+
# Match lemma and language
|
|
961
|
+
if item.lemma.lower() == token.lower():
|
|
962
|
+
if language_code is None or item.language_code == language_code:
|
|
963
|
+
# Check slot constraints
|
|
964
|
+
if slot.constraints:
|
|
965
|
+
# Evaluate constraints using resolver
|
|
966
|
+
if self.resolver.evaluate_slot_constraints(
|
|
967
|
+
item, slot.constraints
|
|
968
|
+
):
|
|
969
|
+
candidates.append((item, log_prob))
|
|
970
|
+
else:
|
|
971
|
+
candidates.append((item, log_prob))
|
|
972
|
+
|
|
973
|
+
# Apply max_fills limit (take top-N by log probability)
|
|
974
|
+
if max_fills is not None and len(candidates) > max_fills:
|
|
975
|
+
# Already sorted by log_prob (descending) from MLM predictions
|
|
976
|
+
# But need to ensure we take highest scoring ones
|
|
977
|
+
candidates.sort(key=lambda x: x[1], reverse=True)
|
|
978
|
+
candidates = candidates[:max_fills]
|
|
979
|
+
|
|
980
|
+
return candidates
|
|
981
|
+
|
|
982
|
+
def _create_masked_text(
|
|
983
|
+
self,
|
|
984
|
+
template: Template,
|
|
985
|
+
slot_names: list[str],
|
|
986
|
+
filled_slots: dict[str, LexicalItem],
|
|
987
|
+
current_slot_idx: int,
|
|
988
|
+
) -> str:
|
|
989
|
+
"""Create text with mask token for current slot.
|
|
990
|
+
|
|
991
|
+
Parameters
|
|
992
|
+
----------
|
|
993
|
+
template : Template
|
|
994
|
+
Template
|
|
995
|
+
slot_names : list[str]
|
|
996
|
+
Slot names
|
|
997
|
+
filled_slots : dict[str, LexicalItem]
|
|
998
|
+
Filled slots
|
|
999
|
+
current_slot_idx : int
|
|
1000
|
+
Current slot index
|
|
1001
|
+
|
|
1002
|
+
Returns
|
|
1003
|
+
-------
|
|
1004
|
+
str
|
|
1005
|
+
Text with [MASK] token
|
|
1006
|
+
"""
|
|
1007
|
+
mask_token = self.model_adapter.get_mask_token()
|
|
1008
|
+
text = template.template_string
|
|
1009
|
+
|
|
1010
|
+
# Replace filled slots with lemmas
|
|
1011
|
+
for slot_name, item in filled_slots.items():
|
|
1012
|
+
placeholder = f"{{{slot_name}}}"
|
|
1013
|
+
text = text.replace(placeholder, item.lemma)
|
|
1014
|
+
|
|
1015
|
+
# Replace current slot with mask
|
|
1016
|
+
current_slot_name = slot_names[current_slot_idx]
|
|
1017
|
+
current_placeholder = f"{{{current_slot_name}}}"
|
|
1018
|
+
text = text.replace(current_placeholder, mask_token)
|
|
1019
|
+
|
|
1020
|
+
# Replace remaining unfilled slots with mask for context
|
|
1021
|
+
for slot_name in slot_names:
|
|
1022
|
+
placeholder = f"{{{slot_name}}}"
|
|
1023
|
+
if placeholder in text:
|
|
1024
|
+
text = text.replace(placeholder, mask_token)
|
|
1025
|
+
|
|
1026
|
+
return text
|
|
1027
|
+
|
|
1028
|
+
|
|
1029
|
+
class StrategyFiller(TemplateFiller):
|
|
1030
|
+
"""Strategy-based template filling for simple templates.
|
|
1031
|
+
|
|
1032
|
+
Fast filling using enumeration strategies (exhaustive, random, stratified).
|
|
1033
|
+
Does NOT handle template-level multi-slot constraints (Template.constraints).
|
|
1034
|
+
|
|
1035
|
+
For templates with multi-slot constraints requiring agreement or
|
|
1036
|
+
relational checks, use CSPFiller instead.
|
|
1037
|
+
|
|
1038
|
+
Parameters
|
|
1039
|
+
----------
|
|
1040
|
+
lexicon : Lexicon
|
|
1041
|
+
Lexicon containing candidate items.
|
|
1042
|
+
strategy : FillingStrategy
|
|
1043
|
+
Strategy for generating combinations.
|
|
1044
|
+
|
|
1045
|
+
Examples
|
|
1046
|
+
--------
|
|
1047
|
+
>>> from bead.templates.strategies import StrategyFiller, ExhaustiveStrategy
|
|
1048
|
+
>>> filler = StrategyFiller(lexicon, ExhaustiveStrategy())
|
|
1049
|
+
>>> filled = filler.fill(template)
|
|
1050
|
+
>>> len(filled)
|
|
1051
|
+
12
|
|
1052
|
+
"""
|
|
1053
|
+
|
|
1054
|
+
def __init__(self, lexicon: Lexicon, strategy: FillingStrategy) -> None:
|
|
1055
|
+
self.lexicon = lexicon
|
|
1056
|
+
self.strategy = strategy
|
|
1057
|
+
self.resolver = ConstraintResolver()
|
|
1058
|
+
|
|
1059
|
+
def fill(
|
|
1060
|
+
self,
|
|
1061
|
+
template: Template,
|
|
1062
|
+
language_code: LanguageCode | None = None,
|
|
1063
|
+
) -> list[FilledTemplate]:
|
|
1064
|
+
"""Fill template with lexical items using strategy.
|
|
1065
|
+
|
|
1066
|
+
Parameters
|
|
1067
|
+
----------
|
|
1068
|
+
template : Template
|
|
1069
|
+
Template to fill.
|
|
1070
|
+
language_code : LanguageCode | None
|
|
1071
|
+
Optional language code to filter items.
|
|
1072
|
+
|
|
1073
|
+
Returns
|
|
1074
|
+
-------
|
|
1075
|
+
list[FilledTemplate]
|
|
1076
|
+
List of all filled template instances.
|
|
1077
|
+
|
|
1078
|
+
Raises
|
|
1079
|
+
------
|
|
1080
|
+
ValueError
|
|
1081
|
+
If any slot has no valid items.
|
|
1082
|
+
"""
|
|
1083
|
+
# 1. Resolve slot constraints
|
|
1084
|
+
slot_items = self._resolve_slot_constraints(template, language_code)
|
|
1085
|
+
|
|
1086
|
+
# 2. Check for empty slots
|
|
1087
|
+
empty_slots = [name for name, items in slot_items.items() if not items]
|
|
1088
|
+
if empty_slots:
|
|
1089
|
+
raise ValueError(f"No valid items for slots: {empty_slots}")
|
|
1090
|
+
|
|
1091
|
+
# 3. Generate combinations using strategy
|
|
1092
|
+
combinations = self.strategy.generate_combinations(slot_items)
|
|
1093
|
+
|
|
1094
|
+
# 4. Create FilledTemplate instances
|
|
1095
|
+
filled_templates: list[FilledTemplate] = []
|
|
1096
|
+
for combo in combinations:
|
|
1097
|
+
rendered = self._render_template(template, combo)
|
|
1098
|
+
filled = FilledTemplate(
|
|
1099
|
+
template_id=str(template.id),
|
|
1100
|
+
template_name=template.name,
|
|
1101
|
+
slot_fillers=combo,
|
|
1102
|
+
rendered_text=rendered,
|
|
1103
|
+
strategy_name=self.strategy.name,
|
|
1104
|
+
)
|
|
1105
|
+
filled_templates.append(filled)
|
|
1106
|
+
|
|
1107
|
+
return filled_templates
|
|
1108
|
+
|
|
1109
|
+
def _resolve_slot_constraints(
|
|
1110
|
+
self,
|
|
1111
|
+
template: Template,
|
|
1112
|
+
language_code: LanguageCode | None,
|
|
1113
|
+
) -> dict[str, list[LexicalItem]]:
|
|
1114
|
+
"""Resolve constraints for each slot.
|
|
1115
|
+
|
|
1116
|
+
Parameters
|
|
1117
|
+
----------
|
|
1118
|
+
template : Template
|
|
1119
|
+
Template with slots and constraints.
|
|
1120
|
+
language_code : LanguageCode | None
|
|
1121
|
+
Optional language filter.
|
|
1122
|
+
|
|
1123
|
+
Returns
|
|
1124
|
+
-------
|
|
1125
|
+
dict[str, list[LexicalItem]]
|
|
1126
|
+
Mapping of slot names to valid items.
|
|
1127
|
+
"""
|
|
1128
|
+
slot_items: dict[str, list[LexicalItem]] = {}
|
|
1129
|
+
|
|
1130
|
+
# Normalize language code if provided
|
|
1131
|
+
normalized_lang = validate_iso639_code(language_code) if language_code else None
|
|
1132
|
+
|
|
1133
|
+
for slot_name, slot in template.slots.items():
|
|
1134
|
+
candidates = list(self.lexicon.items.values())
|
|
1135
|
+
|
|
1136
|
+
# Filter by language code
|
|
1137
|
+
if normalized_lang:
|
|
1138
|
+
candidates = [
|
|
1139
|
+
item for item in candidates if item.language_code == normalized_lang
|
|
1140
|
+
]
|
|
1141
|
+
|
|
1142
|
+
# Apply slot constraints
|
|
1143
|
+
if slot.constraints:
|
|
1144
|
+
filtered: list[LexicalItem] = []
|
|
1145
|
+
for item in candidates:
|
|
1146
|
+
eval_context: dict[
|
|
1147
|
+
str, ContextValue | LexicalItem | FilledTemplate | Item
|
|
1148
|
+
] = {"self": item}
|
|
1149
|
+
|
|
1150
|
+
# Check all constraints
|
|
1151
|
+
passes_all_constraints = True
|
|
1152
|
+
for constraint in slot.constraints:
|
|
1153
|
+
if constraint.context:
|
|
1154
|
+
eval_context.update(constraint.context)
|
|
1155
|
+
|
|
1156
|
+
evaluator = DSLEvaluator()
|
|
1157
|
+
if not evaluator.evaluate(constraint.expression, eval_context):
|
|
1158
|
+
passes_all_constraints = False
|
|
1159
|
+
break
|
|
1160
|
+
|
|
1161
|
+
# Only add if passed ALL constraints
|
|
1162
|
+
if passes_all_constraints:
|
|
1163
|
+
filtered.append(item)
|
|
1164
|
+
|
|
1165
|
+
candidates = filtered
|
|
1166
|
+
|
|
1167
|
+
slot_items[slot_name] = candidates
|
|
1168
|
+
|
|
1169
|
+
return slot_items
|
|
1170
|
+
|
|
1171
|
+
def _render_template(
|
|
1172
|
+
self, template: Template, slot_fillers: dict[str, LexicalItem]
|
|
1173
|
+
) -> str:
|
|
1174
|
+
"""Render template string with slot fillers.
|
|
1175
|
+
|
|
1176
|
+
Parameters
|
|
1177
|
+
----------
|
|
1178
|
+
template : Template
|
|
1179
|
+
Template with template_string.
|
|
1180
|
+
slot_fillers : dict[str, LexicalItem]
|
|
1181
|
+
Items filling each slot.
|
|
1182
|
+
|
|
1183
|
+
Returns
|
|
1184
|
+
-------
|
|
1185
|
+
str
|
|
1186
|
+
Rendered template string.
|
|
1187
|
+
"""
|
|
1188
|
+
rendered = template.template_string
|
|
1189
|
+
for slot_name, item in slot_fillers.items():
|
|
1190
|
+
placeholder = f"{{{slot_name}}}"
|
|
1191
|
+
rendered = rendered.replace(placeholder, item.lemma)
|
|
1192
|
+
return rendered
|
|
1193
|
+
|
|
1194
|
+
def count_combinations(self, template: Template) -> int:
|
|
1195
|
+
"""Count total possible combinations for template.
|
|
1196
|
+
|
|
1197
|
+
Parameters
|
|
1198
|
+
----------
|
|
1199
|
+
template : Template
|
|
1200
|
+
Template to count combinations for.
|
|
1201
|
+
|
|
1202
|
+
Returns
|
|
1203
|
+
-------
|
|
1204
|
+
int
|
|
1205
|
+
Total number of possible combinations.
|
|
1206
|
+
"""
|
|
1207
|
+
slot_items = self._resolve_slot_constraints(template, None)
|
|
1208
|
+
|
|
1209
|
+
if not slot_items:
|
|
1210
|
+
return 0
|
|
1211
|
+
|
|
1212
|
+
count = 1
|
|
1213
|
+
for items in slot_items.values():
|
|
1214
|
+
count *= len(items)
|
|
1215
|
+
|
|
1216
|
+
return count
|
|
1217
|
+
|
|
1218
|
+
|
|
1219
|
+
class MixedFillingStrategy(FillingStrategy):
|
|
1220
|
+
"""Fill different template slots using different strategies.
|
|
1221
|
+
|
|
1222
|
+
Allows per-slot strategy specification, enabling workflows like:
|
|
1223
|
+
- Fill nouns/verbs exhaustively
|
|
1224
|
+
- Fill adjectives via MLM based on noun context
|
|
1225
|
+
|
|
1226
|
+
This strategy operates in two steps:
|
|
1227
|
+
1. First pass: Fill slots assigned to non-MLM strategies (exhaustive, random, etc.)
|
|
1228
|
+
2. Second pass: For each first pass combination, fill remaining slots via MLM
|
|
1229
|
+
|
|
1230
|
+
Parameters
|
|
1231
|
+
----------
|
|
1232
|
+
slot_strategies : dict[str, tuple[FillingStrategy, dict]]
|
|
1233
|
+
Mapping of slot names to (strategy, config) tuples.
|
|
1234
|
+
Config is strategy-specific kwargs.
|
|
1235
|
+
default_strategy : FillingStrategy | None
|
|
1236
|
+
Default strategy for slots not explicitly specified.
|
|
1237
|
+
|
|
1238
|
+
Examples
|
|
1239
|
+
--------
|
|
1240
|
+
>>> exhaustive = ExhaustiveStrategy()
|
|
1241
|
+
>>> mlm_config = {
|
|
1242
|
+
... "resolver": resolver,
|
|
1243
|
+
... "model_adapter": mlm_adapter,
|
|
1244
|
+
... "top_k": 5
|
|
1245
|
+
... }
|
|
1246
|
+
>>> strategy = MixedFillingStrategy(
|
|
1247
|
+
... slot_strategies={
|
|
1248
|
+
... "noun": (exhaustive, {}),
|
|
1249
|
+
... "verb": (exhaustive, {}),
|
|
1250
|
+
... "adjective": ("mlm", mlm_config)
|
|
1251
|
+
... }
|
|
1252
|
+
... )
|
|
1253
|
+
"""
|
|
1254
|
+
|
|
1255
|
+
def __init__(
|
|
1256
|
+
self,
|
|
1257
|
+
slot_strategies: dict[str, tuple[str | FillingStrategy, StrategyConfig]],
|
|
1258
|
+
default_strategy: FillingStrategy | None = None,
|
|
1259
|
+
) -> None:
|
|
1260
|
+
"""Initialize mixed strategy.
|
|
1261
|
+
|
|
1262
|
+
Parameters
|
|
1263
|
+
----------
|
|
1264
|
+
slot_strategies : dict[str, tuple[str | FillingStrategy, StrategyConfig]]
|
|
1265
|
+
Mapping slot names to (strategy_name, config) or
|
|
1266
|
+
(strategy_instance, config). strategy_name can be:
|
|
1267
|
+
"exhaustive", "random", "stratified", "mlm"
|
|
1268
|
+
default_strategy : FillingStrategy | None
|
|
1269
|
+
Default strategy for unspecified slots.
|
|
1270
|
+
"""
|
|
1271
|
+
self.slot_strategies = slot_strategies
|
|
1272
|
+
self.default_strategy = default_strategy or ExhaustiveStrategy()
|
|
1273
|
+
|
|
1274
|
+
# Separate slots by strategy type
|
|
1275
|
+
self.non_mlm_slots: list[str] = [] # Non-MLM slots
|
|
1276
|
+
self.mlm_slots: list[str] = [] # MLM slots
|
|
1277
|
+
self.non_mlm_strategies: dict[str, FillingStrategy] = {}
|
|
1278
|
+
self.mlm_configs: dict[str, StrategyConfig] = {}
|
|
1279
|
+
|
|
1280
|
+
for slot_name, (strategy, config) in slot_strategies.items():
|
|
1281
|
+
strategy_name = strategy if isinstance(strategy, str) else strategy.name
|
|
1282
|
+
|
|
1283
|
+
if strategy_name == "mlm":
|
|
1284
|
+
self.mlm_slots.append(slot_name)
|
|
1285
|
+
self.mlm_configs[slot_name] = config
|
|
1286
|
+
else:
|
|
1287
|
+
self.non_mlm_slots.append(slot_name)
|
|
1288
|
+
# Instantiate strategy if needed
|
|
1289
|
+
if isinstance(strategy, str):
|
|
1290
|
+
self.non_mlm_strategies[slot_name] = self._instantiate_strategy(
|
|
1291
|
+
strategy, config
|
|
1292
|
+
)
|
|
1293
|
+
else:
|
|
1294
|
+
self.non_mlm_strategies[slot_name] = strategy
|
|
1295
|
+
|
|
1296
|
+
def _instantiate_strategy(
|
|
1297
|
+
self, strategy_name: str, config: StrategyConfig
|
|
1298
|
+
) -> FillingStrategy:
|
|
1299
|
+
"""Instantiate strategy from name and config.
|
|
1300
|
+
|
|
1301
|
+
Parameters
|
|
1302
|
+
----------
|
|
1303
|
+
strategy_name : str
|
|
1304
|
+
Strategy name: "exhaustive", "random", "stratified"
|
|
1305
|
+
config : dict
|
|
1306
|
+
Strategy-specific configuration
|
|
1307
|
+
|
|
1308
|
+
Returns
|
|
1309
|
+
-------
|
|
1310
|
+
FillingStrategy
|
|
1311
|
+
Instantiated strategy
|
|
1312
|
+
|
|
1313
|
+
Raises
|
|
1314
|
+
------
|
|
1315
|
+
ValueError
|
|
1316
|
+
If strategy name is unknown
|
|
1317
|
+
"""
|
|
1318
|
+
if strategy_name == "exhaustive":
|
|
1319
|
+
return ExhaustiveStrategy()
|
|
1320
|
+
elif strategy_name == "random":
|
|
1321
|
+
return RandomStrategy(
|
|
1322
|
+
n_samples=cast(int, config.get("n_samples", 100)),
|
|
1323
|
+
seed=cast(int | None, config.get("seed")),
|
|
1324
|
+
)
|
|
1325
|
+
elif strategy_name == "stratified":
|
|
1326
|
+
return StratifiedStrategy(
|
|
1327
|
+
n_samples=cast(int, config.get("n_samples", 100)),
|
|
1328
|
+
grouping_property=cast(str, config.get("grouping_property", "pos")),
|
|
1329
|
+
seed=cast(int | None, config.get("seed")),
|
|
1330
|
+
)
|
|
1331
|
+
else:
|
|
1332
|
+
raise ValueError(f"Unknown strategy: {strategy_name}")
|
|
1333
|
+
|
|
1334
|
+
@property
|
|
1335
|
+
def name(self) -> str:
|
|
1336
|
+
"""Get strategy name."""
|
|
1337
|
+
return "mixed"
|
|
1338
|
+
|
|
1339
|
+
def generate_combinations(
|
|
1340
|
+
self,
|
|
1341
|
+
slot_items: dict[str, list[LexicalItem]],
|
|
1342
|
+
) -> list[dict[str, LexicalItem]]:
|
|
1343
|
+
"""Generate combinations using mixed strategies.
|
|
1344
|
+
|
|
1345
|
+
Note: This method signature is required by FillingStrategy,
|
|
1346
|
+
but MixedFillingStrategy with MLM requires template context.
|
|
1347
|
+
Use generate_from_template instead.
|
|
1348
|
+
|
|
1349
|
+
Parameters
|
|
1350
|
+
----------
|
|
1351
|
+
slot_items : dict[str, list[LexicalItem]]
|
|
1352
|
+
Mapping of slot names to valid items
|
|
1353
|
+
|
|
1354
|
+
Returns
|
|
1355
|
+
-------
|
|
1356
|
+
list[dict[str, LexicalItem]]
|
|
1357
|
+
Generated combinations
|
|
1358
|
+
|
|
1359
|
+
Raises
|
|
1360
|
+
------
|
|
1361
|
+
NotImplementedError
|
|
1362
|
+
If any slot uses MLM strategy (requires template context)
|
|
1363
|
+
"""
|
|
1364
|
+
if self.mlm_slots:
|
|
1365
|
+
raise NotImplementedError(
|
|
1366
|
+
"MixedFillingStrategy with MLM slots requires template context. "
|
|
1367
|
+
"Use StrategyFiller or CSPFiller, which call generate_from_template."
|
|
1368
|
+
)
|
|
1369
|
+
|
|
1370
|
+
# If no MLM slots, just use non-MLM strategies
|
|
1371
|
+
# This is a simplified case: all slots use non-MLM strategies
|
|
1372
|
+
|
|
1373
|
+
# For each slot, generate its combinations independently
|
|
1374
|
+
slot_combinations: dict[str, list[LexicalItem]] = {}
|
|
1375
|
+
|
|
1376
|
+
for slot_name, items in slot_items.items():
|
|
1377
|
+
if slot_name in self.non_mlm_strategies:
|
|
1378
|
+
strategy = self.non_mlm_strategies[slot_name]
|
|
1379
|
+
# Generate combinations for just this slot
|
|
1380
|
+
combos = strategy.generate_combinations({slot_name: items})
|
|
1381
|
+
slot_combinations[slot_name] = [c[slot_name] for c in combos]
|
|
1382
|
+
else:
|
|
1383
|
+
# Use default strategy
|
|
1384
|
+
combos = self.default_strategy.generate_combinations({slot_name: items})
|
|
1385
|
+
slot_combinations[slot_name] = [c[slot_name] for c in combos]
|
|
1386
|
+
|
|
1387
|
+
# Generate cartesian product of all slot combinations
|
|
1388
|
+
slot_names = list(slot_items.keys())
|
|
1389
|
+
item_lists = [slot_combinations[name] for name in slot_names]
|
|
1390
|
+
|
|
1391
|
+
combinations: list[dict[str, LexicalItem]] = []
|
|
1392
|
+
for combo_tuple in cartesian_product(*item_lists):
|
|
1393
|
+
combo_dict = dict(zip(slot_names, combo_tuple, strict=True))
|
|
1394
|
+
combinations.append(combo_dict)
|
|
1395
|
+
|
|
1396
|
+
return combinations
|
|
1397
|
+
|
|
1398
|
+
def generate_from_template(
|
|
1399
|
+
self,
|
|
1400
|
+
template: Template,
|
|
1401
|
+
lexicons: list[Lexicon],
|
|
1402
|
+
language_code: LanguageCode | None = None,
|
|
1403
|
+
) -> Iterator[dict[str, LexicalItem]]:
|
|
1404
|
+
"""Generate combinations from template using mixed strategies.
|
|
1405
|
+
|
|
1406
|
+
First pass: Fill non-MLM slots using their assigned strategies
|
|
1407
|
+
Second pass: For each first pass combination, fill MLM slots using beam search
|
|
1408
|
+
|
|
1409
|
+
Parameters
|
|
1410
|
+
----------
|
|
1411
|
+
template : Template
|
|
1412
|
+
Template to fill
|
|
1413
|
+
lexicons : list[Lexicon]
|
|
1414
|
+
Lexicons for constraint resolution
|
|
1415
|
+
language_code : LanguageCode | None
|
|
1416
|
+
Language filter
|
|
1417
|
+
|
|
1418
|
+
Yields
|
|
1419
|
+
------
|
|
1420
|
+
dict[str, LexicalItem]
|
|
1421
|
+
Complete slot-to-item mappings
|
|
1422
|
+
"""
|
|
1423
|
+
logger.info(f"[MixedFillingStrategy] Starting template: {template.name}")
|
|
1424
|
+
logger.info(f"[MixedFillingStrategy] Non-MLM slots: {self.non_mlm_slots}")
|
|
1425
|
+
logger.info(f"[MixedFillingStrategy] MLM slots: {self.mlm_slots}")
|
|
1426
|
+
|
|
1427
|
+
# First pass: Fill non-MLM slots
|
|
1428
|
+
first_pass_start = time.time()
|
|
1429
|
+
if not self.non_mlm_slots:
|
|
1430
|
+
# No non-MLM slots - just use MLM for all MLM slots
|
|
1431
|
+
first_pass_combinations: list[dict[str, LexicalItem]] = [{}]
|
|
1432
|
+
else:
|
|
1433
|
+
first_pass_combinations = self._generate_non_mlm_combinations(
|
|
1434
|
+
template, lexicons, language_code
|
|
1435
|
+
)
|
|
1436
|
+
first_pass_elapsed = time.time() - first_pass_start
|
|
1437
|
+
logger.info(
|
|
1438
|
+
f"[MixedFillingStrategy] First pass generated "
|
|
1439
|
+
f"{len(first_pass_combinations)} combinations in {first_pass_elapsed:.2f}s"
|
|
1440
|
+
)
|
|
1441
|
+
|
|
1442
|
+
# Second pass: Fill MLM slots for each first pass combination
|
|
1443
|
+
if not self.mlm_slots:
|
|
1444
|
+
# No MLM slots - just yield first pass combinations
|
|
1445
|
+
logger.info(
|
|
1446
|
+
"[MixedFillingStrategy] No MLM slots, yielding first pass combinations"
|
|
1447
|
+
)
|
|
1448
|
+
yield from first_pass_combinations
|
|
1449
|
+
else:
|
|
1450
|
+
logger.info(
|
|
1451
|
+
f"[MixedFillingStrategy] Starting second pass for "
|
|
1452
|
+
f"{len(first_pass_combinations)} combinations..."
|
|
1453
|
+
)
|
|
1454
|
+
second_pass_start = time.time()
|
|
1455
|
+
total_yielded = 0
|
|
1456
|
+
for i, partial_combo in enumerate(first_pass_combinations):
|
|
1457
|
+
combo_start = time.time()
|
|
1458
|
+
if i == 0:
|
|
1459
|
+
# Debug first combo to see what's in it
|
|
1460
|
+
combo_slots = list(partial_combo.keys())
|
|
1461
|
+
combo_values = {k: v.lemma for k, v in partial_combo.items()}
|
|
1462
|
+
logger.info(
|
|
1463
|
+
f"[MixedFillingStrategy] First combination has "
|
|
1464
|
+
f"slots: {combo_slots}"
|
|
1465
|
+
)
|
|
1466
|
+
logger.info(
|
|
1467
|
+
f"[MixedFillingStrategy] First combination "
|
|
1468
|
+
f"values: {combo_values}"
|
|
1469
|
+
)
|
|
1470
|
+
logger.info(
|
|
1471
|
+
f"[MixedFillingStrategy] Processing combination "
|
|
1472
|
+
f"{i + 1}/{len(first_pass_combinations)}"
|
|
1473
|
+
)
|
|
1474
|
+
# Fill remaining slots with MLM
|
|
1475
|
+
n_yielded_for_combo = 0
|
|
1476
|
+
for filled in self._fill_mlm_slots(
|
|
1477
|
+
template, partial_combo, lexicons, language_code
|
|
1478
|
+
):
|
|
1479
|
+
# Filter by template-level constraints
|
|
1480
|
+
if self._check_template_constraints(template, filled):
|
|
1481
|
+
n_yielded_for_combo += 1
|
|
1482
|
+
total_yielded += 1
|
|
1483
|
+
yield filled
|
|
1484
|
+
combo_elapsed = time.time() - combo_start
|
|
1485
|
+
logger.info(
|
|
1486
|
+
f"[MixedFillingStrategy] Combination {i + 1} yielded "
|
|
1487
|
+
f"{n_yielded_for_combo} complete fillings in "
|
|
1488
|
+
f"{combo_elapsed:.2f}s"
|
|
1489
|
+
)
|
|
1490
|
+
second_pass_elapsed = time.time() - second_pass_start
|
|
1491
|
+
logger.info(
|
|
1492
|
+
f"[MixedFillingStrategy] Second pass complete: {total_yielded} "
|
|
1493
|
+
f"total fillings in {second_pass_elapsed:.2f}s"
|
|
1494
|
+
)
|
|
1495
|
+
|
|
1496
|
+
def _generate_non_mlm_combinations(
|
|
1497
|
+
self,
|
|
1498
|
+
template: Template,
|
|
1499
|
+
lexicons: list[Lexicon],
|
|
1500
|
+
language_code: LanguageCode | None,
|
|
1501
|
+
) -> list[dict[str, LexicalItem]]:
|
|
1502
|
+
"""Generate combinations for non-MLM slots.
|
|
1503
|
+
|
|
1504
|
+
Parameters
|
|
1505
|
+
----------
|
|
1506
|
+
template : Template
|
|
1507
|
+
Template being filled
|
|
1508
|
+
lexicons : list[Lexicon]
|
|
1509
|
+
Lexicons for items
|
|
1510
|
+
language_code : LanguageCode | None
|
|
1511
|
+
Language filter
|
|
1512
|
+
|
|
1513
|
+
Returns
|
|
1514
|
+
-------
|
|
1515
|
+
list[dict[str, LexicalItem]]
|
|
1516
|
+
Partial combinations (only non-MLM slots filled)
|
|
1517
|
+
"""
|
|
1518
|
+
# Get valid items for each non-MLM slot
|
|
1519
|
+
slot_items: dict[str, list[LexicalItem]] = {}
|
|
1520
|
+
normalized_lang = validate_iso639_code(language_code) if language_code else None
|
|
1521
|
+
|
|
1522
|
+
for slot_name in self.non_mlm_slots:
|
|
1523
|
+
if slot_name not in template.slots:
|
|
1524
|
+
continue
|
|
1525
|
+
|
|
1526
|
+
slot = template.slots[slot_name]
|
|
1527
|
+
candidates: list[LexicalItem] = []
|
|
1528
|
+
|
|
1529
|
+
# Collect items from all lexicons
|
|
1530
|
+
for lexicon in lexicons:
|
|
1531
|
+
for item in lexicon.items.values():
|
|
1532
|
+
# Filter by language
|
|
1533
|
+
if normalized_lang and item.language_code != normalized_lang:
|
|
1534
|
+
continue
|
|
1535
|
+
# Check slot constraints
|
|
1536
|
+
if slot.constraints:
|
|
1537
|
+
eval_context: dict[str, ContextValue | LexicalItem] = {
|
|
1538
|
+
"self": item
|
|
1539
|
+
}
|
|
1540
|
+
# Check ALL constraints - item must pass every one
|
|
1541
|
+
passes_all_constraints = True
|
|
1542
|
+
for constraint in slot.constraints:
|
|
1543
|
+
if constraint.context:
|
|
1544
|
+
eval_context.update(constraint.context)
|
|
1545
|
+
# Evaluate
|
|
1546
|
+
evaluator = DSLEvaluator()
|
|
1547
|
+
# Cast to expected context type
|
|
1548
|
+
typed_context = cast(
|
|
1549
|
+
dict[
|
|
1550
|
+
str,
|
|
1551
|
+
ContextValue | LexicalItem | FilledTemplate | Item,
|
|
1552
|
+
],
|
|
1553
|
+
eval_context,
|
|
1554
|
+
)
|
|
1555
|
+
if not evaluator.evaluate(
|
|
1556
|
+
constraint.expression, typed_context
|
|
1557
|
+
):
|
|
1558
|
+
passes_all_constraints = False
|
|
1559
|
+
break
|
|
1560
|
+
|
|
1561
|
+
# Only add item if it passed ALL constraints
|
|
1562
|
+
if not passes_all_constraints:
|
|
1563
|
+
continue
|
|
1564
|
+
|
|
1565
|
+
candidates.append(item)
|
|
1566
|
+
|
|
1567
|
+
slot_items[slot_name] = candidates
|
|
1568
|
+
|
|
1569
|
+
# Generate combinations using per-slot strategies
|
|
1570
|
+
# For each slot, we need to apply its strategy independently,
|
|
1571
|
+
# then take cartesian product
|
|
1572
|
+
|
|
1573
|
+
# Collect combinations per slot
|
|
1574
|
+
slot_combos: dict[str, list[LexicalItem]] = {}
|
|
1575
|
+
|
|
1576
|
+
for slot_name in self.non_mlm_slots:
|
|
1577
|
+
if slot_name not in slot_items:
|
|
1578
|
+
continue
|
|
1579
|
+
|
|
1580
|
+
items = slot_items[slot_name]
|
|
1581
|
+
strategy = self.non_mlm_strategies.get(slot_name, self.default_strategy)
|
|
1582
|
+
|
|
1583
|
+
# Generate combinations for this slot
|
|
1584
|
+
combos = strategy.generate_combinations({slot_name: items})
|
|
1585
|
+
slot_combos[slot_name] = [c[slot_name] for c in combos]
|
|
1586
|
+
|
|
1587
|
+
# Cartesian product of all non-MLM slots
|
|
1588
|
+
if not slot_combos:
|
|
1589
|
+
return [{}]
|
|
1590
|
+
|
|
1591
|
+
slot_names = list(slot_combos.keys())
|
|
1592
|
+
item_lists = [slot_combos[name] for name in slot_names]
|
|
1593
|
+
|
|
1594
|
+
combinations: list[dict[str, LexicalItem]] = []
|
|
1595
|
+
for combo_tuple in cartesian_product(*item_lists):
|
|
1596
|
+
combo_dict = dict(zip(slot_names, combo_tuple, strict=True))
|
|
1597
|
+
# Filter by template-level constraints
|
|
1598
|
+
if self._check_template_constraints(template, combo_dict):
|
|
1599
|
+
combinations.append(combo_dict)
|
|
1600
|
+
|
|
1601
|
+
return combinations
|
|
1602
|
+
|
|
1603
|
+
def _fill_mlm_slots(
|
|
1604
|
+
self,
|
|
1605
|
+
template: Template,
|
|
1606
|
+
partial_filling: dict[str, LexicalItem],
|
|
1607
|
+
lexicons: list[Lexicon],
|
|
1608
|
+
language_code: LanguageCode | None,
|
|
1609
|
+
) -> Iterator[dict[str, LexicalItem]]:
|
|
1610
|
+
"""Fill MLM slots given a partial filling from first pass.
|
|
1611
|
+
|
|
1612
|
+
Parameters
|
|
1613
|
+
----------
|
|
1614
|
+
template : Template
|
|
1615
|
+
Template being filled
|
|
1616
|
+
partial_filling : dict[str, LexicalItem]
|
|
1617
|
+
Already-filled slots from first pass
|
|
1618
|
+
lexicons : list[Lexicon]
|
|
1619
|
+
Lexicons for items
|
|
1620
|
+
language_code : LanguageCode | None
|
|
1621
|
+
Language filter
|
|
1622
|
+
|
|
1623
|
+
Yields
|
|
1624
|
+
------
|
|
1625
|
+
dict[str, LexicalItem]
|
|
1626
|
+
Complete fillings with MLM slots added
|
|
1627
|
+
"""
|
|
1628
|
+
if not self.mlm_slots or not self.mlm_configs:
|
|
1629
|
+
yield partial_filling
|
|
1630
|
+
return
|
|
1631
|
+
|
|
1632
|
+
# Get base config from first MLM slot (model adapter, resolver, etc.)
|
|
1633
|
+
first_mlm_slot = self.mlm_slots[0]
|
|
1634
|
+
base_config = self.mlm_configs[first_mlm_slot].copy()
|
|
1635
|
+
|
|
1636
|
+
# Extract per-slot max_fills and enforce_unique settings
|
|
1637
|
+
per_slot_max_fills: dict[str, int] = {}
|
|
1638
|
+
per_slot_enforce_unique: dict[str, bool] = {}
|
|
1639
|
+
|
|
1640
|
+
for slot_name in self.mlm_slots:
|
|
1641
|
+
config = self.mlm_configs[slot_name]
|
|
1642
|
+
if "max_fills" in config:
|
|
1643
|
+
per_slot_max_fills[slot_name] = cast(int, config["max_fills"])
|
|
1644
|
+
if "enforce_unique" in config:
|
|
1645
|
+
per_slot_enforce_unique[slot_name] = cast(
|
|
1646
|
+
bool, config["enforce_unique"]
|
|
1647
|
+
)
|
|
1648
|
+
|
|
1649
|
+
# Remove per-slot settings from base config
|
|
1650
|
+
# (they're not MLMFillingStrategy params)
|
|
1651
|
+
base_config.pop("max_fills", None)
|
|
1652
|
+
base_config.pop("enforce_unique", None)
|
|
1653
|
+
|
|
1654
|
+
# Add per-slot dicts to config
|
|
1655
|
+
base_config["per_slot_max_fills"] = per_slot_max_fills
|
|
1656
|
+
base_config["per_slot_enforce_unique"] = per_slot_enforce_unique
|
|
1657
|
+
|
|
1658
|
+
# Create MLM strategy with properly typed config
|
|
1659
|
+
mlm_strategy = MLMFillingStrategy(
|
|
1660
|
+
resolver=cast(ConstraintResolver, base_config["resolver"]),
|
|
1661
|
+
model_adapter=cast(HuggingFaceMLMAdapter, base_config["model_adapter"]),
|
|
1662
|
+
beam_size=cast(int, base_config.get("beam_size", 5)),
|
|
1663
|
+
fill_direction=cast(
|
|
1664
|
+
Literal[
|
|
1665
|
+
"left_to_right",
|
|
1666
|
+
"right_to_left",
|
|
1667
|
+
"inside_out",
|
|
1668
|
+
"outside_in",
|
|
1669
|
+
"custom",
|
|
1670
|
+
],
|
|
1671
|
+
base_config.get("fill_direction", "left_to_right"),
|
|
1672
|
+
),
|
|
1673
|
+
custom_order=cast(list[int] | None, base_config.get("custom_order")),
|
|
1674
|
+
top_k=cast(int, base_config.get("top_k", 20)),
|
|
1675
|
+
cache=cast(ModelOutputCache | None, base_config.get("cache")),
|
|
1676
|
+
budget=cast(int | None, base_config.get("budget")),
|
|
1677
|
+
per_slot_max_fills=per_slot_max_fills,
|
|
1678
|
+
per_slot_enforce_unique=per_slot_enforce_unique,
|
|
1679
|
+
)
|
|
1680
|
+
|
|
1681
|
+
# Create a modified template with only MLM slots
|
|
1682
|
+
mlm_template = self._create_mlm_template(template, partial_filling)
|
|
1683
|
+
|
|
1684
|
+
# Generate completions via MLM
|
|
1685
|
+
for mlm_filling in mlm_strategy.generate_from_template(
|
|
1686
|
+
mlm_template, lexicons, language_code
|
|
1687
|
+
):
|
|
1688
|
+
# Combine partial + MLM fillings
|
|
1689
|
+
complete = partial_filling.copy()
|
|
1690
|
+
complete.update(mlm_filling)
|
|
1691
|
+
yield complete
|
|
1692
|
+
|
|
1693
|
+
def _check_template_constraints(
|
|
1694
|
+
self,
|
|
1695
|
+
template: Template,
|
|
1696
|
+
slot_fillers: dict[str, LexicalItem],
|
|
1697
|
+
) -> bool:
|
|
1698
|
+
"""Check if slot fillers satisfy template-level constraints.
|
|
1699
|
+
|
|
1700
|
+
Only evaluates constraints where all referenced slots are present.
|
|
1701
|
+
Constraints referencing missing slots are skipped (deferred).
|
|
1702
|
+
|
|
1703
|
+
Parameters
|
|
1704
|
+
----------
|
|
1705
|
+
template : Template
|
|
1706
|
+
Template with multi-slot constraints
|
|
1707
|
+
slot_fillers : dict[str, LexicalItem]
|
|
1708
|
+
Complete or partial slot fillings
|
|
1709
|
+
|
|
1710
|
+
Returns
|
|
1711
|
+
-------
|
|
1712
|
+
bool
|
|
1713
|
+
True if all evaluable template constraints are satisfied
|
|
1714
|
+
"""
|
|
1715
|
+
logger.info(
|
|
1716
|
+
f"[TemplateConstraints] Called with template '{template.name}', "
|
|
1717
|
+
f"{len(template.constraints)} constraints, {len(slot_fillers)} fillers"
|
|
1718
|
+
)
|
|
1719
|
+
if not template.constraints:
|
|
1720
|
+
logger.info("[TemplateConstraints] No constraints, returning True")
|
|
1721
|
+
return True
|
|
1722
|
+
|
|
1723
|
+
# Extract slot names referenced in each constraint
|
|
1724
|
+
# Pattern matches "slot_name." but NOT "something.property." (no dot before)
|
|
1725
|
+
slot_pattern = re.compile(r"(?<![.])\b([a-zA-Z_][a-zA-Z0-9_]*)\.")
|
|
1726
|
+
filled_slots = set(slot_fillers.keys())
|
|
1727
|
+
|
|
1728
|
+
# Filter to only constraints where all referenced slots are filled
|
|
1729
|
+
evaluable_constraints = []
|
|
1730
|
+
for constraint in template.constraints:
|
|
1731
|
+
# Remove string literals before matching to avoid false matches
|
|
1732
|
+
# (e.g., 'V.PTCP' should not match slot 'V')
|
|
1733
|
+
expr_no_strings = re.sub(r"'[^']*'|\"[^\"]*\"", '""', constraint.expression)
|
|
1734
|
+
referenced_slots = set(slot_pattern.findall(expr_no_strings))
|
|
1735
|
+
if referenced_slots.issubset(filled_slots):
|
|
1736
|
+
evaluable_constraints.append(constraint)
|
|
1737
|
+
logger.info(
|
|
1738
|
+
f"[TemplateConstraints] Will evaluate: {constraint.description}"
|
|
1739
|
+
)
|
|
1740
|
+
else:
|
|
1741
|
+
missing = referenced_slots - filled_slots
|
|
1742
|
+
logger.info(
|
|
1743
|
+
f"[TemplateConstraints] Deferring (missing {missing}): "
|
|
1744
|
+
f"{constraint.description}"
|
|
1745
|
+
)
|
|
1746
|
+
|
|
1747
|
+
if not evaluable_constraints:
|
|
1748
|
+
return True # No constraints can be evaluated yet
|
|
1749
|
+
|
|
1750
|
+
# Use ConstraintResolver to evaluate constraints properly
|
|
1751
|
+
n_constraints = len(evaluable_constraints)
|
|
1752
|
+
n_slots = len(filled_slots)
|
|
1753
|
+
logger.info(
|
|
1754
|
+
f"[TemplateConstraints] Evaluating {n_constraints} constraints "
|
|
1755
|
+
f"with {n_slots} filled slots"
|
|
1756
|
+
)
|
|
1757
|
+
resolver = ConstraintResolver()
|
|
1758
|
+
result = resolver.evaluate_template_constraints(
|
|
1759
|
+
slot_fillers, evaluable_constraints
|
|
1760
|
+
)
|
|
1761
|
+
if not result:
|
|
1762
|
+
logger.info("[TemplateConstraints] Combination REJECTED by constraints")
|
|
1763
|
+
return result
|
|
1764
|
+
|
|
1765
|
+
def _create_mlm_template(
|
|
1766
|
+
self, template: Template, partial_filling: dict[str, LexicalItem]
|
|
1767
|
+
) -> Template:
|
|
1768
|
+
"""Create template with non-MLM slots already filled.
|
|
1769
|
+
|
|
1770
|
+
Parameters
|
|
1771
|
+
----------
|
|
1772
|
+
template : Template
|
|
1773
|
+
Original template
|
|
1774
|
+
partial_filling : dict[str, LexicalItem]
|
|
1775
|
+
Items filling non-MLM slots
|
|
1776
|
+
|
|
1777
|
+
Returns
|
|
1778
|
+
-------
|
|
1779
|
+
Template
|
|
1780
|
+
Modified template with non-MLM slots replaced by text
|
|
1781
|
+
"""
|
|
1782
|
+
# Replace non-MLM slots in template string with their fillings
|
|
1783
|
+
modified_string = template.template_string
|
|
1784
|
+
for slot_name, item in partial_filling.items():
|
|
1785
|
+
placeholder = f"{{{slot_name}}}"
|
|
1786
|
+
# Use actual form if available (e.g., "is" not "be"), otherwise lemma
|
|
1787
|
+
surface_form = item.form if item.form is not None else item.lemma
|
|
1788
|
+
modified_string = modified_string.replace(placeholder, surface_form)
|
|
1789
|
+
|
|
1790
|
+
# Create new template with only MLM slots
|
|
1791
|
+
mlm_slots = {
|
|
1792
|
+
name: slot
|
|
1793
|
+
for name, slot in template.slots.items()
|
|
1794
|
+
if name in self.mlm_slots
|
|
1795
|
+
}
|
|
1796
|
+
|
|
1797
|
+
# Create modified template
|
|
1798
|
+
modified_template = Template(
|
|
1799
|
+
name=f"{template.name}_mlm",
|
|
1800
|
+
template_string=modified_string,
|
|
1801
|
+
slots=mlm_slots,
|
|
1802
|
+
constraints=template.constraints,
|
|
1803
|
+
language_code=template.language_code,
|
|
1804
|
+
)
|
|
1805
|
+
|
|
1806
|
+
return modified_template
|