omnigenome 0.3.0a1__py3-none-any.whl → 0.3.1a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- omnigenome/__init__.py +16 -8
- omnigenome/auto/auto_bench/__init__.py +0 -1
- omnigenome/auto/auto_bench/auto_bench.py +24 -14
- omnigenome/auto/auto_train/__init__.py +0 -1
- omnigenome/auto/auto_train/auto_train.py +11 -12
- omnigenome/auto/bench_hub/__init__.py +0 -1
- omnigenome/auto/bench_hub/bench_hub.py +1 -1
- omnigenome/cli/__init__.py +0 -1
- omnigenome/cli/commands/__init__.py +0 -1
- omnigenome/cli/commands/base.py +10 -10
- omnigenome/cli/commands/bench/__init__.py +0 -1
- omnigenome/cli/commands/bench/bench_cli.py +10 -10
- omnigenome/cli/commands/rna/__init__.py +0 -1
- omnigenome/cli/commands/rna/rna_design.py +10 -11
- omnigenome/src/__init__.py +0 -1
- omnigenome/src/abc/__init__.py +0 -1
- omnigenome/src/abc/abstract_dataset.py +38 -19
- omnigenome/src/abc/abstract_metric.py +7 -7
- omnigenome/src/abc/abstract_model.py +15 -14
- omnigenome/src/abc/abstract_tokenizer.py +9 -7
- omnigenome/src/dataset/omni_dataset.py +16 -14
- omnigenome/src/lora/__init__.py +0 -1
- omnigenome/src/lora/lora_model.py +47 -41
- omnigenome/src/metric/classification_metric.py +11 -11
- omnigenome/src/metric/metric.py +19 -19
- omnigenome/src/metric/ranking_metric.py +15 -15
- omnigenome/src/metric/regression_metric.py +18 -18
- omnigenome/src/misc/utils.py +40 -36
- omnigenome/src/model/augmentation/__init__.py +0 -1
- omnigenome/src/model/augmentation/model.py +17 -17
- omnigenome/src/model/classification/__init__.py +0 -1
- omnigenome/src/model/classification/model.py +28 -32
- omnigenome/src/model/embedding/__init__.py +0 -1
- omnigenome/src/model/embedding/model.py +35 -35
- omnigenome/src/model/mlm/__init__.py +0 -1
- omnigenome/src/model/mlm/model.py +13 -13
- omnigenome/src/model/module_utils.py +17 -17
- omnigenome/src/model/regression/__init__.py +0 -1
- omnigenome/src/model/regression/model.py +72 -77
- omnigenome/src/model/regression/resnet.py +32 -32
- omnigenome/src/model/rna_design/__init__.py +0 -1
- omnigenome/src/model/rna_design/model.py +65 -58
- omnigenome/src/model/seq2seq/__init__.py +0 -1
- omnigenome/src/model/seq2seq/model.py +4 -4
- omnigenome/src/tokenizer/bpe_tokenizer.py +27 -27
- omnigenome/src/tokenizer/kmers_tokenizer.py +22 -22
- omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +11 -11
- omnigenome/src/trainer/accelerate_trainer.py +40 -32
- omnigenome/src/trainer/hf_trainer.py +8 -8
- omnigenome/src/trainer/trainer.py +37 -25
- omnigenome/utility/dataset_hub/__init__.py +0 -1
- omnigenome/utility/dataset_hub/dataset_hub.py +13 -13
- omnigenome/utility/ensemble.py +26 -26
- omnigenome/utility/hub_utils.py +8 -8
- omnigenome/utility/model_hub/__init__.py +0 -1
- omnigenome/utility/model_hub/model_hub.py +26 -25
- omnigenome/utility/pipeline_hub/__init__.py +0 -1
- omnigenome/utility/pipeline_hub/pipeline.py +49 -49
- omnigenome/utility/pipeline_hub/pipeline_hub.py +17 -17
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/METADATA +2 -2
- omnigenome-0.3.1a0.dist-info/RECORD +78 -0
- omnigenome-0.3.0a1.dist-info/RECORD +0 -78
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/WHEEL +0 -0
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/entry_points.txt +0 -0
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/licenses/LICENSE +0 -0
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/top_level.txt +0 -0
|
@@ -16,16 +16,16 @@ from ..module_utils import OmniPooling
|
|
|
16
16
|
class OmniModelForTokenClassification(OmniModel):
|
|
17
17
|
"""
|
|
18
18
|
Model for token classification tasks in genomics.
|
|
19
|
-
|
|
19
|
+
|
|
20
20
|
This model is designed for token-level classification tasks such as
|
|
21
21
|
sequence labeling, where each token in the input sequence needs to be
|
|
22
22
|
classified into different categories. It extends the base OmniModel
|
|
23
23
|
with token-level classification capabilities.
|
|
24
|
-
|
|
24
|
+
|
|
25
25
|
The model adds a classification head on top of the base model's hidden
|
|
26
26
|
states and applies softmax to produce probability distributions over
|
|
27
27
|
the label classes for each token.
|
|
28
|
-
|
|
28
|
+
|
|
29
29
|
Attributes:
|
|
30
30
|
softmax (torch.nn.Softmax): Softmax layer for probability computation.
|
|
31
31
|
classifier (torch.nn.Linear): Linear classification head.
|
|
@@ -57,7 +57,7 @@ class OmniModelForTokenClassification(OmniModel):
|
|
|
57
57
|
def forward(self, **inputs):
|
|
58
58
|
"""
|
|
59
59
|
Forward pass for token classification.
|
|
60
|
-
|
|
60
|
+
|
|
61
61
|
This method performs the forward pass through the model, computing
|
|
62
62
|
logits for each token in the input sequence and applying softmax
|
|
63
63
|
to produce probability distributions.
|
|
@@ -95,13 +95,13 @@ class OmniModelForTokenClassification(OmniModel):
|
|
|
95
95
|
def predict(self, sequence_or_inputs, **kwargs):
|
|
96
96
|
"""
|
|
97
97
|
Performs token-level prediction on raw inputs.
|
|
98
|
-
|
|
98
|
+
|
|
99
99
|
This method takes raw sequences or tokenized inputs and returns
|
|
100
100
|
token-level predictions. It processes the inputs through the model
|
|
101
101
|
and returns the predicted class for each token.
|
|
102
102
|
|
|
103
103
|
Args:
|
|
104
|
-
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
104
|
+
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
105
105
|
tokenized inputs (dict/tuple).
|
|
106
106
|
**kwargs: Additional arguments for tokenization and inference.
|
|
107
107
|
|
|
@@ -115,7 +115,7 @@ class OmniModelForTokenClassification(OmniModel):
|
|
|
115
115
|
>>> # Predict on a single sequence
|
|
116
116
|
>>> outputs = model.predict("ATCGATCG")
|
|
117
117
|
>>> print(outputs['predictions'].shape) # (seq_len,)
|
|
118
|
-
|
|
118
|
+
|
|
119
119
|
>>> # Predict on multiple sequences
|
|
120
120
|
>>> outputs = model.predict(["ATCGATCG", "GCTAGCTA"])
|
|
121
121
|
"""
|
|
@@ -142,12 +142,12 @@ class OmniModelForTokenClassification(OmniModel):
|
|
|
142
142
|
def inference(self, sequence_or_inputs, **kwargs):
|
|
143
143
|
"""
|
|
144
144
|
Performs token-level inference with human-readable output.
|
|
145
|
-
|
|
145
|
+
|
|
146
146
|
This method provides processed, human-readable token-level predictions.
|
|
147
147
|
It converts logits to class labels and handles special tokens appropriately.
|
|
148
148
|
|
|
149
149
|
Args:
|
|
150
|
-
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
150
|
+
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
151
151
|
tokenized inputs (dict/tuple).
|
|
152
152
|
**kwargs: Additional arguments for tokenization and inference.
|
|
153
153
|
|
|
@@ -200,7 +200,7 @@ class OmniModelForTokenClassification(OmniModel):
|
|
|
200
200
|
def loss_function(self, logits, labels):
|
|
201
201
|
"""
|
|
202
202
|
Calculates the cross-entropy loss for token classification.
|
|
203
|
-
|
|
203
|
+
|
|
204
204
|
This method computes the cross-entropy loss between the predicted
|
|
205
205
|
logits and the ground truth labels, ignoring padding tokens.
|
|
206
206
|
|
|
@@ -221,11 +221,11 @@ class OmniModelForTokenClassification(OmniModel):
|
|
|
221
221
|
class OmniModelForSequenceClassification(OmniModel):
|
|
222
222
|
"""
|
|
223
223
|
Model for sequence classification tasks in genomics.
|
|
224
|
-
|
|
224
|
+
|
|
225
225
|
This model is designed for sequence-level classification tasks where
|
|
226
226
|
the entire input sequence is classified into one of several categories.
|
|
227
227
|
It extends the base OmniModel with sequence-level classification capabilities.
|
|
228
|
-
|
|
228
|
+
|
|
229
229
|
The model uses a pooling mechanism to aggregate token-level representations
|
|
230
230
|
into a sequence-level representation, which is then classified using a
|
|
231
231
|
linear classifier.
|
|
@@ -263,7 +263,7 @@ class OmniModelForSequenceClassification(OmniModel):
|
|
|
263
263
|
def forward(self, **inputs):
|
|
264
264
|
"""
|
|
265
265
|
Forward pass for sequence classification.
|
|
266
|
-
|
|
266
|
+
|
|
267
267
|
This method performs the forward pass through the model, computing
|
|
268
268
|
sequence-level logits and applying softmax to produce probability
|
|
269
269
|
distributions over the label classes.
|
|
@@ -302,13 +302,13 @@ class OmniModelForSequenceClassification(OmniModel):
|
|
|
302
302
|
def predict(self, sequence_or_inputs, **kwargs):
|
|
303
303
|
"""
|
|
304
304
|
Performs sequence-level prediction on raw inputs.
|
|
305
|
-
|
|
305
|
+
|
|
306
306
|
This method takes raw sequences or tokenized inputs and returns
|
|
307
307
|
sequence-level predictions. It processes the inputs through the model
|
|
308
308
|
and returns the predicted class for each sequence.
|
|
309
309
|
|
|
310
310
|
Args:
|
|
311
|
-
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
311
|
+
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
312
312
|
tokenized inputs (dict/tuple).
|
|
313
313
|
**kwargs: Additional arguments for tokenization and inference.
|
|
314
314
|
|
|
@@ -322,7 +322,7 @@ class OmniModelForSequenceClassification(OmniModel):
|
|
|
322
322
|
>>> # Predict on a single sequence
|
|
323
323
|
>>> outputs = model.predict("ATCGATCG")
|
|
324
324
|
>>> print(outputs['predictions']) # tensor([0])
|
|
325
|
-
|
|
325
|
+
|
|
326
326
|
>>> # Predict on multiple sequences
|
|
327
327
|
>>> outputs = model.predict(["ATCGATCG", "GCTAGCTA"])
|
|
328
328
|
"""
|
|
@@ -350,12 +350,12 @@ class OmniModelForSequenceClassification(OmniModel):
|
|
|
350
350
|
def inference(self, sequence_or_inputs, **kwargs):
|
|
351
351
|
"""
|
|
352
352
|
Performs sequence-level inference with human-readable output.
|
|
353
|
-
|
|
353
|
+
|
|
354
354
|
This method provides processed, human-readable sequence-level predictions.
|
|
355
355
|
It converts logits to class labels and provides confidence scores.
|
|
356
356
|
|
|
357
357
|
Args:
|
|
358
|
-
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
358
|
+
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
359
359
|
tokenized inputs (dict/tuple).
|
|
360
360
|
**kwargs: Additional arguments for tokenization and inference.
|
|
361
361
|
|
|
@@ -403,7 +403,7 @@ class OmniModelForSequenceClassification(OmniModel):
|
|
|
403
403
|
def loss_function(self, logits, labels):
|
|
404
404
|
"""
|
|
405
405
|
Calculates the cross-entropy loss for sequence classification.
|
|
406
|
-
|
|
406
|
+
|
|
407
407
|
This method computes the cross-entropy loss between the predicted
|
|
408
408
|
logits and the ground truth labels.
|
|
409
409
|
|
|
@@ -421,16 +421,14 @@ class OmniModelForSequenceClassification(OmniModel):
|
|
|
421
421
|
return loss
|
|
422
422
|
|
|
423
423
|
|
|
424
|
-
class OmniModelForMultiLabelSequenceClassification(
|
|
425
|
-
OmniModelForSequenceClassification
|
|
426
|
-
):
|
|
424
|
+
class OmniModelForMultiLabelSequenceClassification(OmniModelForSequenceClassification):
|
|
427
425
|
"""
|
|
428
426
|
Model for multi-label sequence classification tasks in genomics.
|
|
429
|
-
|
|
427
|
+
|
|
430
428
|
This model is designed for multi-label classification tasks where
|
|
431
429
|
a single sequence can be assigned multiple labels simultaneously.
|
|
432
430
|
It extends the sequence classification model with multi-label capabilities.
|
|
433
|
-
|
|
431
|
+
|
|
434
432
|
The model uses sigmoid activation instead of softmax to allow multiple
|
|
435
433
|
labels per sequence and uses binary cross-entropy loss for training.
|
|
436
434
|
|
|
@@ -461,7 +459,7 @@ class OmniModelForMultiLabelSequenceClassification(
|
|
|
461
459
|
def loss_function(self, logits, labels):
|
|
462
460
|
"""
|
|
463
461
|
Calculates the binary cross-entropy loss for multi-label classification.
|
|
464
|
-
|
|
462
|
+
|
|
465
463
|
This method computes the binary cross-entropy loss between the predicted
|
|
466
464
|
probabilities and the ground truth multi-label targets.
|
|
467
465
|
|
|
@@ -481,13 +479,13 @@ class OmniModelForMultiLabelSequenceClassification(
|
|
|
481
479
|
def predict(self, sequence_or_inputs, **kwargs):
|
|
482
480
|
"""
|
|
483
481
|
Performs multi-label prediction on raw inputs.
|
|
484
|
-
|
|
482
|
+
|
|
485
483
|
This method takes raw sequences or tokenized inputs and returns
|
|
486
484
|
multi-label predictions. It applies a threshold to determine
|
|
487
485
|
which labels are active for each sequence.
|
|
488
486
|
|
|
489
487
|
Args:
|
|
490
|
-
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
488
|
+
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
491
489
|
tokenized inputs (dict/tuple).
|
|
492
490
|
**kwargs: Additional arguments for tokenization and inference.
|
|
493
491
|
|
|
@@ -527,12 +525,12 @@ class OmniModelForMultiLabelSequenceClassification(
|
|
|
527
525
|
def inference(self, sequence_or_inputs, **kwargs):
|
|
528
526
|
"""
|
|
529
527
|
Performs multi-label inference with human-readable output.
|
|
530
|
-
|
|
528
|
+
|
|
531
529
|
This method provides processed, human-readable multi-label predictions.
|
|
532
530
|
It converts logits to binary labels and provides confidence scores.
|
|
533
531
|
|
|
534
532
|
Args:
|
|
535
|
-
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
533
|
+
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
536
534
|
tokenized inputs (dict/tuple).
|
|
537
535
|
**kwargs: Additional arguments for tokenization and inference.
|
|
538
536
|
|
|
@@ -551,9 +549,7 @@ class OmniModelForMultiLabelSequenceClassification(
|
|
|
551
549
|
return self.predict(sequence_or_inputs, **kwargs)
|
|
552
550
|
|
|
553
551
|
|
|
554
|
-
class OmniModelForTokenClassificationWith2DStructure(
|
|
555
|
-
OmniModelForTokenClassification
|
|
556
|
-
):
|
|
552
|
+
class OmniModelForTokenClassificationWith2DStructure(OmniModelForTokenClassification):
|
|
557
553
|
def __init__(self, config_or_model, tokenizer, *args, **kwargs):
|
|
558
554
|
super().__init__(config_or_model, tokenizer, *args, **kwargs)
|
|
559
555
|
self.metadata["model_name"] = self.__class__.__name__
|
|
@@ -16,16 +16,16 @@ from omnigenome.src.misc.utils import fprint
|
|
|
16
16
|
class OmniModelForEmbedding(torch.nn.Module):
|
|
17
17
|
"""
|
|
18
18
|
A wrapper class for generating embeddings from pre-trained models.
|
|
19
|
-
|
|
19
|
+
|
|
20
20
|
This class provides a unified interface for loading pre-trained models and
|
|
21
21
|
generating embeddings from genomic sequences. It supports various aggregation
|
|
22
22
|
methods and batch processing for efficient embedding generation.
|
|
23
|
-
|
|
23
|
+
|
|
24
24
|
Attributes:
|
|
25
25
|
tokenizer: The tokenizer for processing input sequences
|
|
26
26
|
model: The pre-trained model for generating embeddings
|
|
27
27
|
_device: The device (CPU/GPU) where the model is loaded
|
|
28
|
-
|
|
28
|
+
|
|
29
29
|
Example:
|
|
30
30
|
>>> from omnigenome import OmniModelForEmbedding
|
|
31
31
|
>>> model = OmniModelForEmbedding("anonymous8/OmniGenome-186M")
|
|
@@ -34,11 +34,11 @@ class OmniModelForEmbedding(torch.nn.Module):
|
|
|
34
34
|
>>> print(f"Embeddings shape: {embeddings.shape}")
|
|
35
35
|
torch.Size([2, 768])
|
|
36
36
|
"""
|
|
37
|
-
|
|
37
|
+
|
|
38
38
|
def __init__(self, model_name_or_path, *args, **kwargs):
|
|
39
39
|
"""
|
|
40
40
|
Initialize the embedding model.
|
|
41
|
-
|
|
41
|
+
|
|
42
42
|
Args:
|
|
43
43
|
model_name_or_path (str): Name or path of the pre-trained model to load
|
|
44
44
|
*args: Additional positional arguments passed to AutoModel.from_pretrained
|
|
@@ -51,25 +51,25 @@ class OmniModelForEmbedding(torch.nn.Module):
|
|
|
51
51
|
self.model.to(self._device)
|
|
52
52
|
self.model.eval() # Set model to evaluation mode
|
|
53
53
|
|
|
54
|
-
def batch_encode(self, sequences, batch_size=8, max_length=512, agg=
|
|
54
|
+
def batch_encode(self, sequences, batch_size=8, max_length=512, agg="head"):
|
|
55
55
|
"""
|
|
56
56
|
Encode a list of sequences to their corresponding embeddings.
|
|
57
|
-
|
|
57
|
+
|
|
58
58
|
This method processes sequences in batches for memory efficiency and
|
|
59
59
|
supports different aggregation methods for the final embeddings.
|
|
60
|
-
|
|
60
|
+
|
|
61
61
|
Args:
|
|
62
62
|
sequences (list of str): List of input sequences to encode
|
|
63
63
|
batch_size (int, optional): Batch size for processing. Defaults to 8
|
|
64
64
|
max_length (int, optional): Maximum sequence length for encoding. Defaults to 512
|
|
65
65
|
agg (str, optional): Aggregation method for embeddings. Options are 'head', 'mean', 'tail'. Defaults to 'head'
|
|
66
|
-
|
|
66
|
+
|
|
67
67
|
Returns:
|
|
68
68
|
torch.Tensor: Embeddings for the input sequences with shape (n_sequences, embedding_dim)
|
|
69
|
-
|
|
69
|
+
|
|
70
70
|
Raises:
|
|
71
71
|
ValueError: If unsupported aggregation method is provided
|
|
72
|
-
|
|
72
|
+
|
|
73
73
|
Example:
|
|
74
74
|
>>> sequences = ["ATCGGCTA", "GGCTAGCTA", "TATCGCTA"]
|
|
75
75
|
>>> embeddings = model.batch_encode(sequences, batch_size=2, agg='mean')
|
|
@@ -79,7 +79,7 @@ class OmniModelForEmbedding(torch.nn.Module):
|
|
|
79
79
|
embeddings = []
|
|
80
80
|
|
|
81
81
|
for i in range(0, len(sequences), batch_size):
|
|
82
|
-
batch_sequences = sequences[i: i + batch_size]
|
|
82
|
+
batch_sequences = sequences[i : i + batch_size]
|
|
83
83
|
inputs = self.tokenizer(
|
|
84
84
|
batch_sequences,
|
|
85
85
|
return_tensors="pt",
|
|
@@ -94,19 +94,19 @@ class OmniModelForEmbedding(torch.nn.Module):
|
|
|
94
94
|
|
|
95
95
|
batch_embeddings = outputs.last_hidden_state.cpu()
|
|
96
96
|
|
|
97
|
-
if agg ==
|
|
97
|
+
if agg == "head":
|
|
98
98
|
emb = batch_embeddings[:, 0, :]
|
|
99
|
-
elif agg ==
|
|
99
|
+
elif agg == "mean":
|
|
100
100
|
attention_mask = inputs["attention_mask"].cpu()
|
|
101
101
|
masked_embeddings = batch_embeddings * attention_mask.unsqueeze(-1)
|
|
102
102
|
lengths = attention_mask.sum(dim=1).unsqueeze(1)
|
|
103
103
|
emb = masked_embeddings.sum(dim=1) / lengths
|
|
104
|
-
elif agg ==
|
|
104
|
+
elif agg == "tail":
|
|
105
105
|
attention_mask = inputs["attention_mask"]
|
|
106
106
|
lengths = attention_mask.sum(dim=1) - 1
|
|
107
|
-
emb = torch.stack(
|
|
108
|
-
batch_embeddings[i, l.item()] for i, l in enumerate(lengths)
|
|
109
|
-
|
|
107
|
+
emb = torch.stack(
|
|
108
|
+
[batch_embeddings[i, l.item()] for i, l in enumerate(lengths)]
|
|
109
|
+
)
|
|
110
110
|
else:
|
|
111
111
|
raise ValueError(f"Unsupported aggregation method: {agg}")
|
|
112
112
|
|
|
@@ -116,22 +116,22 @@ class OmniModelForEmbedding(torch.nn.Module):
|
|
|
116
116
|
fprint(f"Generated embeddings for {len(sequences)} sequences.")
|
|
117
117
|
return embeddings
|
|
118
118
|
|
|
119
|
-
def encode(self, sequence, max_length=512, agg=
|
|
119
|
+
def encode(self, sequence, max_length=512, agg="head", keep_dim=False):
|
|
120
120
|
"""
|
|
121
121
|
Encode a single sequence to its corresponding embedding.
|
|
122
|
-
|
|
122
|
+
|
|
123
123
|
Args:
|
|
124
124
|
sequence (str): Input sequence to encode
|
|
125
125
|
max_length (int, optional): Maximum sequence length for encoding. Defaults to 512
|
|
126
126
|
agg (str, optional): Aggregation method. Options are 'head', 'mean', 'tail'. Defaults to 'head'
|
|
127
127
|
keep_dim (bool, optional): Whether to retain the batch dimension. Defaults to False
|
|
128
|
-
|
|
128
|
+
|
|
129
129
|
Returns:
|
|
130
130
|
torch.Tensor: Embedding for the input sequence
|
|
131
|
-
|
|
131
|
+
|
|
132
132
|
Raises:
|
|
133
133
|
ValueError: If unsupported aggregation method is provided
|
|
134
|
-
|
|
134
|
+
|
|
135
135
|
Example:
|
|
136
136
|
>>> sequence = "ATCGGCTA"
|
|
137
137
|
>>> embedding = model.encode(sequence, agg='mean')
|
|
@@ -152,15 +152,15 @@ class OmniModelForEmbedding(torch.nn.Module):
|
|
|
152
152
|
|
|
153
153
|
last_hidden = outputs.last_hidden_state.cpu()
|
|
154
154
|
|
|
155
|
-
if agg ==
|
|
155
|
+
if agg == "head":
|
|
156
156
|
emb = last_hidden[0, 0]
|
|
157
|
-
elif agg ==
|
|
157
|
+
elif agg == "mean":
|
|
158
158
|
attention_mask = inputs["attention_mask"].cpu()
|
|
159
159
|
masked_embeddings = last_hidden * attention_mask.unsqueeze(-1)
|
|
160
160
|
lengths = attention_mask.sum(dim=1).unsqueeze(1)
|
|
161
161
|
emb = masked_embeddings.sum(dim=1) / lengths
|
|
162
162
|
emb = emb.squeeze(0)
|
|
163
|
-
elif agg ==
|
|
163
|
+
elif agg == "tail":
|
|
164
164
|
attention_mask = inputs["attention_mask"]
|
|
165
165
|
lengths = attention_mask.sum(dim=1) - 1
|
|
166
166
|
emb = last_hidden[0, lengths[0].item()]
|
|
@@ -172,11 +172,11 @@ class OmniModelForEmbedding(torch.nn.Module):
|
|
|
172
172
|
def save_embeddings(self, embeddings, output_path):
|
|
173
173
|
"""
|
|
174
174
|
Save the generated embeddings to a file.
|
|
175
|
-
|
|
175
|
+
|
|
176
176
|
Args:
|
|
177
177
|
embeddings (torch.Tensor): The embeddings to save
|
|
178
178
|
output_path (str): Path to save the embeddings
|
|
179
|
-
|
|
179
|
+
|
|
180
180
|
Example:
|
|
181
181
|
>>> embeddings = model.batch_encode(sequences)
|
|
182
182
|
>>> model.save_embeddings(embeddings, "embeddings.pt")
|
|
@@ -188,13 +188,13 @@ class OmniModelForEmbedding(torch.nn.Module):
|
|
|
188
188
|
def load_embeddings(self, embedding_path):
|
|
189
189
|
"""
|
|
190
190
|
Load embeddings from a file.
|
|
191
|
-
|
|
191
|
+
|
|
192
192
|
Args:
|
|
193
193
|
embedding_path (str): Path to the saved embeddings
|
|
194
|
-
|
|
194
|
+
|
|
195
195
|
Returns:
|
|
196
196
|
torch.Tensor: The loaded embeddings
|
|
197
|
-
|
|
197
|
+
|
|
198
198
|
Example:
|
|
199
199
|
>>> embeddings = model.load_embeddings("embeddings.pt")
|
|
200
200
|
>>> print(f"Loaded embeddings shape: {embeddings.shape}")
|
|
@@ -207,15 +207,15 @@ class OmniModelForEmbedding(torch.nn.Module):
|
|
|
207
207
|
def compute_similarity(self, embedding1, embedding2, dim=0):
|
|
208
208
|
"""
|
|
209
209
|
Compute cosine similarity between two embeddings.
|
|
210
|
-
|
|
210
|
+
|
|
211
211
|
Args:
|
|
212
212
|
embedding1 (torch.Tensor): The first embedding
|
|
213
213
|
embedding2 (torch.Tensor): The second embedding
|
|
214
214
|
dim (int, optional): Dimension along which to compute cosine similarity. Defaults to 0
|
|
215
|
-
|
|
215
|
+
|
|
216
216
|
Returns:
|
|
217
217
|
float: Cosine similarity score between -1 and 1
|
|
218
|
-
|
|
218
|
+
|
|
219
219
|
Example:
|
|
220
220
|
>>> emb1 = model.encode("ATCGGCTA")
|
|
221
221
|
>>> emb2 = model.encode("GGCTAGCTA")
|
|
@@ -232,7 +232,7 @@ class OmniModelForEmbedding(torch.nn.Module):
|
|
|
232
232
|
def device(self):
|
|
233
233
|
"""
|
|
234
234
|
Get the current device ('cuda' or 'cpu').
|
|
235
|
-
|
|
235
|
+
|
|
236
236
|
Returns:
|
|
237
237
|
torch.device: The device where the model is loaded
|
|
238
238
|
"""
|
|
@@ -23,26 +23,26 @@ from ...abc.abstract_model import OmniModel
|
|
|
23
23
|
class OmniModelForMLM(OmniModel):
|
|
24
24
|
"""
|
|
25
25
|
Masked Language Model for genomic sequences.
|
|
26
|
-
|
|
26
|
+
|
|
27
27
|
This model implements masked language modeling for genomic sequences, where
|
|
28
28
|
tokens are randomly masked and the model learns to predict the original tokens.
|
|
29
29
|
It's useful for pre-training genomic language models and understanding sequence
|
|
30
30
|
patterns and dependencies.
|
|
31
|
-
|
|
31
|
+
|
|
32
32
|
Attributes:
|
|
33
33
|
loss_fn: Cross-entropy loss function for masked language modeling
|
|
34
34
|
"""
|
|
35
|
-
|
|
35
|
+
|
|
36
36
|
def __init__(self, config_or_model, tokenizer, *args, **kwargs):
|
|
37
37
|
"""
|
|
38
38
|
Initialize the MLM model.
|
|
39
|
-
|
|
39
|
+
|
|
40
40
|
Args:
|
|
41
41
|
config_or_model: Model configuration or pre-trained model
|
|
42
42
|
tokenizer: Tokenizer for processing input sequences
|
|
43
43
|
*args: Additional positional arguments
|
|
44
44
|
**kwargs: Additional keyword arguments
|
|
45
|
-
|
|
45
|
+
|
|
46
46
|
Raises:
|
|
47
47
|
ValueError: If the model doesn't support masked language modeling
|
|
48
48
|
"""
|
|
@@ -59,10 +59,10 @@ class OmniModelForMLM(OmniModel):
|
|
|
59
59
|
def forward(self, **inputs):
|
|
60
60
|
"""
|
|
61
61
|
Forward pass for masked language modeling.
|
|
62
|
-
|
|
62
|
+
|
|
63
63
|
Args:
|
|
64
64
|
**inputs: Input tensors including input_ids, attention_mask, and labels
|
|
65
|
-
|
|
65
|
+
|
|
66
66
|
Returns:
|
|
67
67
|
dict: Dictionary containing loss, logits, and last_hidden_state
|
|
68
68
|
"""
|
|
@@ -85,11 +85,11 @@ class OmniModelForMLM(OmniModel):
|
|
|
85
85
|
def predict(self, sequence_or_inputs, **kwargs):
|
|
86
86
|
"""
|
|
87
87
|
Generate predictions for masked language modeling.
|
|
88
|
-
|
|
88
|
+
|
|
89
89
|
Args:
|
|
90
90
|
sequence_or_inputs: Input sequences or pre-processed inputs
|
|
91
91
|
**kwargs: Additional keyword arguments
|
|
92
|
-
|
|
92
|
+
|
|
93
93
|
Returns:
|
|
94
94
|
dict: Dictionary containing predictions, logits, and last_hidden_state
|
|
95
95
|
"""
|
|
@@ -124,11 +124,11 @@ class OmniModelForMLM(OmniModel):
|
|
|
124
124
|
def inference(self, sequence_or_inputs, **kwargs):
|
|
125
125
|
"""
|
|
126
126
|
Perform inference for masked language modeling, decoding predictions to sequences.
|
|
127
|
-
|
|
127
|
+
|
|
128
128
|
Args:
|
|
129
129
|
sequence_or_inputs: Input sequences or pre-processed inputs
|
|
130
130
|
**kwargs: Additional keyword arguments
|
|
131
|
-
|
|
131
|
+
|
|
132
132
|
Returns:
|
|
133
133
|
dict: Dictionary containing decoded predictions, logits, and last_hidden_state
|
|
134
134
|
"""
|
|
@@ -164,11 +164,11 @@ class OmniModelForMLM(OmniModel):
|
|
|
164
164
|
def loss_function(self, logits, labels):
|
|
165
165
|
"""
|
|
166
166
|
Compute the loss for masked language modeling.
|
|
167
|
-
|
|
167
|
+
|
|
168
168
|
Args:
|
|
169
169
|
logits (torch.Tensor): Model predictions [batch_size, seq_len, vocab_size]
|
|
170
170
|
labels (torch.Tensor): Ground truth labels [batch_size, seq_len]
|
|
171
|
-
|
|
171
|
+
|
|
172
172
|
Returns:
|
|
173
173
|
torch.Tensor: Computed cross-entropy loss value
|
|
174
174
|
"""
|
|
@@ -22,21 +22,21 @@ from transformers.tokenization_utils_base import BatchEncoding
|
|
|
22
22
|
class OmniPooling(torch.nn.Module):
|
|
23
23
|
"""
|
|
24
24
|
A flexible pooling layer for OmniGenome models that handles different input formats.
|
|
25
|
-
|
|
25
|
+
|
|
26
26
|
This class provides a unified interface for pooling operations across different
|
|
27
27
|
model architectures, supporting both causal language models and encoder-based models.
|
|
28
28
|
It can handle various input formats including tuples, dictionaries, BatchEncoding
|
|
29
29
|
objects, and tensors.
|
|
30
|
-
|
|
30
|
+
|
|
31
31
|
Attributes:
|
|
32
32
|
config: Model configuration object containing architecture and tokenizer settings
|
|
33
33
|
pooler: BertPooler instance for non-causal models, None for causal models
|
|
34
34
|
"""
|
|
35
|
-
|
|
35
|
+
|
|
36
36
|
def __init__(self, config, *args, **kwargs):
|
|
37
37
|
"""
|
|
38
38
|
Initialize the OmniPooling layer.
|
|
39
|
-
|
|
39
|
+
|
|
40
40
|
Args:
|
|
41
41
|
config: Model configuration object containing architecture information
|
|
42
42
|
*args: Additional positional arguments
|
|
@@ -49,18 +49,18 @@ class OmniPooling(torch.nn.Module):
|
|
|
49
49
|
def forward(self, inputs, last_hidden_state):
|
|
50
50
|
"""
|
|
51
51
|
Perform pooling operation on the last hidden state.
|
|
52
|
-
|
|
52
|
+
|
|
53
53
|
This method handles different input formats and applies appropriate pooling:
|
|
54
54
|
- For causal language models: Uses the last non-padded token
|
|
55
55
|
- For encoder models: Uses the BertPooler
|
|
56
|
-
|
|
56
|
+
|
|
57
57
|
Args:
|
|
58
58
|
inputs: Input data in various formats (tuple, dict, BatchEncoding, or tensor)
|
|
59
59
|
last_hidden_state (torch.Tensor): Hidden states from the model [batch_size, seq_len, hidden_size]
|
|
60
|
-
|
|
60
|
+
|
|
61
61
|
Returns:
|
|
62
62
|
torch.Tensor: Pooled representation [batch_size, hidden_size]
|
|
63
|
-
|
|
63
|
+
|
|
64
64
|
Raises:
|
|
65
65
|
ValueError: If input format is not supported or cannot be parsed
|
|
66
66
|
"""
|
|
@@ -110,9 +110,9 @@ class OmniPooling(torch.nn.Module):
|
|
|
110
110
|
def _is_causal_lm(self):
|
|
111
111
|
"""
|
|
112
112
|
Check if the model is a causal language model.
|
|
113
|
-
|
|
113
|
+
|
|
114
114
|
Determines if the model architecture is causal based on the configuration.
|
|
115
|
-
|
|
115
|
+
|
|
116
116
|
Returns:
|
|
117
117
|
bool: True if the model is a causal language model, False otherwise
|
|
118
118
|
"""
|
|
@@ -175,25 +175,25 @@ class OmniPooling(torch.nn.Module):
|
|
|
175
175
|
class InteractingAttention(nn.Module):
|
|
176
176
|
"""
|
|
177
177
|
An interacting attention mechanism for sequence modeling.
|
|
178
|
-
|
|
178
|
+
|
|
179
179
|
This class implements a multi-head attention mechanism with residual connections
|
|
180
180
|
and layer normalization. It's designed for processing sequences where different
|
|
181
181
|
parts of the sequence need to interact with each other.
|
|
182
|
-
|
|
182
|
+
|
|
183
183
|
Attributes:
|
|
184
184
|
attention: Multi-head attention layer
|
|
185
185
|
layer_norm: Layer normalization for residual connections
|
|
186
186
|
fc_out: Output projection layer
|
|
187
187
|
"""
|
|
188
|
-
|
|
188
|
+
|
|
189
189
|
def __init__(self, embed_size, num_heads=24):
|
|
190
190
|
"""
|
|
191
191
|
Initialize the InteractingAttention module.
|
|
192
|
-
|
|
192
|
+
|
|
193
193
|
Args:
|
|
194
194
|
embed_size (int): Size of the embedding dimension
|
|
195
195
|
num_heads (int): Number of attention heads (default: 24)
|
|
196
|
-
|
|
196
|
+
|
|
197
197
|
Raises:
|
|
198
198
|
AssertionError: If embed_size is not divisible by num_heads
|
|
199
199
|
"""
|
|
@@ -213,12 +213,12 @@ class InteractingAttention(nn.Module):
|
|
|
213
213
|
def forward(self, query, keys, values):
|
|
214
214
|
"""
|
|
215
215
|
Forward pass through the interacting attention mechanism.
|
|
216
|
-
|
|
216
|
+
|
|
217
217
|
Args:
|
|
218
218
|
query (torch.Tensor): Query tensor [batch_size, query_len, embed_size]
|
|
219
219
|
keys (torch.Tensor): Key tensor [batch_size, key_len, embed_size]
|
|
220
220
|
values (torch.Tensor): Value tensor [batch_size, value_len, embed_size]
|
|
221
|
-
|
|
221
|
+
|
|
222
222
|
Returns:
|
|
223
223
|
torch.Tensor: Output tensor with same shape as query
|
|
224
224
|
"""
|