omnigenome 0.3.1a0__py3-none-any.whl → 0.3.4a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of omnigenome might be problematic. Click here for more details.

Files changed (79) hide show
  1. omnigenome/__init__.py +252 -266
  2. {omnigenome-0.3.1a0.dist-info → omnigenome-0.3.4a0.dist-info}/METADATA +9 -9
  3. omnigenome-0.3.4a0.dist-info/RECORD +7 -0
  4. omnigenome/auto/__init__.py +0 -3
  5. omnigenome/auto/auto_bench/__init__.py +0 -11
  6. omnigenome/auto/auto_bench/auto_bench.py +0 -494
  7. omnigenome/auto/auto_bench/auto_bench_cli.py +0 -230
  8. omnigenome/auto/auto_bench/auto_bench_config.py +0 -216
  9. omnigenome/auto/auto_bench/config_check.py +0 -34
  10. omnigenome/auto/auto_train/__init__.py +0 -12
  11. omnigenome/auto/auto_train/auto_train.py +0 -429
  12. omnigenome/auto/auto_train/auto_train_cli.py +0 -222
  13. omnigenome/auto/bench_hub/__init__.py +0 -11
  14. omnigenome/auto/bench_hub/bench_hub.py +0 -25
  15. omnigenome/cli/__init__.py +0 -12
  16. omnigenome/cli/commands/__init__.py +0 -12
  17. omnigenome/cli/commands/base.py +0 -83
  18. omnigenome/cli/commands/bench/__init__.py +0 -12
  19. omnigenome/cli/commands/bench/bench_cli.py +0 -202
  20. omnigenome/cli/commands/rna/__init__.py +0 -12
  21. omnigenome/cli/commands/rna/rna_design.py +0 -177
  22. omnigenome/cli/omnigenome_cli.py +0 -128
  23. omnigenome/src/__init__.py +0 -11
  24. omnigenome/src/abc/__init__.py +0 -11
  25. omnigenome/src/abc/abstract_dataset.py +0 -641
  26. omnigenome/src/abc/abstract_metric.py +0 -114
  27. omnigenome/src/abc/abstract_model.py +0 -690
  28. omnigenome/src/abc/abstract_tokenizer.py +0 -269
  29. omnigenome/src/dataset/__init__.py +0 -16
  30. omnigenome/src/dataset/omni_dataset.py +0 -437
  31. omnigenome/src/lora/__init__.py +0 -12
  32. omnigenome/src/lora/lora_model.py +0 -300
  33. omnigenome/src/metric/__init__.py +0 -15
  34. omnigenome/src/metric/classification_metric.py +0 -184
  35. omnigenome/src/metric/metric.py +0 -199
  36. omnigenome/src/metric/ranking_metric.py +0 -142
  37. omnigenome/src/metric/regression_metric.py +0 -191
  38. omnigenome/src/misc/__init__.py +0 -3
  39. omnigenome/src/misc/utils.py +0 -503
  40. omnigenome/src/model/__init__.py +0 -19
  41. omnigenome/src/model/augmentation/__init__.py +0 -11
  42. omnigenome/src/model/augmentation/model.py +0 -219
  43. omnigenome/src/model/classification/__init__.py +0 -11
  44. omnigenome/src/model/classification/model.py +0 -638
  45. omnigenome/src/model/embedding/__init__.py +0 -11
  46. omnigenome/src/model/embedding/model.py +0 -263
  47. omnigenome/src/model/mlm/__init__.py +0 -11
  48. omnigenome/src/model/mlm/model.py +0 -177
  49. omnigenome/src/model/module_utils.py +0 -232
  50. omnigenome/src/model/regression/__init__.py +0 -11
  51. omnigenome/src/model/regression/model.py +0 -781
  52. omnigenome/src/model/regression/resnet.py +0 -483
  53. omnigenome/src/model/rna_design/__init__.py +0 -11
  54. omnigenome/src/model/rna_design/model.py +0 -476
  55. omnigenome/src/model/seq2seq/__init__.py +0 -11
  56. omnigenome/src/model/seq2seq/model.py +0 -44
  57. omnigenome/src/tokenizer/__init__.py +0 -16
  58. omnigenome/src/tokenizer/bpe_tokenizer.py +0 -226
  59. omnigenome/src/tokenizer/kmers_tokenizer.py +0 -247
  60. omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +0 -249
  61. omnigenome/src/trainer/__init__.py +0 -14
  62. omnigenome/src/trainer/accelerate_trainer.py +0 -747
  63. omnigenome/src/trainer/hf_trainer.py +0 -75
  64. omnigenome/src/trainer/trainer.py +0 -591
  65. omnigenome/utility/__init__.py +0 -3
  66. omnigenome/utility/dataset_hub/__init__.py +0 -12
  67. omnigenome/utility/dataset_hub/dataset_hub.py +0 -178
  68. omnigenome/utility/ensemble.py +0 -324
  69. omnigenome/utility/hub_utils.py +0 -517
  70. omnigenome/utility/model_hub/__init__.py +0 -11
  71. omnigenome/utility/model_hub/model_hub.py +0 -232
  72. omnigenome/utility/pipeline_hub/__init__.py +0 -11
  73. omnigenome/utility/pipeline_hub/pipeline.py +0 -483
  74. omnigenome/utility/pipeline_hub/pipeline_hub.py +0 -129
  75. omnigenome-0.3.1a0.dist-info/RECORD +0 -78
  76. {omnigenome-0.3.1a0.dist-info → omnigenome-0.3.4a0.dist-info}/WHEEL +0 -0
  77. {omnigenome-0.3.1a0.dist-info → omnigenome-0.3.4a0.dist-info}/entry_points.txt +0 -0
  78. {omnigenome-0.3.1a0.dist-info → omnigenome-0.3.4a0.dist-info}/licenses/LICENSE +0 -0
  79. {omnigenome-0.3.1a0.dist-info → omnigenome-0.3.4a0.dist-info}/top_level.txt +0 -0
@@ -1,75 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # file: hf_trainer.py
3
- # time: 14:40 06/04/2024
4
- # author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
- # github: https://github.com/yangheng95
6
- # huggingface: https://huggingface.co/yangheng
7
- # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
8
- # Copyright (C) 2019-2024. All Rights Reserved.
9
- """
10
- HuggingFace trainer integration for OmniGenome models.
11
-
12
- This module provides HuggingFace trainer wrappers for OmniGenome models,
13
- enabling seamless integration with the HuggingFace training ecosystem
14
- while maintaining OmniGenome-specific functionality.
15
- """
16
-
17
- from transformers import Trainer
18
- from transformers import TrainingArguments
19
-
20
- from ... import __name__ as omnigenome_name
21
- from ... import __version__ as omnigenome_version
22
-
23
-
24
- class HFTrainer(Trainer):
25
- """
26
- HuggingFace trainer wrapper for OmniGenome models.
27
-
28
- This class extends the HuggingFace Trainer to include OmniGenome-specific
29
- metadata and functionality while maintaining full compatibility with the
30
- HuggingFace training ecosystem.
31
-
32
- Attributes:
33
- metadata: Dictionary containing OmniGenome library information
34
- """
35
-
36
- def __init__(self, *args, **kwargs):
37
- """
38
- Initialize the HuggingFace trainer wrapper.
39
-
40
- Args:
41
- *args: Positional arguments passed to the parent Trainer
42
- **kwargs: Keyword arguments passed to the parent Trainer
43
- """
44
- super(HFTrainer, self).__init__(*args, **kwargs)
45
- self.metadata = {
46
- "library_name": omnigenome_name,
47
- "omnigenome_version": omnigenome_version,
48
- }
49
-
50
-
51
- class HFTrainingArguments(TrainingArguments):
52
- """
53
- HuggingFace training arguments wrapper for OmniGenome models.
54
-
55
- This class extends the HuggingFace TrainingArguments to include
56
- OmniGenome-specific metadata while maintaining full compatibility
57
- with the HuggingFace training ecosystem.
58
-
59
- Attributes:
60
- metadata: Dictionary containing OmniGenome library information
61
- """
62
-
63
- def __init__(self, *args, **kwargs):
64
- """
65
- Initialize the HuggingFace training arguments wrapper.
66
-
67
- Args:
68
- *args: Positional arguments passed to the parent TrainingArguments
69
- **kwargs: Keyword arguments passed to the parent TrainingArguments
70
- """
71
- super(HFTrainingArguments, self).__init__(*args, **kwargs)
72
- self.metadata = {
73
- "library_name": omnigenome_name,
74
- "omnigenome_version": omnigenome_version,
75
- }
@@ -1,591 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # file: trainer.py
3
- # time: 14:40 06/04/2024
4
- # author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
- # github: https://github.com/yangheng95
6
- # huggingface: https://huggingface.co/yangheng
7
- # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
8
- # Copyright (C) 2019-2024. All Rights Reserved.
9
- """
10
- Training utilities for OmniGenome models.
11
-
12
- This module provides a comprehensive training framework for OmniGenome models,
13
- including automatic mixed precision training, early stopping, metric tracking,
14
- and model checkpointing.
15
- """
16
- import os
17
- import tempfile
18
- import autocuda
19
- import numpy as np
20
- from torch.utils.data import DataLoader
21
- from tqdm import tqdm
22
-
23
- from ..misc.utils import env_meta_info, fprint, seed_everything
24
-
25
- import torch
26
- from torch.cuda.amp import GradScaler
27
-
28
-
29
- def _infer_optimization_direction(metrics, prev_metrics):
30
- """
31
- Infer the optimization direction based on metric names and trends.
32
-
33
- This function determines whether larger or smaller values are better for
34
- the given metrics by analyzing metric names and their trends over time.
35
-
36
- Args:
37
- metrics (dict): Current metric values
38
- prev_metrics (list): Previous metric values from multiple epochs
39
-
40
- Returns:
41
- str: Either "larger_is_better" or "smaller_is_better"
42
- """
43
- larger_is_better_metrics = [
44
- "accuracy",
45
- "f1",
46
- "recall",
47
- "precision",
48
- "roc_auc",
49
- "pr_auc",
50
- "score",
51
- # ...
52
- ]
53
- smaller_is_better_metrics = [
54
- "loss",
55
- "error",
56
- "mse",
57
- "mae",
58
- "r2",
59
- "distance",
60
- # ...
61
- ]
62
- for metric in larger_is_better_metrics:
63
- if prev_metrics and metric in list(prev_metrics[0].keys())[0]:
64
- return "larger_is_better"
65
- for metric in smaller_is_better_metrics:
66
- if prev_metrics and metric in list(prev_metrics[0].keys())[0]:
67
- return "smaller_is_better"
68
-
69
- fprint(
70
- "Cannot determine the optimisation direction. Attempting inference from the metrics."
71
- )
72
- is_prev_increasing = np.mean(list(prev_metrics[0].values())[0]) < np.mean(
73
- list(prev_metrics[-1].values())[0]
74
- )
75
- is_still_increasing = np.mean(list(prev_metrics[1].values())[0]) < np.mean(
76
- list(metrics.values())[0]
77
- )
78
- fprint(
79
- "Cannot determine the optimisation direction. Attempting inference from the metrics."
80
- )
81
-
82
- if is_prev_increasing and is_still_increasing:
83
- return "larger_is_better"
84
-
85
- is_prev_decreasing = np.mean(list(prev_metrics[0].values())[0]) > np.mean(
86
- list(prev_metrics[-1].values())[0]
87
- )
88
- is_still_decreasing = np.mean(list(prev_metrics[1].values())[0]) > np.mean(
89
- list(metrics.values())
90
- )
91
-
92
- if is_prev_decreasing and is_still_decreasing:
93
- return "smaller_is_better"
94
-
95
- return "larger_is_better" if is_prev_increasing else "smaller_is_better"
96
-
97
-
98
- class Trainer:
99
- """
100
- Comprehensive trainer for OmniGenome models.
101
-
102
- This trainer provides a complete training framework with automatic mixed precision,
103
- early stopping, metric tracking, and model checkpointing. It supports various
104
- training configurations and can handle different types of genomic sequence tasks.
105
-
106
- Attributes:
107
- model: The model to be trained
108
- train_loader: DataLoader for training data
109
- eval_loader: DataLoader for validation data
110
- test_loader: DataLoader for test data
111
- epochs: Number of training epochs
112
- patience: Early stopping patience
113
- optimizer: Optimizer for training
114
- loss_fn: Loss function
115
- compute_metrics: List of metric computation functions
116
- device: Device to run training on
117
- scaler: Gradient scaler for mixed precision training
118
- metrics: Dictionary to store training metrics
119
- predictions: Dictionary to store model predictions
120
- """
121
-
122
- def __init__(
123
- self,
124
- model,
125
- train_dataset: torch.utils.data.Dataset = None,
126
- eval_dataset: torch.utils.data.Dataset = None,
127
- test_dataset: torch.utils.data.Dataset = None,
128
- epochs: int = 3,
129
- batch_size: int = 8,
130
- patience: int = -1,
131
- gradient_accumulation_steps: int = 1,
132
- optimizer: torch.optim.Optimizer = None,
133
- loss_fn: torch.nn.Module = None,
134
- compute_metrics: list | str = None,
135
- seed: int = 42,
136
- device: [torch.device | str] = None,
137
- autocast: str = "float16",
138
- **kwargs,
139
- ):
140
- """
141
- Initialize the trainer.
142
-
143
- Args:
144
- model: The model to be trained
145
- train_dataset: Training dataset
146
- eval_dataset: Validation dataset
147
- test_dataset: Test dataset
148
- epochs (int): Number of training epochs (default: 3)
149
- batch_size (int): Batch size for training (default: 8)
150
- patience (int): Early stopping patience (default: -1, no early stopping)
151
- gradient_accumulation_steps (int): Gradient accumulation steps (default: 1)
152
- optimizer: Optimizer for training (default: None)
153
- loss_fn: Loss function (default: None)
154
- compute_metrics: Metric computation functions (default: None)
155
- seed (int): Random seed (default: 42)
156
- device: Device to run training on (default: None, auto-detect)
157
- autocast (str): Mixed precision type (default: "float16")
158
- **kwargs: Additional keyword arguments
159
- """
160
- # sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
161
-
162
- self.model = model
163
-
164
- # DataLoaders
165
- if kwargs.get("train_loader"):
166
- self.train_loader = kwargs.get("train_loader", None)
167
- self.eval_loader = kwargs.get("eval_loader", None)
168
- self.test_loader = kwargs.get("test_loader", None)
169
- else:
170
- self.train_loader = DataLoader(
171
- train_dataset, batch_size=batch_size, shuffle=True
172
- )
173
- self.eval_loader = (
174
- DataLoader(eval_dataset, batch_size=batch_size)
175
- if eval_dataset
176
- else None
177
- )
178
- self.test_loader = (
179
- DataLoader(test_dataset, batch_size=batch_size)
180
- if test_dataset
181
- else None
182
- )
183
-
184
- self.epochs = epochs
185
- self.patience = patience if patience > 0 else epochs
186
- self.gradient_accumulation_steps = gradient_accumulation_steps
187
- self.optimizer = optimizer
188
- self.loss_fn = loss_fn
189
- self.compute_metrics = (
190
- compute_metrics if isinstance(compute_metrics, list) else [compute_metrics]
191
- )
192
- self.seed = seed
193
- self.device = device if device else autocuda.auto_cuda()
194
- self.device = (
195
- torch.device(self.device) if isinstance(self.device, str) else self.device
196
- )
197
-
198
- self.fast_dtype = {
199
- "float32": torch.float32,
200
- "fp32": torch.float32,
201
- "float16": torch.float16,
202
- "fp16": torch.float16,
203
- "bfloat16": torch.bfloat16,
204
- "bf16": torch.bfloat16,
205
- }.get(autocast, torch.float16)
206
- self.scaler = GradScaler()
207
- if self.loss_fn is not None:
208
- self.model.set_loss_fn(self.loss_fn)
209
-
210
- self.model.to(self.device)
211
-
212
- self.metadata = env_meta_info()
213
- self.metrics = {}
214
-
215
- self._optimization_direction = None
216
- self.trial_name = kwargs.get("trial_name", self.model.__class__.__name__)
217
-
218
- self.predictions = {}
219
-
220
- def _is_metric_better(self, metrics, stage="valid"):
221
- """
222
- Check if the current metrics are better than the best metrics so far.
223
-
224
- Args:
225
- metrics (dict): Current metric values
226
- stage (str): Stage name ("valid" or "test")
227
-
228
- Returns:
229
- bool: True if current metrics are better than best metrics
230
- """
231
- assert stage in [
232
- "valid",
233
- "test",
234
- ], "The metrics stage should be either 'valid' or 'test'."
235
-
236
- prev_metrics = self.metrics.get(stage, None)
237
- if stage not in self.metrics:
238
- self.metrics.update({f"{stage}": [metrics]})
239
- else:
240
- self.metrics[f"{stage}"].append(metrics)
241
-
242
- if "best_valid" not in self.metrics:
243
- self.metrics.update({"best_valid": metrics})
244
- return True
245
-
246
- if prev_metrics is None:
247
- return False
248
-
249
- self._optimization_direction = (
250
- _infer_optimization_direction(metrics, prev_metrics)
251
- if self._optimization_direction is None
252
- else self._optimization_direction
253
- )
254
-
255
- if self._optimization_direction == "larger_is_better":
256
- if np.mean(list(metrics.values())[0]) > np.mean(
257
- list(self.metrics["best_valid"].values())[0]
258
- ):
259
- self.metrics.update({"best_valid": metrics})
260
- return True
261
- elif self._optimization_direction == "smaller_is_better":
262
- if np.mean(list(metrics.values())[0]) < np.mean(
263
- list(self.metrics["best_valid"].values())[0]
264
- ):
265
- self.metrics.update({"best_valid": metrics})
266
- return True
267
-
268
- return False
269
-
270
- def train(self, path_to_save=None, **kwargs):
271
- """
272
- Train the model.
273
-
274
- Args:
275
- path_to_save (str, optional): Path to save the best model
276
- **kwargs: Additional keyword arguments
277
-
278
- Returns:
279
- dict: Training metrics and results
280
- """
281
- seed_everything(self.seed)
282
- patience = 0
283
-
284
- if self.eval_loader is not None and len(self.eval_loader) > 0:
285
- valid_metrics = self.evaluate()
286
- else:
287
- valid_metrics = self.test()
288
- if self._is_metric_better(valid_metrics, stage="valid"):
289
- self._save_state_dict()
290
- patience = 0
291
-
292
- for epoch in range(self.epochs):
293
- self.model.train()
294
- train_loss = []
295
- train_it = tqdm(
296
- self.train_loader, desc=f"Epoch {epoch + 1}/{self.epochs} Loss"
297
- )
298
- for step, batch in enumerate(train_it):
299
- batch = batch.to(self.device)
300
-
301
- if step % self.gradient_accumulation_steps == 0:
302
- self.optimizer.zero_grad()
303
-
304
- if self.fast_dtype:
305
- with torch.autocast(
306
- device_type=self.device.type, dtype=self.fast_dtype
307
- ):
308
- outputs = self.model(**batch)
309
- else:
310
- outputs = self.model(**batch)
311
- if "loss" not in outputs:
312
- # Generally, the model should return a loss in the outputs via OmniGenBench
313
- # For the Lora models, the loss is computed separately
314
- if hasattr(self.model, "loss_function") and callable(
315
- self.model.loss_function
316
- ):
317
- loss = self.model.loss_function(
318
- outputs["logits"], outputs["labels"]
319
- )
320
- elif (
321
- hasattr(self.model, "model")
322
- and hasattr(self.model.model, "loss_function")
323
- and callable(self.model.model.loss_function)
324
- ):
325
- loss = self.model.model.loss_function(
326
- outputs["logits"], outputs["labels"]
327
- )
328
- else:
329
- raise ValueError(
330
- "The model does not have a loss function defined. "
331
- "Please provide a loss function or ensure the model has one."
332
- )
333
- else:
334
- # If the model returns a loss directly
335
- loss = outputs["loss"]
336
-
337
- loss = loss / self.gradient_accumulation_steps
338
-
339
- if self.fast_dtype:
340
- self.scaler.scale(loss).backward()
341
- else:
342
- loss.backward()
343
-
344
- if (step + 1) % self.gradient_accumulation_steps == 0 or (
345
- step + 1
346
- ) == len(self.train_loader):
347
- if self.fast_dtype:
348
- self.scaler.step(self.optimizer)
349
- self.scaler.update()
350
- else:
351
- self.optimizer.step()
352
-
353
- train_loss.append(loss.item() * self.gradient_accumulation_steps)
354
- train_it.set_description(
355
- f"Epoch {epoch + 1}/{self.epochs} Loss: {np.nanmean(train_loss):.4f}"
356
- )
357
-
358
- if self.eval_loader is not None and len(self.eval_loader) > 0:
359
- valid_metrics = self.evaluate()
360
- else:
361
- valid_metrics = self.test()
362
-
363
- if self._is_metric_better(valid_metrics, stage="valid"):
364
- self._save_state_dict()
365
- patience = 0
366
- else:
367
- patience += 1
368
- if patience >= self.patience:
369
- fprint(f"Early stopping at epoch {epoch + 1}.")
370
- break
371
-
372
- if path_to_save:
373
- _path_to_save = path_to_save + "_epoch_" + str(epoch + 1)
374
-
375
- if valid_metrics:
376
- for key, value in valid_metrics.items():
377
- _path_to_save += f"_seed_{self.seed}_{key}_{value:.4f}"
378
-
379
- self.save_model(_path_to_save, **kwargs)
380
-
381
- if self.test_loader is not None and len(self.test_loader) > 0:
382
- self._load_state_dict()
383
- test_metrics = self.test()
384
- self._is_metric_better(test_metrics, stage="test")
385
-
386
- if path_to_save:
387
- _path_to_save = path_to_save + "_final"
388
- if self.metrics["test"]:
389
- for key, value in self.metrics["test"][-1].items():
390
- _path_to_save += f"_seed_{self.seed}_{key}_{value:.4f}"
391
-
392
- self.save_model(_path_to_save, **kwargs)
393
-
394
- self._remove_state_dict()
395
-
396
- return self.metrics
397
-
398
- def evaluate(self):
399
- """
400
- Evaluate the model on the validation set.
401
-
402
- Returns:
403
- dict: Evaluation metrics
404
- """
405
- with torch.no_grad():
406
- self.model.eval()
407
- val_truth = []
408
- val_preds = []
409
- it = tqdm(self.eval_loader, desc="Evaluating")
410
- for batch in it:
411
- batch.to(self.device)
412
- labels = batch["labels"]
413
- batch.pop("labels")
414
- if self.fast_dtype:
415
- with torch.autocast(device_type="cuda", dtype=self.fast_dtype):
416
- predictions = self.model.predict(batch)["predictions"]
417
- else:
418
- predictions = self.model.predict(batch)["predictions"]
419
- val_truth.append(labels.float().cpu().numpy(force=True))
420
- val_preds.append(predictions.float().cpu().numpy(force=True))
421
- val_truth = (
422
- np.vstack(val_truth) if labels.ndim > 1 else np.hstack(val_truth)
423
- )
424
- val_preds = (
425
- np.vstack(val_preds) if predictions.ndim > 1 else np.hstack(val_preds)
426
- )
427
- if not np.all(val_truth == -100):
428
- valid_metrics = {}
429
- for metric_func in self.compute_metrics:
430
- valid_metrics.update(metric_func(val_truth, val_preds))
431
-
432
- fprint(valid_metrics)
433
- else:
434
- valid_metrics = {
435
- "Validation set labels may be NaN. No metrics calculated.": 0
436
- }
437
-
438
- self.predictions.update({"valid": {"pred": val_preds, "true": val_truth}})
439
-
440
- return valid_metrics
441
-
442
- def test(self):
443
- """
444
- Test the model on the test set.
445
-
446
- Returns:
447
- dict: Test metrics and predictions
448
- """
449
- with torch.no_grad():
450
- self.model.eval()
451
- preds = []
452
- truth = []
453
- it = tqdm(self.test_loader, desc="Testing")
454
- for batch in it:
455
- batch.to(self.device)
456
- labels = batch["labels"]
457
- batch.pop("labels")
458
- if self.fast_dtype:
459
- with torch.autocast(device_type="cuda", dtype=self.fast_dtype):
460
- predictions = self.model.predict(batch)["predictions"]
461
- else:
462
- predictions = self.model.predict(batch)["predictions"]
463
- truth.append(labels.float().cpu().numpy(force=True))
464
- preds.append(predictions.float().cpu().numpy(force=True))
465
- truth = np.vstack(truth) if labels.ndim > 1 else np.hstack(truth)
466
- preds = np.vstack(preds) if predictions.ndim > 1 else np.hstack(preds)
467
- if not np.all(truth == -100):
468
- test_metrics = {}
469
- for metric_func in self.compute_metrics:
470
- test_metrics.update(metric_func(truth, preds))
471
-
472
- fprint(test_metrics)
473
- else:
474
- test_metrics = {"Test set labels may be NaN. No metrics calculated.": 0}
475
-
476
- self.predictions.update({"test": {"pred": preds, "true": truth}})
477
-
478
- return test_metrics
479
-
480
- def predict(self, data_loader):
481
- """
482
- Generate predictions using the model.
483
-
484
- Args:
485
- data_loader: DataLoader for prediction data
486
-
487
- Returns:
488
- torch.Tensor: Model predictions
489
- """
490
- return self.model.predict(data_loader)
491
-
492
- def get_model(self, **kwargs):
493
- """
494
- Get the trained model.
495
-
496
- Args:
497
- **kwargs: Additional keyword arguments
498
-
499
- Returns:
500
- The trained model
501
- """
502
- return self.model
503
-
504
- def compute_metrics(self):
505
- """
506
- Get the metric computation functions.
507
-
508
- Returns:
509
- list: List of metric computation functions
510
- """
511
- return self.compute_metrics
512
-
513
- def unwrap_model(self, model=None):
514
- """
515
- Unwrap the model from any distributed training wrappers.
516
-
517
- Args:
518
- model: Model to unwrap (default: None, uses self.model)
519
-
520
- Returns:
521
- The unwrapped model
522
- """
523
- if model is None:
524
- model = self.model
525
- try:
526
- return self.accelerator.unwrap_model(model)
527
- except:
528
- try:
529
- return model.module
530
- except:
531
- return model
532
-
533
- def save_model(self, path, overwrite=False, **kwargs):
534
- """
535
- Save the model to disk.
536
-
537
- Args:
538
- path (str): Path to save the model
539
- overwrite (bool): Whether to overwrite existing files (default: False)
540
- **kwargs: Additional keyword arguments
541
- """
542
- self.unwrap_model().save(path, overwrite, **kwargs)
543
-
544
- def _load_state_dict(self):
545
- """
546
- Load model state dictionary from temporary file.
547
-
548
- Returns:
549
- dict: Model state dictionary
550
- """
551
- if os.path.exists(self._model_state_dict_path):
552
- self.unwrap_model().load_state_dict(
553
- torch.load(self._model_state_dict_path, map_location="cpu")
554
- )
555
- self.unwrap_model().to(self.device)
556
-
557
- def _save_state_dict(self):
558
- """
559
- Save model state dictionary to temporary file.
560
-
561
- Returns:
562
- str: Path to temporary file
563
- """
564
- if not hasattr(self, "_model_state_dict_path"):
565
- # 创建临时文件,并关闭以便写入
566
- tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pt")
567
- self._model_state_dict_path = tmp_file.name
568
- tmp_file.close()
569
-
570
- try:
571
- if os.path.exists(self._model_state_dict_path):
572
- os.remove(self._model_state_dict_path)
573
- except Exception as e:
574
- fprint(
575
- f"Failed to remove the temporary checkpoint file {self._model_state_dict_path}: {e}"
576
- )
577
-
578
- torch.save(self.unwrap_model().state_dict(), self._model_state_dict_path)
579
-
580
- def _remove_state_dict(self):
581
- """
582
- Remove temporary state dictionary file.
583
- """
584
- if hasattr(self, "_model_state_dict_path"):
585
- try:
586
- if os.path.exists(self._model_state_dict_path):
587
- os.remove(self._model_state_dict_path)
588
- except Exception as e:
589
- fprint(
590
- f"Failed to remove the temporary checkpoint file {self._model_state_dict_path}: {e}"
591
- )
@@ -1,3 +0,0 @@
1
- """
2
- This package contains utility modules for interacting with the hub, datasets, and pipelines.
3
- """
@@ -1,12 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # File: __init__.py
3
- # Time: 02:22 20/06/2025
4
- # Author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
- # Website: https://yangheng95.github.io
6
- # GitHub: https://github.com/yangheng95
7
- # HuggingFace: https://huggingface.co/yangheng
8
- # Google Scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
9
- # Copyright (C) 2019-2025. All rights reserved.
10
- """
11
- This package contains modules for the dataset hub.
12
- """