bead 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. bead/__init__.py +11 -0
  2. bead/__main__.py +11 -0
  3. bead/active_learning/__init__.py +15 -0
  4. bead/active_learning/config.py +231 -0
  5. bead/active_learning/loop.py +566 -0
  6. bead/active_learning/models/__init__.py +24 -0
  7. bead/active_learning/models/base.py +852 -0
  8. bead/active_learning/models/binary.py +910 -0
  9. bead/active_learning/models/categorical.py +943 -0
  10. bead/active_learning/models/cloze.py +862 -0
  11. bead/active_learning/models/forced_choice.py +956 -0
  12. bead/active_learning/models/free_text.py +773 -0
  13. bead/active_learning/models/lora.py +365 -0
  14. bead/active_learning/models/magnitude.py +835 -0
  15. bead/active_learning/models/multi_select.py +795 -0
  16. bead/active_learning/models/ordinal_scale.py +811 -0
  17. bead/active_learning/models/peft_adapter.py +155 -0
  18. bead/active_learning/models/random_effects.py +639 -0
  19. bead/active_learning/selection.py +354 -0
  20. bead/active_learning/strategies.py +391 -0
  21. bead/active_learning/trainers/__init__.py +26 -0
  22. bead/active_learning/trainers/base.py +210 -0
  23. bead/active_learning/trainers/data_collator.py +172 -0
  24. bead/active_learning/trainers/dataset_utils.py +261 -0
  25. bead/active_learning/trainers/huggingface.py +304 -0
  26. bead/active_learning/trainers/lightning.py +324 -0
  27. bead/active_learning/trainers/metrics.py +424 -0
  28. bead/active_learning/trainers/mixed_effects.py +551 -0
  29. bead/active_learning/trainers/model_wrapper.py +509 -0
  30. bead/active_learning/trainers/registry.py +104 -0
  31. bead/adapters/__init__.py +11 -0
  32. bead/adapters/huggingface.py +61 -0
  33. bead/behavioral/__init__.py +116 -0
  34. bead/behavioral/analytics.py +646 -0
  35. bead/behavioral/extraction.py +343 -0
  36. bead/behavioral/merging.py +343 -0
  37. bead/cli/__init__.py +11 -0
  38. bead/cli/active_learning.py +513 -0
  39. bead/cli/active_learning_commands.py +779 -0
  40. bead/cli/completion.py +359 -0
  41. bead/cli/config.py +624 -0
  42. bead/cli/constraint_builders.py +286 -0
  43. bead/cli/deployment.py +859 -0
  44. bead/cli/deployment_trials.py +493 -0
  45. bead/cli/deployment_ui.py +332 -0
  46. bead/cli/display.py +378 -0
  47. bead/cli/items.py +960 -0
  48. bead/cli/items_factories.py +776 -0
  49. bead/cli/list_constraints.py +714 -0
  50. bead/cli/lists.py +490 -0
  51. bead/cli/main.py +430 -0
  52. bead/cli/models.py +877 -0
  53. bead/cli/resource_loaders.py +621 -0
  54. bead/cli/resources.py +1036 -0
  55. bead/cli/shell.py +356 -0
  56. bead/cli/simulate.py +840 -0
  57. bead/cli/templates.py +1158 -0
  58. bead/cli/training.py +1080 -0
  59. bead/cli/utils.py +614 -0
  60. bead/cli/workflow.py +1273 -0
  61. bead/config/__init__.py +68 -0
  62. bead/config/active_learning.py +1009 -0
  63. bead/config/config.py +192 -0
  64. bead/config/defaults.py +118 -0
  65. bead/config/deployment.py +217 -0
  66. bead/config/env.py +147 -0
  67. bead/config/item.py +45 -0
  68. bead/config/list.py +193 -0
  69. bead/config/loader.py +149 -0
  70. bead/config/logging.py +42 -0
  71. bead/config/model.py +49 -0
  72. bead/config/paths.py +46 -0
  73. bead/config/profiles.py +320 -0
  74. bead/config/resources.py +47 -0
  75. bead/config/serialization.py +210 -0
  76. bead/config/simulation.py +206 -0
  77. bead/config/template.py +238 -0
  78. bead/config/validation.py +267 -0
  79. bead/data/__init__.py +65 -0
  80. bead/data/base.py +87 -0
  81. bead/data/identifiers.py +97 -0
  82. bead/data/language_codes.py +61 -0
  83. bead/data/metadata.py +270 -0
  84. bead/data/range.py +123 -0
  85. bead/data/repository.py +358 -0
  86. bead/data/serialization.py +249 -0
  87. bead/data/timestamps.py +89 -0
  88. bead/data/validation.py +349 -0
  89. bead/data_collection/__init__.py +11 -0
  90. bead/data_collection/jatos.py +223 -0
  91. bead/data_collection/merger.py +154 -0
  92. bead/data_collection/prolific.py +198 -0
  93. bead/deployment/__init__.py +5 -0
  94. bead/deployment/distribution.py +402 -0
  95. bead/deployment/jatos/__init__.py +1 -0
  96. bead/deployment/jatos/api.py +200 -0
  97. bead/deployment/jatos/exporter.py +210 -0
  98. bead/deployment/jspsych/__init__.py +9 -0
  99. bead/deployment/jspsych/biome.json +44 -0
  100. bead/deployment/jspsych/config.py +411 -0
  101. bead/deployment/jspsych/generator.py +598 -0
  102. bead/deployment/jspsych/package.json +51 -0
  103. bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
  104. bead/deployment/jspsych/randomizer.py +299 -0
  105. bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
  106. bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
  107. bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
  108. bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
  109. bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
  110. bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
  111. bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
  112. bead/deployment/jspsych/src/plugins/rating.ts +248 -0
  113. bead/deployment/jspsych/src/slopit/index.ts +9 -0
  114. bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
  115. bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
  116. bead/deployment/jspsych/templates/experiment.css +1 -0
  117. bead/deployment/jspsych/templates/experiment.js.template +289 -0
  118. bead/deployment/jspsych/templates/index.html +51 -0
  119. bead/deployment/jspsych/templates/randomizer.js +241 -0
  120. bead/deployment/jspsych/templates/randomizer.js.template +313 -0
  121. bead/deployment/jspsych/trials.py +723 -0
  122. bead/deployment/jspsych/tsconfig.json +23 -0
  123. bead/deployment/jspsych/tsup.config.ts +30 -0
  124. bead/deployment/jspsych/ui/__init__.py +1 -0
  125. bead/deployment/jspsych/ui/components.py +383 -0
  126. bead/deployment/jspsych/ui/styles.py +411 -0
  127. bead/dsl/__init__.py +80 -0
  128. bead/dsl/ast.py +168 -0
  129. bead/dsl/context.py +178 -0
  130. bead/dsl/errors.py +71 -0
  131. bead/dsl/evaluator.py +570 -0
  132. bead/dsl/grammar.lark +81 -0
  133. bead/dsl/parser.py +231 -0
  134. bead/dsl/stdlib.py +929 -0
  135. bead/evaluation/__init__.py +13 -0
  136. bead/evaluation/convergence.py +485 -0
  137. bead/evaluation/interannotator.py +398 -0
  138. bead/items/__init__.py +40 -0
  139. bead/items/adapters/__init__.py +70 -0
  140. bead/items/adapters/anthropic.py +224 -0
  141. bead/items/adapters/api_utils.py +167 -0
  142. bead/items/adapters/base.py +216 -0
  143. bead/items/adapters/google.py +259 -0
  144. bead/items/adapters/huggingface.py +1074 -0
  145. bead/items/adapters/openai.py +323 -0
  146. bead/items/adapters/registry.py +202 -0
  147. bead/items/adapters/sentence_transformers.py +224 -0
  148. bead/items/adapters/togetherai.py +309 -0
  149. bead/items/binary.py +515 -0
  150. bead/items/cache.py +558 -0
  151. bead/items/categorical.py +593 -0
  152. bead/items/cloze.py +757 -0
  153. bead/items/constructor.py +784 -0
  154. bead/items/forced_choice.py +413 -0
  155. bead/items/free_text.py +681 -0
  156. bead/items/generation.py +432 -0
  157. bead/items/item.py +396 -0
  158. bead/items/item_template.py +787 -0
  159. bead/items/magnitude.py +573 -0
  160. bead/items/multi_select.py +621 -0
  161. bead/items/ordinal_scale.py +569 -0
  162. bead/items/scoring.py +448 -0
  163. bead/items/validation.py +723 -0
  164. bead/lists/__init__.py +30 -0
  165. bead/lists/balancer.py +263 -0
  166. bead/lists/constraints.py +1067 -0
  167. bead/lists/experiment_list.py +286 -0
  168. bead/lists/list_collection.py +378 -0
  169. bead/lists/partitioner.py +1141 -0
  170. bead/lists/stratification.py +254 -0
  171. bead/participants/__init__.py +73 -0
  172. bead/participants/collection.py +699 -0
  173. bead/participants/merging.py +312 -0
  174. bead/participants/metadata_spec.py +491 -0
  175. bead/participants/models.py +276 -0
  176. bead/resources/__init__.py +29 -0
  177. bead/resources/adapters/__init__.py +19 -0
  178. bead/resources/adapters/base.py +104 -0
  179. bead/resources/adapters/cache.py +128 -0
  180. bead/resources/adapters/glazing.py +508 -0
  181. bead/resources/adapters/registry.py +117 -0
  182. bead/resources/adapters/unimorph.py +796 -0
  183. bead/resources/classification.py +856 -0
  184. bead/resources/constraint_builders.py +329 -0
  185. bead/resources/constraints.py +165 -0
  186. bead/resources/lexical_item.py +223 -0
  187. bead/resources/lexicon.py +744 -0
  188. bead/resources/loaders.py +209 -0
  189. bead/resources/template.py +441 -0
  190. bead/resources/template_collection.py +707 -0
  191. bead/resources/template_generation.py +349 -0
  192. bead/simulation/__init__.py +29 -0
  193. bead/simulation/annotators/__init__.py +15 -0
  194. bead/simulation/annotators/base.py +175 -0
  195. bead/simulation/annotators/distance_based.py +135 -0
  196. bead/simulation/annotators/lm_based.py +114 -0
  197. bead/simulation/annotators/oracle.py +182 -0
  198. bead/simulation/annotators/random.py +181 -0
  199. bead/simulation/dsl_extension/__init__.py +3 -0
  200. bead/simulation/noise_models/__init__.py +13 -0
  201. bead/simulation/noise_models/base.py +42 -0
  202. bead/simulation/noise_models/random_noise.py +82 -0
  203. bead/simulation/noise_models/systematic.py +132 -0
  204. bead/simulation/noise_models/temperature.py +86 -0
  205. bead/simulation/runner.py +144 -0
  206. bead/simulation/strategies/__init__.py +23 -0
  207. bead/simulation/strategies/base.py +123 -0
  208. bead/simulation/strategies/binary.py +103 -0
  209. bead/simulation/strategies/categorical.py +123 -0
  210. bead/simulation/strategies/cloze.py +224 -0
  211. bead/simulation/strategies/forced_choice.py +127 -0
  212. bead/simulation/strategies/free_text.py +105 -0
  213. bead/simulation/strategies/magnitude.py +116 -0
  214. bead/simulation/strategies/multi_select.py +129 -0
  215. bead/simulation/strategies/ordinal_scale.py +131 -0
  216. bead/templates/__init__.py +27 -0
  217. bead/templates/adapters/__init__.py +17 -0
  218. bead/templates/adapters/base.py +128 -0
  219. bead/templates/adapters/cache.py +178 -0
  220. bead/templates/adapters/huggingface.py +312 -0
  221. bead/templates/combinatorics.py +103 -0
  222. bead/templates/filler.py +605 -0
  223. bead/templates/renderers.py +177 -0
  224. bead/templates/resolver.py +178 -0
  225. bead/templates/strategies.py +1806 -0
  226. bead/templates/streaming.py +195 -0
  227. bead-0.1.0.dist-info/METADATA +212 -0
  228. bead-0.1.0.dist-info/RECORD +231 -0
  229. bead-0.1.0.dist-info/WHEEL +4 -0
  230. bead-0.1.0.dist-info/entry_points.txt +2 -0
  231. bead-0.1.0.dist-info/licenses/LICENSE +21 -0
bead/cli/models.py ADDED
@@ -0,0 +1,877 @@
1
+ """Model training commands for bead CLI.
2
+
3
+ This module provides commands for training GLMM models across all 8 task types
4
+ with support for fixed effects, random intercepts, and random slopes.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ from pathlib import Path
11
+ from typing import Literal, cast
12
+
13
+ import click
14
+ from rich.console import Console
15
+ from rich.progress import Progress, SpinnerColumn, TextColumn
16
+ from rich.table import Table
17
+
18
+ from bead.active_learning.config import MixedEffectsConfig
19
+ from bead.cli.display import (
20
+ print_error,
21
+ print_info,
22
+ print_success,
23
+ )
24
+ from bead.data.serialization import read_jsonlines
25
+ from bead.items.item import Item
26
+
27
+ console = Console()
28
+
29
+ # Task type to model class mapping
30
+ TASK_TYPE_MODELS = {
31
+ "forced_choice": "bead.active_learning.models.forced_choice.ForcedChoiceModel",
32
+ "categorical": "bead.active_learning.models.categorical.CategoricalModel",
33
+ "binary": "bead.active_learning.models.binary.BinaryModel",
34
+ "multi_select": "bead.active_learning.models.multi_select.MultiSelectModel",
35
+ "ordinal_scale": "bead.active_learning.models.ordinal_scale.OrdinalScaleModel",
36
+ "magnitude": "bead.active_learning.models.magnitude.MagnitudeModel",
37
+ "free_text": "bead.active_learning.models.free_text.FreeTextModel",
38
+ "cloze": "bead.active_learning.models.cloze.ClozeModel",
39
+ }
40
+
41
+ # Config classes for each task type
42
+ TASK_TYPE_CONFIGS = {
43
+ "forced_choice": "bead.config.active_learning.ForcedChoiceModelConfig",
44
+ "categorical": "bead.config.active_learning.CategoricalModelConfig",
45
+ "binary": "bead.config.active_learning.BinaryModelConfig",
46
+ "multi_select": "bead.config.active_learning.MultiSelectModelConfig",
47
+ "ordinal_scale": "bead.config.active_learning.OrdinalScaleModelConfig",
48
+ "magnitude": "bead.config.active_learning.MagnitudeModelConfig",
49
+ "free_text": "bead.config.active_learning.FreeTextModelConfig",
50
+ "cloze": "bead.config.active_learning.ClozeModelConfig",
51
+ }
52
+
53
+
54
+ def _import_class(module_path: str) -> type:
55
+ """Dynamically import a class from module path.
56
+
57
+ Parameters
58
+ ----------
59
+ module_path : str
60
+ Fully qualified path to class (e.g., 'bead.models.forced_choice.Model').
61
+
62
+ Returns
63
+ -------
64
+ type
65
+ Imported class.
66
+ """
67
+ module_name, class_name = module_path.rsplit(".", 1)
68
+ module = __import__(module_name, fromlist=[class_name])
69
+ return getattr(module, class_name)
70
+
71
+
72
+ @click.group()
73
+ def models() -> None:
74
+ r"""Model training commands.
75
+
76
+ Commands for training GLMM models for judgment prediction across all 8
77
+ task types with support for mixed effects modeling.
78
+
79
+ \b
80
+ Task Types:
81
+ • forced_choice - 2AFC, 3AFC, N-way forced choice
82
+ • categorical - Unordered categories (NLI, semantic relations)
83
+ • binary - Yes/No, True/False
84
+ • multi_select - Multiple selection (checkboxes)
85
+ • ordinal_scale - Likert scales, sliders
86
+ • magnitude - Unbounded numeric (reading time, confidence)
87
+ • free_text - Open-ended text responses
88
+ • cloze - Fill-in-the-blank
89
+
90
+ \b
91
+ Mixed Effects Modes:
92
+ • fixed - Fixed effects only (no participant variability)
93
+ • random_intercepts - Participant-specific biases
94
+ • random_slopes - Participant-specific model parameters
95
+
96
+ \b
97
+ Examples:
98
+ # Train forced choice model with fixed effects
99
+ $ bead models train-model \\
100
+ --task-type forced_choice \\
101
+ --items items.jsonl \\
102
+ --labels labels.jsonl \\
103
+ --output-dir models/fc_model/
104
+
105
+ # Train with random intercepts
106
+ $ bead models train-model \\
107
+ --task-type ordinal_scale \\
108
+ --items items.jsonl \\
109
+ --labels labels.jsonl \\
110
+ --participant-ids participant_ids.txt \\
111
+ --mixed-effects-mode random_intercepts \\
112
+ --output-dir models/os_model/
113
+
114
+ # Make predictions
115
+ $ bead models predict \\
116
+ --model-dir models/fc_model/ \\
117
+ --items test_items.jsonl \\
118
+ --output predictions.jsonl
119
+ """
120
+
121
+
122
+ @click.command()
123
+ @click.option(
124
+ "--task-type",
125
+ required=True,
126
+ type=click.Choice(list(TASK_TYPE_MODELS.keys())),
127
+ help="Task type for model",
128
+ )
129
+ @click.option(
130
+ "--items",
131
+ "items_file",
132
+ required=True,
133
+ type=click.Path(exists=True, path_type=Path),
134
+ help="Path to items JSONL file",
135
+ )
136
+ @click.option(
137
+ "--labels",
138
+ "labels_file",
139
+ required=True,
140
+ type=click.Path(exists=True, path_type=Path),
141
+ help="Path to labels JSONL file (list of response strings)",
142
+ )
143
+ @click.option(
144
+ "--participant-ids",
145
+ "participant_ids_file",
146
+ type=click.Path(exists=True, path_type=Path),
147
+ help="Path to participant IDs file (one ID per line, aligned with labels)",
148
+ )
149
+ @click.option(
150
+ "--validation-items",
151
+ type=click.Path(exists=True, path_type=Path),
152
+ help="Path to validation items JSONL file (optional)",
153
+ )
154
+ @click.option(
155
+ "--validation-labels",
156
+ type=click.Path(exists=True, path_type=Path),
157
+ help="Path to validation labels JSONL file (optional)",
158
+ )
159
+ @click.option(
160
+ "--output-dir",
161
+ required=True,
162
+ type=click.Path(path_type=Path),
163
+ help="Output directory for trained model",
164
+ )
165
+ @click.option(
166
+ "--model-name",
167
+ default="bert-base-uncased",
168
+ help="HuggingFace model name",
169
+ )
170
+ @click.option(
171
+ "--mixed-effects-mode",
172
+ type=click.Choice(["fixed", "random_intercepts", "random_slopes"]),
173
+ default="fixed",
174
+ help="Mixed effects mode",
175
+ )
176
+ @click.option(
177
+ "--max-length",
178
+ type=int,
179
+ default=128,
180
+ help="Maximum sequence length for tokenization",
181
+ )
182
+ @click.option(
183
+ "--learning-rate",
184
+ type=float,
185
+ default=2e-5,
186
+ help="Learning rate for AdamW optimizer",
187
+ )
188
+ @click.option(
189
+ "--batch-size",
190
+ type=int,
191
+ default=16,
192
+ help="Batch size for training",
193
+ )
194
+ @click.option(
195
+ "--num-epochs",
196
+ type=int,
197
+ default=3,
198
+ help="Number of training epochs",
199
+ )
200
+ @click.option(
201
+ "--device",
202
+ type=click.Choice(["cpu", "cuda", "mps"]),
203
+ default="cpu",
204
+ help="Device to train on",
205
+ )
206
+ @click.option(
207
+ "--use-lora",
208
+ is_flag=True,
209
+ help="Use LoRA parameter-efficient fine-tuning",
210
+ )
211
+ @click.option(
212
+ "--lora-rank",
213
+ type=int,
214
+ default=8,
215
+ help="LoRA rank (r)",
216
+ )
217
+ @click.option(
218
+ "--lora-alpha",
219
+ type=int,
220
+ default=16,
221
+ help="LoRA alpha scaling parameter",
222
+ )
223
+ @click.pass_context
224
+ def train_model(
225
+ ctx: click.Context,
226
+ task_type: str,
227
+ items_file: Path,
228
+ labels_file: Path,
229
+ participant_ids_file: Path | None,
230
+ validation_items: Path | None,
231
+ validation_labels: Path | None,
232
+ output_dir: Path,
233
+ model_name: str,
234
+ mixed_effects_mode: str,
235
+ max_length: int,
236
+ learning_rate: float,
237
+ batch_size: int,
238
+ num_epochs: int,
239
+ device: str,
240
+ use_lora: bool,
241
+ lora_rank: int,
242
+ lora_alpha: int,
243
+ ) -> None:
244
+ r"""Train GLMM model for judgment prediction.
245
+
246
+ Trains a generalized linear mixed model (GLMM) with support for:
247
+ - Fixed effects (population-level parameters)
248
+ - Random intercepts (participant-specific biases)
249
+ - Random slopes (participant-specific model parameters)
250
+
251
+ The model uses a transformer encoder (default: BERT) with optional
252
+ LoRA parameter-efficient fine-tuning.
253
+
254
+ Parameters
255
+ ----------
256
+ ctx : click.Context
257
+ Click context object.
258
+ task_type : str
259
+ Task type (forced_choice, categorical, binary, etc.).
260
+ items_file : Path
261
+ Path to items JSONL file.
262
+ labels_file : Path
263
+ Path to labels JSONL file (one label per line).
264
+ participant_ids_file : Path | None
265
+ Path to participant IDs file (required for random effects).
266
+ validation_items : Path | None
267
+ Path to validation items JSONL file (optional).
268
+ validation_labels : Path | None
269
+ Path to validation labels JSONL file (optional).
270
+ output_dir : Path
271
+ Output directory for trained model.
272
+ model_name : str
273
+ HuggingFace model name.
274
+ mixed_effects_mode : str
275
+ Mixed effects mode (fixed, random_intercepts, random_slopes).
276
+ max_length : int
277
+ Maximum sequence length for tokenization.
278
+ learning_rate : float
279
+ Learning rate for AdamW optimizer.
280
+ batch_size : int
281
+ Batch size for training.
282
+ num_epochs : int
283
+ Number of training epochs.
284
+ device : str
285
+ Device to train on (cpu, cuda, mps).
286
+ use_lora : bool
287
+ Whether to use LoRA fine-tuning.
288
+ lora_rank : int
289
+ LoRA rank.
290
+ lora_alpha : int
291
+ LoRA alpha scaling parameter.
292
+
293
+ Examples
294
+ --------
295
+ $ bead models train-model \\
296
+ --task-type forced_choice \\
297
+ --items items.jsonl \\
298
+ --labels labels.jsonl \\
299
+ --output-dir models/fc_model/ \\
300
+ --num-epochs 5
301
+
302
+ $ bead models train-model \\
303
+ --task-type ordinal_scale \\
304
+ --items items.jsonl \\
305
+ --labels labels.jsonl \\
306
+ --participant-ids participant_ids.txt \\
307
+ --mixed-effects-mode random_intercepts \\
308
+ --output-dir models/os_model/ \\
309
+ --device cuda \\
310
+ --use-lora \\
311
+ --lora-rank 8
312
+ """
313
+ try:
314
+ # Validate mixed effects mode requirements
315
+ if mixed_effects_mode != "fixed" and participant_ids_file is None:
316
+ print_error(
317
+ f"Mixed effects mode '{mixed_effects_mode}' requires "
318
+ "--participant-ids parameter"
319
+ )
320
+ print_info(
321
+ "Provide a file with one participant ID per line, "
322
+ "aligned with the labels file"
323
+ )
324
+ ctx.exit(1)
325
+
326
+ print_info(f"Training {task_type} model with {mixed_effects_mode} mode")
327
+
328
+ # Load items
329
+ with Progress(
330
+ SpinnerColumn(),
331
+ TextColumn("[progress.description]{task.description}"),
332
+ console=console,
333
+ ) as progress:
334
+ progress.add_task("Loading items...", total=None)
335
+ items = read_jsonlines(items_file, Item)
336
+
337
+ print_success(f"Loaded {len(items)} items")
338
+
339
+ # Load labels
340
+ with open(labels_file, encoding="utf-8") as f:
341
+ labels = [line.strip() for line in f if line.strip()]
342
+
343
+ if len(labels) != len(items):
344
+ print_error(
345
+ f"Number of labels ({len(labels)}) does not match "
346
+ f"number of items ({len(items)})"
347
+ )
348
+ ctx.exit(1)
349
+
350
+ print_success(f"Loaded {len(labels)} labels")
351
+
352
+ # Load participant IDs if provided
353
+ participant_ids = None
354
+ if participant_ids_file:
355
+ with open(participant_ids_file, encoding="utf-8") as f:
356
+ participant_ids = [line.strip() for line in f if line.strip()]
357
+
358
+ if len(participant_ids) != len(items):
359
+ print_error(
360
+ f"Number of participant IDs ({len(participant_ids)}) does not "
361
+ f"match number of items ({len(items)})"
362
+ )
363
+ ctx.exit(1)
364
+
365
+ unique_participants = len(set(participant_ids))
366
+ print_success(
367
+ f"Loaded {len(participant_ids)} participant IDs "
368
+ f"({unique_participants} unique participants)"
369
+ )
370
+
371
+ # Load validation data if provided
372
+ val_items = None
373
+ val_labels = None
374
+ if validation_items and validation_labels:
375
+ val_items = read_jsonlines(validation_items, Item)
376
+
377
+ with open(validation_labels, encoding="utf-8") as f:
378
+ val_labels = [line.strip() for line in f if line.strip()]
379
+
380
+ if len(val_labels) != len(val_items):
381
+ print_error(
382
+ f"Number of validation labels ({len(val_labels)}) does not "
383
+ f"match number of validation items ({len(val_items)})"
384
+ )
385
+ ctx.exit(1)
386
+
387
+ print_success(f"Loaded {len(val_items)} validation items")
388
+
389
+ # Build mixed effects config
390
+ # Cast to proper Literal type since Click validates the value
391
+ mode = cast(
392
+ Literal["fixed", "random_intercepts", "random_slopes"],
393
+ mixed_effects_mode,
394
+ )
395
+ mixed_effects_config = MixedEffectsConfig(mode=mode)
396
+
397
+ # Import model class and config dynamically
398
+ model_class = _import_class(TASK_TYPE_MODELS[task_type])
399
+ config_class = _import_class(TASK_TYPE_CONFIGS[task_type])
400
+
401
+ # Build model config
402
+ config_dict = {
403
+ "model_name": model_name,
404
+ "max_length": max_length,
405
+ "learning_rate": learning_rate,
406
+ "batch_size": batch_size,
407
+ "num_epochs": num_epochs,
408
+ "device": device,
409
+ "mixed_effects": mixed_effects_config,
410
+ }
411
+
412
+ # Add LoRA config if enabled
413
+ if use_lora:
414
+ config_dict["use_lora"] = True
415
+ config_dict["lora_rank"] = lora_rank
416
+ config_dict["lora_alpha"] = lora_alpha
417
+
418
+ model_config = config_class(**config_dict)
419
+
420
+ # Initialize model
421
+ console.rule("[bold]Initializing Model[/bold]")
422
+ model = model_class(config=model_config)
423
+
424
+ # Train model
425
+ console.rule("[bold]Training Model[/bold]")
426
+ print_info(
427
+ f"Training for {num_epochs} epochs on {device} "
428
+ f"(batch_size={batch_size}, lr={learning_rate})"
429
+ )
430
+
431
+ if use_lora:
432
+ print_info(f"Using LoRA fine-tuning (rank={lora_rank}, alpha={lora_alpha})")
433
+
434
+ metrics = model.train(
435
+ items=items,
436
+ labels=labels,
437
+ participant_ids=participant_ids,
438
+ validation_items=val_items,
439
+ validation_labels=val_labels,
440
+ )
441
+
442
+ # Display training metrics
443
+ console.rule("[bold]Training Results[/bold]")
444
+ table = Table(title="Training Metrics")
445
+ table.add_column("Metric", style="cyan")
446
+ table.add_column("Value", style="green", justify="right")
447
+
448
+ for metric_name, metric_value in metrics.items():
449
+ if isinstance(metric_value, float):
450
+ table.add_row(metric_name, f"{metric_value:.4f}")
451
+ else:
452
+ table.add_row(metric_name, str(metric_value))
453
+
454
+ console.print(table)
455
+
456
+ # Save model
457
+ console.rule("[bold]Saving Model[/bold]")
458
+ output_dir.mkdir(parents=True, exist_ok=True)
459
+
460
+ # Save model weights
461
+ model_path = output_dir / "model.pt"
462
+ model.save(model_path)
463
+ print_success(f"Saved model weights: {model_path}")
464
+
465
+ # Save config with task_type for later inference
466
+ config_path = output_dir / "config.json"
467
+ config_with_task_type = model_config.model_dump()
468
+ config_with_task_type["task_type"] = task_type # Add task type to config
469
+ with open(config_path, "w", encoding="utf-8") as f:
470
+ json.dump(config_with_task_type, f, indent=2)
471
+ print_success(f"Saved config: {config_path}")
472
+
473
+ # Save training metrics
474
+ metrics_path = output_dir / "training_metrics.json"
475
+ with open(metrics_path, "w", encoding="utf-8") as f:
476
+ json.dump(metrics, f, indent=2)
477
+ print_success(f"Saved training metrics: {metrics_path}")
478
+
479
+ console.rule("[bold green]✓ Training Complete[/bold green]")
480
+
481
+ except FileNotFoundError as e:
482
+ print_error(f"File not found: {e}")
483
+ ctx.exit(1)
484
+ except json.JSONDecodeError as e:
485
+ print_error(f"Invalid JSON in file: {e}")
486
+ ctx.exit(1)
487
+ except ValueError as e:
488
+ print_error(f"Invalid configuration or data: {e}")
489
+ ctx.exit(1)
490
+ except (ImportError, AttributeError) as e:
491
+ print_error(f"Failed to import model class: {e}")
492
+ print_info(
493
+ "This may indicate a corrupted installation. "
494
+ "Try reinstalling bead with: pip install --force-reinstall bead"
495
+ )
496
+ ctx.exit(1)
497
+
498
+
499
+ @click.command()
500
+ @click.option(
501
+ "--model-dir",
502
+ required=True,
503
+ type=click.Path(exists=True, path_type=Path),
504
+ help="Path to trained model directory",
505
+ )
506
+ @click.option(
507
+ "--items",
508
+ "items_file",
509
+ required=True,
510
+ type=click.Path(exists=True, path_type=Path),
511
+ help="Path to items JSONL file",
512
+ )
513
+ @click.option(
514
+ "--participant-ids",
515
+ "participant_ids_file",
516
+ type=click.Path(exists=True, path_type=Path),
517
+ help="Path to participant IDs file (required for random effects models)",
518
+ )
519
+ @click.option(
520
+ "--output",
521
+ "output_file",
522
+ required=True,
523
+ type=click.Path(path_type=Path),
524
+ help="Output path for predictions JSONL",
525
+ )
526
+ @click.pass_context
527
+ def predict(
528
+ ctx: click.Context,
529
+ model_dir: Path,
530
+ items_file: Path,
531
+ participant_ids_file: Path | None,
532
+ output_file: Path,
533
+ ) -> None:
534
+ r"""Make predictions with trained model.
535
+
536
+ Predicts class labels for items using a trained GLMM model.
537
+ For random effects models, participant IDs are required to compute
538
+ participant-specific predictions.
539
+
540
+ Parameters
541
+ ----------
542
+ ctx : click.Context
543
+ Click context object.
544
+ model_dir : Path
545
+ Path to trained model directory.
546
+ items_file : Path
547
+ Path to items JSONL file.
548
+ participant_ids_file : Path | None
549
+ Path to participant IDs file (required for random effects).
550
+ output_file : Path
551
+ Output path for predictions JSONL.
552
+
553
+ Examples
554
+ --------
555
+ $ bead models predict \\
556
+ --model-dir models/fc_model/ \\
557
+ --items test_items.jsonl \\
558
+ --output predictions.jsonl
559
+
560
+ $ bead models predict \\
561
+ --model-dir models/os_model/ \\
562
+ --items test_items.jsonl \\
563
+ --participant-ids participant_ids.txt \\
564
+ --output predictions.jsonl
565
+ """
566
+ try:
567
+ print_info(f"Loading model from {model_dir}")
568
+
569
+ # Load config
570
+ config_path = model_dir / "config.json"
571
+ if not config_path.exists():
572
+ print_error(f"Model config not found: {config_path}")
573
+ ctx.exit(1)
574
+
575
+ with open(config_path, encoding="utf-8") as f:
576
+ config_dict = json.load(f)
577
+
578
+ # Get task type from config
579
+ if "task_type" not in config_dict:
580
+ print_error(
581
+ "Model config missing 'task_type' field. "
582
+ "This model may have been trained with an older version of bead."
583
+ )
584
+ print_info("Valid task types: " + ", ".join(TASK_TYPE_MODELS.keys()))
585
+ ctx.exit(1)
586
+
587
+ task_type = config_dict["task_type"]
588
+ if task_type not in TASK_TYPE_MODELS:
589
+ print_error(
590
+ f"Unknown task type '{task_type}' in model config. "
591
+ f"Valid types: {', '.join(TASK_TYPE_MODELS.keys())}"
592
+ )
593
+ ctx.exit(1)
594
+
595
+ print_success(f"Detected task type: {task_type}")
596
+
597
+ # Import model class
598
+ model_class = _import_class(TASK_TYPE_MODELS[task_type])
599
+ config_class = _import_class(TASK_TYPE_CONFIGS[task_type])
600
+ model_config = config_class(**config_dict)
601
+
602
+ # Initialize model and load weights
603
+ model = model_class(config=model_config)
604
+ model_path = model_dir / "model.pt"
605
+ if not model_path.exists():
606
+ print_error(f"Model weights not found: {model_path}")
607
+ ctx.exit(1)
608
+
609
+ model.load(model_path)
610
+ print_success(f"Loaded model: {model_path}")
611
+
612
+ # Load items
613
+ with Progress(
614
+ SpinnerColumn(),
615
+ TextColumn("[progress.description]{task.description}"),
616
+ console=console,
617
+ ) as progress:
618
+ progress.add_task("Loading items...", total=None)
619
+ items = read_jsonlines(items_file, Item)
620
+
621
+ print_success(f"Loaded {len(items)} items")
622
+
623
+ # Load participant IDs if provided
624
+ participant_ids = None
625
+ if participant_ids_file:
626
+ with open(participant_ids_file, encoding="utf-8") as f:
627
+ participant_ids = [line.strip() for line in f if line.strip()]
628
+
629
+ if len(participant_ids) != len(items):
630
+ print_error(
631
+ f"Number of participant IDs ({len(participant_ids)}) does not "
632
+ f"match number of items ({len(items)})"
633
+ )
634
+ ctx.exit(1)
635
+
636
+ print_success(f"Loaded {len(participant_ids)} participant IDs")
637
+
638
+ # Make predictions
639
+ console.rule("[bold]Making Predictions[/bold]")
640
+ with Progress(
641
+ SpinnerColumn(),
642
+ TextColumn("[progress.description]{task.description}"),
643
+ console=console,
644
+ ) as progress:
645
+ progress.add_task("Predicting...", total=None)
646
+ predictions = model.predict(items=items, participant_ids=participant_ids)
647
+
648
+ # Save predictions
649
+ output_file.parent.mkdir(parents=True, exist_ok=True)
650
+ with open(output_file, "w", encoding="utf-8") as f:
651
+ for pred in predictions:
652
+ f.write(pred.model_dump_json() + "\n")
653
+
654
+ print_success(f"Saved {len(predictions)} predictions: {output_file}")
655
+
656
+ # Display sample predictions
657
+ console.rule("[bold]Sample Predictions[/bold]")
658
+ table = Table(title="First 5 Predictions")
659
+ table.add_column("Index", style="cyan", justify="right")
660
+ table.add_column("Predicted Label", style="green")
661
+ table.add_column("Confidence", style="yellow", justify="right")
662
+
663
+ for i, pred in enumerate(predictions[:5]):
664
+ confidence = pred.confidence if hasattr(pred, "confidence") else "N/A"
665
+ if isinstance(confidence, float):
666
+ confidence_str = f"{confidence:.3f}"
667
+ else:
668
+ confidence_str = str(confidence)
669
+ table.add_row(str(i), str(pred.predicted_label), confidence_str)
670
+
671
+ console.print(table)
672
+
673
+ except FileNotFoundError as e:
674
+ print_error(f"File not found: {e}")
675
+ ctx.exit(1)
676
+ except json.JSONDecodeError as e:
677
+ print_error(f"Invalid JSON in file: {e}")
678
+ ctx.exit(1)
679
+ except ValueError as e:
680
+ print_error(f"Invalid configuration or data: {e}")
681
+ ctx.exit(1)
682
+ except (ImportError, AttributeError) as e:
683
+ print_error(f"Failed to import model class: {e}")
684
+ print_info(
685
+ "This may indicate a corrupted installation. "
686
+ "Try reinstalling bead with: pip install --force-reinstall bead"
687
+ )
688
+ ctx.exit(1)
689
+
690
+
691
+ @click.command()
692
+ @click.option(
693
+ "--model-dir",
694
+ required=True,
695
+ type=click.Path(exists=True, path_type=Path),
696
+ help="Path to trained model directory",
697
+ )
698
+ @click.option(
699
+ "--items",
700
+ "items_file",
701
+ required=True,
702
+ type=click.Path(exists=True, path_type=Path),
703
+ help="Path to items JSONL file",
704
+ )
705
+ @click.option(
706
+ "--participant-ids",
707
+ "participant_ids_file",
708
+ type=click.Path(exists=True, path_type=Path),
709
+ help="Path to participant IDs file (required for random effects models)",
710
+ )
711
+ @click.option(
712
+ "--output",
713
+ "output_file",
714
+ required=True,
715
+ type=click.Path(path_type=Path),
716
+ help="Output path for probabilities JSON",
717
+ )
718
+ @click.pass_context
719
+ def predict_proba(
720
+ ctx: click.Context,
721
+ model_dir: Path,
722
+ items_file: Path,
723
+ participant_ids_file: Path | None,
724
+ output_file: Path,
725
+ ) -> None:
726
+ r"""Predict class probabilities with trained model.
727
+
728
+ Predicts class probability distributions for items using a trained GLMM
729
+ model. For random effects models, participant IDs are required.
730
+
731
+ Parameters
732
+ ----------
733
+ ctx : click.Context
734
+ Click context object.
735
+ model_dir : Path
736
+ Path to trained model directory.
737
+ items_file : Path
738
+ Path to items JSONL file.
739
+ participant_ids_file : Path | None
740
+ Path to participant IDs file (required for random effects).
741
+ output_file : Path
742
+ Output path for probabilities JSON.
743
+
744
+ Examples
745
+ --------
746
+ $ bead models predict-proba \\
747
+ --model-dir models/fc_model/ \\
748
+ --items test_items.jsonl \\
749
+ --output probabilities.json
750
+ """
751
+ try:
752
+ print_info(f"Loading model from {model_dir}")
753
+
754
+ # Load config
755
+ config_path = model_dir / "config.json"
756
+ if not config_path.exists():
757
+ print_error(f"Model config not found: {config_path}")
758
+ ctx.exit(1)
759
+
760
+ with open(config_path, encoding="utf-8") as f:
761
+ config_dict = json.load(f)
762
+
763
+ # Get task type from config
764
+ if "task_type" not in config_dict:
765
+ print_error(
766
+ "Model config missing 'task_type' field. "
767
+ "This model may have been trained with an older version of bead."
768
+ )
769
+ print_info("Valid task types: " + ", ".join(TASK_TYPE_MODELS.keys()))
770
+ ctx.exit(1)
771
+
772
+ task_type = config_dict["task_type"]
773
+ if task_type not in TASK_TYPE_MODELS:
774
+ print_error(
775
+ f"Unknown task type '{task_type}' in model config. "
776
+ f"Valid types: {', '.join(TASK_TYPE_MODELS.keys())}"
777
+ )
778
+ ctx.exit(1)
779
+
780
+ print_success(f"Detected task type: {task_type}")
781
+
782
+ # Import model class
783
+ model_class = _import_class(TASK_TYPE_MODELS[task_type])
784
+ config_class = _import_class(TASK_TYPE_CONFIGS[task_type])
785
+ model_config = config_class(**config_dict)
786
+
787
+ # Initialize model and load weights
788
+ model = model_class(config=model_config)
789
+ model_path = model_dir / "model.pt"
790
+ if not model_path.exists():
791
+ print_error(f"Model weights not found: {model_path}")
792
+ ctx.exit(1)
793
+
794
+ model.load(model_path)
795
+ print_success(f"Loaded model: {model_path}")
796
+
797
+ # Load items
798
+ with Progress(
799
+ SpinnerColumn(),
800
+ TextColumn("[progress.description]{task.description}"),
801
+ console=console,
802
+ ) as progress:
803
+ progress.add_task("Loading items...", total=None)
804
+ items = read_jsonlines(items_file, Item)
805
+
806
+ print_success(f"Loaded {len(items)} items")
807
+
808
+ # Load participant IDs if provided
809
+ participant_ids = None
810
+ if participant_ids_file:
811
+ with open(participant_ids_file, encoding="utf-8") as f:
812
+ participant_ids = [line.strip() for line in f if line.strip()]
813
+
814
+ if len(participant_ids) != len(items):
815
+ print_error(
816
+ f"Number of participant IDs ({len(participant_ids)}) does not "
817
+ f"match number of items ({len(items)})"
818
+ )
819
+ ctx.exit(1)
820
+
821
+ print_success(f"Loaded {len(participant_ids)} participant IDs")
822
+
823
+ # Predict probabilities
824
+ console.rule("[bold]Predicting Probabilities[/bold]")
825
+ with Progress(
826
+ SpinnerColumn(),
827
+ TextColumn("[progress.description]{task.description}"),
828
+ console=console,
829
+ ) as progress:
830
+ progress.add_task("Predicting...", total=None)
831
+ probabilities = model.predict_proba(
832
+ items=items, participant_ids=participant_ids
833
+ )
834
+
835
+ # Save probabilities
836
+ output_file.parent.mkdir(parents=True, exist_ok=True)
837
+ with open(output_file, "w", encoding="utf-8") as f:
838
+ json.dump(probabilities.tolist(), f, indent=2)
839
+
840
+ print_success(
841
+ f"Saved {len(probabilities)} probability distributions: {output_file}"
842
+ )
843
+
844
+ # Display sample probabilities
845
+ console.rule("[bold]Sample Probabilities[/bold]")
846
+ table = Table(title="First 5 Probability Distributions")
847
+ table.add_column("Index", style="cyan", justify="right")
848
+ table.add_column("Probabilities", style="green")
849
+
850
+ for i, prob in enumerate(probabilities[:5]):
851
+ prob_str = ", ".join([f"{p:.3f}" for p in prob])
852
+ table.add_row(str(i), f"[{prob_str}]")
853
+
854
+ console.print(table)
855
+
856
+ except FileNotFoundError as e:
857
+ print_error(f"File not found: {e}")
858
+ ctx.exit(1)
859
+ except json.JSONDecodeError as e:
860
+ print_error(f"Invalid JSON in file: {e}")
861
+ ctx.exit(1)
862
+ except ValueError as e:
863
+ print_error(f"Invalid configuration or data: {e}")
864
+ ctx.exit(1)
865
+ except (ImportError, AttributeError) as e:
866
+ print_error(f"Failed to import model class: {e}")
867
+ print_info(
868
+ "This may indicate a corrupted installation. "
869
+ "Try reinstalling bead with: pip install --force-reinstall bead"
870
+ )
871
+ ctx.exit(1)
872
+
873
+
874
+ # Register commands
875
+ models.add_command(train_model)
876
+ models.add_command(predict)
877
+ models.add_command(predict_proba)