bead 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bead/__init__.py +11 -0
- bead/__main__.py +11 -0
- bead/active_learning/__init__.py +15 -0
- bead/active_learning/config.py +231 -0
- bead/active_learning/loop.py +566 -0
- bead/active_learning/models/__init__.py +24 -0
- bead/active_learning/models/base.py +852 -0
- bead/active_learning/models/binary.py +910 -0
- bead/active_learning/models/categorical.py +943 -0
- bead/active_learning/models/cloze.py +862 -0
- bead/active_learning/models/forced_choice.py +956 -0
- bead/active_learning/models/free_text.py +773 -0
- bead/active_learning/models/lora.py +365 -0
- bead/active_learning/models/magnitude.py +835 -0
- bead/active_learning/models/multi_select.py +795 -0
- bead/active_learning/models/ordinal_scale.py +811 -0
- bead/active_learning/models/peft_adapter.py +155 -0
- bead/active_learning/models/random_effects.py +639 -0
- bead/active_learning/selection.py +354 -0
- bead/active_learning/strategies.py +391 -0
- bead/active_learning/trainers/__init__.py +26 -0
- bead/active_learning/trainers/base.py +210 -0
- bead/active_learning/trainers/data_collator.py +172 -0
- bead/active_learning/trainers/dataset_utils.py +261 -0
- bead/active_learning/trainers/huggingface.py +304 -0
- bead/active_learning/trainers/lightning.py +324 -0
- bead/active_learning/trainers/metrics.py +424 -0
- bead/active_learning/trainers/mixed_effects.py +551 -0
- bead/active_learning/trainers/model_wrapper.py +509 -0
- bead/active_learning/trainers/registry.py +104 -0
- bead/adapters/__init__.py +11 -0
- bead/adapters/huggingface.py +61 -0
- bead/behavioral/__init__.py +116 -0
- bead/behavioral/analytics.py +646 -0
- bead/behavioral/extraction.py +343 -0
- bead/behavioral/merging.py +343 -0
- bead/cli/__init__.py +11 -0
- bead/cli/active_learning.py +513 -0
- bead/cli/active_learning_commands.py +779 -0
- bead/cli/completion.py +359 -0
- bead/cli/config.py +624 -0
- bead/cli/constraint_builders.py +286 -0
- bead/cli/deployment.py +859 -0
- bead/cli/deployment_trials.py +493 -0
- bead/cli/deployment_ui.py +332 -0
- bead/cli/display.py +378 -0
- bead/cli/items.py +960 -0
- bead/cli/items_factories.py +776 -0
- bead/cli/list_constraints.py +714 -0
- bead/cli/lists.py +490 -0
- bead/cli/main.py +430 -0
- bead/cli/models.py +877 -0
- bead/cli/resource_loaders.py +621 -0
- bead/cli/resources.py +1036 -0
- bead/cli/shell.py +356 -0
- bead/cli/simulate.py +840 -0
- bead/cli/templates.py +1158 -0
- bead/cli/training.py +1080 -0
- bead/cli/utils.py +614 -0
- bead/cli/workflow.py +1273 -0
- bead/config/__init__.py +68 -0
- bead/config/active_learning.py +1009 -0
- bead/config/config.py +192 -0
- bead/config/defaults.py +118 -0
- bead/config/deployment.py +217 -0
- bead/config/env.py +147 -0
- bead/config/item.py +45 -0
- bead/config/list.py +193 -0
- bead/config/loader.py +149 -0
- bead/config/logging.py +42 -0
- bead/config/model.py +49 -0
- bead/config/paths.py +46 -0
- bead/config/profiles.py +320 -0
- bead/config/resources.py +47 -0
- bead/config/serialization.py +210 -0
- bead/config/simulation.py +206 -0
- bead/config/template.py +238 -0
- bead/config/validation.py +267 -0
- bead/data/__init__.py +65 -0
- bead/data/base.py +87 -0
- bead/data/identifiers.py +97 -0
- bead/data/language_codes.py +61 -0
- bead/data/metadata.py +270 -0
- bead/data/range.py +123 -0
- bead/data/repository.py +358 -0
- bead/data/serialization.py +249 -0
- bead/data/timestamps.py +89 -0
- bead/data/validation.py +349 -0
- bead/data_collection/__init__.py +11 -0
- bead/data_collection/jatos.py +223 -0
- bead/data_collection/merger.py +154 -0
- bead/data_collection/prolific.py +198 -0
- bead/deployment/__init__.py +5 -0
- bead/deployment/distribution.py +402 -0
- bead/deployment/jatos/__init__.py +1 -0
- bead/deployment/jatos/api.py +200 -0
- bead/deployment/jatos/exporter.py +210 -0
- bead/deployment/jspsych/__init__.py +9 -0
- bead/deployment/jspsych/biome.json +44 -0
- bead/deployment/jspsych/config.py +411 -0
- bead/deployment/jspsych/generator.py +598 -0
- bead/deployment/jspsych/package.json +51 -0
- bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
- bead/deployment/jspsych/randomizer.py +299 -0
- bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
- bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
- bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
- bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
- bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
- bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
- bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
- bead/deployment/jspsych/src/plugins/rating.ts +248 -0
- bead/deployment/jspsych/src/slopit/index.ts +9 -0
- bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
- bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
- bead/deployment/jspsych/templates/experiment.css +1 -0
- bead/deployment/jspsych/templates/experiment.js.template +289 -0
- bead/deployment/jspsych/templates/index.html +51 -0
- bead/deployment/jspsych/templates/randomizer.js +241 -0
- bead/deployment/jspsych/templates/randomizer.js.template +313 -0
- bead/deployment/jspsych/trials.py +723 -0
- bead/deployment/jspsych/tsconfig.json +23 -0
- bead/deployment/jspsych/tsup.config.ts +30 -0
- bead/deployment/jspsych/ui/__init__.py +1 -0
- bead/deployment/jspsych/ui/components.py +383 -0
- bead/deployment/jspsych/ui/styles.py +411 -0
- bead/dsl/__init__.py +80 -0
- bead/dsl/ast.py +168 -0
- bead/dsl/context.py +178 -0
- bead/dsl/errors.py +71 -0
- bead/dsl/evaluator.py +570 -0
- bead/dsl/grammar.lark +81 -0
- bead/dsl/parser.py +231 -0
- bead/dsl/stdlib.py +929 -0
- bead/evaluation/__init__.py +13 -0
- bead/evaluation/convergence.py +485 -0
- bead/evaluation/interannotator.py +398 -0
- bead/items/__init__.py +40 -0
- bead/items/adapters/__init__.py +70 -0
- bead/items/adapters/anthropic.py +224 -0
- bead/items/adapters/api_utils.py +167 -0
- bead/items/adapters/base.py +216 -0
- bead/items/adapters/google.py +259 -0
- bead/items/adapters/huggingface.py +1074 -0
- bead/items/adapters/openai.py +323 -0
- bead/items/adapters/registry.py +202 -0
- bead/items/adapters/sentence_transformers.py +224 -0
- bead/items/adapters/togetherai.py +309 -0
- bead/items/binary.py +515 -0
- bead/items/cache.py +558 -0
- bead/items/categorical.py +593 -0
- bead/items/cloze.py +757 -0
- bead/items/constructor.py +784 -0
- bead/items/forced_choice.py +413 -0
- bead/items/free_text.py +681 -0
- bead/items/generation.py +432 -0
- bead/items/item.py +396 -0
- bead/items/item_template.py +787 -0
- bead/items/magnitude.py +573 -0
- bead/items/multi_select.py +621 -0
- bead/items/ordinal_scale.py +569 -0
- bead/items/scoring.py +448 -0
- bead/items/validation.py +723 -0
- bead/lists/__init__.py +30 -0
- bead/lists/balancer.py +263 -0
- bead/lists/constraints.py +1067 -0
- bead/lists/experiment_list.py +286 -0
- bead/lists/list_collection.py +378 -0
- bead/lists/partitioner.py +1141 -0
- bead/lists/stratification.py +254 -0
- bead/participants/__init__.py +73 -0
- bead/participants/collection.py +699 -0
- bead/participants/merging.py +312 -0
- bead/participants/metadata_spec.py +491 -0
- bead/participants/models.py +276 -0
- bead/resources/__init__.py +29 -0
- bead/resources/adapters/__init__.py +19 -0
- bead/resources/adapters/base.py +104 -0
- bead/resources/adapters/cache.py +128 -0
- bead/resources/adapters/glazing.py +508 -0
- bead/resources/adapters/registry.py +117 -0
- bead/resources/adapters/unimorph.py +796 -0
- bead/resources/classification.py +856 -0
- bead/resources/constraint_builders.py +329 -0
- bead/resources/constraints.py +165 -0
- bead/resources/lexical_item.py +223 -0
- bead/resources/lexicon.py +744 -0
- bead/resources/loaders.py +209 -0
- bead/resources/template.py +441 -0
- bead/resources/template_collection.py +707 -0
- bead/resources/template_generation.py +349 -0
- bead/simulation/__init__.py +29 -0
- bead/simulation/annotators/__init__.py +15 -0
- bead/simulation/annotators/base.py +175 -0
- bead/simulation/annotators/distance_based.py +135 -0
- bead/simulation/annotators/lm_based.py +114 -0
- bead/simulation/annotators/oracle.py +182 -0
- bead/simulation/annotators/random.py +181 -0
- bead/simulation/dsl_extension/__init__.py +3 -0
- bead/simulation/noise_models/__init__.py +13 -0
- bead/simulation/noise_models/base.py +42 -0
- bead/simulation/noise_models/random_noise.py +82 -0
- bead/simulation/noise_models/systematic.py +132 -0
- bead/simulation/noise_models/temperature.py +86 -0
- bead/simulation/runner.py +144 -0
- bead/simulation/strategies/__init__.py +23 -0
- bead/simulation/strategies/base.py +123 -0
- bead/simulation/strategies/binary.py +103 -0
- bead/simulation/strategies/categorical.py +123 -0
- bead/simulation/strategies/cloze.py +224 -0
- bead/simulation/strategies/forced_choice.py +127 -0
- bead/simulation/strategies/free_text.py +105 -0
- bead/simulation/strategies/magnitude.py +116 -0
- bead/simulation/strategies/multi_select.py +129 -0
- bead/simulation/strategies/ordinal_scale.py +131 -0
- bead/templates/__init__.py +27 -0
- bead/templates/adapters/__init__.py +17 -0
- bead/templates/adapters/base.py +128 -0
- bead/templates/adapters/cache.py +178 -0
- bead/templates/adapters/huggingface.py +312 -0
- bead/templates/combinatorics.py +103 -0
- bead/templates/filler.py +605 -0
- bead/templates/renderers.py +177 -0
- bead/templates/resolver.py +178 -0
- bead/templates/strategies.py +1806 -0
- bead/templates/streaming.py +195 -0
- bead-0.1.0.dist-info/METADATA +212 -0
- bead-0.1.0.dist-info/RECORD +231 -0
- bead-0.1.0.dist-info/WHEEL +4 -0
- bead-0.1.0.dist-info/entry_points.txt +2 -0
- bead-0.1.0.dist-info/licenses/LICENSE +21 -0
bead/cli/models.py
ADDED
|
@@ -0,0 +1,877 @@
|
|
|
1
|
+
"""Model training commands for bead CLI.
|
|
2
|
+
|
|
3
|
+
This module provides commands for training GLMM models across all 8 task types
|
|
4
|
+
with support for fixed effects, random intercepts, and random slopes.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Literal, cast
|
|
12
|
+
|
|
13
|
+
import click
|
|
14
|
+
from rich.console import Console
|
|
15
|
+
from rich.progress import Progress, SpinnerColumn, TextColumn
|
|
16
|
+
from rich.table import Table
|
|
17
|
+
|
|
18
|
+
from bead.active_learning.config import MixedEffectsConfig
|
|
19
|
+
from bead.cli.display import (
|
|
20
|
+
print_error,
|
|
21
|
+
print_info,
|
|
22
|
+
print_success,
|
|
23
|
+
)
|
|
24
|
+
from bead.data.serialization import read_jsonlines
|
|
25
|
+
from bead.items.item import Item
|
|
26
|
+
|
|
27
|
+
console = Console()
|
|
28
|
+
|
|
29
|
+
# Task type to model class mapping
|
|
30
|
+
TASK_TYPE_MODELS = {
|
|
31
|
+
"forced_choice": "bead.active_learning.models.forced_choice.ForcedChoiceModel",
|
|
32
|
+
"categorical": "bead.active_learning.models.categorical.CategoricalModel",
|
|
33
|
+
"binary": "bead.active_learning.models.binary.BinaryModel",
|
|
34
|
+
"multi_select": "bead.active_learning.models.multi_select.MultiSelectModel",
|
|
35
|
+
"ordinal_scale": "bead.active_learning.models.ordinal_scale.OrdinalScaleModel",
|
|
36
|
+
"magnitude": "bead.active_learning.models.magnitude.MagnitudeModel",
|
|
37
|
+
"free_text": "bead.active_learning.models.free_text.FreeTextModel",
|
|
38
|
+
"cloze": "bead.active_learning.models.cloze.ClozeModel",
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
# Config classes for each task type
|
|
42
|
+
TASK_TYPE_CONFIGS = {
|
|
43
|
+
"forced_choice": "bead.config.active_learning.ForcedChoiceModelConfig",
|
|
44
|
+
"categorical": "bead.config.active_learning.CategoricalModelConfig",
|
|
45
|
+
"binary": "bead.config.active_learning.BinaryModelConfig",
|
|
46
|
+
"multi_select": "bead.config.active_learning.MultiSelectModelConfig",
|
|
47
|
+
"ordinal_scale": "bead.config.active_learning.OrdinalScaleModelConfig",
|
|
48
|
+
"magnitude": "bead.config.active_learning.MagnitudeModelConfig",
|
|
49
|
+
"free_text": "bead.config.active_learning.FreeTextModelConfig",
|
|
50
|
+
"cloze": "bead.config.active_learning.ClozeModelConfig",
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _import_class(module_path: str) -> type:
|
|
55
|
+
"""Dynamically import a class from module path.
|
|
56
|
+
|
|
57
|
+
Parameters
|
|
58
|
+
----------
|
|
59
|
+
module_path : str
|
|
60
|
+
Fully qualified path to class (e.g., 'bead.models.forced_choice.Model').
|
|
61
|
+
|
|
62
|
+
Returns
|
|
63
|
+
-------
|
|
64
|
+
type
|
|
65
|
+
Imported class.
|
|
66
|
+
"""
|
|
67
|
+
module_name, class_name = module_path.rsplit(".", 1)
|
|
68
|
+
module = __import__(module_name, fromlist=[class_name])
|
|
69
|
+
return getattr(module, class_name)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@click.group()
|
|
73
|
+
def models() -> None:
|
|
74
|
+
r"""Model training commands.
|
|
75
|
+
|
|
76
|
+
Commands for training GLMM models for judgment prediction across all 8
|
|
77
|
+
task types with support for mixed effects modeling.
|
|
78
|
+
|
|
79
|
+
\b
|
|
80
|
+
Task Types:
|
|
81
|
+
• forced_choice - 2AFC, 3AFC, N-way forced choice
|
|
82
|
+
• categorical - Unordered categories (NLI, semantic relations)
|
|
83
|
+
• binary - Yes/No, True/False
|
|
84
|
+
• multi_select - Multiple selection (checkboxes)
|
|
85
|
+
• ordinal_scale - Likert scales, sliders
|
|
86
|
+
• magnitude - Unbounded numeric (reading time, confidence)
|
|
87
|
+
• free_text - Open-ended text responses
|
|
88
|
+
• cloze - Fill-in-the-blank
|
|
89
|
+
|
|
90
|
+
\b
|
|
91
|
+
Mixed Effects Modes:
|
|
92
|
+
• fixed - Fixed effects only (no participant variability)
|
|
93
|
+
• random_intercepts - Participant-specific biases
|
|
94
|
+
• random_slopes - Participant-specific model parameters
|
|
95
|
+
|
|
96
|
+
\b
|
|
97
|
+
Examples:
|
|
98
|
+
# Train forced choice model with fixed effects
|
|
99
|
+
$ bead models train-model \\
|
|
100
|
+
--task-type forced_choice \\
|
|
101
|
+
--items items.jsonl \\
|
|
102
|
+
--labels labels.jsonl \\
|
|
103
|
+
--output-dir models/fc_model/
|
|
104
|
+
|
|
105
|
+
# Train with random intercepts
|
|
106
|
+
$ bead models train-model \\
|
|
107
|
+
--task-type ordinal_scale \\
|
|
108
|
+
--items items.jsonl \\
|
|
109
|
+
--labels labels.jsonl \\
|
|
110
|
+
--participant-ids participant_ids.txt \\
|
|
111
|
+
--mixed-effects-mode random_intercepts \\
|
|
112
|
+
--output-dir models/os_model/
|
|
113
|
+
|
|
114
|
+
# Make predictions
|
|
115
|
+
$ bead models predict \\
|
|
116
|
+
--model-dir models/fc_model/ \\
|
|
117
|
+
--items test_items.jsonl \\
|
|
118
|
+
--output predictions.jsonl
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
@click.command()
|
|
123
|
+
@click.option(
|
|
124
|
+
"--task-type",
|
|
125
|
+
required=True,
|
|
126
|
+
type=click.Choice(list(TASK_TYPE_MODELS.keys())),
|
|
127
|
+
help="Task type for model",
|
|
128
|
+
)
|
|
129
|
+
@click.option(
|
|
130
|
+
"--items",
|
|
131
|
+
"items_file",
|
|
132
|
+
required=True,
|
|
133
|
+
type=click.Path(exists=True, path_type=Path),
|
|
134
|
+
help="Path to items JSONL file",
|
|
135
|
+
)
|
|
136
|
+
@click.option(
|
|
137
|
+
"--labels",
|
|
138
|
+
"labels_file",
|
|
139
|
+
required=True,
|
|
140
|
+
type=click.Path(exists=True, path_type=Path),
|
|
141
|
+
help="Path to labels JSONL file (list of response strings)",
|
|
142
|
+
)
|
|
143
|
+
@click.option(
|
|
144
|
+
"--participant-ids",
|
|
145
|
+
"participant_ids_file",
|
|
146
|
+
type=click.Path(exists=True, path_type=Path),
|
|
147
|
+
help="Path to participant IDs file (one ID per line, aligned with labels)",
|
|
148
|
+
)
|
|
149
|
+
@click.option(
|
|
150
|
+
"--validation-items",
|
|
151
|
+
type=click.Path(exists=True, path_type=Path),
|
|
152
|
+
help="Path to validation items JSONL file (optional)",
|
|
153
|
+
)
|
|
154
|
+
@click.option(
|
|
155
|
+
"--validation-labels",
|
|
156
|
+
type=click.Path(exists=True, path_type=Path),
|
|
157
|
+
help="Path to validation labels JSONL file (optional)",
|
|
158
|
+
)
|
|
159
|
+
@click.option(
|
|
160
|
+
"--output-dir",
|
|
161
|
+
required=True,
|
|
162
|
+
type=click.Path(path_type=Path),
|
|
163
|
+
help="Output directory for trained model",
|
|
164
|
+
)
|
|
165
|
+
@click.option(
|
|
166
|
+
"--model-name",
|
|
167
|
+
default="bert-base-uncased",
|
|
168
|
+
help="HuggingFace model name",
|
|
169
|
+
)
|
|
170
|
+
@click.option(
|
|
171
|
+
"--mixed-effects-mode",
|
|
172
|
+
type=click.Choice(["fixed", "random_intercepts", "random_slopes"]),
|
|
173
|
+
default="fixed",
|
|
174
|
+
help="Mixed effects mode",
|
|
175
|
+
)
|
|
176
|
+
@click.option(
|
|
177
|
+
"--max-length",
|
|
178
|
+
type=int,
|
|
179
|
+
default=128,
|
|
180
|
+
help="Maximum sequence length for tokenization",
|
|
181
|
+
)
|
|
182
|
+
@click.option(
|
|
183
|
+
"--learning-rate",
|
|
184
|
+
type=float,
|
|
185
|
+
default=2e-5,
|
|
186
|
+
help="Learning rate for AdamW optimizer",
|
|
187
|
+
)
|
|
188
|
+
@click.option(
|
|
189
|
+
"--batch-size",
|
|
190
|
+
type=int,
|
|
191
|
+
default=16,
|
|
192
|
+
help="Batch size for training",
|
|
193
|
+
)
|
|
194
|
+
@click.option(
|
|
195
|
+
"--num-epochs",
|
|
196
|
+
type=int,
|
|
197
|
+
default=3,
|
|
198
|
+
help="Number of training epochs",
|
|
199
|
+
)
|
|
200
|
+
@click.option(
|
|
201
|
+
"--device",
|
|
202
|
+
type=click.Choice(["cpu", "cuda", "mps"]),
|
|
203
|
+
default="cpu",
|
|
204
|
+
help="Device to train on",
|
|
205
|
+
)
|
|
206
|
+
@click.option(
|
|
207
|
+
"--use-lora",
|
|
208
|
+
is_flag=True,
|
|
209
|
+
help="Use LoRA parameter-efficient fine-tuning",
|
|
210
|
+
)
|
|
211
|
+
@click.option(
|
|
212
|
+
"--lora-rank",
|
|
213
|
+
type=int,
|
|
214
|
+
default=8,
|
|
215
|
+
help="LoRA rank (r)",
|
|
216
|
+
)
|
|
217
|
+
@click.option(
|
|
218
|
+
"--lora-alpha",
|
|
219
|
+
type=int,
|
|
220
|
+
default=16,
|
|
221
|
+
help="LoRA alpha scaling parameter",
|
|
222
|
+
)
|
|
223
|
+
@click.pass_context
|
|
224
|
+
def train_model(
|
|
225
|
+
ctx: click.Context,
|
|
226
|
+
task_type: str,
|
|
227
|
+
items_file: Path,
|
|
228
|
+
labels_file: Path,
|
|
229
|
+
participant_ids_file: Path | None,
|
|
230
|
+
validation_items: Path | None,
|
|
231
|
+
validation_labels: Path | None,
|
|
232
|
+
output_dir: Path,
|
|
233
|
+
model_name: str,
|
|
234
|
+
mixed_effects_mode: str,
|
|
235
|
+
max_length: int,
|
|
236
|
+
learning_rate: float,
|
|
237
|
+
batch_size: int,
|
|
238
|
+
num_epochs: int,
|
|
239
|
+
device: str,
|
|
240
|
+
use_lora: bool,
|
|
241
|
+
lora_rank: int,
|
|
242
|
+
lora_alpha: int,
|
|
243
|
+
) -> None:
|
|
244
|
+
r"""Train GLMM model for judgment prediction.
|
|
245
|
+
|
|
246
|
+
Trains a generalized linear mixed model (GLMM) with support for:
|
|
247
|
+
- Fixed effects (population-level parameters)
|
|
248
|
+
- Random intercepts (participant-specific biases)
|
|
249
|
+
- Random slopes (participant-specific model parameters)
|
|
250
|
+
|
|
251
|
+
The model uses a transformer encoder (default: BERT) with optional
|
|
252
|
+
LoRA parameter-efficient fine-tuning.
|
|
253
|
+
|
|
254
|
+
Parameters
|
|
255
|
+
----------
|
|
256
|
+
ctx : click.Context
|
|
257
|
+
Click context object.
|
|
258
|
+
task_type : str
|
|
259
|
+
Task type (forced_choice, categorical, binary, etc.).
|
|
260
|
+
items_file : Path
|
|
261
|
+
Path to items JSONL file.
|
|
262
|
+
labels_file : Path
|
|
263
|
+
Path to labels JSONL file (one label per line).
|
|
264
|
+
participant_ids_file : Path | None
|
|
265
|
+
Path to participant IDs file (required for random effects).
|
|
266
|
+
validation_items : Path | None
|
|
267
|
+
Path to validation items JSONL file (optional).
|
|
268
|
+
validation_labels : Path | None
|
|
269
|
+
Path to validation labels JSONL file (optional).
|
|
270
|
+
output_dir : Path
|
|
271
|
+
Output directory for trained model.
|
|
272
|
+
model_name : str
|
|
273
|
+
HuggingFace model name.
|
|
274
|
+
mixed_effects_mode : str
|
|
275
|
+
Mixed effects mode (fixed, random_intercepts, random_slopes).
|
|
276
|
+
max_length : int
|
|
277
|
+
Maximum sequence length for tokenization.
|
|
278
|
+
learning_rate : float
|
|
279
|
+
Learning rate for AdamW optimizer.
|
|
280
|
+
batch_size : int
|
|
281
|
+
Batch size for training.
|
|
282
|
+
num_epochs : int
|
|
283
|
+
Number of training epochs.
|
|
284
|
+
device : str
|
|
285
|
+
Device to train on (cpu, cuda, mps).
|
|
286
|
+
use_lora : bool
|
|
287
|
+
Whether to use LoRA fine-tuning.
|
|
288
|
+
lora_rank : int
|
|
289
|
+
LoRA rank.
|
|
290
|
+
lora_alpha : int
|
|
291
|
+
LoRA alpha scaling parameter.
|
|
292
|
+
|
|
293
|
+
Examples
|
|
294
|
+
--------
|
|
295
|
+
$ bead models train-model \\
|
|
296
|
+
--task-type forced_choice \\
|
|
297
|
+
--items items.jsonl \\
|
|
298
|
+
--labels labels.jsonl \\
|
|
299
|
+
--output-dir models/fc_model/ \\
|
|
300
|
+
--num-epochs 5
|
|
301
|
+
|
|
302
|
+
$ bead models train-model \\
|
|
303
|
+
--task-type ordinal_scale \\
|
|
304
|
+
--items items.jsonl \\
|
|
305
|
+
--labels labels.jsonl \\
|
|
306
|
+
--participant-ids participant_ids.txt \\
|
|
307
|
+
--mixed-effects-mode random_intercepts \\
|
|
308
|
+
--output-dir models/os_model/ \\
|
|
309
|
+
--device cuda \\
|
|
310
|
+
--use-lora \\
|
|
311
|
+
--lora-rank 8
|
|
312
|
+
"""
|
|
313
|
+
try:
|
|
314
|
+
# Validate mixed effects mode requirements
|
|
315
|
+
if mixed_effects_mode != "fixed" and participant_ids_file is None:
|
|
316
|
+
print_error(
|
|
317
|
+
f"Mixed effects mode '{mixed_effects_mode}' requires "
|
|
318
|
+
"--participant-ids parameter"
|
|
319
|
+
)
|
|
320
|
+
print_info(
|
|
321
|
+
"Provide a file with one participant ID per line, "
|
|
322
|
+
"aligned with the labels file"
|
|
323
|
+
)
|
|
324
|
+
ctx.exit(1)
|
|
325
|
+
|
|
326
|
+
print_info(f"Training {task_type} model with {mixed_effects_mode} mode")
|
|
327
|
+
|
|
328
|
+
# Load items
|
|
329
|
+
with Progress(
|
|
330
|
+
SpinnerColumn(),
|
|
331
|
+
TextColumn("[progress.description]{task.description}"),
|
|
332
|
+
console=console,
|
|
333
|
+
) as progress:
|
|
334
|
+
progress.add_task("Loading items...", total=None)
|
|
335
|
+
items = read_jsonlines(items_file, Item)
|
|
336
|
+
|
|
337
|
+
print_success(f"Loaded {len(items)} items")
|
|
338
|
+
|
|
339
|
+
# Load labels
|
|
340
|
+
with open(labels_file, encoding="utf-8") as f:
|
|
341
|
+
labels = [line.strip() for line in f if line.strip()]
|
|
342
|
+
|
|
343
|
+
if len(labels) != len(items):
|
|
344
|
+
print_error(
|
|
345
|
+
f"Number of labels ({len(labels)}) does not match "
|
|
346
|
+
f"number of items ({len(items)})"
|
|
347
|
+
)
|
|
348
|
+
ctx.exit(1)
|
|
349
|
+
|
|
350
|
+
print_success(f"Loaded {len(labels)} labels")
|
|
351
|
+
|
|
352
|
+
# Load participant IDs if provided
|
|
353
|
+
participant_ids = None
|
|
354
|
+
if participant_ids_file:
|
|
355
|
+
with open(participant_ids_file, encoding="utf-8") as f:
|
|
356
|
+
participant_ids = [line.strip() for line in f if line.strip()]
|
|
357
|
+
|
|
358
|
+
if len(participant_ids) != len(items):
|
|
359
|
+
print_error(
|
|
360
|
+
f"Number of participant IDs ({len(participant_ids)}) does not "
|
|
361
|
+
f"match number of items ({len(items)})"
|
|
362
|
+
)
|
|
363
|
+
ctx.exit(1)
|
|
364
|
+
|
|
365
|
+
unique_participants = len(set(participant_ids))
|
|
366
|
+
print_success(
|
|
367
|
+
f"Loaded {len(participant_ids)} participant IDs "
|
|
368
|
+
f"({unique_participants} unique participants)"
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
# Load validation data if provided
|
|
372
|
+
val_items = None
|
|
373
|
+
val_labels = None
|
|
374
|
+
if validation_items and validation_labels:
|
|
375
|
+
val_items = read_jsonlines(validation_items, Item)
|
|
376
|
+
|
|
377
|
+
with open(validation_labels, encoding="utf-8") as f:
|
|
378
|
+
val_labels = [line.strip() for line in f if line.strip()]
|
|
379
|
+
|
|
380
|
+
if len(val_labels) != len(val_items):
|
|
381
|
+
print_error(
|
|
382
|
+
f"Number of validation labels ({len(val_labels)}) does not "
|
|
383
|
+
f"match number of validation items ({len(val_items)})"
|
|
384
|
+
)
|
|
385
|
+
ctx.exit(1)
|
|
386
|
+
|
|
387
|
+
print_success(f"Loaded {len(val_items)} validation items")
|
|
388
|
+
|
|
389
|
+
# Build mixed effects config
|
|
390
|
+
# Cast to proper Literal type since Click validates the value
|
|
391
|
+
mode = cast(
|
|
392
|
+
Literal["fixed", "random_intercepts", "random_slopes"],
|
|
393
|
+
mixed_effects_mode,
|
|
394
|
+
)
|
|
395
|
+
mixed_effects_config = MixedEffectsConfig(mode=mode)
|
|
396
|
+
|
|
397
|
+
# Import model class and config dynamically
|
|
398
|
+
model_class = _import_class(TASK_TYPE_MODELS[task_type])
|
|
399
|
+
config_class = _import_class(TASK_TYPE_CONFIGS[task_type])
|
|
400
|
+
|
|
401
|
+
# Build model config
|
|
402
|
+
config_dict = {
|
|
403
|
+
"model_name": model_name,
|
|
404
|
+
"max_length": max_length,
|
|
405
|
+
"learning_rate": learning_rate,
|
|
406
|
+
"batch_size": batch_size,
|
|
407
|
+
"num_epochs": num_epochs,
|
|
408
|
+
"device": device,
|
|
409
|
+
"mixed_effects": mixed_effects_config,
|
|
410
|
+
}
|
|
411
|
+
|
|
412
|
+
# Add LoRA config if enabled
|
|
413
|
+
if use_lora:
|
|
414
|
+
config_dict["use_lora"] = True
|
|
415
|
+
config_dict["lora_rank"] = lora_rank
|
|
416
|
+
config_dict["lora_alpha"] = lora_alpha
|
|
417
|
+
|
|
418
|
+
model_config = config_class(**config_dict)
|
|
419
|
+
|
|
420
|
+
# Initialize model
|
|
421
|
+
console.rule("[bold]Initializing Model[/bold]")
|
|
422
|
+
model = model_class(config=model_config)
|
|
423
|
+
|
|
424
|
+
# Train model
|
|
425
|
+
console.rule("[bold]Training Model[/bold]")
|
|
426
|
+
print_info(
|
|
427
|
+
f"Training for {num_epochs} epochs on {device} "
|
|
428
|
+
f"(batch_size={batch_size}, lr={learning_rate})"
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
if use_lora:
|
|
432
|
+
print_info(f"Using LoRA fine-tuning (rank={lora_rank}, alpha={lora_alpha})")
|
|
433
|
+
|
|
434
|
+
metrics = model.train(
|
|
435
|
+
items=items,
|
|
436
|
+
labels=labels,
|
|
437
|
+
participant_ids=participant_ids,
|
|
438
|
+
validation_items=val_items,
|
|
439
|
+
validation_labels=val_labels,
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
# Display training metrics
|
|
443
|
+
console.rule("[bold]Training Results[/bold]")
|
|
444
|
+
table = Table(title="Training Metrics")
|
|
445
|
+
table.add_column("Metric", style="cyan")
|
|
446
|
+
table.add_column("Value", style="green", justify="right")
|
|
447
|
+
|
|
448
|
+
for metric_name, metric_value in metrics.items():
|
|
449
|
+
if isinstance(metric_value, float):
|
|
450
|
+
table.add_row(metric_name, f"{metric_value:.4f}")
|
|
451
|
+
else:
|
|
452
|
+
table.add_row(metric_name, str(metric_value))
|
|
453
|
+
|
|
454
|
+
console.print(table)
|
|
455
|
+
|
|
456
|
+
# Save model
|
|
457
|
+
console.rule("[bold]Saving Model[/bold]")
|
|
458
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
459
|
+
|
|
460
|
+
# Save model weights
|
|
461
|
+
model_path = output_dir / "model.pt"
|
|
462
|
+
model.save(model_path)
|
|
463
|
+
print_success(f"Saved model weights: {model_path}")
|
|
464
|
+
|
|
465
|
+
# Save config with task_type for later inference
|
|
466
|
+
config_path = output_dir / "config.json"
|
|
467
|
+
config_with_task_type = model_config.model_dump()
|
|
468
|
+
config_with_task_type["task_type"] = task_type # Add task type to config
|
|
469
|
+
with open(config_path, "w", encoding="utf-8") as f:
|
|
470
|
+
json.dump(config_with_task_type, f, indent=2)
|
|
471
|
+
print_success(f"Saved config: {config_path}")
|
|
472
|
+
|
|
473
|
+
# Save training metrics
|
|
474
|
+
metrics_path = output_dir / "training_metrics.json"
|
|
475
|
+
with open(metrics_path, "w", encoding="utf-8") as f:
|
|
476
|
+
json.dump(metrics, f, indent=2)
|
|
477
|
+
print_success(f"Saved training metrics: {metrics_path}")
|
|
478
|
+
|
|
479
|
+
console.rule("[bold green]✓ Training Complete[/bold green]")
|
|
480
|
+
|
|
481
|
+
except FileNotFoundError as e:
|
|
482
|
+
print_error(f"File not found: {e}")
|
|
483
|
+
ctx.exit(1)
|
|
484
|
+
except json.JSONDecodeError as e:
|
|
485
|
+
print_error(f"Invalid JSON in file: {e}")
|
|
486
|
+
ctx.exit(1)
|
|
487
|
+
except ValueError as e:
|
|
488
|
+
print_error(f"Invalid configuration or data: {e}")
|
|
489
|
+
ctx.exit(1)
|
|
490
|
+
except (ImportError, AttributeError) as e:
|
|
491
|
+
print_error(f"Failed to import model class: {e}")
|
|
492
|
+
print_info(
|
|
493
|
+
"This may indicate a corrupted installation. "
|
|
494
|
+
"Try reinstalling bead with: pip install --force-reinstall bead"
|
|
495
|
+
)
|
|
496
|
+
ctx.exit(1)
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
@click.command()
|
|
500
|
+
@click.option(
|
|
501
|
+
"--model-dir",
|
|
502
|
+
required=True,
|
|
503
|
+
type=click.Path(exists=True, path_type=Path),
|
|
504
|
+
help="Path to trained model directory",
|
|
505
|
+
)
|
|
506
|
+
@click.option(
|
|
507
|
+
"--items",
|
|
508
|
+
"items_file",
|
|
509
|
+
required=True,
|
|
510
|
+
type=click.Path(exists=True, path_type=Path),
|
|
511
|
+
help="Path to items JSONL file",
|
|
512
|
+
)
|
|
513
|
+
@click.option(
|
|
514
|
+
"--participant-ids",
|
|
515
|
+
"participant_ids_file",
|
|
516
|
+
type=click.Path(exists=True, path_type=Path),
|
|
517
|
+
help="Path to participant IDs file (required for random effects models)",
|
|
518
|
+
)
|
|
519
|
+
@click.option(
|
|
520
|
+
"--output",
|
|
521
|
+
"output_file",
|
|
522
|
+
required=True,
|
|
523
|
+
type=click.Path(path_type=Path),
|
|
524
|
+
help="Output path for predictions JSONL",
|
|
525
|
+
)
|
|
526
|
+
@click.pass_context
|
|
527
|
+
def predict(
|
|
528
|
+
ctx: click.Context,
|
|
529
|
+
model_dir: Path,
|
|
530
|
+
items_file: Path,
|
|
531
|
+
participant_ids_file: Path | None,
|
|
532
|
+
output_file: Path,
|
|
533
|
+
) -> None:
|
|
534
|
+
r"""Make predictions with trained model.
|
|
535
|
+
|
|
536
|
+
Predicts class labels for items using a trained GLMM model.
|
|
537
|
+
For random effects models, participant IDs are required to compute
|
|
538
|
+
participant-specific predictions.
|
|
539
|
+
|
|
540
|
+
Parameters
|
|
541
|
+
----------
|
|
542
|
+
ctx : click.Context
|
|
543
|
+
Click context object.
|
|
544
|
+
model_dir : Path
|
|
545
|
+
Path to trained model directory.
|
|
546
|
+
items_file : Path
|
|
547
|
+
Path to items JSONL file.
|
|
548
|
+
participant_ids_file : Path | None
|
|
549
|
+
Path to participant IDs file (required for random effects).
|
|
550
|
+
output_file : Path
|
|
551
|
+
Output path for predictions JSONL.
|
|
552
|
+
|
|
553
|
+
Examples
|
|
554
|
+
--------
|
|
555
|
+
$ bead models predict \\
|
|
556
|
+
--model-dir models/fc_model/ \\
|
|
557
|
+
--items test_items.jsonl \\
|
|
558
|
+
--output predictions.jsonl
|
|
559
|
+
|
|
560
|
+
$ bead models predict \\
|
|
561
|
+
--model-dir models/os_model/ \\
|
|
562
|
+
--items test_items.jsonl \\
|
|
563
|
+
--participant-ids participant_ids.txt \\
|
|
564
|
+
--output predictions.jsonl
|
|
565
|
+
"""
|
|
566
|
+
try:
|
|
567
|
+
print_info(f"Loading model from {model_dir}")
|
|
568
|
+
|
|
569
|
+
# Load config
|
|
570
|
+
config_path = model_dir / "config.json"
|
|
571
|
+
if not config_path.exists():
|
|
572
|
+
print_error(f"Model config not found: {config_path}")
|
|
573
|
+
ctx.exit(1)
|
|
574
|
+
|
|
575
|
+
with open(config_path, encoding="utf-8") as f:
|
|
576
|
+
config_dict = json.load(f)
|
|
577
|
+
|
|
578
|
+
# Get task type from config
|
|
579
|
+
if "task_type" not in config_dict:
|
|
580
|
+
print_error(
|
|
581
|
+
"Model config missing 'task_type' field. "
|
|
582
|
+
"This model may have been trained with an older version of bead."
|
|
583
|
+
)
|
|
584
|
+
print_info("Valid task types: " + ", ".join(TASK_TYPE_MODELS.keys()))
|
|
585
|
+
ctx.exit(1)
|
|
586
|
+
|
|
587
|
+
task_type = config_dict["task_type"]
|
|
588
|
+
if task_type not in TASK_TYPE_MODELS:
|
|
589
|
+
print_error(
|
|
590
|
+
f"Unknown task type '{task_type}' in model config. "
|
|
591
|
+
f"Valid types: {', '.join(TASK_TYPE_MODELS.keys())}"
|
|
592
|
+
)
|
|
593
|
+
ctx.exit(1)
|
|
594
|
+
|
|
595
|
+
print_success(f"Detected task type: {task_type}")
|
|
596
|
+
|
|
597
|
+
# Import model class
|
|
598
|
+
model_class = _import_class(TASK_TYPE_MODELS[task_type])
|
|
599
|
+
config_class = _import_class(TASK_TYPE_CONFIGS[task_type])
|
|
600
|
+
model_config = config_class(**config_dict)
|
|
601
|
+
|
|
602
|
+
# Initialize model and load weights
|
|
603
|
+
model = model_class(config=model_config)
|
|
604
|
+
model_path = model_dir / "model.pt"
|
|
605
|
+
if not model_path.exists():
|
|
606
|
+
print_error(f"Model weights not found: {model_path}")
|
|
607
|
+
ctx.exit(1)
|
|
608
|
+
|
|
609
|
+
model.load(model_path)
|
|
610
|
+
print_success(f"Loaded model: {model_path}")
|
|
611
|
+
|
|
612
|
+
# Load items
|
|
613
|
+
with Progress(
|
|
614
|
+
SpinnerColumn(),
|
|
615
|
+
TextColumn("[progress.description]{task.description}"),
|
|
616
|
+
console=console,
|
|
617
|
+
) as progress:
|
|
618
|
+
progress.add_task("Loading items...", total=None)
|
|
619
|
+
items = read_jsonlines(items_file, Item)
|
|
620
|
+
|
|
621
|
+
print_success(f"Loaded {len(items)} items")
|
|
622
|
+
|
|
623
|
+
# Load participant IDs if provided
|
|
624
|
+
participant_ids = None
|
|
625
|
+
if participant_ids_file:
|
|
626
|
+
with open(participant_ids_file, encoding="utf-8") as f:
|
|
627
|
+
participant_ids = [line.strip() for line in f if line.strip()]
|
|
628
|
+
|
|
629
|
+
if len(participant_ids) != len(items):
|
|
630
|
+
print_error(
|
|
631
|
+
f"Number of participant IDs ({len(participant_ids)}) does not "
|
|
632
|
+
f"match number of items ({len(items)})"
|
|
633
|
+
)
|
|
634
|
+
ctx.exit(1)
|
|
635
|
+
|
|
636
|
+
print_success(f"Loaded {len(participant_ids)} participant IDs")
|
|
637
|
+
|
|
638
|
+
# Make predictions
|
|
639
|
+
console.rule("[bold]Making Predictions[/bold]")
|
|
640
|
+
with Progress(
|
|
641
|
+
SpinnerColumn(),
|
|
642
|
+
TextColumn("[progress.description]{task.description}"),
|
|
643
|
+
console=console,
|
|
644
|
+
) as progress:
|
|
645
|
+
progress.add_task("Predicting...", total=None)
|
|
646
|
+
predictions = model.predict(items=items, participant_ids=participant_ids)
|
|
647
|
+
|
|
648
|
+
# Save predictions
|
|
649
|
+
output_file.parent.mkdir(parents=True, exist_ok=True)
|
|
650
|
+
with open(output_file, "w", encoding="utf-8") as f:
|
|
651
|
+
for pred in predictions:
|
|
652
|
+
f.write(pred.model_dump_json() + "\n")
|
|
653
|
+
|
|
654
|
+
print_success(f"Saved {len(predictions)} predictions: {output_file}")
|
|
655
|
+
|
|
656
|
+
# Display sample predictions
|
|
657
|
+
console.rule("[bold]Sample Predictions[/bold]")
|
|
658
|
+
table = Table(title="First 5 Predictions")
|
|
659
|
+
table.add_column("Index", style="cyan", justify="right")
|
|
660
|
+
table.add_column("Predicted Label", style="green")
|
|
661
|
+
table.add_column("Confidence", style="yellow", justify="right")
|
|
662
|
+
|
|
663
|
+
for i, pred in enumerate(predictions[:5]):
|
|
664
|
+
confidence = pred.confidence if hasattr(pred, "confidence") else "N/A"
|
|
665
|
+
if isinstance(confidence, float):
|
|
666
|
+
confidence_str = f"{confidence:.3f}"
|
|
667
|
+
else:
|
|
668
|
+
confidence_str = str(confidence)
|
|
669
|
+
table.add_row(str(i), str(pred.predicted_label), confidence_str)
|
|
670
|
+
|
|
671
|
+
console.print(table)
|
|
672
|
+
|
|
673
|
+
except FileNotFoundError as e:
|
|
674
|
+
print_error(f"File not found: {e}")
|
|
675
|
+
ctx.exit(1)
|
|
676
|
+
except json.JSONDecodeError as e:
|
|
677
|
+
print_error(f"Invalid JSON in file: {e}")
|
|
678
|
+
ctx.exit(1)
|
|
679
|
+
except ValueError as e:
|
|
680
|
+
print_error(f"Invalid configuration or data: {e}")
|
|
681
|
+
ctx.exit(1)
|
|
682
|
+
except (ImportError, AttributeError) as e:
|
|
683
|
+
print_error(f"Failed to import model class: {e}")
|
|
684
|
+
print_info(
|
|
685
|
+
"This may indicate a corrupted installation. "
|
|
686
|
+
"Try reinstalling bead with: pip install --force-reinstall bead"
|
|
687
|
+
)
|
|
688
|
+
ctx.exit(1)
|
|
689
|
+
|
|
690
|
+
|
|
691
|
+
@click.command()
|
|
692
|
+
@click.option(
|
|
693
|
+
"--model-dir",
|
|
694
|
+
required=True,
|
|
695
|
+
type=click.Path(exists=True, path_type=Path),
|
|
696
|
+
help="Path to trained model directory",
|
|
697
|
+
)
|
|
698
|
+
@click.option(
|
|
699
|
+
"--items",
|
|
700
|
+
"items_file",
|
|
701
|
+
required=True,
|
|
702
|
+
type=click.Path(exists=True, path_type=Path),
|
|
703
|
+
help="Path to items JSONL file",
|
|
704
|
+
)
|
|
705
|
+
@click.option(
|
|
706
|
+
"--participant-ids",
|
|
707
|
+
"participant_ids_file",
|
|
708
|
+
type=click.Path(exists=True, path_type=Path),
|
|
709
|
+
help="Path to participant IDs file (required for random effects models)",
|
|
710
|
+
)
|
|
711
|
+
@click.option(
|
|
712
|
+
"--output",
|
|
713
|
+
"output_file",
|
|
714
|
+
required=True,
|
|
715
|
+
type=click.Path(path_type=Path),
|
|
716
|
+
help="Output path for probabilities JSON",
|
|
717
|
+
)
|
|
718
|
+
@click.pass_context
|
|
719
|
+
def predict_proba(
|
|
720
|
+
ctx: click.Context,
|
|
721
|
+
model_dir: Path,
|
|
722
|
+
items_file: Path,
|
|
723
|
+
participant_ids_file: Path | None,
|
|
724
|
+
output_file: Path,
|
|
725
|
+
) -> None:
|
|
726
|
+
r"""Predict class probabilities with trained model.
|
|
727
|
+
|
|
728
|
+
Predicts class probability distributions for items using a trained GLMM
|
|
729
|
+
model. For random effects models, participant IDs are required.
|
|
730
|
+
|
|
731
|
+
Parameters
|
|
732
|
+
----------
|
|
733
|
+
ctx : click.Context
|
|
734
|
+
Click context object.
|
|
735
|
+
model_dir : Path
|
|
736
|
+
Path to trained model directory.
|
|
737
|
+
items_file : Path
|
|
738
|
+
Path to items JSONL file.
|
|
739
|
+
participant_ids_file : Path | None
|
|
740
|
+
Path to participant IDs file (required for random effects).
|
|
741
|
+
output_file : Path
|
|
742
|
+
Output path for probabilities JSON.
|
|
743
|
+
|
|
744
|
+
Examples
|
|
745
|
+
--------
|
|
746
|
+
$ bead models predict-proba \\
|
|
747
|
+
--model-dir models/fc_model/ \\
|
|
748
|
+
--items test_items.jsonl \\
|
|
749
|
+
--output probabilities.json
|
|
750
|
+
"""
|
|
751
|
+
try:
|
|
752
|
+
print_info(f"Loading model from {model_dir}")
|
|
753
|
+
|
|
754
|
+
# Load config
|
|
755
|
+
config_path = model_dir / "config.json"
|
|
756
|
+
if not config_path.exists():
|
|
757
|
+
print_error(f"Model config not found: {config_path}")
|
|
758
|
+
ctx.exit(1)
|
|
759
|
+
|
|
760
|
+
with open(config_path, encoding="utf-8") as f:
|
|
761
|
+
config_dict = json.load(f)
|
|
762
|
+
|
|
763
|
+
# Get task type from config
|
|
764
|
+
if "task_type" not in config_dict:
|
|
765
|
+
print_error(
|
|
766
|
+
"Model config missing 'task_type' field. "
|
|
767
|
+
"This model may have been trained with an older version of bead."
|
|
768
|
+
)
|
|
769
|
+
print_info("Valid task types: " + ", ".join(TASK_TYPE_MODELS.keys()))
|
|
770
|
+
ctx.exit(1)
|
|
771
|
+
|
|
772
|
+
task_type = config_dict["task_type"]
|
|
773
|
+
if task_type not in TASK_TYPE_MODELS:
|
|
774
|
+
print_error(
|
|
775
|
+
f"Unknown task type '{task_type}' in model config. "
|
|
776
|
+
f"Valid types: {', '.join(TASK_TYPE_MODELS.keys())}"
|
|
777
|
+
)
|
|
778
|
+
ctx.exit(1)
|
|
779
|
+
|
|
780
|
+
print_success(f"Detected task type: {task_type}")
|
|
781
|
+
|
|
782
|
+
# Import model class
|
|
783
|
+
model_class = _import_class(TASK_TYPE_MODELS[task_type])
|
|
784
|
+
config_class = _import_class(TASK_TYPE_CONFIGS[task_type])
|
|
785
|
+
model_config = config_class(**config_dict)
|
|
786
|
+
|
|
787
|
+
# Initialize model and load weights
|
|
788
|
+
model = model_class(config=model_config)
|
|
789
|
+
model_path = model_dir / "model.pt"
|
|
790
|
+
if not model_path.exists():
|
|
791
|
+
print_error(f"Model weights not found: {model_path}")
|
|
792
|
+
ctx.exit(1)
|
|
793
|
+
|
|
794
|
+
model.load(model_path)
|
|
795
|
+
print_success(f"Loaded model: {model_path}")
|
|
796
|
+
|
|
797
|
+
# Load items
|
|
798
|
+
with Progress(
|
|
799
|
+
SpinnerColumn(),
|
|
800
|
+
TextColumn("[progress.description]{task.description}"),
|
|
801
|
+
console=console,
|
|
802
|
+
) as progress:
|
|
803
|
+
progress.add_task("Loading items...", total=None)
|
|
804
|
+
items = read_jsonlines(items_file, Item)
|
|
805
|
+
|
|
806
|
+
print_success(f"Loaded {len(items)} items")
|
|
807
|
+
|
|
808
|
+
# Load participant IDs if provided
|
|
809
|
+
participant_ids = None
|
|
810
|
+
if participant_ids_file:
|
|
811
|
+
with open(participant_ids_file, encoding="utf-8") as f:
|
|
812
|
+
participant_ids = [line.strip() for line in f if line.strip()]
|
|
813
|
+
|
|
814
|
+
if len(participant_ids) != len(items):
|
|
815
|
+
print_error(
|
|
816
|
+
f"Number of participant IDs ({len(participant_ids)}) does not "
|
|
817
|
+
f"match number of items ({len(items)})"
|
|
818
|
+
)
|
|
819
|
+
ctx.exit(1)
|
|
820
|
+
|
|
821
|
+
print_success(f"Loaded {len(participant_ids)} participant IDs")
|
|
822
|
+
|
|
823
|
+
# Predict probabilities
|
|
824
|
+
console.rule("[bold]Predicting Probabilities[/bold]")
|
|
825
|
+
with Progress(
|
|
826
|
+
SpinnerColumn(),
|
|
827
|
+
TextColumn("[progress.description]{task.description}"),
|
|
828
|
+
console=console,
|
|
829
|
+
) as progress:
|
|
830
|
+
progress.add_task("Predicting...", total=None)
|
|
831
|
+
probabilities = model.predict_proba(
|
|
832
|
+
items=items, participant_ids=participant_ids
|
|
833
|
+
)
|
|
834
|
+
|
|
835
|
+
# Save probabilities
|
|
836
|
+
output_file.parent.mkdir(parents=True, exist_ok=True)
|
|
837
|
+
with open(output_file, "w", encoding="utf-8") as f:
|
|
838
|
+
json.dump(probabilities.tolist(), f, indent=2)
|
|
839
|
+
|
|
840
|
+
print_success(
|
|
841
|
+
f"Saved {len(probabilities)} probability distributions: {output_file}"
|
|
842
|
+
)
|
|
843
|
+
|
|
844
|
+
# Display sample probabilities
|
|
845
|
+
console.rule("[bold]Sample Probabilities[/bold]")
|
|
846
|
+
table = Table(title="First 5 Probability Distributions")
|
|
847
|
+
table.add_column("Index", style="cyan", justify="right")
|
|
848
|
+
table.add_column("Probabilities", style="green")
|
|
849
|
+
|
|
850
|
+
for i, prob in enumerate(probabilities[:5]):
|
|
851
|
+
prob_str = ", ".join([f"{p:.3f}" for p in prob])
|
|
852
|
+
table.add_row(str(i), f"[{prob_str}]")
|
|
853
|
+
|
|
854
|
+
console.print(table)
|
|
855
|
+
|
|
856
|
+
except FileNotFoundError as e:
|
|
857
|
+
print_error(f"File not found: {e}")
|
|
858
|
+
ctx.exit(1)
|
|
859
|
+
except json.JSONDecodeError as e:
|
|
860
|
+
print_error(f"Invalid JSON in file: {e}")
|
|
861
|
+
ctx.exit(1)
|
|
862
|
+
except ValueError as e:
|
|
863
|
+
print_error(f"Invalid configuration or data: {e}")
|
|
864
|
+
ctx.exit(1)
|
|
865
|
+
except (ImportError, AttributeError) as e:
|
|
866
|
+
print_error(f"Failed to import model class: {e}")
|
|
867
|
+
print_info(
|
|
868
|
+
"This may indicate a corrupted installation. "
|
|
869
|
+
"Try reinstalling bead with: pip install --force-reinstall bead"
|
|
870
|
+
)
|
|
871
|
+
ctx.exit(1)
|
|
872
|
+
|
|
873
|
+
|
|
874
|
+
# Register commands
|
|
875
|
+
models.add_command(train_model)
|
|
876
|
+
models.add_command(predict)
|
|
877
|
+
models.add_command(predict_proba)
|