bead 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bead/__init__.py +11 -0
- bead/__main__.py +11 -0
- bead/active_learning/__init__.py +15 -0
- bead/active_learning/config.py +231 -0
- bead/active_learning/loop.py +566 -0
- bead/active_learning/models/__init__.py +24 -0
- bead/active_learning/models/base.py +852 -0
- bead/active_learning/models/binary.py +910 -0
- bead/active_learning/models/categorical.py +943 -0
- bead/active_learning/models/cloze.py +862 -0
- bead/active_learning/models/forced_choice.py +956 -0
- bead/active_learning/models/free_text.py +773 -0
- bead/active_learning/models/lora.py +365 -0
- bead/active_learning/models/magnitude.py +835 -0
- bead/active_learning/models/multi_select.py +795 -0
- bead/active_learning/models/ordinal_scale.py +811 -0
- bead/active_learning/models/peft_adapter.py +155 -0
- bead/active_learning/models/random_effects.py +639 -0
- bead/active_learning/selection.py +354 -0
- bead/active_learning/strategies.py +391 -0
- bead/active_learning/trainers/__init__.py +26 -0
- bead/active_learning/trainers/base.py +210 -0
- bead/active_learning/trainers/data_collator.py +172 -0
- bead/active_learning/trainers/dataset_utils.py +261 -0
- bead/active_learning/trainers/huggingface.py +304 -0
- bead/active_learning/trainers/lightning.py +324 -0
- bead/active_learning/trainers/metrics.py +424 -0
- bead/active_learning/trainers/mixed_effects.py +551 -0
- bead/active_learning/trainers/model_wrapper.py +509 -0
- bead/active_learning/trainers/registry.py +104 -0
- bead/adapters/__init__.py +11 -0
- bead/adapters/huggingface.py +61 -0
- bead/behavioral/__init__.py +116 -0
- bead/behavioral/analytics.py +646 -0
- bead/behavioral/extraction.py +343 -0
- bead/behavioral/merging.py +343 -0
- bead/cli/__init__.py +11 -0
- bead/cli/active_learning.py +513 -0
- bead/cli/active_learning_commands.py +779 -0
- bead/cli/completion.py +359 -0
- bead/cli/config.py +624 -0
- bead/cli/constraint_builders.py +286 -0
- bead/cli/deployment.py +859 -0
- bead/cli/deployment_trials.py +493 -0
- bead/cli/deployment_ui.py +332 -0
- bead/cli/display.py +378 -0
- bead/cli/items.py +960 -0
- bead/cli/items_factories.py +776 -0
- bead/cli/list_constraints.py +714 -0
- bead/cli/lists.py +490 -0
- bead/cli/main.py +430 -0
- bead/cli/models.py +877 -0
- bead/cli/resource_loaders.py +621 -0
- bead/cli/resources.py +1036 -0
- bead/cli/shell.py +356 -0
- bead/cli/simulate.py +840 -0
- bead/cli/templates.py +1158 -0
- bead/cli/training.py +1080 -0
- bead/cli/utils.py +614 -0
- bead/cli/workflow.py +1273 -0
- bead/config/__init__.py +68 -0
- bead/config/active_learning.py +1009 -0
- bead/config/config.py +192 -0
- bead/config/defaults.py +118 -0
- bead/config/deployment.py +217 -0
- bead/config/env.py +147 -0
- bead/config/item.py +45 -0
- bead/config/list.py +193 -0
- bead/config/loader.py +149 -0
- bead/config/logging.py +42 -0
- bead/config/model.py +49 -0
- bead/config/paths.py +46 -0
- bead/config/profiles.py +320 -0
- bead/config/resources.py +47 -0
- bead/config/serialization.py +210 -0
- bead/config/simulation.py +206 -0
- bead/config/template.py +238 -0
- bead/config/validation.py +267 -0
- bead/data/__init__.py +65 -0
- bead/data/base.py +87 -0
- bead/data/identifiers.py +97 -0
- bead/data/language_codes.py +61 -0
- bead/data/metadata.py +270 -0
- bead/data/range.py +123 -0
- bead/data/repository.py +358 -0
- bead/data/serialization.py +249 -0
- bead/data/timestamps.py +89 -0
- bead/data/validation.py +349 -0
- bead/data_collection/__init__.py +11 -0
- bead/data_collection/jatos.py +223 -0
- bead/data_collection/merger.py +154 -0
- bead/data_collection/prolific.py +198 -0
- bead/deployment/__init__.py +5 -0
- bead/deployment/distribution.py +402 -0
- bead/deployment/jatos/__init__.py +1 -0
- bead/deployment/jatos/api.py +200 -0
- bead/deployment/jatos/exporter.py +210 -0
- bead/deployment/jspsych/__init__.py +9 -0
- bead/deployment/jspsych/biome.json +44 -0
- bead/deployment/jspsych/config.py +411 -0
- bead/deployment/jspsych/generator.py +598 -0
- bead/deployment/jspsych/package.json +51 -0
- bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
- bead/deployment/jspsych/randomizer.py +299 -0
- bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
- bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
- bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
- bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
- bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
- bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
- bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
- bead/deployment/jspsych/src/plugins/rating.ts +248 -0
- bead/deployment/jspsych/src/slopit/index.ts +9 -0
- bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
- bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
- bead/deployment/jspsych/templates/experiment.css +1 -0
- bead/deployment/jspsych/templates/experiment.js.template +289 -0
- bead/deployment/jspsych/templates/index.html +51 -0
- bead/deployment/jspsych/templates/randomizer.js +241 -0
- bead/deployment/jspsych/templates/randomizer.js.template +313 -0
- bead/deployment/jspsych/trials.py +723 -0
- bead/deployment/jspsych/tsconfig.json +23 -0
- bead/deployment/jspsych/tsup.config.ts +30 -0
- bead/deployment/jspsych/ui/__init__.py +1 -0
- bead/deployment/jspsych/ui/components.py +383 -0
- bead/deployment/jspsych/ui/styles.py +411 -0
- bead/dsl/__init__.py +80 -0
- bead/dsl/ast.py +168 -0
- bead/dsl/context.py +178 -0
- bead/dsl/errors.py +71 -0
- bead/dsl/evaluator.py +570 -0
- bead/dsl/grammar.lark +81 -0
- bead/dsl/parser.py +231 -0
- bead/dsl/stdlib.py +929 -0
- bead/evaluation/__init__.py +13 -0
- bead/evaluation/convergence.py +485 -0
- bead/evaluation/interannotator.py +398 -0
- bead/items/__init__.py +40 -0
- bead/items/adapters/__init__.py +70 -0
- bead/items/adapters/anthropic.py +224 -0
- bead/items/adapters/api_utils.py +167 -0
- bead/items/adapters/base.py +216 -0
- bead/items/adapters/google.py +259 -0
- bead/items/adapters/huggingface.py +1074 -0
- bead/items/adapters/openai.py +323 -0
- bead/items/adapters/registry.py +202 -0
- bead/items/adapters/sentence_transformers.py +224 -0
- bead/items/adapters/togetherai.py +309 -0
- bead/items/binary.py +515 -0
- bead/items/cache.py +558 -0
- bead/items/categorical.py +593 -0
- bead/items/cloze.py +757 -0
- bead/items/constructor.py +784 -0
- bead/items/forced_choice.py +413 -0
- bead/items/free_text.py +681 -0
- bead/items/generation.py +432 -0
- bead/items/item.py +396 -0
- bead/items/item_template.py +787 -0
- bead/items/magnitude.py +573 -0
- bead/items/multi_select.py +621 -0
- bead/items/ordinal_scale.py +569 -0
- bead/items/scoring.py +448 -0
- bead/items/validation.py +723 -0
- bead/lists/__init__.py +30 -0
- bead/lists/balancer.py +263 -0
- bead/lists/constraints.py +1067 -0
- bead/lists/experiment_list.py +286 -0
- bead/lists/list_collection.py +378 -0
- bead/lists/partitioner.py +1141 -0
- bead/lists/stratification.py +254 -0
- bead/participants/__init__.py +73 -0
- bead/participants/collection.py +699 -0
- bead/participants/merging.py +312 -0
- bead/participants/metadata_spec.py +491 -0
- bead/participants/models.py +276 -0
- bead/resources/__init__.py +29 -0
- bead/resources/adapters/__init__.py +19 -0
- bead/resources/adapters/base.py +104 -0
- bead/resources/adapters/cache.py +128 -0
- bead/resources/adapters/glazing.py +508 -0
- bead/resources/adapters/registry.py +117 -0
- bead/resources/adapters/unimorph.py +796 -0
- bead/resources/classification.py +856 -0
- bead/resources/constraint_builders.py +329 -0
- bead/resources/constraints.py +165 -0
- bead/resources/lexical_item.py +223 -0
- bead/resources/lexicon.py +744 -0
- bead/resources/loaders.py +209 -0
- bead/resources/template.py +441 -0
- bead/resources/template_collection.py +707 -0
- bead/resources/template_generation.py +349 -0
- bead/simulation/__init__.py +29 -0
- bead/simulation/annotators/__init__.py +15 -0
- bead/simulation/annotators/base.py +175 -0
- bead/simulation/annotators/distance_based.py +135 -0
- bead/simulation/annotators/lm_based.py +114 -0
- bead/simulation/annotators/oracle.py +182 -0
- bead/simulation/annotators/random.py +181 -0
- bead/simulation/dsl_extension/__init__.py +3 -0
- bead/simulation/noise_models/__init__.py +13 -0
- bead/simulation/noise_models/base.py +42 -0
- bead/simulation/noise_models/random_noise.py +82 -0
- bead/simulation/noise_models/systematic.py +132 -0
- bead/simulation/noise_models/temperature.py +86 -0
- bead/simulation/runner.py +144 -0
- bead/simulation/strategies/__init__.py +23 -0
- bead/simulation/strategies/base.py +123 -0
- bead/simulation/strategies/binary.py +103 -0
- bead/simulation/strategies/categorical.py +123 -0
- bead/simulation/strategies/cloze.py +224 -0
- bead/simulation/strategies/forced_choice.py +127 -0
- bead/simulation/strategies/free_text.py +105 -0
- bead/simulation/strategies/magnitude.py +116 -0
- bead/simulation/strategies/multi_select.py +129 -0
- bead/simulation/strategies/ordinal_scale.py +131 -0
- bead/templates/__init__.py +27 -0
- bead/templates/adapters/__init__.py +17 -0
- bead/templates/adapters/base.py +128 -0
- bead/templates/adapters/cache.py +178 -0
- bead/templates/adapters/huggingface.py +312 -0
- bead/templates/combinatorics.py +103 -0
- bead/templates/filler.py +605 -0
- bead/templates/renderers.py +177 -0
- bead/templates/resolver.py +178 -0
- bead/templates/strategies.py +1806 -0
- bead/templates/streaming.py +195 -0
- bead-0.1.0.dist-info/METADATA +212 -0
- bead-0.1.0.dist-info/RECORD +231 -0
- bead-0.1.0.dist-info/WHEEL +4 -0
- bead-0.1.0.dist-info/entry_points.txt +2 -0
- bead-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
"""HuggingFace Transformers trainer implementation.
|
|
2
|
+
|
|
3
|
+
This module provides a trainer that uses the HuggingFace Transformers library
|
|
4
|
+
for model training with integrated TensorBoard logging.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import time
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import TYPE_CHECKING
|
|
13
|
+
|
|
14
|
+
from bead.active_learning.trainers.base import BaseTrainer, ModelMetadata
|
|
15
|
+
from bead.data.base import BeadBaseModel
|
|
16
|
+
from bead.data.timestamps import format_iso8601, now_iso8601
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from datasets import Dataset
|
|
20
|
+
from transformers import PreTrainedModel, PreTrainedTokenizer
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class HuggingFaceTrainer(BaseTrainer):
|
|
24
|
+
"""Trainer using HuggingFace Transformers.
|
|
25
|
+
|
|
26
|
+
This trainer uses the HuggingFace Transformers library to train models
|
|
27
|
+
for sequence classification and other NLP tasks. It supports TensorBoard
|
|
28
|
+
logging and checkpoint management.
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
config : dict[str, int | str | float | bool | Path] | BeadBaseModel
|
|
33
|
+
Training configuration with the following expected fields:
|
|
34
|
+
- model_name: str - Base model name/path
|
|
35
|
+
- task_type: str - Task type (classification, regression, etc.)
|
|
36
|
+
- num_labels: int | None - Number of labels for classification
|
|
37
|
+
- output_dir: Path - Directory for outputs
|
|
38
|
+
- num_epochs: int - Number of training epochs
|
|
39
|
+
- batch_size: int - Training batch size
|
|
40
|
+
- learning_rate: float - Learning rate
|
|
41
|
+
- weight_decay: float - Weight decay
|
|
42
|
+
- warmup_steps: int - Warmup steps
|
|
43
|
+
- evaluation_strategy: str - Evaluation strategy (epoch, steps, no)
|
|
44
|
+
- save_strategy: str - Save strategy (epoch, steps, no)
|
|
45
|
+
- load_best_model_at_end: bool - Load best model at end
|
|
46
|
+
- logging_dir: Path | None - Logging directory
|
|
47
|
+
- fp16: bool - Use mixed precision
|
|
48
|
+
|
|
49
|
+
Attributes
|
|
50
|
+
----------
|
|
51
|
+
config : dict[str, int | str | float | bool | Path] | BeadBaseModel
|
|
52
|
+
Training configuration.
|
|
53
|
+
model : PreTrainedModel | None
|
|
54
|
+
The trained model.
|
|
55
|
+
tokenizer : PreTrainedTokenizer | None
|
|
56
|
+
The tokenizer.
|
|
57
|
+
|
|
58
|
+
Examples
|
|
59
|
+
--------
|
|
60
|
+
>>> from pathlib import Path
|
|
61
|
+
>>> config = {
|
|
62
|
+
... "model_name": "bert-base-uncased",
|
|
63
|
+
... "task_type": "classification",
|
|
64
|
+
... "num_labels": 2,
|
|
65
|
+
... "output_dir": Path("output"),
|
|
66
|
+
... "num_epochs": 3,
|
|
67
|
+
... "batch_size": 16,
|
|
68
|
+
... "learning_rate": 2e-5,
|
|
69
|
+
... "weight_decay": 0.01,
|
|
70
|
+
... "warmup_steps": 0,
|
|
71
|
+
... "evaluation_strategy": "epoch",
|
|
72
|
+
... "save_strategy": "epoch",
|
|
73
|
+
... "load_best_model_at_end": True,
|
|
74
|
+
... "logging_dir": None,
|
|
75
|
+
... "fp16": False
|
|
76
|
+
... }
|
|
77
|
+
>>> trainer = HuggingFaceTrainer(config)
|
|
78
|
+
>>> trainer.model is None
|
|
79
|
+
True
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def __init__(
|
|
83
|
+
self, config: dict[str, int | str | float | bool | Path] | BeadBaseModel
|
|
84
|
+
) -> None:
|
|
85
|
+
super().__init__(config)
|
|
86
|
+
self.model: PreTrainedModel | None = None
|
|
87
|
+
self.tokenizer: PreTrainedTokenizer | None = None
|
|
88
|
+
|
|
89
|
+
def _get_config_value(
|
|
90
|
+
self, key: str, default: int | str | float | bool | Path | None = None
|
|
91
|
+
) -> int | str | float | bool | Path | None:
|
|
92
|
+
"""Get configuration value with fallback to default.
|
|
93
|
+
|
|
94
|
+
Parameters
|
|
95
|
+
----------
|
|
96
|
+
key : str
|
|
97
|
+
Configuration key.
|
|
98
|
+
default : int | str | float | bool | Path | None
|
|
99
|
+
Default value if key not found.
|
|
100
|
+
|
|
101
|
+
Returns
|
|
102
|
+
-------
|
|
103
|
+
int | str | float | bool | Path | None
|
|
104
|
+
Configuration value.
|
|
105
|
+
"""
|
|
106
|
+
if hasattr(self.config, key):
|
|
107
|
+
return getattr(self.config, key)
|
|
108
|
+
if isinstance(self.config, dict):
|
|
109
|
+
return self.config.get(key, default)
|
|
110
|
+
return default
|
|
111
|
+
|
|
112
|
+
def train(
|
|
113
|
+
self, train_data: Dataset, eval_data: Dataset | None = None
|
|
114
|
+
) -> ModelMetadata:
|
|
115
|
+
"""Train model using HuggingFace Trainer.
|
|
116
|
+
|
|
117
|
+
Parameters
|
|
118
|
+
----------
|
|
119
|
+
train_data : Dataset
|
|
120
|
+
HuggingFace Dataset for training.
|
|
121
|
+
eval_data : Dataset | None
|
|
122
|
+
HuggingFace Dataset for evaluation.
|
|
123
|
+
|
|
124
|
+
Returns
|
|
125
|
+
-------
|
|
126
|
+
ModelMetadata
|
|
127
|
+
Training metadata.
|
|
128
|
+
|
|
129
|
+
Raises
|
|
130
|
+
------
|
|
131
|
+
ValueError
|
|
132
|
+
If task type is not supported.
|
|
133
|
+
|
|
134
|
+
Examples
|
|
135
|
+
--------
|
|
136
|
+
>>> config = {"model_name": "bert-base-uncased"} # doctest: +SKIP
|
|
137
|
+
>>> trainer = HuggingFaceTrainer(config) # doctest: +SKIP
|
|
138
|
+
>>> metadata = trainer.train(train_dataset) # doctest: +SKIP
|
|
139
|
+
>>> metadata.framework # doctest: +SKIP
|
|
140
|
+
'huggingface'
|
|
141
|
+
"""
|
|
142
|
+
from transformers import ( # noqa: PLC0415
|
|
143
|
+
AutoModelForSequenceClassification,
|
|
144
|
+
AutoTokenizer,
|
|
145
|
+
DataCollatorWithPadding,
|
|
146
|
+
Trainer,
|
|
147
|
+
TrainingArguments,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
start_time = time.time()
|
|
151
|
+
|
|
152
|
+
# Get config values
|
|
153
|
+
model_name = self._get_config_value("model_name", "bert-base-uncased")
|
|
154
|
+
task_type = self._get_config_value("task_type", "classification")
|
|
155
|
+
num_labels = self._get_config_value("num_labels", 2)
|
|
156
|
+
output_dir = self._get_config_value("output_dir", Path("output"))
|
|
157
|
+
num_epochs = self._get_config_value("num_epochs", 3)
|
|
158
|
+
batch_size = self._get_config_value("batch_size", 16)
|
|
159
|
+
learning_rate = self._get_config_value("learning_rate", 2e-5)
|
|
160
|
+
weight_decay = self._get_config_value("weight_decay", 0.01)
|
|
161
|
+
warmup_steps = self._get_config_value("warmup_steps", 0)
|
|
162
|
+
evaluation_strategy = self._get_config_value("evaluation_strategy", "epoch")
|
|
163
|
+
save_strategy = self._get_config_value("save_strategy", "epoch")
|
|
164
|
+
load_best = self._get_config_value("load_best_model_at_end", True)
|
|
165
|
+
logging_dir = self._get_config_value("logging_dir", None)
|
|
166
|
+
fp16 = self._get_config_value("fp16", False)
|
|
167
|
+
|
|
168
|
+
# Load model and tokenizer
|
|
169
|
+
if task_type == "classification":
|
|
170
|
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
|
171
|
+
model_name, num_labels=num_labels
|
|
172
|
+
)
|
|
173
|
+
else:
|
|
174
|
+
msg = f"Task type not supported: {task_type}"
|
|
175
|
+
raise ValueError(msg)
|
|
176
|
+
|
|
177
|
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
178
|
+
|
|
179
|
+
# Create training arguments
|
|
180
|
+
training_args = TrainingArguments(
|
|
181
|
+
output_dir=str(output_dir),
|
|
182
|
+
num_train_epochs=num_epochs,
|
|
183
|
+
per_device_train_batch_size=batch_size,
|
|
184
|
+
per_device_eval_batch_size=batch_size,
|
|
185
|
+
learning_rate=learning_rate,
|
|
186
|
+
weight_decay=weight_decay,
|
|
187
|
+
warmup_steps=warmup_steps,
|
|
188
|
+
eval_strategy=evaluation_strategy, # type: ignore
|
|
189
|
+
save_strategy=save_strategy,
|
|
190
|
+
load_best_model_at_end=load_best,
|
|
191
|
+
logging_dir=str(logging_dir) if logging_dir else None,
|
|
192
|
+
fp16=fp16,
|
|
193
|
+
report_to=["tensorboard"] if logging_dir else [],
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
# Create data collator
|
|
197
|
+
data_collator = DataCollatorWithPadding(tokenizer=self.tokenizer)
|
|
198
|
+
|
|
199
|
+
# Create trainer
|
|
200
|
+
trainer = Trainer(
|
|
201
|
+
model=self.model,
|
|
202
|
+
args=training_args,
|
|
203
|
+
train_dataset=train_data,
|
|
204
|
+
eval_dataset=eval_data,
|
|
205
|
+
data_collator=data_collator,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
# Train
|
|
209
|
+
trainer.train()
|
|
210
|
+
|
|
211
|
+
# Evaluate
|
|
212
|
+
metrics = {}
|
|
213
|
+
if eval_data is not None:
|
|
214
|
+
eval_results = trainer.evaluate()
|
|
215
|
+
metrics = {k: float(v) for k, v in eval_results.items()}
|
|
216
|
+
|
|
217
|
+
training_time = time.time() - start_time
|
|
218
|
+
|
|
219
|
+
# Get best checkpoint path
|
|
220
|
+
best_checkpoint = None
|
|
221
|
+
if trainer.state.best_model_checkpoint:
|
|
222
|
+
best_checkpoint = Path(trainer.state.best_model_checkpoint)
|
|
223
|
+
|
|
224
|
+
# Create metadata
|
|
225
|
+
config_dict = (
|
|
226
|
+
self.config
|
|
227
|
+
if isinstance(self.config, dict)
|
|
228
|
+
else (
|
|
229
|
+
self.config.model_dump() if hasattr(self.config, "model_dump") else {}
|
|
230
|
+
)
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
metadata = ModelMetadata(
|
|
234
|
+
model_name=model_name,
|
|
235
|
+
framework="huggingface",
|
|
236
|
+
training_config=config_dict,
|
|
237
|
+
training_data_path=Path("train.json"),
|
|
238
|
+
eval_data_path=Path("eval.json") if eval_data else None,
|
|
239
|
+
metrics=metrics,
|
|
240
|
+
best_checkpoint=best_checkpoint,
|
|
241
|
+
training_time=training_time,
|
|
242
|
+
training_timestamp=format_iso8601(now_iso8601()),
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
return metadata
|
|
246
|
+
|
|
247
|
+
def save_model(self, output_dir: Path, metadata: ModelMetadata) -> None:
|
|
248
|
+
"""Save model and metadata.
|
|
249
|
+
|
|
250
|
+
Parameters
|
|
251
|
+
----------
|
|
252
|
+
output_dir : Path
|
|
253
|
+
Directory to save model and metadata.
|
|
254
|
+
metadata : ModelMetadata
|
|
255
|
+
Training metadata to save.
|
|
256
|
+
|
|
257
|
+
Examples
|
|
258
|
+
--------
|
|
259
|
+
>>> trainer = HuggingFaceTrainer({}) # doctest: +SKIP
|
|
260
|
+
>>> trainer.save_model(Path("output"), metadata) # doctest: +SKIP
|
|
261
|
+
"""
|
|
262
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
263
|
+
|
|
264
|
+
# Save model
|
|
265
|
+
if self.model is not None:
|
|
266
|
+
self.model.save_pretrained(output_dir / "model")
|
|
267
|
+
if self.tokenizer is not None:
|
|
268
|
+
self.tokenizer.save_pretrained(output_dir / "model")
|
|
269
|
+
|
|
270
|
+
# Save metadata
|
|
271
|
+
with open(output_dir / "metadata.json", "w") as f:
|
|
272
|
+
# Convert Path objects to strings for JSON serialization
|
|
273
|
+
metadata_dict = metadata.model_dump()
|
|
274
|
+
json.dump(metadata_dict, f, indent=2, default=str)
|
|
275
|
+
|
|
276
|
+
def load_model(self, model_dir: Path) -> PreTrainedModel:
|
|
277
|
+
"""Load model.
|
|
278
|
+
|
|
279
|
+
Parameters
|
|
280
|
+
----------
|
|
281
|
+
model_dir : Path
|
|
282
|
+
Directory containing saved model.
|
|
283
|
+
|
|
284
|
+
Returns
|
|
285
|
+
-------
|
|
286
|
+
PreTrainedModel
|
|
287
|
+
Loaded model.
|
|
288
|
+
|
|
289
|
+
Examples
|
|
290
|
+
--------
|
|
291
|
+
>>> trainer = HuggingFaceTrainer({}) # doctest: +SKIP
|
|
292
|
+
>>> model = trainer.load_model(Path("saved_model")) # doctest: +SKIP
|
|
293
|
+
"""
|
|
294
|
+
from transformers import ( # noqa: PLC0415
|
|
295
|
+
AutoModelForSequenceClassification,
|
|
296
|
+
AutoTokenizer,
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
|
300
|
+
model_dir / "model"
|
|
301
|
+
)
|
|
302
|
+
self.tokenizer = AutoTokenizer.from_pretrained(model_dir / "model")
|
|
303
|
+
|
|
304
|
+
return self.model
|
|
@@ -0,0 +1,324 @@
|
|
|
1
|
+
"""PyTorch Lightning trainer implementation.
|
|
2
|
+
|
|
3
|
+
This module provides a trainer that uses PyTorch Lightning for model training
|
|
4
|
+
with callbacks for checkpointing and early stopping.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import time
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import TYPE_CHECKING, Any
|
|
13
|
+
|
|
14
|
+
from bead.active_learning.trainers.base import BaseTrainer, ModelMetadata
|
|
15
|
+
from bead.data.base import BeadBaseModel
|
|
16
|
+
from bead.data.timestamps import format_iso8601, now_iso8601
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
import pytorch_lightning as pl
|
|
20
|
+
from torch.nn import Module
|
|
21
|
+
from torch.utils.data import DataLoader
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def create_lightning_module(
|
|
25
|
+
model: Module, learning_rate: float = 2e-5
|
|
26
|
+
) -> pl.LightningModule:
|
|
27
|
+
"""Create a PyTorch Lightning module.
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
model
|
|
32
|
+
The PyTorch model to wrap in a Lightning module.
|
|
33
|
+
learning_rate
|
|
34
|
+
Learning rate for the AdamW optimizer.
|
|
35
|
+
|
|
36
|
+
Returns
|
|
37
|
+
-------
|
|
38
|
+
pl.LightningModule
|
|
39
|
+
Lightning module wrapping the provided model with training,
|
|
40
|
+
validation, and optimizer configuration.
|
|
41
|
+
"""
|
|
42
|
+
import pytorch_lightning as pl # noqa: PLC0415
|
|
43
|
+
import torch # noqa: PLC0415
|
|
44
|
+
|
|
45
|
+
class _LightningModule(pl.LightningModule):
|
|
46
|
+
def __init__(self) -> None:
|
|
47
|
+
super().__init__()
|
|
48
|
+
self.model = model
|
|
49
|
+
self.learning_rate = learning_rate
|
|
50
|
+
|
|
51
|
+
def forward(self, **inputs: Any) -> Any:
|
|
52
|
+
return self.model(**inputs)
|
|
53
|
+
|
|
54
|
+
def training_step(self, batch: Any, batch_idx: int) -> Any:
|
|
55
|
+
outputs = self(**batch)
|
|
56
|
+
loss = outputs.loss
|
|
57
|
+
self.log("train_loss", loss)
|
|
58
|
+
return loss
|
|
59
|
+
|
|
60
|
+
def validation_step(self, batch: Any, batch_idx: int) -> Any:
|
|
61
|
+
outputs = self(**batch)
|
|
62
|
+
loss = outputs.loss
|
|
63
|
+
self.log("val_loss", loss)
|
|
64
|
+
return loss
|
|
65
|
+
|
|
66
|
+
def configure_optimizers(self) -> Any:
|
|
67
|
+
optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
|
|
68
|
+
return optimizer
|
|
69
|
+
|
|
70
|
+
return _LightningModule()
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class PyTorchLightningTrainer(BaseTrainer):
|
|
74
|
+
"""Trainer using PyTorch Lightning.
|
|
75
|
+
|
|
76
|
+
Trains models using PyTorch Lightning with callbacks for checkpointing
|
|
77
|
+
and early stopping.
|
|
78
|
+
|
|
79
|
+
Parameters
|
|
80
|
+
----------
|
|
81
|
+
config
|
|
82
|
+
Training configuration as a dict or config object with the following
|
|
83
|
+
fields:
|
|
84
|
+
|
|
85
|
+
- model_name: str, base model name or path
|
|
86
|
+
- num_labels: int, number of output labels
|
|
87
|
+
- num_epochs: int, number of training epochs
|
|
88
|
+
- learning_rate: float, learning rate for optimizer
|
|
89
|
+
- output_dir: Path, directory for outputs and checkpoints
|
|
90
|
+
- logging_dir: Path or None, optional TensorBoard logging directory
|
|
91
|
+
|
|
92
|
+
Attributes
|
|
93
|
+
----------
|
|
94
|
+
config : dict[str, int | str | float | bool | Path] | BeadBaseModel
|
|
95
|
+
Training configuration.
|
|
96
|
+
lightning_module : pl.LightningModule | None
|
|
97
|
+
The Lightning module wrapper, set after training.
|
|
98
|
+
|
|
99
|
+
Examples
|
|
100
|
+
--------
|
|
101
|
+
>>> from pathlib import Path
|
|
102
|
+
>>> config = {
|
|
103
|
+
... "model_name": "bert-base-uncased",
|
|
104
|
+
... "num_labels": 2,
|
|
105
|
+
... "num_epochs": 3,
|
|
106
|
+
... "learning_rate": 2e-5,
|
|
107
|
+
... "output_dir": Path("output"),
|
|
108
|
+
... "logging_dir": None
|
|
109
|
+
... }
|
|
110
|
+
>>> trainer = PyTorchLightningTrainer(config)
|
|
111
|
+
>>> trainer.lightning_module is None
|
|
112
|
+
True
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
def __init__(
|
|
116
|
+
self, config: dict[str, int | str | float | bool | Path] | BeadBaseModel
|
|
117
|
+
) -> None:
|
|
118
|
+
super().__init__(config)
|
|
119
|
+
self.lightning_module: pl.LightningModule | None = None
|
|
120
|
+
|
|
121
|
+
def _get_config_value(
|
|
122
|
+
self, key: str, default: int | str | float | bool | Path | None = None
|
|
123
|
+
) -> int | str | float | bool | Path | None:
|
|
124
|
+
"""Get configuration value with fallback to default.
|
|
125
|
+
|
|
126
|
+
Parameters
|
|
127
|
+
----------
|
|
128
|
+
key
|
|
129
|
+
Configuration key to retrieve.
|
|
130
|
+
default
|
|
131
|
+
Default value if key is not found.
|
|
132
|
+
|
|
133
|
+
Returns
|
|
134
|
+
-------
|
|
135
|
+
int | str | float | bool | Path | None
|
|
136
|
+
Configuration value for the given key, or default if not found.
|
|
137
|
+
"""
|
|
138
|
+
if hasattr(self.config, key):
|
|
139
|
+
return getattr(self.config, key)
|
|
140
|
+
if isinstance(self.config, dict):
|
|
141
|
+
return self.config.get(key, default)
|
|
142
|
+
return default
|
|
143
|
+
|
|
144
|
+
def train(
|
|
145
|
+
self, train_data: DataLoader, eval_data: DataLoader | None = None
|
|
146
|
+
) -> ModelMetadata:
|
|
147
|
+
"""Train a model using PyTorch Lightning.
|
|
148
|
+
|
|
149
|
+
Loads a pretrained model, wraps it in a Lightning module, and trains
|
|
150
|
+
with checkpointing and early stopping callbacks.
|
|
151
|
+
|
|
152
|
+
Parameters
|
|
153
|
+
----------
|
|
154
|
+
train_data
|
|
155
|
+
Training dataloader providing batches for training.
|
|
156
|
+
eval_data
|
|
157
|
+
Optional evaluation dataloader for validation during training.
|
|
158
|
+
|
|
159
|
+
Returns
|
|
160
|
+
-------
|
|
161
|
+
ModelMetadata
|
|
162
|
+
Metadata containing model name, framework, training config,
|
|
163
|
+
metrics, checkpoint path, and training time.
|
|
164
|
+
|
|
165
|
+
Examples
|
|
166
|
+
--------
|
|
167
|
+
>>> config = {"model_name": "bert-base-uncased"} # doctest: +SKIP
|
|
168
|
+
>>> trainer = PyTorchLightningTrainer(config) # doctest: +SKIP
|
|
169
|
+
>>> metadata = trainer.train(train_loader) # doctest: +SKIP
|
|
170
|
+
>>> metadata.framework # doctest: +SKIP
|
|
171
|
+
'pytorch_lightning'
|
|
172
|
+
"""
|
|
173
|
+
import pytorch_lightning as pl # noqa: PLC0415
|
|
174
|
+
from transformers import AutoModelForSequenceClassification # noqa: PLC0415
|
|
175
|
+
|
|
176
|
+
start_time = time.time()
|
|
177
|
+
|
|
178
|
+
# get config values
|
|
179
|
+
model_name = self._get_config_value("model_name", "bert-base-uncased")
|
|
180
|
+
num_labels = self._get_config_value("num_labels", 2)
|
|
181
|
+
num_epochs = self._get_config_value("num_epochs", 3)
|
|
182
|
+
learning_rate = self._get_config_value("learning_rate", 2e-5)
|
|
183
|
+
output_dir = self._get_config_value("output_dir", Path("output"))
|
|
184
|
+
logging_dir = self._get_config_value("logging_dir", None)
|
|
185
|
+
|
|
186
|
+
# load model
|
|
187
|
+
model = AutoModelForSequenceClassification.from_pretrained(
|
|
188
|
+
model_name, num_labels=num_labels
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
# create lightning module
|
|
192
|
+
self.lightning_module = create_lightning_module(model, learning_rate)
|
|
193
|
+
|
|
194
|
+
# create callbacks
|
|
195
|
+
callbacks = [
|
|
196
|
+
pl.callbacks.ModelCheckpoint(
|
|
197
|
+
monitor="val_loss",
|
|
198
|
+
dirpath=output_dir,
|
|
199
|
+
filename="best-{epoch:02d}-{val_loss:.2f}",
|
|
200
|
+
),
|
|
201
|
+
pl.callbacks.EarlyStopping(monitor="val_loss", patience=3),
|
|
202
|
+
]
|
|
203
|
+
|
|
204
|
+
# create logger
|
|
205
|
+
logger = None
|
|
206
|
+
if logging_dir:
|
|
207
|
+
logger = pl.loggers.TensorBoardLogger(str(logging_dir))
|
|
208
|
+
|
|
209
|
+
# create trainer
|
|
210
|
+
trainer = pl.Trainer(
|
|
211
|
+
max_epochs=num_epochs,
|
|
212
|
+
accelerator="auto",
|
|
213
|
+
devices="auto",
|
|
214
|
+
logger=logger,
|
|
215
|
+
callbacks=callbacks,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
# train
|
|
219
|
+
trainer.fit(
|
|
220
|
+
self.lightning_module,
|
|
221
|
+
train_dataloaders=train_data,
|
|
222
|
+
val_dataloaders=eval_data,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# evaluate
|
|
226
|
+
metrics: dict[str, float] = {}
|
|
227
|
+
if eval_data is not None:
|
|
228
|
+
eval_results = trainer.validate(
|
|
229
|
+
self.lightning_module, dataloaders=eval_data
|
|
230
|
+
)
|
|
231
|
+
if eval_results:
|
|
232
|
+
metrics = {k: float(v) for k, v in eval_results[0].items()}
|
|
233
|
+
|
|
234
|
+
training_time = time.time() - start_time
|
|
235
|
+
|
|
236
|
+
# get best checkpoint path
|
|
237
|
+
best_checkpoint = None
|
|
238
|
+
if hasattr(trainer.checkpoint_callback, "best_model_path"):
|
|
239
|
+
best_checkpoint_str = trainer.checkpoint_callback.best_model_path
|
|
240
|
+
if best_checkpoint_str:
|
|
241
|
+
best_checkpoint = Path(best_checkpoint_str)
|
|
242
|
+
|
|
243
|
+
# create metadata
|
|
244
|
+
config_dict = (
|
|
245
|
+
self.config
|
|
246
|
+
if isinstance(self.config, dict)
|
|
247
|
+
else (
|
|
248
|
+
self.config.model_dump() if hasattr(self.config, "model_dump") else {}
|
|
249
|
+
)
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
metadata = ModelMetadata(
|
|
253
|
+
model_name=model_name,
|
|
254
|
+
framework="pytorch_lightning",
|
|
255
|
+
training_config=config_dict,
|
|
256
|
+
training_data_path=Path("train.json"),
|
|
257
|
+
eval_data_path=Path("eval.json") if eval_data else None,
|
|
258
|
+
metrics=metrics,
|
|
259
|
+
best_checkpoint=best_checkpoint,
|
|
260
|
+
training_time=training_time,
|
|
261
|
+
training_timestamp=format_iso8601(now_iso8601()),
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
return metadata
|
|
265
|
+
|
|
266
|
+
def save_model(self, output_dir: Path, metadata: ModelMetadata) -> None:
|
|
267
|
+
"""Save model and metadata to disk.
|
|
268
|
+
|
|
269
|
+
Saves the Lightning module state dict and training metadata as JSON.
|
|
270
|
+
|
|
271
|
+
Parameters
|
|
272
|
+
----------
|
|
273
|
+
output_dir
|
|
274
|
+
Directory to save model checkpoint and metadata JSON file.
|
|
275
|
+
metadata
|
|
276
|
+
Training metadata to save alongside the model.
|
|
277
|
+
|
|
278
|
+
Examples
|
|
279
|
+
--------
|
|
280
|
+
>>> trainer = PyTorchLightningTrainer({}) # doctest: +SKIP
|
|
281
|
+
>>> trainer.save_model(Path("output"), metadata) # doctest: +SKIP
|
|
282
|
+
"""
|
|
283
|
+
import torch # noqa: PLC0415
|
|
284
|
+
|
|
285
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
286
|
+
|
|
287
|
+
# save lightning checkpoint
|
|
288
|
+
if self.lightning_module is not None:
|
|
289
|
+
torch.save(
|
|
290
|
+
self.lightning_module.state_dict(),
|
|
291
|
+
output_dir / "lightning_model.pt",
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
# save metadata
|
|
295
|
+
with open(output_dir / "metadata.json", "w") as f:
|
|
296
|
+
metadata_dict = metadata.model_dump()
|
|
297
|
+
json.dump(metadata_dict, f, indent=2, default=str)
|
|
298
|
+
|
|
299
|
+
def load_model(self, model_dir: Path) -> pl.LightningModule | None:
|
|
300
|
+
"""Load a saved model from disk.
|
|
301
|
+
|
|
302
|
+
Parameters
|
|
303
|
+
----------
|
|
304
|
+
model_dir
|
|
305
|
+
Directory containing the saved Lightning model state dict.
|
|
306
|
+
|
|
307
|
+
Returns
|
|
308
|
+
-------
|
|
309
|
+
pl.LightningModule | None
|
|
310
|
+
The Lightning module with loaded weights, or None if no module
|
|
311
|
+
has been initialized.
|
|
312
|
+
|
|
313
|
+
Examples
|
|
314
|
+
--------
|
|
315
|
+
>>> trainer = PyTorchLightningTrainer({}) # doctest: +SKIP
|
|
316
|
+
>>> model = trainer.load_model(Path("saved_model")) # doctest: +SKIP
|
|
317
|
+
"""
|
|
318
|
+
import torch # noqa: PLC0415
|
|
319
|
+
|
|
320
|
+
if self.lightning_module is not None:
|
|
321
|
+
self.lightning_module.load_state_dict(
|
|
322
|
+
torch.load(model_dir / "lightning_model.pt")
|
|
323
|
+
)
|
|
324
|
+
return self.lightning_module
|