omnigenome 0.3.0a1__py3-none-any.whl → 0.3.3a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of omnigenome might be problematic. Click here for more details.
- omnigenome/__init__.py +252 -258
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.3a0.dist-info}/METADATA +10 -10
- omnigenome-0.3.3a0.dist-info/RECORD +7 -0
- omnigenome/auto/__init__.py +0 -3
- omnigenome/auto/auto_bench/__init__.py +0 -12
- omnigenome/auto/auto_bench/auto_bench.py +0 -484
- omnigenome/auto/auto_bench/auto_bench_cli.py +0 -230
- omnigenome/auto/auto_bench/auto_bench_config.py +0 -216
- omnigenome/auto/auto_bench/config_check.py +0 -34
- omnigenome/auto/auto_train/__init__.py +0 -13
- omnigenome/auto/auto_train/auto_train.py +0 -430
- omnigenome/auto/auto_train/auto_train_cli.py +0 -222
- omnigenome/auto/bench_hub/__init__.py +0 -12
- omnigenome/auto/bench_hub/bench_hub.py +0 -25
- omnigenome/cli/__init__.py +0 -13
- omnigenome/cli/commands/__init__.py +0 -13
- omnigenome/cli/commands/base.py +0 -83
- omnigenome/cli/commands/bench/__init__.py +0 -13
- omnigenome/cli/commands/bench/bench_cli.py +0 -202
- omnigenome/cli/commands/rna/__init__.py +0 -13
- omnigenome/cli/commands/rna/rna_design.py +0 -178
- omnigenome/cli/omnigenome_cli.py +0 -128
- omnigenome/src/__init__.py +0 -12
- omnigenome/src/abc/__init__.py +0 -12
- omnigenome/src/abc/abstract_dataset.py +0 -622
- omnigenome/src/abc/abstract_metric.py +0 -114
- omnigenome/src/abc/abstract_model.py +0 -689
- omnigenome/src/abc/abstract_tokenizer.py +0 -267
- omnigenome/src/dataset/__init__.py +0 -16
- omnigenome/src/dataset/omni_dataset.py +0 -435
- omnigenome/src/lora/__init__.py +0 -13
- omnigenome/src/lora/lora_model.py +0 -294
- omnigenome/src/metric/__init__.py +0 -15
- omnigenome/src/metric/classification_metric.py +0 -184
- omnigenome/src/metric/metric.py +0 -199
- omnigenome/src/metric/ranking_metric.py +0 -142
- omnigenome/src/metric/regression_metric.py +0 -191
- omnigenome/src/misc/__init__.py +0 -3
- omnigenome/src/misc/utils.py +0 -499
- omnigenome/src/model/__init__.py +0 -19
- omnigenome/src/model/augmentation/__init__.py +0 -12
- omnigenome/src/model/augmentation/model.py +0 -219
- omnigenome/src/model/classification/__init__.py +0 -12
- omnigenome/src/model/classification/model.py +0 -642
- omnigenome/src/model/embedding/__init__.py +0 -12
- omnigenome/src/model/embedding/model.py +0 -263
- omnigenome/src/model/mlm/__init__.py +0 -12
- omnigenome/src/model/mlm/model.py +0 -177
- omnigenome/src/model/module_utils.py +0 -232
- omnigenome/src/model/regression/__init__.py +0 -12
- omnigenome/src/model/regression/model.py +0 -786
- omnigenome/src/model/regression/resnet.py +0 -483
- omnigenome/src/model/rna_design/__init__.py +0 -12
- omnigenome/src/model/rna_design/model.py +0 -469
- omnigenome/src/model/seq2seq/__init__.py +0 -12
- omnigenome/src/model/seq2seq/model.py +0 -44
- omnigenome/src/tokenizer/__init__.py +0 -16
- omnigenome/src/tokenizer/bpe_tokenizer.py +0 -226
- omnigenome/src/tokenizer/kmers_tokenizer.py +0 -247
- omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +0 -249
- omnigenome/src/trainer/__init__.py +0 -14
- omnigenome/src/trainer/accelerate_trainer.py +0 -739
- omnigenome/src/trainer/hf_trainer.py +0 -75
- omnigenome/src/trainer/trainer.py +0 -579
- omnigenome/utility/__init__.py +0 -3
- omnigenome/utility/dataset_hub/__init__.py +0 -13
- omnigenome/utility/dataset_hub/dataset_hub.py +0 -178
- omnigenome/utility/ensemble.py +0 -324
- omnigenome/utility/hub_utils.py +0 -517
- omnigenome/utility/model_hub/__init__.py +0 -12
- omnigenome/utility/model_hub/model_hub.py +0 -231
- omnigenome/utility/pipeline_hub/__init__.py +0 -12
- omnigenome/utility/pipeline_hub/pipeline.py +0 -483
- omnigenome/utility/pipeline_hub/pipeline_hub.py +0 -129
- omnigenome-0.3.0a1.dist-info/RECORD +0 -78
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.3a0.dist-info}/WHEEL +0 -0
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.3a0.dist-info}/entry_points.txt +0 -0
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.3a0.dist-info}/licenses/LICENSE +0 -0
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.3a0.dist-info}/top_level.txt +0 -0
|
@@ -1,786 +0,0 @@
|
|
|
1
|
-
# -*- coding: utf-8 -*-
|
|
2
|
-
# file: model.py
|
|
3
|
-
# time: 18:36 06/04/2024
|
|
4
|
-
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
-
# github: https://github.com/yangheng95
|
|
6
|
-
# huggingface: https://huggingface.co/yangheng
|
|
7
|
-
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
8
|
-
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
9
|
-
"""
|
|
10
|
-
Regression models for OmniGenome framework.
|
|
11
|
-
|
|
12
|
-
This module provides various regression model implementations for genomic sequence analysis,
|
|
13
|
-
including token-level regression, sequence-level regression, structural imputation,
|
|
14
|
-
and matrix regression/classification tasks.
|
|
15
|
-
"""
|
|
16
|
-
import torch
|
|
17
|
-
|
|
18
|
-
from .resnet import resnet_b16
|
|
19
|
-
from ...abc.abstract_model import OmniModel
|
|
20
|
-
from ..module_utils import OmniPooling
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
class OmniModelForTokenRegression(OmniModel):
|
|
24
|
-
"""
|
|
25
|
-
Token-level regression model for genomic sequences.
|
|
26
|
-
|
|
27
|
-
This model performs regression at the token level, predicting continuous values
|
|
28
|
-
for each token in the input sequence. It's useful for tasks like predicting
|
|
29
|
-
binding affinities, expression levels, or other continuous properties at each
|
|
30
|
-
position in a genomic sequence.
|
|
31
|
-
|
|
32
|
-
Attributes:
|
|
33
|
-
classifier: Linear layer for regression output
|
|
34
|
-
loss_fn: Mean squared error loss function
|
|
35
|
-
"""
|
|
36
|
-
|
|
37
|
-
def __init__(self, config_or_model, tokenizer, *args, **kwargs):
|
|
38
|
-
"""
|
|
39
|
-
Initialize the token regression model.
|
|
40
|
-
|
|
41
|
-
Args:
|
|
42
|
-
config_or_model: Model configuration or pre-trained model
|
|
43
|
-
tokenizer: Tokenizer for processing input sequences
|
|
44
|
-
*args: Additional positional arguments
|
|
45
|
-
**kwargs: Additional keyword arguments
|
|
46
|
-
"""
|
|
47
|
-
super().__init__(config_or_model, tokenizer, *args, **kwargs)
|
|
48
|
-
self.metadata["model_name"] = self.__class__.__name__
|
|
49
|
-
self.classifier = torch.nn.Linear(
|
|
50
|
-
self.config.hidden_size, self.config.num_labels
|
|
51
|
-
)
|
|
52
|
-
self.loss_fn = torch.nn.MSELoss()
|
|
53
|
-
self.model_info()
|
|
54
|
-
|
|
55
|
-
def forward(self, **inputs):
|
|
56
|
-
"""
|
|
57
|
-
Forward pass for token-level regression.
|
|
58
|
-
|
|
59
|
-
Args:
|
|
60
|
-
**inputs: Input tensors including input_ids, attention_mask, and labels
|
|
61
|
-
|
|
62
|
-
Returns:
|
|
63
|
-
dict: Dictionary containing logits, last_hidden_state, and labels
|
|
64
|
-
"""
|
|
65
|
-
labels = inputs.pop("labels", None)
|
|
66
|
-
last_hidden_state = self.last_hidden_state_forward(**inputs)
|
|
67
|
-
last_hidden_state = self.dropout(last_hidden_state)
|
|
68
|
-
last_hidden_state = self.activation(last_hidden_state)
|
|
69
|
-
logits = self.classifier(last_hidden_state)
|
|
70
|
-
outputs = {
|
|
71
|
-
"logits": logits,
|
|
72
|
-
"last_hidden_state": last_hidden_state,
|
|
73
|
-
"labels": labels,
|
|
74
|
-
}
|
|
75
|
-
return outputs
|
|
76
|
-
|
|
77
|
-
def predict(self, sequence_or_inputs, **kwargs):
|
|
78
|
-
"""
|
|
79
|
-
Generate predictions for token-level regression.
|
|
80
|
-
|
|
81
|
-
Args:
|
|
82
|
-
sequence_or_inputs: Input sequences or pre-processed inputs
|
|
83
|
-
**kwargs: Additional keyword arguments
|
|
84
|
-
|
|
85
|
-
Returns:
|
|
86
|
-
dict: Dictionary containing predictions, logits, and last_hidden_state
|
|
87
|
-
"""
|
|
88
|
-
raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
|
|
89
|
-
|
|
90
|
-
logits = raw_outputs["logits"]
|
|
91
|
-
last_hidden_state = raw_outputs["last_hidden_state"]
|
|
92
|
-
|
|
93
|
-
predictions = []
|
|
94
|
-
for i in range(logits.shape[0]):
|
|
95
|
-
predictions.append(logits[i].cpu())
|
|
96
|
-
|
|
97
|
-
outputs = {
|
|
98
|
-
"predictions": (
|
|
99
|
-
torch.vstack(predictions).to(self.model.device)
|
|
100
|
-
if predictions[0].shape
|
|
101
|
-
else torch.tensor(predictions).to(self.model.device)
|
|
102
|
-
),
|
|
103
|
-
"logits": logits,
|
|
104
|
-
"last_hidden_state": last_hidden_state,
|
|
105
|
-
}
|
|
106
|
-
|
|
107
|
-
return outputs
|
|
108
|
-
|
|
109
|
-
def inference(self, sequence_or_inputs, **kwargs):
|
|
110
|
-
"""
|
|
111
|
-
Perform inference for token-level regression, excluding special tokens.
|
|
112
|
-
|
|
113
|
-
Args:
|
|
114
|
-
sequence_or_inputs: Input sequences or pre-processed inputs
|
|
115
|
-
**kwargs: Additional keyword arguments
|
|
116
|
-
|
|
117
|
-
Returns:
|
|
118
|
-
dict: Dictionary containing predictions, logits, and last_hidden_state
|
|
119
|
-
"""
|
|
120
|
-
raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
|
|
121
|
-
|
|
122
|
-
inputs = raw_outputs["inputs"]
|
|
123
|
-
logits = raw_outputs["logits"]
|
|
124
|
-
last_hidden_state = raw_outputs["last_hidden_state"]
|
|
125
|
-
|
|
126
|
-
predictions = []
|
|
127
|
-
for i in range(logits.shape[0]):
|
|
128
|
-
i_logit = logits[i][inputs["input_ids"][i].ne(self.config.pad_token_id)][
|
|
129
|
-
1:-1
|
|
130
|
-
]
|
|
131
|
-
predictions.append(i_logit.detach().cpu())
|
|
132
|
-
|
|
133
|
-
if not isinstance(sequence_or_inputs, list):
|
|
134
|
-
outputs = {
|
|
135
|
-
"predictions": predictions[0],
|
|
136
|
-
"logits": logits[0],
|
|
137
|
-
"last_hidden_state": last_hidden_state[0],
|
|
138
|
-
}
|
|
139
|
-
else:
|
|
140
|
-
outputs = {
|
|
141
|
-
"predictions": predictions,
|
|
142
|
-
"logits": logits,
|
|
143
|
-
"last_hidden_state": last_hidden_state,
|
|
144
|
-
}
|
|
145
|
-
|
|
146
|
-
return outputs
|
|
147
|
-
|
|
148
|
-
def loss_function(self, logits, labels):
|
|
149
|
-
"""
|
|
150
|
-
Compute the loss for token-level regression.
|
|
151
|
-
|
|
152
|
-
Args:
|
|
153
|
-
logits (torch.Tensor): Model predictions
|
|
154
|
-
labels (torch.Tensor): Ground truth labels
|
|
155
|
-
|
|
156
|
-
Returns:
|
|
157
|
-
torch.Tensor: Computed loss value
|
|
158
|
-
"""
|
|
159
|
-
padding_value = (
|
|
160
|
-
self.config.ignore_y if hasattr(self.config, "ignore_y") else -100
|
|
161
|
-
)
|
|
162
|
-
logits = logits.view(-1)
|
|
163
|
-
labels = labels.view(-1)
|
|
164
|
-
mask = torch.where(labels != padding_value)
|
|
165
|
-
|
|
166
|
-
filtered_logits = logits[mask]
|
|
167
|
-
filtered_targets = labels[mask]
|
|
168
|
-
|
|
169
|
-
loss = self.loss_fn(filtered_logits, filtered_targets)
|
|
170
|
-
return loss
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
class OmniModelForSequenceRegression(OmniModel):
|
|
174
|
-
"""
|
|
175
|
-
Sequence-level regression model for genomic sequences.
|
|
176
|
-
|
|
177
|
-
This model performs regression at the sequence level, predicting a single
|
|
178
|
-
continuous value for the entire input sequence. It's useful for tasks like
|
|
179
|
-
predicting overall expression levels, binding affinities, or other sequence-level
|
|
180
|
-
properties.
|
|
181
|
-
|
|
182
|
-
Attributes:
|
|
183
|
-
pooler: OmniPooling layer for sequence-level representation
|
|
184
|
-
classifier: Linear layer for regression output
|
|
185
|
-
loss_fn: Mean squared error loss function
|
|
186
|
-
"""
|
|
187
|
-
|
|
188
|
-
def __init__(self, config_or_model, tokenizer, *args, **kwargs):
|
|
189
|
-
"""
|
|
190
|
-
Initialize the sequence regression model.
|
|
191
|
-
|
|
192
|
-
Args:
|
|
193
|
-
config_or_model: Model configuration or pre-trained model
|
|
194
|
-
tokenizer: Tokenizer for processing input sequences
|
|
195
|
-
*args: Additional positional arguments
|
|
196
|
-
**kwargs: Additional keyword arguments
|
|
197
|
-
"""
|
|
198
|
-
super().__init__(config_or_model, tokenizer, *args, **kwargs)
|
|
199
|
-
self.metadata["model_name"] = self.__class__.__name__
|
|
200
|
-
self.pooler = OmniPooling(self.config)
|
|
201
|
-
self.classifier = torch.nn.Linear(
|
|
202
|
-
self.config.hidden_size, self.config.num_labels
|
|
203
|
-
)
|
|
204
|
-
self.loss_fn = torch.nn.MSELoss()
|
|
205
|
-
self.model_info()
|
|
206
|
-
|
|
207
|
-
def forward(self, **inputs):
|
|
208
|
-
"""
|
|
209
|
-
Forward pass for sequence-level regression.
|
|
210
|
-
|
|
211
|
-
Args:
|
|
212
|
-
**inputs: Input tensors including input_ids, attention_mask, and labels
|
|
213
|
-
|
|
214
|
-
Returns:
|
|
215
|
-
dict: Dictionary containing logits, last_hidden_state, and labels
|
|
216
|
-
"""
|
|
217
|
-
labels = inputs.pop("labels", None)
|
|
218
|
-
last_hidden_state = self.last_hidden_state_forward(**inputs)
|
|
219
|
-
last_hidden_state = self.dropout(last_hidden_state)
|
|
220
|
-
last_hidden_state = self.activation(last_hidden_state)
|
|
221
|
-
last_hidden_state = self.pooler(inputs, last_hidden_state)
|
|
222
|
-
logits = self.classifier(last_hidden_state)
|
|
223
|
-
outputs = {
|
|
224
|
-
"logits": logits,
|
|
225
|
-
"last_hidden_state": last_hidden_state,
|
|
226
|
-
"labels": labels,
|
|
227
|
-
}
|
|
228
|
-
return outputs
|
|
229
|
-
|
|
230
|
-
def predict(self, sequence_or_inputs, **kwargs):
|
|
231
|
-
"""
|
|
232
|
-
Generate predictions for sequence-level regression.
|
|
233
|
-
|
|
234
|
-
Args:
|
|
235
|
-
sequence_or_inputs: Input sequences or pre-processed inputs
|
|
236
|
-
**kwargs: Additional keyword arguments
|
|
237
|
-
|
|
238
|
-
Returns:
|
|
239
|
-
dict: Dictionary containing predictions, logits, and last_hidden_state
|
|
240
|
-
"""
|
|
241
|
-
raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
|
|
242
|
-
|
|
243
|
-
logits = raw_outputs["logits"]
|
|
244
|
-
last_hidden_state = raw_outputs["last_hidden_state"]
|
|
245
|
-
|
|
246
|
-
predictions = []
|
|
247
|
-
for i in range(logits.shape[0]):
|
|
248
|
-
predictions.append(logits[i].cpu())
|
|
249
|
-
|
|
250
|
-
outputs = {
|
|
251
|
-
"predictions": (
|
|
252
|
-
torch.vstack(predictions).to(self.model.device)
|
|
253
|
-
if predictions[0].shape
|
|
254
|
-
else torch.tensor(predictions).to(self.model.device)
|
|
255
|
-
),
|
|
256
|
-
"logits": logits,
|
|
257
|
-
"last_hidden_state": last_hidden_state,
|
|
258
|
-
}
|
|
259
|
-
|
|
260
|
-
return outputs
|
|
261
|
-
|
|
262
|
-
def inference(self, sequence_or_inputs, **kwargs):
|
|
263
|
-
"""
|
|
264
|
-
Perform inference for sequence-level regression.
|
|
265
|
-
|
|
266
|
-
Args:
|
|
267
|
-
sequence_or_inputs: Input sequences or pre-processed inputs
|
|
268
|
-
**kwargs: Additional keyword arguments
|
|
269
|
-
|
|
270
|
-
Returns:
|
|
271
|
-
dict: Dictionary containing predictions, logits, and last_hidden_state
|
|
272
|
-
"""
|
|
273
|
-
raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
|
|
274
|
-
|
|
275
|
-
logits = raw_outputs["logits"]
|
|
276
|
-
last_hidden_state = raw_outputs["last_hidden_state"]
|
|
277
|
-
|
|
278
|
-
predictions = []
|
|
279
|
-
for i in range(logits.shape[0]):
|
|
280
|
-
predictions.append(logits[i].cpu())
|
|
281
|
-
|
|
282
|
-
if not isinstance(sequence_or_inputs, list):
|
|
283
|
-
outputs = {
|
|
284
|
-
"predictions": predictions[0],
|
|
285
|
-
"logits": logits[0],
|
|
286
|
-
"last_hidden_state": last_hidden_state[0],
|
|
287
|
-
}
|
|
288
|
-
else:
|
|
289
|
-
outputs = {
|
|
290
|
-
"predictions": predictions,
|
|
291
|
-
"logits": logits,
|
|
292
|
-
"last_hidden_state": last_hidden_state,
|
|
293
|
-
}
|
|
294
|
-
|
|
295
|
-
return outputs
|
|
296
|
-
|
|
297
|
-
def loss_function(self, logits, labels):
|
|
298
|
-
"""
|
|
299
|
-
Compute the loss for sequence-level regression.
|
|
300
|
-
|
|
301
|
-
Args:
|
|
302
|
-
logits (torch.Tensor): Model predictions
|
|
303
|
-
labels (torch.Tensor): Ground truth labels
|
|
304
|
-
|
|
305
|
-
Returns:
|
|
306
|
-
torch.Tensor: Computed loss value
|
|
307
|
-
"""
|
|
308
|
-
padding_value = (
|
|
309
|
-
self.config.ignore_y if hasattr(self.config, "ignore_y") else -100
|
|
310
|
-
)
|
|
311
|
-
logits = logits.view(-1)
|
|
312
|
-
labels = labels.view(-1)
|
|
313
|
-
mask = torch.where(labels != padding_value)
|
|
314
|
-
|
|
315
|
-
filtered_logits = logits[mask]
|
|
316
|
-
filtered_targets = labels[mask]
|
|
317
|
-
|
|
318
|
-
loss = self.loss_fn(filtered_logits, filtered_targets)
|
|
319
|
-
return loss
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
class OmniModelForStructuralImputation(OmniModelForSequenceRegression):
|
|
323
|
-
"""
|
|
324
|
-
Structural imputation model for genomic sequences.
|
|
325
|
-
|
|
326
|
-
This model is specialized for imputing missing structural information in
|
|
327
|
-
genomic sequences. It extends the sequence regression model with additional
|
|
328
|
-
embedding capabilities for structural features.
|
|
329
|
-
|
|
330
|
-
Attributes:
|
|
331
|
-
embedding: Embedding layer for structural features
|
|
332
|
-
loss_fn: Mean squared error loss function
|
|
333
|
-
"""
|
|
334
|
-
|
|
335
|
-
def __init__(self, config_or_model, tokenizer, *args, **kwargs):
|
|
336
|
-
"""
|
|
337
|
-
Initialize the structural imputation model.
|
|
338
|
-
|
|
339
|
-
Args:
|
|
340
|
-
config_or_model: Model configuration or pre-trained model
|
|
341
|
-
tokenizer: Tokenizer for processing input sequences
|
|
342
|
-
*args: Additional positional arguments
|
|
343
|
-
**kwargs: Additional keyword arguments
|
|
344
|
-
"""
|
|
345
|
-
super().__init__(config_or_model, tokenizer, *args, **kwargs)
|
|
346
|
-
self.metadata["model_name"] = self.__class__.__name__
|
|
347
|
-
self.loss_fn = torch.nn.MSELoss()
|
|
348
|
-
self.embedding = torch.nn.Embedding(1, self.config.hidden_size)
|
|
349
|
-
self.model_info()
|
|
350
|
-
|
|
351
|
-
def forward(self, **inputs):
|
|
352
|
-
"""
|
|
353
|
-
Forward pass for structural imputation.
|
|
354
|
-
|
|
355
|
-
Args:
|
|
356
|
-
**inputs: Input tensors including input_ids, attention_mask, and labels
|
|
357
|
-
|
|
358
|
-
Returns:
|
|
359
|
-
dict: Dictionary containing logits, last_hidden_state, and labels
|
|
360
|
-
"""
|
|
361
|
-
labels = inputs.pop("labels", None)
|
|
362
|
-
last_hidden_state = self.last_hidden_state_forward(**inputs)
|
|
363
|
-
last_hidden_state = self.dropout(last_hidden_state)
|
|
364
|
-
last_hidden_state = self.activation(last_hidden_state)
|
|
365
|
-
last_hidden_state = self.pooler(inputs, last_hidden_state)
|
|
366
|
-
logits = self.classifier(last_hidden_state)
|
|
367
|
-
outputs = {
|
|
368
|
-
"logits": logits,
|
|
369
|
-
"last_hidden_state": last_hidden_state,
|
|
370
|
-
"labels": labels,
|
|
371
|
-
}
|
|
372
|
-
return outputs
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
class OmniModelForTokenRegressionWith2DStructure(
|
|
376
|
-
OmniModelForTokenRegression
|
|
377
|
-
):
|
|
378
|
-
"""
|
|
379
|
-
Token-level regression model with 2D structural information.
|
|
380
|
-
|
|
381
|
-
This model extends the basic token regression model to incorporate
|
|
382
|
-
2D structural information, useful for RNA structure prediction
|
|
383
|
-
and other structural genomics tasks.
|
|
384
|
-
"""
|
|
385
|
-
|
|
386
|
-
def __init__(self, config_or_model, tokenizer, *args, **kwargs):
|
|
387
|
-
"""
|
|
388
|
-
Initialize the 2D structure-aware token regression model.
|
|
389
|
-
|
|
390
|
-
Args:
|
|
391
|
-
config_or_model: Model configuration or pre-trained model
|
|
392
|
-
tokenizer: Tokenizer for processing input sequences
|
|
393
|
-
*args: Additional positional arguments
|
|
394
|
-
**kwargs: Additional keyword arguments
|
|
395
|
-
"""
|
|
396
|
-
super().__init__(config_or_model, tokenizer, *args, **kwargs)
|
|
397
|
-
self.metadata["model_name"] = self.__class__.__name__
|
|
398
|
-
|
|
399
|
-
def forward(self, **inputs):
|
|
400
|
-
"""
|
|
401
|
-
Forward pass for 2D structure-aware token regression.
|
|
402
|
-
|
|
403
|
-
Args:
|
|
404
|
-
**inputs: Input tensors including input_ids, attention_mask, labels, and structural info
|
|
405
|
-
|
|
406
|
-
Returns:
|
|
407
|
-
dict: Dictionary containing logits, last_hidden_state, and labels
|
|
408
|
-
"""
|
|
409
|
-
labels = inputs.pop("labels", None)
|
|
410
|
-
last_hidden_state = self.last_hidden_state_forward(**inputs)
|
|
411
|
-
last_hidden_state = self.dropout(last_hidden_state)
|
|
412
|
-
last_hidden_state = self.activation(last_hidden_state)
|
|
413
|
-
logits = self.classifier(last_hidden_state)
|
|
414
|
-
outputs = {
|
|
415
|
-
"logits": logits,
|
|
416
|
-
"last_hidden_state": last_hidden_state,
|
|
417
|
-
"labels": labels,
|
|
418
|
-
}
|
|
419
|
-
return outputs
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
class OmniModelForSequenceRegressionWith2DStructure(
|
|
423
|
-
OmniModelForSequenceRegression
|
|
424
|
-
):
|
|
425
|
-
"""
|
|
426
|
-
Sequence-level regression model with 2D structural information.
|
|
427
|
-
|
|
428
|
-
This model extends the basic sequence regression model to incorporate
|
|
429
|
-
2D structural information, useful for RNA structure prediction
|
|
430
|
-
and other structural genomics tasks.
|
|
431
|
-
"""
|
|
432
|
-
|
|
433
|
-
def __init__(self, config_or_model, tokenizer, *args, **kwargs):
|
|
434
|
-
"""
|
|
435
|
-
Initialize the 2D structure-aware sequence regression model.
|
|
436
|
-
|
|
437
|
-
Args:
|
|
438
|
-
config_or_model: Model configuration or pre-trained model
|
|
439
|
-
tokenizer: Tokenizer for processing input sequences
|
|
440
|
-
*args: Additional positional arguments
|
|
441
|
-
**kwargs: Additional keyword arguments
|
|
442
|
-
"""
|
|
443
|
-
super().__init__(config_or_model, tokenizer, *args, **kwargs)
|
|
444
|
-
self.metadata["model_name"] = self.__class__.__name__
|
|
445
|
-
|
|
446
|
-
def forward(self, **inputs):
|
|
447
|
-
"""
|
|
448
|
-
Forward pass for 2D structure-aware sequence regression.
|
|
449
|
-
|
|
450
|
-
Args:
|
|
451
|
-
**inputs: Input tensors including input_ids, attention_mask, labels, and structural info
|
|
452
|
-
|
|
453
|
-
Returns:
|
|
454
|
-
dict: Dictionary containing logits, last_hidden_state, and labels
|
|
455
|
-
"""
|
|
456
|
-
labels = inputs.pop("labels", None)
|
|
457
|
-
last_hidden_state = self.last_hidden_state_forward(**inputs)
|
|
458
|
-
last_hidden_state = self.dropout(last_hidden_state)
|
|
459
|
-
last_hidden_state = self.activation(last_hidden_state)
|
|
460
|
-
last_hidden_state = self.pooler(inputs, last_hidden_state)
|
|
461
|
-
logits = self.classifier(last_hidden_state)
|
|
462
|
-
outputs = {
|
|
463
|
-
"logits": logits,
|
|
464
|
-
"last_hidden_state": last_hidden_state,
|
|
465
|
-
"labels": labels,
|
|
466
|
-
}
|
|
467
|
-
return outputs
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
class OmniModelForMatrixRegression(OmniModel):
|
|
471
|
-
"""
|
|
472
|
-
Matrix regression model for genomic sequences.
|
|
473
|
-
|
|
474
|
-
This model performs regression on matrix representations of genomic sequences,
|
|
475
|
-
useful for tasks like contact map prediction, structure prediction, or other
|
|
476
|
-
matrix-based genomic analysis tasks.
|
|
477
|
-
|
|
478
|
-
Attributes:
|
|
479
|
-
resnet: ResNet backbone for processing matrix inputs
|
|
480
|
-
classifier: Linear layer for regression output
|
|
481
|
-
loss_fn: Mean squared error loss function
|
|
482
|
-
"""
|
|
483
|
-
|
|
484
|
-
def __init__(self, config_or_model, tokenizer, *args, **kwargs):
|
|
485
|
-
"""
|
|
486
|
-
Initialize the matrix regression model.
|
|
487
|
-
|
|
488
|
-
Args:
|
|
489
|
-
config_or_model: Model configuration or pre-trained model
|
|
490
|
-
tokenizer: Tokenizer for processing input sequences
|
|
491
|
-
*args: Additional positional arguments
|
|
492
|
-
**kwargs: Additional keyword arguments
|
|
493
|
-
"""
|
|
494
|
-
super().__init__(config_or_model, tokenizer, *args, **kwargs)
|
|
495
|
-
self.metadata["model_name"] = self.__class__.__name__
|
|
496
|
-
self.resnet = resnet_b16(channels=128, bbn=16)
|
|
497
|
-
self.classifier = torch.nn.Linear(1, self.config.num_labels)
|
|
498
|
-
self.loss_fn = torch.nn.MSELoss()
|
|
499
|
-
self.model_info()
|
|
500
|
-
|
|
501
|
-
def forward(self, **inputs):
|
|
502
|
-
"""
|
|
503
|
-
Forward pass for matrix regression.
|
|
504
|
-
|
|
505
|
-
Args:
|
|
506
|
-
**inputs: Input tensors including matrix representations and labels
|
|
507
|
-
|
|
508
|
-
Returns:
|
|
509
|
-
dict: Dictionary containing logits, last_hidden_state, and labels
|
|
510
|
-
"""
|
|
511
|
-
labels = inputs.pop("labels", None)
|
|
512
|
-
matrix_inputs = inputs.pop("matrix_inputs", None)
|
|
513
|
-
|
|
514
|
-
if matrix_inputs is None:
|
|
515
|
-
raise ValueError("matrix_inputs is required for matrix regression")
|
|
516
|
-
|
|
517
|
-
outputs = self.resnet(matrix_inputs)
|
|
518
|
-
logits = self.classifier(outputs)
|
|
519
|
-
|
|
520
|
-
outputs = {
|
|
521
|
-
"logits": logits,
|
|
522
|
-
"last_hidden_state": outputs,
|
|
523
|
-
"labels": labels,
|
|
524
|
-
}
|
|
525
|
-
return outputs
|
|
526
|
-
|
|
527
|
-
def predict(self, sequence_or_inputs, **kwargs):
|
|
528
|
-
"""
|
|
529
|
-
Generate predictions for matrix regression.
|
|
530
|
-
|
|
531
|
-
Args:
|
|
532
|
-
sequence_or_inputs: Input sequences or pre-processed inputs
|
|
533
|
-
**kwargs: Additional keyword arguments
|
|
534
|
-
|
|
535
|
-
Returns:
|
|
536
|
-
dict: Dictionary containing predictions, logits, and last_hidden_state
|
|
537
|
-
"""
|
|
538
|
-
raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
|
|
539
|
-
|
|
540
|
-
logits = raw_outputs["logits"]
|
|
541
|
-
last_hidden_state = raw_outputs["last_hidden_state"]
|
|
542
|
-
|
|
543
|
-
predictions = []
|
|
544
|
-
for i in range(logits.shape[0]):
|
|
545
|
-
predictions.append(logits[i].cpu())
|
|
546
|
-
|
|
547
|
-
outputs = {
|
|
548
|
-
"predictions": (
|
|
549
|
-
torch.vstack(predictions).to(self.model.device)
|
|
550
|
-
if predictions[0].shape
|
|
551
|
-
else torch.tensor(predictions).to(self.model.device)
|
|
552
|
-
),
|
|
553
|
-
"logits": logits,
|
|
554
|
-
"last_hidden_state": last_hidden_state,
|
|
555
|
-
}
|
|
556
|
-
|
|
557
|
-
return outputs
|
|
558
|
-
|
|
559
|
-
def inference(self, sequence_or_inputs, **kwargs):
|
|
560
|
-
"""
|
|
561
|
-
Perform inference for matrix regression.
|
|
562
|
-
|
|
563
|
-
Args:
|
|
564
|
-
sequence_or_inputs: Input sequences or pre-processed inputs
|
|
565
|
-
**kwargs: Additional keyword arguments
|
|
566
|
-
|
|
567
|
-
Returns:
|
|
568
|
-
dict: Dictionary containing predictions, logits, and last_hidden_state
|
|
569
|
-
"""
|
|
570
|
-
raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
|
|
571
|
-
|
|
572
|
-
logits = raw_outputs["logits"]
|
|
573
|
-
last_hidden_state = raw_outputs["last_hidden_state"]
|
|
574
|
-
|
|
575
|
-
predictions = []
|
|
576
|
-
for i in range(logits.shape[0]):
|
|
577
|
-
predictions.append(logits[i].cpu())
|
|
578
|
-
|
|
579
|
-
if not isinstance(sequence_or_inputs, list):
|
|
580
|
-
outputs = {
|
|
581
|
-
"predictions": predictions[0],
|
|
582
|
-
"logits": logits[0],
|
|
583
|
-
"last_hidden_state": last_hidden_state[0],
|
|
584
|
-
}
|
|
585
|
-
else:
|
|
586
|
-
outputs = {
|
|
587
|
-
"predictions": predictions,
|
|
588
|
-
"logits": logits,
|
|
589
|
-
"last_hidden_state": last_hidden_state,
|
|
590
|
-
}
|
|
591
|
-
|
|
592
|
-
return outputs
|
|
593
|
-
|
|
594
|
-
def loss_function(self, logits, labels):
|
|
595
|
-
"""
|
|
596
|
-
Compute the loss for matrix regression.
|
|
597
|
-
|
|
598
|
-
Args:
|
|
599
|
-
logits (torch.Tensor): Model predictions
|
|
600
|
-
labels (torch.Tensor): Ground truth labels
|
|
601
|
-
|
|
602
|
-
Returns:
|
|
603
|
-
torch.Tensor: Computed loss value
|
|
604
|
-
"""
|
|
605
|
-
padding_value = (
|
|
606
|
-
self.config.ignore_y if hasattr(self.config, "ignore_y") else -100
|
|
607
|
-
)
|
|
608
|
-
logits = logits.view(-1)
|
|
609
|
-
labels = labels.view(-1)
|
|
610
|
-
mask = torch.where(labels != padding_value)
|
|
611
|
-
|
|
612
|
-
filtered_logits = logits[mask]
|
|
613
|
-
filtered_targets = labels[mask]
|
|
614
|
-
|
|
615
|
-
loss = self.loss_fn(filtered_logits, filtered_targets)
|
|
616
|
-
return loss
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
class OmniModelForMatrixClassification(OmniModel):
|
|
620
|
-
"""
|
|
621
|
-
Matrix classification model for genomic sequences.
|
|
622
|
-
|
|
623
|
-
This model performs classification on matrix representations of genomic sequences,
|
|
624
|
-
useful for tasks like structure classification, contact map classification, or other
|
|
625
|
-
matrix-based genomic analysis tasks.
|
|
626
|
-
|
|
627
|
-
Attributes:
|
|
628
|
-
resnet: ResNet backbone for processing matrix inputs
|
|
629
|
-
classifier: Linear layer for classification output
|
|
630
|
-
loss_fn: Cross-entropy loss function
|
|
631
|
-
"""
|
|
632
|
-
|
|
633
|
-
def __init__(self, config_or_model, tokenizer, *args, **kwargs):
|
|
634
|
-
"""
|
|
635
|
-
Initialize the matrix classification model.
|
|
636
|
-
|
|
637
|
-
Args:
|
|
638
|
-
config_or_model: Model configuration or pre-trained model
|
|
639
|
-
tokenizer: Tokenizer for processing input sequences
|
|
640
|
-
*args: Additional positional arguments
|
|
641
|
-
**kwargs: Additional keyword arguments
|
|
642
|
-
"""
|
|
643
|
-
super().__init__(config_or_model, tokenizer, *args, **kwargs)
|
|
644
|
-
self.metadata["model_name"] = self.__class__.__name__
|
|
645
|
-
# For binary classification, output size is 1
|
|
646
|
-
self.classifier = torch.nn.Linear(self.config.hidden_size, 1)
|
|
647
|
-
self.sigmoid = torch.nn.Sigmoid()
|
|
648
|
-
# Change to BCEWithLogitsLoss for binary classification
|
|
649
|
-
self.loss_fn = torch.nn.BCEWithLogitsLoss()
|
|
650
|
-
self.cnn = resnet_b16(channels=self.config.hidden_size, bbn=16)
|
|
651
|
-
self.model_info()
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
def forward(self, **inputs):
|
|
655
|
-
"""
|
|
656
|
-
Forward pass for matrix classification.
|
|
657
|
-
|
|
658
|
-
Args:
|
|
659
|
-
**inputs: Input tensors including matrix representations and labels
|
|
660
|
-
|
|
661
|
-
Returns:
|
|
662
|
-
dict: Dictionary containing logits, last_hidden_state, and labels
|
|
663
|
-
"""
|
|
664
|
-
labels = inputs.pop("labels", None)
|
|
665
|
-
matrix_inputs = inputs.pop("matrix_inputs", None)
|
|
666
|
-
|
|
667
|
-
if matrix_inputs is None:
|
|
668
|
-
raise ValueError("matrix_inputs is required for matrix classification")
|
|
669
|
-
|
|
670
|
-
outputs = self.resnet(matrix_inputs)
|
|
671
|
-
logits = self.classifier(outputs)
|
|
672
|
-
|
|
673
|
-
outputs = {
|
|
674
|
-
"logits": logits,
|
|
675
|
-
"last_hidden_state": outputs,
|
|
676
|
-
"labels": labels,
|
|
677
|
-
}
|
|
678
|
-
return outputs
|
|
679
|
-
|
|
680
|
-
def predict(self, sequence_or_inputs, **kwargs):
|
|
681
|
-
"""
|
|
682
|
-
Generate predictions for matrix classification.
|
|
683
|
-
|
|
684
|
-
Args:
|
|
685
|
-
sequence_or_inputs: Input sequences or pre-processed inputs
|
|
686
|
-
**kwargs: Additional keyword arguments
|
|
687
|
-
|
|
688
|
-
Returns:
|
|
689
|
-
dict: Dictionary containing predictions, logits, and last_hidden_state
|
|
690
|
-
"""
|
|
691
|
-
raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
|
|
692
|
-
|
|
693
|
-
logits = raw_outputs["logits"]
|
|
694
|
-
last_hidden_state = raw_outputs["last_hidden_state"]
|
|
695
|
-
|
|
696
|
-
predictions = []
|
|
697
|
-
for i in range(logits.shape[0]):
|
|
698
|
-
# Apply sigmoid for binary classification
|
|
699
|
-
pred_class = (logits[i] > 0.5).float()
|
|
700
|
-
predictions.append(pred_class.cpu())
|
|
701
|
-
outputs = {
|
|
702
|
-
"predictions": (
|
|
703
|
-
torch.vstack(predictions).to(self.model.device)
|
|
704
|
-
if predictions[0].shape
|
|
705
|
-
else torch.tensor(predictions).to(self.model.device)
|
|
706
|
-
),
|
|
707
|
-
"logits": logits,
|
|
708
|
-
"last_hidden_state": last_hidden_state,
|
|
709
|
-
}
|
|
710
|
-
|
|
711
|
-
return outputs
|
|
712
|
-
|
|
713
|
-
def inference(self, sequence_or_inputs, **kwargs):
|
|
714
|
-
"""
|
|
715
|
-
Perform inference for matrix classification.
|
|
716
|
-
|
|
717
|
-
Args:
|
|
718
|
-
sequence_or_inputs: Input sequences or pre-processed inputs
|
|
719
|
-
**kwargs: Additional keyword arguments
|
|
720
|
-
|
|
721
|
-
Returns:
|
|
722
|
-
dict: Dictionary containing predictions, logits, and last_hidden_state
|
|
723
|
-
"""
|
|
724
|
-
raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
|
|
725
|
-
inputs = raw_outputs["inputs"]
|
|
726
|
-
logits = raw_outputs["logits"]
|
|
727
|
-
last_hidden_state = raw_outputs["last_hidden_state"]
|
|
728
|
-
|
|
729
|
-
predictions = []
|
|
730
|
-
probabilities = []
|
|
731
|
-
for i in range(logits.shape[0]):
|
|
732
|
-
i_logit = logits[i][inputs["input_ids"][i].ne(self.config.pad_token_id)][
|
|
733
|
-
1:-1
|
|
734
|
-
]
|
|
735
|
-
probs = i_logit
|
|
736
|
-
# For binary classification, threshold at 0.5
|
|
737
|
-
pred_class = (probs > 0.5).float()
|
|
738
|
-
predictions.append(pred_class.detach().cpu())
|
|
739
|
-
probabilities.append(probs.detach().cpu())
|
|
740
|
-
|
|
741
|
-
if not isinstance(sequence_or_inputs, list):
|
|
742
|
-
outputs = {
|
|
743
|
-
"predictions": predictions[0],
|
|
744
|
-
"logits": logits[0],
|
|
745
|
-
"last_hidden_state": last_hidden_state[0],
|
|
746
|
-
}
|
|
747
|
-
else:
|
|
748
|
-
outputs = {
|
|
749
|
-
"predictions": predictions,
|
|
750
|
-
"logits": logits,
|
|
751
|
-
"last_hidden_state": last_hidden_state,
|
|
752
|
-
}
|
|
753
|
-
|
|
754
|
-
return outputs
|
|
755
|
-
|
|
756
|
-
def loss_function(self, logits, labels):
|
|
757
|
-
"""
|
|
758
|
-
Compute the loss for matrix classification.
|
|
759
|
-
|
|
760
|
-
Args:
|
|
761
|
-
logits (torch.Tensor): Model predictions
|
|
762
|
-
labels (torch.Tensor): Ground truth labels
|
|
763
|
-
|
|
764
|
-
Returns:
|
|
765
|
-
torch.Tensor: Computed loss value
|
|
766
|
-
"""
|
|
767
|
-
padding_value = (
|
|
768
|
-
self.config.ignore_y if hasattr(self.config, "ignore_y") else -100
|
|
769
|
-
)
|
|
770
|
-
logits = logits.view(-1, self.config.num_labels)
|
|
771
|
-
labels = labels.view(-1)
|
|
772
|
-
mask = torch.where(labels != padding_value)
|
|
773
|
-
|
|
774
|
-
# Filter out padding
|
|
775
|
-
filtered_logits = logits[mask]
|
|
776
|
-
filtered_targets = labels[mask]
|
|
777
|
-
|
|
778
|
-
# Reshape for binary classification
|
|
779
|
-
filtered_logits = filtered_logits.view(-1)
|
|
780
|
-
filtered_targets = filtered_targets.view(
|
|
781
|
-
-1
|
|
782
|
-
).float() # Convert to float for BCEWithLogitsLoss
|
|
783
|
-
|
|
784
|
-
# Apply BCEWithLogitsLoss
|
|
785
|
-
loss = self.loss_fn(filtered_logits, filtered_targets)
|
|
786
|
-
return loss
|