omnigenome 0.3.0a0__py3-none-any.whl → 0.3.1a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- omnigenome/__init__.py +29 -44
- omnigenome/auto/auto_bench/__init__.py +0 -1
- omnigenome/auto/auto_bench/auto_bench.py +24 -14
- omnigenome/auto/auto_train/__init__.py +0 -1
- omnigenome/auto/auto_train/auto_train.py +11 -12
- omnigenome/auto/bench_hub/__init__.py +0 -1
- omnigenome/auto/bench_hub/bench_hub.py +1 -1
- omnigenome/cli/__init__.py +0 -1
- omnigenome/cli/commands/__init__.py +0 -1
- omnigenome/cli/commands/base.py +10 -10
- omnigenome/cli/commands/bench/__init__.py +0 -1
- omnigenome/cli/commands/bench/bench_cli.py +10 -10
- omnigenome/cli/commands/rna/__init__.py +0 -1
- omnigenome/cli/commands/rna/rna_design.py +10 -11
- omnigenome/src/__init__.py +0 -1
- omnigenome/src/abc/__init__.py +0 -1
- omnigenome/src/abc/abstract_dataset.py +38 -19
- omnigenome/src/abc/abstract_metric.py +7 -7
- omnigenome/src/abc/abstract_model.py +15 -14
- omnigenome/src/abc/abstract_tokenizer.py +9 -7
- omnigenome/src/dataset/omni_dataset.py +16 -14
- omnigenome/src/lora/__init__.py +0 -1
- omnigenome/src/lora/lora_model.py +47 -41
- omnigenome/src/metric/classification_metric.py +11 -11
- omnigenome/src/metric/metric.py +19 -19
- omnigenome/src/metric/ranking_metric.py +15 -15
- omnigenome/src/metric/regression_metric.py +18 -18
- omnigenome/src/misc/utils.py +214 -150
- omnigenome/src/model/augmentation/__init__.py +0 -1
- omnigenome/src/model/augmentation/model.py +17 -17
- omnigenome/src/model/classification/__init__.py +0 -1
- omnigenome/src/model/classification/model.py +28 -32
- omnigenome/src/model/embedding/__init__.py +0 -1
- omnigenome/src/model/embedding/model.py +35 -35
- omnigenome/src/model/mlm/__init__.py +0 -1
- omnigenome/src/model/mlm/model.py +13 -13
- omnigenome/src/model/module_utils.py +17 -17
- omnigenome/src/model/regression/__init__.py +0 -1
- omnigenome/src/model/regression/model.py +72 -77
- omnigenome/src/model/regression/resnet.py +32 -32
- omnigenome/src/model/rna_design/__init__.py +0 -1
- omnigenome/src/model/rna_design/model.py +168 -118
- omnigenome/src/model/seq2seq/__init__.py +0 -1
- omnigenome/src/model/seq2seq/model.py +4 -4
- omnigenome/src/tokenizer/bpe_tokenizer.py +27 -27
- omnigenome/src/tokenizer/kmers_tokenizer.py +22 -22
- omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +11 -11
- omnigenome/src/trainer/accelerate_trainer.py +40 -32
- omnigenome/src/trainer/hf_trainer.py +8 -8
- omnigenome/src/trainer/trainer.py +37 -25
- omnigenome/utility/dataset_hub/__init__.py +0 -1
- omnigenome/utility/dataset_hub/dataset_hub.py +13 -13
- omnigenome/utility/ensemble.py +26 -26
- omnigenome/utility/hub_utils.py +8 -8
- omnigenome/utility/model_hub/__init__.py +0 -1
- omnigenome/utility/model_hub/model_hub.py +26 -25
- omnigenome/utility/pipeline_hub/__init__.py +0 -1
- omnigenome/utility/pipeline_hub/pipeline.py +49 -49
- omnigenome/utility/pipeline_hub/pipeline_hub.py +17 -17
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/METADATA +3 -3
- omnigenome-0.3.1a0.dist-info/RECORD +78 -0
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/top_level.txt +0 -1
- omnigenome-0.3.0a0.dist-info/RECORD +0 -85
- tests/__init__.py +0 -9
- tests/conftest.py +0 -160
- tests/test_dataset_patterns.py +0 -291
- tests/test_examples_syntax.py +0 -83
- tests/test_model_loading.py +0 -183
- tests/test_rna_functions.py +0 -255
- tests/test_training_patterns.py +0 -302
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/WHEEL +0 -0
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/entry_points.txt +0 -0
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/licenses/LICENSE +0 -0
|
@@ -16,18 +16,18 @@ warnings.filterwarnings("once")
|
|
|
16
16
|
class OmniKmersTokenizer(OmniTokenizer):
|
|
17
17
|
"""
|
|
18
18
|
A k-mer based tokenizer for genomic sequences.
|
|
19
|
-
|
|
19
|
+
|
|
20
20
|
This tokenizer breaks genomic sequences into overlapping k-mers and uses
|
|
21
21
|
a base tokenizer to convert them into token IDs. It supports various
|
|
22
22
|
k-mer sizes and overlap configurations for different genomic applications.
|
|
23
|
-
|
|
23
|
+
|
|
24
24
|
Attributes:
|
|
25
25
|
base_tokenizer: The underlying tokenizer for converting k-mers to IDs
|
|
26
26
|
k: Size of k-mers
|
|
27
27
|
overlap: Number of overlapping positions between consecutive k-mers
|
|
28
28
|
max_length: Maximum sequence length for tokenization
|
|
29
29
|
metadata: Dictionary containing tokenizer metadata
|
|
30
|
-
|
|
30
|
+
|
|
31
31
|
Example:
|
|
32
32
|
>>> from omnigenome.src.tokenizer import OmniKmersTokenizer
|
|
33
33
|
>>> from transformers import AutoTokenizer
|
|
@@ -42,7 +42,7 @@ class OmniKmersTokenizer(OmniTokenizer):
|
|
|
42
42
|
def __init__(self, base_tokenizer=None, k=3, overlap=0, max_length=512, **kwargs):
|
|
43
43
|
"""
|
|
44
44
|
Initialize the OmniKmersTokenizer.
|
|
45
|
-
|
|
45
|
+
|
|
46
46
|
Args:
|
|
47
47
|
base_tokenizer: The base tokenizer for converting k-mers to token IDs
|
|
48
48
|
k (int, optional): Size of k-mers. Defaults to 3
|
|
@@ -59,18 +59,18 @@ class OmniKmersTokenizer(OmniTokenizer):
|
|
|
59
59
|
def __call__(self, sequence, **kwargs):
|
|
60
60
|
"""
|
|
61
61
|
Tokenize a sequence or list of sequences into tokenized inputs.
|
|
62
|
-
|
|
62
|
+
|
|
63
63
|
This method processes the input sequence(s) by first converting them to k-mers,
|
|
64
64
|
then using the base tokenizer to convert k-mers to token IDs. It handles
|
|
65
65
|
sequence preprocessing (U/T conversion) and adds special tokens.
|
|
66
|
-
|
|
66
|
+
|
|
67
67
|
Args:
|
|
68
68
|
sequence (str or list): Input sequence(s) to tokenize
|
|
69
69
|
**kwargs: Additional keyword arguments including max_length
|
|
70
|
-
|
|
70
|
+
|
|
71
71
|
Returns:
|
|
72
72
|
dict: Dictionary containing tokenized inputs with keys 'input_ids' and 'attention_mask'
|
|
73
|
-
|
|
73
|
+
|
|
74
74
|
Example:
|
|
75
75
|
>>> sequence = "ACGUAGGUAUCGUAGA"
|
|
76
76
|
>>> tokenized = tokenizer(sequence)
|
|
@@ -126,14 +126,14 @@ class OmniKmersTokenizer(OmniTokenizer):
|
|
|
126
126
|
def from_pretrained(model_name_or_path, **kwargs):
|
|
127
127
|
"""
|
|
128
128
|
Create a k-mers tokenizer from a pre-trained model.
|
|
129
|
-
|
|
129
|
+
|
|
130
130
|
Args:
|
|
131
131
|
model_name_or_path (str): Name or path of the pre-trained model
|
|
132
132
|
**kwargs: Additional keyword arguments
|
|
133
|
-
|
|
133
|
+
|
|
134
134
|
Returns:
|
|
135
135
|
OmniKmersTokenizer: Initialized k-mers tokenizer
|
|
136
|
-
|
|
136
|
+
|
|
137
137
|
Example:
|
|
138
138
|
>>> tokenizer = OmniKmersTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
|
|
139
139
|
>>> print(type(tokenizer))
|
|
@@ -149,17 +149,17 @@ class OmniKmersTokenizer(OmniTokenizer):
|
|
|
149
149
|
def tokenize(self, sequence, **kwargs):
|
|
150
150
|
"""
|
|
151
151
|
Convert sequence(s) into k-mers.
|
|
152
|
-
|
|
152
|
+
|
|
153
153
|
This method breaks the input sequence(s) into overlapping k-mers based on
|
|
154
154
|
the configured k-mer size and overlap parameters.
|
|
155
|
-
|
|
155
|
+
|
|
156
156
|
Args:
|
|
157
157
|
sequence (str or list): Input sequence(s) to convert to k-mers
|
|
158
158
|
**kwargs: Additional keyword arguments
|
|
159
|
-
|
|
159
|
+
|
|
160
160
|
Returns:
|
|
161
161
|
list: List of k-mer lists for each input sequence
|
|
162
|
-
|
|
162
|
+
|
|
163
163
|
Example:
|
|
164
164
|
>>> sequence = "ACGUAGGUAUCGUAGA"
|
|
165
165
|
>>> k_mers = tokenizer.tokenize(sequence)
|
|
@@ -184,11 +184,11 @@ class OmniKmersTokenizer(OmniTokenizer):
|
|
|
184
184
|
def encode(self, input_ids, **kwargs):
|
|
185
185
|
"""
|
|
186
186
|
Encode input IDs using the base tokenizer.
|
|
187
|
-
|
|
187
|
+
|
|
188
188
|
Args:
|
|
189
189
|
input_ids: Input IDs to encode
|
|
190
190
|
**kwargs: Additional keyword arguments
|
|
191
|
-
|
|
191
|
+
|
|
192
192
|
Returns:
|
|
193
193
|
Encoded input IDs
|
|
194
194
|
"""
|
|
@@ -197,11 +197,11 @@ class OmniKmersTokenizer(OmniTokenizer):
|
|
|
197
197
|
def decode(self, input_ids, **kwargs):
|
|
198
198
|
"""
|
|
199
199
|
Decode input IDs using the base tokenizer.
|
|
200
|
-
|
|
200
|
+
|
|
201
201
|
Args:
|
|
202
202
|
input_ids: Input IDs to decode
|
|
203
203
|
**kwargs: Additional keyword arguments
|
|
204
|
-
|
|
204
|
+
|
|
205
205
|
Returns:
|
|
206
206
|
Decoded sequence
|
|
207
207
|
"""
|
|
@@ -210,13 +210,13 @@ class OmniKmersTokenizer(OmniTokenizer):
|
|
|
210
210
|
def encode_plus(self, sequence, **kwargs):
|
|
211
211
|
"""
|
|
212
212
|
Encode a sequence with additional information.
|
|
213
|
-
|
|
213
|
+
|
|
214
214
|
This method is not yet implemented for k-mers tokenizer.
|
|
215
|
-
|
|
215
|
+
|
|
216
216
|
Args:
|
|
217
217
|
sequence: Input sequence
|
|
218
218
|
**kwargs: Additional keyword arguments
|
|
219
|
-
|
|
219
|
+
|
|
220
220
|
Raises:
|
|
221
221
|
NotImplementedError: This method is not implemented yet
|
|
222
222
|
"""
|
|
@@ -19,16 +19,16 @@ warnings.filterwarnings("once")
|
|
|
19
19
|
class OmniSingleNucleotideTokenizer(OmniTokenizer):
|
|
20
20
|
"""
|
|
21
21
|
Tokenizer for single nucleotide tokenization in genomics.
|
|
22
|
-
|
|
22
|
+
|
|
23
23
|
This tokenizer converts genomic sequences into individual nucleotide tokens,
|
|
24
24
|
where each nucleotide (A, T, C, G, U) becomes a separate token. It's designed
|
|
25
25
|
for genomic sequence processing where fine-grained nucleotide-level analysis
|
|
26
26
|
is required.
|
|
27
|
-
|
|
27
|
+
|
|
28
28
|
The tokenizer supports various preprocessing options including U/T conversion
|
|
29
29
|
and whitespace addition between nucleotides. It also handles special tokens
|
|
30
30
|
like BOS (beginning of sequence) and EOS (end of sequence) tokens.
|
|
31
|
-
|
|
31
|
+
|
|
32
32
|
Attributes:
|
|
33
33
|
u2t (bool): Whether to convert 'U' to 'T'.
|
|
34
34
|
t2u (bool): Whether to convert 'T' to 'U'.
|
|
@@ -54,7 +54,7 @@ class OmniSingleNucleotideTokenizer(OmniTokenizer):
|
|
|
54
54
|
def __call__(self, sequence, **kwargs):
|
|
55
55
|
"""
|
|
56
56
|
Tokenizes sequences using single nucleotide tokenization.
|
|
57
|
-
|
|
57
|
+
|
|
58
58
|
This method converts genomic sequences into tokenized inputs suitable
|
|
59
59
|
for model training and inference. It handles sequence preprocessing,
|
|
60
60
|
tokenization, and padding/truncation.
|
|
@@ -76,7 +76,7 @@ class OmniSingleNucleotideTokenizer(OmniTokenizer):
|
|
|
76
76
|
>>> # Tokenize a single sequence
|
|
77
77
|
>>> inputs = tokenizer("ATCGATCG")
|
|
78
78
|
>>> print(inputs['input_ids'].shape) # torch.Size([1, seq_len])
|
|
79
|
-
|
|
79
|
+
|
|
80
80
|
>>> # Tokenize multiple sequences
|
|
81
81
|
>>> inputs = tokenizer(["ATCGATCG", "GCTAGCTA"])
|
|
82
82
|
>>> print(inputs['input_ids'].shape) # torch.Size([2, seq_len])
|
|
@@ -134,7 +134,7 @@ class OmniSingleNucleotideTokenizer(OmniTokenizer):
|
|
|
134
134
|
def from_pretrained(model_name_or_path, **kwargs):
|
|
135
135
|
"""
|
|
136
136
|
Loads a single nucleotide tokenizer from a pre-trained model.
|
|
137
|
-
|
|
137
|
+
|
|
138
138
|
This method creates a single nucleotide tokenizer wrapper around
|
|
139
139
|
a Hugging Face tokenizer loaded from a pre-trained model.
|
|
140
140
|
|
|
@@ -156,7 +156,7 @@ class OmniSingleNucleotideTokenizer(OmniTokenizer):
|
|
|
156
156
|
def tokenize(self, sequence, **kwargs):
|
|
157
157
|
"""
|
|
158
158
|
Converts a sequence into a list of individual nucleotide tokens.
|
|
159
|
-
|
|
159
|
+
|
|
160
160
|
This method tokenizes genomic sequences by treating each nucleotide
|
|
161
161
|
as a separate token. It handles both single sequences and lists of sequences.
|
|
162
162
|
|
|
@@ -172,7 +172,7 @@ class OmniSingleNucleotideTokenizer(OmniTokenizer):
|
|
|
172
172
|
>>> # Tokenize a single sequence
|
|
173
173
|
>>> tokens = tokenizer.tokenize("ATCGATCG")
|
|
174
174
|
>>> print(tokens) # [['A', 'T', 'C', 'G', 'A', 'T', 'C', 'G']]
|
|
175
|
-
|
|
175
|
+
|
|
176
176
|
>>> # Tokenize multiple sequences
|
|
177
177
|
>>> tokens = tokenizer.tokenize(["ATCGATCG", "GCTAGCTA"])
|
|
178
178
|
>>> print(tokens) # [['A', 'T', 'C', 'G', ...], ['G', 'C', 'T', 'A', ...]]
|
|
@@ -191,7 +191,7 @@ class OmniSingleNucleotideTokenizer(OmniTokenizer):
|
|
|
191
191
|
def encode(self, sequence, **kwargs):
|
|
192
192
|
"""
|
|
193
193
|
Converts a sequence into a list of token IDs.
|
|
194
|
-
|
|
194
|
+
|
|
195
195
|
This method encodes genomic sequences into token IDs using the
|
|
196
196
|
underlying base tokenizer.
|
|
197
197
|
|
|
@@ -211,7 +211,7 @@ class OmniSingleNucleotideTokenizer(OmniTokenizer):
|
|
|
211
211
|
def decode(self, sequence, **kwargs):
|
|
212
212
|
"""
|
|
213
213
|
Converts a list of token IDs back into a sequence.
|
|
214
|
-
|
|
214
|
+
|
|
215
215
|
This method decodes token IDs back into genomic sequences using
|
|
216
216
|
the underlying base tokenizer.
|
|
217
217
|
|
|
@@ -231,7 +231,7 @@ class OmniSingleNucleotideTokenizer(OmniTokenizer):
|
|
|
231
231
|
def encode_plus(self, sequence, **kwargs):
|
|
232
232
|
"""
|
|
233
233
|
Encodes a sequence with additional information.
|
|
234
|
-
|
|
234
|
+
|
|
235
235
|
This method provides enhanced encoding with additional information
|
|
236
236
|
like attention masks and token type IDs.
|
|
237
237
|
|
|
@@ -21,15 +21,15 @@ from ..misc.utils import env_meta_info, fprint, seed_everything
|
|
|
21
21
|
def _infer_optimization_direction(metrics, prev_metrics):
|
|
22
22
|
"""
|
|
23
23
|
Infer the optimization direction based on metric values.
|
|
24
|
-
|
|
24
|
+
|
|
25
25
|
This function analyzes the trend of metric values to determine whether
|
|
26
26
|
larger values are better (e.g., accuracy) or smaller values are better
|
|
27
27
|
(e.g., loss).
|
|
28
|
-
|
|
28
|
+
|
|
29
29
|
Args:
|
|
30
30
|
metrics (dict): Current metric values
|
|
31
31
|
prev_metrics (list): Previous metric values
|
|
32
|
-
|
|
32
|
+
|
|
33
33
|
Returns:
|
|
34
34
|
str: Either 'larger_is_better' or 'smaller_is_better'
|
|
35
35
|
"""
|
|
@@ -91,11 +91,11 @@ def _infer_optimization_direction(metrics, prev_metrics):
|
|
|
91
91
|
class AccelerateTrainer:
|
|
92
92
|
"""
|
|
93
93
|
A distributed training trainer using HuggingFace Accelerate.
|
|
94
|
-
|
|
94
|
+
|
|
95
95
|
This trainer provides distributed training capabilities with automatic mixed precision,
|
|
96
96
|
gradient accumulation, and early stopping. It supports both single and multi-GPU
|
|
97
97
|
training with seamless integration with HuggingFace Accelerate.
|
|
98
|
-
|
|
98
|
+
|
|
99
99
|
Attributes:
|
|
100
100
|
model: The model to train
|
|
101
101
|
train_loader: DataLoader for training data
|
|
@@ -110,7 +110,7 @@ class AccelerateTrainer:
|
|
|
110
110
|
accelerator: HuggingFace Accelerate instance
|
|
111
111
|
metrics: Dictionary to store training metrics
|
|
112
112
|
predictions: Dictionary to store predictions
|
|
113
|
-
|
|
113
|
+
|
|
114
114
|
Example:
|
|
115
115
|
>>> from omnigenome.src.trainer import AccelerateTrainer
|
|
116
116
|
>>> trainer = AccelerateTrainer(
|
|
@@ -143,7 +143,7 @@ class AccelerateTrainer:
|
|
|
143
143
|
):
|
|
144
144
|
"""
|
|
145
145
|
Initialize the AccelerateTrainer.
|
|
146
|
-
|
|
146
|
+
|
|
147
147
|
Args:
|
|
148
148
|
model: The model to train
|
|
149
149
|
train_dataset (torch.utils.data.Dataset, optional): Training dataset
|
|
@@ -293,14 +293,14 @@ class AccelerateTrainer:
|
|
|
293
293
|
def evaluate(self):
|
|
294
294
|
"""
|
|
295
295
|
Evaluate the model on the validation dataset.
|
|
296
|
-
|
|
296
|
+
|
|
297
297
|
This method runs the model in evaluation mode and computes metrics
|
|
298
298
|
on the validation dataset. It handles distributed evaluation and
|
|
299
299
|
gathers results from all processes.
|
|
300
|
-
|
|
300
|
+
|
|
301
301
|
Returns:
|
|
302
302
|
dict: Dictionary containing evaluation metrics
|
|
303
|
-
|
|
303
|
+
|
|
304
304
|
Example:
|
|
305
305
|
>>> metrics = trainer.evaluate()
|
|
306
306
|
>>> print(f"Validation accuracy: {metrics['accuracy']:.4f}")
|
|
@@ -364,14 +364,14 @@ class AccelerateTrainer:
|
|
|
364
364
|
def test(self):
|
|
365
365
|
"""
|
|
366
366
|
Test the model on the test dataset.
|
|
367
|
-
|
|
367
|
+
|
|
368
368
|
This method runs the model in evaluation mode and computes metrics
|
|
369
369
|
on the test dataset. It handles distributed testing and gathers
|
|
370
370
|
results from all processes.
|
|
371
|
-
|
|
371
|
+
|
|
372
372
|
Returns:
|
|
373
373
|
dict: Dictionary containing test metrics
|
|
374
|
-
|
|
374
|
+
|
|
375
375
|
Example:
|
|
376
376
|
>>> metrics = trainer.test()
|
|
377
377
|
>>> print(f"Test accuracy: {metrics['accuracy']:.4f}")
|
|
@@ -431,18 +431,18 @@ class AccelerateTrainer:
|
|
|
431
431
|
def train(self, path_to_save=None, **kwargs):
|
|
432
432
|
"""
|
|
433
433
|
Train the model using distributed training.
|
|
434
|
-
|
|
434
|
+
|
|
435
435
|
This method performs the complete training loop with validation,
|
|
436
436
|
early stopping, and model checkpointing. It handles distributed
|
|
437
437
|
training across multiple GPUs and processes.
|
|
438
|
-
|
|
438
|
+
|
|
439
439
|
Args:
|
|
440
440
|
path_to_save (str, optional): Path to save the trained model
|
|
441
441
|
**kwargs: Additional keyword arguments for model saving
|
|
442
|
-
|
|
442
|
+
|
|
443
443
|
Returns:
|
|
444
444
|
dict: Dictionary containing training metrics
|
|
445
|
-
|
|
445
|
+
|
|
446
446
|
Example:
|
|
447
447
|
>>> metrics = trainer.train(path_to_save="./checkpoints/model")
|
|
448
448
|
>>> print(f"Best validation accuracy: {metrics['best_valid']['accuracy']:.4f}")
|
|
@@ -489,12 +489,20 @@ class AccelerateTrainer:
|
|
|
489
489
|
if "loss" not in outputs:
|
|
490
490
|
# Generally, the model should return a loss in the outputs via OmniGenBench
|
|
491
491
|
# For the Lora models, the loss is computed separately
|
|
492
|
-
if hasattr(self.model, "loss_function") and callable(
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
492
|
+
if hasattr(self.model, "loss_function") and callable(
|
|
493
|
+
self.model.loss_function
|
|
494
|
+
):
|
|
495
|
+
loss = self.model.loss_function(
|
|
496
|
+
outputs["logits"], outputs["labels"]
|
|
497
|
+
)
|
|
498
|
+
elif (
|
|
499
|
+
hasattr(self.model, "model")
|
|
500
|
+
and hasattr(self.model.model, "loss_function")
|
|
501
|
+
and callable(self.model.model.loss_function)
|
|
502
|
+
):
|
|
503
|
+
loss = self.model.model.loss_function(
|
|
504
|
+
outputs["logits"], outputs["labels"]
|
|
505
|
+
)
|
|
498
506
|
else:
|
|
499
507
|
raise ValueError(
|
|
500
508
|
"The model does not have a loss function defined. "
|
|
@@ -585,11 +593,11 @@ class AccelerateTrainer:
|
|
|
585
593
|
def _is_metric_better(self, metrics, stage="valid"):
|
|
586
594
|
"""
|
|
587
595
|
Check if the current metrics are better than the best metrics so far.
|
|
588
|
-
|
|
596
|
+
|
|
589
597
|
Args:
|
|
590
598
|
metrics (dict): Current metrics
|
|
591
599
|
stage (str): Stage of evaluation ('valid' or 'test')
|
|
592
|
-
|
|
600
|
+
|
|
593
601
|
Returns:
|
|
594
602
|
bool: True if current metrics are better, False otherwise
|
|
595
603
|
"""
|
|
@@ -643,10 +651,10 @@ class AccelerateTrainer:
|
|
|
643
651
|
def predict(self, data_loader):
|
|
644
652
|
"""
|
|
645
653
|
Make predictions using the trained model.
|
|
646
|
-
|
|
654
|
+
|
|
647
655
|
Args:
|
|
648
656
|
data_loader: DataLoader containing data to predict on
|
|
649
|
-
|
|
657
|
+
|
|
650
658
|
Returns:
|
|
651
659
|
dict: Dictionary containing predictions
|
|
652
660
|
"""
|
|
@@ -655,10 +663,10 @@ class AccelerateTrainer:
|
|
|
655
663
|
def get_model(self, **kwargs):
|
|
656
664
|
"""
|
|
657
665
|
Get the trained model.
|
|
658
|
-
|
|
666
|
+
|
|
659
667
|
Args:
|
|
660
668
|
**kwargs: Additional keyword arguments
|
|
661
|
-
|
|
669
|
+
|
|
662
670
|
Returns:
|
|
663
671
|
The trained model
|
|
664
672
|
"""
|
|
@@ -667,10 +675,10 @@ class AccelerateTrainer:
|
|
|
667
675
|
def compute_metrics(self):
|
|
668
676
|
"""
|
|
669
677
|
Compute metrics for evaluation.
|
|
670
|
-
|
|
678
|
+
|
|
671
679
|
This method should be implemented by subclasses to provide specific
|
|
672
680
|
metric computation logic.
|
|
673
|
-
|
|
681
|
+
|
|
674
682
|
Raises:
|
|
675
683
|
NotImplementedError: If compute_metrics method is not implemented
|
|
676
684
|
"""
|
|
@@ -682,7 +690,7 @@ class AccelerateTrainer:
|
|
|
682
690
|
def save_model(self, path, overwrite=False, **kwargs):
|
|
683
691
|
"""
|
|
684
692
|
Save the trained model.
|
|
685
|
-
|
|
693
|
+
|
|
686
694
|
Args:
|
|
687
695
|
path (str): Path to save the model
|
|
688
696
|
overwrite (bool, optional): Whether to overwrite existing files. Defaults to False
|
|
@@ -24,19 +24,19 @@ from ... import __version__ as omnigenome_version
|
|
|
24
24
|
class HFTrainer(Trainer):
|
|
25
25
|
"""
|
|
26
26
|
HuggingFace trainer wrapper for OmniGenome models.
|
|
27
|
-
|
|
27
|
+
|
|
28
28
|
This class extends the HuggingFace Trainer to include OmniGenome-specific
|
|
29
29
|
metadata and functionality while maintaining full compatibility with the
|
|
30
30
|
HuggingFace training ecosystem.
|
|
31
|
-
|
|
31
|
+
|
|
32
32
|
Attributes:
|
|
33
33
|
metadata: Dictionary containing OmniGenome library information
|
|
34
34
|
"""
|
|
35
|
-
|
|
35
|
+
|
|
36
36
|
def __init__(self, *args, **kwargs):
|
|
37
37
|
"""
|
|
38
38
|
Initialize the HuggingFace trainer wrapper.
|
|
39
|
-
|
|
39
|
+
|
|
40
40
|
Args:
|
|
41
41
|
*args: Positional arguments passed to the parent Trainer
|
|
42
42
|
**kwargs: Keyword arguments passed to the parent Trainer
|
|
@@ -51,19 +51,19 @@ class HFTrainer(Trainer):
|
|
|
51
51
|
class HFTrainingArguments(TrainingArguments):
|
|
52
52
|
"""
|
|
53
53
|
HuggingFace training arguments wrapper for OmniGenome models.
|
|
54
|
-
|
|
54
|
+
|
|
55
55
|
This class extends the HuggingFace TrainingArguments to include
|
|
56
56
|
OmniGenome-specific metadata while maintaining full compatibility
|
|
57
57
|
with the HuggingFace training ecosystem.
|
|
58
|
-
|
|
58
|
+
|
|
59
59
|
Attributes:
|
|
60
60
|
metadata: Dictionary containing OmniGenome library information
|
|
61
61
|
"""
|
|
62
|
-
|
|
62
|
+
|
|
63
63
|
def __init__(self, *args, **kwargs):
|
|
64
64
|
"""
|
|
65
65
|
Initialize the HuggingFace training arguments wrapper.
|
|
66
|
-
|
|
66
|
+
|
|
67
67
|
Args:
|
|
68
68
|
*args: Positional arguments passed to the parent TrainingArguments
|
|
69
69
|
**kwargs: Keyword arguments passed to the parent TrainingArguments
|
|
@@ -29,14 +29,14 @@ from torch.cuda.amp import GradScaler
|
|
|
29
29
|
def _infer_optimization_direction(metrics, prev_metrics):
|
|
30
30
|
"""
|
|
31
31
|
Infer the optimization direction based on metric names and trends.
|
|
32
|
-
|
|
32
|
+
|
|
33
33
|
This function determines whether larger or smaller values are better for
|
|
34
34
|
the given metrics by analyzing metric names and their trends over time.
|
|
35
|
-
|
|
35
|
+
|
|
36
36
|
Args:
|
|
37
37
|
metrics (dict): Current metric values
|
|
38
38
|
prev_metrics (list): Previous metric values from multiple epochs
|
|
39
|
-
|
|
39
|
+
|
|
40
40
|
Returns:
|
|
41
41
|
str: Either "larger_is_better" or "smaller_is_better"
|
|
42
42
|
"""
|
|
@@ -98,11 +98,11 @@ def _infer_optimization_direction(metrics, prev_metrics):
|
|
|
98
98
|
class Trainer:
|
|
99
99
|
"""
|
|
100
100
|
Comprehensive trainer for OmniGenome models.
|
|
101
|
-
|
|
101
|
+
|
|
102
102
|
This trainer provides a complete training framework with automatic mixed precision,
|
|
103
103
|
early stopping, metric tracking, and model checkpointing. It supports various
|
|
104
104
|
training configurations and can handle different types of genomic sequence tasks.
|
|
105
|
-
|
|
105
|
+
|
|
106
106
|
Attributes:
|
|
107
107
|
model: The model to be trained
|
|
108
108
|
train_loader: DataLoader for training data
|
|
@@ -118,7 +118,7 @@ class Trainer:
|
|
|
118
118
|
metrics: Dictionary to store training metrics
|
|
119
119
|
predictions: Dictionary to store model predictions
|
|
120
120
|
"""
|
|
121
|
-
|
|
121
|
+
|
|
122
122
|
def __init__(
|
|
123
123
|
self,
|
|
124
124
|
model,
|
|
@@ -139,7 +139,7 @@ class Trainer:
|
|
|
139
139
|
):
|
|
140
140
|
"""
|
|
141
141
|
Initialize the trainer.
|
|
142
|
-
|
|
142
|
+
|
|
143
143
|
Args:
|
|
144
144
|
model: The model to be trained
|
|
145
145
|
train_dataset: Training dataset
|
|
@@ -191,7 +191,9 @@ class Trainer:
|
|
|
191
191
|
)
|
|
192
192
|
self.seed = seed
|
|
193
193
|
self.device = device if device else autocuda.auto_cuda()
|
|
194
|
-
self.device =
|
|
194
|
+
self.device = (
|
|
195
|
+
torch.device(self.device) if isinstance(self.device, str) else self.device
|
|
196
|
+
)
|
|
195
197
|
|
|
196
198
|
self.fast_dtype = {
|
|
197
199
|
"float32": torch.float32,
|
|
@@ -218,11 +220,11 @@ class Trainer:
|
|
|
218
220
|
def _is_metric_better(self, metrics, stage="valid"):
|
|
219
221
|
"""
|
|
220
222
|
Check if the current metrics are better than the best metrics so far.
|
|
221
|
-
|
|
223
|
+
|
|
222
224
|
Args:
|
|
223
225
|
metrics (dict): Current metric values
|
|
224
226
|
stage (str): Stage name ("valid" or "test")
|
|
225
|
-
|
|
227
|
+
|
|
226
228
|
Returns:
|
|
227
229
|
bool: True if current metrics are better than best metrics
|
|
228
230
|
"""
|
|
@@ -268,11 +270,11 @@ class Trainer:
|
|
|
268
270
|
def train(self, path_to_save=None, **kwargs):
|
|
269
271
|
"""
|
|
270
272
|
Train the model.
|
|
271
|
-
|
|
273
|
+
|
|
272
274
|
Args:
|
|
273
275
|
path_to_save (str, optional): Path to save the best model
|
|
274
276
|
**kwargs: Additional keyword arguments
|
|
275
|
-
|
|
277
|
+
|
|
276
278
|
Returns:
|
|
277
279
|
dict: Training metrics and results
|
|
278
280
|
"""
|
|
@@ -300,19 +302,29 @@ class Trainer:
|
|
|
300
302
|
self.optimizer.zero_grad()
|
|
301
303
|
|
|
302
304
|
if self.fast_dtype:
|
|
303
|
-
with torch.autocast(
|
|
305
|
+
with torch.autocast(
|
|
306
|
+
device_type=self.device.type, dtype=self.fast_dtype
|
|
307
|
+
):
|
|
304
308
|
outputs = self.model(**batch)
|
|
305
309
|
else:
|
|
306
310
|
outputs = self.model(**batch)
|
|
307
311
|
if "loss" not in outputs:
|
|
308
312
|
# Generally, the model should return a loss in the outputs via OmniGenBench
|
|
309
313
|
# For the Lora models, the loss is computed separately
|
|
310
|
-
if hasattr(self.model, "loss_function") and callable(
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
314
|
+
if hasattr(self.model, "loss_function") and callable(
|
|
315
|
+
self.model.loss_function
|
|
316
|
+
):
|
|
317
|
+
loss = self.model.loss_function(
|
|
318
|
+
outputs["logits"], outputs["labels"]
|
|
319
|
+
)
|
|
320
|
+
elif (
|
|
321
|
+
hasattr(self.model, "model")
|
|
322
|
+
and hasattr(self.model.model, "loss_function")
|
|
323
|
+
and callable(self.model.model.loss_function)
|
|
324
|
+
):
|
|
325
|
+
loss = self.model.model.loss_function(
|
|
326
|
+
outputs["logits"], outputs["labels"]
|
|
327
|
+
)
|
|
316
328
|
else:
|
|
317
329
|
raise ValueError(
|
|
318
330
|
"The model does not have a loss function defined. "
|
|
@@ -480,10 +492,10 @@ class Trainer:
|
|
|
480
492
|
def get_model(self, **kwargs):
|
|
481
493
|
"""
|
|
482
494
|
Get the trained model.
|
|
483
|
-
|
|
495
|
+
|
|
484
496
|
Args:
|
|
485
497
|
**kwargs: Additional keyword arguments
|
|
486
|
-
|
|
498
|
+
|
|
487
499
|
Returns:
|
|
488
500
|
The trained model
|
|
489
501
|
"""
|
|
@@ -492,7 +504,7 @@ class Trainer:
|
|
|
492
504
|
def compute_metrics(self):
|
|
493
505
|
"""
|
|
494
506
|
Get the metric computation functions.
|
|
495
|
-
|
|
507
|
+
|
|
496
508
|
Returns:
|
|
497
509
|
list: List of metric computation functions
|
|
498
510
|
"""
|
|
@@ -501,10 +513,10 @@ class Trainer:
|
|
|
501
513
|
def unwrap_model(self, model=None):
|
|
502
514
|
"""
|
|
503
515
|
Unwrap the model from any distributed training wrappers.
|
|
504
|
-
|
|
516
|
+
|
|
505
517
|
Args:
|
|
506
518
|
model: Model to unwrap (default: None, uses self.model)
|
|
507
|
-
|
|
519
|
+
|
|
508
520
|
Returns:
|
|
509
521
|
The unwrapped model
|
|
510
522
|
"""
|
|
@@ -538,7 +550,7 @@ class Trainer:
|
|
|
538
550
|
"""
|
|
539
551
|
if os.path.exists(self._model_state_dict_path):
|
|
540
552
|
self.unwrap_model().load_state_dict(
|
|
541
|
-
torch.load(self._model_state_dict_path, map_location=
|
|
553
|
+
torch.load(self._model_state_dict_path, map_location="cpu")
|
|
542
554
|
)
|
|
543
555
|
self.unwrap_model().to(self.device)
|
|
544
556
|
|