omnigenome 0.3.0a1__py3-none-any.whl → 0.3.1a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. omnigenome/__init__.py +16 -8
  2. omnigenome/auto/auto_bench/__init__.py +0 -1
  3. omnigenome/auto/auto_bench/auto_bench.py +24 -14
  4. omnigenome/auto/auto_train/__init__.py +0 -1
  5. omnigenome/auto/auto_train/auto_train.py +11 -12
  6. omnigenome/auto/bench_hub/__init__.py +0 -1
  7. omnigenome/auto/bench_hub/bench_hub.py +1 -1
  8. omnigenome/cli/__init__.py +0 -1
  9. omnigenome/cli/commands/__init__.py +0 -1
  10. omnigenome/cli/commands/base.py +10 -10
  11. omnigenome/cli/commands/bench/__init__.py +0 -1
  12. omnigenome/cli/commands/bench/bench_cli.py +10 -10
  13. omnigenome/cli/commands/rna/__init__.py +0 -1
  14. omnigenome/cli/commands/rna/rna_design.py +10 -11
  15. omnigenome/src/__init__.py +0 -1
  16. omnigenome/src/abc/__init__.py +0 -1
  17. omnigenome/src/abc/abstract_dataset.py +38 -19
  18. omnigenome/src/abc/abstract_metric.py +7 -7
  19. omnigenome/src/abc/abstract_model.py +15 -14
  20. omnigenome/src/abc/abstract_tokenizer.py +9 -7
  21. omnigenome/src/dataset/omni_dataset.py +16 -14
  22. omnigenome/src/lora/__init__.py +0 -1
  23. omnigenome/src/lora/lora_model.py +47 -41
  24. omnigenome/src/metric/classification_metric.py +11 -11
  25. omnigenome/src/metric/metric.py +19 -19
  26. omnigenome/src/metric/ranking_metric.py +15 -15
  27. omnigenome/src/metric/regression_metric.py +18 -18
  28. omnigenome/src/misc/utils.py +40 -36
  29. omnigenome/src/model/augmentation/__init__.py +0 -1
  30. omnigenome/src/model/augmentation/model.py +17 -17
  31. omnigenome/src/model/classification/__init__.py +0 -1
  32. omnigenome/src/model/classification/model.py +28 -32
  33. omnigenome/src/model/embedding/__init__.py +0 -1
  34. omnigenome/src/model/embedding/model.py +35 -35
  35. omnigenome/src/model/mlm/__init__.py +0 -1
  36. omnigenome/src/model/mlm/model.py +13 -13
  37. omnigenome/src/model/module_utils.py +17 -17
  38. omnigenome/src/model/regression/__init__.py +0 -1
  39. omnigenome/src/model/regression/model.py +72 -77
  40. omnigenome/src/model/regression/resnet.py +32 -32
  41. omnigenome/src/model/rna_design/__init__.py +0 -1
  42. omnigenome/src/model/rna_design/model.py +65 -58
  43. omnigenome/src/model/seq2seq/__init__.py +0 -1
  44. omnigenome/src/model/seq2seq/model.py +4 -4
  45. omnigenome/src/tokenizer/bpe_tokenizer.py +27 -27
  46. omnigenome/src/tokenizer/kmers_tokenizer.py +22 -22
  47. omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +11 -11
  48. omnigenome/src/trainer/accelerate_trainer.py +40 -32
  49. omnigenome/src/trainer/hf_trainer.py +8 -8
  50. omnigenome/src/trainer/trainer.py +37 -25
  51. omnigenome/utility/dataset_hub/__init__.py +0 -1
  52. omnigenome/utility/dataset_hub/dataset_hub.py +13 -13
  53. omnigenome/utility/ensemble.py +26 -26
  54. omnigenome/utility/hub_utils.py +8 -8
  55. omnigenome/utility/model_hub/__init__.py +0 -1
  56. omnigenome/utility/model_hub/model_hub.py +26 -25
  57. omnigenome/utility/pipeline_hub/__init__.py +0 -1
  58. omnigenome/utility/pipeline_hub/pipeline.py +49 -49
  59. omnigenome/utility/pipeline_hub/pipeline_hub.py +17 -17
  60. {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/METADATA +2 -2
  61. omnigenome-0.3.1a0.dist-info/RECORD +78 -0
  62. omnigenome-0.3.0a1.dist-info/RECORD +0 -78
  63. {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/WHEEL +0 -0
  64. {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/entry_points.txt +0 -0
  65. {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/licenses/LICENSE +0 -0
  66. {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/top_level.txt +0 -0
@@ -17,17 +17,17 @@ warnings.filterwarnings("once")
17
17
  def is_bpe_tokenization(tokens, threshold=0.1):
18
18
  """
19
19
  Check if the tokenization is BPE-based by analyzing token characteristics.
20
-
20
+
21
21
  This function examines the tokens to determine if they follow BPE tokenization
22
22
  patterns by analyzing token length distributions and special token patterns.
23
-
23
+
24
24
  Args:
25
25
  tokens (list): List of tokens to analyze
26
26
  threshold (float, optional): Threshold for determining BPE tokenization. Defaults to 0.1
27
-
27
+
28
28
  Returns:
29
29
  bool: True if tokens appear to be BPE-based, False otherwise
30
-
30
+
31
31
  Example:
32
32
  >>> tokens = ["▁hello", "▁world", "▁how", "▁are", "▁you"]
33
33
  >>> is_bpe = is_bpe_tokenization(tokens)
@@ -52,15 +52,15 @@ def is_bpe_tokenization(tokens, threshold=0.1):
52
52
  class OmniBPETokenizer(OmniTokenizer):
53
53
  """
54
54
  A Byte Pair Encoding (BPE) tokenizer for genomic sequences.
55
-
55
+
56
56
  This tokenizer uses BPE tokenization for genomic sequences and provides
57
57
  validation to ensure the base tokenizer is BPE-based. It supports sequence
58
58
  preprocessing and handles various input formats.
59
-
59
+
60
60
  Attributes:
61
61
  base_tokenizer: The underlying BPE tokenizer
62
62
  metadata: Dictionary containing tokenizer metadata
63
-
63
+
64
64
  Example:
65
65
  >>> from omnigenome.src.tokenizer import OmniBPETokenizer
66
66
  >>> from transformers import AutoTokenizer
@@ -75,7 +75,7 @@ class OmniBPETokenizer(OmniTokenizer):
75
75
  def __init__(self, base_tokenizer=None, **kwargs):
76
76
  """
77
77
  Initialize the OmniBPETokenizer.
78
-
78
+
79
79
  Args:
80
80
  base_tokenizer: The base BPE tokenizer
81
81
  **kwargs: Additional keyword arguments passed to parent class
@@ -86,21 +86,21 @@ class OmniBPETokenizer(OmniTokenizer):
86
86
  def __call__(self, sequence, **kwargs):
87
87
  """
88
88
  Tokenize a sequence using BPE tokenization.
89
-
89
+
90
90
  This method processes the input sequence using BPE tokenization,
91
91
  handles sequence preprocessing (U/T conversion, whitespace addition),
92
92
  and validates that the tokenization is BPE-based.
93
-
93
+
94
94
  Args:
95
95
  sequence (str): Input sequence to tokenize
96
96
  **kwargs: Additional keyword arguments including max_length
97
-
97
+
98
98
  Returns:
99
99
  dict: Dictionary containing tokenized inputs with keys 'input_ids' and 'attention_mask'
100
-
100
+
101
101
  Raises:
102
102
  ValueError: If the tokenizer is not BPE-based
103
-
103
+
104
104
  Example:
105
105
  >>> sequence = "ACGUAGGUAUCGUAGA"
106
106
  >>> tokenized = tokenizer(sequence)
@@ -136,14 +136,14 @@ class OmniBPETokenizer(OmniTokenizer):
136
136
  def from_pretrained(model_name_or_path, **kwargs):
137
137
  """
138
138
  Create a BPE tokenizer from a pre-trained model.
139
-
139
+
140
140
  Args:
141
141
  model_name_or_path (str): Name or path of the pre-trained model
142
142
  **kwargs: Additional keyword arguments
143
-
143
+
144
144
  Returns:
145
145
  OmniBPETokenizer: Initialized BPE tokenizer
146
-
146
+
147
147
  Example:
148
148
  >>> tokenizer = OmniBPETokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
149
149
  >>> print(type(tokenizer))
@@ -159,14 +159,14 @@ class OmniBPETokenizer(OmniTokenizer):
159
159
  def tokenize(self, sequence, **kwargs):
160
160
  """
161
161
  Tokenize a sequence using the base BPE tokenizer.
162
-
162
+
163
163
  Args:
164
164
  sequence (str): Input sequence to tokenize
165
165
  **kwargs: Additional keyword arguments
166
-
166
+
167
167
  Returns:
168
168
  list: List of tokens
169
-
169
+
170
170
  Example:
171
171
  >>> sequence = "ACGUAGGUAUCGUAGA"
172
172
  >>> tokens = tokenizer.tokenize(sequence)
@@ -178,17 +178,17 @@ class OmniBPETokenizer(OmniTokenizer):
178
178
  def encode(self, sequence, **kwargs):
179
179
  """
180
180
  Encode a sequence using the base BPE tokenizer.
181
-
181
+
182
182
  Args:
183
183
  sequence (str): Input sequence to encode
184
184
  **kwargs: Additional keyword arguments
185
-
185
+
186
186
  Returns:
187
187
  list: List of token IDs
188
-
188
+
189
189
  Raises:
190
190
  AssertionError: If the base tokenizer is not BPE-based
191
-
191
+
192
192
  Example:
193
193
  >>> sequence = "ACGUAGGUAUCGUAGA"
194
194
  >>> token_ids = tokenizer.encode(sequence)
@@ -203,17 +203,17 @@ class OmniBPETokenizer(OmniTokenizer):
203
203
  def decode(self, sequence, **kwargs):
204
204
  """
205
205
  Decode a sequence using the base BPE tokenizer.
206
-
206
+
207
207
  Args:
208
208
  sequence: Input sequence to decode (can be token IDs or tokens)
209
209
  **kwargs: Additional keyword arguments
210
-
210
+
211
211
  Returns:
212
212
  str: Decoded sequence
213
-
213
+
214
214
  Raises:
215
215
  AssertionError: If the base tokenizer is not BPE-based
216
-
216
+
217
217
  Example:
218
218
  >>> token_ids = [1, 2, 3, 4, 5]
219
219
  >>> sequence = tokenizer.decode(token_ids)
@@ -16,18 +16,18 @@ warnings.filterwarnings("once")
16
16
  class OmniKmersTokenizer(OmniTokenizer):
17
17
  """
18
18
  A k-mer based tokenizer for genomic sequences.
19
-
19
+
20
20
  This tokenizer breaks genomic sequences into overlapping k-mers and uses
21
21
  a base tokenizer to convert them into token IDs. It supports various
22
22
  k-mer sizes and overlap configurations for different genomic applications.
23
-
23
+
24
24
  Attributes:
25
25
  base_tokenizer: The underlying tokenizer for converting k-mers to IDs
26
26
  k: Size of k-mers
27
27
  overlap: Number of overlapping positions between consecutive k-mers
28
28
  max_length: Maximum sequence length for tokenization
29
29
  metadata: Dictionary containing tokenizer metadata
30
-
30
+
31
31
  Example:
32
32
  >>> from omnigenome.src.tokenizer import OmniKmersTokenizer
33
33
  >>> from transformers import AutoTokenizer
@@ -42,7 +42,7 @@ class OmniKmersTokenizer(OmniTokenizer):
42
42
  def __init__(self, base_tokenizer=None, k=3, overlap=0, max_length=512, **kwargs):
43
43
  """
44
44
  Initialize the OmniKmersTokenizer.
45
-
45
+
46
46
  Args:
47
47
  base_tokenizer: The base tokenizer for converting k-mers to token IDs
48
48
  k (int, optional): Size of k-mers. Defaults to 3
@@ -59,18 +59,18 @@ class OmniKmersTokenizer(OmniTokenizer):
59
59
  def __call__(self, sequence, **kwargs):
60
60
  """
61
61
  Tokenize a sequence or list of sequences into tokenized inputs.
62
-
62
+
63
63
  This method processes the input sequence(s) by first converting them to k-mers,
64
64
  then using the base tokenizer to convert k-mers to token IDs. It handles
65
65
  sequence preprocessing (U/T conversion) and adds special tokens.
66
-
66
+
67
67
  Args:
68
68
  sequence (str or list): Input sequence(s) to tokenize
69
69
  **kwargs: Additional keyword arguments including max_length
70
-
70
+
71
71
  Returns:
72
72
  dict: Dictionary containing tokenized inputs with keys 'input_ids' and 'attention_mask'
73
-
73
+
74
74
  Example:
75
75
  >>> sequence = "ACGUAGGUAUCGUAGA"
76
76
  >>> tokenized = tokenizer(sequence)
@@ -126,14 +126,14 @@ class OmniKmersTokenizer(OmniTokenizer):
126
126
  def from_pretrained(model_name_or_path, **kwargs):
127
127
  """
128
128
  Create a k-mers tokenizer from a pre-trained model.
129
-
129
+
130
130
  Args:
131
131
  model_name_or_path (str): Name or path of the pre-trained model
132
132
  **kwargs: Additional keyword arguments
133
-
133
+
134
134
  Returns:
135
135
  OmniKmersTokenizer: Initialized k-mers tokenizer
136
-
136
+
137
137
  Example:
138
138
  >>> tokenizer = OmniKmersTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
139
139
  >>> print(type(tokenizer))
@@ -149,17 +149,17 @@ class OmniKmersTokenizer(OmniTokenizer):
149
149
  def tokenize(self, sequence, **kwargs):
150
150
  """
151
151
  Convert sequence(s) into k-mers.
152
-
152
+
153
153
  This method breaks the input sequence(s) into overlapping k-mers based on
154
154
  the configured k-mer size and overlap parameters.
155
-
155
+
156
156
  Args:
157
157
  sequence (str or list): Input sequence(s) to convert to k-mers
158
158
  **kwargs: Additional keyword arguments
159
-
159
+
160
160
  Returns:
161
161
  list: List of k-mer lists for each input sequence
162
-
162
+
163
163
  Example:
164
164
  >>> sequence = "ACGUAGGUAUCGUAGA"
165
165
  >>> k_mers = tokenizer.tokenize(sequence)
@@ -184,11 +184,11 @@ class OmniKmersTokenizer(OmniTokenizer):
184
184
  def encode(self, input_ids, **kwargs):
185
185
  """
186
186
  Encode input IDs using the base tokenizer.
187
-
187
+
188
188
  Args:
189
189
  input_ids: Input IDs to encode
190
190
  **kwargs: Additional keyword arguments
191
-
191
+
192
192
  Returns:
193
193
  Encoded input IDs
194
194
  """
@@ -197,11 +197,11 @@ class OmniKmersTokenizer(OmniTokenizer):
197
197
  def decode(self, input_ids, **kwargs):
198
198
  """
199
199
  Decode input IDs using the base tokenizer.
200
-
200
+
201
201
  Args:
202
202
  input_ids: Input IDs to decode
203
203
  **kwargs: Additional keyword arguments
204
-
204
+
205
205
  Returns:
206
206
  Decoded sequence
207
207
  """
@@ -210,13 +210,13 @@ class OmniKmersTokenizer(OmniTokenizer):
210
210
  def encode_plus(self, sequence, **kwargs):
211
211
  """
212
212
  Encode a sequence with additional information.
213
-
213
+
214
214
  This method is not yet implemented for k-mers tokenizer.
215
-
215
+
216
216
  Args:
217
217
  sequence: Input sequence
218
218
  **kwargs: Additional keyword arguments
219
-
219
+
220
220
  Raises:
221
221
  NotImplementedError: This method is not implemented yet
222
222
  """
@@ -19,16 +19,16 @@ warnings.filterwarnings("once")
19
19
  class OmniSingleNucleotideTokenizer(OmniTokenizer):
20
20
  """
21
21
  Tokenizer for single nucleotide tokenization in genomics.
22
-
22
+
23
23
  This tokenizer converts genomic sequences into individual nucleotide tokens,
24
24
  where each nucleotide (A, T, C, G, U) becomes a separate token. It's designed
25
25
  for genomic sequence processing where fine-grained nucleotide-level analysis
26
26
  is required.
27
-
27
+
28
28
  The tokenizer supports various preprocessing options including U/T conversion
29
29
  and whitespace addition between nucleotides. It also handles special tokens
30
30
  like BOS (beginning of sequence) and EOS (end of sequence) tokens.
31
-
31
+
32
32
  Attributes:
33
33
  u2t (bool): Whether to convert 'U' to 'T'.
34
34
  t2u (bool): Whether to convert 'T' to 'U'.
@@ -54,7 +54,7 @@ class OmniSingleNucleotideTokenizer(OmniTokenizer):
54
54
  def __call__(self, sequence, **kwargs):
55
55
  """
56
56
  Tokenizes sequences using single nucleotide tokenization.
57
-
57
+
58
58
  This method converts genomic sequences into tokenized inputs suitable
59
59
  for model training and inference. It handles sequence preprocessing,
60
60
  tokenization, and padding/truncation.
@@ -76,7 +76,7 @@ class OmniSingleNucleotideTokenizer(OmniTokenizer):
76
76
  >>> # Tokenize a single sequence
77
77
  >>> inputs = tokenizer("ATCGATCG")
78
78
  >>> print(inputs['input_ids'].shape) # torch.Size([1, seq_len])
79
-
79
+
80
80
  >>> # Tokenize multiple sequences
81
81
  >>> inputs = tokenizer(["ATCGATCG", "GCTAGCTA"])
82
82
  >>> print(inputs['input_ids'].shape) # torch.Size([2, seq_len])
@@ -134,7 +134,7 @@ class OmniSingleNucleotideTokenizer(OmniTokenizer):
134
134
  def from_pretrained(model_name_or_path, **kwargs):
135
135
  """
136
136
  Loads a single nucleotide tokenizer from a pre-trained model.
137
-
137
+
138
138
  This method creates a single nucleotide tokenizer wrapper around
139
139
  a Hugging Face tokenizer loaded from a pre-trained model.
140
140
 
@@ -156,7 +156,7 @@ class OmniSingleNucleotideTokenizer(OmniTokenizer):
156
156
  def tokenize(self, sequence, **kwargs):
157
157
  """
158
158
  Converts a sequence into a list of individual nucleotide tokens.
159
-
159
+
160
160
  This method tokenizes genomic sequences by treating each nucleotide
161
161
  as a separate token. It handles both single sequences and lists of sequences.
162
162
 
@@ -172,7 +172,7 @@ class OmniSingleNucleotideTokenizer(OmniTokenizer):
172
172
  >>> # Tokenize a single sequence
173
173
  >>> tokens = tokenizer.tokenize("ATCGATCG")
174
174
  >>> print(tokens) # [['A', 'T', 'C', 'G', 'A', 'T', 'C', 'G']]
175
-
175
+
176
176
  >>> # Tokenize multiple sequences
177
177
  >>> tokens = tokenizer.tokenize(["ATCGATCG", "GCTAGCTA"])
178
178
  >>> print(tokens) # [['A', 'T', 'C', 'G', ...], ['G', 'C', 'T', 'A', ...]]
@@ -191,7 +191,7 @@ class OmniSingleNucleotideTokenizer(OmniTokenizer):
191
191
  def encode(self, sequence, **kwargs):
192
192
  """
193
193
  Converts a sequence into a list of token IDs.
194
-
194
+
195
195
  This method encodes genomic sequences into token IDs using the
196
196
  underlying base tokenizer.
197
197
 
@@ -211,7 +211,7 @@ class OmniSingleNucleotideTokenizer(OmniTokenizer):
211
211
  def decode(self, sequence, **kwargs):
212
212
  """
213
213
  Converts a list of token IDs back into a sequence.
214
-
214
+
215
215
  This method decodes token IDs back into genomic sequences using
216
216
  the underlying base tokenizer.
217
217
 
@@ -231,7 +231,7 @@ class OmniSingleNucleotideTokenizer(OmniTokenizer):
231
231
  def encode_plus(self, sequence, **kwargs):
232
232
  """
233
233
  Encodes a sequence with additional information.
234
-
234
+
235
235
  This method provides enhanced encoding with additional information
236
236
  like attention masks and token type IDs.
237
237
 
@@ -21,15 +21,15 @@ from ..misc.utils import env_meta_info, fprint, seed_everything
21
21
  def _infer_optimization_direction(metrics, prev_metrics):
22
22
  """
23
23
  Infer the optimization direction based on metric values.
24
-
24
+
25
25
  This function analyzes the trend of metric values to determine whether
26
26
  larger values are better (e.g., accuracy) or smaller values are better
27
27
  (e.g., loss).
28
-
28
+
29
29
  Args:
30
30
  metrics (dict): Current metric values
31
31
  prev_metrics (list): Previous metric values
32
-
32
+
33
33
  Returns:
34
34
  str: Either 'larger_is_better' or 'smaller_is_better'
35
35
  """
@@ -91,11 +91,11 @@ def _infer_optimization_direction(metrics, prev_metrics):
91
91
  class AccelerateTrainer:
92
92
  """
93
93
  A distributed training trainer using HuggingFace Accelerate.
94
-
94
+
95
95
  This trainer provides distributed training capabilities with automatic mixed precision,
96
96
  gradient accumulation, and early stopping. It supports both single and multi-GPU
97
97
  training with seamless integration with HuggingFace Accelerate.
98
-
98
+
99
99
  Attributes:
100
100
  model: The model to train
101
101
  train_loader: DataLoader for training data
@@ -110,7 +110,7 @@ class AccelerateTrainer:
110
110
  accelerator: HuggingFace Accelerate instance
111
111
  metrics: Dictionary to store training metrics
112
112
  predictions: Dictionary to store predictions
113
-
113
+
114
114
  Example:
115
115
  >>> from omnigenome.src.trainer import AccelerateTrainer
116
116
  >>> trainer = AccelerateTrainer(
@@ -143,7 +143,7 @@ class AccelerateTrainer:
143
143
  ):
144
144
  """
145
145
  Initialize the AccelerateTrainer.
146
-
146
+
147
147
  Args:
148
148
  model: The model to train
149
149
  train_dataset (torch.utils.data.Dataset, optional): Training dataset
@@ -293,14 +293,14 @@ class AccelerateTrainer:
293
293
  def evaluate(self):
294
294
  """
295
295
  Evaluate the model on the validation dataset.
296
-
296
+
297
297
  This method runs the model in evaluation mode and computes metrics
298
298
  on the validation dataset. It handles distributed evaluation and
299
299
  gathers results from all processes.
300
-
300
+
301
301
  Returns:
302
302
  dict: Dictionary containing evaluation metrics
303
-
303
+
304
304
  Example:
305
305
  >>> metrics = trainer.evaluate()
306
306
  >>> print(f"Validation accuracy: {metrics['accuracy']:.4f}")
@@ -364,14 +364,14 @@ class AccelerateTrainer:
364
364
  def test(self):
365
365
  """
366
366
  Test the model on the test dataset.
367
-
367
+
368
368
  This method runs the model in evaluation mode and computes metrics
369
369
  on the test dataset. It handles distributed testing and gathers
370
370
  results from all processes.
371
-
371
+
372
372
  Returns:
373
373
  dict: Dictionary containing test metrics
374
-
374
+
375
375
  Example:
376
376
  >>> metrics = trainer.test()
377
377
  >>> print(f"Test accuracy: {metrics['accuracy']:.4f}")
@@ -431,18 +431,18 @@ class AccelerateTrainer:
431
431
  def train(self, path_to_save=None, **kwargs):
432
432
  """
433
433
  Train the model using distributed training.
434
-
434
+
435
435
  This method performs the complete training loop with validation,
436
436
  early stopping, and model checkpointing. It handles distributed
437
437
  training across multiple GPUs and processes.
438
-
438
+
439
439
  Args:
440
440
  path_to_save (str, optional): Path to save the trained model
441
441
  **kwargs: Additional keyword arguments for model saving
442
-
442
+
443
443
  Returns:
444
444
  dict: Dictionary containing training metrics
445
-
445
+
446
446
  Example:
447
447
  >>> metrics = trainer.train(path_to_save="./checkpoints/model")
448
448
  >>> print(f"Best validation accuracy: {metrics['best_valid']['accuracy']:.4f}")
@@ -489,12 +489,20 @@ class AccelerateTrainer:
489
489
  if "loss" not in outputs:
490
490
  # Generally, the model should return a loss in the outputs via OmniGenBench
491
491
  # For the Lora models, the loss is computed separately
492
- if hasattr(self.model, "loss_function") and callable(self.model.loss_function):
493
- loss = self.model.loss_function(outputs['logits'], outputs["labels"])
494
- elif (hasattr(self.model, "model")
495
- and hasattr(self.model.model, "loss_function")
496
- and callable(self.model.model.loss_function)):
497
- loss = self.model.model.loss_function(outputs['logits'], outputs["labels"])
492
+ if hasattr(self.model, "loss_function") and callable(
493
+ self.model.loss_function
494
+ ):
495
+ loss = self.model.loss_function(
496
+ outputs["logits"], outputs["labels"]
497
+ )
498
+ elif (
499
+ hasattr(self.model, "model")
500
+ and hasattr(self.model.model, "loss_function")
501
+ and callable(self.model.model.loss_function)
502
+ ):
503
+ loss = self.model.model.loss_function(
504
+ outputs["logits"], outputs["labels"]
505
+ )
498
506
  else:
499
507
  raise ValueError(
500
508
  "The model does not have a loss function defined. "
@@ -585,11 +593,11 @@ class AccelerateTrainer:
585
593
  def _is_metric_better(self, metrics, stage="valid"):
586
594
  """
587
595
  Check if the current metrics are better than the best metrics so far.
588
-
596
+
589
597
  Args:
590
598
  metrics (dict): Current metrics
591
599
  stage (str): Stage of evaluation ('valid' or 'test')
592
-
600
+
593
601
  Returns:
594
602
  bool: True if current metrics are better, False otherwise
595
603
  """
@@ -643,10 +651,10 @@ class AccelerateTrainer:
643
651
  def predict(self, data_loader):
644
652
  """
645
653
  Make predictions using the trained model.
646
-
654
+
647
655
  Args:
648
656
  data_loader: DataLoader containing data to predict on
649
-
657
+
650
658
  Returns:
651
659
  dict: Dictionary containing predictions
652
660
  """
@@ -655,10 +663,10 @@ class AccelerateTrainer:
655
663
  def get_model(self, **kwargs):
656
664
  """
657
665
  Get the trained model.
658
-
666
+
659
667
  Args:
660
668
  **kwargs: Additional keyword arguments
661
-
669
+
662
670
  Returns:
663
671
  The trained model
664
672
  """
@@ -667,10 +675,10 @@ class AccelerateTrainer:
667
675
  def compute_metrics(self):
668
676
  """
669
677
  Compute metrics for evaluation.
670
-
678
+
671
679
  This method should be implemented by subclasses to provide specific
672
680
  metric computation logic.
673
-
681
+
674
682
  Raises:
675
683
  NotImplementedError: If compute_metrics method is not implemented
676
684
  """
@@ -682,7 +690,7 @@ class AccelerateTrainer:
682
690
  def save_model(self, path, overwrite=False, **kwargs):
683
691
  """
684
692
  Save the trained model.
685
-
693
+
686
694
  Args:
687
695
  path (str): Path to save the model
688
696
  overwrite (bool, optional): Whether to overwrite existing files. Defaults to False
@@ -24,19 +24,19 @@ from ... import __version__ as omnigenome_version
24
24
  class HFTrainer(Trainer):
25
25
  """
26
26
  HuggingFace trainer wrapper for OmniGenome models.
27
-
27
+
28
28
  This class extends the HuggingFace Trainer to include OmniGenome-specific
29
29
  metadata and functionality while maintaining full compatibility with the
30
30
  HuggingFace training ecosystem.
31
-
31
+
32
32
  Attributes:
33
33
  metadata: Dictionary containing OmniGenome library information
34
34
  """
35
-
35
+
36
36
  def __init__(self, *args, **kwargs):
37
37
  """
38
38
  Initialize the HuggingFace trainer wrapper.
39
-
39
+
40
40
  Args:
41
41
  *args: Positional arguments passed to the parent Trainer
42
42
  **kwargs: Keyword arguments passed to the parent Trainer
@@ -51,19 +51,19 @@ class HFTrainer(Trainer):
51
51
  class HFTrainingArguments(TrainingArguments):
52
52
  """
53
53
  HuggingFace training arguments wrapper for OmniGenome models.
54
-
54
+
55
55
  This class extends the HuggingFace TrainingArguments to include
56
56
  OmniGenome-specific metadata while maintaining full compatibility
57
57
  with the HuggingFace training ecosystem.
58
-
58
+
59
59
  Attributes:
60
60
  metadata: Dictionary containing OmniGenome library information
61
61
  """
62
-
62
+
63
63
  def __init__(self, *args, **kwargs):
64
64
  """
65
65
  Initialize the HuggingFace training arguments wrapper.
66
-
66
+
67
67
  Args:
68
68
  *args: Positional arguments passed to the parent TrainingArguments
69
69
  **kwargs: Keyword arguments passed to the parent TrainingArguments