bead 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bead/__init__.py +11 -0
- bead/__main__.py +11 -0
- bead/active_learning/__init__.py +15 -0
- bead/active_learning/config.py +231 -0
- bead/active_learning/loop.py +566 -0
- bead/active_learning/models/__init__.py +24 -0
- bead/active_learning/models/base.py +852 -0
- bead/active_learning/models/binary.py +910 -0
- bead/active_learning/models/categorical.py +943 -0
- bead/active_learning/models/cloze.py +862 -0
- bead/active_learning/models/forced_choice.py +956 -0
- bead/active_learning/models/free_text.py +773 -0
- bead/active_learning/models/lora.py +365 -0
- bead/active_learning/models/magnitude.py +835 -0
- bead/active_learning/models/multi_select.py +795 -0
- bead/active_learning/models/ordinal_scale.py +811 -0
- bead/active_learning/models/peft_adapter.py +155 -0
- bead/active_learning/models/random_effects.py +639 -0
- bead/active_learning/selection.py +354 -0
- bead/active_learning/strategies.py +391 -0
- bead/active_learning/trainers/__init__.py +26 -0
- bead/active_learning/trainers/base.py +210 -0
- bead/active_learning/trainers/data_collator.py +172 -0
- bead/active_learning/trainers/dataset_utils.py +261 -0
- bead/active_learning/trainers/huggingface.py +304 -0
- bead/active_learning/trainers/lightning.py +324 -0
- bead/active_learning/trainers/metrics.py +424 -0
- bead/active_learning/trainers/mixed_effects.py +551 -0
- bead/active_learning/trainers/model_wrapper.py +509 -0
- bead/active_learning/trainers/registry.py +104 -0
- bead/adapters/__init__.py +11 -0
- bead/adapters/huggingface.py +61 -0
- bead/behavioral/__init__.py +116 -0
- bead/behavioral/analytics.py +646 -0
- bead/behavioral/extraction.py +343 -0
- bead/behavioral/merging.py +343 -0
- bead/cli/__init__.py +11 -0
- bead/cli/active_learning.py +513 -0
- bead/cli/active_learning_commands.py +779 -0
- bead/cli/completion.py +359 -0
- bead/cli/config.py +624 -0
- bead/cli/constraint_builders.py +286 -0
- bead/cli/deployment.py +859 -0
- bead/cli/deployment_trials.py +493 -0
- bead/cli/deployment_ui.py +332 -0
- bead/cli/display.py +378 -0
- bead/cli/items.py +960 -0
- bead/cli/items_factories.py +776 -0
- bead/cli/list_constraints.py +714 -0
- bead/cli/lists.py +490 -0
- bead/cli/main.py +430 -0
- bead/cli/models.py +877 -0
- bead/cli/resource_loaders.py +621 -0
- bead/cli/resources.py +1036 -0
- bead/cli/shell.py +356 -0
- bead/cli/simulate.py +840 -0
- bead/cli/templates.py +1158 -0
- bead/cli/training.py +1080 -0
- bead/cli/utils.py +614 -0
- bead/cli/workflow.py +1273 -0
- bead/config/__init__.py +68 -0
- bead/config/active_learning.py +1009 -0
- bead/config/config.py +192 -0
- bead/config/defaults.py +118 -0
- bead/config/deployment.py +217 -0
- bead/config/env.py +147 -0
- bead/config/item.py +45 -0
- bead/config/list.py +193 -0
- bead/config/loader.py +149 -0
- bead/config/logging.py +42 -0
- bead/config/model.py +49 -0
- bead/config/paths.py +46 -0
- bead/config/profiles.py +320 -0
- bead/config/resources.py +47 -0
- bead/config/serialization.py +210 -0
- bead/config/simulation.py +206 -0
- bead/config/template.py +238 -0
- bead/config/validation.py +267 -0
- bead/data/__init__.py +65 -0
- bead/data/base.py +87 -0
- bead/data/identifiers.py +97 -0
- bead/data/language_codes.py +61 -0
- bead/data/metadata.py +270 -0
- bead/data/range.py +123 -0
- bead/data/repository.py +358 -0
- bead/data/serialization.py +249 -0
- bead/data/timestamps.py +89 -0
- bead/data/validation.py +349 -0
- bead/data_collection/__init__.py +11 -0
- bead/data_collection/jatos.py +223 -0
- bead/data_collection/merger.py +154 -0
- bead/data_collection/prolific.py +198 -0
- bead/deployment/__init__.py +5 -0
- bead/deployment/distribution.py +402 -0
- bead/deployment/jatos/__init__.py +1 -0
- bead/deployment/jatos/api.py +200 -0
- bead/deployment/jatos/exporter.py +210 -0
- bead/deployment/jspsych/__init__.py +9 -0
- bead/deployment/jspsych/biome.json +44 -0
- bead/deployment/jspsych/config.py +411 -0
- bead/deployment/jspsych/generator.py +598 -0
- bead/deployment/jspsych/package.json +51 -0
- bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
- bead/deployment/jspsych/randomizer.py +299 -0
- bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
- bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
- bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
- bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
- bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
- bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
- bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
- bead/deployment/jspsych/src/plugins/rating.ts +248 -0
- bead/deployment/jspsych/src/slopit/index.ts +9 -0
- bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
- bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
- bead/deployment/jspsych/templates/experiment.css +1 -0
- bead/deployment/jspsych/templates/experiment.js.template +289 -0
- bead/deployment/jspsych/templates/index.html +51 -0
- bead/deployment/jspsych/templates/randomizer.js +241 -0
- bead/deployment/jspsych/templates/randomizer.js.template +313 -0
- bead/deployment/jspsych/trials.py +723 -0
- bead/deployment/jspsych/tsconfig.json +23 -0
- bead/deployment/jspsych/tsup.config.ts +30 -0
- bead/deployment/jspsych/ui/__init__.py +1 -0
- bead/deployment/jspsych/ui/components.py +383 -0
- bead/deployment/jspsych/ui/styles.py +411 -0
- bead/dsl/__init__.py +80 -0
- bead/dsl/ast.py +168 -0
- bead/dsl/context.py +178 -0
- bead/dsl/errors.py +71 -0
- bead/dsl/evaluator.py +570 -0
- bead/dsl/grammar.lark +81 -0
- bead/dsl/parser.py +231 -0
- bead/dsl/stdlib.py +929 -0
- bead/evaluation/__init__.py +13 -0
- bead/evaluation/convergence.py +485 -0
- bead/evaluation/interannotator.py +398 -0
- bead/items/__init__.py +40 -0
- bead/items/adapters/__init__.py +70 -0
- bead/items/adapters/anthropic.py +224 -0
- bead/items/adapters/api_utils.py +167 -0
- bead/items/adapters/base.py +216 -0
- bead/items/adapters/google.py +259 -0
- bead/items/adapters/huggingface.py +1074 -0
- bead/items/adapters/openai.py +323 -0
- bead/items/adapters/registry.py +202 -0
- bead/items/adapters/sentence_transformers.py +224 -0
- bead/items/adapters/togetherai.py +309 -0
- bead/items/binary.py +515 -0
- bead/items/cache.py +558 -0
- bead/items/categorical.py +593 -0
- bead/items/cloze.py +757 -0
- bead/items/constructor.py +784 -0
- bead/items/forced_choice.py +413 -0
- bead/items/free_text.py +681 -0
- bead/items/generation.py +432 -0
- bead/items/item.py +396 -0
- bead/items/item_template.py +787 -0
- bead/items/magnitude.py +573 -0
- bead/items/multi_select.py +621 -0
- bead/items/ordinal_scale.py +569 -0
- bead/items/scoring.py +448 -0
- bead/items/validation.py +723 -0
- bead/lists/__init__.py +30 -0
- bead/lists/balancer.py +263 -0
- bead/lists/constraints.py +1067 -0
- bead/lists/experiment_list.py +286 -0
- bead/lists/list_collection.py +378 -0
- bead/lists/partitioner.py +1141 -0
- bead/lists/stratification.py +254 -0
- bead/participants/__init__.py +73 -0
- bead/participants/collection.py +699 -0
- bead/participants/merging.py +312 -0
- bead/participants/metadata_spec.py +491 -0
- bead/participants/models.py +276 -0
- bead/resources/__init__.py +29 -0
- bead/resources/adapters/__init__.py +19 -0
- bead/resources/adapters/base.py +104 -0
- bead/resources/adapters/cache.py +128 -0
- bead/resources/adapters/glazing.py +508 -0
- bead/resources/adapters/registry.py +117 -0
- bead/resources/adapters/unimorph.py +796 -0
- bead/resources/classification.py +856 -0
- bead/resources/constraint_builders.py +329 -0
- bead/resources/constraints.py +165 -0
- bead/resources/lexical_item.py +223 -0
- bead/resources/lexicon.py +744 -0
- bead/resources/loaders.py +209 -0
- bead/resources/template.py +441 -0
- bead/resources/template_collection.py +707 -0
- bead/resources/template_generation.py +349 -0
- bead/simulation/__init__.py +29 -0
- bead/simulation/annotators/__init__.py +15 -0
- bead/simulation/annotators/base.py +175 -0
- bead/simulation/annotators/distance_based.py +135 -0
- bead/simulation/annotators/lm_based.py +114 -0
- bead/simulation/annotators/oracle.py +182 -0
- bead/simulation/annotators/random.py +181 -0
- bead/simulation/dsl_extension/__init__.py +3 -0
- bead/simulation/noise_models/__init__.py +13 -0
- bead/simulation/noise_models/base.py +42 -0
- bead/simulation/noise_models/random_noise.py +82 -0
- bead/simulation/noise_models/systematic.py +132 -0
- bead/simulation/noise_models/temperature.py +86 -0
- bead/simulation/runner.py +144 -0
- bead/simulation/strategies/__init__.py +23 -0
- bead/simulation/strategies/base.py +123 -0
- bead/simulation/strategies/binary.py +103 -0
- bead/simulation/strategies/categorical.py +123 -0
- bead/simulation/strategies/cloze.py +224 -0
- bead/simulation/strategies/forced_choice.py +127 -0
- bead/simulation/strategies/free_text.py +105 -0
- bead/simulation/strategies/magnitude.py +116 -0
- bead/simulation/strategies/multi_select.py +129 -0
- bead/simulation/strategies/ordinal_scale.py +131 -0
- bead/templates/__init__.py +27 -0
- bead/templates/adapters/__init__.py +17 -0
- bead/templates/adapters/base.py +128 -0
- bead/templates/adapters/cache.py +178 -0
- bead/templates/adapters/huggingface.py +312 -0
- bead/templates/combinatorics.py +103 -0
- bead/templates/filler.py +605 -0
- bead/templates/renderers.py +177 -0
- bead/templates/resolver.py +178 -0
- bead/templates/strategies.py +1806 -0
- bead/templates/streaming.py +195 -0
- bead-0.1.0.dist-info/METADATA +212 -0
- bead-0.1.0.dist-info/RECORD +231 -0
- bead-0.1.0.dist-info/WHEEL +4 -0
- bead-0.1.0.dist-info/entry_points.txt +2 -0
- bead-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
"""Data collator for mixed effects training.
|
|
2
|
+
|
|
3
|
+
This module provides a custom data collator that handles participant_ids
|
|
4
|
+
along with standard tokenization and padding.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import TYPE_CHECKING
|
|
10
|
+
|
|
11
|
+
from transformers import DataCollatorWithPadding
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
import torch
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class MixedEffectsDataCollator(DataCollatorWithPadding):
|
|
18
|
+
"""Data collator that preserves participant_ids for mixed effects.
|
|
19
|
+
|
|
20
|
+
Extends DataCollatorWithPadding to handle participant_ids as strings
|
|
21
|
+
(not tensors) and pass them through to the training batch.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
tokenizer : PreTrainedTokenizerBase
|
|
26
|
+
HuggingFace tokenizer.
|
|
27
|
+
padding : bool | str
|
|
28
|
+
Padding strategy (default: True).
|
|
29
|
+
max_length : int | None
|
|
30
|
+
Maximum sequence length (optional).
|
|
31
|
+
pad_to_multiple_of : int | None
|
|
32
|
+
Pad to multiple of this value (optional).
|
|
33
|
+
|
|
34
|
+
Examples
|
|
35
|
+
--------
|
|
36
|
+
>>> from transformers import AutoTokenizer
|
|
37
|
+
>>> tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
|
|
38
|
+
>>> collator = MixedEffectsDataCollator(tokenizer)
|
|
39
|
+
>>> batch = collator([{'input_ids': [1, 2, 3], 'participant_id': 'alice'}])
|
|
40
|
+
>>> 'participant_id' in batch
|
|
41
|
+
True
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __call__(
|
|
45
|
+
self, features: list[dict[str, torch.Tensor | str | int | float]]
|
|
46
|
+
) -> dict[str, torch.Tensor | list[str]]:
|
|
47
|
+
"""Collate batch with participant_ids preserved.
|
|
48
|
+
|
|
49
|
+
Parameters
|
|
50
|
+
----------
|
|
51
|
+
features : list[dict[str, torch.Tensor | str | int | float]]
|
|
52
|
+
List of feature dictionaries from dataset.
|
|
53
|
+
|
|
54
|
+
Returns
|
|
55
|
+
-------
|
|
56
|
+
dict[str, torch.Tensor | list[str]]
|
|
57
|
+
Collated batch with participant_ids as list[str].
|
|
58
|
+
"""
|
|
59
|
+
# Extract participant_ids before padding
|
|
60
|
+
participant_ids: list[str] = []
|
|
61
|
+
for feat in features:
|
|
62
|
+
pid = feat.get("participant_id", "_fixed_")
|
|
63
|
+
participant_ids.append(str(pid))
|
|
64
|
+
|
|
65
|
+
# Remove participant_id from features for standard collation
|
|
66
|
+
features_for_collation = [
|
|
67
|
+
{k: v for k, v in feat.items() if k != "participant_id"}
|
|
68
|
+
for feat in features
|
|
69
|
+
]
|
|
70
|
+
|
|
71
|
+
# Use parent collator for tokenization/padding
|
|
72
|
+
batch = super().__call__(features_for_collation)
|
|
73
|
+
|
|
74
|
+
# Add participant_ids back as list (not tensor)
|
|
75
|
+
batch["participant_id"] = participant_ids
|
|
76
|
+
|
|
77
|
+
return batch
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class ClozeDataCollator(MixedEffectsDataCollator):
|
|
81
|
+
"""Data collator for cloze (MLM) tasks with custom masking.
|
|
82
|
+
|
|
83
|
+
Extends MixedEffectsDataCollator to handle:
|
|
84
|
+
- masked_positions: List of masked token positions per item
|
|
85
|
+
- target_token_ids: List of target token IDs per masked position
|
|
86
|
+
- Preserves these for loss computation in the trainer
|
|
87
|
+
|
|
88
|
+
Parameters
|
|
89
|
+
----------
|
|
90
|
+
tokenizer : PreTrainedTokenizerBase
|
|
91
|
+
HuggingFace tokenizer.
|
|
92
|
+
padding : bool | str
|
|
93
|
+
Padding strategy (default: True).
|
|
94
|
+
max_length : int | None
|
|
95
|
+
Maximum sequence length (optional).
|
|
96
|
+
pad_to_multiple_of : int | None
|
|
97
|
+
Pad to multiple of this value (optional).
|
|
98
|
+
|
|
99
|
+
Examples
|
|
100
|
+
--------
|
|
101
|
+
>>> from transformers import AutoTokenizer
|
|
102
|
+
>>> tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
|
|
103
|
+
>>> collator = ClozeDataCollator(tokenizer)
|
|
104
|
+
>>> batch = collator([{
|
|
105
|
+
... 'input_ids': [1, 2, 103, 4], # 103 is [MASK]
|
|
106
|
+
... 'masked_positions': [2],
|
|
107
|
+
... 'target_token_ids': [1234],
|
|
108
|
+
... 'participant_id': 'alice'
|
|
109
|
+
... }])
|
|
110
|
+
>>> 'masked_positions' in batch
|
|
111
|
+
True
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
def __call__(
|
|
115
|
+
self, features: list[dict[str, torch.Tensor | str | int | float | list[int]]]
|
|
116
|
+
) -> dict[str, torch.Tensor | list[str] | list[list[int]]]:
|
|
117
|
+
"""Collate batch with masked positions and target token IDs preserved.
|
|
118
|
+
|
|
119
|
+
Parameters
|
|
120
|
+
----------
|
|
121
|
+
features : list[dict[str, torch.Tensor | str | int | float | list[int]]]
|
|
122
|
+
List of feature dictionaries from dataset.
|
|
123
|
+
|
|
124
|
+
Returns
|
|
125
|
+
-------
|
|
126
|
+
dict[str, torch.Tensor | list[str] | list[list[int]]]
|
|
127
|
+
Collated batch with:
|
|
128
|
+
- Standard tokenized inputs (input_ids, attention_mask, etc.)
|
|
129
|
+
- participant_ids as list[str]
|
|
130
|
+
- masked_positions as list[list[int]]
|
|
131
|
+
- target_token_ids as list[list[int]]
|
|
132
|
+
"""
|
|
133
|
+
# Extract cloze-specific fields before padding
|
|
134
|
+
participant_ids: list[str] = []
|
|
135
|
+
masked_positions: list[list[int]] = []
|
|
136
|
+
target_token_ids: list[list[int]] = []
|
|
137
|
+
|
|
138
|
+
for feat in features:
|
|
139
|
+
pid = feat.get("participant_id", "_fixed_")
|
|
140
|
+
participant_ids.append(str(pid))
|
|
141
|
+
|
|
142
|
+
masked_pos = feat.get("masked_positions", [])
|
|
143
|
+
if isinstance(masked_pos, list):
|
|
144
|
+
masked_positions.append(masked_pos)
|
|
145
|
+
else:
|
|
146
|
+
masked_positions.append([])
|
|
147
|
+
|
|
148
|
+
target_ids = feat.get("target_token_ids", [])
|
|
149
|
+
if isinstance(target_ids, list):
|
|
150
|
+
target_token_ids.append(target_ids)
|
|
151
|
+
else:
|
|
152
|
+
target_token_ids.append([])
|
|
153
|
+
|
|
154
|
+
# Remove cloze-specific fields for standard collation
|
|
155
|
+
features_for_collation = [
|
|
156
|
+
{
|
|
157
|
+
k: v
|
|
158
|
+
for k, v in feat.items()
|
|
159
|
+
if k not in ("participant_id", "masked_positions", "target_token_ids")
|
|
160
|
+
}
|
|
161
|
+
for feat in features
|
|
162
|
+
]
|
|
163
|
+
|
|
164
|
+
# Use parent collator for tokenization/padding
|
|
165
|
+
batch = super().__call__(features_for_collation)
|
|
166
|
+
|
|
167
|
+
# Add cloze-specific fields back
|
|
168
|
+
batch["participant_id"] = participant_ids
|
|
169
|
+
batch["masked_positions"] = masked_positions
|
|
170
|
+
batch["target_token_ids"] = target_token_ids
|
|
171
|
+
|
|
172
|
+
return batch
|
|
@@ -0,0 +1,261 @@
|
|
|
1
|
+
"""Utilities for converting items to HuggingFace datasets.
|
|
2
|
+
|
|
3
|
+
This module provides functions to convert bead Items to HuggingFace Dataset
|
|
4
|
+
format for use with HuggingFace Trainer.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import TYPE_CHECKING
|
|
10
|
+
|
|
11
|
+
from datasets import Dataset
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from transformers import PreTrainedTokenizerBase
|
|
15
|
+
|
|
16
|
+
from bead.items.item import Item
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def items_to_dataset(
|
|
20
|
+
items: list[Item],
|
|
21
|
+
labels: list[str | int | float],
|
|
22
|
+
participant_ids: list[str] | None,
|
|
23
|
+
tokenizer: PreTrainedTokenizerBase,
|
|
24
|
+
max_length: int = 128,
|
|
25
|
+
text_key: str = "text",
|
|
26
|
+
) -> Dataset:
|
|
27
|
+
"""Convert items and labels to HuggingFace Dataset.
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
items : list[Item]
|
|
32
|
+
Items to convert.
|
|
33
|
+
labels : list[str | int | float]
|
|
34
|
+
Labels for items.
|
|
35
|
+
participant_ids : list[str] | None
|
|
36
|
+
Participant IDs for each item (required for mixed effects).
|
|
37
|
+
tokenizer : PreTrainedTokenizerBase
|
|
38
|
+
HuggingFace tokenizer.
|
|
39
|
+
max_length : int
|
|
40
|
+
Maximum sequence length for tokenization.
|
|
41
|
+
text_key : str
|
|
42
|
+
Key in rendered_elements to use as text (default: "text").
|
|
43
|
+
|
|
44
|
+
Returns
|
|
45
|
+
-------
|
|
46
|
+
Dataset
|
|
47
|
+
HuggingFace Dataset with tokenized inputs, labels, and participant_ids.
|
|
48
|
+
|
|
49
|
+
Examples
|
|
50
|
+
--------
|
|
51
|
+
>>> from transformers import AutoTokenizer
|
|
52
|
+
>>> tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
|
|
53
|
+
>>> dataset = items_to_dataset(
|
|
54
|
+
... items=items,
|
|
55
|
+
... labels=['yes', 'no', 'yes'],
|
|
56
|
+
... participant_ids=['p1', 'p1', 'p2'],
|
|
57
|
+
... tokenizer=tokenizer
|
|
58
|
+
... )
|
|
59
|
+
>>> len(dataset)
|
|
60
|
+
3
|
|
61
|
+
"""
|
|
62
|
+
# Extract texts from items
|
|
63
|
+
texts: list[str] = []
|
|
64
|
+
for item in items:
|
|
65
|
+
# Try to get text from rendered_elements
|
|
66
|
+
if text_key in item.rendered_elements:
|
|
67
|
+
text = item.rendered_elements[text_key]
|
|
68
|
+
else:
|
|
69
|
+
# Fallback: concatenate all rendered elements
|
|
70
|
+
text = " ".join(str(v) for v in item.rendered_elements.values())
|
|
71
|
+
texts.append(text)
|
|
72
|
+
|
|
73
|
+
# Tokenize texts
|
|
74
|
+
tokenized = tokenizer(
|
|
75
|
+
texts,
|
|
76
|
+
padding=True,
|
|
77
|
+
truncation=True,
|
|
78
|
+
max_length=max_length,
|
|
79
|
+
return_tensors=None, # Return lists, not tensors
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# Build dataset dict
|
|
83
|
+
dataset_dict: dict[str, list[str | int | float]] = {
|
|
84
|
+
"input_ids": tokenized["input_ids"],
|
|
85
|
+
"attention_mask": tokenized["attention_mask"],
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
# Add token_type_ids if present
|
|
89
|
+
if "token_type_ids" in tokenized:
|
|
90
|
+
dataset_dict["token_type_ids"] = tokenized["token_type_ids"]
|
|
91
|
+
|
|
92
|
+
# Add labels
|
|
93
|
+
dataset_dict["labels"] = labels
|
|
94
|
+
|
|
95
|
+
# Add participant IDs if provided
|
|
96
|
+
if participant_ids is not None:
|
|
97
|
+
dataset_dict["participant_id"] = participant_ids
|
|
98
|
+
|
|
99
|
+
return Dataset.from_dict(dataset_dict)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def cloze_items_to_dataset(
|
|
103
|
+
items: list[Item],
|
|
104
|
+
labels: list[list[str]],
|
|
105
|
+
participant_ids: list[str] | None,
|
|
106
|
+
tokenizer: PreTrainedTokenizerBase,
|
|
107
|
+
max_length: int = 128,
|
|
108
|
+
text_key: str = "text",
|
|
109
|
+
) -> Dataset:
|
|
110
|
+
"""Convert cloze items and labels to HuggingFace Dataset with masking.
|
|
111
|
+
|
|
112
|
+
For cloze tasks, this function:
|
|
113
|
+
1. Extracts text from items
|
|
114
|
+
2. Tokenizes and identifies masked positions (from "___" placeholders)
|
|
115
|
+
3. Replaces "___" with [MASK] tokens
|
|
116
|
+
4. Stores masked positions and target token IDs for loss computation
|
|
117
|
+
|
|
118
|
+
Parameters
|
|
119
|
+
----------
|
|
120
|
+
items : list[Item]
|
|
121
|
+
Items with unfilled_slots (cloze items).
|
|
122
|
+
labels : list[list[str]]
|
|
123
|
+
Labels as list of lists. Each inner list contains one token per unfilled slot.
|
|
124
|
+
participant_ids : list[str] | None
|
|
125
|
+
Participant IDs for each item.
|
|
126
|
+
tokenizer : PreTrainedTokenizerBase
|
|
127
|
+
HuggingFace tokenizer.
|
|
128
|
+
max_length : int
|
|
129
|
+
Maximum sequence length for tokenization.
|
|
130
|
+
text_key : str
|
|
131
|
+
Key in rendered_elements to use as text (default: "text").
|
|
132
|
+
|
|
133
|
+
Returns
|
|
134
|
+
-------
|
|
135
|
+
Dataset
|
|
136
|
+
HuggingFace Dataset with:
|
|
137
|
+
- input_ids: Tokenized text with [MASK] tokens
|
|
138
|
+
- attention_mask: Attention mask
|
|
139
|
+
- masked_positions: List of masked token positions per item
|
|
140
|
+
- target_token_ids: List of target token IDs per masked position
|
|
141
|
+
- participant_id: Participant IDs
|
|
142
|
+
|
|
143
|
+
Examples
|
|
144
|
+
--------
|
|
145
|
+
>>> from transformers import AutoTokenizer
|
|
146
|
+
>>> tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
|
|
147
|
+
>>> items = [Item(..., rendered_elements={"text": "The cat ___."}, ...)]
|
|
148
|
+
>>> labels = [["ran"]]
|
|
149
|
+
>>> dataset = cloze_items_to_dataset(
|
|
150
|
+
... items=items,
|
|
151
|
+
... labels=[["ran"]],
|
|
152
|
+
... participant_ids=['p1'],
|
|
153
|
+
... tokenizer=tokenizer
|
|
154
|
+
... )
|
|
155
|
+
>>> len(dataset)
|
|
156
|
+
1
|
|
157
|
+
"""
|
|
158
|
+
mask_token_id = tokenizer.mask_token_id
|
|
159
|
+
texts: list[str] = []
|
|
160
|
+
all_masked_positions: list[list[int]] = []
|
|
161
|
+
all_target_token_ids: list[list[int]] = []
|
|
162
|
+
|
|
163
|
+
for item, label_list in zip(items, labels, strict=True):
|
|
164
|
+
# Get text
|
|
165
|
+
if text_key in item.rendered_elements:
|
|
166
|
+
text = item.rendered_elements[text_key]
|
|
167
|
+
else:
|
|
168
|
+
text = " ".join(str(v) for v in item.rendered_elements.values())
|
|
169
|
+
texts.append(text)
|
|
170
|
+
|
|
171
|
+
# Tokenize to find "___" positions
|
|
172
|
+
# First tokenize the full text to get the actual token IDs
|
|
173
|
+
full_tokenized = tokenizer(
|
|
174
|
+
text, add_special_tokens=True, return_offsets_mapping=False
|
|
175
|
+
)
|
|
176
|
+
full_tokens = tokenizer.convert_ids_to_tokens(full_tokenized["input_ids"])
|
|
177
|
+
|
|
178
|
+
# Now find "___" positions in the tokenized sequence
|
|
179
|
+
masked_indices: list[int] = []
|
|
180
|
+
target_ids: list[int] = []
|
|
181
|
+
|
|
182
|
+
# Track which tokens are part of "___" to avoid duplicates
|
|
183
|
+
in_blank = False
|
|
184
|
+
label_idx = 0
|
|
185
|
+
|
|
186
|
+
# Skip [CLS] token (index 0)
|
|
187
|
+
for j in range(1, len(full_tokens)):
|
|
188
|
+
token = full_tokens[j]
|
|
189
|
+
# Check if this token is part of a "___" placeholder
|
|
190
|
+
if "_" in token and not in_blank:
|
|
191
|
+
# Start of a new blank - record this position
|
|
192
|
+
masked_indices.append(j)
|
|
193
|
+
in_blank = True
|
|
194
|
+
|
|
195
|
+
# Get target token ID for this label
|
|
196
|
+
if label_idx < len(label_list):
|
|
197
|
+
target_token = label_list[label_idx]
|
|
198
|
+
# Tokenize the target token
|
|
199
|
+
target_tokenized = tokenizer.encode(
|
|
200
|
+
target_token, add_special_tokens=False
|
|
201
|
+
)
|
|
202
|
+
if target_tokenized:
|
|
203
|
+
target_ids.append(target_tokenized[0])
|
|
204
|
+
else:
|
|
205
|
+
# Fallback: use UNK token
|
|
206
|
+
target_ids.append(tokenizer.unk_token_id)
|
|
207
|
+
label_idx += 1
|
|
208
|
+
elif "_" in token and in_blank:
|
|
209
|
+
# Continuation of current blank - also mask but don't record again
|
|
210
|
+
pass
|
|
211
|
+
else:
|
|
212
|
+
# Not a blank token - reset in_blank
|
|
213
|
+
in_blank = False
|
|
214
|
+
|
|
215
|
+
# Verify we found the expected number of masked positions
|
|
216
|
+
expected_slots = len(item.unfilled_slots)
|
|
217
|
+
if len(masked_indices) != expected_slots:
|
|
218
|
+
raise ValueError(
|
|
219
|
+
f"Mismatch between masked positions and unfilled_slots "
|
|
220
|
+
f"for item: found {len(masked_indices)} '___' "
|
|
221
|
+
f"placeholders in text but item has {expected_slots} "
|
|
222
|
+
f"unfilled_slots. Ensure rendered text uses exactly one "
|
|
223
|
+
f"'___' per unfilled_slot. Text: '{text}'"
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
all_masked_positions.append(masked_indices)
|
|
227
|
+
all_target_token_ids.append(target_ids)
|
|
228
|
+
|
|
229
|
+
# Tokenize all texts (this will include "___" which we'll replace)
|
|
230
|
+
tokenized = tokenizer(
|
|
231
|
+
texts,
|
|
232
|
+
padding=True,
|
|
233
|
+
truncation=True,
|
|
234
|
+
max_length=max_length,
|
|
235
|
+
return_tensors=None, # Return lists, not tensors
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
# Replace "___" tokens with [MASK] in input_ids
|
|
239
|
+
input_ids = tokenized["input_ids"]
|
|
240
|
+
for i, masked_pos in enumerate(all_masked_positions):
|
|
241
|
+
for pos in masked_pos:
|
|
242
|
+
if pos < len(input_ids[i]):
|
|
243
|
+
input_ids[i][pos] = mask_token_id
|
|
244
|
+
|
|
245
|
+
# Build dataset dict
|
|
246
|
+
dataset_dict: dict[str, list[str | int | float | list[int]]] = {
|
|
247
|
+
"input_ids": input_ids,
|
|
248
|
+
"attention_mask": tokenized["attention_mask"],
|
|
249
|
+
"masked_positions": all_masked_positions,
|
|
250
|
+
"target_token_ids": all_target_token_ids,
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
# Add token_type_ids if present
|
|
254
|
+
if "token_type_ids" in tokenized:
|
|
255
|
+
dataset_dict["token_type_ids"] = tokenized["token_type_ids"]
|
|
256
|
+
|
|
257
|
+
# Add participant IDs if provided
|
|
258
|
+
if participant_ids is not None:
|
|
259
|
+
dataset_dict["participant_id"] = participant_ids
|
|
260
|
+
|
|
261
|
+
return Dataset.from_dict(dataset_dict)
|