bead 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. bead/__init__.py +11 -0
  2. bead/__main__.py +11 -0
  3. bead/active_learning/__init__.py +15 -0
  4. bead/active_learning/config.py +231 -0
  5. bead/active_learning/loop.py +566 -0
  6. bead/active_learning/models/__init__.py +24 -0
  7. bead/active_learning/models/base.py +852 -0
  8. bead/active_learning/models/binary.py +910 -0
  9. bead/active_learning/models/categorical.py +943 -0
  10. bead/active_learning/models/cloze.py +862 -0
  11. bead/active_learning/models/forced_choice.py +956 -0
  12. bead/active_learning/models/free_text.py +773 -0
  13. bead/active_learning/models/lora.py +365 -0
  14. bead/active_learning/models/magnitude.py +835 -0
  15. bead/active_learning/models/multi_select.py +795 -0
  16. bead/active_learning/models/ordinal_scale.py +811 -0
  17. bead/active_learning/models/peft_adapter.py +155 -0
  18. bead/active_learning/models/random_effects.py +639 -0
  19. bead/active_learning/selection.py +354 -0
  20. bead/active_learning/strategies.py +391 -0
  21. bead/active_learning/trainers/__init__.py +26 -0
  22. bead/active_learning/trainers/base.py +210 -0
  23. bead/active_learning/trainers/data_collator.py +172 -0
  24. bead/active_learning/trainers/dataset_utils.py +261 -0
  25. bead/active_learning/trainers/huggingface.py +304 -0
  26. bead/active_learning/trainers/lightning.py +324 -0
  27. bead/active_learning/trainers/metrics.py +424 -0
  28. bead/active_learning/trainers/mixed_effects.py +551 -0
  29. bead/active_learning/trainers/model_wrapper.py +509 -0
  30. bead/active_learning/trainers/registry.py +104 -0
  31. bead/adapters/__init__.py +11 -0
  32. bead/adapters/huggingface.py +61 -0
  33. bead/behavioral/__init__.py +116 -0
  34. bead/behavioral/analytics.py +646 -0
  35. bead/behavioral/extraction.py +343 -0
  36. bead/behavioral/merging.py +343 -0
  37. bead/cli/__init__.py +11 -0
  38. bead/cli/active_learning.py +513 -0
  39. bead/cli/active_learning_commands.py +779 -0
  40. bead/cli/completion.py +359 -0
  41. bead/cli/config.py +624 -0
  42. bead/cli/constraint_builders.py +286 -0
  43. bead/cli/deployment.py +859 -0
  44. bead/cli/deployment_trials.py +493 -0
  45. bead/cli/deployment_ui.py +332 -0
  46. bead/cli/display.py +378 -0
  47. bead/cli/items.py +960 -0
  48. bead/cli/items_factories.py +776 -0
  49. bead/cli/list_constraints.py +714 -0
  50. bead/cli/lists.py +490 -0
  51. bead/cli/main.py +430 -0
  52. bead/cli/models.py +877 -0
  53. bead/cli/resource_loaders.py +621 -0
  54. bead/cli/resources.py +1036 -0
  55. bead/cli/shell.py +356 -0
  56. bead/cli/simulate.py +840 -0
  57. bead/cli/templates.py +1158 -0
  58. bead/cli/training.py +1080 -0
  59. bead/cli/utils.py +614 -0
  60. bead/cli/workflow.py +1273 -0
  61. bead/config/__init__.py +68 -0
  62. bead/config/active_learning.py +1009 -0
  63. bead/config/config.py +192 -0
  64. bead/config/defaults.py +118 -0
  65. bead/config/deployment.py +217 -0
  66. bead/config/env.py +147 -0
  67. bead/config/item.py +45 -0
  68. bead/config/list.py +193 -0
  69. bead/config/loader.py +149 -0
  70. bead/config/logging.py +42 -0
  71. bead/config/model.py +49 -0
  72. bead/config/paths.py +46 -0
  73. bead/config/profiles.py +320 -0
  74. bead/config/resources.py +47 -0
  75. bead/config/serialization.py +210 -0
  76. bead/config/simulation.py +206 -0
  77. bead/config/template.py +238 -0
  78. bead/config/validation.py +267 -0
  79. bead/data/__init__.py +65 -0
  80. bead/data/base.py +87 -0
  81. bead/data/identifiers.py +97 -0
  82. bead/data/language_codes.py +61 -0
  83. bead/data/metadata.py +270 -0
  84. bead/data/range.py +123 -0
  85. bead/data/repository.py +358 -0
  86. bead/data/serialization.py +249 -0
  87. bead/data/timestamps.py +89 -0
  88. bead/data/validation.py +349 -0
  89. bead/data_collection/__init__.py +11 -0
  90. bead/data_collection/jatos.py +223 -0
  91. bead/data_collection/merger.py +154 -0
  92. bead/data_collection/prolific.py +198 -0
  93. bead/deployment/__init__.py +5 -0
  94. bead/deployment/distribution.py +402 -0
  95. bead/deployment/jatos/__init__.py +1 -0
  96. bead/deployment/jatos/api.py +200 -0
  97. bead/deployment/jatos/exporter.py +210 -0
  98. bead/deployment/jspsych/__init__.py +9 -0
  99. bead/deployment/jspsych/biome.json +44 -0
  100. bead/deployment/jspsych/config.py +411 -0
  101. bead/deployment/jspsych/generator.py +598 -0
  102. bead/deployment/jspsych/package.json +51 -0
  103. bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
  104. bead/deployment/jspsych/randomizer.py +299 -0
  105. bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
  106. bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
  107. bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
  108. bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
  109. bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
  110. bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
  111. bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
  112. bead/deployment/jspsych/src/plugins/rating.ts +248 -0
  113. bead/deployment/jspsych/src/slopit/index.ts +9 -0
  114. bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
  115. bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
  116. bead/deployment/jspsych/templates/experiment.css +1 -0
  117. bead/deployment/jspsych/templates/experiment.js.template +289 -0
  118. bead/deployment/jspsych/templates/index.html +51 -0
  119. bead/deployment/jspsych/templates/randomizer.js +241 -0
  120. bead/deployment/jspsych/templates/randomizer.js.template +313 -0
  121. bead/deployment/jspsych/trials.py +723 -0
  122. bead/deployment/jspsych/tsconfig.json +23 -0
  123. bead/deployment/jspsych/tsup.config.ts +30 -0
  124. bead/deployment/jspsych/ui/__init__.py +1 -0
  125. bead/deployment/jspsych/ui/components.py +383 -0
  126. bead/deployment/jspsych/ui/styles.py +411 -0
  127. bead/dsl/__init__.py +80 -0
  128. bead/dsl/ast.py +168 -0
  129. bead/dsl/context.py +178 -0
  130. bead/dsl/errors.py +71 -0
  131. bead/dsl/evaluator.py +570 -0
  132. bead/dsl/grammar.lark +81 -0
  133. bead/dsl/parser.py +231 -0
  134. bead/dsl/stdlib.py +929 -0
  135. bead/evaluation/__init__.py +13 -0
  136. bead/evaluation/convergence.py +485 -0
  137. bead/evaluation/interannotator.py +398 -0
  138. bead/items/__init__.py +40 -0
  139. bead/items/adapters/__init__.py +70 -0
  140. bead/items/adapters/anthropic.py +224 -0
  141. bead/items/adapters/api_utils.py +167 -0
  142. bead/items/adapters/base.py +216 -0
  143. bead/items/adapters/google.py +259 -0
  144. bead/items/adapters/huggingface.py +1074 -0
  145. bead/items/adapters/openai.py +323 -0
  146. bead/items/adapters/registry.py +202 -0
  147. bead/items/adapters/sentence_transformers.py +224 -0
  148. bead/items/adapters/togetherai.py +309 -0
  149. bead/items/binary.py +515 -0
  150. bead/items/cache.py +558 -0
  151. bead/items/categorical.py +593 -0
  152. bead/items/cloze.py +757 -0
  153. bead/items/constructor.py +784 -0
  154. bead/items/forced_choice.py +413 -0
  155. bead/items/free_text.py +681 -0
  156. bead/items/generation.py +432 -0
  157. bead/items/item.py +396 -0
  158. bead/items/item_template.py +787 -0
  159. bead/items/magnitude.py +573 -0
  160. bead/items/multi_select.py +621 -0
  161. bead/items/ordinal_scale.py +569 -0
  162. bead/items/scoring.py +448 -0
  163. bead/items/validation.py +723 -0
  164. bead/lists/__init__.py +30 -0
  165. bead/lists/balancer.py +263 -0
  166. bead/lists/constraints.py +1067 -0
  167. bead/lists/experiment_list.py +286 -0
  168. bead/lists/list_collection.py +378 -0
  169. bead/lists/partitioner.py +1141 -0
  170. bead/lists/stratification.py +254 -0
  171. bead/participants/__init__.py +73 -0
  172. bead/participants/collection.py +699 -0
  173. bead/participants/merging.py +312 -0
  174. bead/participants/metadata_spec.py +491 -0
  175. bead/participants/models.py +276 -0
  176. bead/resources/__init__.py +29 -0
  177. bead/resources/adapters/__init__.py +19 -0
  178. bead/resources/adapters/base.py +104 -0
  179. bead/resources/adapters/cache.py +128 -0
  180. bead/resources/adapters/glazing.py +508 -0
  181. bead/resources/adapters/registry.py +117 -0
  182. bead/resources/adapters/unimorph.py +796 -0
  183. bead/resources/classification.py +856 -0
  184. bead/resources/constraint_builders.py +329 -0
  185. bead/resources/constraints.py +165 -0
  186. bead/resources/lexical_item.py +223 -0
  187. bead/resources/lexicon.py +744 -0
  188. bead/resources/loaders.py +209 -0
  189. bead/resources/template.py +441 -0
  190. bead/resources/template_collection.py +707 -0
  191. bead/resources/template_generation.py +349 -0
  192. bead/simulation/__init__.py +29 -0
  193. bead/simulation/annotators/__init__.py +15 -0
  194. bead/simulation/annotators/base.py +175 -0
  195. bead/simulation/annotators/distance_based.py +135 -0
  196. bead/simulation/annotators/lm_based.py +114 -0
  197. bead/simulation/annotators/oracle.py +182 -0
  198. bead/simulation/annotators/random.py +181 -0
  199. bead/simulation/dsl_extension/__init__.py +3 -0
  200. bead/simulation/noise_models/__init__.py +13 -0
  201. bead/simulation/noise_models/base.py +42 -0
  202. bead/simulation/noise_models/random_noise.py +82 -0
  203. bead/simulation/noise_models/systematic.py +132 -0
  204. bead/simulation/noise_models/temperature.py +86 -0
  205. bead/simulation/runner.py +144 -0
  206. bead/simulation/strategies/__init__.py +23 -0
  207. bead/simulation/strategies/base.py +123 -0
  208. bead/simulation/strategies/binary.py +103 -0
  209. bead/simulation/strategies/categorical.py +123 -0
  210. bead/simulation/strategies/cloze.py +224 -0
  211. bead/simulation/strategies/forced_choice.py +127 -0
  212. bead/simulation/strategies/free_text.py +105 -0
  213. bead/simulation/strategies/magnitude.py +116 -0
  214. bead/simulation/strategies/multi_select.py +129 -0
  215. bead/simulation/strategies/ordinal_scale.py +131 -0
  216. bead/templates/__init__.py +27 -0
  217. bead/templates/adapters/__init__.py +17 -0
  218. bead/templates/adapters/base.py +128 -0
  219. bead/templates/adapters/cache.py +178 -0
  220. bead/templates/adapters/huggingface.py +312 -0
  221. bead/templates/combinatorics.py +103 -0
  222. bead/templates/filler.py +605 -0
  223. bead/templates/renderers.py +177 -0
  224. bead/templates/resolver.py +178 -0
  225. bead/templates/strategies.py +1806 -0
  226. bead/templates/streaming.py +195 -0
  227. bead-0.1.0.dist-info/METADATA +212 -0
  228. bead-0.1.0.dist-info/RECORD +231 -0
  229. bead-0.1.0.dist-info/WHEEL +4 -0
  230. bead-0.1.0.dist-info/entry_points.txt +2 -0
  231. bead-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,304 @@
1
+ """HuggingFace Transformers trainer implementation.
2
+
3
+ This module provides a trainer that uses the HuggingFace Transformers library
4
+ for model training with integrated TensorBoard logging.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ import time
11
+ from pathlib import Path
12
+ from typing import TYPE_CHECKING
13
+
14
+ from bead.active_learning.trainers.base import BaseTrainer, ModelMetadata
15
+ from bead.data.base import BeadBaseModel
16
+ from bead.data.timestamps import format_iso8601, now_iso8601
17
+
18
+ if TYPE_CHECKING:
19
+ from datasets import Dataset
20
+ from transformers import PreTrainedModel, PreTrainedTokenizer
21
+
22
+
23
+ class HuggingFaceTrainer(BaseTrainer):
24
+ """Trainer using HuggingFace Transformers.
25
+
26
+ This trainer uses the HuggingFace Transformers library to train models
27
+ for sequence classification and other NLP tasks. It supports TensorBoard
28
+ logging and checkpoint management.
29
+
30
+ Parameters
31
+ ----------
32
+ config : dict[str, int | str | float | bool | Path] | BeadBaseModel
33
+ Training configuration with the following expected fields:
34
+ - model_name: str - Base model name/path
35
+ - task_type: str - Task type (classification, regression, etc.)
36
+ - num_labels: int | None - Number of labels for classification
37
+ - output_dir: Path - Directory for outputs
38
+ - num_epochs: int - Number of training epochs
39
+ - batch_size: int - Training batch size
40
+ - learning_rate: float - Learning rate
41
+ - weight_decay: float - Weight decay
42
+ - warmup_steps: int - Warmup steps
43
+ - evaluation_strategy: str - Evaluation strategy (epoch, steps, no)
44
+ - save_strategy: str - Save strategy (epoch, steps, no)
45
+ - load_best_model_at_end: bool - Load best model at end
46
+ - logging_dir: Path | None - Logging directory
47
+ - fp16: bool - Use mixed precision
48
+
49
+ Attributes
50
+ ----------
51
+ config : dict[str, int | str | float | bool | Path] | BeadBaseModel
52
+ Training configuration.
53
+ model : PreTrainedModel | None
54
+ The trained model.
55
+ tokenizer : PreTrainedTokenizer | None
56
+ The tokenizer.
57
+
58
+ Examples
59
+ --------
60
+ >>> from pathlib import Path
61
+ >>> config = {
62
+ ... "model_name": "bert-base-uncased",
63
+ ... "task_type": "classification",
64
+ ... "num_labels": 2,
65
+ ... "output_dir": Path("output"),
66
+ ... "num_epochs": 3,
67
+ ... "batch_size": 16,
68
+ ... "learning_rate": 2e-5,
69
+ ... "weight_decay": 0.01,
70
+ ... "warmup_steps": 0,
71
+ ... "evaluation_strategy": "epoch",
72
+ ... "save_strategy": "epoch",
73
+ ... "load_best_model_at_end": True,
74
+ ... "logging_dir": None,
75
+ ... "fp16": False
76
+ ... }
77
+ >>> trainer = HuggingFaceTrainer(config)
78
+ >>> trainer.model is None
79
+ True
80
+ """
81
+
82
+ def __init__(
83
+ self, config: dict[str, int | str | float | bool | Path] | BeadBaseModel
84
+ ) -> None:
85
+ super().__init__(config)
86
+ self.model: PreTrainedModel | None = None
87
+ self.tokenizer: PreTrainedTokenizer | None = None
88
+
89
+ def _get_config_value(
90
+ self, key: str, default: int | str | float | bool | Path | None = None
91
+ ) -> int | str | float | bool | Path | None:
92
+ """Get configuration value with fallback to default.
93
+
94
+ Parameters
95
+ ----------
96
+ key : str
97
+ Configuration key.
98
+ default : int | str | float | bool | Path | None
99
+ Default value if key not found.
100
+
101
+ Returns
102
+ -------
103
+ int | str | float | bool | Path | None
104
+ Configuration value.
105
+ """
106
+ if hasattr(self.config, key):
107
+ return getattr(self.config, key)
108
+ if isinstance(self.config, dict):
109
+ return self.config.get(key, default)
110
+ return default
111
+
112
+ def train(
113
+ self, train_data: Dataset, eval_data: Dataset | None = None
114
+ ) -> ModelMetadata:
115
+ """Train model using HuggingFace Trainer.
116
+
117
+ Parameters
118
+ ----------
119
+ train_data : Dataset
120
+ HuggingFace Dataset for training.
121
+ eval_data : Dataset | None
122
+ HuggingFace Dataset for evaluation.
123
+
124
+ Returns
125
+ -------
126
+ ModelMetadata
127
+ Training metadata.
128
+
129
+ Raises
130
+ ------
131
+ ValueError
132
+ If task type is not supported.
133
+
134
+ Examples
135
+ --------
136
+ >>> config = {"model_name": "bert-base-uncased"} # doctest: +SKIP
137
+ >>> trainer = HuggingFaceTrainer(config) # doctest: +SKIP
138
+ >>> metadata = trainer.train(train_dataset) # doctest: +SKIP
139
+ >>> metadata.framework # doctest: +SKIP
140
+ 'huggingface'
141
+ """
142
+ from transformers import ( # noqa: PLC0415
143
+ AutoModelForSequenceClassification,
144
+ AutoTokenizer,
145
+ DataCollatorWithPadding,
146
+ Trainer,
147
+ TrainingArguments,
148
+ )
149
+
150
+ start_time = time.time()
151
+
152
+ # Get config values
153
+ model_name = self._get_config_value("model_name", "bert-base-uncased")
154
+ task_type = self._get_config_value("task_type", "classification")
155
+ num_labels = self._get_config_value("num_labels", 2)
156
+ output_dir = self._get_config_value("output_dir", Path("output"))
157
+ num_epochs = self._get_config_value("num_epochs", 3)
158
+ batch_size = self._get_config_value("batch_size", 16)
159
+ learning_rate = self._get_config_value("learning_rate", 2e-5)
160
+ weight_decay = self._get_config_value("weight_decay", 0.01)
161
+ warmup_steps = self._get_config_value("warmup_steps", 0)
162
+ evaluation_strategy = self._get_config_value("evaluation_strategy", "epoch")
163
+ save_strategy = self._get_config_value("save_strategy", "epoch")
164
+ load_best = self._get_config_value("load_best_model_at_end", True)
165
+ logging_dir = self._get_config_value("logging_dir", None)
166
+ fp16 = self._get_config_value("fp16", False)
167
+
168
+ # Load model and tokenizer
169
+ if task_type == "classification":
170
+ self.model = AutoModelForSequenceClassification.from_pretrained(
171
+ model_name, num_labels=num_labels
172
+ )
173
+ else:
174
+ msg = f"Task type not supported: {task_type}"
175
+ raise ValueError(msg)
176
+
177
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
178
+
179
+ # Create training arguments
180
+ training_args = TrainingArguments(
181
+ output_dir=str(output_dir),
182
+ num_train_epochs=num_epochs,
183
+ per_device_train_batch_size=batch_size,
184
+ per_device_eval_batch_size=batch_size,
185
+ learning_rate=learning_rate,
186
+ weight_decay=weight_decay,
187
+ warmup_steps=warmup_steps,
188
+ eval_strategy=evaluation_strategy, # type: ignore
189
+ save_strategy=save_strategy,
190
+ load_best_model_at_end=load_best,
191
+ logging_dir=str(logging_dir) if logging_dir else None,
192
+ fp16=fp16,
193
+ report_to=["tensorboard"] if logging_dir else [],
194
+ )
195
+
196
+ # Create data collator
197
+ data_collator = DataCollatorWithPadding(tokenizer=self.tokenizer)
198
+
199
+ # Create trainer
200
+ trainer = Trainer(
201
+ model=self.model,
202
+ args=training_args,
203
+ train_dataset=train_data,
204
+ eval_dataset=eval_data,
205
+ data_collator=data_collator,
206
+ )
207
+
208
+ # Train
209
+ trainer.train()
210
+
211
+ # Evaluate
212
+ metrics = {}
213
+ if eval_data is not None:
214
+ eval_results = trainer.evaluate()
215
+ metrics = {k: float(v) for k, v in eval_results.items()}
216
+
217
+ training_time = time.time() - start_time
218
+
219
+ # Get best checkpoint path
220
+ best_checkpoint = None
221
+ if trainer.state.best_model_checkpoint:
222
+ best_checkpoint = Path(trainer.state.best_model_checkpoint)
223
+
224
+ # Create metadata
225
+ config_dict = (
226
+ self.config
227
+ if isinstance(self.config, dict)
228
+ else (
229
+ self.config.model_dump() if hasattr(self.config, "model_dump") else {}
230
+ )
231
+ )
232
+
233
+ metadata = ModelMetadata(
234
+ model_name=model_name,
235
+ framework="huggingface",
236
+ training_config=config_dict,
237
+ training_data_path=Path("train.json"),
238
+ eval_data_path=Path("eval.json") if eval_data else None,
239
+ metrics=metrics,
240
+ best_checkpoint=best_checkpoint,
241
+ training_time=training_time,
242
+ training_timestamp=format_iso8601(now_iso8601()),
243
+ )
244
+
245
+ return metadata
246
+
247
+ def save_model(self, output_dir: Path, metadata: ModelMetadata) -> None:
248
+ """Save model and metadata.
249
+
250
+ Parameters
251
+ ----------
252
+ output_dir : Path
253
+ Directory to save model and metadata.
254
+ metadata : ModelMetadata
255
+ Training metadata to save.
256
+
257
+ Examples
258
+ --------
259
+ >>> trainer = HuggingFaceTrainer({}) # doctest: +SKIP
260
+ >>> trainer.save_model(Path("output"), metadata) # doctest: +SKIP
261
+ """
262
+ output_dir.mkdir(parents=True, exist_ok=True)
263
+
264
+ # Save model
265
+ if self.model is not None:
266
+ self.model.save_pretrained(output_dir / "model")
267
+ if self.tokenizer is not None:
268
+ self.tokenizer.save_pretrained(output_dir / "model")
269
+
270
+ # Save metadata
271
+ with open(output_dir / "metadata.json", "w") as f:
272
+ # Convert Path objects to strings for JSON serialization
273
+ metadata_dict = metadata.model_dump()
274
+ json.dump(metadata_dict, f, indent=2, default=str)
275
+
276
+ def load_model(self, model_dir: Path) -> PreTrainedModel:
277
+ """Load model.
278
+
279
+ Parameters
280
+ ----------
281
+ model_dir : Path
282
+ Directory containing saved model.
283
+
284
+ Returns
285
+ -------
286
+ PreTrainedModel
287
+ Loaded model.
288
+
289
+ Examples
290
+ --------
291
+ >>> trainer = HuggingFaceTrainer({}) # doctest: +SKIP
292
+ >>> model = trainer.load_model(Path("saved_model")) # doctest: +SKIP
293
+ """
294
+ from transformers import ( # noqa: PLC0415
295
+ AutoModelForSequenceClassification,
296
+ AutoTokenizer,
297
+ )
298
+
299
+ self.model = AutoModelForSequenceClassification.from_pretrained(
300
+ model_dir / "model"
301
+ )
302
+ self.tokenizer = AutoTokenizer.from_pretrained(model_dir / "model")
303
+
304
+ return self.model
@@ -0,0 +1,324 @@
1
+ """PyTorch Lightning trainer implementation.
2
+
3
+ This module provides a trainer that uses PyTorch Lightning for model training
4
+ with callbacks for checkpointing and early stopping.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ import time
11
+ from pathlib import Path
12
+ from typing import TYPE_CHECKING, Any
13
+
14
+ from bead.active_learning.trainers.base import BaseTrainer, ModelMetadata
15
+ from bead.data.base import BeadBaseModel
16
+ from bead.data.timestamps import format_iso8601, now_iso8601
17
+
18
+ if TYPE_CHECKING:
19
+ import pytorch_lightning as pl
20
+ from torch.nn import Module
21
+ from torch.utils.data import DataLoader
22
+
23
+
24
+ def create_lightning_module(
25
+ model: Module, learning_rate: float = 2e-5
26
+ ) -> pl.LightningModule:
27
+ """Create a PyTorch Lightning module.
28
+
29
+ Parameters
30
+ ----------
31
+ model
32
+ The PyTorch model to wrap in a Lightning module.
33
+ learning_rate
34
+ Learning rate for the AdamW optimizer.
35
+
36
+ Returns
37
+ -------
38
+ pl.LightningModule
39
+ Lightning module wrapping the provided model with training,
40
+ validation, and optimizer configuration.
41
+ """
42
+ import pytorch_lightning as pl # noqa: PLC0415
43
+ import torch # noqa: PLC0415
44
+
45
+ class _LightningModule(pl.LightningModule):
46
+ def __init__(self) -> None:
47
+ super().__init__()
48
+ self.model = model
49
+ self.learning_rate = learning_rate
50
+
51
+ def forward(self, **inputs: Any) -> Any:
52
+ return self.model(**inputs)
53
+
54
+ def training_step(self, batch: Any, batch_idx: int) -> Any:
55
+ outputs = self(**batch)
56
+ loss = outputs.loss
57
+ self.log("train_loss", loss)
58
+ return loss
59
+
60
+ def validation_step(self, batch: Any, batch_idx: int) -> Any:
61
+ outputs = self(**batch)
62
+ loss = outputs.loss
63
+ self.log("val_loss", loss)
64
+ return loss
65
+
66
+ def configure_optimizers(self) -> Any:
67
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
68
+ return optimizer
69
+
70
+ return _LightningModule()
71
+
72
+
73
+ class PyTorchLightningTrainer(BaseTrainer):
74
+ """Trainer using PyTorch Lightning.
75
+
76
+ Trains models using PyTorch Lightning with callbacks for checkpointing
77
+ and early stopping.
78
+
79
+ Parameters
80
+ ----------
81
+ config
82
+ Training configuration as a dict or config object with the following
83
+ fields:
84
+
85
+ - model_name: str, base model name or path
86
+ - num_labels: int, number of output labels
87
+ - num_epochs: int, number of training epochs
88
+ - learning_rate: float, learning rate for optimizer
89
+ - output_dir: Path, directory for outputs and checkpoints
90
+ - logging_dir: Path or None, optional TensorBoard logging directory
91
+
92
+ Attributes
93
+ ----------
94
+ config : dict[str, int | str | float | bool | Path] | BeadBaseModel
95
+ Training configuration.
96
+ lightning_module : pl.LightningModule | None
97
+ The Lightning module wrapper, set after training.
98
+
99
+ Examples
100
+ --------
101
+ >>> from pathlib import Path
102
+ >>> config = {
103
+ ... "model_name": "bert-base-uncased",
104
+ ... "num_labels": 2,
105
+ ... "num_epochs": 3,
106
+ ... "learning_rate": 2e-5,
107
+ ... "output_dir": Path("output"),
108
+ ... "logging_dir": None
109
+ ... }
110
+ >>> trainer = PyTorchLightningTrainer(config)
111
+ >>> trainer.lightning_module is None
112
+ True
113
+ """
114
+
115
+ def __init__(
116
+ self, config: dict[str, int | str | float | bool | Path] | BeadBaseModel
117
+ ) -> None:
118
+ super().__init__(config)
119
+ self.lightning_module: pl.LightningModule | None = None
120
+
121
+ def _get_config_value(
122
+ self, key: str, default: int | str | float | bool | Path | None = None
123
+ ) -> int | str | float | bool | Path | None:
124
+ """Get configuration value with fallback to default.
125
+
126
+ Parameters
127
+ ----------
128
+ key
129
+ Configuration key to retrieve.
130
+ default
131
+ Default value if key is not found.
132
+
133
+ Returns
134
+ -------
135
+ int | str | float | bool | Path | None
136
+ Configuration value for the given key, or default if not found.
137
+ """
138
+ if hasattr(self.config, key):
139
+ return getattr(self.config, key)
140
+ if isinstance(self.config, dict):
141
+ return self.config.get(key, default)
142
+ return default
143
+
144
+ def train(
145
+ self, train_data: DataLoader, eval_data: DataLoader | None = None
146
+ ) -> ModelMetadata:
147
+ """Train a model using PyTorch Lightning.
148
+
149
+ Loads a pretrained model, wraps it in a Lightning module, and trains
150
+ with checkpointing and early stopping callbacks.
151
+
152
+ Parameters
153
+ ----------
154
+ train_data
155
+ Training dataloader providing batches for training.
156
+ eval_data
157
+ Optional evaluation dataloader for validation during training.
158
+
159
+ Returns
160
+ -------
161
+ ModelMetadata
162
+ Metadata containing model name, framework, training config,
163
+ metrics, checkpoint path, and training time.
164
+
165
+ Examples
166
+ --------
167
+ >>> config = {"model_name": "bert-base-uncased"} # doctest: +SKIP
168
+ >>> trainer = PyTorchLightningTrainer(config) # doctest: +SKIP
169
+ >>> metadata = trainer.train(train_loader) # doctest: +SKIP
170
+ >>> metadata.framework # doctest: +SKIP
171
+ 'pytorch_lightning'
172
+ """
173
+ import pytorch_lightning as pl # noqa: PLC0415
174
+ from transformers import AutoModelForSequenceClassification # noqa: PLC0415
175
+
176
+ start_time = time.time()
177
+
178
+ # get config values
179
+ model_name = self._get_config_value("model_name", "bert-base-uncased")
180
+ num_labels = self._get_config_value("num_labels", 2)
181
+ num_epochs = self._get_config_value("num_epochs", 3)
182
+ learning_rate = self._get_config_value("learning_rate", 2e-5)
183
+ output_dir = self._get_config_value("output_dir", Path("output"))
184
+ logging_dir = self._get_config_value("logging_dir", None)
185
+
186
+ # load model
187
+ model = AutoModelForSequenceClassification.from_pretrained(
188
+ model_name, num_labels=num_labels
189
+ )
190
+
191
+ # create lightning module
192
+ self.lightning_module = create_lightning_module(model, learning_rate)
193
+
194
+ # create callbacks
195
+ callbacks = [
196
+ pl.callbacks.ModelCheckpoint(
197
+ monitor="val_loss",
198
+ dirpath=output_dir,
199
+ filename="best-{epoch:02d}-{val_loss:.2f}",
200
+ ),
201
+ pl.callbacks.EarlyStopping(monitor="val_loss", patience=3),
202
+ ]
203
+
204
+ # create logger
205
+ logger = None
206
+ if logging_dir:
207
+ logger = pl.loggers.TensorBoardLogger(str(logging_dir))
208
+
209
+ # create trainer
210
+ trainer = pl.Trainer(
211
+ max_epochs=num_epochs,
212
+ accelerator="auto",
213
+ devices="auto",
214
+ logger=logger,
215
+ callbacks=callbacks,
216
+ )
217
+
218
+ # train
219
+ trainer.fit(
220
+ self.lightning_module,
221
+ train_dataloaders=train_data,
222
+ val_dataloaders=eval_data,
223
+ )
224
+
225
+ # evaluate
226
+ metrics: dict[str, float] = {}
227
+ if eval_data is not None:
228
+ eval_results = trainer.validate(
229
+ self.lightning_module, dataloaders=eval_data
230
+ )
231
+ if eval_results:
232
+ metrics = {k: float(v) for k, v in eval_results[0].items()}
233
+
234
+ training_time = time.time() - start_time
235
+
236
+ # get best checkpoint path
237
+ best_checkpoint = None
238
+ if hasattr(trainer.checkpoint_callback, "best_model_path"):
239
+ best_checkpoint_str = trainer.checkpoint_callback.best_model_path
240
+ if best_checkpoint_str:
241
+ best_checkpoint = Path(best_checkpoint_str)
242
+
243
+ # create metadata
244
+ config_dict = (
245
+ self.config
246
+ if isinstance(self.config, dict)
247
+ else (
248
+ self.config.model_dump() if hasattr(self.config, "model_dump") else {}
249
+ )
250
+ )
251
+
252
+ metadata = ModelMetadata(
253
+ model_name=model_name,
254
+ framework="pytorch_lightning",
255
+ training_config=config_dict,
256
+ training_data_path=Path("train.json"),
257
+ eval_data_path=Path("eval.json") if eval_data else None,
258
+ metrics=metrics,
259
+ best_checkpoint=best_checkpoint,
260
+ training_time=training_time,
261
+ training_timestamp=format_iso8601(now_iso8601()),
262
+ )
263
+
264
+ return metadata
265
+
266
+ def save_model(self, output_dir: Path, metadata: ModelMetadata) -> None:
267
+ """Save model and metadata to disk.
268
+
269
+ Saves the Lightning module state dict and training metadata as JSON.
270
+
271
+ Parameters
272
+ ----------
273
+ output_dir
274
+ Directory to save model checkpoint and metadata JSON file.
275
+ metadata
276
+ Training metadata to save alongside the model.
277
+
278
+ Examples
279
+ --------
280
+ >>> trainer = PyTorchLightningTrainer({}) # doctest: +SKIP
281
+ >>> trainer.save_model(Path("output"), metadata) # doctest: +SKIP
282
+ """
283
+ import torch # noqa: PLC0415
284
+
285
+ output_dir.mkdir(parents=True, exist_ok=True)
286
+
287
+ # save lightning checkpoint
288
+ if self.lightning_module is not None:
289
+ torch.save(
290
+ self.lightning_module.state_dict(),
291
+ output_dir / "lightning_model.pt",
292
+ )
293
+
294
+ # save metadata
295
+ with open(output_dir / "metadata.json", "w") as f:
296
+ metadata_dict = metadata.model_dump()
297
+ json.dump(metadata_dict, f, indent=2, default=str)
298
+
299
+ def load_model(self, model_dir: Path) -> pl.LightningModule | None:
300
+ """Load a saved model from disk.
301
+
302
+ Parameters
303
+ ----------
304
+ model_dir
305
+ Directory containing the saved Lightning model state dict.
306
+
307
+ Returns
308
+ -------
309
+ pl.LightningModule | None
310
+ The Lightning module with loaded weights, or None if no module
311
+ has been initialized.
312
+
313
+ Examples
314
+ --------
315
+ >>> trainer = PyTorchLightningTrainer({}) # doctest: +SKIP
316
+ >>> model = trainer.load_model(Path("saved_model")) # doctest: +SKIP
317
+ """
318
+ import torch # noqa: PLC0415
319
+
320
+ if self.lightning_module is not None:
321
+ self.lightning_module.load_state_dict(
322
+ torch.load(model_dir / "lightning_model.pt")
323
+ )
324
+ return self.lightning_module