omnigenome 0.3.0a1__py3-none-any.whl → 0.3.3a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of omnigenome might be problematic. Click here for more details.

Files changed (79) hide show
  1. omnigenome/__init__.py +252 -258
  2. {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.3a0.dist-info}/METADATA +10 -10
  3. omnigenome-0.3.3a0.dist-info/RECORD +7 -0
  4. omnigenome/auto/__init__.py +0 -3
  5. omnigenome/auto/auto_bench/__init__.py +0 -12
  6. omnigenome/auto/auto_bench/auto_bench.py +0 -484
  7. omnigenome/auto/auto_bench/auto_bench_cli.py +0 -230
  8. omnigenome/auto/auto_bench/auto_bench_config.py +0 -216
  9. omnigenome/auto/auto_bench/config_check.py +0 -34
  10. omnigenome/auto/auto_train/__init__.py +0 -13
  11. omnigenome/auto/auto_train/auto_train.py +0 -430
  12. omnigenome/auto/auto_train/auto_train_cli.py +0 -222
  13. omnigenome/auto/bench_hub/__init__.py +0 -12
  14. omnigenome/auto/bench_hub/bench_hub.py +0 -25
  15. omnigenome/cli/__init__.py +0 -13
  16. omnigenome/cli/commands/__init__.py +0 -13
  17. omnigenome/cli/commands/base.py +0 -83
  18. omnigenome/cli/commands/bench/__init__.py +0 -13
  19. omnigenome/cli/commands/bench/bench_cli.py +0 -202
  20. omnigenome/cli/commands/rna/__init__.py +0 -13
  21. omnigenome/cli/commands/rna/rna_design.py +0 -178
  22. omnigenome/cli/omnigenome_cli.py +0 -128
  23. omnigenome/src/__init__.py +0 -12
  24. omnigenome/src/abc/__init__.py +0 -12
  25. omnigenome/src/abc/abstract_dataset.py +0 -622
  26. omnigenome/src/abc/abstract_metric.py +0 -114
  27. omnigenome/src/abc/abstract_model.py +0 -689
  28. omnigenome/src/abc/abstract_tokenizer.py +0 -267
  29. omnigenome/src/dataset/__init__.py +0 -16
  30. omnigenome/src/dataset/omni_dataset.py +0 -435
  31. omnigenome/src/lora/__init__.py +0 -13
  32. omnigenome/src/lora/lora_model.py +0 -294
  33. omnigenome/src/metric/__init__.py +0 -15
  34. omnigenome/src/metric/classification_metric.py +0 -184
  35. omnigenome/src/metric/metric.py +0 -199
  36. omnigenome/src/metric/ranking_metric.py +0 -142
  37. omnigenome/src/metric/regression_metric.py +0 -191
  38. omnigenome/src/misc/__init__.py +0 -3
  39. omnigenome/src/misc/utils.py +0 -499
  40. omnigenome/src/model/__init__.py +0 -19
  41. omnigenome/src/model/augmentation/__init__.py +0 -12
  42. omnigenome/src/model/augmentation/model.py +0 -219
  43. omnigenome/src/model/classification/__init__.py +0 -12
  44. omnigenome/src/model/classification/model.py +0 -642
  45. omnigenome/src/model/embedding/__init__.py +0 -12
  46. omnigenome/src/model/embedding/model.py +0 -263
  47. omnigenome/src/model/mlm/__init__.py +0 -12
  48. omnigenome/src/model/mlm/model.py +0 -177
  49. omnigenome/src/model/module_utils.py +0 -232
  50. omnigenome/src/model/regression/__init__.py +0 -12
  51. omnigenome/src/model/regression/model.py +0 -786
  52. omnigenome/src/model/regression/resnet.py +0 -483
  53. omnigenome/src/model/rna_design/__init__.py +0 -12
  54. omnigenome/src/model/rna_design/model.py +0 -469
  55. omnigenome/src/model/seq2seq/__init__.py +0 -12
  56. omnigenome/src/model/seq2seq/model.py +0 -44
  57. omnigenome/src/tokenizer/__init__.py +0 -16
  58. omnigenome/src/tokenizer/bpe_tokenizer.py +0 -226
  59. omnigenome/src/tokenizer/kmers_tokenizer.py +0 -247
  60. omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +0 -249
  61. omnigenome/src/trainer/__init__.py +0 -14
  62. omnigenome/src/trainer/accelerate_trainer.py +0 -739
  63. omnigenome/src/trainer/hf_trainer.py +0 -75
  64. omnigenome/src/trainer/trainer.py +0 -579
  65. omnigenome/utility/__init__.py +0 -3
  66. omnigenome/utility/dataset_hub/__init__.py +0 -13
  67. omnigenome/utility/dataset_hub/dataset_hub.py +0 -178
  68. omnigenome/utility/ensemble.py +0 -324
  69. omnigenome/utility/hub_utils.py +0 -517
  70. omnigenome/utility/model_hub/__init__.py +0 -12
  71. omnigenome/utility/model_hub/model_hub.py +0 -231
  72. omnigenome/utility/pipeline_hub/__init__.py +0 -12
  73. omnigenome/utility/pipeline_hub/pipeline.py +0 -483
  74. omnigenome/utility/pipeline_hub/pipeline_hub.py +0 -129
  75. omnigenome-0.3.0a1.dist-info/RECORD +0 -78
  76. {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.3a0.dist-info}/WHEEL +0 -0
  77. {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.3a0.dist-info}/entry_points.txt +0 -0
  78. {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.3a0.dist-info}/licenses/LICENSE +0 -0
  79. {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.3a0.dist-info}/top_level.txt +0 -0
@@ -1,786 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # file: model.py
3
- # time: 18:36 06/04/2024
4
- # author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
- # github: https://github.com/yangheng95
6
- # huggingface: https://huggingface.co/yangheng
7
- # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
8
- # Copyright (C) 2019-2024. All Rights Reserved.
9
- """
10
- Regression models for OmniGenome framework.
11
-
12
- This module provides various regression model implementations for genomic sequence analysis,
13
- including token-level regression, sequence-level regression, structural imputation,
14
- and matrix regression/classification tasks.
15
- """
16
- import torch
17
-
18
- from .resnet import resnet_b16
19
- from ...abc.abstract_model import OmniModel
20
- from ..module_utils import OmniPooling
21
-
22
-
23
- class OmniModelForTokenRegression(OmniModel):
24
- """
25
- Token-level regression model for genomic sequences.
26
-
27
- This model performs regression at the token level, predicting continuous values
28
- for each token in the input sequence. It's useful for tasks like predicting
29
- binding affinities, expression levels, or other continuous properties at each
30
- position in a genomic sequence.
31
-
32
- Attributes:
33
- classifier: Linear layer for regression output
34
- loss_fn: Mean squared error loss function
35
- """
36
-
37
- def __init__(self, config_or_model, tokenizer, *args, **kwargs):
38
- """
39
- Initialize the token regression model.
40
-
41
- Args:
42
- config_or_model: Model configuration or pre-trained model
43
- tokenizer: Tokenizer for processing input sequences
44
- *args: Additional positional arguments
45
- **kwargs: Additional keyword arguments
46
- """
47
- super().__init__(config_or_model, tokenizer, *args, **kwargs)
48
- self.metadata["model_name"] = self.__class__.__name__
49
- self.classifier = torch.nn.Linear(
50
- self.config.hidden_size, self.config.num_labels
51
- )
52
- self.loss_fn = torch.nn.MSELoss()
53
- self.model_info()
54
-
55
- def forward(self, **inputs):
56
- """
57
- Forward pass for token-level regression.
58
-
59
- Args:
60
- **inputs: Input tensors including input_ids, attention_mask, and labels
61
-
62
- Returns:
63
- dict: Dictionary containing logits, last_hidden_state, and labels
64
- """
65
- labels = inputs.pop("labels", None)
66
- last_hidden_state = self.last_hidden_state_forward(**inputs)
67
- last_hidden_state = self.dropout(last_hidden_state)
68
- last_hidden_state = self.activation(last_hidden_state)
69
- logits = self.classifier(last_hidden_state)
70
- outputs = {
71
- "logits": logits,
72
- "last_hidden_state": last_hidden_state,
73
- "labels": labels,
74
- }
75
- return outputs
76
-
77
- def predict(self, sequence_or_inputs, **kwargs):
78
- """
79
- Generate predictions for token-level regression.
80
-
81
- Args:
82
- sequence_or_inputs: Input sequences or pre-processed inputs
83
- **kwargs: Additional keyword arguments
84
-
85
- Returns:
86
- dict: Dictionary containing predictions, logits, and last_hidden_state
87
- """
88
- raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
89
-
90
- logits = raw_outputs["logits"]
91
- last_hidden_state = raw_outputs["last_hidden_state"]
92
-
93
- predictions = []
94
- for i in range(logits.shape[0]):
95
- predictions.append(logits[i].cpu())
96
-
97
- outputs = {
98
- "predictions": (
99
- torch.vstack(predictions).to(self.model.device)
100
- if predictions[0].shape
101
- else torch.tensor(predictions).to(self.model.device)
102
- ),
103
- "logits": logits,
104
- "last_hidden_state": last_hidden_state,
105
- }
106
-
107
- return outputs
108
-
109
- def inference(self, sequence_or_inputs, **kwargs):
110
- """
111
- Perform inference for token-level regression, excluding special tokens.
112
-
113
- Args:
114
- sequence_or_inputs: Input sequences or pre-processed inputs
115
- **kwargs: Additional keyword arguments
116
-
117
- Returns:
118
- dict: Dictionary containing predictions, logits, and last_hidden_state
119
- """
120
- raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
121
-
122
- inputs = raw_outputs["inputs"]
123
- logits = raw_outputs["logits"]
124
- last_hidden_state = raw_outputs["last_hidden_state"]
125
-
126
- predictions = []
127
- for i in range(logits.shape[0]):
128
- i_logit = logits[i][inputs["input_ids"][i].ne(self.config.pad_token_id)][
129
- 1:-1
130
- ]
131
- predictions.append(i_logit.detach().cpu())
132
-
133
- if not isinstance(sequence_or_inputs, list):
134
- outputs = {
135
- "predictions": predictions[0],
136
- "logits": logits[0],
137
- "last_hidden_state": last_hidden_state[0],
138
- }
139
- else:
140
- outputs = {
141
- "predictions": predictions,
142
- "logits": logits,
143
- "last_hidden_state": last_hidden_state,
144
- }
145
-
146
- return outputs
147
-
148
- def loss_function(self, logits, labels):
149
- """
150
- Compute the loss for token-level regression.
151
-
152
- Args:
153
- logits (torch.Tensor): Model predictions
154
- labels (torch.Tensor): Ground truth labels
155
-
156
- Returns:
157
- torch.Tensor: Computed loss value
158
- """
159
- padding_value = (
160
- self.config.ignore_y if hasattr(self.config, "ignore_y") else -100
161
- )
162
- logits = logits.view(-1)
163
- labels = labels.view(-1)
164
- mask = torch.where(labels != padding_value)
165
-
166
- filtered_logits = logits[mask]
167
- filtered_targets = labels[mask]
168
-
169
- loss = self.loss_fn(filtered_logits, filtered_targets)
170
- return loss
171
-
172
-
173
- class OmniModelForSequenceRegression(OmniModel):
174
- """
175
- Sequence-level regression model for genomic sequences.
176
-
177
- This model performs regression at the sequence level, predicting a single
178
- continuous value for the entire input sequence. It's useful for tasks like
179
- predicting overall expression levels, binding affinities, or other sequence-level
180
- properties.
181
-
182
- Attributes:
183
- pooler: OmniPooling layer for sequence-level representation
184
- classifier: Linear layer for regression output
185
- loss_fn: Mean squared error loss function
186
- """
187
-
188
- def __init__(self, config_or_model, tokenizer, *args, **kwargs):
189
- """
190
- Initialize the sequence regression model.
191
-
192
- Args:
193
- config_or_model: Model configuration or pre-trained model
194
- tokenizer: Tokenizer for processing input sequences
195
- *args: Additional positional arguments
196
- **kwargs: Additional keyword arguments
197
- """
198
- super().__init__(config_or_model, tokenizer, *args, **kwargs)
199
- self.metadata["model_name"] = self.__class__.__name__
200
- self.pooler = OmniPooling(self.config)
201
- self.classifier = torch.nn.Linear(
202
- self.config.hidden_size, self.config.num_labels
203
- )
204
- self.loss_fn = torch.nn.MSELoss()
205
- self.model_info()
206
-
207
- def forward(self, **inputs):
208
- """
209
- Forward pass for sequence-level regression.
210
-
211
- Args:
212
- **inputs: Input tensors including input_ids, attention_mask, and labels
213
-
214
- Returns:
215
- dict: Dictionary containing logits, last_hidden_state, and labels
216
- """
217
- labels = inputs.pop("labels", None)
218
- last_hidden_state = self.last_hidden_state_forward(**inputs)
219
- last_hidden_state = self.dropout(last_hidden_state)
220
- last_hidden_state = self.activation(last_hidden_state)
221
- last_hidden_state = self.pooler(inputs, last_hidden_state)
222
- logits = self.classifier(last_hidden_state)
223
- outputs = {
224
- "logits": logits,
225
- "last_hidden_state": last_hidden_state,
226
- "labels": labels,
227
- }
228
- return outputs
229
-
230
- def predict(self, sequence_or_inputs, **kwargs):
231
- """
232
- Generate predictions for sequence-level regression.
233
-
234
- Args:
235
- sequence_or_inputs: Input sequences or pre-processed inputs
236
- **kwargs: Additional keyword arguments
237
-
238
- Returns:
239
- dict: Dictionary containing predictions, logits, and last_hidden_state
240
- """
241
- raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
242
-
243
- logits = raw_outputs["logits"]
244
- last_hidden_state = raw_outputs["last_hidden_state"]
245
-
246
- predictions = []
247
- for i in range(logits.shape[0]):
248
- predictions.append(logits[i].cpu())
249
-
250
- outputs = {
251
- "predictions": (
252
- torch.vstack(predictions).to(self.model.device)
253
- if predictions[0].shape
254
- else torch.tensor(predictions).to(self.model.device)
255
- ),
256
- "logits": logits,
257
- "last_hidden_state": last_hidden_state,
258
- }
259
-
260
- return outputs
261
-
262
- def inference(self, sequence_or_inputs, **kwargs):
263
- """
264
- Perform inference for sequence-level regression.
265
-
266
- Args:
267
- sequence_or_inputs: Input sequences or pre-processed inputs
268
- **kwargs: Additional keyword arguments
269
-
270
- Returns:
271
- dict: Dictionary containing predictions, logits, and last_hidden_state
272
- """
273
- raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
274
-
275
- logits = raw_outputs["logits"]
276
- last_hidden_state = raw_outputs["last_hidden_state"]
277
-
278
- predictions = []
279
- for i in range(logits.shape[0]):
280
- predictions.append(logits[i].cpu())
281
-
282
- if not isinstance(sequence_or_inputs, list):
283
- outputs = {
284
- "predictions": predictions[0],
285
- "logits": logits[0],
286
- "last_hidden_state": last_hidden_state[0],
287
- }
288
- else:
289
- outputs = {
290
- "predictions": predictions,
291
- "logits": logits,
292
- "last_hidden_state": last_hidden_state,
293
- }
294
-
295
- return outputs
296
-
297
- def loss_function(self, logits, labels):
298
- """
299
- Compute the loss for sequence-level regression.
300
-
301
- Args:
302
- logits (torch.Tensor): Model predictions
303
- labels (torch.Tensor): Ground truth labels
304
-
305
- Returns:
306
- torch.Tensor: Computed loss value
307
- """
308
- padding_value = (
309
- self.config.ignore_y if hasattr(self.config, "ignore_y") else -100
310
- )
311
- logits = logits.view(-1)
312
- labels = labels.view(-1)
313
- mask = torch.where(labels != padding_value)
314
-
315
- filtered_logits = logits[mask]
316
- filtered_targets = labels[mask]
317
-
318
- loss = self.loss_fn(filtered_logits, filtered_targets)
319
- return loss
320
-
321
-
322
- class OmniModelForStructuralImputation(OmniModelForSequenceRegression):
323
- """
324
- Structural imputation model for genomic sequences.
325
-
326
- This model is specialized for imputing missing structural information in
327
- genomic sequences. It extends the sequence regression model with additional
328
- embedding capabilities for structural features.
329
-
330
- Attributes:
331
- embedding: Embedding layer for structural features
332
- loss_fn: Mean squared error loss function
333
- """
334
-
335
- def __init__(self, config_or_model, tokenizer, *args, **kwargs):
336
- """
337
- Initialize the structural imputation model.
338
-
339
- Args:
340
- config_or_model: Model configuration or pre-trained model
341
- tokenizer: Tokenizer for processing input sequences
342
- *args: Additional positional arguments
343
- **kwargs: Additional keyword arguments
344
- """
345
- super().__init__(config_or_model, tokenizer, *args, **kwargs)
346
- self.metadata["model_name"] = self.__class__.__name__
347
- self.loss_fn = torch.nn.MSELoss()
348
- self.embedding = torch.nn.Embedding(1, self.config.hidden_size)
349
- self.model_info()
350
-
351
- def forward(self, **inputs):
352
- """
353
- Forward pass for structural imputation.
354
-
355
- Args:
356
- **inputs: Input tensors including input_ids, attention_mask, and labels
357
-
358
- Returns:
359
- dict: Dictionary containing logits, last_hidden_state, and labels
360
- """
361
- labels = inputs.pop("labels", None)
362
- last_hidden_state = self.last_hidden_state_forward(**inputs)
363
- last_hidden_state = self.dropout(last_hidden_state)
364
- last_hidden_state = self.activation(last_hidden_state)
365
- last_hidden_state = self.pooler(inputs, last_hidden_state)
366
- logits = self.classifier(last_hidden_state)
367
- outputs = {
368
- "logits": logits,
369
- "last_hidden_state": last_hidden_state,
370
- "labels": labels,
371
- }
372
- return outputs
373
-
374
-
375
- class OmniModelForTokenRegressionWith2DStructure(
376
- OmniModelForTokenRegression
377
- ):
378
- """
379
- Token-level regression model with 2D structural information.
380
-
381
- This model extends the basic token regression model to incorporate
382
- 2D structural information, useful for RNA structure prediction
383
- and other structural genomics tasks.
384
- """
385
-
386
- def __init__(self, config_or_model, tokenizer, *args, **kwargs):
387
- """
388
- Initialize the 2D structure-aware token regression model.
389
-
390
- Args:
391
- config_or_model: Model configuration or pre-trained model
392
- tokenizer: Tokenizer for processing input sequences
393
- *args: Additional positional arguments
394
- **kwargs: Additional keyword arguments
395
- """
396
- super().__init__(config_or_model, tokenizer, *args, **kwargs)
397
- self.metadata["model_name"] = self.__class__.__name__
398
-
399
- def forward(self, **inputs):
400
- """
401
- Forward pass for 2D structure-aware token regression.
402
-
403
- Args:
404
- **inputs: Input tensors including input_ids, attention_mask, labels, and structural info
405
-
406
- Returns:
407
- dict: Dictionary containing logits, last_hidden_state, and labels
408
- """
409
- labels = inputs.pop("labels", None)
410
- last_hidden_state = self.last_hidden_state_forward(**inputs)
411
- last_hidden_state = self.dropout(last_hidden_state)
412
- last_hidden_state = self.activation(last_hidden_state)
413
- logits = self.classifier(last_hidden_state)
414
- outputs = {
415
- "logits": logits,
416
- "last_hidden_state": last_hidden_state,
417
- "labels": labels,
418
- }
419
- return outputs
420
-
421
-
422
- class OmniModelForSequenceRegressionWith2DStructure(
423
- OmniModelForSequenceRegression
424
- ):
425
- """
426
- Sequence-level regression model with 2D structural information.
427
-
428
- This model extends the basic sequence regression model to incorporate
429
- 2D structural information, useful for RNA structure prediction
430
- and other structural genomics tasks.
431
- """
432
-
433
- def __init__(self, config_or_model, tokenizer, *args, **kwargs):
434
- """
435
- Initialize the 2D structure-aware sequence regression model.
436
-
437
- Args:
438
- config_or_model: Model configuration or pre-trained model
439
- tokenizer: Tokenizer for processing input sequences
440
- *args: Additional positional arguments
441
- **kwargs: Additional keyword arguments
442
- """
443
- super().__init__(config_or_model, tokenizer, *args, **kwargs)
444
- self.metadata["model_name"] = self.__class__.__name__
445
-
446
- def forward(self, **inputs):
447
- """
448
- Forward pass for 2D structure-aware sequence regression.
449
-
450
- Args:
451
- **inputs: Input tensors including input_ids, attention_mask, labels, and structural info
452
-
453
- Returns:
454
- dict: Dictionary containing logits, last_hidden_state, and labels
455
- """
456
- labels = inputs.pop("labels", None)
457
- last_hidden_state = self.last_hidden_state_forward(**inputs)
458
- last_hidden_state = self.dropout(last_hidden_state)
459
- last_hidden_state = self.activation(last_hidden_state)
460
- last_hidden_state = self.pooler(inputs, last_hidden_state)
461
- logits = self.classifier(last_hidden_state)
462
- outputs = {
463
- "logits": logits,
464
- "last_hidden_state": last_hidden_state,
465
- "labels": labels,
466
- }
467
- return outputs
468
-
469
-
470
- class OmniModelForMatrixRegression(OmniModel):
471
- """
472
- Matrix regression model for genomic sequences.
473
-
474
- This model performs regression on matrix representations of genomic sequences,
475
- useful for tasks like contact map prediction, structure prediction, or other
476
- matrix-based genomic analysis tasks.
477
-
478
- Attributes:
479
- resnet: ResNet backbone for processing matrix inputs
480
- classifier: Linear layer for regression output
481
- loss_fn: Mean squared error loss function
482
- """
483
-
484
- def __init__(self, config_or_model, tokenizer, *args, **kwargs):
485
- """
486
- Initialize the matrix regression model.
487
-
488
- Args:
489
- config_or_model: Model configuration or pre-trained model
490
- tokenizer: Tokenizer for processing input sequences
491
- *args: Additional positional arguments
492
- **kwargs: Additional keyword arguments
493
- """
494
- super().__init__(config_or_model, tokenizer, *args, **kwargs)
495
- self.metadata["model_name"] = self.__class__.__name__
496
- self.resnet = resnet_b16(channels=128, bbn=16)
497
- self.classifier = torch.nn.Linear(1, self.config.num_labels)
498
- self.loss_fn = torch.nn.MSELoss()
499
- self.model_info()
500
-
501
- def forward(self, **inputs):
502
- """
503
- Forward pass for matrix regression.
504
-
505
- Args:
506
- **inputs: Input tensors including matrix representations and labels
507
-
508
- Returns:
509
- dict: Dictionary containing logits, last_hidden_state, and labels
510
- """
511
- labels = inputs.pop("labels", None)
512
- matrix_inputs = inputs.pop("matrix_inputs", None)
513
-
514
- if matrix_inputs is None:
515
- raise ValueError("matrix_inputs is required for matrix regression")
516
-
517
- outputs = self.resnet(matrix_inputs)
518
- logits = self.classifier(outputs)
519
-
520
- outputs = {
521
- "logits": logits,
522
- "last_hidden_state": outputs,
523
- "labels": labels,
524
- }
525
- return outputs
526
-
527
- def predict(self, sequence_or_inputs, **kwargs):
528
- """
529
- Generate predictions for matrix regression.
530
-
531
- Args:
532
- sequence_or_inputs: Input sequences or pre-processed inputs
533
- **kwargs: Additional keyword arguments
534
-
535
- Returns:
536
- dict: Dictionary containing predictions, logits, and last_hidden_state
537
- """
538
- raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
539
-
540
- logits = raw_outputs["logits"]
541
- last_hidden_state = raw_outputs["last_hidden_state"]
542
-
543
- predictions = []
544
- for i in range(logits.shape[0]):
545
- predictions.append(logits[i].cpu())
546
-
547
- outputs = {
548
- "predictions": (
549
- torch.vstack(predictions).to(self.model.device)
550
- if predictions[0].shape
551
- else torch.tensor(predictions).to(self.model.device)
552
- ),
553
- "logits": logits,
554
- "last_hidden_state": last_hidden_state,
555
- }
556
-
557
- return outputs
558
-
559
- def inference(self, sequence_or_inputs, **kwargs):
560
- """
561
- Perform inference for matrix regression.
562
-
563
- Args:
564
- sequence_or_inputs: Input sequences or pre-processed inputs
565
- **kwargs: Additional keyword arguments
566
-
567
- Returns:
568
- dict: Dictionary containing predictions, logits, and last_hidden_state
569
- """
570
- raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
571
-
572
- logits = raw_outputs["logits"]
573
- last_hidden_state = raw_outputs["last_hidden_state"]
574
-
575
- predictions = []
576
- for i in range(logits.shape[0]):
577
- predictions.append(logits[i].cpu())
578
-
579
- if not isinstance(sequence_or_inputs, list):
580
- outputs = {
581
- "predictions": predictions[0],
582
- "logits": logits[0],
583
- "last_hidden_state": last_hidden_state[0],
584
- }
585
- else:
586
- outputs = {
587
- "predictions": predictions,
588
- "logits": logits,
589
- "last_hidden_state": last_hidden_state,
590
- }
591
-
592
- return outputs
593
-
594
- def loss_function(self, logits, labels):
595
- """
596
- Compute the loss for matrix regression.
597
-
598
- Args:
599
- logits (torch.Tensor): Model predictions
600
- labels (torch.Tensor): Ground truth labels
601
-
602
- Returns:
603
- torch.Tensor: Computed loss value
604
- """
605
- padding_value = (
606
- self.config.ignore_y if hasattr(self.config, "ignore_y") else -100
607
- )
608
- logits = logits.view(-1)
609
- labels = labels.view(-1)
610
- mask = torch.where(labels != padding_value)
611
-
612
- filtered_logits = logits[mask]
613
- filtered_targets = labels[mask]
614
-
615
- loss = self.loss_fn(filtered_logits, filtered_targets)
616
- return loss
617
-
618
-
619
- class OmniModelForMatrixClassification(OmniModel):
620
- """
621
- Matrix classification model for genomic sequences.
622
-
623
- This model performs classification on matrix representations of genomic sequences,
624
- useful for tasks like structure classification, contact map classification, or other
625
- matrix-based genomic analysis tasks.
626
-
627
- Attributes:
628
- resnet: ResNet backbone for processing matrix inputs
629
- classifier: Linear layer for classification output
630
- loss_fn: Cross-entropy loss function
631
- """
632
-
633
- def __init__(self, config_or_model, tokenizer, *args, **kwargs):
634
- """
635
- Initialize the matrix classification model.
636
-
637
- Args:
638
- config_or_model: Model configuration or pre-trained model
639
- tokenizer: Tokenizer for processing input sequences
640
- *args: Additional positional arguments
641
- **kwargs: Additional keyword arguments
642
- """
643
- super().__init__(config_or_model, tokenizer, *args, **kwargs)
644
- self.metadata["model_name"] = self.__class__.__name__
645
- # For binary classification, output size is 1
646
- self.classifier = torch.nn.Linear(self.config.hidden_size, 1)
647
- self.sigmoid = torch.nn.Sigmoid()
648
- # Change to BCEWithLogitsLoss for binary classification
649
- self.loss_fn = torch.nn.BCEWithLogitsLoss()
650
- self.cnn = resnet_b16(channels=self.config.hidden_size, bbn=16)
651
- self.model_info()
652
-
653
-
654
- def forward(self, **inputs):
655
- """
656
- Forward pass for matrix classification.
657
-
658
- Args:
659
- **inputs: Input tensors including matrix representations and labels
660
-
661
- Returns:
662
- dict: Dictionary containing logits, last_hidden_state, and labels
663
- """
664
- labels = inputs.pop("labels", None)
665
- matrix_inputs = inputs.pop("matrix_inputs", None)
666
-
667
- if matrix_inputs is None:
668
- raise ValueError("matrix_inputs is required for matrix classification")
669
-
670
- outputs = self.resnet(matrix_inputs)
671
- logits = self.classifier(outputs)
672
-
673
- outputs = {
674
- "logits": logits,
675
- "last_hidden_state": outputs,
676
- "labels": labels,
677
- }
678
- return outputs
679
-
680
- def predict(self, sequence_or_inputs, **kwargs):
681
- """
682
- Generate predictions for matrix classification.
683
-
684
- Args:
685
- sequence_or_inputs: Input sequences or pre-processed inputs
686
- **kwargs: Additional keyword arguments
687
-
688
- Returns:
689
- dict: Dictionary containing predictions, logits, and last_hidden_state
690
- """
691
- raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
692
-
693
- logits = raw_outputs["logits"]
694
- last_hidden_state = raw_outputs["last_hidden_state"]
695
-
696
- predictions = []
697
- for i in range(logits.shape[0]):
698
- # Apply sigmoid for binary classification
699
- pred_class = (logits[i] > 0.5).float()
700
- predictions.append(pred_class.cpu())
701
- outputs = {
702
- "predictions": (
703
- torch.vstack(predictions).to(self.model.device)
704
- if predictions[0].shape
705
- else torch.tensor(predictions).to(self.model.device)
706
- ),
707
- "logits": logits,
708
- "last_hidden_state": last_hidden_state,
709
- }
710
-
711
- return outputs
712
-
713
- def inference(self, sequence_or_inputs, **kwargs):
714
- """
715
- Perform inference for matrix classification.
716
-
717
- Args:
718
- sequence_or_inputs: Input sequences or pre-processed inputs
719
- **kwargs: Additional keyword arguments
720
-
721
- Returns:
722
- dict: Dictionary containing predictions, logits, and last_hidden_state
723
- """
724
- raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
725
- inputs = raw_outputs["inputs"]
726
- logits = raw_outputs["logits"]
727
- last_hidden_state = raw_outputs["last_hidden_state"]
728
-
729
- predictions = []
730
- probabilities = []
731
- for i in range(logits.shape[0]):
732
- i_logit = logits[i][inputs["input_ids"][i].ne(self.config.pad_token_id)][
733
- 1:-1
734
- ]
735
- probs = i_logit
736
- # For binary classification, threshold at 0.5
737
- pred_class = (probs > 0.5).float()
738
- predictions.append(pred_class.detach().cpu())
739
- probabilities.append(probs.detach().cpu())
740
-
741
- if not isinstance(sequence_or_inputs, list):
742
- outputs = {
743
- "predictions": predictions[0],
744
- "logits": logits[0],
745
- "last_hidden_state": last_hidden_state[0],
746
- }
747
- else:
748
- outputs = {
749
- "predictions": predictions,
750
- "logits": logits,
751
- "last_hidden_state": last_hidden_state,
752
- }
753
-
754
- return outputs
755
-
756
- def loss_function(self, logits, labels):
757
- """
758
- Compute the loss for matrix classification.
759
-
760
- Args:
761
- logits (torch.Tensor): Model predictions
762
- labels (torch.Tensor): Ground truth labels
763
-
764
- Returns:
765
- torch.Tensor: Computed loss value
766
- """
767
- padding_value = (
768
- self.config.ignore_y if hasattr(self.config, "ignore_y") else -100
769
- )
770
- logits = logits.view(-1, self.config.num_labels)
771
- labels = labels.view(-1)
772
- mask = torch.where(labels != padding_value)
773
-
774
- # Filter out padding
775
- filtered_logits = logits[mask]
776
- filtered_targets = labels[mask]
777
-
778
- # Reshape for binary classification
779
- filtered_logits = filtered_logits.view(-1)
780
- filtered_targets = filtered_targets.view(
781
- -1
782
- ).float() # Convert to float for BCEWithLogitsLoss
783
-
784
- # Apply BCEWithLogitsLoss
785
- loss = self.loss_fn(filtered_logits, filtered_targets)
786
- return loss