omnigenome 0.3.0a0__py3-none-any.whl → 0.3.1a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. omnigenome/__init__.py +29 -44
  2. omnigenome/auto/auto_bench/__init__.py +0 -1
  3. omnigenome/auto/auto_bench/auto_bench.py +24 -14
  4. omnigenome/auto/auto_train/__init__.py +0 -1
  5. omnigenome/auto/auto_train/auto_train.py +11 -12
  6. omnigenome/auto/bench_hub/__init__.py +0 -1
  7. omnigenome/auto/bench_hub/bench_hub.py +1 -1
  8. omnigenome/cli/__init__.py +0 -1
  9. omnigenome/cli/commands/__init__.py +0 -1
  10. omnigenome/cli/commands/base.py +10 -10
  11. omnigenome/cli/commands/bench/__init__.py +0 -1
  12. omnigenome/cli/commands/bench/bench_cli.py +10 -10
  13. omnigenome/cli/commands/rna/__init__.py +0 -1
  14. omnigenome/cli/commands/rna/rna_design.py +10 -11
  15. omnigenome/src/__init__.py +0 -1
  16. omnigenome/src/abc/__init__.py +0 -1
  17. omnigenome/src/abc/abstract_dataset.py +38 -19
  18. omnigenome/src/abc/abstract_metric.py +7 -7
  19. omnigenome/src/abc/abstract_model.py +15 -14
  20. omnigenome/src/abc/abstract_tokenizer.py +9 -7
  21. omnigenome/src/dataset/omni_dataset.py +16 -14
  22. omnigenome/src/lora/__init__.py +0 -1
  23. omnigenome/src/lora/lora_model.py +47 -41
  24. omnigenome/src/metric/classification_metric.py +11 -11
  25. omnigenome/src/metric/metric.py +19 -19
  26. omnigenome/src/metric/ranking_metric.py +15 -15
  27. omnigenome/src/metric/regression_metric.py +18 -18
  28. omnigenome/src/misc/utils.py +214 -150
  29. omnigenome/src/model/augmentation/__init__.py +0 -1
  30. omnigenome/src/model/augmentation/model.py +17 -17
  31. omnigenome/src/model/classification/__init__.py +0 -1
  32. omnigenome/src/model/classification/model.py +28 -32
  33. omnigenome/src/model/embedding/__init__.py +0 -1
  34. omnigenome/src/model/embedding/model.py +35 -35
  35. omnigenome/src/model/mlm/__init__.py +0 -1
  36. omnigenome/src/model/mlm/model.py +13 -13
  37. omnigenome/src/model/module_utils.py +17 -17
  38. omnigenome/src/model/regression/__init__.py +0 -1
  39. omnigenome/src/model/regression/model.py +72 -77
  40. omnigenome/src/model/regression/resnet.py +32 -32
  41. omnigenome/src/model/rna_design/__init__.py +0 -1
  42. omnigenome/src/model/rna_design/model.py +168 -118
  43. omnigenome/src/model/seq2seq/__init__.py +0 -1
  44. omnigenome/src/model/seq2seq/model.py +4 -4
  45. omnigenome/src/tokenizer/bpe_tokenizer.py +27 -27
  46. omnigenome/src/tokenizer/kmers_tokenizer.py +22 -22
  47. omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +11 -11
  48. omnigenome/src/trainer/accelerate_trainer.py +40 -32
  49. omnigenome/src/trainer/hf_trainer.py +8 -8
  50. omnigenome/src/trainer/trainer.py +37 -25
  51. omnigenome/utility/dataset_hub/__init__.py +0 -1
  52. omnigenome/utility/dataset_hub/dataset_hub.py +13 -13
  53. omnigenome/utility/ensemble.py +26 -26
  54. omnigenome/utility/hub_utils.py +8 -8
  55. omnigenome/utility/model_hub/__init__.py +0 -1
  56. omnigenome/utility/model_hub/model_hub.py +26 -25
  57. omnigenome/utility/pipeline_hub/__init__.py +0 -1
  58. omnigenome/utility/pipeline_hub/pipeline.py +49 -49
  59. omnigenome/utility/pipeline_hub/pipeline_hub.py +17 -17
  60. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/METADATA +3 -3
  61. omnigenome-0.3.1a0.dist-info/RECORD +78 -0
  62. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/top_level.txt +0 -1
  63. omnigenome-0.3.0a0.dist-info/RECORD +0 -85
  64. tests/__init__.py +0 -9
  65. tests/conftest.py +0 -160
  66. tests/test_dataset_patterns.py +0 -291
  67. tests/test_examples_syntax.py +0 -83
  68. tests/test_model_loading.py +0 -183
  69. tests/test_rna_functions.py +0 -255
  70. tests/test_training_patterns.py +0 -302
  71. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/WHEEL +0 -0
  72. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/entry_points.txt +0 -0
  73. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/licenses/LICENSE +0 -0
@@ -22,21 +22,21 @@ from transformers.tokenization_utils_base import BatchEncoding
22
22
  class OmniPooling(torch.nn.Module):
23
23
  """
24
24
  A flexible pooling layer for OmniGenome models that handles different input formats.
25
-
25
+
26
26
  This class provides a unified interface for pooling operations across different
27
27
  model architectures, supporting both causal language models and encoder-based models.
28
28
  It can handle various input formats including tuples, dictionaries, BatchEncoding
29
29
  objects, and tensors.
30
-
30
+
31
31
  Attributes:
32
32
  config: Model configuration object containing architecture and tokenizer settings
33
33
  pooler: BertPooler instance for non-causal models, None for causal models
34
34
  """
35
-
35
+
36
36
  def __init__(self, config, *args, **kwargs):
37
37
  """
38
38
  Initialize the OmniPooling layer.
39
-
39
+
40
40
  Args:
41
41
  config: Model configuration object containing architecture information
42
42
  *args: Additional positional arguments
@@ -49,18 +49,18 @@ class OmniPooling(torch.nn.Module):
49
49
  def forward(self, inputs, last_hidden_state):
50
50
  """
51
51
  Perform pooling operation on the last hidden state.
52
-
52
+
53
53
  This method handles different input formats and applies appropriate pooling:
54
54
  - For causal language models: Uses the last non-padded token
55
55
  - For encoder models: Uses the BertPooler
56
-
56
+
57
57
  Args:
58
58
  inputs: Input data in various formats (tuple, dict, BatchEncoding, or tensor)
59
59
  last_hidden_state (torch.Tensor): Hidden states from the model [batch_size, seq_len, hidden_size]
60
-
60
+
61
61
  Returns:
62
62
  torch.Tensor: Pooled representation [batch_size, hidden_size]
63
-
63
+
64
64
  Raises:
65
65
  ValueError: If input format is not supported or cannot be parsed
66
66
  """
@@ -110,9 +110,9 @@ class OmniPooling(torch.nn.Module):
110
110
  def _is_causal_lm(self):
111
111
  """
112
112
  Check if the model is a causal language model.
113
-
113
+
114
114
  Determines if the model architecture is causal based on the configuration.
115
-
115
+
116
116
  Returns:
117
117
  bool: True if the model is a causal language model, False otherwise
118
118
  """
@@ -175,25 +175,25 @@ class OmniPooling(torch.nn.Module):
175
175
  class InteractingAttention(nn.Module):
176
176
  """
177
177
  An interacting attention mechanism for sequence modeling.
178
-
178
+
179
179
  This class implements a multi-head attention mechanism with residual connections
180
180
  and layer normalization. It's designed for processing sequences where different
181
181
  parts of the sequence need to interact with each other.
182
-
182
+
183
183
  Attributes:
184
184
  attention: Multi-head attention layer
185
185
  layer_norm: Layer normalization for residual connections
186
186
  fc_out: Output projection layer
187
187
  """
188
-
188
+
189
189
  def __init__(self, embed_size, num_heads=24):
190
190
  """
191
191
  Initialize the InteractingAttention module.
192
-
192
+
193
193
  Args:
194
194
  embed_size (int): Size of the embedding dimension
195
195
  num_heads (int): Number of attention heads (default: 24)
196
-
196
+
197
197
  Raises:
198
198
  AssertionError: If embed_size is not divisible by num_heads
199
199
  """
@@ -213,12 +213,12 @@ class InteractingAttention(nn.Module):
213
213
  def forward(self, query, keys, values):
214
214
  """
215
215
  Forward pass through the interacting attention mechanism.
216
-
216
+
217
217
  Args:
218
218
  query (torch.Tensor): Query tensor [batch_size, query_len, embed_size]
219
219
  keys (torch.Tensor): Key tensor [batch_size, key_len, embed_size]
220
220
  values (torch.Tensor): Value tensor [batch_size, value_len, embed_size]
221
-
221
+
222
222
  Returns:
223
223
  torch.Tensor: Output tensor with same shape as query
224
224
  """
@@ -9,4 +9,3 @@
9
9
  """
10
10
  This package contains modules for regression models.
11
11
  """
12
-
@@ -23,21 +23,21 @@ from ..module_utils import OmniPooling
23
23
  class OmniModelForTokenRegression(OmniModel):
24
24
  """
25
25
  Token-level regression model for genomic sequences.
26
-
26
+
27
27
  This model performs regression at the token level, predicting continuous values
28
28
  for each token in the input sequence. It's useful for tasks like predicting
29
29
  binding affinities, expression levels, or other continuous properties at each
30
30
  position in a genomic sequence.
31
-
31
+
32
32
  Attributes:
33
33
  classifier: Linear layer for regression output
34
34
  loss_fn: Mean squared error loss function
35
35
  """
36
-
36
+
37
37
  def __init__(self, config_or_model, tokenizer, *args, **kwargs):
38
38
  """
39
39
  Initialize the token regression model.
40
-
40
+
41
41
  Args:
42
42
  config_or_model: Model configuration or pre-trained model
43
43
  tokenizer: Tokenizer for processing input sequences
@@ -55,10 +55,10 @@ class OmniModelForTokenRegression(OmniModel):
55
55
  def forward(self, **inputs):
56
56
  """
57
57
  Forward pass for token-level regression.
58
-
58
+
59
59
  Args:
60
60
  **inputs: Input tensors including input_ids, attention_mask, and labels
61
-
61
+
62
62
  Returns:
63
63
  dict: Dictionary containing logits, last_hidden_state, and labels
64
64
  """
@@ -77,11 +77,11 @@ class OmniModelForTokenRegression(OmniModel):
77
77
  def predict(self, sequence_or_inputs, **kwargs):
78
78
  """
79
79
  Generate predictions for token-level regression.
80
-
80
+
81
81
  Args:
82
82
  sequence_or_inputs: Input sequences or pre-processed inputs
83
83
  **kwargs: Additional keyword arguments
84
-
84
+
85
85
  Returns:
86
86
  dict: Dictionary containing predictions, logits, and last_hidden_state
87
87
  """
@@ -109,11 +109,11 @@ class OmniModelForTokenRegression(OmniModel):
109
109
  def inference(self, sequence_or_inputs, **kwargs):
110
110
  """
111
111
  Perform inference for token-level regression, excluding special tokens.
112
-
112
+
113
113
  Args:
114
114
  sequence_or_inputs: Input sequences or pre-processed inputs
115
115
  **kwargs: Additional keyword arguments
116
-
116
+
117
117
  Returns:
118
118
  dict: Dictionary containing predictions, logits, and last_hidden_state
119
119
  """
@@ -148,11 +148,11 @@ class OmniModelForTokenRegression(OmniModel):
148
148
  def loss_function(self, logits, labels):
149
149
  """
150
150
  Compute the loss for token-level regression.
151
-
151
+
152
152
  Args:
153
153
  logits (torch.Tensor): Model predictions
154
154
  labels (torch.Tensor): Ground truth labels
155
-
155
+
156
156
  Returns:
157
157
  torch.Tensor: Computed loss value
158
158
  """
@@ -173,22 +173,22 @@ class OmniModelForTokenRegression(OmniModel):
173
173
  class OmniModelForSequenceRegression(OmniModel):
174
174
  """
175
175
  Sequence-level regression model for genomic sequences.
176
-
176
+
177
177
  This model performs regression at the sequence level, predicting a single
178
178
  continuous value for the entire input sequence. It's useful for tasks like
179
179
  predicting overall expression levels, binding affinities, or other sequence-level
180
180
  properties.
181
-
181
+
182
182
  Attributes:
183
183
  pooler: OmniPooling layer for sequence-level representation
184
184
  classifier: Linear layer for regression output
185
185
  loss_fn: Mean squared error loss function
186
186
  """
187
-
187
+
188
188
  def __init__(self, config_or_model, tokenizer, *args, **kwargs):
189
189
  """
190
190
  Initialize the sequence regression model.
191
-
191
+
192
192
  Args:
193
193
  config_or_model: Model configuration or pre-trained model
194
194
  tokenizer: Tokenizer for processing input sequences
@@ -207,10 +207,10 @@ class OmniModelForSequenceRegression(OmniModel):
207
207
  def forward(self, **inputs):
208
208
  """
209
209
  Forward pass for sequence-level regression.
210
-
210
+
211
211
  Args:
212
212
  **inputs: Input tensors including input_ids, attention_mask, and labels
213
-
213
+
214
214
  Returns:
215
215
  dict: Dictionary containing logits, last_hidden_state, and labels
216
216
  """
@@ -230,11 +230,11 @@ class OmniModelForSequenceRegression(OmniModel):
230
230
  def predict(self, sequence_or_inputs, **kwargs):
231
231
  """
232
232
  Generate predictions for sequence-level regression.
233
-
233
+
234
234
  Args:
235
235
  sequence_or_inputs: Input sequences or pre-processed inputs
236
236
  **kwargs: Additional keyword arguments
237
-
237
+
238
238
  Returns:
239
239
  dict: Dictionary containing predictions, logits, and last_hidden_state
240
240
  """
@@ -262,11 +262,11 @@ class OmniModelForSequenceRegression(OmniModel):
262
262
  def inference(self, sequence_or_inputs, **kwargs):
263
263
  """
264
264
  Perform inference for sequence-level regression.
265
-
265
+
266
266
  Args:
267
267
  sequence_or_inputs: Input sequences or pre-processed inputs
268
268
  **kwargs: Additional keyword arguments
269
-
269
+
270
270
  Returns:
271
271
  dict: Dictionary containing predictions, logits, and last_hidden_state
272
272
  """
@@ -297,11 +297,11 @@ class OmniModelForSequenceRegression(OmniModel):
297
297
  def loss_function(self, logits, labels):
298
298
  """
299
299
  Compute the loss for sequence-level regression.
300
-
300
+
301
301
  Args:
302
302
  logits (torch.Tensor): Model predictions
303
303
  labels (torch.Tensor): Ground truth labels
304
-
304
+
305
305
  Returns:
306
306
  torch.Tensor: Computed loss value
307
307
  """
@@ -322,20 +322,20 @@ class OmniModelForSequenceRegression(OmniModel):
322
322
  class OmniModelForStructuralImputation(OmniModelForSequenceRegression):
323
323
  """
324
324
  Structural imputation model for genomic sequences.
325
-
325
+
326
326
  This model is specialized for imputing missing structural information in
327
327
  genomic sequences. It extends the sequence regression model with additional
328
328
  embedding capabilities for structural features.
329
-
329
+
330
330
  Attributes:
331
331
  embedding: Embedding layer for structural features
332
332
  loss_fn: Mean squared error loss function
333
333
  """
334
-
334
+
335
335
  def __init__(self, config_or_model, tokenizer, *args, **kwargs):
336
336
  """
337
337
  Initialize the structural imputation model.
338
-
338
+
339
339
  Args:
340
340
  config_or_model: Model configuration or pre-trained model
341
341
  tokenizer: Tokenizer for processing input sequences
@@ -351,10 +351,10 @@ class OmniModelForStructuralImputation(OmniModelForSequenceRegression):
351
351
  def forward(self, **inputs):
352
352
  """
353
353
  Forward pass for structural imputation.
354
-
354
+
355
355
  Args:
356
356
  **inputs: Input tensors including input_ids, attention_mask, and labels
357
-
357
+
358
358
  Returns:
359
359
  dict: Dictionary containing logits, last_hidden_state, and labels
360
360
  """
@@ -372,21 +372,19 @@ class OmniModelForStructuralImputation(OmniModelForSequenceRegression):
372
372
  return outputs
373
373
 
374
374
 
375
- class OmniModelForTokenRegressionWith2DStructure(
376
- OmniModelForTokenRegression
377
- ):
375
+ class OmniModelForTokenRegressionWith2DStructure(OmniModelForTokenRegression):
378
376
  """
379
377
  Token-level regression model with 2D structural information.
380
-
378
+
381
379
  This model extends the basic token regression model to incorporate
382
380
  2D structural information, useful for RNA structure prediction
383
381
  and other structural genomics tasks.
384
382
  """
385
-
383
+
386
384
  def __init__(self, config_or_model, tokenizer, *args, **kwargs):
387
385
  """
388
386
  Initialize the 2D structure-aware token regression model.
389
-
387
+
390
388
  Args:
391
389
  config_or_model: Model configuration or pre-trained model
392
390
  tokenizer: Tokenizer for processing input sequences
@@ -399,10 +397,10 @@ class OmniModelForTokenRegressionWith2DStructure(
399
397
  def forward(self, **inputs):
400
398
  """
401
399
  Forward pass for 2D structure-aware token regression.
402
-
400
+
403
401
  Args:
404
402
  **inputs: Input tensors including input_ids, attention_mask, labels, and structural info
405
-
403
+
406
404
  Returns:
407
405
  dict: Dictionary containing logits, last_hidden_state, and labels
408
406
  """
@@ -419,21 +417,19 @@ class OmniModelForTokenRegressionWith2DStructure(
419
417
  return outputs
420
418
 
421
419
 
422
- class OmniModelForSequenceRegressionWith2DStructure(
423
- OmniModelForSequenceRegression
424
- ):
420
+ class OmniModelForSequenceRegressionWith2DStructure(OmniModelForSequenceRegression):
425
421
  """
426
422
  Sequence-level regression model with 2D structural information.
427
-
423
+
428
424
  This model extends the basic sequence regression model to incorporate
429
425
  2D structural information, useful for RNA structure prediction
430
426
  and other structural genomics tasks.
431
427
  """
432
-
428
+
433
429
  def __init__(self, config_or_model, tokenizer, *args, **kwargs):
434
430
  """
435
431
  Initialize the 2D structure-aware sequence regression model.
436
-
432
+
437
433
  Args:
438
434
  config_or_model: Model configuration or pre-trained model
439
435
  tokenizer: Tokenizer for processing input sequences
@@ -446,10 +442,10 @@ class OmniModelForSequenceRegressionWith2DStructure(
446
442
  def forward(self, **inputs):
447
443
  """
448
444
  Forward pass for 2D structure-aware sequence regression.
449
-
445
+
450
446
  Args:
451
447
  **inputs: Input tensors including input_ids, attention_mask, labels, and structural info
452
-
448
+
453
449
  Returns:
454
450
  dict: Dictionary containing logits, last_hidden_state, and labels
455
451
  """
@@ -470,21 +466,21 @@ class OmniModelForSequenceRegressionWith2DStructure(
470
466
  class OmniModelForMatrixRegression(OmniModel):
471
467
  """
472
468
  Matrix regression model for genomic sequences.
473
-
469
+
474
470
  This model performs regression on matrix representations of genomic sequences,
475
471
  useful for tasks like contact map prediction, structure prediction, or other
476
472
  matrix-based genomic analysis tasks.
477
-
473
+
478
474
  Attributes:
479
475
  resnet: ResNet backbone for processing matrix inputs
480
476
  classifier: Linear layer for regression output
481
477
  loss_fn: Mean squared error loss function
482
478
  """
483
-
479
+
484
480
  def __init__(self, config_or_model, tokenizer, *args, **kwargs):
485
481
  """
486
482
  Initialize the matrix regression model.
487
-
483
+
488
484
  Args:
489
485
  config_or_model: Model configuration or pre-trained model
490
486
  tokenizer: Tokenizer for processing input sequences
@@ -501,22 +497,22 @@ class OmniModelForMatrixRegression(OmniModel):
501
497
  def forward(self, **inputs):
502
498
  """
503
499
  Forward pass for matrix regression.
504
-
500
+
505
501
  Args:
506
502
  **inputs: Input tensors including matrix representations and labels
507
-
503
+
508
504
  Returns:
509
505
  dict: Dictionary containing logits, last_hidden_state, and labels
510
506
  """
511
507
  labels = inputs.pop("labels", None)
512
508
  matrix_inputs = inputs.pop("matrix_inputs", None)
513
-
509
+
514
510
  if matrix_inputs is None:
515
511
  raise ValueError("matrix_inputs is required for matrix regression")
516
-
512
+
517
513
  outputs = self.resnet(matrix_inputs)
518
514
  logits = self.classifier(outputs)
519
-
515
+
520
516
  outputs = {
521
517
  "logits": logits,
522
518
  "last_hidden_state": outputs,
@@ -527,11 +523,11 @@ class OmniModelForMatrixRegression(OmniModel):
527
523
  def predict(self, sequence_or_inputs, **kwargs):
528
524
  """
529
525
  Generate predictions for matrix regression.
530
-
526
+
531
527
  Args:
532
528
  sequence_or_inputs: Input sequences or pre-processed inputs
533
529
  **kwargs: Additional keyword arguments
534
-
530
+
535
531
  Returns:
536
532
  dict: Dictionary containing predictions, logits, and last_hidden_state
537
533
  """
@@ -559,11 +555,11 @@ class OmniModelForMatrixRegression(OmniModel):
559
555
  def inference(self, sequence_or_inputs, **kwargs):
560
556
  """
561
557
  Perform inference for matrix regression.
562
-
558
+
563
559
  Args:
564
560
  sequence_or_inputs: Input sequences or pre-processed inputs
565
561
  **kwargs: Additional keyword arguments
566
-
562
+
567
563
  Returns:
568
564
  dict: Dictionary containing predictions, logits, and last_hidden_state
569
565
  """
@@ -594,11 +590,11 @@ class OmniModelForMatrixRegression(OmniModel):
594
590
  def loss_function(self, logits, labels):
595
591
  """
596
592
  Compute the loss for matrix regression.
597
-
593
+
598
594
  Args:
599
595
  logits (torch.Tensor): Model predictions
600
596
  labels (torch.Tensor): Ground truth labels
601
-
597
+
602
598
  Returns:
603
599
  torch.Tensor: Computed loss value
604
600
  """
@@ -619,21 +615,21 @@ class OmniModelForMatrixRegression(OmniModel):
619
615
  class OmniModelForMatrixClassification(OmniModel):
620
616
  """
621
617
  Matrix classification model for genomic sequences.
622
-
618
+
623
619
  This model performs classification on matrix representations of genomic sequences,
624
620
  useful for tasks like structure classification, contact map classification, or other
625
621
  matrix-based genomic analysis tasks.
626
-
622
+
627
623
  Attributes:
628
624
  resnet: ResNet backbone for processing matrix inputs
629
625
  classifier: Linear layer for classification output
630
626
  loss_fn: Cross-entropy loss function
631
627
  """
632
-
628
+
633
629
  def __init__(self, config_or_model, tokenizer, *args, **kwargs):
634
630
  """
635
631
  Initialize the matrix classification model.
636
-
632
+
637
633
  Args:
638
634
  config_or_model: Model configuration or pre-trained model
639
635
  tokenizer: Tokenizer for processing input sequences
@@ -650,26 +646,25 @@ class OmniModelForMatrixClassification(OmniModel):
650
646
  self.cnn = resnet_b16(channels=self.config.hidden_size, bbn=16)
651
647
  self.model_info()
652
648
 
653
-
654
649
  def forward(self, **inputs):
655
650
  """
656
651
  Forward pass for matrix classification.
657
-
652
+
658
653
  Args:
659
654
  **inputs: Input tensors including matrix representations and labels
660
-
655
+
661
656
  Returns:
662
657
  dict: Dictionary containing logits, last_hidden_state, and labels
663
658
  """
664
659
  labels = inputs.pop("labels", None)
665
660
  matrix_inputs = inputs.pop("matrix_inputs", None)
666
-
661
+
667
662
  if matrix_inputs is None:
668
663
  raise ValueError("matrix_inputs is required for matrix classification")
669
-
664
+
670
665
  outputs = self.resnet(matrix_inputs)
671
666
  logits = self.classifier(outputs)
672
-
667
+
673
668
  outputs = {
674
669
  "logits": logits,
675
670
  "last_hidden_state": outputs,
@@ -680,11 +675,11 @@ class OmniModelForMatrixClassification(OmniModel):
680
675
  def predict(self, sequence_or_inputs, **kwargs):
681
676
  """
682
677
  Generate predictions for matrix classification.
683
-
678
+
684
679
  Args:
685
680
  sequence_or_inputs: Input sequences or pre-processed inputs
686
681
  **kwargs: Additional keyword arguments
687
-
682
+
688
683
  Returns:
689
684
  dict: Dictionary containing predictions, logits, and last_hidden_state
690
685
  """
@@ -713,11 +708,11 @@ class OmniModelForMatrixClassification(OmniModel):
713
708
  def inference(self, sequence_or_inputs, **kwargs):
714
709
  """
715
710
  Perform inference for matrix classification.
716
-
711
+
717
712
  Args:
718
713
  sequence_or_inputs: Input sequences or pre-processed inputs
719
714
  **kwargs: Additional keyword arguments
720
-
715
+
721
716
  Returns:
722
717
  dict: Dictionary containing predictions, logits, and last_hidden_state
723
718
  """
@@ -756,11 +751,11 @@ class OmniModelForMatrixClassification(OmniModel):
756
751
  def loss_function(self, logits, labels):
757
752
  """
758
753
  Compute the loss for matrix classification.
759
-
754
+
760
755
  Args:
761
756
  logits (torch.Tensor): Model predictions
762
757
  labels (torch.Tensor): Ground truth labels
763
-
758
+
764
759
  Returns:
765
760
  torch.Tensor: Computed loss value
766
761
  """