bead 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bead/__init__.py +11 -0
- bead/__main__.py +11 -0
- bead/active_learning/__init__.py +15 -0
- bead/active_learning/config.py +231 -0
- bead/active_learning/loop.py +566 -0
- bead/active_learning/models/__init__.py +24 -0
- bead/active_learning/models/base.py +852 -0
- bead/active_learning/models/binary.py +910 -0
- bead/active_learning/models/categorical.py +943 -0
- bead/active_learning/models/cloze.py +862 -0
- bead/active_learning/models/forced_choice.py +956 -0
- bead/active_learning/models/free_text.py +773 -0
- bead/active_learning/models/lora.py +365 -0
- bead/active_learning/models/magnitude.py +835 -0
- bead/active_learning/models/multi_select.py +795 -0
- bead/active_learning/models/ordinal_scale.py +811 -0
- bead/active_learning/models/peft_adapter.py +155 -0
- bead/active_learning/models/random_effects.py +639 -0
- bead/active_learning/selection.py +354 -0
- bead/active_learning/strategies.py +391 -0
- bead/active_learning/trainers/__init__.py +26 -0
- bead/active_learning/trainers/base.py +210 -0
- bead/active_learning/trainers/data_collator.py +172 -0
- bead/active_learning/trainers/dataset_utils.py +261 -0
- bead/active_learning/trainers/huggingface.py +304 -0
- bead/active_learning/trainers/lightning.py +324 -0
- bead/active_learning/trainers/metrics.py +424 -0
- bead/active_learning/trainers/mixed_effects.py +551 -0
- bead/active_learning/trainers/model_wrapper.py +509 -0
- bead/active_learning/trainers/registry.py +104 -0
- bead/adapters/__init__.py +11 -0
- bead/adapters/huggingface.py +61 -0
- bead/behavioral/__init__.py +116 -0
- bead/behavioral/analytics.py +646 -0
- bead/behavioral/extraction.py +343 -0
- bead/behavioral/merging.py +343 -0
- bead/cli/__init__.py +11 -0
- bead/cli/active_learning.py +513 -0
- bead/cli/active_learning_commands.py +779 -0
- bead/cli/completion.py +359 -0
- bead/cli/config.py +624 -0
- bead/cli/constraint_builders.py +286 -0
- bead/cli/deployment.py +859 -0
- bead/cli/deployment_trials.py +493 -0
- bead/cli/deployment_ui.py +332 -0
- bead/cli/display.py +378 -0
- bead/cli/items.py +960 -0
- bead/cli/items_factories.py +776 -0
- bead/cli/list_constraints.py +714 -0
- bead/cli/lists.py +490 -0
- bead/cli/main.py +430 -0
- bead/cli/models.py +877 -0
- bead/cli/resource_loaders.py +621 -0
- bead/cli/resources.py +1036 -0
- bead/cli/shell.py +356 -0
- bead/cli/simulate.py +840 -0
- bead/cli/templates.py +1158 -0
- bead/cli/training.py +1080 -0
- bead/cli/utils.py +614 -0
- bead/cli/workflow.py +1273 -0
- bead/config/__init__.py +68 -0
- bead/config/active_learning.py +1009 -0
- bead/config/config.py +192 -0
- bead/config/defaults.py +118 -0
- bead/config/deployment.py +217 -0
- bead/config/env.py +147 -0
- bead/config/item.py +45 -0
- bead/config/list.py +193 -0
- bead/config/loader.py +149 -0
- bead/config/logging.py +42 -0
- bead/config/model.py +49 -0
- bead/config/paths.py +46 -0
- bead/config/profiles.py +320 -0
- bead/config/resources.py +47 -0
- bead/config/serialization.py +210 -0
- bead/config/simulation.py +206 -0
- bead/config/template.py +238 -0
- bead/config/validation.py +267 -0
- bead/data/__init__.py +65 -0
- bead/data/base.py +87 -0
- bead/data/identifiers.py +97 -0
- bead/data/language_codes.py +61 -0
- bead/data/metadata.py +270 -0
- bead/data/range.py +123 -0
- bead/data/repository.py +358 -0
- bead/data/serialization.py +249 -0
- bead/data/timestamps.py +89 -0
- bead/data/validation.py +349 -0
- bead/data_collection/__init__.py +11 -0
- bead/data_collection/jatos.py +223 -0
- bead/data_collection/merger.py +154 -0
- bead/data_collection/prolific.py +198 -0
- bead/deployment/__init__.py +5 -0
- bead/deployment/distribution.py +402 -0
- bead/deployment/jatos/__init__.py +1 -0
- bead/deployment/jatos/api.py +200 -0
- bead/deployment/jatos/exporter.py +210 -0
- bead/deployment/jspsych/__init__.py +9 -0
- bead/deployment/jspsych/biome.json +44 -0
- bead/deployment/jspsych/config.py +411 -0
- bead/deployment/jspsych/generator.py +598 -0
- bead/deployment/jspsych/package.json +51 -0
- bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
- bead/deployment/jspsych/randomizer.py +299 -0
- bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
- bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
- bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
- bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
- bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
- bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
- bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
- bead/deployment/jspsych/src/plugins/rating.ts +248 -0
- bead/deployment/jspsych/src/slopit/index.ts +9 -0
- bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
- bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
- bead/deployment/jspsych/templates/experiment.css +1 -0
- bead/deployment/jspsych/templates/experiment.js.template +289 -0
- bead/deployment/jspsych/templates/index.html +51 -0
- bead/deployment/jspsych/templates/randomizer.js +241 -0
- bead/deployment/jspsych/templates/randomizer.js.template +313 -0
- bead/deployment/jspsych/trials.py +723 -0
- bead/deployment/jspsych/tsconfig.json +23 -0
- bead/deployment/jspsych/tsup.config.ts +30 -0
- bead/deployment/jspsych/ui/__init__.py +1 -0
- bead/deployment/jspsych/ui/components.py +383 -0
- bead/deployment/jspsych/ui/styles.py +411 -0
- bead/dsl/__init__.py +80 -0
- bead/dsl/ast.py +168 -0
- bead/dsl/context.py +178 -0
- bead/dsl/errors.py +71 -0
- bead/dsl/evaluator.py +570 -0
- bead/dsl/grammar.lark +81 -0
- bead/dsl/parser.py +231 -0
- bead/dsl/stdlib.py +929 -0
- bead/evaluation/__init__.py +13 -0
- bead/evaluation/convergence.py +485 -0
- bead/evaluation/interannotator.py +398 -0
- bead/items/__init__.py +40 -0
- bead/items/adapters/__init__.py +70 -0
- bead/items/adapters/anthropic.py +224 -0
- bead/items/adapters/api_utils.py +167 -0
- bead/items/adapters/base.py +216 -0
- bead/items/adapters/google.py +259 -0
- bead/items/adapters/huggingface.py +1074 -0
- bead/items/adapters/openai.py +323 -0
- bead/items/adapters/registry.py +202 -0
- bead/items/adapters/sentence_transformers.py +224 -0
- bead/items/adapters/togetherai.py +309 -0
- bead/items/binary.py +515 -0
- bead/items/cache.py +558 -0
- bead/items/categorical.py +593 -0
- bead/items/cloze.py +757 -0
- bead/items/constructor.py +784 -0
- bead/items/forced_choice.py +413 -0
- bead/items/free_text.py +681 -0
- bead/items/generation.py +432 -0
- bead/items/item.py +396 -0
- bead/items/item_template.py +787 -0
- bead/items/magnitude.py +573 -0
- bead/items/multi_select.py +621 -0
- bead/items/ordinal_scale.py +569 -0
- bead/items/scoring.py +448 -0
- bead/items/validation.py +723 -0
- bead/lists/__init__.py +30 -0
- bead/lists/balancer.py +263 -0
- bead/lists/constraints.py +1067 -0
- bead/lists/experiment_list.py +286 -0
- bead/lists/list_collection.py +378 -0
- bead/lists/partitioner.py +1141 -0
- bead/lists/stratification.py +254 -0
- bead/participants/__init__.py +73 -0
- bead/participants/collection.py +699 -0
- bead/participants/merging.py +312 -0
- bead/participants/metadata_spec.py +491 -0
- bead/participants/models.py +276 -0
- bead/resources/__init__.py +29 -0
- bead/resources/adapters/__init__.py +19 -0
- bead/resources/adapters/base.py +104 -0
- bead/resources/adapters/cache.py +128 -0
- bead/resources/adapters/glazing.py +508 -0
- bead/resources/adapters/registry.py +117 -0
- bead/resources/adapters/unimorph.py +796 -0
- bead/resources/classification.py +856 -0
- bead/resources/constraint_builders.py +329 -0
- bead/resources/constraints.py +165 -0
- bead/resources/lexical_item.py +223 -0
- bead/resources/lexicon.py +744 -0
- bead/resources/loaders.py +209 -0
- bead/resources/template.py +441 -0
- bead/resources/template_collection.py +707 -0
- bead/resources/template_generation.py +349 -0
- bead/simulation/__init__.py +29 -0
- bead/simulation/annotators/__init__.py +15 -0
- bead/simulation/annotators/base.py +175 -0
- bead/simulation/annotators/distance_based.py +135 -0
- bead/simulation/annotators/lm_based.py +114 -0
- bead/simulation/annotators/oracle.py +182 -0
- bead/simulation/annotators/random.py +181 -0
- bead/simulation/dsl_extension/__init__.py +3 -0
- bead/simulation/noise_models/__init__.py +13 -0
- bead/simulation/noise_models/base.py +42 -0
- bead/simulation/noise_models/random_noise.py +82 -0
- bead/simulation/noise_models/systematic.py +132 -0
- bead/simulation/noise_models/temperature.py +86 -0
- bead/simulation/runner.py +144 -0
- bead/simulation/strategies/__init__.py +23 -0
- bead/simulation/strategies/base.py +123 -0
- bead/simulation/strategies/binary.py +103 -0
- bead/simulation/strategies/categorical.py +123 -0
- bead/simulation/strategies/cloze.py +224 -0
- bead/simulation/strategies/forced_choice.py +127 -0
- bead/simulation/strategies/free_text.py +105 -0
- bead/simulation/strategies/magnitude.py +116 -0
- bead/simulation/strategies/multi_select.py +129 -0
- bead/simulation/strategies/ordinal_scale.py +131 -0
- bead/templates/__init__.py +27 -0
- bead/templates/adapters/__init__.py +17 -0
- bead/templates/adapters/base.py +128 -0
- bead/templates/adapters/cache.py +178 -0
- bead/templates/adapters/huggingface.py +312 -0
- bead/templates/combinatorics.py +103 -0
- bead/templates/filler.py +605 -0
- bead/templates/renderers.py +177 -0
- bead/templates/resolver.py +178 -0
- bead/templates/strategies.py +1806 -0
- bead/templates/streaming.py +195 -0
- bead-0.1.0.dist-info/METADATA +212 -0
- bead-0.1.0.dist-info/RECORD +231 -0
- bead-0.1.0.dist-info/WHEEL +4 -0
- bead-0.1.0.dist-info/entry_points.txt +2 -0
- bead-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
"""HuggingFace masked language model adapter for template filling."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from transformers import (
|
|
9
|
+
AutoModelForMaskedLM,
|
|
10
|
+
AutoTokenizer,
|
|
11
|
+
PreTrainedModel,
|
|
12
|
+
PreTrainedTokenizer,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from bead.adapters.huggingface import DeviceType, HuggingFaceAdapterMixin
|
|
16
|
+
from bead.templates.adapters.base import TemplateFillingModelAdapter
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class HuggingFaceMLMAdapter(HuggingFaceAdapterMixin, TemplateFillingModelAdapter):
|
|
20
|
+
"""Adapter for HuggingFace masked language models.
|
|
21
|
+
|
|
22
|
+
Supports BERT, RoBERTa, ALBERT, and other MLM architectures.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
model_name : str
|
|
27
|
+
HuggingFace model identifier (e.g., "bert-base-uncased")
|
|
28
|
+
device : DeviceType
|
|
29
|
+
Computation device ("cpu", "cuda", "mps")
|
|
30
|
+
cache_dir : Path | None
|
|
31
|
+
Directory for caching model files
|
|
32
|
+
|
|
33
|
+
Examples
|
|
34
|
+
--------
|
|
35
|
+
>>> adapter = HuggingFaceMLMAdapter("bert-base-uncased", device="cpu")
|
|
36
|
+
>>> adapter.load_model()
|
|
37
|
+
>>> predictions = adapter.predict_masked_token(
|
|
38
|
+
... text="The cat sat on the mat",
|
|
39
|
+
... mask_position=2,
|
|
40
|
+
... top_k=5
|
|
41
|
+
... )
|
|
42
|
+
>>> for token, log_prob in predictions:
|
|
43
|
+
... print(f"{token}: {log_prob:.2f}")
|
|
44
|
+
>>> adapter.unload_model()
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
model_name: str,
|
|
50
|
+
device: DeviceType = "cpu",
|
|
51
|
+
cache_dir: Path | None = None,
|
|
52
|
+
) -> None:
|
|
53
|
+
# validate device before passing to parent
|
|
54
|
+
validated_device = self._validate_device(device)
|
|
55
|
+
super().__init__(model_name, validated_device, cache_dir)
|
|
56
|
+
self.model: PreTrainedModel | None = None
|
|
57
|
+
self.tokenizer: PreTrainedTokenizer | None = None
|
|
58
|
+
|
|
59
|
+
def load_model(self) -> None:
|
|
60
|
+
"""Load model and tokenizer from HuggingFace.
|
|
61
|
+
|
|
62
|
+
Raises
|
|
63
|
+
------
|
|
64
|
+
RuntimeError
|
|
65
|
+
If model loading fails
|
|
66
|
+
"""
|
|
67
|
+
if self._model_loaded:
|
|
68
|
+
return
|
|
69
|
+
|
|
70
|
+
try:
|
|
71
|
+
# load tokenizer
|
|
72
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
73
|
+
self.model_name,
|
|
74
|
+
cache_dir=self.cache_dir,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# load model
|
|
78
|
+
self.model = AutoModelForMaskedLM.from_pretrained(
|
|
79
|
+
self.model_name,
|
|
80
|
+
cache_dir=self.cache_dir,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# move to device
|
|
84
|
+
self.model.to(self.device)
|
|
85
|
+
|
|
86
|
+
# set to evaluation mode
|
|
87
|
+
self.model.eval()
|
|
88
|
+
|
|
89
|
+
self._model_loaded = True
|
|
90
|
+
|
|
91
|
+
except Exception as e:
|
|
92
|
+
raise RuntimeError(f"Failed to load model {self.model_name}: {e}") from e
|
|
93
|
+
|
|
94
|
+
def unload_model(self) -> None:
|
|
95
|
+
"""Unload model from memory."""
|
|
96
|
+
if not self._model_loaded:
|
|
97
|
+
return
|
|
98
|
+
|
|
99
|
+
# move model to CPU and delete
|
|
100
|
+
if self.model is not None:
|
|
101
|
+
self.model.to("cpu")
|
|
102
|
+
del self.model
|
|
103
|
+
self.model = None
|
|
104
|
+
|
|
105
|
+
del self.tokenizer
|
|
106
|
+
self.tokenizer = None
|
|
107
|
+
|
|
108
|
+
self._model_loaded = False
|
|
109
|
+
|
|
110
|
+
# clear CUDA cache if using GPU
|
|
111
|
+
if self.device == "cuda":
|
|
112
|
+
torch.cuda.empty_cache()
|
|
113
|
+
|
|
114
|
+
def predict_masked_token(
|
|
115
|
+
self,
|
|
116
|
+
text: str,
|
|
117
|
+
mask_position: int,
|
|
118
|
+
top_k: int = 10,
|
|
119
|
+
) -> list[tuple[str, float]]:
|
|
120
|
+
"""Predict masked token at specified position.
|
|
121
|
+
|
|
122
|
+
Parameters
|
|
123
|
+
----------
|
|
124
|
+
text : str
|
|
125
|
+
Text with mask token (e.g., "The cat [MASK] quickly")
|
|
126
|
+
mask_position : int
|
|
127
|
+
Token position of mask (0-indexed)
|
|
128
|
+
top_k : int
|
|
129
|
+
Number of top predictions to return
|
|
130
|
+
|
|
131
|
+
Returns
|
|
132
|
+
-------
|
|
133
|
+
list[tuple[str, float]]
|
|
134
|
+
List of (token, log_probability) tuples, sorted by probability
|
|
135
|
+
|
|
136
|
+
Raises
|
|
137
|
+
------
|
|
138
|
+
RuntimeError
|
|
139
|
+
If model is not loaded
|
|
140
|
+
ValueError
|
|
141
|
+
If mask_position is invalid or text has no mask token
|
|
142
|
+
"""
|
|
143
|
+
if not self._model_loaded:
|
|
144
|
+
raise RuntimeError("Model not loaded. Call load_model() first.")
|
|
145
|
+
|
|
146
|
+
# tokenize input
|
|
147
|
+
inputs = self.tokenizer(text, return_tensors="pt")
|
|
148
|
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
149
|
+
|
|
150
|
+
# find mask token ID
|
|
151
|
+
mask_token_id = self.tokenizer.mask_token_id
|
|
152
|
+
if mask_token_id is None:
|
|
153
|
+
raise ValueError(f"Model {self.model_name} does not have a mask token")
|
|
154
|
+
|
|
155
|
+
# find mask position in tokenized input
|
|
156
|
+
input_ids = inputs["input_ids"][0]
|
|
157
|
+
mask_positions = (input_ids == mask_token_id).nonzero(as_tuple=True)[0]
|
|
158
|
+
|
|
159
|
+
if len(mask_positions) == 0:
|
|
160
|
+
raise ValueError(f"No mask token found in text: {text}")
|
|
161
|
+
|
|
162
|
+
if mask_position >= len(mask_positions):
|
|
163
|
+
raise ValueError(
|
|
164
|
+
f"mask_position {mask_position} out of range. "
|
|
165
|
+
f"Found {len(mask_positions)} mask tokens in text."
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
# get actual token index
|
|
169
|
+
mask_idx = mask_positions[mask_position].item()
|
|
170
|
+
|
|
171
|
+
# forward pass
|
|
172
|
+
with torch.no_grad():
|
|
173
|
+
outputs = self.model(**inputs)
|
|
174
|
+
logits = outputs.logits
|
|
175
|
+
|
|
176
|
+
# get predictions for mask position
|
|
177
|
+
mask_logits = logits[0, mask_idx]
|
|
178
|
+
|
|
179
|
+
# convert to log probabilities
|
|
180
|
+
log_probs = torch.log_softmax(mask_logits, dim=0)
|
|
181
|
+
|
|
182
|
+
# get top-k predictions
|
|
183
|
+
top_log_probs, top_indices = torch.topk(log_probs, k=min(top_k, len(log_probs)))
|
|
184
|
+
|
|
185
|
+
# convert to tokens
|
|
186
|
+
predictions: list[tuple[str, float]] = []
|
|
187
|
+
for log_prob, idx in zip(top_log_probs.cpu(), top_indices.cpu(), strict=True):
|
|
188
|
+
token = self.tokenizer.decode([idx], skip_special_tokens=True).strip()
|
|
189
|
+
predictions.append((token, float(log_prob)))
|
|
190
|
+
|
|
191
|
+
return predictions
|
|
192
|
+
|
|
193
|
+
def predict_masked_token_batch(
|
|
194
|
+
self,
|
|
195
|
+
texts: list[str],
|
|
196
|
+
mask_position: int = 0,
|
|
197
|
+
top_k: int = 10,
|
|
198
|
+
) -> list[list[tuple[str, float]]]:
|
|
199
|
+
"""Predict masked tokens for multiple texts in a single batch.
|
|
200
|
+
|
|
201
|
+
Parameters
|
|
202
|
+
----------
|
|
203
|
+
texts : list[str]
|
|
204
|
+
List of texts with mask tokens
|
|
205
|
+
mask_position : int
|
|
206
|
+
Token position of mask (0-indexed, relative to mask tokens found)
|
|
207
|
+
top_k : int
|
|
208
|
+
Number of top predictions to return per text
|
|
209
|
+
|
|
210
|
+
Returns
|
|
211
|
+
-------
|
|
212
|
+
list[list[tuple[str, float]]]
|
|
213
|
+
List of predictions for each text. Each element is a list of
|
|
214
|
+
(token, log_probability) tuples.
|
|
215
|
+
|
|
216
|
+
Raises
|
|
217
|
+
------
|
|
218
|
+
RuntimeError
|
|
219
|
+
If model is not loaded
|
|
220
|
+
ValueError
|
|
221
|
+
If any text has no mask token
|
|
222
|
+
"""
|
|
223
|
+
if not self._model_loaded:
|
|
224
|
+
raise RuntimeError("Model not loaded. Call load_model() first.")
|
|
225
|
+
|
|
226
|
+
if not texts:
|
|
227
|
+
return []
|
|
228
|
+
|
|
229
|
+
# tokenize all texts with padding
|
|
230
|
+
inputs = self.tokenizer(
|
|
231
|
+
texts,
|
|
232
|
+
return_tensors="pt",
|
|
233
|
+
padding=True,
|
|
234
|
+
truncation=True,
|
|
235
|
+
)
|
|
236
|
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
237
|
+
|
|
238
|
+
# find mask token ID
|
|
239
|
+
mask_token_id = self.tokenizer.mask_token_id
|
|
240
|
+
if mask_token_id is None:
|
|
241
|
+
raise ValueError(f"Model {self.model_name} does not have a mask token")
|
|
242
|
+
|
|
243
|
+
# forward pass for entire batch
|
|
244
|
+
with torch.no_grad():
|
|
245
|
+
outputs = self.model(**inputs)
|
|
246
|
+
logits = outputs.logits # shape: (batch_size, seq_len, vocab_size)
|
|
247
|
+
|
|
248
|
+
# process each text in batch
|
|
249
|
+
results: list[list[tuple[str, float]]] = []
|
|
250
|
+
for i, text in enumerate(texts):
|
|
251
|
+
# find mask position in this text
|
|
252
|
+
input_ids = inputs["input_ids"][i]
|
|
253
|
+
mask_positions = (input_ids == mask_token_id).nonzero(as_tuple=True)[0]
|
|
254
|
+
|
|
255
|
+
if len(mask_positions) == 0:
|
|
256
|
+
raise ValueError(f"No mask token found in text: {text}")
|
|
257
|
+
|
|
258
|
+
if mask_position >= len(mask_positions):
|
|
259
|
+
raise ValueError(
|
|
260
|
+
f"mask_position {mask_position} out of range. "
|
|
261
|
+
f"Found {len(mask_positions)} mask tokens in text."
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# get actual token index
|
|
265
|
+
mask_idx = mask_positions[mask_position].item()
|
|
266
|
+
|
|
267
|
+
# get predictions for this mask position
|
|
268
|
+
mask_logits = logits[i, mask_idx]
|
|
269
|
+
|
|
270
|
+
# convert to log probabilities
|
|
271
|
+
log_probs = torch.log_softmax(mask_logits, dim=0)
|
|
272
|
+
|
|
273
|
+
# get top-k predictions
|
|
274
|
+
top_log_probs, top_indices = torch.topk(
|
|
275
|
+
log_probs, k=min(top_k, len(log_probs))
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
# convert to tokens
|
|
279
|
+
predictions: list[tuple[str, float]] = []
|
|
280
|
+
for log_prob, idx in zip(
|
|
281
|
+
top_log_probs.cpu(), top_indices.cpu(), strict=True
|
|
282
|
+
):
|
|
283
|
+
token = self.tokenizer.decode([idx], skip_special_tokens=True).strip()
|
|
284
|
+
predictions.append((token, float(log_prob)))
|
|
285
|
+
|
|
286
|
+
results.append(predictions)
|
|
287
|
+
|
|
288
|
+
return results
|
|
289
|
+
|
|
290
|
+
def get_mask_token(self) -> str:
|
|
291
|
+
"""Get the mask token for this model.
|
|
292
|
+
|
|
293
|
+
Returns
|
|
294
|
+
-------
|
|
295
|
+
str
|
|
296
|
+
Mask token string (e.g., "[MASK]" for BERT)
|
|
297
|
+
|
|
298
|
+
Raises
|
|
299
|
+
------
|
|
300
|
+
RuntimeError
|
|
301
|
+
If model is not loaded
|
|
302
|
+
ValueError
|
|
303
|
+
If model has no mask token
|
|
304
|
+
"""
|
|
305
|
+
if not self._model_loaded:
|
|
306
|
+
raise RuntimeError("Model not loaded. Call load_model() first.")
|
|
307
|
+
|
|
308
|
+
mask_token = self.tokenizer.mask_token
|
|
309
|
+
if mask_token is None:
|
|
310
|
+
raise ValueError(f"Model {self.model_name} does not have a mask token")
|
|
311
|
+
|
|
312
|
+
return mask_token
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
"""Combinatorial utilities for template filling."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import itertools
|
|
6
|
+
import random
|
|
7
|
+
from collections.abc import Iterator
|
|
8
|
+
from typing import TypeVar
|
|
9
|
+
|
|
10
|
+
T = TypeVar("T")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def cartesian_product[T](*iterables: list[T]) -> Iterator[tuple[T, ...]]:
|
|
14
|
+
"""Generate Cartesian product of iterables.
|
|
15
|
+
|
|
16
|
+
Equivalent to itertools.product but with explicit type hints
|
|
17
|
+
and documentation for template filling use case.
|
|
18
|
+
|
|
19
|
+
Parameters
|
|
20
|
+
----------
|
|
21
|
+
*iterables : list[T]
|
|
22
|
+
Variable number of iterables to combine.
|
|
23
|
+
|
|
24
|
+
Yields
|
|
25
|
+
------
|
|
26
|
+
tuple[T, ...]
|
|
27
|
+
Each combination from the Cartesian product.
|
|
28
|
+
|
|
29
|
+
Examples
|
|
30
|
+
--------
|
|
31
|
+
>>> list(cartesian_product([1, 2], ['a', 'b']))
|
|
32
|
+
[(1, 'a'), (1, 'b'), (2, 'a'), (2, 'b')]
|
|
33
|
+
"""
|
|
34
|
+
return itertools.product(*iterables)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def count_combinations[T](*iterables: list[T]) -> int:
|
|
38
|
+
"""Count total combinations without generating them.
|
|
39
|
+
|
|
40
|
+
Calculate the size of the Cartesian product space efficiently
|
|
41
|
+
without actually generating combinations.
|
|
42
|
+
|
|
43
|
+
Parameters
|
|
44
|
+
----------
|
|
45
|
+
*iterables : list[Any]
|
|
46
|
+
Variable number of iterables.
|
|
47
|
+
|
|
48
|
+
Returns
|
|
49
|
+
-------
|
|
50
|
+
int
|
|
51
|
+
Total number of combinations.
|
|
52
|
+
|
|
53
|
+
Examples
|
|
54
|
+
--------
|
|
55
|
+
>>> count_combinations([1, 2], ['a', 'b'], [True, False])
|
|
56
|
+
8
|
|
57
|
+
"""
|
|
58
|
+
count = 1
|
|
59
|
+
for iterable in iterables:
|
|
60
|
+
count *= len(iterable)
|
|
61
|
+
return count
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def stratified_sample[T](
|
|
65
|
+
groups: dict[str, list[T]],
|
|
66
|
+
n_per_group: int,
|
|
67
|
+
seed: int | None = None,
|
|
68
|
+
) -> list[T]:
|
|
69
|
+
"""Sample items from groups with balanced representation.
|
|
70
|
+
|
|
71
|
+
Ensure each group is represented approximately equally in the sample.
|
|
72
|
+
|
|
73
|
+
Parameters
|
|
74
|
+
----------
|
|
75
|
+
groups : dict[str, list[T]]
|
|
76
|
+
Dictionary mapping group names to lists of items.
|
|
77
|
+
n_per_group : int
|
|
78
|
+
Number of items to sample from each group.
|
|
79
|
+
seed : int | None
|
|
80
|
+
Random seed for reproducibility.
|
|
81
|
+
|
|
82
|
+
Returns
|
|
83
|
+
-------
|
|
84
|
+
list[T]
|
|
85
|
+
Sampled items, balanced across groups.
|
|
86
|
+
|
|
87
|
+
Examples
|
|
88
|
+
--------
|
|
89
|
+
>>> groups = {"verbs": [v1, v2, v3], "nouns": [n1, n2, n3]}
|
|
90
|
+
>>> sample = stratified_sample(groups, n_per_group=2, seed=42)
|
|
91
|
+
>>> len(sample)
|
|
92
|
+
4
|
|
93
|
+
"""
|
|
94
|
+
if seed is not None:
|
|
95
|
+
random.seed(seed)
|
|
96
|
+
|
|
97
|
+
sampled: list[T] = []
|
|
98
|
+
for group_items in groups.values():
|
|
99
|
+
# Sample with replacement if n_per_group > len(group_items)
|
|
100
|
+
k = min(n_per_group, len(group_items))
|
|
101
|
+
sampled.extend(random.sample(group_items, k))
|
|
102
|
+
|
|
103
|
+
return sampled
|