omnigenome 0.3.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of omnigenome might be problematic. Click here for more details.

Files changed (85) hide show
  1. omnigenome/__init__.py +281 -0
  2. omnigenome/auto/__init__.py +3 -0
  3. omnigenome/auto/auto_bench/__init__.py +12 -0
  4. omnigenome/auto/auto_bench/auto_bench.py +484 -0
  5. omnigenome/auto/auto_bench/auto_bench_cli.py +230 -0
  6. omnigenome/auto/auto_bench/auto_bench_config.py +216 -0
  7. omnigenome/auto/auto_bench/config_check.py +34 -0
  8. omnigenome/auto/auto_train/__init__.py +13 -0
  9. omnigenome/auto/auto_train/auto_train.py +430 -0
  10. omnigenome/auto/auto_train/auto_train_cli.py +222 -0
  11. omnigenome/auto/bench_hub/__init__.py +12 -0
  12. omnigenome/auto/bench_hub/bench_hub.py +25 -0
  13. omnigenome/cli/__init__.py +13 -0
  14. omnigenome/cli/commands/__init__.py +13 -0
  15. omnigenome/cli/commands/base.py +83 -0
  16. omnigenome/cli/commands/bench/__init__.py +13 -0
  17. omnigenome/cli/commands/bench/bench_cli.py +202 -0
  18. omnigenome/cli/commands/rna/__init__.py +13 -0
  19. omnigenome/cli/commands/rna/rna_design.py +178 -0
  20. omnigenome/cli/omnigenome_cli.py +128 -0
  21. omnigenome/src/__init__.py +12 -0
  22. omnigenome/src/abc/__init__.py +12 -0
  23. omnigenome/src/abc/abstract_dataset.py +622 -0
  24. omnigenome/src/abc/abstract_metric.py +114 -0
  25. omnigenome/src/abc/abstract_model.py +689 -0
  26. omnigenome/src/abc/abstract_tokenizer.py +267 -0
  27. omnigenome/src/dataset/__init__.py +16 -0
  28. omnigenome/src/dataset/omni_dataset.py +435 -0
  29. omnigenome/src/lora/__init__.py +13 -0
  30. omnigenome/src/lora/lora_model.py +294 -0
  31. omnigenome/src/metric/__init__.py +15 -0
  32. omnigenome/src/metric/classification_metric.py +184 -0
  33. omnigenome/src/metric/metric.py +199 -0
  34. omnigenome/src/metric/ranking_metric.py +142 -0
  35. omnigenome/src/metric/regression_metric.py +191 -0
  36. omnigenome/src/misc/__init__.py +3 -0
  37. omnigenome/src/misc/utils.py +439 -0
  38. omnigenome/src/model/__init__.py +19 -0
  39. omnigenome/src/model/augmentation/__init__.py +12 -0
  40. omnigenome/src/model/augmentation/model.py +219 -0
  41. omnigenome/src/model/classification/__init__.py +12 -0
  42. omnigenome/src/model/classification/model.py +642 -0
  43. omnigenome/src/model/embedding/__init__.py +12 -0
  44. omnigenome/src/model/embedding/model.py +263 -0
  45. omnigenome/src/model/mlm/__init__.py +12 -0
  46. omnigenome/src/model/mlm/model.py +177 -0
  47. omnigenome/src/model/module_utils.py +232 -0
  48. omnigenome/src/model/regression/__init__.py +12 -0
  49. omnigenome/src/model/regression/model.py +786 -0
  50. omnigenome/src/model/regression/resnet.py +483 -0
  51. omnigenome/src/model/rna_design/__init__.py +12 -0
  52. omnigenome/src/model/rna_design/model.py +426 -0
  53. omnigenome/src/model/seq2seq/__init__.py +12 -0
  54. omnigenome/src/model/seq2seq/model.py +44 -0
  55. omnigenome/src/tokenizer/__init__.py +16 -0
  56. omnigenome/src/tokenizer/bpe_tokenizer.py +226 -0
  57. omnigenome/src/tokenizer/kmers_tokenizer.py +247 -0
  58. omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +249 -0
  59. omnigenome/src/trainer/__init__.py +14 -0
  60. omnigenome/src/trainer/accelerate_trainer.py +739 -0
  61. omnigenome/src/trainer/hf_trainer.py +75 -0
  62. omnigenome/src/trainer/trainer.py +579 -0
  63. omnigenome/utility/__init__.py +3 -0
  64. omnigenome/utility/dataset_hub/__init__.py +13 -0
  65. omnigenome/utility/dataset_hub/dataset_hub.py +178 -0
  66. omnigenome/utility/ensemble.py +324 -0
  67. omnigenome/utility/hub_utils.py +517 -0
  68. omnigenome/utility/model_hub/__init__.py +12 -0
  69. omnigenome/utility/model_hub/model_hub.py +231 -0
  70. omnigenome/utility/pipeline_hub/__init__.py +12 -0
  71. omnigenome/utility/pipeline_hub/pipeline.py +483 -0
  72. omnigenome/utility/pipeline_hub/pipeline_hub.py +129 -0
  73. omnigenome-0.3.0a0.dist-info/METADATA +224 -0
  74. omnigenome-0.3.0a0.dist-info/RECORD +85 -0
  75. omnigenome-0.3.0a0.dist-info/WHEEL +5 -0
  76. omnigenome-0.3.0a0.dist-info/entry_points.txt +3 -0
  77. omnigenome-0.3.0a0.dist-info/licenses/LICENSE +201 -0
  78. omnigenome-0.3.0a0.dist-info/top_level.txt +2 -0
  79. tests/__init__.py +9 -0
  80. tests/conftest.py +160 -0
  81. tests/test_dataset_patterns.py +291 -0
  82. tests/test_examples_syntax.py +83 -0
  83. tests/test_model_loading.py +183 -0
  84. tests/test_rna_functions.py +255 -0
  85. tests/test_training_patterns.py +302 -0
@@ -0,0 +1,75 @@
1
+ # -*- coding: utf-8 -*-
2
+ # file: hf_trainer.py
3
+ # time: 14:40 06/04/2024
4
+ # author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
+ # github: https://github.com/yangheng95
6
+ # huggingface: https://huggingface.co/yangheng
7
+ # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
8
+ # Copyright (C) 2019-2024. All Rights Reserved.
9
+ """
10
+ HuggingFace trainer integration for OmniGenome models.
11
+
12
+ This module provides HuggingFace trainer wrappers for OmniGenome models,
13
+ enabling seamless integration with the HuggingFace training ecosystem
14
+ while maintaining OmniGenome-specific functionality.
15
+ """
16
+
17
+ from transformers import Trainer
18
+ from transformers import TrainingArguments
19
+
20
+ from ... import __name__ as omnigenome_name
21
+ from ... import __version__ as omnigenome_version
22
+
23
+
24
+ class HFTrainer(Trainer):
25
+ """
26
+ HuggingFace trainer wrapper for OmniGenome models.
27
+
28
+ This class extends the HuggingFace Trainer to include OmniGenome-specific
29
+ metadata and functionality while maintaining full compatibility with the
30
+ HuggingFace training ecosystem.
31
+
32
+ Attributes:
33
+ metadata: Dictionary containing OmniGenome library information
34
+ """
35
+
36
+ def __init__(self, *args, **kwargs):
37
+ """
38
+ Initialize the HuggingFace trainer wrapper.
39
+
40
+ Args:
41
+ *args: Positional arguments passed to the parent Trainer
42
+ **kwargs: Keyword arguments passed to the parent Trainer
43
+ """
44
+ super(HFTrainer, self).__init__(*args, **kwargs)
45
+ self.metadata = {
46
+ "library_name": omnigenome_name,
47
+ "omnigenome_version": omnigenome_version,
48
+ }
49
+
50
+
51
+ class HFTrainingArguments(TrainingArguments):
52
+ """
53
+ HuggingFace training arguments wrapper for OmniGenome models.
54
+
55
+ This class extends the HuggingFace TrainingArguments to include
56
+ OmniGenome-specific metadata while maintaining full compatibility
57
+ with the HuggingFace training ecosystem.
58
+
59
+ Attributes:
60
+ metadata: Dictionary containing OmniGenome library information
61
+ """
62
+
63
+ def __init__(self, *args, **kwargs):
64
+ """
65
+ Initialize the HuggingFace training arguments wrapper.
66
+
67
+ Args:
68
+ *args: Positional arguments passed to the parent TrainingArguments
69
+ **kwargs: Keyword arguments passed to the parent TrainingArguments
70
+ """
71
+ super(HFTrainingArguments, self).__init__(*args, **kwargs)
72
+ self.metadata = {
73
+ "library_name": omnigenome_name,
74
+ "omnigenome_version": omnigenome_version,
75
+ }
@@ -0,0 +1,579 @@
1
+ # -*- coding: utf-8 -*-
2
+ # file: trainer.py
3
+ # time: 14:40 06/04/2024
4
+ # author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
+ # github: https://github.com/yangheng95
6
+ # huggingface: https://huggingface.co/yangheng
7
+ # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
8
+ # Copyright (C) 2019-2024. All Rights Reserved.
9
+ """
10
+ Training utilities for OmniGenome models.
11
+
12
+ This module provides a comprehensive training framework for OmniGenome models,
13
+ including automatic mixed precision training, early stopping, metric tracking,
14
+ and model checkpointing.
15
+ """
16
+ import os
17
+ import tempfile
18
+ import autocuda
19
+ import numpy as np
20
+ from torch.utils.data import DataLoader
21
+ from tqdm import tqdm
22
+
23
+ from ..misc.utils import env_meta_info, fprint, seed_everything
24
+
25
+ import torch
26
+ from torch.cuda.amp import GradScaler
27
+
28
+
29
+ def _infer_optimization_direction(metrics, prev_metrics):
30
+ """
31
+ Infer the optimization direction based on metric names and trends.
32
+
33
+ This function determines whether larger or smaller values are better for
34
+ the given metrics by analyzing metric names and their trends over time.
35
+
36
+ Args:
37
+ metrics (dict): Current metric values
38
+ prev_metrics (list): Previous metric values from multiple epochs
39
+
40
+ Returns:
41
+ str: Either "larger_is_better" or "smaller_is_better"
42
+ """
43
+ larger_is_better_metrics = [
44
+ "accuracy",
45
+ "f1",
46
+ "recall",
47
+ "precision",
48
+ "roc_auc",
49
+ "pr_auc",
50
+ "score",
51
+ # ...
52
+ ]
53
+ smaller_is_better_metrics = [
54
+ "loss",
55
+ "error",
56
+ "mse",
57
+ "mae",
58
+ "r2",
59
+ "distance",
60
+ # ...
61
+ ]
62
+ for metric in larger_is_better_metrics:
63
+ if prev_metrics and metric in list(prev_metrics[0].keys())[0]:
64
+ return "larger_is_better"
65
+ for metric in smaller_is_better_metrics:
66
+ if prev_metrics and metric in list(prev_metrics[0].keys())[0]:
67
+ return "smaller_is_better"
68
+
69
+ fprint(
70
+ "Cannot determine the optimisation direction. Attempting inference from the metrics."
71
+ )
72
+ is_prev_increasing = np.mean(list(prev_metrics[0].values())[0]) < np.mean(
73
+ list(prev_metrics[-1].values())[0]
74
+ )
75
+ is_still_increasing = np.mean(list(prev_metrics[1].values())[0]) < np.mean(
76
+ list(metrics.values())[0]
77
+ )
78
+ fprint(
79
+ "Cannot determine the optimisation direction. Attempting inference from the metrics."
80
+ )
81
+
82
+ if is_prev_increasing and is_still_increasing:
83
+ return "larger_is_better"
84
+
85
+ is_prev_decreasing = np.mean(list(prev_metrics[0].values())[0]) > np.mean(
86
+ list(prev_metrics[-1].values())[0]
87
+ )
88
+ is_still_decreasing = np.mean(list(prev_metrics[1].values())[0]) > np.mean(
89
+ list(metrics.values())
90
+ )
91
+
92
+ if is_prev_decreasing and is_still_decreasing:
93
+ return "smaller_is_better"
94
+
95
+ return "larger_is_better" if is_prev_increasing else "smaller_is_better"
96
+
97
+
98
+ class Trainer:
99
+ """
100
+ Comprehensive trainer for OmniGenome models.
101
+
102
+ This trainer provides a complete training framework with automatic mixed precision,
103
+ early stopping, metric tracking, and model checkpointing. It supports various
104
+ training configurations and can handle different types of genomic sequence tasks.
105
+
106
+ Attributes:
107
+ model: The model to be trained
108
+ train_loader: DataLoader for training data
109
+ eval_loader: DataLoader for validation data
110
+ test_loader: DataLoader for test data
111
+ epochs: Number of training epochs
112
+ patience: Early stopping patience
113
+ optimizer: Optimizer for training
114
+ loss_fn: Loss function
115
+ compute_metrics: List of metric computation functions
116
+ device: Device to run training on
117
+ scaler: Gradient scaler for mixed precision training
118
+ metrics: Dictionary to store training metrics
119
+ predictions: Dictionary to store model predictions
120
+ """
121
+
122
+ def __init__(
123
+ self,
124
+ model,
125
+ train_dataset: torch.utils.data.Dataset = None,
126
+ eval_dataset: torch.utils.data.Dataset = None,
127
+ test_dataset: torch.utils.data.Dataset = None,
128
+ epochs: int = 3,
129
+ batch_size: int = 8,
130
+ patience: int = -1,
131
+ gradient_accumulation_steps: int = 1,
132
+ optimizer: torch.optim.Optimizer = None,
133
+ loss_fn: torch.nn.Module = None,
134
+ compute_metrics: list | str = None,
135
+ seed: int = 42,
136
+ device: [torch.device | str] = None,
137
+ autocast: str = "float16",
138
+ **kwargs,
139
+ ):
140
+ """
141
+ Initialize the trainer.
142
+
143
+ Args:
144
+ model: The model to be trained
145
+ train_dataset: Training dataset
146
+ eval_dataset: Validation dataset
147
+ test_dataset: Test dataset
148
+ epochs (int): Number of training epochs (default: 3)
149
+ batch_size (int): Batch size for training (default: 8)
150
+ patience (int): Early stopping patience (default: -1, no early stopping)
151
+ gradient_accumulation_steps (int): Gradient accumulation steps (default: 1)
152
+ optimizer: Optimizer for training (default: None)
153
+ loss_fn: Loss function (default: None)
154
+ compute_metrics: Metric computation functions (default: None)
155
+ seed (int): Random seed (default: 42)
156
+ device: Device to run training on (default: None, auto-detect)
157
+ autocast (str): Mixed precision type (default: "float16")
158
+ **kwargs: Additional keyword arguments
159
+ """
160
+ # sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
161
+
162
+ self.model = model
163
+
164
+ # DataLoaders
165
+ if kwargs.get("train_loader"):
166
+ self.train_loader = kwargs.get("train_loader", None)
167
+ self.eval_loader = kwargs.get("eval_loader", None)
168
+ self.test_loader = kwargs.get("test_loader", None)
169
+ else:
170
+ self.train_loader = DataLoader(
171
+ train_dataset, batch_size=batch_size, shuffle=True
172
+ )
173
+ self.eval_loader = (
174
+ DataLoader(eval_dataset, batch_size=batch_size)
175
+ if eval_dataset
176
+ else None
177
+ )
178
+ self.test_loader = (
179
+ DataLoader(test_dataset, batch_size=batch_size)
180
+ if test_dataset
181
+ else None
182
+ )
183
+
184
+ self.epochs = epochs
185
+ self.patience = patience if patience > 0 else epochs
186
+ self.gradient_accumulation_steps = gradient_accumulation_steps
187
+ self.optimizer = optimizer
188
+ self.loss_fn = loss_fn
189
+ self.compute_metrics = (
190
+ compute_metrics if isinstance(compute_metrics, list) else [compute_metrics]
191
+ )
192
+ self.seed = seed
193
+ self.device = device if device else autocuda.auto_cuda()
194
+ self.device = torch.device(self.device) if isinstance(self.device, str) else self.device
195
+
196
+ self.fast_dtype = {
197
+ "float32": torch.float32,
198
+ "fp32": torch.float32,
199
+ "float16": torch.float16,
200
+ "fp16": torch.float16,
201
+ "bfloat16": torch.bfloat16,
202
+ "bf16": torch.bfloat16,
203
+ }.get(autocast, torch.float16)
204
+ self.scaler = GradScaler()
205
+ if self.loss_fn is not None:
206
+ self.model.set_loss_fn(self.loss_fn)
207
+
208
+ self.model.to(self.device)
209
+
210
+ self.metadata = env_meta_info()
211
+ self.metrics = {}
212
+
213
+ self._optimization_direction = None
214
+ self.trial_name = kwargs.get("trial_name", self.model.__class__.__name__)
215
+
216
+ self.predictions = {}
217
+
218
+ def _is_metric_better(self, metrics, stage="valid"):
219
+ """
220
+ Check if the current metrics are better than the best metrics so far.
221
+
222
+ Args:
223
+ metrics (dict): Current metric values
224
+ stage (str): Stage name ("valid" or "test")
225
+
226
+ Returns:
227
+ bool: True if current metrics are better than best metrics
228
+ """
229
+ assert stage in [
230
+ "valid",
231
+ "test",
232
+ ], "The metrics stage should be either 'valid' or 'test'."
233
+
234
+ prev_metrics = self.metrics.get(stage, None)
235
+ if stage not in self.metrics:
236
+ self.metrics.update({f"{stage}": [metrics]})
237
+ else:
238
+ self.metrics[f"{stage}"].append(metrics)
239
+
240
+ if "best_valid" not in self.metrics:
241
+ self.metrics.update({"best_valid": metrics})
242
+ return True
243
+
244
+ if prev_metrics is None:
245
+ return False
246
+
247
+ self._optimization_direction = (
248
+ _infer_optimization_direction(metrics, prev_metrics)
249
+ if self._optimization_direction is None
250
+ else self._optimization_direction
251
+ )
252
+
253
+ if self._optimization_direction == "larger_is_better":
254
+ if np.mean(list(metrics.values())[0]) > np.mean(
255
+ list(self.metrics["best_valid"].values())[0]
256
+ ):
257
+ self.metrics.update({"best_valid": metrics})
258
+ return True
259
+ elif self._optimization_direction == "smaller_is_better":
260
+ if np.mean(list(metrics.values())[0]) < np.mean(
261
+ list(self.metrics["best_valid"].values())[0]
262
+ ):
263
+ self.metrics.update({"best_valid": metrics})
264
+ return True
265
+
266
+ return False
267
+
268
+ def train(self, path_to_save=None, **kwargs):
269
+ """
270
+ Train the model.
271
+
272
+ Args:
273
+ path_to_save (str, optional): Path to save the best model
274
+ **kwargs: Additional keyword arguments
275
+
276
+ Returns:
277
+ dict: Training metrics and results
278
+ """
279
+ seed_everything(self.seed)
280
+ patience = 0
281
+
282
+ if self.eval_loader is not None and len(self.eval_loader) > 0:
283
+ valid_metrics = self.evaluate()
284
+ else:
285
+ valid_metrics = self.test()
286
+ if self._is_metric_better(valid_metrics, stage="valid"):
287
+ self._save_state_dict()
288
+ patience = 0
289
+
290
+ for epoch in range(self.epochs):
291
+ self.model.train()
292
+ train_loss = []
293
+ train_it = tqdm(
294
+ self.train_loader, desc=f"Epoch {epoch + 1}/{self.epochs} Loss"
295
+ )
296
+ for step, batch in enumerate(train_it):
297
+ batch = batch.to(self.device)
298
+
299
+ if step % self.gradient_accumulation_steps == 0:
300
+ self.optimizer.zero_grad()
301
+
302
+ if self.fast_dtype:
303
+ with torch.autocast(device_type=self.device.type, dtype=self.fast_dtype):
304
+ outputs = self.model(**batch)
305
+ else:
306
+ outputs = self.model(**batch)
307
+ if "loss" not in outputs:
308
+ # Generally, the model should return a loss in the outputs via OmniGenBench
309
+ # For the Lora models, the loss is computed separately
310
+ if hasattr(self.model, "loss_function") and callable(self.model.loss_function):
311
+ loss = self.model.loss_function(outputs['logits'], outputs["labels"])
312
+ elif (hasattr(self.model, "model")
313
+ and hasattr(self.model.model, "loss_function")
314
+ and callable(self.model.model.loss_function)):
315
+ loss = self.model.model.loss_function(outputs['logits'], outputs["labels"])
316
+ else:
317
+ raise ValueError(
318
+ "The model does not have a loss function defined. "
319
+ "Please provide a loss function or ensure the model has one."
320
+ )
321
+ else:
322
+ # If the model returns a loss directly
323
+ loss = outputs["loss"]
324
+
325
+ loss = loss / self.gradient_accumulation_steps
326
+
327
+ if self.fast_dtype:
328
+ self.scaler.scale(loss).backward()
329
+ else:
330
+ loss.backward()
331
+
332
+ if (step + 1) % self.gradient_accumulation_steps == 0 or (
333
+ step + 1
334
+ ) == len(self.train_loader):
335
+ if self.fast_dtype:
336
+ self.scaler.step(self.optimizer)
337
+ self.scaler.update()
338
+ else:
339
+ self.optimizer.step()
340
+
341
+ train_loss.append(loss.item() * self.gradient_accumulation_steps)
342
+ train_it.set_description(
343
+ f"Epoch {epoch + 1}/{self.epochs} Loss: {np.nanmean(train_loss):.4f}"
344
+ )
345
+
346
+ if self.eval_loader is not None and len(self.eval_loader) > 0:
347
+ valid_metrics = self.evaluate()
348
+ else:
349
+ valid_metrics = self.test()
350
+
351
+ if self._is_metric_better(valid_metrics, stage="valid"):
352
+ self._save_state_dict()
353
+ patience = 0
354
+ else:
355
+ patience += 1
356
+ if patience >= self.patience:
357
+ fprint(f"Early stopping at epoch {epoch + 1}.")
358
+ break
359
+
360
+ if path_to_save:
361
+ _path_to_save = path_to_save + "_epoch_" + str(epoch + 1)
362
+
363
+ if valid_metrics:
364
+ for key, value in valid_metrics.items():
365
+ _path_to_save += f"_seed_{self.seed}_{key}_{value:.4f}"
366
+
367
+ self.save_model(_path_to_save, **kwargs)
368
+
369
+ if self.test_loader is not None and len(self.test_loader) > 0:
370
+ self._load_state_dict()
371
+ test_metrics = self.test()
372
+ self._is_metric_better(test_metrics, stage="test")
373
+
374
+ if path_to_save:
375
+ _path_to_save = path_to_save + "_final"
376
+ if self.metrics["test"]:
377
+ for key, value in self.metrics["test"][-1].items():
378
+ _path_to_save += f"_seed_{self.seed}_{key}_{value:.4f}"
379
+
380
+ self.save_model(_path_to_save, **kwargs)
381
+
382
+ self._remove_state_dict()
383
+
384
+ return self.metrics
385
+
386
+ def evaluate(self):
387
+ """
388
+ Evaluate the model on the validation set.
389
+
390
+ Returns:
391
+ dict: Evaluation metrics
392
+ """
393
+ with torch.no_grad():
394
+ self.model.eval()
395
+ val_truth = []
396
+ val_preds = []
397
+ it = tqdm(self.eval_loader, desc="Evaluating")
398
+ for batch in it:
399
+ batch.to(self.device)
400
+ labels = batch["labels"]
401
+ batch.pop("labels")
402
+ if self.fast_dtype:
403
+ with torch.autocast(device_type="cuda", dtype=self.fast_dtype):
404
+ predictions = self.model.predict(batch)["predictions"]
405
+ else:
406
+ predictions = self.model.predict(batch)["predictions"]
407
+ val_truth.append(labels.float().cpu().numpy(force=True))
408
+ val_preds.append(predictions.float().cpu().numpy(force=True))
409
+ val_truth = (
410
+ np.vstack(val_truth) if labels.ndim > 1 else np.hstack(val_truth)
411
+ )
412
+ val_preds = (
413
+ np.vstack(val_preds) if predictions.ndim > 1 else np.hstack(val_preds)
414
+ )
415
+ if not np.all(val_truth == -100):
416
+ valid_metrics = {}
417
+ for metric_func in self.compute_metrics:
418
+ valid_metrics.update(metric_func(val_truth, val_preds))
419
+
420
+ fprint(valid_metrics)
421
+ else:
422
+ valid_metrics = {
423
+ "Validation set labels may be NaN. No metrics calculated.": 0
424
+ }
425
+
426
+ self.predictions.update({"valid": {"pred": val_preds, "true": val_truth}})
427
+
428
+ return valid_metrics
429
+
430
+ def test(self):
431
+ """
432
+ Test the model on the test set.
433
+
434
+ Returns:
435
+ dict: Test metrics and predictions
436
+ """
437
+ with torch.no_grad():
438
+ self.model.eval()
439
+ preds = []
440
+ truth = []
441
+ it = tqdm(self.test_loader, desc="Testing")
442
+ for batch in it:
443
+ batch.to(self.device)
444
+ labels = batch["labels"]
445
+ batch.pop("labels")
446
+ if self.fast_dtype:
447
+ with torch.autocast(device_type="cuda", dtype=self.fast_dtype):
448
+ predictions = self.model.predict(batch)["predictions"]
449
+ else:
450
+ predictions = self.model.predict(batch)["predictions"]
451
+ truth.append(labels.float().cpu().numpy(force=True))
452
+ preds.append(predictions.float().cpu().numpy(force=True))
453
+ truth = np.vstack(truth) if labels.ndim > 1 else np.hstack(truth)
454
+ preds = np.vstack(preds) if predictions.ndim > 1 else np.hstack(preds)
455
+ if not np.all(truth == -100):
456
+ test_metrics = {}
457
+ for metric_func in self.compute_metrics:
458
+ test_metrics.update(metric_func(truth, preds))
459
+
460
+ fprint(test_metrics)
461
+ else:
462
+ test_metrics = {"Test set labels may be NaN. No metrics calculated.": 0}
463
+
464
+ self.predictions.update({"test": {"pred": preds, "true": truth}})
465
+
466
+ return test_metrics
467
+
468
+ def predict(self, data_loader):
469
+ """
470
+ Generate predictions using the model.
471
+
472
+ Args:
473
+ data_loader: DataLoader for prediction data
474
+
475
+ Returns:
476
+ torch.Tensor: Model predictions
477
+ """
478
+ return self.model.predict(data_loader)
479
+
480
+ def get_model(self, **kwargs):
481
+ """
482
+ Get the trained model.
483
+
484
+ Args:
485
+ **kwargs: Additional keyword arguments
486
+
487
+ Returns:
488
+ The trained model
489
+ """
490
+ return self.model
491
+
492
+ def compute_metrics(self):
493
+ """
494
+ Get the metric computation functions.
495
+
496
+ Returns:
497
+ list: List of metric computation functions
498
+ """
499
+ return self.compute_metrics
500
+
501
+ def unwrap_model(self, model=None):
502
+ """
503
+ Unwrap the model from any distributed training wrappers.
504
+
505
+ Args:
506
+ model: Model to unwrap (default: None, uses self.model)
507
+
508
+ Returns:
509
+ The unwrapped model
510
+ """
511
+ if model is None:
512
+ model = self.model
513
+ try:
514
+ return self.accelerator.unwrap_model(model)
515
+ except:
516
+ try:
517
+ return model.module
518
+ except:
519
+ return model
520
+
521
+ def save_model(self, path, overwrite=False, **kwargs):
522
+ """
523
+ Save the model to disk.
524
+
525
+ Args:
526
+ path (str): Path to save the model
527
+ overwrite (bool): Whether to overwrite existing files (default: False)
528
+ **kwargs: Additional keyword arguments
529
+ """
530
+ self.unwrap_model().save(path, overwrite, **kwargs)
531
+
532
+ def _load_state_dict(self):
533
+ """
534
+ Load model state dictionary from temporary file.
535
+
536
+ Returns:
537
+ dict: Model state dictionary
538
+ """
539
+ if os.path.exists(self._model_state_dict_path):
540
+ self.unwrap_model().load_state_dict(
541
+ torch.load(self._model_state_dict_path, map_location='cpu')
542
+ )
543
+ self.unwrap_model().to(self.device)
544
+
545
+ def _save_state_dict(self):
546
+ """
547
+ Save model state dictionary to temporary file.
548
+
549
+ Returns:
550
+ str: Path to temporary file
551
+ """
552
+ if not hasattr(self, "_model_state_dict_path"):
553
+ # 创建临时文件,并关闭以便写入
554
+ tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pt")
555
+ self._model_state_dict_path = tmp_file.name
556
+ tmp_file.close()
557
+
558
+ try:
559
+ if os.path.exists(self._model_state_dict_path):
560
+ os.remove(self._model_state_dict_path)
561
+ except Exception as e:
562
+ fprint(
563
+ f"Failed to remove the temporary checkpoint file {self._model_state_dict_path}: {e}"
564
+ )
565
+
566
+ torch.save(self.unwrap_model().state_dict(), self._model_state_dict_path)
567
+
568
+ def _remove_state_dict(self):
569
+ """
570
+ Remove temporary state dictionary file.
571
+ """
572
+ if hasattr(self, "_model_state_dict_path"):
573
+ try:
574
+ if os.path.exists(self._model_state_dict_path):
575
+ os.remove(self._model_state_dict_path)
576
+ except Exception as e:
577
+ fprint(
578
+ f"Failed to remove the temporary checkpoint file {self._model_state_dict_path}: {e}"
579
+ )
@@ -0,0 +1,3 @@
1
+ """
2
+ This package contains utility modules for interacting with the hub, datasets, and pipelines.
3
+ """
@@ -0,0 +1,13 @@
1
+ # -*- coding: utf-8 -*-
2
+ # File: __init__.py
3
+ # Time: 02:22 20/06/2025
4
+ # Author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
+ # Website: https://yangheng95.github.io
6
+ # GitHub: https://github.com/yangheng95
7
+ # HuggingFace: https://huggingface.co/yangheng
8
+ # Google Scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
9
+ # Copyright (C) 2019-2025. All rights reserved.
10
+ """
11
+ This package contains modules for the dataset hub.
12
+ """
13
+