omnigenome 0.3.0a0__py3-none-any.whl → 0.3.1a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- omnigenome/__init__.py +29 -44
- omnigenome/auto/auto_bench/__init__.py +0 -1
- omnigenome/auto/auto_bench/auto_bench.py +24 -14
- omnigenome/auto/auto_train/__init__.py +0 -1
- omnigenome/auto/auto_train/auto_train.py +11 -12
- omnigenome/auto/bench_hub/__init__.py +0 -1
- omnigenome/auto/bench_hub/bench_hub.py +1 -1
- omnigenome/cli/__init__.py +0 -1
- omnigenome/cli/commands/__init__.py +0 -1
- omnigenome/cli/commands/base.py +10 -10
- omnigenome/cli/commands/bench/__init__.py +0 -1
- omnigenome/cli/commands/bench/bench_cli.py +10 -10
- omnigenome/cli/commands/rna/__init__.py +0 -1
- omnigenome/cli/commands/rna/rna_design.py +10 -11
- omnigenome/src/__init__.py +0 -1
- omnigenome/src/abc/__init__.py +0 -1
- omnigenome/src/abc/abstract_dataset.py +38 -19
- omnigenome/src/abc/abstract_metric.py +7 -7
- omnigenome/src/abc/abstract_model.py +15 -14
- omnigenome/src/abc/abstract_tokenizer.py +9 -7
- omnigenome/src/dataset/omni_dataset.py +16 -14
- omnigenome/src/lora/__init__.py +0 -1
- omnigenome/src/lora/lora_model.py +47 -41
- omnigenome/src/metric/classification_metric.py +11 -11
- omnigenome/src/metric/metric.py +19 -19
- omnigenome/src/metric/ranking_metric.py +15 -15
- omnigenome/src/metric/regression_metric.py +18 -18
- omnigenome/src/misc/utils.py +214 -150
- omnigenome/src/model/augmentation/__init__.py +0 -1
- omnigenome/src/model/augmentation/model.py +17 -17
- omnigenome/src/model/classification/__init__.py +0 -1
- omnigenome/src/model/classification/model.py +28 -32
- omnigenome/src/model/embedding/__init__.py +0 -1
- omnigenome/src/model/embedding/model.py +35 -35
- omnigenome/src/model/mlm/__init__.py +0 -1
- omnigenome/src/model/mlm/model.py +13 -13
- omnigenome/src/model/module_utils.py +17 -17
- omnigenome/src/model/regression/__init__.py +0 -1
- omnigenome/src/model/regression/model.py +72 -77
- omnigenome/src/model/regression/resnet.py +32 -32
- omnigenome/src/model/rna_design/__init__.py +0 -1
- omnigenome/src/model/rna_design/model.py +168 -118
- omnigenome/src/model/seq2seq/__init__.py +0 -1
- omnigenome/src/model/seq2seq/model.py +4 -4
- omnigenome/src/tokenizer/bpe_tokenizer.py +27 -27
- omnigenome/src/tokenizer/kmers_tokenizer.py +22 -22
- omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +11 -11
- omnigenome/src/trainer/accelerate_trainer.py +40 -32
- omnigenome/src/trainer/hf_trainer.py +8 -8
- omnigenome/src/trainer/trainer.py +37 -25
- omnigenome/utility/dataset_hub/__init__.py +0 -1
- omnigenome/utility/dataset_hub/dataset_hub.py +13 -13
- omnigenome/utility/ensemble.py +26 -26
- omnigenome/utility/hub_utils.py +8 -8
- omnigenome/utility/model_hub/__init__.py +0 -1
- omnigenome/utility/model_hub/model_hub.py +26 -25
- omnigenome/utility/pipeline_hub/__init__.py +0 -1
- omnigenome/utility/pipeline_hub/pipeline.py +49 -49
- omnigenome/utility/pipeline_hub/pipeline_hub.py +17 -17
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/METADATA +3 -3
- omnigenome-0.3.1a0.dist-info/RECORD +78 -0
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/top_level.txt +0 -1
- omnigenome-0.3.0a0.dist-info/RECORD +0 -85
- tests/__init__.py +0 -9
- tests/conftest.py +0 -160
- tests/test_dataset_patterns.py +0 -291
- tests/test_examples_syntax.py +0 -83
- tests/test_model_loading.py +0 -183
- tests/test_rna_functions.py +0 -255
- tests/test_training_patterns.py +0 -302
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/WHEEL +0 -0
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/entry_points.txt +0 -0
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/licenses/LICENSE +0 -0
|
@@ -24,12 +24,12 @@ import autocuda
|
|
|
24
24
|
class OmniModelForAugmentation(torch.nn.Module):
|
|
25
25
|
"""
|
|
26
26
|
Data augmentation model for genomic sequences using masked language modeling.
|
|
27
|
-
|
|
27
|
+
|
|
28
28
|
This model uses a pre-trained masked language model to generate augmented
|
|
29
29
|
versions of genomic sequences by randomly masking tokens and predicting
|
|
30
30
|
replacements. It's useful for expanding training datasets and improving
|
|
31
31
|
model generalization.
|
|
32
|
-
|
|
32
|
+
|
|
33
33
|
Attributes:
|
|
34
34
|
tokenizer: Tokenizer for processing genomic sequences
|
|
35
35
|
model: Pre-trained masked language model
|
|
@@ -38,7 +38,7 @@ class OmniModelForAugmentation(torch.nn.Module):
|
|
|
38
38
|
max_length: Maximum sequence length for tokenization
|
|
39
39
|
k: Number of augmented instances to generate per sequence
|
|
40
40
|
"""
|
|
41
|
-
|
|
41
|
+
|
|
42
42
|
def __init__(
|
|
43
43
|
self,
|
|
44
44
|
model_name_or_path=None,
|
|
@@ -50,7 +50,7 @@ class OmniModelForAugmentation(torch.nn.Module):
|
|
|
50
50
|
):
|
|
51
51
|
"""
|
|
52
52
|
Initialize the augmentation model.
|
|
53
|
-
|
|
53
|
+
|
|
54
54
|
Args:
|
|
55
55
|
model_name_or_path (str): Path or model name for loading the pre-trained model
|
|
56
56
|
noise_ratio (float): The proportion of tokens to mask in each sequence for augmentation (default: 0.15)
|
|
@@ -82,10 +82,10 @@ class OmniModelForAugmentation(torch.nn.Module):
|
|
|
82
82
|
def load_sequences_from_file(self, input_file):
|
|
83
83
|
"""
|
|
84
84
|
Load sequences from a JSON file.
|
|
85
|
-
|
|
85
|
+
|
|
86
86
|
Args:
|
|
87
87
|
input_file (str): Path to the input JSON file containing sequences
|
|
88
|
-
|
|
88
|
+
|
|
89
89
|
Returns:
|
|
90
90
|
list: List of sequences loaded from the file
|
|
91
91
|
"""
|
|
@@ -98,10 +98,10 @@ class OmniModelForAugmentation(torch.nn.Module):
|
|
|
98
98
|
def apply_noise_to_sequence(self, seq):
|
|
99
99
|
"""
|
|
100
100
|
Apply noise to a single sequence by randomly masking tokens.
|
|
101
|
-
|
|
101
|
+
|
|
102
102
|
Args:
|
|
103
103
|
seq (str): Input genomic sequence
|
|
104
|
-
|
|
104
|
+
|
|
105
105
|
Returns:
|
|
106
106
|
str: Sequence with randomly masked tokens
|
|
107
107
|
"""
|
|
@@ -114,10 +114,10 @@ class OmniModelForAugmentation(torch.nn.Module):
|
|
|
114
114
|
def augment_sequence(self, seq):
|
|
115
115
|
"""
|
|
116
116
|
Perform augmentation on a single sequence by predicting masked tokens.
|
|
117
|
-
|
|
117
|
+
|
|
118
118
|
Args:
|
|
119
119
|
seq (str): Input genomic sequence with masked tokens
|
|
120
|
-
|
|
120
|
+
|
|
121
121
|
Returns:
|
|
122
122
|
str: Augmented sequence with predicted tokens replacing masked tokens
|
|
123
123
|
"""
|
|
@@ -145,11 +145,11 @@ class OmniModelForAugmentation(torch.nn.Module):
|
|
|
145
145
|
def augment(self, seq, k=None):
|
|
146
146
|
"""
|
|
147
147
|
Generate multiple augmented instances for a single sequence.
|
|
148
|
-
|
|
148
|
+
|
|
149
149
|
Args:
|
|
150
150
|
seq (str): Input genomic sequence
|
|
151
151
|
k (int, optional): Number of augmented instances to generate (default: None, uses self.k)
|
|
152
|
-
|
|
152
|
+
|
|
153
153
|
Returns:
|
|
154
154
|
list: List of augmented sequences
|
|
155
155
|
"""
|
|
@@ -163,10 +163,10 @@ class OmniModelForAugmentation(torch.nn.Module):
|
|
|
163
163
|
def augment_sequences(self, sequences):
|
|
164
164
|
"""
|
|
165
165
|
Augment a list of sequences by applying noise and performing MLM-based predictions.
|
|
166
|
-
|
|
166
|
+
|
|
167
167
|
Args:
|
|
168
168
|
sequences (list): List of genomic sequences to augment
|
|
169
|
-
|
|
169
|
+
|
|
170
170
|
Returns:
|
|
171
171
|
list: List of all augmented sequences
|
|
172
172
|
"""
|
|
@@ -179,7 +179,7 @@ class OmniModelForAugmentation(torch.nn.Module):
|
|
|
179
179
|
def save_augmented_sequences(self, augmented_sequences, output_file):
|
|
180
180
|
"""
|
|
181
181
|
Save augmented sequences to a JSON file.
|
|
182
|
-
|
|
182
|
+
|
|
183
183
|
Args:
|
|
184
184
|
augmented_sequences (list): List of augmented sequences to save
|
|
185
185
|
output_file (str): Path to the output JSON file
|
|
@@ -191,10 +191,10 @@ class OmniModelForAugmentation(torch.nn.Module):
|
|
|
191
191
|
def augment_from_file(self, input_file, output_file):
|
|
192
192
|
"""
|
|
193
193
|
Main function to handle the augmentation process from a file input to a file output.
|
|
194
|
-
|
|
194
|
+
|
|
195
195
|
This method loads sequences from an input file, augments them using the MLM model,
|
|
196
196
|
and saves the augmented sequences to an output file.
|
|
197
|
-
|
|
197
|
+
|
|
198
198
|
Args:
|
|
199
199
|
input_file (str): Path to the input file containing sequences
|
|
200
200
|
output_file (str): Path to the output file where augmented sequences will be saved
|
|
@@ -16,16 +16,16 @@ from ..module_utils import OmniPooling
|
|
|
16
16
|
class OmniModelForTokenClassification(OmniModel):
|
|
17
17
|
"""
|
|
18
18
|
Model for token classification tasks in genomics.
|
|
19
|
-
|
|
19
|
+
|
|
20
20
|
This model is designed for token-level classification tasks such as
|
|
21
21
|
sequence labeling, where each token in the input sequence needs to be
|
|
22
22
|
classified into different categories. It extends the base OmniModel
|
|
23
23
|
with token-level classification capabilities.
|
|
24
|
-
|
|
24
|
+
|
|
25
25
|
The model adds a classification head on top of the base model's hidden
|
|
26
26
|
states and applies softmax to produce probability distributions over
|
|
27
27
|
the label classes for each token.
|
|
28
|
-
|
|
28
|
+
|
|
29
29
|
Attributes:
|
|
30
30
|
softmax (torch.nn.Softmax): Softmax layer for probability computation.
|
|
31
31
|
classifier (torch.nn.Linear): Linear classification head.
|
|
@@ -57,7 +57,7 @@ class OmniModelForTokenClassification(OmniModel):
|
|
|
57
57
|
def forward(self, **inputs):
|
|
58
58
|
"""
|
|
59
59
|
Forward pass for token classification.
|
|
60
|
-
|
|
60
|
+
|
|
61
61
|
This method performs the forward pass through the model, computing
|
|
62
62
|
logits for each token in the input sequence and applying softmax
|
|
63
63
|
to produce probability distributions.
|
|
@@ -95,13 +95,13 @@ class OmniModelForTokenClassification(OmniModel):
|
|
|
95
95
|
def predict(self, sequence_or_inputs, **kwargs):
|
|
96
96
|
"""
|
|
97
97
|
Performs token-level prediction on raw inputs.
|
|
98
|
-
|
|
98
|
+
|
|
99
99
|
This method takes raw sequences or tokenized inputs and returns
|
|
100
100
|
token-level predictions. It processes the inputs through the model
|
|
101
101
|
and returns the predicted class for each token.
|
|
102
102
|
|
|
103
103
|
Args:
|
|
104
|
-
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
104
|
+
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
105
105
|
tokenized inputs (dict/tuple).
|
|
106
106
|
**kwargs: Additional arguments for tokenization and inference.
|
|
107
107
|
|
|
@@ -115,7 +115,7 @@ class OmniModelForTokenClassification(OmniModel):
|
|
|
115
115
|
>>> # Predict on a single sequence
|
|
116
116
|
>>> outputs = model.predict("ATCGATCG")
|
|
117
117
|
>>> print(outputs['predictions'].shape) # (seq_len,)
|
|
118
|
-
|
|
118
|
+
|
|
119
119
|
>>> # Predict on multiple sequences
|
|
120
120
|
>>> outputs = model.predict(["ATCGATCG", "GCTAGCTA"])
|
|
121
121
|
"""
|
|
@@ -142,12 +142,12 @@ class OmniModelForTokenClassification(OmniModel):
|
|
|
142
142
|
def inference(self, sequence_or_inputs, **kwargs):
|
|
143
143
|
"""
|
|
144
144
|
Performs token-level inference with human-readable output.
|
|
145
|
-
|
|
145
|
+
|
|
146
146
|
This method provides processed, human-readable token-level predictions.
|
|
147
147
|
It converts logits to class labels and handles special tokens appropriately.
|
|
148
148
|
|
|
149
149
|
Args:
|
|
150
|
-
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
150
|
+
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
151
151
|
tokenized inputs (dict/tuple).
|
|
152
152
|
**kwargs: Additional arguments for tokenization and inference.
|
|
153
153
|
|
|
@@ -200,7 +200,7 @@ class OmniModelForTokenClassification(OmniModel):
|
|
|
200
200
|
def loss_function(self, logits, labels):
|
|
201
201
|
"""
|
|
202
202
|
Calculates the cross-entropy loss for token classification.
|
|
203
|
-
|
|
203
|
+
|
|
204
204
|
This method computes the cross-entropy loss between the predicted
|
|
205
205
|
logits and the ground truth labels, ignoring padding tokens.
|
|
206
206
|
|
|
@@ -221,11 +221,11 @@ class OmniModelForTokenClassification(OmniModel):
|
|
|
221
221
|
class OmniModelForSequenceClassification(OmniModel):
|
|
222
222
|
"""
|
|
223
223
|
Model for sequence classification tasks in genomics.
|
|
224
|
-
|
|
224
|
+
|
|
225
225
|
This model is designed for sequence-level classification tasks where
|
|
226
226
|
the entire input sequence is classified into one of several categories.
|
|
227
227
|
It extends the base OmniModel with sequence-level classification capabilities.
|
|
228
|
-
|
|
228
|
+
|
|
229
229
|
The model uses a pooling mechanism to aggregate token-level representations
|
|
230
230
|
into a sequence-level representation, which is then classified using a
|
|
231
231
|
linear classifier.
|
|
@@ -263,7 +263,7 @@ class OmniModelForSequenceClassification(OmniModel):
|
|
|
263
263
|
def forward(self, **inputs):
|
|
264
264
|
"""
|
|
265
265
|
Forward pass for sequence classification.
|
|
266
|
-
|
|
266
|
+
|
|
267
267
|
This method performs the forward pass through the model, computing
|
|
268
268
|
sequence-level logits and applying softmax to produce probability
|
|
269
269
|
distributions over the label classes.
|
|
@@ -302,13 +302,13 @@ class OmniModelForSequenceClassification(OmniModel):
|
|
|
302
302
|
def predict(self, sequence_or_inputs, **kwargs):
|
|
303
303
|
"""
|
|
304
304
|
Performs sequence-level prediction on raw inputs.
|
|
305
|
-
|
|
305
|
+
|
|
306
306
|
This method takes raw sequences or tokenized inputs and returns
|
|
307
307
|
sequence-level predictions. It processes the inputs through the model
|
|
308
308
|
and returns the predicted class for each sequence.
|
|
309
309
|
|
|
310
310
|
Args:
|
|
311
|
-
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
311
|
+
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
312
312
|
tokenized inputs (dict/tuple).
|
|
313
313
|
**kwargs: Additional arguments for tokenization and inference.
|
|
314
314
|
|
|
@@ -322,7 +322,7 @@ class OmniModelForSequenceClassification(OmniModel):
|
|
|
322
322
|
>>> # Predict on a single sequence
|
|
323
323
|
>>> outputs = model.predict("ATCGATCG")
|
|
324
324
|
>>> print(outputs['predictions']) # tensor([0])
|
|
325
|
-
|
|
325
|
+
|
|
326
326
|
>>> # Predict on multiple sequences
|
|
327
327
|
>>> outputs = model.predict(["ATCGATCG", "GCTAGCTA"])
|
|
328
328
|
"""
|
|
@@ -350,12 +350,12 @@ class OmniModelForSequenceClassification(OmniModel):
|
|
|
350
350
|
def inference(self, sequence_or_inputs, **kwargs):
|
|
351
351
|
"""
|
|
352
352
|
Performs sequence-level inference with human-readable output.
|
|
353
|
-
|
|
353
|
+
|
|
354
354
|
This method provides processed, human-readable sequence-level predictions.
|
|
355
355
|
It converts logits to class labels and provides confidence scores.
|
|
356
356
|
|
|
357
357
|
Args:
|
|
358
|
-
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
358
|
+
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
359
359
|
tokenized inputs (dict/tuple).
|
|
360
360
|
**kwargs: Additional arguments for tokenization and inference.
|
|
361
361
|
|
|
@@ -403,7 +403,7 @@ class OmniModelForSequenceClassification(OmniModel):
|
|
|
403
403
|
def loss_function(self, logits, labels):
|
|
404
404
|
"""
|
|
405
405
|
Calculates the cross-entropy loss for sequence classification.
|
|
406
|
-
|
|
406
|
+
|
|
407
407
|
This method computes the cross-entropy loss between the predicted
|
|
408
408
|
logits and the ground truth labels.
|
|
409
409
|
|
|
@@ -421,16 +421,14 @@ class OmniModelForSequenceClassification(OmniModel):
|
|
|
421
421
|
return loss
|
|
422
422
|
|
|
423
423
|
|
|
424
|
-
class OmniModelForMultiLabelSequenceClassification(
|
|
425
|
-
OmniModelForSequenceClassification
|
|
426
|
-
):
|
|
424
|
+
class OmniModelForMultiLabelSequenceClassification(OmniModelForSequenceClassification):
|
|
427
425
|
"""
|
|
428
426
|
Model for multi-label sequence classification tasks in genomics.
|
|
429
|
-
|
|
427
|
+
|
|
430
428
|
This model is designed for multi-label classification tasks where
|
|
431
429
|
a single sequence can be assigned multiple labels simultaneously.
|
|
432
430
|
It extends the sequence classification model with multi-label capabilities.
|
|
433
|
-
|
|
431
|
+
|
|
434
432
|
The model uses sigmoid activation instead of softmax to allow multiple
|
|
435
433
|
labels per sequence and uses binary cross-entropy loss for training.
|
|
436
434
|
|
|
@@ -461,7 +459,7 @@ class OmniModelForMultiLabelSequenceClassification(
|
|
|
461
459
|
def loss_function(self, logits, labels):
|
|
462
460
|
"""
|
|
463
461
|
Calculates the binary cross-entropy loss for multi-label classification.
|
|
464
|
-
|
|
462
|
+
|
|
465
463
|
This method computes the binary cross-entropy loss between the predicted
|
|
466
464
|
probabilities and the ground truth multi-label targets.
|
|
467
465
|
|
|
@@ -481,13 +479,13 @@ class OmniModelForMultiLabelSequenceClassification(
|
|
|
481
479
|
def predict(self, sequence_or_inputs, **kwargs):
|
|
482
480
|
"""
|
|
483
481
|
Performs multi-label prediction on raw inputs.
|
|
484
|
-
|
|
482
|
+
|
|
485
483
|
This method takes raw sequences or tokenized inputs and returns
|
|
486
484
|
multi-label predictions. It applies a threshold to determine
|
|
487
485
|
which labels are active for each sequence.
|
|
488
486
|
|
|
489
487
|
Args:
|
|
490
|
-
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
488
|
+
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
491
489
|
tokenized inputs (dict/tuple).
|
|
492
490
|
**kwargs: Additional arguments for tokenization and inference.
|
|
493
491
|
|
|
@@ -527,12 +525,12 @@ class OmniModelForMultiLabelSequenceClassification(
|
|
|
527
525
|
def inference(self, sequence_or_inputs, **kwargs):
|
|
528
526
|
"""
|
|
529
527
|
Performs multi-label inference with human-readable output.
|
|
530
|
-
|
|
528
|
+
|
|
531
529
|
This method provides processed, human-readable multi-label predictions.
|
|
532
530
|
It converts logits to binary labels and provides confidence scores.
|
|
533
531
|
|
|
534
532
|
Args:
|
|
535
|
-
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
533
|
+
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
536
534
|
tokenized inputs (dict/tuple).
|
|
537
535
|
**kwargs: Additional arguments for tokenization and inference.
|
|
538
536
|
|
|
@@ -551,9 +549,7 @@ class OmniModelForMultiLabelSequenceClassification(
|
|
|
551
549
|
return self.predict(sequence_or_inputs, **kwargs)
|
|
552
550
|
|
|
553
551
|
|
|
554
|
-
class OmniModelForTokenClassificationWith2DStructure(
|
|
555
|
-
OmniModelForTokenClassification
|
|
556
|
-
):
|
|
552
|
+
class OmniModelForTokenClassificationWith2DStructure(OmniModelForTokenClassification):
|
|
557
553
|
def __init__(self, config_or_model, tokenizer, *args, **kwargs):
|
|
558
554
|
super().__init__(config_or_model, tokenizer, *args, **kwargs)
|
|
559
555
|
self.metadata["model_name"] = self.__class__.__name__
|
|
@@ -16,16 +16,16 @@ from omnigenome.src.misc.utils import fprint
|
|
|
16
16
|
class OmniModelForEmbedding(torch.nn.Module):
|
|
17
17
|
"""
|
|
18
18
|
A wrapper class for generating embeddings from pre-trained models.
|
|
19
|
-
|
|
19
|
+
|
|
20
20
|
This class provides a unified interface for loading pre-trained models and
|
|
21
21
|
generating embeddings from genomic sequences. It supports various aggregation
|
|
22
22
|
methods and batch processing for efficient embedding generation.
|
|
23
|
-
|
|
23
|
+
|
|
24
24
|
Attributes:
|
|
25
25
|
tokenizer: The tokenizer for processing input sequences
|
|
26
26
|
model: The pre-trained model for generating embeddings
|
|
27
27
|
_device: The device (CPU/GPU) where the model is loaded
|
|
28
|
-
|
|
28
|
+
|
|
29
29
|
Example:
|
|
30
30
|
>>> from omnigenome import OmniModelForEmbedding
|
|
31
31
|
>>> model = OmniModelForEmbedding("anonymous8/OmniGenome-186M")
|
|
@@ -34,11 +34,11 @@ class OmniModelForEmbedding(torch.nn.Module):
|
|
|
34
34
|
>>> print(f"Embeddings shape: {embeddings.shape}")
|
|
35
35
|
torch.Size([2, 768])
|
|
36
36
|
"""
|
|
37
|
-
|
|
37
|
+
|
|
38
38
|
def __init__(self, model_name_or_path, *args, **kwargs):
|
|
39
39
|
"""
|
|
40
40
|
Initialize the embedding model.
|
|
41
|
-
|
|
41
|
+
|
|
42
42
|
Args:
|
|
43
43
|
model_name_or_path (str): Name or path of the pre-trained model to load
|
|
44
44
|
*args: Additional positional arguments passed to AutoModel.from_pretrained
|
|
@@ -51,25 +51,25 @@ class OmniModelForEmbedding(torch.nn.Module):
|
|
|
51
51
|
self.model.to(self._device)
|
|
52
52
|
self.model.eval() # Set model to evaluation mode
|
|
53
53
|
|
|
54
|
-
def batch_encode(self, sequences, batch_size=8, max_length=512, agg=
|
|
54
|
+
def batch_encode(self, sequences, batch_size=8, max_length=512, agg="head"):
|
|
55
55
|
"""
|
|
56
56
|
Encode a list of sequences to their corresponding embeddings.
|
|
57
|
-
|
|
57
|
+
|
|
58
58
|
This method processes sequences in batches for memory efficiency and
|
|
59
59
|
supports different aggregation methods for the final embeddings.
|
|
60
|
-
|
|
60
|
+
|
|
61
61
|
Args:
|
|
62
62
|
sequences (list of str): List of input sequences to encode
|
|
63
63
|
batch_size (int, optional): Batch size for processing. Defaults to 8
|
|
64
64
|
max_length (int, optional): Maximum sequence length for encoding. Defaults to 512
|
|
65
65
|
agg (str, optional): Aggregation method for embeddings. Options are 'head', 'mean', 'tail'. Defaults to 'head'
|
|
66
|
-
|
|
66
|
+
|
|
67
67
|
Returns:
|
|
68
68
|
torch.Tensor: Embeddings for the input sequences with shape (n_sequences, embedding_dim)
|
|
69
|
-
|
|
69
|
+
|
|
70
70
|
Raises:
|
|
71
71
|
ValueError: If unsupported aggregation method is provided
|
|
72
|
-
|
|
72
|
+
|
|
73
73
|
Example:
|
|
74
74
|
>>> sequences = ["ATCGGCTA", "GGCTAGCTA", "TATCGCTA"]
|
|
75
75
|
>>> embeddings = model.batch_encode(sequences, batch_size=2, agg='mean')
|
|
@@ -79,7 +79,7 @@ class OmniModelForEmbedding(torch.nn.Module):
|
|
|
79
79
|
embeddings = []
|
|
80
80
|
|
|
81
81
|
for i in range(0, len(sequences), batch_size):
|
|
82
|
-
batch_sequences = sequences[i: i + batch_size]
|
|
82
|
+
batch_sequences = sequences[i : i + batch_size]
|
|
83
83
|
inputs = self.tokenizer(
|
|
84
84
|
batch_sequences,
|
|
85
85
|
return_tensors="pt",
|
|
@@ -94,19 +94,19 @@ class OmniModelForEmbedding(torch.nn.Module):
|
|
|
94
94
|
|
|
95
95
|
batch_embeddings = outputs.last_hidden_state.cpu()
|
|
96
96
|
|
|
97
|
-
if agg ==
|
|
97
|
+
if agg == "head":
|
|
98
98
|
emb = batch_embeddings[:, 0, :]
|
|
99
|
-
elif agg ==
|
|
99
|
+
elif agg == "mean":
|
|
100
100
|
attention_mask = inputs["attention_mask"].cpu()
|
|
101
101
|
masked_embeddings = batch_embeddings * attention_mask.unsqueeze(-1)
|
|
102
102
|
lengths = attention_mask.sum(dim=1).unsqueeze(1)
|
|
103
103
|
emb = masked_embeddings.sum(dim=1) / lengths
|
|
104
|
-
elif agg ==
|
|
104
|
+
elif agg == "tail":
|
|
105
105
|
attention_mask = inputs["attention_mask"]
|
|
106
106
|
lengths = attention_mask.sum(dim=1) - 1
|
|
107
|
-
emb = torch.stack(
|
|
108
|
-
batch_embeddings[i, l.item()] for i, l in enumerate(lengths)
|
|
109
|
-
|
|
107
|
+
emb = torch.stack(
|
|
108
|
+
[batch_embeddings[i, l.item()] for i, l in enumerate(lengths)]
|
|
109
|
+
)
|
|
110
110
|
else:
|
|
111
111
|
raise ValueError(f"Unsupported aggregation method: {agg}")
|
|
112
112
|
|
|
@@ -116,22 +116,22 @@ class OmniModelForEmbedding(torch.nn.Module):
|
|
|
116
116
|
fprint(f"Generated embeddings for {len(sequences)} sequences.")
|
|
117
117
|
return embeddings
|
|
118
118
|
|
|
119
|
-
def encode(self, sequence, max_length=512, agg=
|
|
119
|
+
def encode(self, sequence, max_length=512, agg="head", keep_dim=False):
|
|
120
120
|
"""
|
|
121
121
|
Encode a single sequence to its corresponding embedding.
|
|
122
|
-
|
|
122
|
+
|
|
123
123
|
Args:
|
|
124
124
|
sequence (str): Input sequence to encode
|
|
125
125
|
max_length (int, optional): Maximum sequence length for encoding. Defaults to 512
|
|
126
126
|
agg (str, optional): Aggregation method. Options are 'head', 'mean', 'tail'. Defaults to 'head'
|
|
127
127
|
keep_dim (bool, optional): Whether to retain the batch dimension. Defaults to False
|
|
128
|
-
|
|
128
|
+
|
|
129
129
|
Returns:
|
|
130
130
|
torch.Tensor: Embedding for the input sequence
|
|
131
|
-
|
|
131
|
+
|
|
132
132
|
Raises:
|
|
133
133
|
ValueError: If unsupported aggregation method is provided
|
|
134
|
-
|
|
134
|
+
|
|
135
135
|
Example:
|
|
136
136
|
>>> sequence = "ATCGGCTA"
|
|
137
137
|
>>> embedding = model.encode(sequence, agg='mean')
|
|
@@ -152,15 +152,15 @@ class OmniModelForEmbedding(torch.nn.Module):
|
|
|
152
152
|
|
|
153
153
|
last_hidden = outputs.last_hidden_state.cpu()
|
|
154
154
|
|
|
155
|
-
if agg ==
|
|
155
|
+
if agg == "head":
|
|
156
156
|
emb = last_hidden[0, 0]
|
|
157
|
-
elif agg ==
|
|
157
|
+
elif agg == "mean":
|
|
158
158
|
attention_mask = inputs["attention_mask"].cpu()
|
|
159
159
|
masked_embeddings = last_hidden * attention_mask.unsqueeze(-1)
|
|
160
160
|
lengths = attention_mask.sum(dim=1).unsqueeze(1)
|
|
161
161
|
emb = masked_embeddings.sum(dim=1) / lengths
|
|
162
162
|
emb = emb.squeeze(0)
|
|
163
|
-
elif agg ==
|
|
163
|
+
elif agg == "tail":
|
|
164
164
|
attention_mask = inputs["attention_mask"]
|
|
165
165
|
lengths = attention_mask.sum(dim=1) - 1
|
|
166
166
|
emb = last_hidden[0, lengths[0].item()]
|
|
@@ -172,11 +172,11 @@ class OmniModelForEmbedding(torch.nn.Module):
|
|
|
172
172
|
def save_embeddings(self, embeddings, output_path):
|
|
173
173
|
"""
|
|
174
174
|
Save the generated embeddings to a file.
|
|
175
|
-
|
|
175
|
+
|
|
176
176
|
Args:
|
|
177
177
|
embeddings (torch.Tensor): The embeddings to save
|
|
178
178
|
output_path (str): Path to save the embeddings
|
|
179
|
-
|
|
179
|
+
|
|
180
180
|
Example:
|
|
181
181
|
>>> embeddings = model.batch_encode(sequences)
|
|
182
182
|
>>> model.save_embeddings(embeddings, "embeddings.pt")
|
|
@@ -188,13 +188,13 @@ class OmniModelForEmbedding(torch.nn.Module):
|
|
|
188
188
|
def load_embeddings(self, embedding_path):
|
|
189
189
|
"""
|
|
190
190
|
Load embeddings from a file.
|
|
191
|
-
|
|
191
|
+
|
|
192
192
|
Args:
|
|
193
193
|
embedding_path (str): Path to the saved embeddings
|
|
194
|
-
|
|
194
|
+
|
|
195
195
|
Returns:
|
|
196
196
|
torch.Tensor: The loaded embeddings
|
|
197
|
-
|
|
197
|
+
|
|
198
198
|
Example:
|
|
199
199
|
>>> embeddings = model.load_embeddings("embeddings.pt")
|
|
200
200
|
>>> print(f"Loaded embeddings shape: {embeddings.shape}")
|
|
@@ -207,15 +207,15 @@ class OmniModelForEmbedding(torch.nn.Module):
|
|
|
207
207
|
def compute_similarity(self, embedding1, embedding2, dim=0):
|
|
208
208
|
"""
|
|
209
209
|
Compute cosine similarity between two embeddings.
|
|
210
|
-
|
|
210
|
+
|
|
211
211
|
Args:
|
|
212
212
|
embedding1 (torch.Tensor): The first embedding
|
|
213
213
|
embedding2 (torch.Tensor): The second embedding
|
|
214
214
|
dim (int, optional): Dimension along which to compute cosine similarity. Defaults to 0
|
|
215
|
-
|
|
215
|
+
|
|
216
216
|
Returns:
|
|
217
217
|
float: Cosine similarity score between -1 and 1
|
|
218
|
-
|
|
218
|
+
|
|
219
219
|
Example:
|
|
220
220
|
>>> emb1 = model.encode("ATCGGCTA")
|
|
221
221
|
>>> emb2 = model.encode("GGCTAGCTA")
|
|
@@ -232,7 +232,7 @@ class OmniModelForEmbedding(torch.nn.Module):
|
|
|
232
232
|
def device(self):
|
|
233
233
|
"""
|
|
234
234
|
Get the current device ('cuda' or 'cpu').
|
|
235
|
-
|
|
235
|
+
|
|
236
236
|
Returns:
|
|
237
237
|
torch.device: The device where the model is loaded
|
|
238
238
|
"""
|
|
@@ -23,26 +23,26 @@ from ...abc.abstract_model import OmniModel
|
|
|
23
23
|
class OmniModelForMLM(OmniModel):
|
|
24
24
|
"""
|
|
25
25
|
Masked Language Model for genomic sequences.
|
|
26
|
-
|
|
26
|
+
|
|
27
27
|
This model implements masked language modeling for genomic sequences, where
|
|
28
28
|
tokens are randomly masked and the model learns to predict the original tokens.
|
|
29
29
|
It's useful for pre-training genomic language models and understanding sequence
|
|
30
30
|
patterns and dependencies.
|
|
31
|
-
|
|
31
|
+
|
|
32
32
|
Attributes:
|
|
33
33
|
loss_fn: Cross-entropy loss function for masked language modeling
|
|
34
34
|
"""
|
|
35
|
-
|
|
35
|
+
|
|
36
36
|
def __init__(self, config_or_model, tokenizer, *args, **kwargs):
|
|
37
37
|
"""
|
|
38
38
|
Initialize the MLM model.
|
|
39
|
-
|
|
39
|
+
|
|
40
40
|
Args:
|
|
41
41
|
config_or_model: Model configuration or pre-trained model
|
|
42
42
|
tokenizer: Tokenizer for processing input sequences
|
|
43
43
|
*args: Additional positional arguments
|
|
44
44
|
**kwargs: Additional keyword arguments
|
|
45
|
-
|
|
45
|
+
|
|
46
46
|
Raises:
|
|
47
47
|
ValueError: If the model doesn't support masked language modeling
|
|
48
48
|
"""
|
|
@@ -59,10 +59,10 @@ class OmniModelForMLM(OmniModel):
|
|
|
59
59
|
def forward(self, **inputs):
|
|
60
60
|
"""
|
|
61
61
|
Forward pass for masked language modeling.
|
|
62
|
-
|
|
62
|
+
|
|
63
63
|
Args:
|
|
64
64
|
**inputs: Input tensors including input_ids, attention_mask, and labels
|
|
65
|
-
|
|
65
|
+
|
|
66
66
|
Returns:
|
|
67
67
|
dict: Dictionary containing loss, logits, and last_hidden_state
|
|
68
68
|
"""
|
|
@@ -85,11 +85,11 @@ class OmniModelForMLM(OmniModel):
|
|
|
85
85
|
def predict(self, sequence_or_inputs, **kwargs):
|
|
86
86
|
"""
|
|
87
87
|
Generate predictions for masked language modeling.
|
|
88
|
-
|
|
88
|
+
|
|
89
89
|
Args:
|
|
90
90
|
sequence_or_inputs: Input sequences or pre-processed inputs
|
|
91
91
|
**kwargs: Additional keyword arguments
|
|
92
|
-
|
|
92
|
+
|
|
93
93
|
Returns:
|
|
94
94
|
dict: Dictionary containing predictions, logits, and last_hidden_state
|
|
95
95
|
"""
|
|
@@ -124,11 +124,11 @@ class OmniModelForMLM(OmniModel):
|
|
|
124
124
|
def inference(self, sequence_or_inputs, **kwargs):
|
|
125
125
|
"""
|
|
126
126
|
Perform inference for masked language modeling, decoding predictions to sequences.
|
|
127
|
-
|
|
127
|
+
|
|
128
128
|
Args:
|
|
129
129
|
sequence_or_inputs: Input sequences or pre-processed inputs
|
|
130
130
|
**kwargs: Additional keyword arguments
|
|
131
|
-
|
|
131
|
+
|
|
132
132
|
Returns:
|
|
133
133
|
dict: Dictionary containing decoded predictions, logits, and last_hidden_state
|
|
134
134
|
"""
|
|
@@ -164,11 +164,11 @@ class OmniModelForMLM(OmniModel):
|
|
|
164
164
|
def loss_function(self, logits, labels):
|
|
165
165
|
"""
|
|
166
166
|
Compute the loss for masked language modeling.
|
|
167
|
-
|
|
167
|
+
|
|
168
168
|
Args:
|
|
169
169
|
logits (torch.Tensor): Model predictions [batch_size, seq_len, vocab_size]
|
|
170
170
|
labels (torch.Tensor): Ground truth labels [batch_size, seq_len]
|
|
171
|
-
|
|
171
|
+
|
|
172
172
|
Returns:
|
|
173
173
|
torch.Tensor: Computed cross-entropy loss value
|
|
174
174
|
"""
|