omnigenome 0.3.0a0__py3-none-any.whl → 0.3.1a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- omnigenome/__init__.py +29 -44
- omnigenome/auto/auto_bench/__init__.py +0 -1
- omnigenome/auto/auto_bench/auto_bench.py +24 -14
- omnigenome/auto/auto_train/__init__.py +0 -1
- omnigenome/auto/auto_train/auto_train.py +11 -12
- omnigenome/auto/bench_hub/__init__.py +0 -1
- omnigenome/auto/bench_hub/bench_hub.py +1 -1
- omnigenome/cli/__init__.py +0 -1
- omnigenome/cli/commands/__init__.py +0 -1
- omnigenome/cli/commands/base.py +10 -10
- omnigenome/cli/commands/bench/__init__.py +0 -1
- omnigenome/cli/commands/bench/bench_cli.py +10 -10
- omnigenome/cli/commands/rna/__init__.py +0 -1
- omnigenome/cli/commands/rna/rna_design.py +10 -11
- omnigenome/src/__init__.py +0 -1
- omnigenome/src/abc/__init__.py +0 -1
- omnigenome/src/abc/abstract_dataset.py +38 -19
- omnigenome/src/abc/abstract_metric.py +7 -7
- omnigenome/src/abc/abstract_model.py +15 -14
- omnigenome/src/abc/abstract_tokenizer.py +9 -7
- omnigenome/src/dataset/omni_dataset.py +16 -14
- omnigenome/src/lora/__init__.py +0 -1
- omnigenome/src/lora/lora_model.py +47 -41
- omnigenome/src/metric/classification_metric.py +11 -11
- omnigenome/src/metric/metric.py +19 -19
- omnigenome/src/metric/ranking_metric.py +15 -15
- omnigenome/src/metric/regression_metric.py +18 -18
- omnigenome/src/misc/utils.py +214 -150
- omnigenome/src/model/augmentation/__init__.py +0 -1
- omnigenome/src/model/augmentation/model.py +17 -17
- omnigenome/src/model/classification/__init__.py +0 -1
- omnigenome/src/model/classification/model.py +28 -32
- omnigenome/src/model/embedding/__init__.py +0 -1
- omnigenome/src/model/embedding/model.py +35 -35
- omnigenome/src/model/mlm/__init__.py +0 -1
- omnigenome/src/model/mlm/model.py +13 -13
- omnigenome/src/model/module_utils.py +17 -17
- omnigenome/src/model/regression/__init__.py +0 -1
- omnigenome/src/model/regression/model.py +72 -77
- omnigenome/src/model/regression/resnet.py +32 -32
- omnigenome/src/model/rna_design/__init__.py +0 -1
- omnigenome/src/model/rna_design/model.py +168 -118
- omnigenome/src/model/seq2seq/__init__.py +0 -1
- omnigenome/src/model/seq2seq/model.py +4 -4
- omnigenome/src/tokenizer/bpe_tokenizer.py +27 -27
- omnigenome/src/tokenizer/kmers_tokenizer.py +22 -22
- omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +11 -11
- omnigenome/src/trainer/accelerate_trainer.py +40 -32
- omnigenome/src/trainer/hf_trainer.py +8 -8
- omnigenome/src/trainer/trainer.py +37 -25
- omnigenome/utility/dataset_hub/__init__.py +0 -1
- omnigenome/utility/dataset_hub/dataset_hub.py +13 -13
- omnigenome/utility/ensemble.py +26 -26
- omnigenome/utility/hub_utils.py +8 -8
- omnigenome/utility/model_hub/__init__.py +0 -1
- omnigenome/utility/model_hub/model_hub.py +26 -25
- omnigenome/utility/pipeline_hub/__init__.py +0 -1
- omnigenome/utility/pipeline_hub/pipeline.py +49 -49
- omnigenome/utility/pipeline_hub/pipeline_hub.py +17 -17
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/METADATA +3 -3
- omnigenome-0.3.1a0.dist-info/RECORD +78 -0
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/top_level.txt +0 -1
- omnigenome-0.3.0a0.dist-info/RECORD +0 -85
- tests/__init__.py +0 -9
- tests/conftest.py +0 -160
- tests/test_dataset_patterns.py +0 -291
- tests/test_examples_syntax.py +0 -83
- tests/test_model_loading.py +0 -183
- tests/test_rna_functions.py +0 -255
- tests/test_training_patterns.py +0 -302
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/WHEEL +0 -0
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/entry_points.txt +0 -0
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/licenses/LICENSE +0 -0
|
@@ -22,21 +22,21 @@ from transformers.tokenization_utils_base import BatchEncoding
|
|
|
22
22
|
class OmniPooling(torch.nn.Module):
|
|
23
23
|
"""
|
|
24
24
|
A flexible pooling layer for OmniGenome models that handles different input formats.
|
|
25
|
-
|
|
25
|
+
|
|
26
26
|
This class provides a unified interface for pooling operations across different
|
|
27
27
|
model architectures, supporting both causal language models and encoder-based models.
|
|
28
28
|
It can handle various input formats including tuples, dictionaries, BatchEncoding
|
|
29
29
|
objects, and tensors.
|
|
30
|
-
|
|
30
|
+
|
|
31
31
|
Attributes:
|
|
32
32
|
config: Model configuration object containing architecture and tokenizer settings
|
|
33
33
|
pooler: BertPooler instance for non-causal models, None for causal models
|
|
34
34
|
"""
|
|
35
|
-
|
|
35
|
+
|
|
36
36
|
def __init__(self, config, *args, **kwargs):
|
|
37
37
|
"""
|
|
38
38
|
Initialize the OmniPooling layer.
|
|
39
|
-
|
|
39
|
+
|
|
40
40
|
Args:
|
|
41
41
|
config: Model configuration object containing architecture information
|
|
42
42
|
*args: Additional positional arguments
|
|
@@ -49,18 +49,18 @@ class OmniPooling(torch.nn.Module):
|
|
|
49
49
|
def forward(self, inputs, last_hidden_state):
|
|
50
50
|
"""
|
|
51
51
|
Perform pooling operation on the last hidden state.
|
|
52
|
-
|
|
52
|
+
|
|
53
53
|
This method handles different input formats and applies appropriate pooling:
|
|
54
54
|
- For causal language models: Uses the last non-padded token
|
|
55
55
|
- For encoder models: Uses the BertPooler
|
|
56
|
-
|
|
56
|
+
|
|
57
57
|
Args:
|
|
58
58
|
inputs: Input data in various formats (tuple, dict, BatchEncoding, or tensor)
|
|
59
59
|
last_hidden_state (torch.Tensor): Hidden states from the model [batch_size, seq_len, hidden_size]
|
|
60
|
-
|
|
60
|
+
|
|
61
61
|
Returns:
|
|
62
62
|
torch.Tensor: Pooled representation [batch_size, hidden_size]
|
|
63
|
-
|
|
63
|
+
|
|
64
64
|
Raises:
|
|
65
65
|
ValueError: If input format is not supported or cannot be parsed
|
|
66
66
|
"""
|
|
@@ -110,9 +110,9 @@ class OmniPooling(torch.nn.Module):
|
|
|
110
110
|
def _is_causal_lm(self):
|
|
111
111
|
"""
|
|
112
112
|
Check if the model is a causal language model.
|
|
113
|
-
|
|
113
|
+
|
|
114
114
|
Determines if the model architecture is causal based on the configuration.
|
|
115
|
-
|
|
115
|
+
|
|
116
116
|
Returns:
|
|
117
117
|
bool: True if the model is a causal language model, False otherwise
|
|
118
118
|
"""
|
|
@@ -175,25 +175,25 @@ class OmniPooling(torch.nn.Module):
|
|
|
175
175
|
class InteractingAttention(nn.Module):
|
|
176
176
|
"""
|
|
177
177
|
An interacting attention mechanism for sequence modeling.
|
|
178
|
-
|
|
178
|
+
|
|
179
179
|
This class implements a multi-head attention mechanism with residual connections
|
|
180
180
|
and layer normalization. It's designed for processing sequences where different
|
|
181
181
|
parts of the sequence need to interact with each other.
|
|
182
|
-
|
|
182
|
+
|
|
183
183
|
Attributes:
|
|
184
184
|
attention: Multi-head attention layer
|
|
185
185
|
layer_norm: Layer normalization for residual connections
|
|
186
186
|
fc_out: Output projection layer
|
|
187
187
|
"""
|
|
188
|
-
|
|
188
|
+
|
|
189
189
|
def __init__(self, embed_size, num_heads=24):
|
|
190
190
|
"""
|
|
191
191
|
Initialize the InteractingAttention module.
|
|
192
|
-
|
|
192
|
+
|
|
193
193
|
Args:
|
|
194
194
|
embed_size (int): Size of the embedding dimension
|
|
195
195
|
num_heads (int): Number of attention heads (default: 24)
|
|
196
|
-
|
|
196
|
+
|
|
197
197
|
Raises:
|
|
198
198
|
AssertionError: If embed_size is not divisible by num_heads
|
|
199
199
|
"""
|
|
@@ -213,12 +213,12 @@ class InteractingAttention(nn.Module):
|
|
|
213
213
|
def forward(self, query, keys, values):
|
|
214
214
|
"""
|
|
215
215
|
Forward pass through the interacting attention mechanism.
|
|
216
|
-
|
|
216
|
+
|
|
217
217
|
Args:
|
|
218
218
|
query (torch.Tensor): Query tensor [batch_size, query_len, embed_size]
|
|
219
219
|
keys (torch.Tensor): Key tensor [batch_size, key_len, embed_size]
|
|
220
220
|
values (torch.Tensor): Value tensor [batch_size, value_len, embed_size]
|
|
221
|
-
|
|
221
|
+
|
|
222
222
|
Returns:
|
|
223
223
|
torch.Tensor: Output tensor with same shape as query
|
|
224
224
|
"""
|
|
@@ -23,21 +23,21 @@ from ..module_utils import OmniPooling
|
|
|
23
23
|
class OmniModelForTokenRegression(OmniModel):
|
|
24
24
|
"""
|
|
25
25
|
Token-level regression model for genomic sequences.
|
|
26
|
-
|
|
26
|
+
|
|
27
27
|
This model performs regression at the token level, predicting continuous values
|
|
28
28
|
for each token in the input sequence. It's useful for tasks like predicting
|
|
29
29
|
binding affinities, expression levels, or other continuous properties at each
|
|
30
30
|
position in a genomic sequence.
|
|
31
|
-
|
|
31
|
+
|
|
32
32
|
Attributes:
|
|
33
33
|
classifier: Linear layer for regression output
|
|
34
34
|
loss_fn: Mean squared error loss function
|
|
35
35
|
"""
|
|
36
|
-
|
|
36
|
+
|
|
37
37
|
def __init__(self, config_or_model, tokenizer, *args, **kwargs):
|
|
38
38
|
"""
|
|
39
39
|
Initialize the token regression model.
|
|
40
|
-
|
|
40
|
+
|
|
41
41
|
Args:
|
|
42
42
|
config_or_model: Model configuration or pre-trained model
|
|
43
43
|
tokenizer: Tokenizer for processing input sequences
|
|
@@ -55,10 +55,10 @@ class OmniModelForTokenRegression(OmniModel):
|
|
|
55
55
|
def forward(self, **inputs):
|
|
56
56
|
"""
|
|
57
57
|
Forward pass for token-level regression.
|
|
58
|
-
|
|
58
|
+
|
|
59
59
|
Args:
|
|
60
60
|
**inputs: Input tensors including input_ids, attention_mask, and labels
|
|
61
|
-
|
|
61
|
+
|
|
62
62
|
Returns:
|
|
63
63
|
dict: Dictionary containing logits, last_hidden_state, and labels
|
|
64
64
|
"""
|
|
@@ -77,11 +77,11 @@ class OmniModelForTokenRegression(OmniModel):
|
|
|
77
77
|
def predict(self, sequence_or_inputs, **kwargs):
|
|
78
78
|
"""
|
|
79
79
|
Generate predictions for token-level regression.
|
|
80
|
-
|
|
80
|
+
|
|
81
81
|
Args:
|
|
82
82
|
sequence_or_inputs: Input sequences or pre-processed inputs
|
|
83
83
|
**kwargs: Additional keyword arguments
|
|
84
|
-
|
|
84
|
+
|
|
85
85
|
Returns:
|
|
86
86
|
dict: Dictionary containing predictions, logits, and last_hidden_state
|
|
87
87
|
"""
|
|
@@ -109,11 +109,11 @@ class OmniModelForTokenRegression(OmniModel):
|
|
|
109
109
|
def inference(self, sequence_or_inputs, **kwargs):
|
|
110
110
|
"""
|
|
111
111
|
Perform inference for token-level regression, excluding special tokens.
|
|
112
|
-
|
|
112
|
+
|
|
113
113
|
Args:
|
|
114
114
|
sequence_or_inputs: Input sequences or pre-processed inputs
|
|
115
115
|
**kwargs: Additional keyword arguments
|
|
116
|
-
|
|
116
|
+
|
|
117
117
|
Returns:
|
|
118
118
|
dict: Dictionary containing predictions, logits, and last_hidden_state
|
|
119
119
|
"""
|
|
@@ -148,11 +148,11 @@ class OmniModelForTokenRegression(OmniModel):
|
|
|
148
148
|
def loss_function(self, logits, labels):
|
|
149
149
|
"""
|
|
150
150
|
Compute the loss for token-level regression.
|
|
151
|
-
|
|
151
|
+
|
|
152
152
|
Args:
|
|
153
153
|
logits (torch.Tensor): Model predictions
|
|
154
154
|
labels (torch.Tensor): Ground truth labels
|
|
155
|
-
|
|
155
|
+
|
|
156
156
|
Returns:
|
|
157
157
|
torch.Tensor: Computed loss value
|
|
158
158
|
"""
|
|
@@ -173,22 +173,22 @@ class OmniModelForTokenRegression(OmniModel):
|
|
|
173
173
|
class OmniModelForSequenceRegression(OmniModel):
|
|
174
174
|
"""
|
|
175
175
|
Sequence-level regression model for genomic sequences.
|
|
176
|
-
|
|
176
|
+
|
|
177
177
|
This model performs regression at the sequence level, predicting a single
|
|
178
178
|
continuous value for the entire input sequence. It's useful for tasks like
|
|
179
179
|
predicting overall expression levels, binding affinities, or other sequence-level
|
|
180
180
|
properties.
|
|
181
|
-
|
|
181
|
+
|
|
182
182
|
Attributes:
|
|
183
183
|
pooler: OmniPooling layer for sequence-level representation
|
|
184
184
|
classifier: Linear layer for regression output
|
|
185
185
|
loss_fn: Mean squared error loss function
|
|
186
186
|
"""
|
|
187
|
-
|
|
187
|
+
|
|
188
188
|
def __init__(self, config_or_model, tokenizer, *args, **kwargs):
|
|
189
189
|
"""
|
|
190
190
|
Initialize the sequence regression model.
|
|
191
|
-
|
|
191
|
+
|
|
192
192
|
Args:
|
|
193
193
|
config_or_model: Model configuration or pre-trained model
|
|
194
194
|
tokenizer: Tokenizer for processing input sequences
|
|
@@ -207,10 +207,10 @@ class OmniModelForSequenceRegression(OmniModel):
|
|
|
207
207
|
def forward(self, **inputs):
|
|
208
208
|
"""
|
|
209
209
|
Forward pass for sequence-level regression.
|
|
210
|
-
|
|
210
|
+
|
|
211
211
|
Args:
|
|
212
212
|
**inputs: Input tensors including input_ids, attention_mask, and labels
|
|
213
|
-
|
|
213
|
+
|
|
214
214
|
Returns:
|
|
215
215
|
dict: Dictionary containing logits, last_hidden_state, and labels
|
|
216
216
|
"""
|
|
@@ -230,11 +230,11 @@ class OmniModelForSequenceRegression(OmniModel):
|
|
|
230
230
|
def predict(self, sequence_or_inputs, **kwargs):
|
|
231
231
|
"""
|
|
232
232
|
Generate predictions for sequence-level regression.
|
|
233
|
-
|
|
233
|
+
|
|
234
234
|
Args:
|
|
235
235
|
sequence_or_inputs: Input sequences or pre-processed inputs
|
|
236
236
|
**kwargs: Additional keyword arguments
|
|
237
|
-
|
|
237
|
+
|
|
238
238
|
Returns:
|
|
239
239
|
dict: Dictionary containing predictions, logits, and last_hidden_state
|
|
240
240
|
"""
|
|
@@ -262,11 +262,11 @@ class OmniModelForSequenceRegression(OmniModel):
|
|
|
262
262
|
def inference(self, sequence_or_inputs, **kwargs):
|
|
263
263
|
"""
|
|
264
264
|
Perform inference for sequence-level regression.
|
|
265
|
-
|
|
265
|
+
|
|
266
266
|
Args:
|
|
267
267
|
sequence_or_inputs: Input sequences or pre-processed inputs
|
|
268
268
|
**kwargs: Additional keyword arguments
|
|
269
|
-
|
|
269
|
+
|
|
270
270
|
Returns:
|
|
271
271
|
dict: Dictionary containing predictions, logits, and last_hidden_state
|
|
272
272
|
"""
|
|
@@ -297,11 +297,11 @@ class OmniModelForSequenceRegression(OmniModel):
|
|
|
297
297
|
def loss_function(self, logits, labels):
|
|
298
298
|
"""
|
|
299
299
|
Compute the loss for sequence-level regression.
|
|
300
|
-
|
|
300
|
+
|
|
301
301
|
Args:
|
|
302
302
|
logits (torch.Tensor): Model predictions
|
|
303
303
|
labels (torch.Tensor): Ground truth labels
|
|
304
|
-
|
|
304
|
+
|
|
305
305
|
Returns:
|
|
306
306
|
torch.Tensor: Computed loss value
|
|
307
307
|
"""
|
|
@@ -322,20 +322,20 @@ class OmniModelForSequenceRegression(OmniModel):
|
|
|
322
322
|
class OmniModelForStructuralImputation(OmniModelForSequenceRegression):
|
|
323
323
|
"""
|
|
324
324
|
Structural imputation model for genomic sequences.
|
|
325
|
-
|
|
325
|
+
|
|
326
326
|
This model is specialized for imputing missing structural information in
|
|
327
327
|
genomic sequences. It extends the sequence regression model with additional
|
|
328
328
|
embedding capabilities for structural features.
|
|
329
|
-
|
|
329
|
+
|
|
330
330
|
Attributes:
|
|
331
331
|
embedding: Embedding layer for structural features
|
|
332
332
|
loss_fn: Mean squared error loss function
|
|
333
333
|
"""
|
|
334
|
-
|
|
334
|
+
|
|
335
335
|
def __init__(self, config_or_model, tokenizer, *args, **kwargs):
|
|
336
336
|
"""
|
|
337
337
|
Initialize the structural imputation model.
|
|
338
|
-
|
|
338
|
+
|
|
339
339
|
Args:
|
|
340
340
|
config_or_model: Model configuration or pre-trained model
|
|
341
341
|
tokenizer: Tokenizer for processing input sequences
|
|
@@ -351,10 +351,10 @@ class OmniModelForStructuralImputation(OmniModelForSequenceRegression):
|
|
|
351
351
|
def forward(self, **inputs):
|
|
352
352
|
"""
|
|
353
353
|
Forward pass for structural imputation.
|
|
354
|
-
|
|
354
|
+
|
|
355
355
|
Args:
|
|
356
356
|
**inputs: Input tensors including input_ids, attention_mask, and labels
|
|
357
|
-
|
|
357
|
+
|
|
358
358
|
Returns:
|
|
359
359
|
dict: Dictionary containing logits, last_hidden_state, and labels
|
|
360
360
|
"""
|
|
@@ -372,21 +372,19 @@ class OmniModelForStructuralImputation(OmniModelForSequenceRegression):
|
|
|
372
372
|
return outputs
|
|
373
373
|
|
|
374
374
|
|
|
375
|
-
class OmniModelForTokenRegressionWith2DStructure(
|
|
376
|
-
OmniModelForTokenRegression
|
|
377
|
-
):
|
|
375
|
+
class OmniModelForTokenRegressionWith2DStructure(OmniModelForTokenRegression):
|
|
378
376
|
"""
|
|
379
377
|
Token-level regression model with 2D structural information.
|
|
380
|
-
|
|
378
|
+
|
|
381
379
|
This model extends the basic token regression model to incorporate
|
|
382
380
|
2D structural information, useful for RNA structure prediction
|
|
383
381
|
and other structural genomics tasks.
|
|
384
382
|
"""
|
|
385
|
-
|
|
383
|
+
|
|
386
384
|
def __init__(self, config_or_model, tokenizer, *args, **kwargs):
|
|
387
385
|
"""
|
|
388
386
|
Initialize the 2D structure-aware token regression model.
|
|
389
|
-
|
|
387
|
+
|
|
390
388
|
Args:
|
|
391
389
|
config_or_model: Model configuration or pre-trained model
|
|
392
390
|
tokenizer: Tokenizer for processing input sequences
|
|
@@ -399,10 +397,10 @@ class OmniModelForTokenRegressionWith2DStructure(
|
|
|
399
397
|
def forward(self, **inputs):
|
|
400
398
|
"""
|
|
401
399
|
Forward pass for 2D structure-aware token regression.
|
|
402
|
-
|
|
400
|
+
|
|
403
401
|
Args:
|
|
404
402
|
**inputs: Input tensors including input_ids, attention_mask, labels, and structural info
|
|
405
|
-
|
|
403
|
+
|
|
406
404
|
Returns:
|
|
407
405
|
dict: Dictionary containing logits, last_hidden_state, and labels
|
|
408
406
|
"""
|
|
@@ -419,21 +417,19 @@ class OmniModelForTokenRegressionWith2DStructure(
|
|
|
419
417
|
return outputs
|
|
420
418
|
|
|
421
419
|
|
|
422
|
-
class OmniModelForSequenceRegressionWith2DStructure(
|
|
423
|
-
OmniModelForSequenceRegression
|
|
424
|
-
):
|
|
420
|
+
class OmniModelForSequenceRegressionWith2DStructure(OmniModelForSequenceRegression):
|
|
425
421
|
"""
|
|
426
422
|
Sequence-level regression model with 2D structural information.
|
|
427
|
-
|
|
423
|
+
|
|
428
424
|
This model extends the basic sequence regression model to incorporate
|
|
429
425
|
2D structural information, useful for RNA structure prediction
|
|
430
426
|
and other structural genomics tasks.
|
|
431
427
|
"""
|
|
432
|
-
|
|
428
|
+
|
|
433
429
|
def __init__(self, config_or_model, tokenizer, *args, **kwargs):
|
|
434
430
|
"""
|
|
435
431
|
Initialize the 2D structure-aware sequence regression model.
|
|
436
|
-
|
|
432
|
+
|
|
437
433
|
Args:
|
|
438
434
|
config_or_model: Model configuration or pre-trained model
|
|
439
435
|
tokenizer: Tokenizer for processing input sequences
|
|
@@ -446,10 +442,10 @@ class OmniModelForSequenceRegressionWith2DStructure(
|
|
|
446
442
|
def forward(self, **inputs):
|
|
447
443
|
"""
|
|
448
444
|
Forward pass for 2D structure-aware sequence regression.
|
|
449
|
-
|
|
445
|
+
|
|
450
446
|
Args:
|
|
451
447
|
**inputs: Input tensors including input_ids, attention_mask, labels, and structural info
|
|
452
|
-
|
|
448
|
+
|
|
453
449
|
Returns:
|
|
454
450
|
dict: Dictionary containing logits, last_hidden_state, and labels
|
|
455
451
|
"""
|
|
@@ -470,21 +466,21 @@ class OmniModelForSequenceRegressionWith2DStructure(
|
|
|
470
466
|
class OmniModelForMatrixRegression(OmniModel):
|
|
471
467
|
"""
|
|
472
468
|
Matrix regression model for genomic sequences.
|
|
473
|
-
|
|
469
|
+
|
|
474
470
|
This model performs regression on matrix representations of genomic sequences,
|
|
475
471
|
useful for tasks like contact map prediction, structure prediction, or other
|
|
476
472
|
matrix-based genomic analysis tasks.
|
|
477
|
-
|
|
473
|
+
|
|
478
474
|
Attributes:
|
|
479
475
|
resnet: ResNet backbone for processing matrix inputs
|
|
480
476
|
classifier: Linear layer for regression output
|
|
481
477
|
loss_fn: Mean squared error loss function
|
|
482
478
|
"""
|
|
483
|
-
|
|
479
|
+
|
|
484
480
|
def __init__(self, config_or_model, tokenizer, *args, **kwargs):
|
|
485
481
|
"""
|
|
486
482
|
Initialize the matrix regression model.
|
|
487
|
-
|
|
483
|
+
|
|
488
484
|
Args:
|
|
489
485
|
config_or_model: Model configuration or pre-trained model
|
|
490
486
|
tokenizer: Tokenizer for processing input sequences
|
|
@@ -501,22 +497,22 @@ class OmniModelForMatrixRegression(OmniModel):
|
|
|
501
497
|
def forward(self, **inputs):
|
|
502
498
|
"""
|
|
503
499
|
Forward pass for matrix regression.
|
|
504
|
-
|
|
500
|
+
|
|
505
501
|
Args:
|
|
506
502
|
**inputs: Input tensors including matrix representations and labels
|
|
507
|
-
|
|
503
|
+
|
|
508
504
|
Returns:
|
|
509
505
|
dict: Dictionary containing logits, last_hidden_state, and labels
|
|
510
506
|
"""
|
|
511
507
|
labels = inputs.pop("labels", None)
|
|
512
508
|
matrix_inputs = inputs.pop("matrix_inputs", None)
|
|
513
|
-
|
|
509
|
+
|
|
514
510
|
if matrix_inputs is None:
|
|
515
511
|
raise ValueError("matrix_inputs is required for matrix regression")
|
|
516
|
-
|
|
512
|
+
|
|
517
513
|
outputs = self.resnet(matrix_inputs)
|
|
518
514
|
logits = self.classifier(outputs)
|
|
519
|
-
|
|
515
|
+
|
|
520
516
|
outputs = {
|
|
521
517
|
"logits": logits,
|
|
522
518
|
"last_hidden_state": outputs,
|
|
@@ -527,11 +523,11 @@ class OmniModelForMatrixRegression(OmniModel):
|
|
|
527
523
|
def predict(self, sequence_or_inputs, **kwargs):
|
|
528
524
|
"""
|
|
529
525
|
Generate predictions for matrix regression.
|
|
530
|
-
|
|
526
|
+
|
|
531
527
|
Args:
|
|
532
528
|
sequence_or_inputs: Input sequences or pre-processed inputs
|
|
533
529
|
**kwargs: Additional keyword arguments
|
|
534
|
-
|
|
530
|
+
|
|
535
531
|
Returns:
|
|
536
532
|
dict: Dictionary containing predictions, logits, and last_hidden_state
|
|
537
533
|
"""
|
|
@@ -559,11 +555,11 @@ class OmniModelForMatrixRegression(OmniModel):
|
|
|
559
555
|
def inference(self, sequence_or_inputs, **kwargs):
|
|
560
556
|
"""
|
|
561
557
|
Perform inference for matrix regression.
|
|
562
|
-
|
|
558
|
+
|
|
563
559
|
Args:
|
|
564
560
|
sequence_or_inputs: Input sequences or pre-processed inputs
|
|
565
561
|
**kwargs: Additional keyword arguments
|
|
566
|
-
|
|
562
|
+
|
|
567
563
|
Returns:
|
|
568
564
|
dict: Dictionary containing predictions, logits, and last_hidden_state
|
|
569
565
|
"""
|
|
@@ -594,11 +590,11 @@ class OmniModelForMatrixRegression(OmniModel):
|
|
|
594
590
|
def loss_function(self, logits, labels):
|
|
595
591
|
"""
|
|
596
592
|
Compute the loss for matrix regression.
|
|
597
|
-
|
|
593
|
+
|
|
598
594
|
Args:
|
|
599
595
|
logits (torch.Tensor): Model predictions
|
|
600
596
|
labels (torch.Tensor): Ground truth labels
|
|
601
|
-
|
|
597
|
+
|
|
602
598
|
Returns:
|
|
603
599
|
torch.Tensor: Computed loss value
|
|
604
600
|
"""
|
|
@@ -619,21 +615,21 @@ class OmniModelForMatrixRegression(OmniModel):
|
|
|
619
615
|
class OmniModelForMatrixClassification(OmniModel):
|
|
620
616
|
"""
|
|
621
617
|
Matrix classification model for genomic sequences.
|
|
622
|
-
|
|
618
|
+
|
|
623
619
|
This model performs classification on matrix representations of genomic sequences,
|
|
624
620
|
useful for tasks like structure classification, contact map classification, or other
|
|
625
621
|
matrix-based genomic analysis tasks.
|
|
626
|
-
|
|
622
|
+
|
|
627
623
|
Attributes:
|
|
628
624
|
resnet: ResNet backbone for processing matrix inputs
|
|
629
625
|
classifier: Linear layer for classification output
|
|
630
626
|
loss_fn: Cross-entropy loss function
|
|
631
627
|
"""
|
|
632
|
-
|
|
628
|
+
|
|
633
629
|
def __init__(self, config_or_model, tokenizer, *args, **kwargs):
|
|
634
630
|
"""
|
|
635
631
|
Initialize the matrix classification model.
|
|
636
|
-
|
|
632
|
+
|
|
637
633
|
Args:
|
|
638
634
|
config_or_model: Model configuration or pre-trained model
|
|
639
635
|
tokenizer: Tokenizer for processing input sequences
|
|
@@ -650,26 +646,25 @@ class OmniModelForMatrixClassification(OmniModel):
|
|
|
650
646
|
self.cnn = resnet_b16(channels=self.config.hidden_size, bbn=16)
|
|
651
647
|
self.model_info()
|
|
652
648
|
|
|
653
|
-
|
|
654
649
|
def forward(self, **inputs):
|
|
655
650
|
"""
|
|
656
651
|
Forward pass for matrix classification.
|
|
657
|
-
|
|
652
|
+
|
|
658
653
|
Args:
|
|
659
654
|
**inputs: Input tensors including matrix representations and labels
|
|
660
|
-
|
|
655
|
+
|
|
661
656
|
Returns:
|
|
662
657
|
dict: Dictionary containing logits, last_hidden_state, and labels
|
|
663
658
|
"""
|
|
664
659
|
labels = inputs.pop("labels", None)
|
|
665
660
|
matrix_inputs = inputs.pop("matrix_inputs", None)
|
|
666
|
-
|
|
661
|
+
|
|
667
662
|
if matrix_inputs is None:
|
|
668
663
|
raise ValueError("matrix_inputs is required for matrix classification")
|
|
669
|
-
|
|
664
|
+
|
|
670
665
|
outputs = self.resnet(matrix_inputs)
|
|
671
666
|
logits = self.classifier(outputs)
|
|
672
|
-
|
|
667
|
+
|
|
673
668
|
outputs = {
|
|
674
669
|
"logits": logits,
|
|
675
670
|
"last_hidden_state": outputs,
|
|
@@ -680,11 +675,11 @@ class OmniModelForMatrixClassification(OmniModel):
|
|
|
680
675
|
def predict(self, sequence_or_inputs, **kwargs):
|
|
681
676
|
"""
|
|
682
677
|
Generate predictions for matrix classification.
|
|
683
|
-
|
|
678
|
+
|
|
684
679
|
Args:
|
|
685
680
|
sequence_or_inputs: Input sequences or pre-processed inputs
|
|
686
681
|
**kwargs: Additional keyword arguments
|
|
687
|
-
|
|
682
|
+
|
|
688
683
|
Returns:
|
|
689
684
|
dict: Dictionary containing predictions, logits, and last_hidden_state
|
|
690
685
|
"""
|
|
@@ -713,11 +708,11 @@ class OmniModelForMatrixClassification(OmniModel):
|
|
|
713
708
|
def inference(self, sequence_or_inputs, **kwargs):
|
|
714
709
|
"""
|
|
715
710
|
Perform inference for matrix classification.
|
|
716
|
-
|
|
711
|
+
|
|
717
712
|
Args:
|
|
718
713
|
sequence_or_inputs: Input sequences or pre-processed inputs
|
|
719
714
|
**kwargs: Additional keyword arguments
|
|
720
|
-
|
|
715
|
+
|
|
721
716
|
Returns:
|
|
722
717
|
dict: Dictionary containing predictions, logits, and last_hidden_state
|
|
723
718
|
"""
|
|
@@ -756,11 +751,11 @@ class OmniModelForMatrixClassification(OmniModel):
|
|
|
756
751
|
def loss_function(self, logits, labels):
|
|
757
752
|
"""
|
|
758
753
|
Compute the loss for matrix classification.
|
|
759
|
-
|
|
754
|
+
|
|
760
755
|
Args:
|
|
761
756
|
logits (torch.Tensor): Model predictions
|
|
762
757
|
labels (torch.Tensor): Ground truth labels
|
|
763
|
-
|
|
758
|
+
|
|
764
759
|
Returns:
|
|
765
760
|
torch.Tensor: Computed loss value
|
|
766
761
|
"""
|