omnigenome 0.3.0a0__py3-none-any.whl → 0.3.1a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. omnigenome/__init__.py +29 -44
  2. omnigenome/auto/auto_bench/__init__.py +0 -1
  3. omnigenome/auto/auto_bench/auto_bench.py +24 -14
  4. omnigenome/auto/auto_train/__init__.py +0 -1
  5. omnigenome/auto/auto_train/auto_train.py +11 -12
  6. omnigenome/auto/bench_hub/__init__.py +0 -1
  7. omnigenome/auto/bench_hub/bench_hub.py +1 -1
  8. omnigenome/cli/__init__.py +0 -1
  9. omnigenome/cli/commands/__init__.py +0 -1
  10. omnigenome/cli/commands/base.py +10 -10
  11. omnigenome/cli/commands/bench/__init__.py +0 -1
  12. omnigenome/cli/commands/bench/bench_cli.py +10 -10
  13. omnigenome/cli/commands/rna/__init__.py +0 -1
  14. omnigenome/cli/commands/rna/rna_design.py +10 -11
  15. omnigenome/src/__init__.py +0 -1
  16. omnigenome/src/abc/__init__.py +0 -1
  17. omnigenome/src/abc/abstract_dataset.py +38 -19
  18. omnigenome/src/abc/abstract_metric.py +7 -7
  19. omnigenome/src/abc/abstract_model.py +15 -14
  20. omnigenome/src/abc/abstract_tokenizer.py +9 -7
  21. omnigenome/src/dataset/omni_dataset.py +16 -14
  22. omnigenome/src/lora/__init__.py +0 -1
  23. omnigenome/src/lora/lora_model.py +47 -41
  24. omnigenome/src/metric/classification_metric.py +11 -11
  25. omnigenome/src/metric/metric.py +19 -19
  26. omnigenome/src/metric/ranking_metric.py +15 -15
  27. omnigenome/src/metric/regression_metric.py +18 -18
  28. omnigenome/src/misc/utils.py +214 -150
  29. omnigenome/src/model/augmentation/__init__.py +0 -1
  30. omnigenome/src/model/augmentation/model.py +17 -17
  31. omnigenome/src/model/classification/__init__.py +0 -1
  32. omnigenome/src/model/classification/model.py +28 -32
  33. omnigenome/src/model/embedding/__init__.py +0 -1
  34. omnigenome/src/model/embedding/model.py +35 -35
  35. omnigenome/src/model/mlm/__init__.py +0 -1
  36. omnigenome/src/model/mlm/model.py +13 -13
  37. omnigenome/src/model/module_utils.py +17 -17
  38. omnigenome/src/model/regression/__init__.py +0 -1
  39. omnigenome/src/model/regression/model.py +72 -77
  40. omnigenome/src/model/regression/resnet.py +32 -32
  41. omnigenome/src/model/rna_design/__init__.py +0 -1
  42. omnigenome/src/model/rna_design/model.py +168 -118
  43. omnigenome/src/model/seq2seq/__init__.py +0 -1
  44. omnigenome/src/model/seq2seq/model.py +4 -4
  45. omnigenome/src/tokenizer/bpe_tokenizer.py +27 -27
  46. omnigenome/src/tokenizer/kmers_tokenizer.py +22 -22
  47. omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +11 -11
  48. omnigenome/src/trainer/accelerate_trainer.py +40 -32
  49. omnigenome/src/trainer/hf_trainer.py +8 -8
  50. omnigenome/src/trainer/trainer.py +37 -25
  51. omnigenome/utility/dataset_hub/__init__.py +0 -1
  52. omnigenome/utility/dataset_hub/dataset_hub.py +13 -13
  53. omnigenome/utility/ensemble.py +26 -26
  54. omnigenome/utility/hub_utils.py +8 -8
  55. omnigenome/utility/model_hub/__init__.py +0 -1
  56. omnigenome/utility/model_hub/model_hub.py +26 -25
  57. omnigenome/utility/pipeline_hub/__init__.py +0 -1
  58. omnigenome/utility/pipeline_hub/pipeline.py +49 -49
  59. omnigenome/utility/pipeline_hub/pipeline_hub.py +17 -17
  60. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/METADATA +3 -3
  61. omnigenome-0.3.1a0.dist-info/RECORD +78 -0
  62. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/top_level.txt +0 -1
  63. omnigenome-0.3.0a0.dist-info/RECORD +0 -85
  64. tests/__init__.py +0 -9
  65. tests/conftest.py +0 -160
  66. tests/test_dataset_patterns.py +0 -291
  67. tests/test_examples_syntax.py +0 -83
  68. tests/test_model_loading.py +0 -183
  69. tests/test_rna_functions.py +0 -255
  70. tests/test_training_patterns.py +0 -302
  71. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/WHEEL +0 -0
  72. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/entry_points.txt +0 -0
  73. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/licenses/LICENSE +0 -0
@@ -24,12 +24,12 @@ import autocuda
24
24
  class OmniModelForAugmentation(torch.nn.Module):
25
25
  """
26
26
  Data augmentation model for genomic sequences using masked language modeling.
27
-
27
+
28
28
  This model uses a pre-trained masked language model to generate augmented
29
29
  versions of genomic sequences by randomly masking tokens and predicting
30
30
  replacements. It's useful for expanding training datasets and improving
31
31
  model generalization.
32
-
32
+
33
33
  Attributes:
34
34
  tokenizer: Tokenizer for processing genomic sequences
35
35
  model: Pre-trained masked language model
@@ -38,7 +38,7 @@ class OmniModelForAugmentation(torch.nn.Module):
38
38
  max_length: Maximum sequence length for tokenization
39
39
  k: Number of augmented instances to generate per sequence
40
40
  """
41
-
41
+
42
42
  def __init__(
43
43
  self,
44
44
  model_name_or_path=None,
@@ -50,7 +50,7 @@ class OmniModelForAugmentation(torch.nn.Module):
50
50
  ):
51
51
  """
52
52
  Initialize the augmentation model.
53
-
53
+
54
54
  Args:
55
55
  model_name_or_path (str): Path or model name for loading the pre-trained model
56
56
  noise_ratio (float): The proportion of tokens to mask in each sequence for augmentation (default: 0.15)
@@ -82,10 +82,10 @@ class OmniModelForAugmentation(torch.nn.Module):
82
82
  def load_sequences_from_file(self, input_file):
83
83
  """
84
84
  Load sequences from a JSON file.
85
-
85
+
86
86
  Args:
87
87
  input_file (str): Path to the input JSON file containing sequences
88
-
88
+
89
89
  Returns:
90
90
  list: List of sequences loaded from the file
91
91
  """
@@ -98,10 +98,10 @@ class OmniModelForAugmentation(torch.nn.Module):
98
98
  def apply_noise_to_sequence(self, seq):
99
99
  """
100
100
  Apply noise to a single sequence by randomly masking tokens.
101
-
101
+
102
102
  Args:
103
103
  seq (str): Input genomic sequence
104
-
104
+
105
105
  Returns:
106
106
  str: Sequence with randomly masked tokens
107
107
  """
@@ -114,10 +114,10 @@ class OmniModelForAugmentation(torch.nn.Module):
114
114
  def augment_sequence(self, seq):
115
115
  """
116
116
  Perform augmentation on a single sequence by predicting masked tokens.
117
-
117
+
118
118
  Args:
119
119
  seq (str): Input genomic sequence with masked tokens
120
-
120
+
121
121
  Returns:
122
122
  str: Augmented sequence with predicted tokens replacing masked tokens
123
123
  """
@@ -145,11 +145,11 @@ class OmniModelForAugmentation(torch.nn.Module):
145
145
  def augment(self, seq, k=None):
146
146
  """
147
147
  Generate multiple augmented instances for a single sequence.
148
-
148
+
149
149
  Args:
150
150
  seq (str): Input genomic sequence
151
151
  k (int, optional): Number of augmented instances to generate (default: None, uses self.k)
152
-
152
+
153
153
  Returns:
154
154
  list: List of augmented sequences
155
155
  """
@@ -163,10 +163,10 @@ class OmniModelForAugmentation(torch.nn.Module):
163
163
  def augment_sequences(self, sequences):
164
164
  """
165
165
  Augment a list of sequences by applying noise and performing MLM-based predictions.
166
-
166
+
167
167
  Args:
168
168
  sequences (list): List of genomic sequences to augment
169
-
169
+
170
170
  Returns:
171
171
  list: List of all augmented sequences
172
172
  """
@@ -179,7 +179,7 @@ class OmniModelForAugmentation(torch.nn.Module):
179
179
  def save_augmented_sequences(self, augmented_sequences, output_file):
180
180
  """
181
181
  Save augmented sequences to a JSON file.
182
-
182
+
183
183
  Args:
184
184
  augmented_sequences (list): List of augmented sequences to save
185
185
  output_file (str): Path to the output JSON file
@@ -191,10 +191,10 @@ class OmniModelForAugmentation(torch.nn.Module):
191
191
  def augment_from_file(self, input_file, output_file):
192
192
  """
193
193
  Main function to handle the augmentation process from a file input to a file output.
194
-
194
+
195
195
  This method loads sequences from an input file, augments them using the MLM model,
196
196
  and saves the augmented sequences to an output file.
197
-
197
+
198
198
  Args:
199
199
  input_file (str): Path to the input file containing sequences
200
200
  output_file (str): Path to the output file where augmented sequences will be saved
@@ -9,4 +9,3 @@
9
9
  """
10
10
  This package contains modules for classification models.
11
11
  """
12
-
@@ -16,16 +16,16 @@ from ..module_utils import OmniPooling
16
16
  class OmniModelForTokenClassification(OmniModel):
17
17
  """
18
18
  Model for token classification tasks in genomics.
19
-
19
+
20
20
  This model is designed for token-level classification tasks such as
21
21
  sequence labeling, where each token in the input sequence needs to be
22
22
  classified into different categories. It extends the base OmniModel
23
23
  with token-level classification capabilities.
24
-
24
+
25
25
  The model adds a classification head on top of the base model's hidden
26
26
  states and applies softmax to produce probability distributions over
27
27
  the label classes for each token.
28
-
28
+
29
29
  Attributes:
30
30
  softmax (torch.nn.Softmax): Softmax layer for probability computation.
31
31
  classifier (torch.nn.Linear): Linear classification head.
@@ -57,7 +57,7 @@ class OmniModelForTokenClassification(OmniModel):
57
57
  def forward(self, **inputs):
58
58
  """
59
59
  Forward pass for token classification.
60
-
60
+
61
61
  This method performs the forward pass through the model, computing
62
62
  logits for each token in the input sequence and applying softmax
63
63
  to produce probability distributions.
@@ -95,13 +95,13 @@ class OmniModelForTokenClassification(OmniModel):
95
95
  def predict(self, sequence_or_inputs, **kwargs):
96
96
  """
97
97
  Performs token-level prediction on raw inputs.
98
-
98
+
99
99
  This method takes raw sequences or tokenized inputs and returns
100
100
  token-level predictions. It processes the inputs through the model
101
101
  and returns the predicted class for each token.
102
102
 
103
103
  Args:
104
- sequence_or_inputs: A sequence (str), list of sequences, or
104
+ sequence_or_inputs: A sequence (str), list of sequences, or
105
105
  tokenized inputs (dict/tuple).
106
106
  **kwargs: Additional arguments for tokenization and inference.
107
107
 
@@ -115,7 +115,7 @@ class OmniModelForTokenClassification(OmniModel):
115
115
  >>> # Predict on a single sequence
116
116
  >>> outputs = model.predict("ATCGATCG")
117
117
  >>> print(outputs['predictions'].shape) # (seq_len,)
118
-
118
+
119
119
  >>> # Predict on multiple sequences
120
120
  >>> outputs = model.predict(["ATCGATCG", "GCTAGCTA"])
121
121
  """
@@ -142,12 +142,12 @@ class OmniModelForTokenClassification(OmniModel):
142
142
  def inference(self, sequence_or_inputs, **kwargs):
143
143
  """
144
144
  Performs token-level inference with human-readable output.
145
-
145
+
146
146
  This method provides processed, human-readable token-level predictions.
147
147
  It converts logits to class labels and handles special tokens appropriately.
148
148
 
149
149
  Args:
150
- sequence_or_inputs: A sequence (str), list of sequences, or
150
+ sequence_or_inputs: A sequence (str), list of sequences, or
151
151
  tokenized inputs (dict/tuple).
152
152
  **kwargs: Additional arguments for tokenization and inference.
153
153
 
@@ -200,7 +200,7 @@ class OmniModelForTokenClassification(OmniModel):
200
200
  def loss_function(self, logits, labels):
201
201
  """
202
202
  Calculates the cross-entropy loss for token classification.
203
-
203
+
204
204
  This method computes the cross-entropy loss between the predicted
205
205
  logits and the ground truth labels, ignoring padding tokens.
206
206
 
@@ -221,11 +221,11 @@ class OmniModelForTokenClassification(OmniModel):
221
221
  class OmniModelForSequenceClassification(OmniModel):
222
222
  """
223
223
  Model for sequence classification tasks in genomics.
224
-
224
+
225
225
  This model is designed for sequence-level classification tasks where
226
226
  the entire input sequence is classified into one of several categories.
227
227
  It extends the base OmniModel with sequence-level classification capabilities.
228
-
228
+
229
229
  The model uses a pooling mechanism to aggregate token-level representations
230
230
  into a sequence-level representation, which is then classified using a
231
231
  linear classifier.
@@ -263,7 +263,7 @@ class OmniModelForSequenceClassification(OmniModel):
263
263
  def forward(self, **inputs):
264
264
  """
265
265
  Forward pass for sequence classification.
266
-
266
+
267
267
  This method performs the forward pass through the model, computing
268
268
  sequence-level logits and applying softmax to produce probability
269
269
  distributions over the label classes.
@@ -302,13 +302,13 @@ class OmniModelForSequenceClassification(OmniModel):
302
302
  def predict(self, sequence_or_inputs, **kwargs):
303
303
  """
304
304
  Performs sequence-level prediction on raw inputs.
305
-
305
+
306
306
  This method takes raw sequences or tokenized inputs and returns
307
307
  sequence-level predictions. It processes the inputs through the model
308
308
  and returns the predicted class for each sequence.
309
309
 
310
310
  Args:
311
- sequence_or_inputs: A sequence (str), list of sequences, or
311
+ sequence_or_inputs: A sequence (str), list of sequences, or
312
312
  tokenized inputs (dict/tuple).
313
313
  **kwargs: Additional arguments for tokenization and inference.
314
314
 
@@ -322,7 +322,7 @@ class OmniModelForSequenceClassification(OmniModel):
322
322
  >>> # Predict on a single sequence
323
323
  >>> outputs = model.predict("ATCGATCG")
324
324
  >>> print(outputs['predictions']) # tensor([0])
325
-
325
+
326
326
  >>> # Predict on multiple sequences
327
327
  >>> outputs = model.predict(["ATCGATCG", "GCTAGCTA"])
328
328
  """
@@ -350,12 +350,12 @@ class OmniModelForSequenceClassification(OmniModel):
350
350
  def inference(self, sequence_or_inputs, **kwargs):
351
351
  """
352
352
  Performs sequence-level inference with human-readable output.
353
-
353
+
354
354
  This method provides processed, human-readable sequence-level predictions.
355
355
  It converts logits to class labels and provides confidence scores.
356
356
 
357
357
  Args:
358
- sequence_or_inputs: A sequence (str), list of sequences, or
358
+ sequence_or_inputs: A sequence (str), list of sequences, or
359
359
  tokenized inputs (dict/tuple).
360
360
  **kwargs: Additional arguments for tokenization and inference.
361
361
 
@@ -403,7 +403,7 @@ class OmniModelForSequenceClassification(OmniModel):
403
403
  def loss_function(self, logits, labels):
404
404
  """
405
405
  Calculates the cross-entropy loss for sequence classification.
406
-
406
+
407
407
  This method computes the cross-entropy loss between the predicted
408
408
  logits and the ground truth labels.
409
409
 
@@ -421,16 +421,14 @@ class OmniModelForSequenceClassification(OmniModel):
421
421
  return loss
422
422
 
423
423
 
424
- class OmniModelForMultiLabelSequenceClassification(
425
- OmniModelForSequenceClassification
426
- ):
424
+ class OmniModelForMultiLabelSequenceClassification(OmniModelForSequenceClassification):
427
425
  """
428
426
  Model for multi-label sequence classification tasks in genomics.
429
-
427
+
430
428
  This model is designed for multi-label classification tasks where
431
429
  a single sequence can be assigned multiple labels simultaneously.
432
430
  It extends the sequence classification model with multi-label capabilities.
433
-
431
+
434
432
  The model uses sigmoid activation instead of softmax to allow multiple
435
433
  labels per sequence and uses binary cross-entropy loss for training.
436
434
 
@@ -461,7 +459,7 @@ class OmniModelForMultiLabelSequenceClassification(
461
459
  def loss_function(self, logits, labels):
462
460
  """
463
461
  Calculates the binary cross-entropy loss for multi-label classification.
464
-
462
+
465
463
  This method computes the binary cross-entropy loss between the predicted
466
464
  probabilities and the ground truth multi-label targets.
467
465
 
@@ -481,13 +479,13 @@ class OmniModelForMultiLabelSequenceClassification(
481
479
  def predict(self, sequence_or_inputs, **kwargs):
482
480
  """
483
481
  Performs multi-label prediction on raw inputs.
484
-
482
+
485
483
  This method takes raw sequences or tokenized inputs and returns
486
484
  multi-label predictions. It applies a threshold to determine
487
485
  which labels are active for each sequence.
488
486
 
489
487
  Args:
490
- sequence_or_inputs: A sequence (str), list of sequences, or
488
+ sequence_or_inputs: A sequence (str), list of sequences, or
491
489
  tokenized inputs (dict/tuple).
492
490
  **kwargs: Additional arguments for tokenization and inference.
493
491
 
@@ -527,12 +525,12 @@ class OmniModelForMultiLabelSequenceClassification(
527
525
  def inference(self, sequence_or_inputs, **kwargs):
528
526
  """
529
527
  Performs multi-label inference with human-readable output.
530
-
528
+
531
529
  This method provides processed, human-readable multi-label predictions.
532
530
  It converts logits to binary labels and provides confidence scores.
533
531
 
534
532
  Args:
535
- sequence_or_inputs: A sequence (str), list of sequences, or
533
+ sequence_or_inputs: A sequence (str), list of sequences, or
536
534
  tokenized inputs (dict/tuple).
537
535
  **kwargs: Additional arguments for tokenization and inference.
538
536
 
@@ -551,9 +549,7 @@ class OmniModelForMultiLabelSequenceClassification(
551
549
  return self.predict(sequence_or_inputs, **kwargs)
552
550
 
553
551
 
554
- class OmniModelForTokenClassificationWith2DStructure(
555
- OmniModelForTokenClassification
556
- ):
552
+ class OmniModelForTokenClassificationWith2DStructure(OmniModelForTokenClassification):
557
553
  def __init__(self, config_or_model, tokenizer, *args, **kwargs):
558
554
  super().__init__(config_or_model, tokenizer, *args, **kwargs)
559
555
  self.metadata["model_name"] = self.__class__.__name__
@@ -9,4 +9,3 @@
9
9
  """
10
10
  This package contains modules for embedding models.
11
11
  """
12
-
@@ -16,16 +16,16 @@ from omnigenome.src.misc.utils import fprint
16
16
  class OmniModelForEmbedding(torch.nn.Module):
17
17
  """
18
18
  A wrapper class for generating embeddings from pre-trained models.
19
-
19
+
20
20
  This class provides a unified interface for loading pre-trained models and
21
21
  generating embeddings from genomic sequences. It supports various aggregation
22
22
  methods and batch processing for efficient embedding generation.
23
-
23
+
24
24
  Attributes:
25
25
  tokenizer: The tokenizer for processing input sequences
26
26
  model: The pre-trained model for generating embeddings
27
27
  _device: The device (CPU/GPU) where the model is loaded
28
-
28
+
29
29
  Example:
30
30
  >>> from omnigenome import OmniModelForEmbedding
31
31
  >>> model = OmniModelForEmbedding("anonymous8/OmniGenome-186M")
@@ -34,11 +34,11 @@ class OmniModelForEmbedding(torch.nn.Module):
34
34
  >>> print(f"Embeddings shape: {embeddings.shape}")
35
35
  torch.Size([2, 768])
36
36
  """
37
-
37
+
38
38
  def __init__(self, model_name_or_path, *args, **kwargs):
39
39
  """
40
40
  Initialize the embedding model.
41
-
41
+
42
42
  Args:
43
43
  model_name_or_path (str): Name or path of the pre-trained model to load
44
44
  *args: Additional positional arguments passed to AutoModel.from_pretrained
@@ -51,25 +51,25 @@ class OmniModelForEmbedding(torch.nn.Module):
51
51
  self.model.to(self._device)
52
52
  self.model.eval() # Set model to evaluation mode
53
53
 
54
- def batch_encode(self, sequences, batch_size=8, max_length=512, agg='head'):
54
+ def batch_encode(self, sequences, batch_size=8, max_length=512, agg="head"):
55
55
  """
56
56
  Encode a list of sequences to their corresponding embeddings.
57
-
57
+
58
58
  This method processes sequences in batches for memory efficiency and
59
59
  supports different aggregation methods for the final embeddings.
60
-
60
+
61
61
  Args:
62
62
  sequences (list of str): List of input sequences to encode
63
63
  batch_size (int, optional): Batch size for processing. Defaults to 8
64
64
  max_length (int, optional): Maximum sequence length for encoding. Defaults to 512
65
65
  agg (str, optional): Aggregation method for embeddings. Options are 'head', 'mean', 'tail'. Defaults to 'head'
66
-
66
+
67
67
  Returns:
68
68
  torch.Tensor: Embeddings for the input sequences with shape (n_sequences, embedding_dim)
69
-
69
+
70
70
  Raises:
71
71
  ValueError: If unsupported aggregation method is provided
72
-
72
+
73
73
  Example:
74
74
  >>> sequences = ["ATCGGCTA", "GGCTAGCTA", "TATCGCTA"]
75
75
  >>> embeddings = model.batch_encode(sequences, batch_size=2, agg='mean')
@@ -79,7 +79,7 @@ class OmniModelForEmbedding(torch.nn.Module):
79
79
  embeddings = []
80
80
 
81
81
  for i in range(0, len(sequences), batch_size):
82
- batch_sequences = sequences[i: i + batch_size]
82
+ batch_sequences = sequences[i : i + batch_size]
83
83
  inputs = self.tokenizer(
84
84
  batch_sequences,
85
85
  return_tensors="pt",
@@ -94,19 +94,19 @@ class OmniModelForEmbedding(torch.nn.Module):
94
94
 
95
95
  batch_embeddings = outputs.last_hidden_state.cpu()
96
96
 
97
- if agg == 'head':
97
+ if agg == "head":
98
98
  emb = batch_embeddings[:, 0, :]
99
- elif agg == 'mean':
99
+ elif agg == "mean":
100
100
  attention_mask = inputs["attention_mask"].cpu()
101
101
  masked_embeddings = batch_embeddings * attention_mask.unsqueeze(-1)
102
102
  lengths = attention_mask.sum(dim=1).unsqueeze(1)
103
103
  emb = masked_embeddings.sum(dim=1) / lengths
104
- elif agg == 'tail':
104
+ elif agg == "tail":
105
105
  attention_mask = inputs["attention_mask"]
106
106
  lengths = attention_mask.sum(dim=1) - 1
107
- emb = torch.stack([
108
- batch_embeddings[i, l.item()] for i, l in enumerate(lengths)
109
- ])
107
+ emb = torch.stack(
108
+ [batch_embeddings[i, l.item()] for i, l in enumerate(lengths)]
109
+ )
110
110
  else:
111
111
  raise ValueError(f"Unsupported aggregation method: {agg}")
112
112
 
@@ -116,22 +116,22 @@ class OmniModelForEmbedding(torch.nn.Module):
116
116
  fprint(f"Generated embeddings for {len(sequences)} sequences.")
117
117
  return embeddings
118
118
 
119
- def encode(self, sequence, max_length=512, agg='head', keep_dim=False):
119
+ def encode(self, sequence, max_length=512, agg="head", keep_dim=False):
120
120
  """
121
121
  Encode a single sequence to its corresponding embedding.
122
-
122
+
123
123
  Args:
124
124
  sequence (str): Input sequence to encode
125
125
  max_length (int, optional): Maximum sequence length for encoding. Defaults to 512
126
126
  agg (str, optional): Aggregation method. Options are 'head', 'mean', 'tail'. Defaults to 'head'
127
127
  keep_dim (bool, optional): Whether to retain the batch dimension. Defaults to False
128
-
128
+
129
129
  Returns:
130
130
  torch.Tensor: Embedding for the input sequence
131
-
131
+
132
132
  Raises:
133
133
  ValueError: If unsupported aggregation method is provided
134
-
134
+
135
135
  Example:
136
136
  >>> sequence = "ATCGGCTA"
137
137
  >>> embedding = model.encode(sequence, agg='mean')
@@ -152,15 +152,15 @@ class OmniModelForEmbedding(torch.nn.Module):
152
152
 
153
153
  last_hidden = outputs.last_hidden_state.cpu()
154
154
 
155
- if agg == 'head':
155
+ if agg == "head":
156
156
  emb = last_hidden[0, 0]
157
- elif agg == 'mean':
157
+ elif agg == "mean":
158
158
  attention_mask = inputs["attention_mask"].cpu()
159
159
  masked_embeddings = last_hidden * attention_mask.unsqueeze(-1)
160
160
  lengths = attention_mask.sum(dim=1).unsqueeze(1)
161
161
  emb = masked_embeddings.sum(dim=1) / lengths
162
162
  emb = emb.squeeze(0)
163
- elif agg == 'tail':
163
+ elif agg == "tail":
164
164
  attention_mask = inputs["attention_mask"]
165
165
  lengths = attention_mask.sum(dim=1) - 1
166
166
  emb = last_hidden[0, lengths[0].item()]
@@ -172,11 +172,11 @@ class OmniModelForEmbedding(torch.nn.Module):
172
172
  def save_embeddings(self, embeddings, output_path):
173
173
  """
174
174
  Save the generated embeddings to a file.
175
-
175
+
176
176
  Args:
177
177
  embeddings (torch.Tensor): The embeddings to save
178
178
  output_path (str): Path to save the embeddings
179
-
179
+
180
180
  Example:
181
181
  >>> embeddings = model.batch_encode(sequences)
182
182
  >>> model.save_embeddings(embeddings, "embeddings.pt")
@@ -188,13 +188,13 @@ class OmniModelForEmbedding(torch.nn.Module):
188
188
  def load_embeddings(self, embedding_path):
189
189
  """
190
190
  Load embeddings from a file.
191
-
191
+
192
192
  Args:
193
193
  embedding_path (str): Path to the saved embeddings
194
-
194
+
195
195
  Returns:
196
196
  torch.Tensor: The loaded embeddings
197
-
197
+
198
198
  Example:
199
199
  >>> embeddings = model.load_embeddings("embeddings.pt")
200
200
  >>> print(f"Loaded embeddings shape: {embeddings.shape}")
@@ -207,15 +207,15 @@ class OmniModelForEmbedding(torch.nn.Module):
207
207
  def compute_similarity(self, embedding1, embedding2, dim=0):
208
208
  """
209
209
  Compute cosine similarity between two embeddings.
210
-
210
+
211
211
  Args:
212
212
  embedding1 (torch.Tensor): The first embedding
213
213
  embedding2 (torch.Tensor): The second embedding
214
214
  dim (int, optional): Dimension along which to compute cosine similarity. Defaults to 0
215
-
215
+
216
216
  Returns:
217
217
  float: Cosine similarity score between -1 and 1
218
-
218
+
219
219
  Example:
220
220
  >>> emb1 = model.encode("ATCGGCTA")
221
221
  >>> emb2 = model.encode("GGCTAGCTA")
@@ -232,7 +232,7 @@ class OmniModelForEmbedding(torch.nn.Module):
232
232
  def device(self):
233
233
  """
234
234
  Get the current device ('cuda' or 'cpu').
235
-
235
+
236
236
  Returns:
237
237
  torch.device: The device where the model is loaded
238
238
  """
@@ -9,4 +9,3 @@
9
9
  """
10
10
  This package contains modules for Masked Language Models (MLM).
11
11
  """
12
-
@@ -23,26 +23,26 @@ from ...abc.abstract_model import OmniModel
23
23
  class OmniModelForMLM(OmniModel):
24
24
  """
25
25
  Masked Language Model for genomic sequences.
26
-
26
+
27
27
  This model implements masked language modeling for genomic sequences, where
28
28
  tokens are randomly masked and the model learns to predict the original tokens.
29
29
  It's useful for pre-training genomic language models and understanding sequence
30
30
  patterns and dependencies.
31
-
31
+
32
32
  Attributes:
33
33
  loss_fn: Cross-entropy loss function for masked language modeling
34
34
  """
35
-
35
+
36
36
  def __init__(self, config_or_model, tokenizer, *args, **kwargs):
37
37
  """
38
38
  Initialize the MLM model.
39
-
39
+
40
40
  Args:
41
41
  config_or_model: Model configuration or pre-trained model
42
42
  tokenizer: Tokenizer for processing input sequences
43
43
  *args: Additional positional arguments
44
44
  **kwargs: Additional keyword arguments
45
-
45
+
46
46
  Raises:
47
47
  ValueError: If the model doesn't support masked language modeling
48
48
  """
@@ -59,10 +59,10 @@ class OmniModelForMLM(OmniModel):
59
59
  def forward(self, **inputs):
60
60
  """
61
61
  Forward pass for masked language modeling.
62
-
62
+
63
63
  Args:
64
64
  **inputs: Input tensors including input_ids, attention_mask, and labels
65
-
65
+
66
66
  Returns:
67
67
  dict: Dictionary containing loss, logits, and last_hidden_state
68
68
  """
@@ -85,11 +85,11 @@ class OmniModelForMLM(OmniModel):
85
85
  def predict(self, sequence_or_inputs, **kwargs):
86
86
  """
87
87
  Generate predictions for masked language modeling.
88
-
88
+
89
89
  Args:
90
90
  sequence_or_inputs: Input sequences or pre-processed inputs
91
91
  **kwargs: Additional keyword arguments
92
-
92
+
93
93
  Returns:
94
94
  dict: Dictionary containing predictions, logits, and last_hidden_state
95
95
  """
@@ -124,11 +124,11 @@ class OmniModelForMLM(OmniModel):
124
124
  def inference(self, sequence_or_inputs, **kwargs):
125
125
  """
126
126
  Perform inference for masked language modeling, decoding predictions to sequences.
127
-
127
+
128
128
  Args:
129
129
  sequence_or_inputs: Input sequences or pre-processed inputs
130
130
  **kwargs: Additional keyword arguments
131
-
131
+
132
132
  Returns:
133
133
  dict: Dictionary containing decoded predictions, logits, and last_hidden_state
134
134
  """
@@ -164,11 +164,11 @@ class OmniModelForMLM(OmniModel):
164
164
  def loss_function(self, logits, labels):
165
165
  """
166
166
  Compute the loss for masked language modeling.
167
-
167
+
168
168
  Args:
169
169
  logits (torch.Tensor): Model predictions [batch_size, seq_len, vocab_size]
170
170
  labels (torch.Tensor): Ground truth labels [batch_size, seq_len]
171
-
171
+
172
172
  Returns:
173
173
  torch.Tensor: Computed cross-entropy loss value
174
174
  """