bead 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bead/__init__.py +11 -0
- bead/__main__.py +11 -0
- bead/active_learning/__init__.py +15 -0
- bead/active_learning/config.py +231 -0
- bead/active_learning/loop.py +566 -0
- bead/active_learning/models/__init__.py +24 -0
- bead/active_learning/models/base.py +852 -0
- bead/active_learning/models/binary.py +910 -0
- bead/active_learning/models/categorical.py +943 -0
- bead/active_learning/models/cloze.py +862 -0
- bead/active_learning/models/forced_choice.py +956 -0
- bead/active_learning/models/free_text.py +773 -0
- bead/active_learning/models/lora.py +365 -0
- bead/active_learning/models/magnitude.py +835 -0
- bead/active_learning/models/multi_select.py +795 -0
- bead/active_learning/models/ordinal_scale.py +811 -0
- bead/active_learning/models/peft_adapter.py +155 -0
- bead/active_learning/models/random_effects.py +639 -0
- bead/active_learning/selection.py +354 -0
- bead/active_learning/strategies.py +391 -0
- bead/active_learning/trainers/__init__.py +26 -0
- bead/active_learning/trainers/base.py +210 -0
- bead/active_learning/trainers/data_collator.py +172 -0
- bead/active_learning/trainers/dataset_utils.py +261 -0
- bead/active_learning/trainers/huggingface.py +304 -0
- bead/active_learning/trainers/lightning.py +324 -0
- bead/active_learning/trainers/metrics.py +424 -0
- bead/active_learning/trainers/mixed_effects.py +551 -0
- bead/active_learning/trainers/model_wrapper.py +509 -0
- bead/active_learning/trainers/registry.py +104 -0
- bead/adapters/__init__.py +11 -0
- bead/adapters/huggingface.py +61 -0
- bead/behavioral/__init__.py +116 -0
- bead/behavioral/analytics.py +646 -0
- bead/behavioral/extraction.py +343 -0
- bead/behavioral/merging.py +343 -0
- bead/cli/__init__.py +11 -0
- bead/cli/active_learning.py +513 -0
- bead/cli/active_learning_commands.py +779 -0
- bead/cli/completion.py +359 -0
- bead/cli/config.py +624 -0
- bead/cli/constraint_builders.py +286 -0
- bead/cli/deployment.py +859 -0
- bead/cli/deployment_trials.py +493 -0
- bead/cli/deployment_ui.py +332 -0
- bead/cli/display.py +378 -0
- bead/cli/items.py +960 -0
- bead/cli/items_factories.py +776 -0
- bead/cli/list_constraints.py +714 -0
- bead/cli/lists.py +490 -0
- bead/cli/main.py +430 -0
- bead/cli/models.py +877 -0
- bead/cli/resource_loaders.py +621 -0
- bead/cli/resources.py +1036 -0
- bead/cli/shell.py +356 -0
- bead/cli/simulate.py +840 -0
- bead/cli/templates.py +1158 -0
- bead/cli/training.py +1080 -0
- bead/cli/utils.py +614 -0
- bead/cli/workflow.py +1273 -0
- bead/config/__init__.py +68 -0
- bead/config/active_learning.py +1009 -0
- bead/config/config.py +192 -0
- bead/config/defaults.py +118 -0
- bead/config/deployment.py +217 -0
- bead/config/env.py +147 -0
- bead/config/item.py +45 -0
- bead/config/list.py +193 -0
- bead/config/loader.py +149 -0
- bead/config/logging.py +42 -0
- bead/config/model.py +49 -0
- bead/config/paths.py +46 -0
- bead/config/profiles.py +320 -0
- bead/config/resources.py +47 -0
- bead/config/serialization.py +210 -0
- bead/config/simulation.py +206 -0
- bead/config/template.py +238 -0
- bead/config/validation.py +267 -0
- bead/data/__init__.py +65 -0
- bead/data/base.py +87 -0
- bead/data/identifiers.py +97 -0
- bead/data/language_codes.py +61 -0
- bead/data/metadata.py +270 -0
- bead/data/range.py +123 -0
- bead/data/repository.py +358 -0
- bead/data/serialization.py +249 -0
- bead/data/timestamps.py +89 -0
- bead/data/validation.py +349 -0
- bead/data_collection/__init__.py +11 -0
- bead/data_collection/jatos.py +223 -0
- bead/data_collection/merger.py +154 -0
- bead/data_collection/prolific.py +198 -0
- bead/deployment/__init__.py +5 -0
- bead/deployment/distribution.py +402 -0
- bead/deployment/jatos/__init__.py +1 -0
- bead/deployment/jatos/api.py +200 -0
- bead/deployment/jatos/exporter.py +210 -0
- bead/deployment/jspsych/__init__.py +9 -0
- bead/deployment/jspsych/biome.json +44 -0
- bead/deployment/jspsych/config.py +411 -0
- bead/deployment/jspsych/generator.py +598 -0
- bead/deployment/jspsych/package.json +51 -0
- bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
- bead/deployment/jspsych/randomizer.py +299 -0
- bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
- bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
- bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
- bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
- bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
- bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
- bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
- bead/deployment/jspsych/src/plugins/rating.ts +248 -0
- bead/deployment/jspsych/src/slopit/index.ts +9 -0
- bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
- bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
- bead/deployment/jspsych/templates/experiment.css +1 -0
- bead/deployment/jspsych/templates/experiment.js.template +289 -0
- bead/deployment/jspsych/templates/index.html +51 -0
- bead/deployment/jspsych/templates/randomizer.js +241 -0
- bead/deployment/jspsych/templates/randomizer.js.template +313 -0
- bead/deployment/jspsych/trials.py +723 -0
- bead/deployment/jspsych/tsconfig.json +23 -0
- bead/deployment/jspsych/tsup.config.ts +30 -0
- bead/deployment/jspsych/ui/__init__.py +1 -0
- bead/deployment/jspsych/ui/components.py +383 -0
- bead/deployment/jspsych/ui/styles.py +411 -0
- bead/dsl/__init__.py +80 -0
- bead/dsl/ast.py +168 -0
- bead/dsl/context.py +178 -0
- bead/dsl/errors.py +71 -0
- bead/dsl/evaluator.py +570 -0
- bead/dsl/grammar.lark +81 -0
- bead/dsl/parser.py +231 -0
- bead/dsl/stdlib.py +929 -0
- bead/evaluation/__init__.py +13 -0
- bead/evaluation/convergence.py +485 -0
- bead/evaluation/interannotator.py +398 -0
- bead/items/__init__.py +40 -0
- bead/items/adapters/__init__.py +70 -0
- bead/items/adapters/anthropic.py +224 -0
- bead/items/adapters/api_utils.py +167 -0
- bead/items/adapters/base.py +216 -0
- bead/items/adapters/google.py +259 -0
- bead/items/adapters/huggingface.py +1074 -0
- bead/items/adapters/openai.py +323 -0
- bead/items/adapters/registry.py +202 -0
- bead/items/adapters/sentence_transformers.py +224 -0
- bead/items/adapters/togetherai.py +309 -0
- bead/items/binary.py +515 -0
- bead/items/cache.py +558 -0
- bead/items/categorical.py +593 -0
- bead/items/cloze.py +757 -0
- bead/items/constructor.py +784 -0
- bead/items/forced_choice.py +413 -0
- bead/items/free_text.py +681 -0
- bead/items/generation.py +432 -0
- bead/items/item.py +396 -0
- bead/items/item_template.py +787 -0
- bead/items/magnitude.py +573 -0
- bead/items/multi_select.py +621 -0
- bead/items/ordinal_scale.py +569 -0
- bead/items/scoring.py +448 -0
- bead/items/validation.py +723 -0
- bead/lists/__init__.py +30 -0
- bead/lists/balancer.py +263 -0
- bead/lists/constraints.py +1067 -0
- bead/lists/experiment_list.py +286 -0
- bead/lists/list_collection.py +378 -0
- bead/lists/partitioner.py +1141 -0
- bead/lists/stratification.py +254 -0
- bead/participants/__init__.py +73 -0
- bead/participants/collection.py +699 -0
- bead/participants/merging.py +312 -0
- bead/participants/metadata_spec.py +491 -0
- bead/participants/models.py +276 -0
- bead/resources/__init__.py +29 -0
- bead/resources/adapters/__init__.py +19 -0
- bead/resources/adapters/base.py +104 -0
- bead/resources/adapters/cache.py +128 -0
- bead/resources/adapters/glazing.py +508 -0
- bead/resources/adapters/registry.py +117 -0
- bead/resources/adapters/unimorph.py +796 -0
- bead/resources/classification.py +856 -0
- bead/resources/constraint_builders.py +329 -0
- bead/resources/constraints.py +165 -0
- bead/resources/lexical_item.py +223 -0
- bead/resources/lexicon.py +744 -0
- bead/resources/loaders.py +209 -0
- bead/resources/template.py +441 -0
- bead/resources/template_collection.py +707 -0
- bead/resources/template_generation.py +349 -0
- bead/simulation/__init__.py +29 -0
- bead/simulation/annotators/__init__.py +15 -0
- bead/simulation/annotators/base.py +175 -0
- bead/simulation/annotators/distance_based.py +135 -0
- bead/simulation/annotators/lm_based.py +114 -0
- bead/simulation/annotators/oracle.py +182 -0
- bead/simulation/annotators/random.py +181 -0
- bead/simulation/dsl_extension/__init__.py +3 -0
- bead/simulation/noise_models/__init__.py +13 -0
- bead/simulation/noise_models/base.py +42 -0
- bead/simulation/noise_models/random_noise.py +82 -0
- bead/simulation/noise_models/systematic.py +132 -0
- bead/simulation/noise_models/temperature.py +86 -0
- bead/simulation/runner.py +144 -0
- bead/simulation/strategies/__init__.py +23 -0
- bead/simulation/strategies/base.py +123 -0
- bead/simulation/strategies/binary.py +103 -0
- bead/simulation/strategies/categorical.py +123 -0
- bead/simulation/strategies/cloze.py +224 -0
- bead/simulation/strategies/forced_choice.py +127 -0
- bead/simulation/strategies/free_text.py +105 -0
- bead/simulation/strategies/magnitude.py +116 -0
- bead/simulation/strategies/multi_select.py +129 -0
- bead/simulation/strategies/ordinal_scale.py +131 -0
- bead/templates/__init__.py +27 -0
- bead/templates/adapters/__init__.py +17 -0
- bead/templates/adapters/base.py +128 -0
- bead/templates/adapters/cache.py +178 -0
- bead/templates/adapters/huggingface.py +312 -0
- bead/templates/combinatorics.py +103 -0
- bead/templates/filler.py +605 -0
- bead/templates/renderers.py +177 -0
- bead/templates/resolver.py +178 -0
- bead/templates/strategies.py +1806 -0
- bead/templates/streaming.py +195 -0
- bead-0.1.0.dist-info/METADATA +212 -0
- bead-0.1.0.dist-info/RECORD +231 -0
- bead-0.1.0.dist-info/WHEEL +4 -0
- bead-0.1.0.dist-info/entry_points.txt +2 -0
- bead-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,862 @@
|
|
|
1
|
+
"""Cloze model for fill-in-the-blank tasks with GLMM support.
|
|
2
|
+
|
|
3
|
+
Implements masked language modeling with participant-level random effects for
|
|
4
|
+
predicting tokens at unfilled slots in partially-filled templates. Supports
|
|
5
|
+
three modes: fixed effects, random intercepts, random slopes.
|
|
6
|
+
|
|
7
|
+
Architecture: Masked LM (BERT/RoBERTa) for token prediction
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import copy
|
|
13
|
+
import tempfile
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
import torch
|
|
18
|
+
from transformers import AutoModelForMaskedLM, AutoTokenizer, TrainingArguments
|
|
19
|
+
|
|
20
|
+
from bead.active_learning.config import VarianceComponents
|
|
21
|
+
from bead.active_learning.models.base import ActiveLearningModel, ModelPrediction
|
|
22
|
+
from bead.active_learning.models.random_effects import RandomEffectsManager
|
|
23
|
+
from bead.active_learning.trainers.data_collator import ClozeDataCollator
|
|
24
|
+
from bead.active_learning.trainers.dataset_utils import cloze_items_to_dataset
|
|
25
|
+
from bead.config.active_learning import ClozeModelConfig
|
|
26
|
+
from bead.items.item import Item
|
|
27
|
+
from bead.items.item_template import ItemTemplate, TaskType
|
|
28
|
+
|
|
29
|
+
__all__ = ["ClozeModel"]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ClozeModel(ActiveLearningModel):
|
|
33
|
+
"""Model for cloze tasks with participant-level random effects.
|
|
34
|
+
|
|
35
|
+
Uses masked language modeling (BERT/RoBERTa) to predict tokens at unfilled
|
|
36
|
+
slots in partially-filled templates. Supports three GLMM modes:
|
|
37
|
+
- Fixed effects: Standard MLM
|
|
38
|
+
- Random intercepts: Participant-specific bias on output logits
|
|
39
|
+
- Random slopes: Participant-specific MLM heads
|
|
40
|
+
|
|
41
|
+
Parameters
|
|
42
|
+
----------
|
|
43
|
+
config : ClozeModelConfig
|
|
44
|
+
Configuration object containing all model parameters.
|
|
45
|
+
|
|
46
|
+
Attributes
|
|
47
|
+
----------
|
|
48
|
+
config : ClozeModelConfig
|
|
49
|
+
Model configuration.
|
|
50
|
+
tokenizer : AutoTokenizer
|
|
51
|
+
Masked LM tokenizer.
|
|
52
|
+
model : AutoModelForMaskedLM
|
|
53
|
+
Masked language model (BERT or RoBERTa).
|
|
54
|
+
encoder : nn.Module
|
|
55
|
+
Encoder module from the model.
|
|
56
|
+
mlm_head : nn.Module
|
|
57
|
+
MLM prediction head.
|
|
58
|
+
random_effects : RandomEffectsManager
|
|
59
|
+
Manager for participant-level random effects.
|
|
60
|
+
variance_history : list[VarianceComponents]
|
|
61
|
+
Variance component estimates over training.
|
|
62
|
+
_is_fitted : bool
|
|
63
|
+
Whether model has been trained.
|
|
64
|
+
|
|
65
|
+
Examples
|
|
66
|
+
--------
|
|
67
|
+
>>> from uuid import uuid4
|
|
68
|
+
>>> from bead.items.item import Item, UnfilledSlot
|
|
69
|
+
>>> from bead.config.active_learning import ClozeModelConfig
|
|
70
|
+
>>> items = [
|
|
71
|
+
... Item(
|
|
72
|
+
... item_template_id=uuid4(),
|
|
73
|
+
... rendered_elements={"text": "The cat ___."},
|
|
74
|
+
... unfilled_slots=[
|
|
75
|
+
... UnfilledSlot(slot_name="verb", position=2, constraint_ids=[])
|
|
76
|
+
... ]
|
|
77
|
+
... )
|
|
78
|
+
... for _ in range(6)
|
|
79
|
+
... ]
|
|
80
|
+
>>> labels = [["ran"], ["jumped"], ["slept"]] * 2 # One token per unfilled slot
|
|
81
|
+
>>> config = ClozeModelConfig( # doctest: +SKIP
|
|
82
|
+
... num_epochs=1, batch_size=2, device="cpu"
|
|
83
|
+
... )
|
|
84
|
+
>>> model = ClozeModel(config=config) # doctest: +SKIP
|
|
85
|
+
>>> metrics = model.train(items, labels, participant_ids=None) # doctest: +SKIP
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
def __init__(
|
|
89
|
+
self,
|
|
90
|
+
config: ClozeModelConfig | None = None,
|
|
91
|
+
) -> None:
|
|
92
|
+
"""Initialize cloze model.
|
|
93
|
+
|
|
94
|
+
Parameters
|
|
95
|
+
----------
|
|
96
|
+
config : ClozeModelConfig | None
|
|
97
|
+
Configuration object. If None, uses default configuration.
|
|
98
|
+
"""
|
|
99
|
+
self.config = config or ClozeModelConfig()
|
|
100
|
+
|
|
101
|
+
# Validate mixed_effects configuration
|
|
102
|
+
super().__init__(self.config)
|
|
103
|
+
|
|
104
|
+
# Load tokenizer and masked LM model
|
|
105
|
+
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
|
|
106
|
+
self.model = AutoModelForMaskedLM.from_pretrained(self.config.model_name)
|
|
107
|
+
|
|
108
|
+
# Extract encoder and MLM head
|
|
109
|
+
# BERT-style models use model.bert and model.cls
|
|
110
|
+
# RoBERTa-style models use model.roberta and model.lm_head
|
|
111
|
+
if hasattr(self.model, "bert"):
|
|
112
|
+
self.encoder = self.model.bert
|
|
113
|
+
self.mlm_head = self.model.cls
|
|
114
|
+
elif hasattr(self.model, "roberta"):
|
|
115
|
+
self.encoder = self.model.roberta
|
|
116
|
+
self.mlm_head = self.model.lm_head
|
|
117
|
+
else:
|
|
118
|
+
# Fallback: try to use the base model attribute
|
|
119
|
+
self.encoder = self.model.base_model
|
|
120
|
+
self.mlm_head = self.model.lm_head
|
|
121
|
+
|
|
122
|
+
self._is_fitted = False
|
|
123
|
+
|
|
124
|
+
# Initialize random effects manager (created during training)
|
|
125
|
+
self.random_effects: RandomEffectsManager | None = None
|
|
126
|
+
self.variance_history: list[VarianceComponents] = []
|
|
127
|
+
|
|
128
|
+
self.model.to(self.config.device)
|
|
129
|
+
|
|
130
|
+
@property
|
|
131
|
+
def supported_task_types(self) -> list[TaskType]:
|
|
132
|
+
"""Get supported task types.
|
|
133
|
+
|
|
134
|
+
Returns
|
|
135
|
+
-------
|
|
136
|
+
list[TaskType]
|
|
137
|
+
List containing "cloze".
|
|
138
|
+
"""
|
|
139
|
+
return ["cloze"]
|
|
140
|
+
|
|
141
|
+
def validate_item_compatibility(
|
|
142
|
+
self, item: Item, item_template: ItemTemplate
|
|
143
|
+
) -> None:
|
|
144
|
+
"""Validate item is compatible with cloze model.
|
|
145
|
+
|
|
146
|
+
Parameters
|
|
147
|
+
----------
|
|
148
|
+
item : Item
|
|
149
|
+
Item to validate.
|
|
150
|
+
item_template : ItemTemplate
|
|
151
|
+
Template the item was constructed from.
|
|
152
|
+
|
|
153
|
+
Raises
|
|
154
|
+
------
|
|
155
|
+
ValueError
|
|
156
|
+
If task_type is not "cloze".
|
|
157
|
+
ValueError
|
|
158
|
+
If item has no unfilled_slots.
|
|
159
|
+
"""
|
|
160
|
+
if item_template.task_type != "cloze":
|
|
161
|
+
raise ValueError(
|
|
162
|
+
f"Expected task_type 'cloze', got '{item_template.task_type}'"
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
if not item.unfilled_slots:
|
|
166
|
+
raise ValueError(
|
|
167
|
+
"Cloze items must have at least one unfilled slot. "
|
|
168
|
+
f"Item {item.id} has no unfilled_slots."
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
def _prepare_inputs_and_masks(
|
|
172
|
+
self, items: list[Item]
|
|
173
|
+
) -> tuple[dict[str, torch.Tensor], list[list[int]]]:
|
|
174
|
+
"""Prepare tokenized inputs with masked positions.
|
|
175
|
+
|
|
176
|
+
Extracts text from items, tokenizes, and replaces tokens at unfilled_slots
|
|
177
|
+
positions with [MASK] token.
|
|
178
|
+
|
|
179
|
+
Parameters
|
|
180
|
+
----------
|
|
181
|
+
items : list[Item]
|
|
182
|
+
Items to prepare.
|
|
183
|
+
|
|
184
|
+
Returns
|
|
185
|
+
-------
|
|
186
|
+
tuple[dict[str, torch.Tensor], list[list[int]]]
|
|
187
|
+
- Tokenized inputs (input_ids, attention_mask)
|
|
188
|
+
- List of masked token positions per item (token-level indices)
|
|
189
|
+
"""
|
|
190
|
+
texts = []
|
|
191
|
+
n_slots_per_item = []
|
|
192
|
+
|
|
193
|
+
for item in items:
|
|
194
|
+
# Get rendered text
|
|
195
|
+
text = item.rendered_elements.get("text", "")
|
|
196
|
+
texts.append(text)
|
|
197
|
+
n_slots_per_item.append(len(item.unfilled_slots))
|
|
198
|
+
|
|
199
|
+
# Tokenize all texts
|
|
200
|
+
tokenized = self.tokenizer(
|
|
201
|
+
texts,
|
|
202
|
+
padding=True,
|
|
203
|
+
truncation=True,
|
|
204
|
+
max_length=self.config.max_length,
|
|
205
|
+
return_tensors="pt",
|
|
206
|
+
).to(self.config.device)
|
|
207
|
+
|
|
208
|
+
mask_token_id = self.tokenizer.mask_token_id
|
|
209
|
+
|
|
210
|
+
# Find and replace "___" placeholders with [MASK]
|
|
211
|
+
# Track ONE position per unfilled slot (even if "___" spans multiple tokens)
|
|
212
|
+
token_masked_positions = []
|
|
213
|
+
for i, text in enumerate(texts):
|
|
214
|
+
# Tokenize individually to find "___" positions
|
|
215
|
+
tokens = self.tokenizer.tokenize(text)
|
|
216
|
+
masked_indices = []
|
|
217
|
+
|
|
218
|
+
# Track which tokens are part of "___" to avoid duplicates
|
|
219
|
+
in_blank = False
|
|
220
|
+
for j, token in enumerate(tokens):
|
|
221
|
+
# Check if this token is part of a "___" placeholder
|
|
222
|
+
if "_" in token and not in_blank:
|
|
223
|
+
# Start of a new blank - record this position
|
|
224
|
+
token_idx = j + 1 # Add 1 for [CLS] token
|
|
225
|
+
masked_indices.append(token_idx)
|
|
226
|
+
in_blank = True
|
|
227
|
+
# Replace with [MASK]
|
|
228
|
+
if token_idx < tokenized["input_ids"].shape[1]:
|
|
229
|
+
tokenized["input_ids"][i, token_idx] = mask_token_id
|
|
230
|
+
elif "_" in token and in_blank:
|
|
231
|
+
# Continuation of current blank - also mask but don't record
|
|
232
|
+
token_idx = j + 1
|
|
233
|
+
if token_idx < tokenized["input_ids"].shape[1]:
|
|
234
|
+
tokenized["input_ids"][i, token_idx] = mask_token_id
|
|
235
|
+
else:
|
|
236
|
+
# Not a blank token - reset in_blank
|
|
237
|
+
in_blank = False
|
|
238
|
+
|
|
239
|
+
# Verify we found the expected number of masked positions
|
|
240
|
+
expected_slots = n_slots_per_item[i]
|
|
241
|
+
if len(masked_indices) != expected_slots:
|
|
242
|
+
raise ValueError(
|
|
243
|
+
f"Mismatch between masked positions and unfilled_slots "
|
|
244
|
+
f"for item {i}: found {len(masked_indices)} '___' "
|
|
245
|
+
f"placeholders in text but item has {expected_slots} "
|
|
246
|
+
f"unfilled_slots. Ensure rendered text uses exactly one "
|
|
247
|
+
f"'___' per unfilled_slot. Text: '{text}'"
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
token_masked_positions.append(masked_indices)
|
|
251
|
+
|
|
252
|
+
return tokenized, token_masked_positions
|
|
253
|
+
|
|
254
|
+
def _prepare_training_data(
|
|
255
|
+
self,
|
|
256
|
+
items: list[Item],
|
|
257
|
+
labels: list[str] | list[list[str]],
|
|
258
|
+
participant_ids: list[str],
|
|
259
|
+
validation_items: list[Item] | None,
|
|
260
|
+
validation_labels: list[str] | list[list[str]] | None,
|
|
261
|
+
) -> tuple[
|
|
262
|
+
list[Item],
|
|
263
|
+
list[list[str]],
|
|
264
|
+
list[Item] | None,
|
|
265
|
+
list[list[str]] | None,
|
|
266
|
+
]:
|
|
267
|
+
"""Prepare data for training, including validation of label format.
|
|
268
|
+
|
|
269
|
+
Parameters
|
|
270
|
+
----------
|
|
271
|
+
items : list[Item]
|
|
272
|
+
Training items.
|
|
273
|
+
labels : list[list[str]]
|
|
274
|
+
Training labels as list of lists (one token per unfilled slot).
|
|
275
|
+
participant_ids : list[str]
|
|
276
|
+
Participant IDs (already normalized).
|
|
277
|
+
validation_items : list[Item] | None
|
|
278
|
+
Validation items.
|
|
279
|
+
validation_labels : list[list[str]] | None
|
|
280
|
+
Validation labels.
|
|
281
|
+
|
|
282
|
+
Returns
|
|
283
|
+
-------
|
|
284
|
+
tuple[list[Item], list[list[str]], list[Item] | None, list[list[str]] | None]
|
|
285
|
+
Prepared items, labels, validation items, validation labels.
|
|
286
|
+
"""
|
|
287
|
+
# Validate labels format: each label must be a list matching unfilled_slots
|
|
288
|
+
labels_list = list(labels) # Type: list[list[str]]
|
|
289
|
+
for i, (item, label) in enumerate(zip(items, labels_list, strict=True)):
|
|
290
|
+
if not isinstance(label, list):
|
|
291
|
+
raise ValueError(
|
|
292
|
+
f"ClozeModel requires labels to be list[list[str]], "
|
|
293
|
+
f"but got {type(label)} for item {i}"
|
|
294
|
+
)
|
|
295
|
+
if len(label) != len(item.unfilled_slots):
|
|
296
|
+
raise ValueError(
|
|
297
|
+
f"Label length mismatch for item {i}: "
|
|
298
|
+
f"expected {len(item.unfilled_slots)} tokens "
|
|
299
|
+
f"(matching unfilled_slots), got {len(label)} tokens. "
|
|
300
|
+
f"Ensure each label is a list with one token per unfilled slot."
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
val_labels_list: list[list[str]] | None = None
|
|
304
|
+
if validation_items is not None and validation_labels is not None:
|
|
305
|
+
val_labels_list = list(validation_labels) # Type: list[list[str]]
|
|
306
|
+
for i, (item, label) in enumerate(
|
|
307
|
+
zip(validation_items, val_labels_list, strict=True)
|
|
308
|
+
):
|
|
309
|
+
if not isinstance(label, list):
|
|
310
|
+
raise ValueError(
|
|
311
|
+
f"ClozeModel requires validation_labels to be list[list[str]], "
|
|
312
|
+
f"but got {type(label)} for validation item {i}"
|
|
313
|
+
)
|
|
314
|
+
if len(label) != len(item.unfilled_slots):
|
|
315
|
+
raise ValueError(
|
|
316
|
+
f"Validation label length mismatch for item {i}: "
|
|
317
|
+
f"expected {len(item.unfilled_slots)} tokens, "
|
|
318
|
+
f"got {len(label)} tokens."
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
return items, labels_list, participant_ids, validation_items, val_labels_list
|
|
322
|
+
|
|
323
|
+
def _do_training(
|
|
324
|
+
self,
|
|
325
|
+
items: list[Item],
|
|
326
|
+
labels_numeric: list[list[str]],
|
|
327
|
+
participant_ids: list[str],
|
|
328
|
+
validation_items: list[Item] | None,
|
|
329
|
+
validation_labels_numeric: list[list[str]] | None,
|
|
330
|
+
) -> dict[str, float]:
|
|
331
|
+
"""Perform the actual training logic (HuggingFace Trainer or custom loop).
|
|
332
|
+
|
|
333
|
+
Parameters
|
|
334
|
+
----------
|
|
335
|
+
items : list[Item]
|
|
336
|
+
Training items.
|
|
337
|
+
labels_numeric : list[list[str]]
|
|
338
|
+
Training labels (already validated).
|
|
339
|
+
participant_ids : list[str]
|
|
340
|
+
Participant IDs.
|
|
341
|
+
validation_items : list[Item] | None
|
|
342
|
+
Validation items.
|
|
343
|
+
validation_labels_numeric : list[list[str]] | None
|
|
344
|
+
Validation labels.
|
|
345
|
+
|
|
346
|
+
Returns
|
|
347
|
+
-------
|
|
348
|
+
dict[str, float]
|
|
349
|
+
Training metrics.
|
|
350
|
+
"""
|
|
351
|
+
# Use HuggingFace Trainer for fixed and random_intercepts modes
|
|
352
|
+
# random_slopes requires custom loop due to per-participant MLM heads
|
|
353
|
+
use_huggingface_trainer = self.config.mixed_effects.mode in (
|
|
354
|
+
"fixed",
|
|
355
|
+
"random_intercepts",
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
if use_huggingface_trainer:
|
|
359
|
+
return self._train_with_huggingface_trainer(
|
|
360
|
+
items,
|
|
361
|
+
labels_numeric,
|
|
362
|
+
participant_ids,
|
|
363
|
+
validation_items,
|
|
364
|
+
validation_labels_numeric,
|
|
365
|
+
)
|
|
366
|
+
else:
|
|
367
|
+
# Use custom training loop for random_slopes
|
|
368
|
+
return self._train_with_custom_loop(
|
|
369
|
+
items,
|
|
370
|
+
labels_numeric,
|
|
371
|
+
participant_ids,
|
|
372
|
+
validation_items,
|
|
373
|
+
validation_labels_numeric,
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
def _train_with_huggingface_trainer(
|
|
377
|
+
self,
|
|
378
|
+
items: list[Item],
|
|
379
|
+
labels: list[list[str]],
|
|
380
|
+
participant_ids: list[str],
|
|
381
|
+
validation_items: list[Item] | None,
|
|
382
|
+
validation_labels: list[list[str]] | None,
|
|
383
|
+
) -> dict[str, float]:
|
|
384
|
+
"""Train using HuggingFace Trainer with mixed effects support for MLM.
|
|
385
|
+
|
|
386
|
+
Parameters
|
|
387
|
+
----------
|
|
388
|
+
items : list[Item]
|
|
389
|
+
Training items with unfilled_slots.
|
|
390
|
+
labels : list[list[str]]
|
|
391
|
+
Training labels as list of lists (one token per unfilled slot).
|
|
392
|
+
participant_ids : list[str]
|
|
393
|
+
Participant IDs.
|
|
394
|
+
validation_items : list[Item] | None
|
|
395
|
+
Validation items.
|
|
396
|
+
validation_labels : list[list[str]] | None
|
|
397
|
+
Validation labels.
|
|
398
|
+
|
|
399
|
+
Returns
|
|
400
|
+
-------
|
|
401
|
+
dict[str, float]
|
|
402
|
+
Training metrics.
|
|
403
|
+
"""
|
|
404
|
+
# Convert items to HuggingFace Dataset with masking
|
|
405
|
+
train_dataset = cloze_items_to_dataset(
|
|
406
|
+
items=items,
|
|
407
|
+
labels=labels,
|
|
408
|
+
participant_ids=participant_ids,
|
|
409
|
+
tokenizer=self.tokenizer,
|
|
410
|
+
max_length=self.config.max_length,
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
eval_dataset = None
|
|
414
|
+
if validation_items is not None and validation_labels is not None:
|
|
415
|
+
val_participant_ids = (
|
|
416
|
+
["_validation_"] * len(validation_items)
|
|
417
|
+
if self.config.mixed_effects.mode != "fixed"
|
|
418
|
+
else ["_fixed_"] * len(validation_items)
|
|
419
|
+
)
|
|
420
|
+
eval_dataset = cloze_items_to_dataset(
|
|
421
|
+
items=validation_items,
|
|
422
|
+
labels=validation_labels,
|
|
423
|
+
participant_ids=val_participant_ids,
|
|
424
|
+
tokenizer=self.tokenizer,
|
|
425
|
+
max_length=self.config.max_length,
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
# Use the model directly (no wrapper needed for MLM models)
|
|
429
|
+
# The model is already compatible with HuggingFace Trainer
|
|
430
|
+
wrapped_model = self.model
|
|
431
|
+
|
|
432
|
+
# Create data collator
|
|
433
|
+
data_collator = ClozeDataCollator(tokenizer=self.tokenizer)
|
|
434
|
+
|
|
435
|
+
# Create training arguments with checkpointing
|
|
436
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
437
|
+
checkpoint_dir = Path(tmpdir) / "checkpoints"
|
|
438
|
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
439
|
+
|
|
440
|
+
training_args = TrainingArguments(
|
|
441
|
+
output_dir=str(checkpoint_dir),
|
|
442
|
+
num_train_epochs=self.config.num_epochs,
|
|
443
|
+
per_device_train_batch_size=self.config.batch_size,
|
|
444
|
+
per_device_eval_batch_size=self.config.batch_size,
|
|
445
|
+
learning_rate=self.config.learning_rate,
|
|
446
|
+
logging_steps=10,
|
|
447
|
+
eval_strategy="epoch" if eval_dataset is not None else "no",
|
|
448
|
+
save_strategy="epoch",
|
|
449
|
+
save_total_limit=1,
|
|
450
|
+
load_best_model_at_end=False,
|
|
451
|
+
report_to="none",
|
|
452
|
+
remove_unused_columns=False,
|
|
453
|
+
use_cpu=self.config.device == "cpu",
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
# Create metrics computation function
|
|
457
|
+
def compute_metrics_fn(eval_pred: object) -> dict[str, float]:
|
|
458
|
+
from bead.active_learning.trainers.metrics import ( # noqa: PLC0415
|
|
459
|
+
compute_cloze_metrics,
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
return compute_cloze_metrics(eval_pred, tokenizer=self.tokenizer)
|
|
463
|
+
|
|
464
|
+
# Import here to avoid circular import
|
|
465
|
+
from bead.active_learning.trainers.mixed_effects import ( # noqa: PLC0415
|
|
466
|
+
ClozeMLMTrainer,
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
# Create trainer
|
|
470
|
+
trainer = ClozeMLMTrainer(
|
|
471
|
+
model=wrapped_model,
|
|
472
|
+
args=training_args,
|
|
473
|
+
train_dataset=train_dataset,
|
|
474
|
+
eval_dataset=eval_dataset,
|
|
475
|
+
data_collator=data_collator,
|
|
476
|
+
tokenizer=self.tokenizer,
|
|
477
|
+
random_effects_manager=self.random_effects,
|
|
478
|
+
compute_metrics=compute_metrics_fn,
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
# Train
|
|
482
|
+
train_result = trainer.train()
|
|
483
|
+
|
|
484
|
+
# Get training metrics
|
|
485
|
+
train_metrics = trainer.evaluate(eval_dataset=train_dataset)
|
|
486
|
+
metrics: dict[str, float] = {
|
|
487
|
+
"train_loss": float(train_result.training_loss),
|
|
488
|
+
"train_accuracy": train_metrics.get("eval_accuracy", 0.0),
|
|
489
|
+
}
|
|
490
|
+
|
|
491
|
+
# Get validation metrics if eval_dataset was provided
|
|
492
|
+
if eval_dataset is not None:
|
|
493
|
+
val_metrics = trainer.evaluate(eval_dataset=eval_dataset)
|
|
494
|
+
metrics["val_accuracy"] = val_metrics.get("eval_accuracy", 0.0)
|
|
495
|
+
|
|
496
|
+
return metrics
|
|
497
|
+
|
|
498
|
+
def _train_with_custom_loop(
|
|
499
|
+
self,
|
|
500
|
+
items: list[Item],
|
|
501
|
+
labels: list[list[str]],
|
|
502
|
+
participant_ids: list[str],
|
|
503
|
+
validation_items: list[Item] | None,
|
|
504
|
+
validation_labels: list[list[str]] | None,
|
|
505
|
+
) -> dict[str, float]:
|
|
506
|
+
"""Train using custom loop for random_slopes mode.
|
|
507
|
+
|
|
508
|
+
Parameters
|
|
509
|
+
----------
|
|
510
|
+
items : list[Item]
|
|
511
|
+
Training items with unfilled_slots.
|
|
512
|
+
labels : list[list[str]]
|
|
513
|
+
Training labels as list of lists.
|
|
514
|
+
participant_ids : list[str]
|
|
515
|
+
Participant IDs.
|
|
516
|
+
validation_items : list[Item] | None
|
|
517
|
+
Validation items.
|
|
518
|
+
validation_labels : list[list[str]] | None
|
|
519
|
+
Validation labels.
|
|
520
|
+
|
|
521
|
+
Returns
|
|
522
|
+
-------
|
|
523
|
+
dict[str, float]
|
|
524
|
+
Training metrics.
|
|
525
|
+
"""
|
|
526
|
+
# Build optimizer parameters
|
|
527
|
+
params_to_optimize = list(self.model.parameters())
|
|
528
|
+
|
|
529
|
+
# Add random effects parameters for random_slopes
|
|
530
|
+
for head in self.random_effects.slopes.values():
|
|
531
|
+
params_to_optimize.extend(head.parameters())
|
|
532
|
+
|
|
533
|
+
optimizer = torch.optim.AdamW(params_to_optimize, lr=self.config.learning_rate)
|
|
534
|
+
|
|
535
|
+
self.model.train()
|
|
536
|
+
|
|
537
|
+
for _epoch in range(self.config.num_epochs):
|
|
538
|
+
n_batches = (
|
|
539
|
+
len(items) + self.config.batch_size - 1
|
|
540
|
+
) // self.config.batch_size
|
|
541
|
+
epoch_loss = 0.0
|
|
542
|
+
|
|
543
|
+
for i in range(n_batches):
|
|
544
|
+
start_idx = i * self.config.batch_size
|
|
545
|
+
end_idx = min(start_idx + self.config.batch_size, len(items))
|
|
546
|
+
|
|
547
|
+
batch_items = items[start_idx:end_idx]
|
|
548
|
+
batch_labels = labels[start_idx:end_idx]
|
|
549
|
+
batch_participant_ids = participant_ids[start_idx:end_idx]
|
|
550
|
+
|
|
551
|
+
# Prepare inputs with masking
|
|
552
|
+
tokenized, masked_positions = self._prepare_inputs_and_masks(
|
|
553
|
+
batch_items
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
# Tokenize labels to get target token IDs
|
|
557
|
+
target_token_ids = []
|
|
558
|
+
for label_list in batch_labels:
|
|
559
|
+
token_ids = []
|
|
560
|
+
for token in label_list:
|
|
561
|
+
tid = self.tokenizer.encode(token, add_special_tokens=False)[0]
|
|
562
|
+
token_ids.append(tid)
|
|
563
|
+
target_token_ids.append(token_ids)
|
|
564
|
+
|
|
565
|
+
# Use participant-specific MLM heads for random_slopes
|
|
566
|
+
all_logits = []
|
|
567
|
+
for j, pid in enumerate(batch_participant_ids):
|
|
568
|
+
# Get participant-specific MLM head
|
|
569
|
+
participant_head = self.random_effects.get_slopes(
|
|
570
|
+
pid,
|
|
571
|
+
fixed_head=copy.deepcopy(self.mlm_head),
|
|
572
|
+
create_if_missing=True,
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
# Get encoder outputs for this item
|
|
576
|
+
item_inputs = {k: v[j : j + 1] for k, v in tokenized.items()}
|
|
577
|
+
encoder_outputs_j = self.encoder(**item_inputs)
|
|
578
|
+
|
|
579
|
+
# Run participant-specific MLM head
|
|
580
|
+
logits_j = participant_head(encoder_outputs_j.last_hidden_state)
|
|
581
|
+
all_logits.append(logits_j)
|
|
582
|
+
|
|
583
|
+
logits = torch.cat(all_logits, dim=0)
|
|
584
|
+
|
|
585
|
+
# Compute loss only on masked positions
|
|
586
|
+
losses = []
|
|
587
|
+
for j, (masked_pos, target_ids) in enumerate(
|
|
588
|
+
zip(masked_positions, target_token_ids, strict=True)
|
|
589
|
+
):
|
|
590
|
+
for pos, target_id in zip(masked_pos, target_ids, strict=True):
|
|
591
|
+
if pos < logits.shape[1]:
|
|
592
|
+
loss_j = torch.nn.functional.cross_entropy(
|
|
593
|
+
logits[j, pos : pos + 1],
|
|
594
|
+
torch.tensor([target_id], device=self.config.device),
|
|
595
|
+
)
|
|
596
|
+
losses.append(loss_j)
|
|
597
|
+
|
|
598
|
+
if losses:
|
|
599
|
+
loss_nll = torch.stack(losses).mean()
|
|
600
|
+
else:
|
|
601
|
+
loss_nll = torch.tensor(0.0, device=self.config.device)
|
|
602
|
+
|
|
603
|
+
# Add prior regularization
|
|
604
|
+
loss_prior = self.random_effects.compute_prior_loss()
|
|
605
|
+
loss = loss_nll + loss_prior
|
|
606
|
+
|
|
607
|
+
optimizer.zero_grad()
|
|
608
|
+
loss.backward()
|
|
609
|
+
optimizer.step()
|
|
610
|
+
|
|
611
|
+
epoch_loss += loss.item()
|
|
612
|
+
|
|
613
|
+
epoch_loss = epoch_loss / n_batches
|
|
614
|
+
|
|
615
|
+
metrics: dict[str, float] = {
|
|
616
|
+
"train_loss": epoch_loss,
|
|
617
|
+
}
|
|
618
|
+
|
|
619
|
+
# Compute training accuracy
|
|
620
|
+
train_predictions = self._do_predict(items, participant_ids)
|
|
621
|
+
correct = 0
|
|
622
|
+
total = 0
|
|
623
|
+
for pred, label in zip(train_predictions, labels, strict=True):
|
|
624
|
+
# pred.predicted_class is comma-separated tokens
|
|
625
|
+
pred_tokens = pred.predicted_class.split(", ")
|
|
626
|
+
for pt, lt in zip(pred_tokens, label, strict=True):
|
|
627
|
+
if pt.lower() == lt.lower():
|
|
628
|
+
correct += 1
|
|
629
|
+
total += 1
|
|
630
|
+
if total > 0:
|
|
631
|
+
metrics["train_accuracy"] = correct / total
|
|
632
|
+
|
|
633
|
+
return metrics
|
|
634
|
+
|
|
635
|
+
def _do_predict(
|
|
636
|
+
self, items: list[Item], participant_ids: list[str]
|
|
637
|
+
) -> list[ModelPrediction]:
|
|
638
|
+
"""Perform cloze model prediction.
|
|
639
|
+
|
|
640
|
+
Parameters
|
|
641
|
+
----------
|
|
642
|
+
items : list[Item]
|
|
643
|
+
Items to predict.
|
|
644
|
+
participant_ids : list[str]
|
|
645
|
+
Normalized participant IDs.
|
|
646
|
+
|
|
647
|
+
Returns
|
|
648
|
+
-------
|
|
649
|
+
list[ModelPrediction]
|
|
650
|
+
Predictions with predicted_class as comma-separated tokens.
|
|
651
|
+
"""
|
|
652
|
+
self.model.eval()
|
|
653
|
+
|
|
654
|
+
# Prepare inputs with masking
|
|
655
|
+
tokenized, masked_positions = self._prepare_inputs_and_masks(items)
|
|
656
|
+
|
|
657
|
+
with torch.no_grad():
|
|
658
|
+
if self.config.mixed_effects.mode == "fixed":
|
|
659
|
+
# Standard MLM prediction
|
|
660
|
+
outputs = self.model(**tokenized)
|
|
661
|
+
logits = outputs.logits
|
|
662
|
+
|
|
663
|
+
elif self.config.mixed_effects.mode == "random_intercepts":
|
|
664
|
+
# Get encoder outputs
|
|
665
|
+
encoder_outputs = self.encoder(**tokenized)
|
|
666
|
+
logits = self.mlm_head(encoder_outputs.last_hidden_state)
|
|
667
|
+
|
|
668
|
+
# Add participant-specific bias
|
|
669
|
+
vocab_size = self.tokenizer.vocab_size
|
|
670
|
+
for j, pid in enumerate(participant_ids):
|
|
671
|
+
bias = self.random_effects.get_intercepts(
|
|
672
|
+
pid,
|
|
673
|
+
n_classes=vocab_size,
|
|
674
|
+
param_name="mu",
|
|
675
|
+
create_if_missing=False,
|
|
676
|
+
)
|
|
677
|
+
# Add to all masked positions
|
|
678
|
+
for pos in masked_positions[j]:
|
|
679
|
+
if pos < logits.shape[1]:
|
|
680
|
+
logits[j, pos] = logits[j, pos] + bias
|
|
681
|
+
|
|
682
|
+
elif self.config.mixed_effects.mode == "random_slopes":
|
|
683
|
+
# Use participant-specific MLM heads
|
|
684
|
+
all_logits = []
|
|
685
|
+
for j, pid in enumerate(participant_ids):
|
|
686
|
+
# Get participant-specific MLM head
|
|
687
|
+
participant_head = self.random_effects.get_slopes(
|
|
688
|
+
pid,
|
|
689
|
+
fixed_head=copy.deepcopy(self.mlm_head),
|
|
690
|
+
create_if_missing=False,
|
|
691
|
+
)
|
|
692
|
+
|
|
693
|
+
# Get encoder outputs
|
|
694
|
+
item_inputs = {k: v[j : j + 1] for k, v in tokenized.items()}
|
|
695
|
+
encoder_outputs_j = self.encoder(**item_inputs)
|
|
696
|
+
|
|
697
|
+
# Run participant-specific MLM head
|
|
698
|
+
logits_j = participant_head(encoder_outputs_j.last_hidden_state)
|
|
699
|
+
all_logits.append(logits_j)
|
|
700
|
+
|
|
701
|
+
logits = torch.cat(all_logits, dim=0)
|
|
702
|
+
|
|
703
|
+
# Get argmax at masked positions
|
|
704
|
+
predictions = []
|
|
705
|
+
for i, masked_pos in enumerate(masked_positions):
|
|
706
|
+
predicted_tokens = []
|
|
707
|
+
for pos in masked_pos:
|
|
708
|
+
if pos < logits.shape[1]:
|
|
709
|
+
# Get token ID with highest probability
|
|
710
|
+
token_id = torch.argmax(logits[i, pos]).item()
|
|
711
|
+
# Decode token
|
|
712
|
+
token = self.tokenizer.decode([token_id])
|
|
713
|
+
predicted_tokens.append(token.strip())
|
|
714
|
+
|
|
715
|
+
# Join with comma for multi-slot items
|
|
716
|
+
predicted_class = ", ".join(predicted_tokens)
|
|
717
|
+
|
|
718
|
+
predictions.append(
|
|
719
|
+
ModelPrediction(
|
|
720
|
+
item_id=str(items[i].id),
|
|
721
|
+
probabilities={}, # Not applicable for generation
|
|
722
|
+
predicted_class=predicted_class,
|
|
723
|
+
confidence=1.0, # Not applicable for generation
|
|
724
|
+
)
|
|
725
|
+
)
|
|
726
|
+
|
|
727
|
+
return predictions
|
|
728
|
+
|
|
729
|
+
def _do_predict_proba(
|
|
730
|
+
self, items: list[Item], participant_ids: list[str]
|
|
731
|
+
) -> np.ndarray:
|
|
732
|
+
"""Perform cloze model probability prediction.
|
|
733
|
+
|
|
734
|
+
For cloze tasks, returns empty array as probabilities are not typically
|
|
735
|
+
used for evaluation.
|
|
736
|
+
|
|
737
|
+
Parameters
|
|
738
|
+
----------
|
|
739
|
+
items : list[Item]
|
|
740
|
+
Items to predict.
|
|
741
|
+
participant_ids : list[str]
|
|
742
|
+
Normalized participant IDs.
|
|
743
|
+
|
|
744
|
+
Returns
|
|
745
|
+
-------
|
|
746
|
+
np.ndarray
|
|
747
|
+
Empty array of shape (n_items, 0).
|
|
748
|
+
"""
|
|
749
|
+
return np.zeros((len(items), 0))
|
|
750
|
+
|
|
751
|
+
def _save_model_components(self, save_path: Path) -> None:
|
|
752
|
+
"""Save model-specific components (model, tokenizer).
|
|
753
|
+
|
|
754
|
+
Parameters
|
|
755
|
+
----------
|
|
756
|
+
save_path : Path
|
|
757
|
+
Directory to save to.
|
|
758
|
+
"""
|
|
759
|
+
self.model.save_pretrained(save_path / "model")
|
|
760
|
+
self.tokenizer.save_pretrained(save_path / "model")
|
|
761
|
+
|
|
762
|
+
def _get_save_state(self) -> dict[str, object]:
|
|
763
|
+
"""Get model-specific state to save in config.json.
|
|
764
|
+
|
|
765
|
+
Returns
|
|
766
|
+
-------
|
|
767
|
+
dict[str, object]
|
|
768
|
+
State dictionary to include in config.json.
|
|
769
|
+
"""
|
|
770
|
+
return {}
|
|
771
|
+
|
|
772
|
+
def _load_model_components(self, load_path: Path) -> None:
|
|
773
|
+
"""Load model-specific components.
|
|
774
|
+
|
|
775
|
+
Parameters
|
|
776
|
+
----------
|
|
777
|
+
load_path : Path
|
|
778
|
+
Directory to load from.
|
|
779
|
+
"""
|
|
780
|
+
# Load config.json to reconstruct config
|
|
781
|
+
with open(load_path / "config.json") as f:
|
|
782
|
+
import json # noqa: PLC0415
|
|
783
|
+
|
|
784
|
+
config_dict = json.load(f)
|
|
785
|
+
|
|
786
|
+
# Reconstruct MixedEffectsConfig if needed
|
|
787
|
+
if "mixed_effects" in config_dict and isinstance(
|
|
788
|
+
config_dict["mixed_effects"], dict
|
|
789
|
+
):
|
|
790
|
+
from bead.active_learning.config import MixedEffectsConfig # noqa: PLC0415
|
|
791
|
+
|
|
792
|
+
config_dict["mixed_effects"] = MixedEffectsConfig(
|
|
793
|
+
**config_dict["mixed_effects"]
|
|
794
|
+
)
|
|
795
|
+
|
|
796
|
+
# Reconstruct ClozeModelConfig
|
|
797
|
+
self.config = ClozeModelConfig(**config_dict)
|
|
798
|
+
|
|
799
|
+
# Load model
|
|
800
|
+
self.model = AutoModelForMaskedLM.from_pretrained(load_path / "model")
|
|
801
|
+
self.tokenizer = AutoTokenizer.from_pretrained(load_path / "model")
|
|
802
|
+
|
|
803
|
+
# Re-extract components
|
|
804
|
+
if hasattr(self.model, "bert"):
|
|
805
|
+
self.encoder = self.model.bert
|
|
806
|
+
self.mlm_head = self.model.cls
|
|
807
|
+
elif hasattr(self.model, "roberta"):
|
|
808
|
+
self.encoder = self.model.roberta
|
|
809
|
+
self.mlm_head = self.model.lm_head
|
|
810
|
+
else:
|
|
811
|
+
self.encoder = self.model.base_model
|
|
812
|
+
self.mlm_head = self.model.lm_head
|
|
813
|
+
|
|
814
|
+
self.model.to(self.config.device)
|
|
815
|
+
|
|
816
|
+
def _restore_training_state(self, config_dict: dict[str, object]) -> None:
|
|
817
|
+
"""Restore model-specific training state from config_dict.
|
|
818
|
+
|
|
819
|
+
Parameters
|
|
820
|
+
----------
|
|
821
|
+
config_dict : dict[str, object]
|
|
822
|
+
Configuration dictionary with training state.
|
|
823
|
+
"""
|
|
824
|
+
# ClozeModel doesn't have additional training state to restore
|
|
825
|
+
pass
|
|
826
|
+
|
|
827
|
+
def _get_n_classes_for_random_effects(self) -> int:
|
|
828
|
+
"""Get the number of classes for initializing RandomEffectsManager.
|
|
829
|
+
|
|
830
|
+
For cloze models, this is the vocabulary size.
|
|
831
|
+
|
|
832
|
+
Returns
|
|
833
|
+
-------
|
|
834
|
+
int
|
|
835
|
+
Vocabulary size.
|
|
836
|
+
"""
|
|
837
|
+
return self.tokenizer.vocab_size
|
|
838
|
+
|
|
839
|
+
def _initialize_random_effects(self, n_classes: int) -> None:
|
|
840
|
+
"""Initialize the RandomEffectsManager.
|
|
841
|
+
|
|
842
|
+
Parameters
|
|
843
|
+
----------
|
|
844
|
+
n_classes : int
|
|
845
|
+
Vocabulary size for cloze models.
|
|
846
|
+
"""
|
|
847
|
+
self.random_effects = RandomEffectsManager(
|
|
848
|
+
self.config.mixed_effects,
|
|
849
|
+
vocab_size=n_classes, # For random intercepts (bias on logits)
|
|
850
|
+
)
|
|
851
|
+
|
|
852
|
+
def _get_random_effects_fixed_head(self) -> torch.nn.Module | None:
|
|
853
|
+
"""Get the fixed head for random effects (classifier_head, etc.).
|
|
854
|
+
|
|
855
|
+
For cloze models, this is the MLM head.
|
|
856
|
+
|
|
857
|
+
Returns
|
|
858
|
+
-------
|
|
859
|
+
torch.nn.Module | None
|
|
860
|
+
The MLM head, or None if not applicable.
|
|
861
|
+
"""
|
|
862
|
+
return self.mlm_head
|