omnigenome 0.3.0a0__py3-none-any.whl → 0.3.1a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. omnigenome/__init__.py +29 -44
  2. omnigenome/auto/auto_bench/__init__.py +0 -1
  3. omnigenome/auto/auto_bench/auto_bench.py +24 -14
  4. omnigenome/auto/auto_train/__init__.py +0 -1
  5. omnigenome/auto/auto_train/auto_train.py +11 -12
  6. omnigenome/auto/bench_hub/__init__.py +0 -1
  7. omnigenome/auto/bench_hub/bench_hub.py +1 -1
  8. omnigenome/cli/__init__.py +0 -1
  9. omnigenome/cli/commands/__init__.py +0 -1
  10. omnigenome/cli/commands/base.py +10 -10
  11. omnigenome/cli/commands/bench/__init__.py +0 -1
  12. omnigenome/cli/commands/bench/bench_cli.py +10 -10
  13. omnigenome/cli/commands/rna/__init__.py +0 -1
  14. omnigenome/cli/commands/rna/rna_design.py +10 -11
  15. omnigenome/src/__init__.py +0 -1
  16. omnigenome/src/abc/__init__.py +0 -1
  17. omnigenome/src/abc/abstract_dataset.py +38 -19
  18. omnigenome/src/abc/abstract_metric.py +7 -7
  19. omnigenome/src/abc/abstract_model.py +15 -14
  20. omnigenome/src/abc/abstract_tokenizer.py +9 -7
  21. omnigenome/src/dataset/omni_dataset.py +16 -14
  22. omnigenome/src/lora/__init__.py +0 -1
  23. omnigenome/src/lora/lora_model.py +47 -41
  24. omnigenome/src/metric/classification_metric.py +11 -11
  25. omnigenome/src/metric/metric.py +19 -19
  26. omnigenome/src/metric/ranking_metric.py +15 -15
  27. omnigenome/src/metric/regression_metric.py +18 -18
  28. omnigenome/src/misc/utils.py +214 -150
  29. omnigenome/src/model/augmentation/__init__.py +0 -1
  30. omnigenome/src/model/augmentation/model.py +17 -17
  31. omnigenome/src/model/classification/__init__.py +0 -1
  32. omnigenome/src/model/classification/model.py +28 -32
  33. omnigenome/src/model/embedding/__init__.py +0 -1
  34. omnigenome/src/model/embedding/model.py +35 -35
  35. omnigenome/src/model/mlm/__init__.py +0 -1
  36. omnigenome/src/model/mlm/model.py +13 -13
  37. omnigenome/src/model/module_utils.py +17 -17
  38. omnigenome/src/model/regression/__init__.py +0 -1
  39. omnigenome/src/model/regression/model.py +72 -77
  40. omnigenome/src/model/regression/resnet.py +32 -32
  41. omnigenome/src/model/rna_design/__init__.py +0 -1
  42. omnigenome/src/model/rna_design/model.py +168 -118
  43. omnigenome/src/model/seq2seq/__init__.py +0 -1
  44. omnigenome/src/model/seq2seq/model.py +4 -4
  45. omnigenome/src/tokenizer/bpe_tokenizer.py +27 -27
  46. omnigenome/src/tokenizer/kmers_tokenizer.py +22 -22
  47. omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +11 -11
  48. omnigenome/src/trainer/accelerate_trainer.py +40 -32
  49. omnigenome/src/trainer/hf_trainer.py +8 -8
  50. omnigenome/src/trainer/trainer.py +37 -25
  51. omnigenome/utility/dataset_hub/__init__.py +0 -1
  52. omnigenome/utility/dataset_hub/dataset_hub.py +13 -13
  53. omnigenome/utility/ensemble.py +26 -26
  54. omnigenome/utility/hub_utils.py +8 -8
  55. omnigenome/utility/model_hub/__init__.py +0 -1
  56. omnigenome/utility/model_hub/model_hub.py +26 -25
  57. omnigenome/utility/pipeline_hub/__init__.py +0 -1
  58. omnigenome/utility/pipeline_hub/pipeline.py +49 -49
  59. omnigenome/utility/pipeline_hub/pipeline_hub.py +17 -17
  60. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/METADATA +3 -3
  61. omnigenome-0.3.1a0.dist-info/RECORD +78 -0
  62. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/top_level.txt +0 -1
  63. omnigenome-0.3.0a0.dist-info/RECORD +0 -85
  64. tests/__init__.py +0 -9
  65. tests/conftest.py +0 -160
  66. tests/test_dataset_patterns.py +0 -291
  67. tests/test_examples_syntax.py +0 -83
  68. tests/test_model_loading.py +0 -183
  69. tests/test_rna_functions.py +0 -255
  70. tests/test_training_patterns.py +0 -302
  71. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/WHEEL +0 -0
  72. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/entry_points.txt +0 -0
  73. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/licenses/LICENSE +0 -0
@@ -16,18 +16,18 @@ warnings.filterwarnings("once")
16
16
  class OmniKmersTokenizer(OmniTokenizer):
17
17
  """
18
18
  A k-mer based tokenizer for genomic sequences.
19
-
19
+
20
20
  This tokenizer breaks genomic sequences into overlapping k-mers and uses
21
21
  a base tokenizer to convert them into token IDs. It supports various
22
22
  k-mer sizes and overlap configurations for different genomic applications.
23
-
23
+
24
24
  Attributes:
25
25
  base_tokenizer: The underlying tokenizer for converting k-mers to IDs
26
26
  k: Size of k-mers
27
27
  overlap: Number of overlapping positions between consecutive k-mers
28
28
  max_length: Maximum sequence length for tokenization
29
29
  metadata: Dictionary containing tokenizer metadata
30
-
30
+
31
31
  Example:
32
32
  >>> from omnigenome.src.tokenizer import OmniKmersTokenizer
33
33
  >>> from transformers import AutoTokenizer
@@ -42,7 +42,7 @@ class OmniKmersTokenizer(OmniTokenizer):
42
42
  def __init__(self, base_tokenizer=None, k=3, overlap=0, max_length=512, **kwargs):
43
43
  """
44
44
  Initialize the OmniKmersTokenizer.
45
-
45
+
46
46
  Args:
47
47
  base_tokenizer: The base tokenizer for converting k-mers to token IDs
48
48
  k (int, optional): Size of k-mers. Defaults to 3
@@ -59,18 +59,18 @@ class OmniKmersTokenizer(OmniTokenizer):
59
59
  def __call__(self, sequence, **kwargs):
60
60
  """
61
61
  Tokenize a sequence or list of sequences into tokenized inputs.
62
-
62
+
63
63
  This method processes the input sequence(s) by first converting them to k-mers,
64
64
  then using the base tokenizer to convert k-mers to token IDs. It handles
65
65
  sequence preprocessing (U/T conversion) and adds special tokens.
66
-
66
+
67
67
  Args:
68
68
  sequence (str or list): Input sequence(s) to tokenize
69
69
  **kwargs: Additional keyword arguments including max_length
70
-
70
+
71
71
  Returns:
72
72
  dict: Dictionary containing tokenized inputs with keys 'input_ids' and 'attention_mask'
73
-
73
+
74
74
  Example:
75
75
  >>> sequence = "ACGUAGGUAUCGUAGA"
76
76
  >>> tokenized = tokenizer(sequence)
@@ -126,14 +126,14 @@ class OmniKmersTokenizer(OmniTokenizer):
126
126
  def from_pretrained(model_name_or_path, **kwargs):
127
127
  """
128
128
  Create a k-mers tokenizer from a pre-trained model.
129
-
129
+
130
130
  Args:
131
131
  model_name_or_path (str): Name or path of the pre-trained model
132
132
  **kwargs: Additional keyword arguments
133
-
133
+
134
134
  Returns:
135
135
  OmniKmersTokenizer: Initialized k-mers tokenizer
136
-
136
+
137
137
  Example:
138
138
  >>> tokenizer = OmniKmersTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
139
139
  >>> print(type(tokenizer))
@@ -149,17 +149,17 @@ class OmniKmersTokenizer(OmniTokenizer):
149
149
  def tokenize(self, sequence, **kwargs):
150
150
  """
151
151
  Convert sequence(s) into k-mers.
152
-
152
+
153
153
  This method breaks the input sequence(s) into overlapping k-mers based on
154
154
  the configured k-mer size and overlap parameters.
155
-
155
+
156
156
  Args:
157
157
  sequence (str or list): Input sequence(s) to convert to k-mers
158
158
  **kwargs: Additional keyword arguments
159
-
159
+
160
160
  Returns:
161
161
  list: List of k-mer lists for each input sequence
162
-
162
+
163
163
  Example:
164
164
  >>> sequence = "ACGUAGGUAUCGUAGA"
165
165
  >>> k_mers = tokenizer.tokenize(sequence)
@@ -184,11 +184,11 @@ class OmniKmersTokenizer(OmniTokenizer):
184
184
  def encode(self, input_ids, **kwargs):
185
185
  """
186
186
  Encode input IDs using the base tokenizer.
187
-
187
+
188
188
  Args:
189
189
  input_ids: Input IDs to encode
190
190
  **kwargs: Additional keyword arguments
191
-
191
+
192
192
  Returns:
193
193
  Encoded input IDs
194
194
  """
@@ -197,11 +197,11 @@ class OmniKmersTokenizer(OmniTokenizer):
197
197
  def decode(self, input_ids, **kwargs):
198
198
  """
199
199
  Decode input IDs using the base tokenizer.
200
-
200
+
201
201
  Args:
202
202
  input_ids: Input IDs to decode
203
203
  **kwargs: Additional keyword arguments
204
-
204
+
205
205
  Returns:
206
206
  Decoded sequence
207
207
  """
@@ -210,13 +210,13 @@ class OmniKmersTokenizer(OmniTokenizer):
210
210
  def encode_plus(self, sequence, **kwargs):
211
211
  """
212
212
  Encode a sequence with additional information.
213
-
213
+
214
214
  This method is not yet implemented for k-mers tokenizer.
215
-
215
+
216
216
  Args:
217
217
  sequence: Input sequence
218
218
  **kwargs: Additional keyword arguments
219
-
219
+
220
220
  Raises:
221
221
  NotImplementedError: This method is not implemented yet
222
222
  """
@@ -19,16 +19,16 @@ warnings.filterwarnings("once")
19
19
  class OmniSingleNucleotideTokenizer(OmniTokenizer):
20
20
  """
21
21
  Tokenizer for single nucleotide tokenization in genomics.
22
-
22
+
23
23
  This tokenizer converts genomic sequences into individual nucleotide tokens,
24
24
  where each nucleotide (A, T, C, G, U) becomes a separate token. It's designed
25
25
  for genomic sequence processing where fine-grained nucleotide-level analysis
26
26
  is required.
27
-
27
+
28
28
  The tokenizer supports various preprocessing options including U/T conversion
29
29
  and whitespace addition between nucleotides. It also handles special tokens
30
30
  like BOS (beginning of sequence) and EOS (end of sequence) tokens.
31
-
31
+
32
32
  Attributes:
33
33
  u2t (bool): Whether to convert 'U' to 'T'.
34
34
  t2u (bool): Whether to convert 'T' to 'U'.
@@ -54,7 +54,7 @@ class OmniSingleNucleotideTokenizer(OmniTokenizer):
54
54
  def __call__(self, sequence, **kwargs):
55
55
  """
56
56
  Tokenizes sequences using single nucleotide tokenization.
57
-
57
+
58
58
  This method converts genomic sequences into tokenized inputs suitable
59
59
  for model training and inference. It handles sequence preprocessing,
60
60
  tokenization, and padding/truncation.
@@ -76,7 +76,7 @@ class OmniSingleNucleotideTokenizer(OmniTokenizer):
76
76
  >>> # Tokenize a single sequence
77
77
  >>> inputs = tokenizer("ATCGATCG")
78
78
  >>> print(inputs['input_ids'].shape) # torch.Size([1, seq_len])
79
-
79
+
80
80
  >>> # Tokenize multiple sequences
81
81
  >>> inputs = tokenizer(["ATCGATCG", "GCTAGCTA"])
82
82
  >>> print(inputs['input_ids'].shape) # torch.Size([2, seq_len])
@@ -134,7 +134,7 @@ class OmniSingleNucleotideTokenizer(OmniTokenizer):
134
134
  def from_pretrained(model_name_or_path, **kwargs):
135
135
  """
136
136
  Loads a single nucleotide tokenizer from a pre-trained model.
137
-
137
+
138
138
  This method creates a single nucleotide tokenizer wrapper around
139
139
  a Hugging Face tokenizer loaded from a pre-trained model.
140
140
 
@@ -156,7 +156,7 @@ class OmniSingleNucleotideTokenizer(OmniTokenizer):
156
156
  def tokenize(self, sequence, **kwargs):
157
157
  """
158
158
  Converts a sequence into a list of individual nucleotide tokens.
159
-
159
+
160
160
  This method tokenizes genomic sequences by treating each nucleotide
161
161
  as a separate token. It handles both single sequences and lists of sequences.
162
162
 
@@ -172,7 +172,7 @@ class OmniSingleNucleotideTokenizer(OmniTokenizer):
172
172
  >>> # Tokenize a single sequence
173
173
  >>> tokens = tokenizer.tokenize("ATCGATCG")
174
174
  >>> print(tokens) # [['A', 'T', 'C', 'G', 'A', 'T', 'C', 'G']]
175
-
175
+
176
176
  >>> # Tokenize multiple sequences
177
177
  >>> tokens = tokenizer.tokenize(["ATCGATCG", "GCTAGCTA"])
178
178
  >>> print(tokens) # [['A', 'T', 'C', 'G', ...], ['G', 'C', 'T', 'A', ...]]
@@ -191,7 +191,7 @@ class OmniSingleNucleotideTokenizer(OmniTokenizer):
191
191
  def encode(self, sequence, **kwargs):
192
192
  """
193
193
  Converts a sequence into a list of token IDs.
194
-
194
+
195
195
  This method encodes genomic sequences into token IDs using the
196
196
  underlying base tokenizer.
197
197
 
@@ -211,7 +211,7 @@ class OmniSingleNucleotideTokenizer(OmniTokenizer):
211
211
  def decode(self, sequence, **kwargs):
212
212
  """
213
213
  Converts a list of token IDs back into a sequence.
214
-
214
+
215
215
  This method decodes token IDs back into genomic sequences using
216
216
  the underlying base tokenizer.
217
217
 
@@ -231,7 +231,7 @@ class OmniSingleNucleotideTokenizer(OmniTokenizer):
231
231
  def encode_plus(self, sequence, **kwargs):
232
232
  """
233
233
  Encodes a sequence with additional information.
234
-
234
+
235
235
  This method provides enhanced encoding with additional information
236
236
  like attention masks and token type IDs.
237
237
 
@@ -21,15 +21,15 @@ from ..misc.utils import env_meta_info, fprint, seed_everything
21
21
  def _infer_optimization_direction(metrics, prev_metrics):
22
22
  """
23
23
  Infer the optimization direction based on metric values.
24
-
24
+
25
25
  This function analyzes the trend of metric values to determine whether
26
26
  larger values are better (e.g., accuracy) or smaller values are better
27
27
  (e.g., loss).
28
-
28
+
29
29
  Args:
30
30
  metrics (dict): Current metric values
31
31
  prev_metrics (list): Previous metric values
32
-
32
+
33
33
  Returns:
34
34
  str: Either 'larger_is_better' or 'smaller_is_better'
35
35
  """
@@ -91,11 +91,11 @@ def _infer_optimization_direction(metrics, prev_metrics):
91
91
  class AccelerateTrainer:
92
92
  """
93
93
  A distributed training trainer using HuggingFace Accelerate.
94
-
94
+
95
95
  This trainer provides distributed training capabilities with automatic mixed precision,
96
96
  gradient accumulation, and early stopping. It supports both single and multi-GPU
97
97
  training with seamless integration with HuggingFace Accelerate.
98
-
98
+
99
99
  Attributes:
100
100
  model: The model to train
101
101
  train_loader: DataLoader for training data
@@ -110,7 +110,7 @@ class AccelerateTrainer:
110
110
  accelerator: HuggingFace Accelerate instance
111
111
  metrics: Dictionary to store training metrics
112
112
  predictions: Dictionary to store predictions
113
-
113
+
114
114
  Example:
115
115
  >>> from omnigenome.src.trainer import AccelerateTrainer
116
116
  >>> trainer = AccelerateTrainer(
@@ -143,7 +143,7 @@ class AccelerateTrainer:
143
143
  ):
144
144
  """
145
145
  Initialize the AccelerateTrainer.
146
-
146
+
147
147
  Args:
148
148
  model: The model to train
149
149
  train_dataset (torch.utils.data.Dataset, optional): Training dataset
@@ -293,14 +293,14 @@ class AccelerateTrainer:
293
293
  def evaluate(self):
294
294
  """
295
295
  Evaluate the model on the validation dataset.
296
-
296
+
297
297
  This method runs the model in evaluation mode and computes metrics
298
298
  on the validation dataset. It handles distributed evaluation and
299
299
  gathers results from all processes.
300
-
300
+
301
301
  Returns:
302
302
  dict: Dictionary containing evaluation metrics
303
-
303
+
304
304
  Example:
305
305
  >>> metrics = trainer.evaluate()
306
306
  >>> print(f"Validation accuracy: {metrics['accuracy']:.4f}")
@@ -364,14 +364,14 @@ class AccelerateTrainer:
364
364
  def test(self):
365
365
  """
366
366
  Test the model on the test dataset.
367
-
367
+
368
368
  This method runs the model in evaluation mode and computes metrics
369
369
  on the test dataset. It handles distributed testing and gathers
370
370
  results from all processes.
371
-
371
+
372
372
  Returns:
373
373
  dict: Dictionary containing test metrics
374
-
374
+
375
375
  Example:
376
376
  >>> metrics = trainer.test()
377
377
  >>> print(f"Test accuracy: {metrics['accuracy']:.4f}")
@@ -431,18 +431,18 @@ class AccelerateTrainer:
431
431
  def train(self, path_to_save=None, **kwargs):
432
432
  """
433
433
  Train the model using distributed training.
434
-
434
+
435
435
  This method performs the complete training loop with validation,
436
436
  early stopping, and model checkpointing. It handles distributed
437
437
  training across multiple GPUs and processes.
438
-
438
+
439
439
  Args:
440
440
  path_to_save (str, optional): Path to save the trained model
441
441
  **kwargs: Additional keyword arguments for model saving
442
-
442
+
443
443
  Returns:
444
444
  dict: Dictionary containing training metrics
445
-
445
+
446
446
  Example:
447
447
  >>> metrics = trainer.train(path_to_save="./checkpoints/model")
448
448
  >>> print(f"Best validation accuracy: {metrics['best_valid']['accuracy']:.4f}")
@@ -489,12 +489,20 @@ class AccelerateTrainer:
489
489
  if "loss" not in outputs:
490
490
  # Generally, the model should return a loss in the outputs via OmniGenBench
491
491
  # For the Lora models, the loss is computed separately
492
- if hasattr(self.model, "loss_function") and callable(self.model.loss_function):
493
- loss = self.model.loss_function(outputs['logits'], outputs["labels"])
494
- elif (hasattr(self.model, "model")
495
- and hasattr(self.model.model, "loss_function")
496
- and callable(self.model.model.loss_function)):
497
- loss = self.model.model.loss_function(outputs['logits'], outputs["labels"])
492
+ if hasattr(self.model, "loss_function") and callable(
493
+ self.model.loss_function
494
+ ):
495
+ loss = self.model.loss_function(
496
+ outputs["logits"], outputs["labels"]
497
+ )
498
+ elif (
499
+ hasattr(self.model, "model")
500
+ and hasattr(self.model.model, "loss_function")
501
+ and callable(self.model.model.loss_function)
502
+ ):
503
+ loss = self.model.model.loss_function(
504
+ outputs["logits"], outputs["labels"]
505
+ )
498
506
  else:
499
507
  raise ValueError(
500
508
  "The model does not have a loss function defined. "
@@ -585,11 +593,11 @@ class AccelerateTrainer:
585
593
  def _is_metric_better(self, metrics, stage="valid"):
586
594
  """
587
595
  Check if the current metrics are better than the best metrics so far.
588
-
596
+
589
597
  Args:
590
598
  metrics (dict): Current metrics
591
599
  stage (str): Stage of evaluation ('valid' or 'test')
592
-
600
+
593
601
  Returns:
594
602
  bool: True if current metrics are better, False otherwise
595
603
  """
@@ -643,10 +651,10 @@ class AccelerateTrainer:
643
651
  def predict(self, data_loader):
644
652
  """
645
653
  Make predictions using the trained model.
646
-
654
+
647
655
  Args:
648
656
  data_loader: DataLoader containing data to predict on
649
-
657
+
650
658
  Returns:
651
659
  dict: Dictionary containing predictions
652
660
  """
@@ -655,10 +663,10 @@ class AccelerateTrainer:
655
663
  def get_model(self, **kwargs):
656
664
  """
657
665
  Get the trained model.
658
-
666
+
659
667
  Args:
660
668
  **kwargs: Additional keyword arguments
661
-
669
+
662
670
  Returns:
663
671
  The trained model
664
672
  """
@@ -667,10 +675,10 @@ class AccelerateTrainer:
667
675
  def compute_metrics(self):
668
676
  """
669
677
  Compute metrics for evaluation.
670
-
678
+
671
679
  This method should be implemented by subclasses to provide specific
672
680
  metric computation logic.
673
-
681
+
674
682
  Raises:
675
683
  NotImplementedError: If compute_metrics method is not implemented
676
684
  """
@@ -682,7 +690,7 @@ class AccelerateTrainer:
682
690
  def save_model(self, path, overwrite=False, **kwargs):
683
691
  """
684
692
  Save the trained model.
685
-
693
+
686
694
  Args:
687
695
  path (str): Path to save the model
688
696
  overwrite (bool, optional): Whether to overwrite existing files. Defaults to False
@@ -24,19 +24,19 @@ from ... import __version__ as omnigenome_version
24
24
  class HFTrainer(Trainer):
25
25
  """
26
26
  HuggingFace trainer wrapper for OmniGenome models.
27
-
27
+
28
28
  This class extends the HuggingFace Trainer to include OmniGenome-specific
29
29
  metadata and functionality while maintaining full compatibility with the
30
30
  HuggingFace training ecosystem.
31
-
31
+
32
32
  Attributes:
33
33
  metadata: Dictionary containing OmniGenome library information
34
34
  """
35
-
35
+
36
36
  def __init__(self, *args, **kwargs):
37
37
  """
38
38
  Initialize the HuggingFace trainer wrapper.
39
-
39
+
40
40
  Args:
41
41
  *args: Positional arguments passed to the parent Trainer
42
42
  **kwargs: Keyword arguments passed to the parent Trainer
@@ -51,19 +51,19 @@ class HFTrainer(Trainer):
51
51
  class HFTrainingArguments(TrainingArguments):
52
52
  """
53
53
  HuggingFace training arguments wrapper for OmniGenome models.
54
-
54
+
55
55
  This class extends the HuggingFace TrainingArguments to include
56
56
  OmniGenome-specific metadata while maintaining full compatibility
57
57
  with the HuggingFace training ecosystem.
58
-
58
+
59
59
  Attributes:
60
60
  metadata: Dictionary containing OmniGenome library information
61
61
  """
62
-
62
+
63
63
  def __init__(self, *args, **kwargs):
64
64
  """
65
65
  Initialize the HuggingFace training arguments wrapper.
66
-
66
+
67
67
  Args:
68
68
  *args: Positional arguments passed to the parent TrainingArguments
69
69
  **kwargs: Keyword arguments passed to the parent TrainingArguments
@@ -29,14 +29,14 @@ from torch.cuda.amp import GradScaler
29
29
  def _infer_optimization_direction(metrics, prev_metrics):
30
30
  """
31
31
  Infer the optimization direction based on metric names and trends.
32
-
32
+
33
33
  This function determines whether larger or smaller values are better for
34
34
  the given metrics by analyzing metric names and their trends over time.
35
-
35
+
36
36
  Args:
37
37
  metrics (dict): Current metric values
38
38
  prev_metrics (list): Previous metric values from multiple epochs
39
-
39
+
40
40
  Returns:
41
41
  str: Either "larger_is_better" or "smaller_is_better"
42
42
  """
@@ -98,11 +98,11 @@ def _infer_optimization_direction(metrics, prev_metrics):
98
98
  class Trainer:
99
99
  """
100
100
  Comprehensive trainer for OmniGenome models.
101
-
101
+
102
102
  This trainer provides a complete training framework with automatic mixed precision,
103
103
  early stopping, metric tracking, and model checkpointing. It supports various
104
104
  training configurations and can handle different types of genomic sequence tasks.
105
-
105
+
106
106
  Attributes:
107
107
  model: The model to be trained
108
108
  train_loader: DataLoader for training data
@@ -118,7 +118,7 @@ class Trainer:
118
118
  metrics: Dictionary to store training metrics
119
119
  predictions: Dictionary to store model predictions
120
120
  """
121
-
121
+
122
122
  def __init__(
123
123
  self,
124
124
  model,
@@ -139,7 +139,7 @@ class Trainer:
139
139
  ):
140
140
  """
141
141
  Initialize the trainer.
142
-
142
+
143
143
  Args:
144
144
  model: The model to be trained
145
145
  train_dataset: Training dataset
@@ -191,7 +191,9 @@ class Trainer:
191
191
  )
192
192
  self.seed = seed
193
193
  self.device = device if device else autocuda.auto_cuda()
194
- self.device = torch.device(self.device) if isinstance(self.device, str) else self.device
194
+ self.device = (
195
+ torch.device(self.device) if isinstance(self.device, str) else self.device
196
+ )
195
197
 
196
198
  self.fast_dtype = {
197
199
  "float32": torch.float32,
@@ -218,11 +220,11 @@ class Trainer:
218
220
  def _is_metric_better(self, metrics, stage="valid"):
219
221
  """
220
222
  Check if the current metrics are better than the best metrics so far.
221
-
223
+
222
224
  Args:
223
225
  metrics (dict): Current metric values
224
226
  stage (str): Stage name ("valid" or "test")
225
-
227
+
226
228
  Returns:
227
229
  bool: True if current metrics are better than best metrics
228
230
  """
@@ -268,11 +270,11 @@ class Trainer:
268
270
  def train(self, path_to_save=None, **kwargs):
269
271
  """
270
272
  Train the model.
271
-
273
+
272
274
  Args:
273
275
  path_to_save (str, optional): Path to save the best model
274
276
  **kwargs: Additional keyword arguments
275
-
277
+
276
278
  Returns:
277
279
  dict: Training metrics and results
278
280
  """
@@ -300,19 +302,29 @@ class Trainer:
300
302
  self.optimizer.zero_grad()
301
303
 
302
304
  if self.fast_dtype:
303
- with torch.autocast(device_type=self.device.type, dtype=self.fast_dtype):
305
+ with torch.autocast(
306
+ device_type=self.device.type, dtype=self.fast_dtype
307
+ ):
304
308
  outputs = self.model(**batch)
305
309
  else:
306
310
  outputs = self.model(**batch)
307
311
  if "loss" not in outputs:
308
312
  # Generally, the model should return a loss in the outputs via OmniGenBench
309
313
  # For the Lora models, the loss is computed separately
310
- if hasattr(self.model, "loss_function") and callable(self.model.loss_function):
311
- loss = self.model.loss_function(outputs['logits'], outputs["labels"])
312
- elif (hasattr(self.model, "model")
313
- and hasattr(self.model.model, "loss_function")
314
- and callable(self.model.model.loss_function)):
315
- loss = self.model.model.loss_function(outputs['logits'], outputs["labels"])
314
+ if hasattr(self.model, "loss_function") and callable(
315
+ self.model.loss_function
316
+ ):
317
+ loss = self.model.loss_function(
318
+ outputs["logits"], outputs["labels"]
319
+ )
320
+ elif (
321
+ hasattr(self.model, "model")
322
+ and hasattr(self.model.model, "loss_function")
323
+ and callable(self.model.model.loss_function)
324
+ ):
325
+ loss = self.model.model.loss_function(
326
+ outputs["logits"], outputs["labels"]
327
+ )
316
328
  else:
317
329
  raise ValueError(
318
330
  "The model does not have a loss function defined. "
@@ -480,10 +492,10 @@ class Trainer:
480
492
  def get_model(self, **kwargs):
481
493
  """
482
494
  Get the trained model.
483
-
495
+
484
496
  Args:
485
497
  **kwargs: Additional keyword arguments
486
-
498
+
487
499
  Returns:
488
500
  The trained model
489
501
  """
@@ -492,7 +504,7 @@ class Trainer:
492
504
  def compute_metrics(self):
493
505
  """
494
506
  Get the metric computation functions.
495
-
507
+
496
508
  Returns:
497
509
  list: List of metric computation functions
498
510
  """
@@ -501,10 +513,10 @@ class Trainer:
501
513
  def unwrap_model(self, model=None):
502
514
  """
503
515
  Unwrap the model from any distributed training wrappers.
504
-
516
+
505
517
  Args:
506
518
  model: Model to unwrap (default: None, uses self.model)
507
-
519
+
508
520
  Returns:
509
521
  The unwrapped model
510
522
  """
@@ -538,7 +550,7 @@ class Trainer:
538
550
  """
539
551
  if os.path.exists(self._model_state_dict_path):
540
552
  self.unwrap_model().load_state_dict(
541
- torch.load(self._model_state_dict_path, map_location='cpu')
553
+ torch.load(self._model_state_dict_path, map_location="cpu")
542
554
  )
543
555
  self.unwrap_model().to(self.device)
544
556
 
@@ -10,4 +10,3 @@
10
10
  """
11
11
  This package contains modules for the dataset hub.
12
12
  """
13
-