bead 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. bead/__init__.py +11 -0
  2. bead/__main__.py +11 -0
  3. bead/active_learning/__init__.py +15 -0
  4. bead/active_learning/config.py +231 -0
  5. bead/active_learning/loop.py +566 -0
  6. bead/active_learning/models/__init__.py +24 -0
  7. bead/active_learning/models/base.py +852 -0
  8. bead/active_learning/models/binary.py +910 -0
  9. bead/active_learning/models/categorical.py +943 -0
  10. bead/active_learning/models/cloze.py +862 -0
  11. bead/active_learning/models/forced_choice.py +956 -0
  12. bead/active_learning/models/free_text.py +773 -0
  13. bead/active_learning/models/lora.py +365 -0
  14. bead/active_learning/models/magnitude.py +835 -0
  15. bead/active_learning/models/multi_select.py +795 -0
  16. bead/active_learning/models/ordinal_scale.py +811 -0
  17. bead/active_learning/models/peft_adapter.py +155 -0
  18. bead/active_learning/models/random_effects.py +639 -0
  19. bead/active_learning/selection.py +354 -0
  20. bead/active_learning/strategies.py +391 -0
  21. bead/active_learning/trainers/__init__.py +26 -0
  22. bead/active_learning/trainers/base.py +210 -0
  23. bead/active_learning/trainers/data_collator.py +172 -0
  24. bead/active_learning/trainers/dataset_utils.py +261 -0
  25. bead/active_learning/trainers/huggingface.py +304 -0
  26. bead/active_learning/trainers/lightning.py +324 -0
  27. bead/active_learning/trainers/metrics.py +424 -0
  28. bead/active_learning/trainers/mixed_effects.py +551 -0
  29. bead/active_learning/trainers/model_wrapper.py +509 -0
  30. bead/active_learning/trainers/registry.py +104 -0
  31. bead/adapters/__init__.py +11 -0
  32. bead/adapters/huggingface.py +61 -0
  33. bead/behavioral/__init__.py +116 -0
  34. bead/behavioral/analytics.py +646 -0
  35. bead/behavioral/extraction.py +343 -0
  36. bead/behavioral/merging.py +343 -0
  37. bead/cli/__init__.py +11 -0
  38. bead/cli/active_learning.py +513 -0
  39. bead/cli/active_learning_commands.py +779 -0
  40. bead/cli/completion.py +359 -0
  41. bead/cli/config.py +624 -0
  42. bead/cli/constraint_builders.py +286 -0
  43. bead/cli/deployment.py +859 -0
  44. bead/cli/deployment_trials.py +493 -0
  45. bead/cli/deployment_ui.py +332 -0
  46. bead/cli/display.py +378 -0
  47. bead/cli/items.py +960 -0
  48. bead/cli/items_factories.py +776 -0
  49. bead/cli/list_constraints.py +714 -0
  50. bead/cli/lists.py +490 -0
  51. bead/cli/main.py +430 -0
  52. bead/cli/models.py +877 -0
  53. bead/cli/resource_loaders.py +621 -0
  54. bead/cli/resources.py +1036 -0
  55. bead/cli/shell.py +356 -0
  56. bead/cli/simulate.py +840 -0
  57. bead/cli/templates.py +1158 -0
  58. bead/cli/training.py +1080 -0
  59. bead/cli/utils.py +614 -0
  60. bead/cli/workflow.py +1273 -0
  61. bead/config/__init__.py +68 -0
  62. bead/config/active_learning.py +1009 -0
  63. bead/config/config.py +192 -0
  64. bead/config/defaults.py +118 -0
  65. bead/config/deployment.py +217 -0
  66. bead/config/env.py +147 -0
  67. bead/config/item.py +45 -0
  68. bead/config/list.py +193 -0
  69. bead/config/loader.py +149 -0
  70. bead/config/logging.py +42 -0
  71. bead/config/model.py +49 -0
  72. bead/config/paths.py +46 -0
  73. bead/config/profiles.py +320 -0
  74. bead/config/resources.py +47 -0
  75. bead/config/serialization.py +210 -0
  76. bead/config/simulation.py +206 -0
  77. bead/config/template.py +238 -0
  78. bead/config/validation.py +267 -0
  79. bead/data/__init__.py +65 -0
  80. bead/data/base.py +87 -0
  81. bead/data/identifiers.py +97 -0
  82. bead/data/language_codes.py +61 -0
  83. bead/data/metadata.py +270 -0
  84. bead/data/range.py +123 -0
  85. bead/data/repository.py +358 -0
  86. bead/data/serialization.py +249 -0
  87. bead/data/timestamps.py +89 -0
  88. bead/data/validation.py +349 -0
  89. bead/data_collection/__init__.py +11 -0
  90. bead/data_collection/jatos.py +223 -0
  91. bead/data_collection/merger.py +154 -0
  92. bead/data_collection/prolific.py +198 -0
  93. bead/deployment/__init__.py +5 -0
  94. bead/deployment/distribution.py +402 -0
  95. bead/deployment/jatos/__init__.py +1 -0
  96. bead/deployment/jatos/api.py +200 -0
  97. bead/deployment/jatos/exporter.py +210 -0
  98. bead/deployment/jspsych/__init__.py +9 -0
  99. bead/deployment/jspsych/biome.json +44 -0
  100. bead/deployment/jspsych/config.py +411 -0
  101. bead/deployment/jspsych/generator.py +598 -0
  102. bead/deployment/jspsych/package.json +51 -0
  103. bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
  104. bead/deployment/jspsych/randomizer.py +299 -0
  105. bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
  106. bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
  107. bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
  108. bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
  109. bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
  110. bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
  111. bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
  112. bead/deployment/jspsych/src/plugins/rating.ts +248 -0
  113. bead/deployment/jspsych/src/slopit/index.ts +9 -0
  114. bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
  115. bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
  116. bead/deployment/jspsych/templates/experiment.css +1 -0
  117. bead/deployment/jspsych/templates/experiment.js.template +289 -0
  118. bead/deployment/jspsych/templates/index.html +51 -0
  119. bead/deployment/jspsych/templates/randomizer.js +241 -0
  120. bead/deployment/jspsych/templates/randomizer.js.template +313 -0
  121. bead/deployment/jspsych/trials.py +723 -0
  122. bead/deployment/jspsych/tsconfig.json +23 -0
  123. bead/deployment/jspsych/tsup.config.ts +30 -0
  124. bead/deployment/jspsych/ui/__init__.py +1 -0
  125. bead/deployment/jspsych/ui/components.py +383 -0
  126. bead/deployment/jspsych/ui/styles.py +411 -0
  127. bead/dsl/__init__.py +80 -0
  128. bead/dsl/ast.py +168 -0
  129. bead/dsl/context.py +178 -0
  130. bead/dsl/errors.py +71 -0
  131. bead/dsl/evaluator.py +570 -0
  132. bead/dsl/grammar.lark +81 -0
  133. bead/dsl/parser.py +231 -0
  134. bead/dsl/stdlib.py +929 -0
  135. bead/evaluation/__init__.py +13 -0
  136. bead/evaluation/convergence.py +485 -0
  137. bead/evaluation/interannotator.py +398 -0
  138. bead/items/__init__.py +40 -0
  139. bead/items/adapters/__init__.py +70 -0
  140. bead/items/adapters/anthropic.py +224 -0
  141. bead/items/adapters/api_utils.py +167 -0
  142. bead/items/adapters/base.py +216 -0
  143. bead/items/adapters/google.py +259 -0
  144. bead/items/adapters/huggingface.py +1074 -0
  145. bead/items/adapters/openai.py +323 -0
  146. bead/items/adapters/registry.py +202 -0
  147. bead/items/adapters/sentence_transformers.py +224 -0
  148. bead/items/adapters/togetherai.py +309 -0
  149. bead/items/binary.py +515 -0
  150. bead/items/cache.py +558 -0
  151. bead/items/categorical.py +593 -0
  152. bead/items/cloze.py +757 -0
  153. bead/items/constructor.py +784 -0
  154. bead/items/forced_choice.py +413 -0
  155. bead/items/free_text.py +681 -0
  156. bead/items/generation.py +432 -0
  157. bead/items/item.py +396 -0
  158. bead/items/item_template.py +787 -0
  159. bead/items/magnitude.py +573 -0
  160. bead/items/multi_select.py +621 -0
  161. bead/items/ordinal_scale.py +569 -0
  162. bead/items/scoring.py +448 -0
  163. bead/items/validation.py +723 -0
  164. bead/lists/__init__.py +30 -0
  165. bead/lists/balancer.py +263 -0
  166. bead/lists/constraints.py +1067 -0
  167. bead/lists/experiment_list.py +286 -0
  168. bead/lists/list_collection.py +378 -0
  169. bead/lists/partitioner.py +1141 -0
  170. bead/lists/stratification.py +254 -0
  171. bead/participants/__init__.py +73 -0
  172. bead/participants/collection.py +699 -0
  173. bead/participants/merging.py +312 -0
  174. bead/participants/metadata_spec.py +491 -0
  175. bead/participants/models.py +276 -0
  176. bead/resources/__init__.py +29 -0
  177. bead/resources/adapters/__init__.py +19 -0
  178. bead/resources/adapters/base.py +104 -0
  179. bead/resources/adapters/cache.py +128 -0
  180. bead/resources/adapters/glazing.py +508 -0
  181. bead/resources/adapters/registry.py +117 -0
  182. bead/resources/adapters/unimorph.py +796 -0
  183. bead/resources/classification.py +856 -0
  184. bead/resources/constraint_builders.py +329 -0
  185. bead/resources/constraints.py +165 -0
  186. bead/resources/lexical_item.py +223 -0
  187. bead/resources/lexicon.py +744 -0
  188. bead/resources/loaders.py +209 -0
  189. bead/resources/template.py +441 -0
  190. bead/resources/template_collection.py +707 -0
  191. bead/resources/template_generation.py +349 -0
  192. bead/simulation/__init__.py +29 -0
  193. bead/simulation/annotators/__init__.py +15 -0
  194. bead/simulation/annotators/base.py +175 -0
  195. bead/simulation/annotators/distance_based.py +135 -0
  196. bead/simulation/annotators/lm_based.py +114 -0
  197. bead/simulation/annotators/oracle.py +182 -0
  198. bead/simulation/annotators/random.py +181 -0
  199. bead/simulation/dsl_extension/__init__.py +3 -0
  200. bead/simulation/noise_models/__init__.py +13 -0
  201. bead/simulation/noise_models/base.py +42 -0
  202. bead/simulation/noise_models/random_noise.py +82 -0
  203. bead/simulation/noise_models/systematic.py +132 -0
  204. bead/simulation/noise_models/temperature.py +86 -0
  205. bead/simulation/runner.py +144 -0
  206. bead/simulation/strategies/__init__.py +23 -0
  207. bead/simulation/strategies/base.py +123 -0
  208. bead/simulation/strategies/binary.py +103 -0
  209. bead/simulation/strategies/categorical.py +123 -0
  210. bead/simulation/strategies/cloze.py +224 -0
  211. bead/simulation/strategies/forced_choice.py +127 -0
  212. bead/simulation/strategies/free_text.py +105 -0
  213. bead/simulation/strategies/magnitude.py +116 -0
  214. bead/simulation/strategies/multi_select.py +129 -0
  215. bead/simulation/strategies/ordinal_scale.py +131 -0
  216. bead/templates/__init__.py +27 -0
  217. bead/templates/adapters/__init__.py +17 -0
  218. bead/templates/adapters/base.py +128 -0
  219. bead/templates/adapters/cache.py +178 -0
  220. bead/templates/adapters/huggingface.py +312 -0
  221. bead/templates/combinatorics.py +103 -0
  222. bead/templates/filler.py +605 -0
  223. bead/templates/renderers.py +177 -0
  224. bead/templates/resolver.py +178 -0
  225. bead/templates/strategies.py +1806 -0
  226. bead/templates/streaming.py +195 -0
  227. bead-0.1.0.dist-info/METADATA +212 -0
  228. bead-0.1.0.dist-info/RECORD +231 -0
  229. bead-0.1.0.dist-info/WHEEL +4 -0
  230. bead-0.1.0.dist-info/entry_points.txt +2 -0
  231. bead-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,172 @@
1
+ """Data collator for mixed effects training.
2
+
3
+ This module provides a custom data collator that handles participant_ids
4
+ along with standard tokenization and padding.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import TYPE_CHECKING
10
+
11
+ from transformers import DataCollatorWithPadding
12
+
13
+ if TYPE_CHECKING:
14
+ import torch
15
+
16
+
17
+ class MixedEffectsDataCollator(DataCollatorWithPadding):
18
+ """Data collator that preserves participant_ids for mixed effects.
19
+
20
+ Extends DataCollatorWithPadding to handle participant_ids as strings
21
+ (not tensors) and pass them through to the training batch.
22
+
23
+ Parameters
24
+ ----------
25
+ tokenizer : PreTrainedTokenizerBase
26
+ HuggingFace tokenizer.
27
+ padding : bool | str
28
+ Padding strategy (default: True).
29
+ max_length : int | None
30
+ Maximum sequence length (optional).
31
+ pad_to_multiple_of : int | None
32
+ Pad to multiple of this value (optional).
33
+
34
+ Examples
35
+ --------
36
+ >>> from transformers import AutoTokenizer
37
+ >>> tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
38
+ >>> collator = MixedEffectsDataCollator(tokenizer)
39
+ >>> batch = collator([{'input_ids': [1, 2, 3], 'participant_id': 'alice'}])
40
+ >>> 'participant_id' in batch
41
+ True
42
+ """
43
+
44
+ def __call__(
45
+ self, features: list[dict[str, torch.Tensor | str | int | float]]
46
+ ) -> dict[str, torch.Tensor | list[str]]:
47
+ """Collate batch with participant_ids preserved.
48
+
49
+ Parameters
50
+ ----------
51
+ features : list[dict[str, torch.Tensor | str | int | float]]
52
+ List of feature dictionaries from dataset.
53
+
54
+ Returns
55
+ -------
56
+ dict[str, torch.Tensor | list[str]]
57
+ Collated batch with participant_ids as list[str].
58
+ """
59
+ # Extract participant_ids before padding
60
+ participant_ids: list[str] = []
61
+ for feat in features:
62
+ pid = feat.get("participant_id", "_fixed_")
63
+ participant_ids.append(str(pid))
64
+
65
+ # Remove participant_id from features for standard collation
66
+ features_for_collation = [
67
+ {k: v for k, v in feat.items() if k != "participant_id"}
68
+ for feat in features
69
+ ]
70
+
71
+ # Use parent collator for tokenization/padding
72
+ batch = super().__call__(features_for_collation)
73
+
74
+ # Add participant_ids back as list (not tensor)
75
+ batch["participant_id"] = participant_ids
76
+
77
+ return batch
78
+
79
+
80
+ class ClozeDataCollator(MixedEffectsDataCollator):
81
+ """Data collator for cloze (MLM) tasks with custom masking.
82
+
83
+ Extends MixedEffectsDataCollator to handle:
84
+ - masked_positions: List of masked token positions per item
85
+ - target_token_ids: List of target token IDs per masked position
86
+ - Preserves these for loss computation in the trainer
87
+
88
+ Parameters
89
+ ----------
90
+ tokenizer : PreTrainedTokenizerBase
91
+ HuggingFace tokenizer.
92
+ padding : bool | str
93
+ Padding strategy (default: True).
94
+ max_length : int | None
95
+ Maximum sequence length (optional).
96
+ pad_to_multiple_of : int | None
97
+ Pad to multiple of this value (optional).
98
+
99
+ Examples
100
+ --------
101
+ >>> from transformers import AutoTokenizer
102
+ >>> tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
103
+ >>> collator = ClozeDataCollator(tokenizer)
104
+ >>> batch = collator([{
105
+ ... 'input_ids': [1, 2, 103, 4], # 103 is [MASK]
106
+ ... 'masked_positions': [2],
107
+ ... 'target_token_ids': [1234],
108
+ ... 'participant_id': 'alice'
109
+ ... }])
110
+ >>> 'masked_positions' in batch
111
+ True
112
+ """
113
+
114
+ def __call__(
115
+ self, features: list[dict[str, torch.Tensor | str | int | float | list[int]]]
116
+ ) -> dict[str, torch.Tensor | list[str] | list[list[int]]]:
117
+ """Collate batch with masked positions and target token IDs preserved.
118
+
119
+ Parameters
120
+ ----------
121
+ features : list[dict[str, torch.Tensor | str | int | float | list[int]]]
122
+ List of feature dictionaries from dataset.
123
+
124
+ Returns
125
+ -------
126
+ dict[str, torch.Tensor | list[str] | list[list[int]]]
127
+ Collated batch with:
128
+ - Standard tokenized inputs (input_ids, attention_mask, etc.)
129
+ - participant_ids as list[str]
130
+ - masked_positions as list[list[int]]
131
+ - target_token_ids as list[list[int]]
132
+ """
133
+ # Extract cloze-specific fields before padding
134
+ participant_ids: list[str] = []
135
+ masked_positions: list[list[int]] = []
136
+ target_token_ids: list[list[int]] = []
137
+
138
+ for feat in features:
139
+ pid = feat.get("participant_id", "_fixed_")
140
+ participant_ids.append(str(pid))
141
+
142
+ masked_pos = feat.get("masked_positions", [])
143
+ if isinstance(masked_pos, list):
144
+ masked_positions.append(masked_pos)
145
+ else:
146
+ masked_positions.append([])
147
+
148
+ target_ids = feat.get("target_token_ids", [])
149
+ if isinstance(target_ids, list):
150
+ target_token_ids.append(target_ids)
151
+ else:
152
+ target_token_ids.append([])
153
+
154
+ # Remove cloze-specific fields for standard collation
155
+ features_for_collation = [
156
+ {
157
+ k: v
158
+ for k, v in feat.items()
159
+ if k not in ("participant_id", "masked_positions", "target_token_ids")
160
+ }
161
+ for feat in features
162
+ ]
163
+
164
+ # Use parent collator for tokenization/padding
165
+ batch = super().__call__(features_for_collation)
166
+
167
+ # Add cloze-specific fields back
168
+ batch["participant_id"] = participant_ids
169
+ batch["masked_positions"] = masked_positions
170
+ batch["target_token_ids"] = target_token_ids
171
+
172
+ return batch
@@ -0,0 +1,261 @@
1
+ """Utilities for converting items to HuggingFace datasets.
2
+
3
+ This module provides functions to convert bead Items to HuggingFace Dataset
4
+ format for use with HuggingFace Trainer.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import TYPE_CHECKING
10
+
11
+ from datasets import Dataset
12
+
13
+ if TYPE_CHECKING:
14
+ from transformers import PreTrainedTokenizerBase
15
+
16
+ from bead.items.item import Item
17
+
18
+
19
+ def items_to_dataset(
20
+ items: list[Item],
21
+ labels: list[str | int | float],
22
+ participant_ids: list[str] | None,
23
+ tokenizer: PreTrainedTokenizerBase,
24
+ max_length: int = 128,
25
+ text_key: str = "text",
26
+ ) -> Dataset:
27
+ """Convert items and labels to HuggingFace Dataset.
28
+
29
+ Parameters
30
+ ----------
31
+ items : list[Item]
32
+ Items to convert.
33
+ labels : list[str | int | float]
34
+ Labels for items.
35
+ participant_ids : list[str] | None
36
+ Participant IDs for each item (required for mixed effects).
37
+ tokenizer : PreTrainedTokenizerBase
38
+ HuggingFace tokenizer.
39
+ max_length : int
40
+ Maximum sequence length for tokenization.
41
+ text_key : str
42
+ Key in rendered_elements to use as text (default: "text").
43
+
44
+ Returns
45
+ -------
46
+ Dataset
47
+ HuggingFace Dataset with tokenized inputs, labels, and participant_ids.
48
+
49
+ Examples
50
+ --------
51
+ >>> from transformers import AutoTokenizer
52
+ >>> tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
53
+ >>> dataset = items_to_dataset(
54
+ ... items=items,
55
+ ... labels=['yes', 'no', 'yes'],
56
+ ... participant_ids=['p1', 'p1', 'p2'],
57
+ ... tokenizer=tokenizer
58
+ ... )
59
+ >>> len(dataset)
60
+ 3
61
+ """
62
+ # Extract texts from items
63
+ texts: list[str] = []
64
+ for item in items:
65
+ # Try to get text from rendered_elements
66
+ if text_key in item.rendered_elements:
67
+ text = item.rendered_elements[text_key]
68
+ else:
69
+ # Fallback: concatenate all rendered elements
70
+ text = " ".join(str(v) for v in item.rendered_elements.values())
71
+ texts.append(text)
72
+
73
+ # Tokenize texts
74
+ tokenized = tokenizer(
75
+ texts,
76
+ padding=True,
77
+ truncation=True,
78
+ max_length=max_length,
79
+ return_tensors=None, # Return lists, not tensors
80
+ )
81
+
82
+ # Build dataset dict
83
+ dataset_dict: dict[str, list[str | int | float]] = {
84
+ "input_ids": tokenized["input_ids"],
85
+ "attention_mask": tokenized["attention_mask"],
86
+ }
87
+
88
+ # Add token_type_ids if present
89
+ if "token_type_ids" in tokenized:
90
+ dataset_dict["token_type_ids"] = tokenized["token_type_ids"]
91
+
92
+ # Add labels
93
+ dataset_dict["labels"] = labels
94
+
95
+ # Add participant IDs if provided
96
+ if participant_ids is not None:
97
+ dataset_dict["participant_id"] = participant_ids
98
+
99
+ return Dataset.from_dict(dataset_dict)
100
+
101
+
102
+ def cloze_items_to_dataset(
103
+ items: list[Item],
104
+ labels: list[list[str]],
105
+ participant_ids: list[str] | None,
106
+ tokenizer: PreTrainedTokenizerBase,
107
+ max_length: int = 128,
108
+ text_key: str = "text",
109
+ ) -> Dataset:
110
+ """Convert cloze items and labels to HuggingFace Dataset with masking.
111
+
112
+ For cloze tasks, this function:
113
+ 1. Extracts text from items
114
+ 2. Tokenizes and identifies masked positions (from "___" placeholders)
115
+ 3. Replaces "___" with [MASK] tokens
116
+ 4. Stores masked positions and target token IDs for loss computation
117
+
118
+ Parameters
119
+ ----------
120
+ items : list[Item]
121
+ Items with unfilled_slots (cloze items).
122
+ labels : list[list[str]]
123
+ Labels as list of lists. Each inner list contains one token per unfilled slot.
124
+ participant_ids : list[str] | None
125
+ Participant IDs for each item.
126
+ tokenizer : PreTrainedTokenizerBase
127
+ HuggingFace tokenizer.
128
+ max_length : int
129
+ Maximum sequence length for tokenization.
130
+ text_key : str
131
+ Key in rendered_elements to use as text (default: "text").
132
+
133
+ Returns
134
+ -------
135
+ Dataset
136
+ HuggingFace Dataset with:
137
+ - input_ids: Tokenized text with [MASK] tokens
138
+ - attention_mask: Attention mask
139
+ - masked_positions: List of masked token positions per item
140
+ - target_token_ids: List of target token IDs per masked position
141
+ - participant_id: Participant IDs
142
+
143
+ Examples
144
+ --------
145
+ >>> from transformers import AutoTokenizer
146
+ >>> tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
147
+ >>> items = [Item(..., rendered_elements={"text": "The cat ___."}, ...)]
148
+ >>> labels = [["ran"]]
149
+ >>> dataset = cloze_items_to_dataset(
150
+ ... items=items,
151
+ ... labels=[["ran"]],
152
+ ... participant_ids=['p1'],
153
+ ... tokenizer=tokenizer
154
+ ... )
155
+ >>> len(dataset)
156
+ 1
157
+ """
158
+ mask_token_id = tokenizer.mask_token_id
159
+ texts: list[str] = []
160
+ all_masked_positions: list[list[int]] = []
161
+ all_target_token_ids: list[list[int]] = []
162
+
163
+ for item, label_list in zip(items, labels, strict=True):
164
+ # Get text
165
+ if text_key in item.rendered_elements:
166
+ text = item.rendered_elements[text_key]
167
+ else:
168
+ text = " ".join(str(v) for v in item.rendered_elements.values())
169
+ texts.append(text)
170
+
171
+ # Tokenize to find "___" positions
172
+ # First tokenize the full text to get the actual token IDs
173
+ full_tokenized = tokenizer(
174
+ text, add_special_tokens=True, return_offsets_mapping=False
175
+ )
176
+ full_tokens = tokenizer.convert_ids_to_tokens(full_tokenized["input_ids"])
177
+
178
+ # Now find "___" positions in the tokenized sequence
179
+ masked_indices: list[int] = []
180
+ target_ids: list[int] = []
181
+
182
+ # Track which tokens are part of "___" to avoid duplicates
183
+ in_blank = False
184
+ label_idx = 0
185
+
186
+ # Skip [CLS] token (index 0)
187
+ for j in range(1, len(full_tokens)):
188
+ token = full_tokens[j]
189
+ # Check if this token is part of a "___" placeholder
190
+ if "_" in token and not in_blank:
191
+ # Start of a new blank - record this position
192
+ masked_indices.append(j)
193
+ in_blank = True
194
+
195
+ # Get target token ID for this label
196
+ if label_idx < len(label_list):
197
+ target_token = label_list[label_idx]
198
+ # Tokenize the target token
199
+ target_tokenized = tokenizer.encode(
200
+ target_token, add_special_tokens=False
201
+ )
202
+ if target_tokenized:
203
+ target_ids.append(target_tokenized[0])
204
+ else:
205
+ # Fallback: use UNK token
206
+ target_ids.append(tokenizer.unk_token_id)
207
+ label_idx += 1
208
+ elif "_" in token and in_blank:
209
+ # Continuation of current blank - also mask but don't record again
210
+ pass
211
+ else:
212
+ # Not a blank token - reset in_blank
213
+ in_blank = False
214
+
215
+ # Verify we found the expected number of masked positions
216
+ expected_slots = len(item.unfilled_slots)
217
+ if len(masked_indices) != expected_slots:
218
+ raise ValueError(
219
+ f"Mismatch between masked positions and unfilled_slots "
220
+ f"for item: found {len(masked_indices)} '___' "
221
+ f"placeholders in text but item has {expected_slots} "
222
+ f"unfilled_slots. Ensure rendered text uses exactly one "
223
+ f"'___' per unfilled_slot. Text: '{text}'"
224
+ )
225
+
226
+ all_masked_positions.append(masked_indices)
227
+ all_target_token_ids.append(target_ids)
228
+
229
+ # Tokenize all texts (this will include "___" which we'll replace)
230
+ tokenized = tokenizer(
231
+ texts,
232
+ padding=True,
233
+ truncation=True,
234
+ max_length=max_length,
235
+ return_tensors=None, # Return lists, not tensors
236
+ )
237
+
238
+ # Replace "___" tokens with [MASK] in input_ids
239
+ input_ids = tokenized["input_ids"]
240
+ for i, masked_pos in enumerate(all_masked_positions):
241
+ for pos in masked_pos:
242
+ if pos < len(input_ids[i]):
243
+ input_ids[i][pos] = mask_token_id
244
+
245
+ # Build dataset dict
246
+ dataset_dict: dict[str, list[str | int | float | list[int]]] = {
247
+ "input_ids": input_ids,
248
+ "attention_mask": tokenized["attention_mask"],
249
+ "masked_positions": all_masked_positions,
250
+ "target_token_ids": all_target_token_ids,
251
+ }
252
+
253
+ # Add token_type_ids if present
254
+ if "token_type_ids" in tokenized:
255
+ dataset_dict["token_type_ids"] = tokenized["token_type_ids"]
256
+
257
+ # Add participant IDs if provided
258
+ if participant_ids is not None:
259
+ dataset_dict["participant_id"] = participant_ids
260
+
261
+ return Dataset.from_dict(dataset_dict)