omnigenome 0.3.1a0__py3-none-any.whl → 0.4.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. omnigenome/__init__.py +304 -266
  2. omnigenome-0.4.0a0.dist-info/METADATA +354 -0
  3. omnigenome-0.4.0a0.dist-info/RECORD +7 -0
  4. omnigenome/auto/__init__.py +0 -3
  5. omnigenome/auto/auto_bench/__init__.py +0 -11
  6. omnigenome/auto/auto_bench/auto_bench.py +0 -494
  7. omnigenome/auto/auto_bench/auto_bench_cli.py +0 -230
  8. omnigenome/auto/auto_bench/auto_bench_config.py +0 -216
  9. omnigenome/auto/auto_bench/config_check.py +0 -34
  10. omnigenome/auto/auto_train/__init__.py +0 -12
  11. omnigenome/auto/auto_train/auto_train.py +0 -429
  12. omnigenome/auto/auto_train/auto_train_cli.py +0 -222
  13. omnigenome/auto/bench_hub/__init__.py +0 -11
  14. omnigenome/auto/bench_hub/bench_hub.py +0 -25
  15. omnigenome/cli/__init__.py +0 -12
  16. omnigenome/cli/commands/__init__.py +0 -12
  17. omnigenome/cli/commands/base.py +0 -83
  18. omnigenome/cli/commands/bench/__init__.py +0 -12
  19. omnigenome/cli/commands/bench/bench_cli.py +0 -202
  20. omnigenome/cli/commands/rna/__init__.py +0 -12
  21. omnigenome/cli/commands/rna/rna_design.py +0 -177
  22. omnigenome/cli/omnigenome_cli.py +0 -128
  23. omnigenome/src/__init__.py +0 -11
  24. omnigenome/src/abc/__init__.py +0 -11
  25. omnigenome/src/abc/abstract_dataset.py +0 -641
  26. omnigenome/src/abc/abstract_metric.py +0 -114
  27. omnigenome/src/abc/abstract_model.py +0 -690
  28. omnigenome/src/abc/abstract_tokenizer.py +0 -269
  29. omnigenome/src/dataset/__init__.py +0 -16
  30. omnigenome/src/dataset/omni_dataset.py +0 -437
  31. omnigenome/src/lora/__init__.py +0 -12
  32. omnigenome/src/lora/lora_model.py +0 -300
  33. omnigenome/src/metric/__init__.py +0 -15
  34. omnigenome/src/metric/classification_metric.py +0 -184
  35. omnigenome/src/metric/metric.py +0 -199
  36. omnigenome/src/metric/ranking_metric.py +0 -142
  37. omnigenome/src/metric/regression_metric.py +0 -191
  38. omnigenome/src/misc/__init__.py +0 -3
  39. omnigenome/src/misc/utils.py +0 -503
  40. omnigenome/src/model/__init__.py +0 -19
  41. omnigenome/src/model/augmentation/__init__.py +0 -11
  42. omnigenome/src/model/augmentation/model.py +0 -219
  43. omnigenome/src/model/classification/__init__.py +0 -11
  44. omnigenome/src/model/classification/model.py +0 -638
  45. omnigenome/src/model/embedding/__init__.py +0 -11
  46. omnigenome/src/model/embedding/model.py +0 -263
  47. omnigenome/src/model/mlm/__init__.py +0 -11
  48. omnigenome/src/model/mlm/model.py +0 -177
  49. omnigenome/src/model/module_utils.py +0 -232
  50. omnigenome/src/model/regression/__init__.py +0 -11
  51. omnigenome/src/model/regression/model.py +0 -781
  52. omnigenome/src/model/regression/resnet.py +0 -483
  53. omnigenome/src/model/rna_design/__init__.py +0 -11
  54. omnigenome/src/model/rna_design/model.py +0 -476
  55. omnigenome/src/model/seq2seq/__init__.py +0 -11
  56. omnigenome/src/model/seq2seq/model.py +0 -44
  57. omnigenome/src/tokenizer/__init__.py +0 -16
  58. omnigenome/src/tokenizer/bpe_tokenizer.py +0 -226
  59. omnigenome/src/tokenizer/kmers_tokenizer.py +0 -247
  60. omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +0 -249
  61. omnigenome/src/trainer/__init__.py +0 -14
  62. omnigenome/src/trainer/accelerate_trainer.py +0 -747
  63. omnigenome/src/trainer/hf_trainer.py +0 -75
  64. omnigenome/src/trainer/trainer.py +0 -591
  65. omnigenome/utility/__init__.py +0 -3
  66. omnigenome/utility/dataset_hub/__init__.py +0 -12
  67. omnigenome/utility/dataset_hub/dataset_hub.py +0 -178
  68. omnigenome/utility/ensemble.py +0 -324
  69. omnigenome/utility/hub_utils.py +0 -517
  70. omnigenome/utility/model_hub/__init__.py +0 -11
  71. omnigenome/utility/model_hub/model_hub.py +0 -232
  72. omnigenome/utility/pipeline_hub/__init__.py +0 -11
  73. omnigenome/utility/pipeline_hub/pipeline.py +0 -483
  74. omnigenome/utility/pipeline_hub/pipeline_hub.py +0 -129
  75. omnigenome-0.3.1a0.dist-info/METADATA +0 -224
  76. omnigenome-0.3.1a0.dist-info/RECORD +0 -78
  77. {omnigenome-0.3.1a0.dist-info → omnigenome-0.4.0a0.dist-info}/WHEEL +0 -0
  78. {omnigenome-0.3.1a0.dist-info → omnigenome-0.4.0a0.dist-info}/entry_points.txt +0 -0
  79. {omnigenome-0.3.1a0.dist-info → omnigenome-0.4.0a0.dist-info}/licenses/LICENSE +0 -0
  80. {omnigenome-0.3.1a0.dist-info → omnigenome-0.4.0a0.dist-info}/top_level.txt +0 -0
@@ -1,781 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # file: model.py
3
- # time: 18:36 06/04/2024
4
- # author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
- # github: https://github.com/yangheng95
6
- # huggingface: https://huggingface.co/yangheng
7
- # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
8
- # Copyright (C) 2019-2024. All Rights Reserved.
9
- """
10
- Regression models for OmniGenome framework.
11
-
12
- This module provides various regression model implementations for genomic sequence analysis,
13
- including token-level regression, sequence-level regression, structural imputation,
14
- and matrix regression/classification tasks.
15
- """
16
- import torch
17
-
18
- from .resnet import resnet_b16
19
- from ...abc.abstract_model import OmniModel
20
- from ..module_utils import OmniPooling
21
-
22
-
23
- class OmniModelForTokenRegression(OmniModel):
24
- """
25
- Token-level regression model for genomic sequences.
26
-
27
- This model performs regression at the token level, predicting continuous values
28
- for each token in the input sequence. It's useful for tasks like predicting
29
- binding affinities, expression levels, or other continuous properties at each
30
- position in a genomic sequence.
31
-
32
- Attributes:
33
- classifier: Linear layer for regression output
34
- loss_fn: Mean squared error loss function
35
- """
36
-
37
- def __init__(self, config_or_model, tokenizer, *args, **kwargs):
38
- """
39
- Initialize the token regression model.
40
-
41
- Args:
42
- config_or_model: Model configuration or pre-trained model
43
- tokenizer: Tokenizer for processing input sequences
44
- *args: Additional positional arguments
45
- **kwargs: Additional keyword arguments
46
- """
47
- super().__init__(config_or_model, tokenizer, *args, **kwargs)
48
- self.metadata["model_name"] = self.__class__.__name__
49
- self.classifier = torch.nn.Linear(
50
- self.config.hidden_size, self.config.num_labels
51
- )
52
- self.loss_fn = torch.nn.MSELoss()
53
- self.model_info()
54
-
55
- def forward(self, **inputs):
56
- """
57
- Forward pass for token-level regression.
58
-
59
- Args:
60
- **inputs: Input tensors including input_ids, attention_mask, and labels
61
-
62
- Returns:
63
- dict: Dictionary containing logits, last_hidden_state, and labels
64
- """
65
- labels = inputs.pop("labels", None)
66
- last_hidden_state = self.last_hidden_state_forward(**inputs)
67
- last_hidden_state = self.dropout(last_hidden_state)
68
- last_hidden_state = self.activation(last_hidden_state)
69
- logits = self.classifier(last_hidden_state)
70
- outputs = {
71
- "logits": logits,
72
- "last_hidden_state": last_hidden_state,
73
- "labels": labels,
74
- }
75
- return outputs
76
-
77
- def predict(self, sequence_or_inputs, **kwargs):
78
- """
79
- Generate predictions for token-level regression.
80
-
81
- Args:
82
- sequence_or_inputs: Input sequences or pre-processed inputs
83
- **kwargs: Additional keyword arguments
84
-
85
- Returns:
86
- dict: Dictionary containing predictions, logits, and last_hidden_state
87
- """
88
- raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
89
-
90
- logits = raw_outputs["logits"]
91
- last_hidden_state = raw_outputs["last_hidden_state"]
92
-
93
- predictions = []
94
- for i in range(logits.shape[0]):
95
- predictions.append(logits[i].cpu())
96
-
97
- outputs = {
98
- "predictions": (
99
- torch.vstack(predictions).to(self.model.device)
100
- if predictions[0].shape
101
- else torch.tensor(predictions).to(self.model.device)
102
- ),
103
- "logits": logits,
104
- "last_hidden_state": last_hidden_state,
105
- }
106
-
107
- return outputs
108
-
109
- def inference(self, sequence_or_inputs, **kwargs):
110
- """
111
- Perform inference for token-level regression, excluding special tokens.
112
-
113
- Args:
114
- sequence_or_inputs: Input sequences or pre-processed inputs
115
- **kwargs: Additional keyword arguments
116
-
117
- Returns:
118
- dict: Dictionary containing predictions, logits, and last_hidden_state
119
- """
120
- raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
121
-
122
- inputs = raw_outputs["inputs"]
123
- logits = raw_outputs["logits"]
124
- last_hidden_state = raw_outputs["last_hidden_state"]
125
-
126
- predictions = []
127
- for i in range(logits.shape[0]):
128
- i_logit = logits[i][inputs["input_ids"][i].ne(self.config.pad_token_id)][
129
- 1:-1
130
- ]
131
- predictions.append(i_logit.detach().cpu())
132
-
133
- if not isinstance(sequence_or_inputs, list):
134
- outputs = {
135
- "predictions": predictions[0],
136
- "logits": logits[0],
137
- "last_hidden_state": last_hidden_state[0],
138
- }
139
- else:
140
- outputs = {
141
- "predictions": predictions,
142
- "logits": logits,
143
- "last_hidden_state": last_hidden_state,
144
- }
145
-
146
- return outputs
147
-
148
- def loss_function(self, logits, labels):
149
- """
150
- Compute the loss for token-level regression.
151
-
152
- Args:
153
- logits (torch.Tensor): Model predictions
154
- labels (torch.Tensor): Ground truth labels
155
-
156
- Returns:
157
- torch.Tensor: Computed loss value
158
- """
159
- padding_value = (
160
- self.config.ignore_y if hasattr(self.config, "ignore_y") else -100
161
- )
162
- logits = logits.view(-1)
163
- labels = labels.view(-1)
164
- mask = torch.where(labels != padding_value)
165
-
166
- filtered_logits = logits[mask]
167
- filtered_targets = labels[mask]
168
-
169
- loss = self.loss_fn(filtered_logits, filtered_targets)
170
- return loss
171
-
172
-
173
- class OmniModelForSequenceRegression(OmniModel):
174
- """
175
- Sequence-level regression model for genomic sequences.
176
-
177
- This model performs regression at the sequence level, predicting a single
178
- continuous value for the entire input sequence. It's useful for tasks like
179
- predicting overall expression levels, binding affinities, or other sequence-level
180
- properties.
181
-
182
- Attributes:
183
- pooler: OmniPooling layer for sequence-level representation
184
- classifier: Linear layer for regression output
185
- loss_fn: Mean squared error loss function
186
- """
187
-
188
- def __init__(self, config_or_model, tokenizer, *args, **kwargs):
189
- """
190
- Initialize the sequence regression model.
191
-
192
- Args:
193
- config_or_model: Model configuration or pre-trained model
194
- tokenizer: Tokenizer for processing input sequences
195
- *args: Additional positional arguments
196
- **kwargs: Additional keyword arguments
197
- """
198
- super().__init__(config_or_model, tokenizer, *args, **kwargs)
199
- self.metadata["model_name"] = self.__class__.__name__
200
- self.pooler = OmniPooling(self.config)
201
- self.classifier = torch.nn.Linear(
202
- self.config.hidden_size, self.config.num_labels
203
- )
204
- self.loss_fn = torch.nn.MSELoss()
205
- self.model_info()
206
-
207
- def forward(self, **inputs):
208
- """
209
- Forward pass for sequence-level regression.
210
-
211
- Args:
212
- **inputs: Input tensors including input_ids, attention_mask, and labels
213
-
214
- Returns:
215
- dict: Dictionary containing logits, last_hidden_state, and labels
216
- """
217
- labels = inputs.pop("labels", None)
218
- last_hidden_state = self.last_hidden_state_forward(**inputs)
219
- last_hidden_state = self.dropout(last_hidden_state)
220
- last_hidden_state = self.activation(last_hidden_state)
221
- last_hidden_state = self.pooler(inputs, last_hidden_state)
222
- logits = self.classifier(last_hidden_state)
223
- outputs = {
224
- "logits": logits,
225
- "last_hidden_state": last_hidden_state,
226
- "labels": labels,
227
- }
228
- return outputs
229
-
230
- def predict(self, sequence_or_inputs, **kwargs):
231
- """
232
- Generate predictions for sequence-level regression.
233
-
234
- Args:
235
- sequence_or_inputs: Input sequences or pre-processed inputs
236
- **kwargs: Additional keyword arguments
237
-
238
- Returns:
239
- dict: Dictionary containing predictions, logits, and last_hidden_state
240
- """
241
- raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
242
-
243
- logits = raw_outputs["logits"]
244
- last_hidden_state = raw_outputs["last_hidden_state"]
245
-
246
- predictions = []
247
- for i in range(logits.shape[0]):
248
- predictions.append(logits[i].cpu())
249
-
250
- outputs = {
251
- "predictions": (
252
- torch.vstack(predictions).to(self.model.device)
253
- if predictions[0].shape
254
- else torch.tensor(predictions).to(self.model.device)
255
- ),
256
- "logits": logits,
257
- "last_hidden_state": last_hidden_state,
258
- }
259
-
260
- return outputs
261
-
262
- def inference(self, sequence_or_inputs, **kwargs):
263
- """
264
- Perform inference for sequence-level regression.
265
-
266
- Args:
267
- sequence_or_inputs: Input sequences or pre-processed inputs
268
- **kwargs: Additional keyword arguments
269
-
270
- Returns:
271
- dict: Dictionary containing predictions, logits, and last_hidden_state
272
- """
273
- raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
274
-
275
- logits = raw_outputs["logits"]
276
- last_hidden_state = raw_outputs["last_hidden_state"]
277
-
278
- predictions = []
279
- for i in range(logits.shape[0]):
280
- predictions.append(logits[i].cpu())
281
-
282
- if not isinstance(sequence_or_inputs, list):
283
- outputs = {
284
- "predictions": predictions[0],
285
- "logits": logits[0],
286
- "last_hidden_state": last_hidden_state[0],
287
- }
288
- else:
289
- outputs = {
290
- "predictions": predictions,
291
- "logits": logits,
292
- "last_hidden_state": last_hidden_state,
293
- }
294
-
295
- return outputs
296
-
297
- def loss_function(self, logits, labels):
298
- """
299
- Compute the loss for sequence-level regression.
300
-
301
- Args:
302
- logits (torch.Tensor): Model predictions
303
- labels (torch.Tensor): Ground truth labels
304
-
305
- Returns:
306
- torch.Tensor: Computed loss value
307
- """
308
- padding_value = (
309
- self.config.ignore_y if hasattr(self.config, "ignore_y") else -100
310
- )
311
- logits = logits.view(-1)
312
- labels = labels.view(-1)
313
- mask = torch.where(labels != padding_value)
314
-
315
- filtered_logits = logits[mask]
316
- filtered_targets = labels[mask]
317
-
318
- loss = self.loss_fn(filtered_logits, filtered_targets)
319
- return loss
320
-
321
-
322
- class OmniModelForStructuralImputation(OmniModelForSequenceRegression):
323
- """
324
- Structural imputation model for genomic sequences.
325
-
326
- This model is specialized for imputing missing structural information in
327
- genomic sequences. It extends the sequence regression model with additional
328
- embedding capabilities for structural features.
329
-
330
- Attributes:
331
- embedding: Embedding layer for structural features
332
- loss_fn: Mean squared error loss function
333
- """
334
-
335
- def __init__(self, config_or_model, tokenizer, *args, **kwargs):
336
- """
337
- Initialize the structural imputation model.
338
-
339
- Args:
340
- config_or_model: Model configuration or pre-trained model
341
- tokenizer: Tokenizer for processing input sequences
342
- *args: Additional positional arguments
343
- **kwargs: Additional keyword arguments
344
- """
345
- super().__init__(config_or_model, tokenizer, *args, **kwargs)
346
- self.metadata["model_name"] = self.__class__.__name__
347
- self.loss_fn = torch.nn.MSELoss()
348
- self.embedding = torch.nn.Embedding(1, self.config.hidden_size)
349
- self.model_info()
350
-
351
- def forward(self, **inputs):
352
- """
353
- Forward pass for structural imputation.
354
-
355
- Args:
356
- **inputs: Input tensors including input_ids, attention_mask, and labels
357
-
358
- Returns:
359
- dict: Dictionary containing logits, last_hidden_state, and labels
360
- """
361
- labels = inputs.pop("labels", None)
362
- last_hidden_state = self.last_hidden_state_forward(**inputs)
363
- last_hidden_state = self.dropout(last_hidden_state)
364
- last_hidden_state = self.activation(last_hidden_state)
365
- last_hidden_state = self.pooler(inputs, last_hidden_state)
366
- logits = self.classifier(last_hidden_state)
367
- outputs = {
368
- "logits": logits,
369
- "last_hidden_state": last_hidden_state,
370
- "labels": labels,
371
- }
372
- return outputs
373
-
374
-
375
- class OmniModelForTokenRegressionWith2DStructure(OmniModelForTokenRegression):
376
- """
377
- Token-level regression model with 2D structural information.
378
-
379
- This model extends the basic token regression model to incorporate
380
- 2D structural information, useful for RNA structure prediction
381
- and other structural genomics tasks.
382
- """
383
-
384
- def __init__(self, config_or_model, tokenizer, *args, **kwargs):
385
- """
386
- Initialize the 2D structure-aware token regression model.
387
-
388
- Args:
389
- config_or_model: Model configuration or pre-trained model
390
- tokenizer: Tokenizer for processing input sequences
391
- *args: Additional positional arguments
392
- **kwargs: Additional keyword arguments
393
- """
394
- super().__init__(config_or_model, tokenizer, *args, **kwargs)
395
- self.metadata["model_name"] = self.__class__.__name__
396
-
397
- def forward(self, **inputs):
398
- """
399
- Forward pass for 2D structure-aware token regression.
400
-
401
- Args:
402
- **inputs: Input tensors including input_ids, attention_mask, labels, and structural info
403
-
404
- Returns:
405
- dict: Dictionary containing logits, last_hidden_state, and labels
406
- """
407
- labels = inputs.pop("labels", None)
408
- last_hidden_state = self.last_hidden_state_forward(**inputs)
409
- last_hidden_state = self.dropout(last_hidden_state)
410
- last_hidden_state = self.activation(last_hidden_state)
411
- logits = self.classifier(last_hidden_state)
412
- outputs = {
413
- "logits": logits,
414
- "last_hidden_state": last_hidden_state,
415
- "labels": labels,
416
- }
417
- return outputs
418
-
419
-
420
- class OmniModelForSequenceRegressionWith2DStructure(OmniModelForSequenceRegression):
421
- """
422
- Sequence-level regression model with 2D structural information.
423
-
424
- This model extends the basic sequence regression model to incorporate
425
- 2D structural information, useful for RNA structure prediction
426
- and other structural genomics tasks.
427
- """
428
-
429
- def __init__(self, config_or_model, tokenizer, *args, **kwargs):
430
- """
431
- Initialize the 2D structure-aware sequence regression model.
432
-
433
- Args:
434
- config_or_model: Model configuration or pre-trained model
435
- tokenizer: Tokenizer for processing input sequences
436
- *args: Additional positional arguments
437
- **kwargs: Additional keyword arguments
438
- """
439
- super().__init__(config_or_model, tokenizer, *args, **kwargs)
440
- self.metadata["model_name"] = self.__class__.__name__
441
-
442
- def forward(self, **inputs):
443
- """
444
- Forward pass for 2D structure-aware sequence regression.
445
-
446
- Args:
447
- **inputs: Input tensors including input_ids, attention_mask, labels, and structural info
448
-
449
- Returns:
450
- dict: Dictionary containing logits, last_hidden_state, and labels
451
- """
452
- labels = inputs.pop("labels", None)
453
- last_hidden_state = self.last_hidden_state_forward(**inputs)
454
- last_hidden_state = self.dropout(last_hidden_state)
455
- last_hidden_state = self.activation(last_hidden_state)
456
- last_hidden_state = self.pooler(inputs, last_hidden_state)
457
- logits = self.classifier(last_hidden_state)
458
- outputs = {
459
- "logits": logits,
460
- "last_hidden_state": last_hidden_state,
461
- "labels": labels,
462
- }
463
- return outputs
464
-
465
-
466
- class OmniModelForMatrixRegression(OmniModel):
467
- """
468
- Matrix regression model for genomic sequences.
469
-
470
- This model performs regression on matrix representations of genomic sequences,
471
- useful for tasks like contact map prediction, structure prediction, or other
472
- matrix-based genomic analysis tasks.
473
-
474
- Attributes:
475
- resnet: ResNet backbone for processing matrix inputs
476
- classifier: Linear layer for regression output
477
- loss_fn: Mean squared error loss function
478
- """
479
-
480
- def __init__(self, config_or_model, tokenizer, *args, **kwargs):
481
- """
482
- Initialize the matrix regression model.
483
-
484
- Args:
485
- config_or_model: Model configuration or pre-trained model
486
- tokenizer: Tokenizer for processing input sequences
487
- *args: Additional positional arguments
488
- **kwargs: Additional keyword arguments
489
- """
490
- super().__init__(config_or_model, tokenizer, *args, **kwargs)
491
- self.metadata["model_name"] = self.__class__.__name__
492
- self.resnet = resnet_b16(channels=128, bbn=16)
493
- self.classifier = torch.nn.Linear(1, self.config.num_labels)
494
- self.loss_fn = torch.nn.MSELoss()
495
- self.model_info()
496
-
497
- def forward(self, **inputs):
498
- """
499
- Forward pass for matrix regression.
500
-
501
- Args:
502
- **inputs: Input tensors including matrix representations and labels
503
-
504
- Returns:
505
- dict: Dictionary containing logits, last_hidden_state, and labels
506
- """
507
- labels = inputs.pop("labels", None)
508
- matrix_inputs = inputs.pop("matrix_inputs", None)
509
-
510
- if matrix_inputs is None:
511
- raise ValueError("matrix_inputs is required for matrix regression")
512
-
513
- outputs = self.resnet(matrix_inputs)
514
- logits = self.classifier(outputs)
515
-
516
- outputs = {
517
- "logits": logits,
518
- "last_hidden_state": outputs,
519
- "labels": labels,
520
- }
521
- return outputs
522
-
523
- def predict(self, sequence_or_inputs, **kwargs):
524
- """
525
- Generate predictions for matrix regression.
526
-
527
- Args:
528
- sequence_or_inputs: Input sequences or pre-processed inputs
529
- **kwargs: Additional keyword arguments
530
-
531
- Returns:
532
- dict: Dictionary containing predictions, logits, and last_hidden_state
533
- """
534
- raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
535
-
536
- logits = raw_outputs["logits"]
537
- last_hidden_state = raw_outputs["last_hidden_state"]
538
-
539
- predictions = []
540
- for i in range(logits.shape[0]):
541
- predictions.append(logits[i].cpu())
542
-
543
- outputs = {
544
- "predictions": (
545
- torch.vstack(predictions).to(self.model.device)
546
- if predictions[0].shape
547
- else torch.tensor(predictions).to(self.model.device)
548
- ),
549
- "logits": logits,
550
- "last_hidden_state": last_hidden_state,
551
- }
552
-
553
- return outputs
554
-
555
- def inference(self, sequence_or_inputs, **kwargs):
556
- """
557
- Perform inference for matrix regression.
558
-
559
- Args:
560
- sequence_or_inputs: Input sequences or pre-processed inputs
561
- **kwargs: Additional keyword arguments
562
-
563
- Returns:
564
- dict: Dictionary containing predictions, logits, and last_hidden_state
565
- """
566
- raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
567
-
568
- logits = raw_outputs["logits"]
569
- last_hidden_state = raw_outputs["last_hidden_state"]
570
-
571
- predictions = []
572
- for i in range(logits.shape[0]):
573
- predictions.append(logits[i].cpu())
574
-
575
- if not isinstance(sequence_or_inputs, list):
576
- outputs = {
577
- "predictions": predictions[0],
578
- "logits": logits[0],
579
- "last_hidden_state": last_hidden_state[0],
580
- }
581
- else:
582
- outputs = {
583
- "predictions": predictions,
584
- "logits": logits,
585
- "last_hidden_state": last_hidden_state,
586
- }
587
-
588
- return outputs
589
-
590
- def loss_function(self, logits, labels):
591
- """
592
- Compute the loss for matrix regression.
593
-
594
- Args:
595
- logits (torch.Tensor): Model predictions
596
- labels (torch.Tensor): Ground truth labels
597
-
598
- Returns:
599
- torch.Tensor: Computed loss value
600
- """
601
- padding_value = (
602
- self.config.ignore_y if hasattr(self.config, "ignore_y") else -100
603
- )
604
- logits = logits.view(-1)
605
- labels = labels.view(-1)
606
- mask = torch.where(labels != padding_value)
607
-
608
- filtered_logits = logits[mask]
609
- filtered_targets = labels[mask]
610
-
611
- loss = self.loss_fn(filtered_logits, filtered_targets)
612
- return loss
613
-
614
-
615
- class OmniModelForMatrixClassification(OmniModel):
616
- """
617
- Matrix classification model for genomic sequences.
618
-
619
- This model performs classification on matrix representations of genomic sequences,
620
- useful for tasks like structure classification, contact map classification, or other
621
- matrix-based genomic analysis tasks.
622
-
623
- Attributes:
624
- resnet: ResNet backbone for processing matrix inputs
625
- classifier: Linear layer for classification output
626
- loss_fn: Cross-entropy loss function
627
- """
628
-
629
- def __init__(self, config_or_model, tokenizer, *args, **kwargs):
630
- """
631
- Initialize the matrix classification model.
632
-
633
- Args:
634
- config_or_model: Model configuration or pre-trained model
635
- tokenizer: Tokenizer for processing input sequences
636
- *args: Additional positional arguments
637
- **kwargs: Additional keyword arguments
638
- """
639
- super().__init__(config_or_model, tokenizer, *args, **kwargs)
640
- self.metadata["model_name"] = self.__class__.__name__
641
- # For binary classification, output size is 1
642
- self.classifier = torch.nn.Linear(self.config.hidden_size, 1)
643
- self.sigmoid = torch.nn.Sigmoid()
644
- # Change to BCEWithLogitsLoss for binary classification
645
- self.loss_fn = torch.nn.BCEWithLogitsLoss()
646
- self.cnn = resnet_b16(channels=self.config.hidden_size, bbn=16)
647
- self.model_info()
648
-
649
- def forward(self, **inputs):
650
- """
651
- Forward pass for matrix classification.
652
-
653
- Args:
654
- **inputs: Input tensors including matrix representations and labels
655
-
656
- Returns:
657
- dict: Dictionary containing logits, last_hidden_state, and labels
658
- """
659
- labels = inputs.pop("labels", None)
660
- matrix_inputs = inputs.pop("matrix_inputs", None)
661
-
662
- if matrix_inputs is None:
663
- raise ValueError("matrix_inputs is required for matrix classification")
664
-
665
- outputs = self.resnet(matrix_inputs)
666
- logits = self.classifier(outputs)
667
-
668
- outputs = {
669
- "logits": logits,
670
- "last_hidden_state": outputs,
671
- "labels": labels,
672
- }
673
- return outputs
674
-
675
- def predict(self, sequence_or_inputs, **kwargs):
676
- """
677
- Generate predictions for matrix classification.
678
-
679
- Args:
680
- sequence_or_inputs: Input sequences or pre-processed inputs
681
- **kwargs: Additional keyword arguments
682
-
683
- Returns:
684
- dict: Dictionary containing predictions, logits, and last_hidden_state
685
- """
686
- raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
687
-
688
- logits = raw_outputs["logits"]
689
- last_hidden_state = raw_outputs["last_hidden_state"]
690
-
691
- predictions = []
692
- for i in range(logits.shape[0]):
693
- # Apply sigmoid for binary classification
694
- pred_class = (logits[i] > 0.5).float()
695
- predictions.append(pred_class.cpu())
696
- outputs = {
697
- "predictions": (
698
- torch.vstack(predictions).to(self.model.device)
699
- if predictions[0].shape
700
- else torch.tensor(predictions).to(self.model.device)
701
- ),
702
- "logits": logits,
703
- "last_hidden_state": last_hidden_state,
704
- }
705
-
706
- return outputs
707
-
708
- def inference(self, sequence_or_inputs, **kwargs):
709
- """
710
- Perform inference for matrix classification.
711
-
712
- Args:
713
- sequence_or_inputs: Input sequences or pre-processed inputs
714
- **kwargs: Additional keyword arguments
715
-
716
- Returns:
717
- dict: Dictionary containing predictions, logits, and last_hidden_state
718
- """
719
- raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
720
- inputs = raw_outputs["inputs"]
721
- logits = raw_outputs["logits"]
722
- last_hidden_state = raw_outputs["last_hidden_state"]
723
-
724
- predictions = []
725
- probabilities = []
726
- for i in range(logits.shape[0]):
727
- i_logit = logits[i][inputs["input_ids"][i].ne(self.config.pad_token_id)][
728
- 1:-1
729
- ]
730
- probs = i_logit
731
- # For binary classification, threshold at 0.5
732
- pred_class = (probs > 0.5).float()
733
- predictions.append(pred_class.detach().cpu())
734
- probabilities.append(probs.detach().cpu())
735
-
736
- if not isinstance(sequence_or_inputs, list):
737
- outputs = {
738
- "predictions": predictions[0],
739
- "logits": logits[0],
740
- "last_hidden_state": last_hidden_state[0],
741
- }
742
- else:
743
- outputs = {
744
- "predictions": predictions,
745
- "logits": logits,
746
- "last_hidden_state": last_hidden_state,
747
- }
748
-
749
- return outputs
750
-
751
- def loss_function(self, logits, labels):
752
- """
753
- Compute the loss for matrix classification.
754
-
755
- Args:
756
- logits (torch.Tensor): Model predictions
757
- labels (torch.Tensor): Ground truth labels
758
-
759
- Returns:
760
- torch.Tensor: Computed loss value
761
- """
762
- padding_value = (
763
- self.config.ignore_y if hasattr(self.config, "ignore_y") else -100
764
- )
765
- logits = logits.view(-1, self.config.num_labels)
766
- labels = labels.view(-1)
767
- mask = torch.where(labels != padding_value)
768
-
769
- # Filter out padding
770
- filtered_logits = logits[mask]
771
- filtered_targets = labels[mask]
772
-
773
- # Reshape for binary classification
774
- filtered_logits = filtered_logits.view(-1)
775
- filtered_targets = filtered_targets.view(
776
- -1
777
- ).float() # Convert to float for BCEWithLogitsLoss
778
-
779
- # Apply BCEWithLogitsLoss
780
- loss = self.loss_fn(filtered_logits, filtered_targets)
781
- return loss