omnigenome 0.3.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of omnigenome might be problematic. Click here for more details.
- omnigenome/__init__.py +281 -0
- omnigenome/auto/__init__.py +3 -0
- omnigenome/auto/auto_bench/__init__.py +12 -0
- omnigenome/auto/auto_bench/auto_bench.py +484 -0
- omnigenome/auto/auto_bench/auto_bench_cli.py +230 -0
- omnigenome/auto/auto_bench/auto_bench_config.py +216 -0
- omnigenome/auto/auto_bench/config_check.py +34 -0
- omnigenome/auto/auto_train/__init__.py +13 -0
- omnigenome/auto/auto_train/auto_train.py +430 -0
- omnigenome/auto/auto_train/auto_train_cli.py +222 -0
- omnigenome/auto/bench_hub/__init__.py +12 -0
- omnigenome/auto/bench_hub/bench_hub.py +25 -0
- omnigenome/cli/__init__.py +13 -0
- omnigenome/cli/commands/__init__.py +13 -0
- omnigenome/cli/commands/base.py +83 -0
- omnigenome/cli/commands/bench/__init__.py +13 -0
- omnigenome/cli/commands/bench/bench_cli.py +202 -0
- omnigenome/cli/commands/rna/__init__.py +13 -0
- omnigenome/cli/commands/rna/rna_design.py +178 -0
- omnigenome/cli/omnigenome_cli.py +128 -0
- omnigenome/src/__init__.py +12 -0
- omnigenome/src/abc/__init__.py +12 -0
- omnigenome/src/abc/abstract_dataset.py +622 -0
- omnigenome/src/abc/abstract_metric.py +114 -0
- omnigenome/src/abc/abstract_model.py +689 -0
- omnigenome/src/abc/abstract_tokenizer.py +267 -0
- omnigenome/src/dataset/__init__.py +16 -0
- omnigenome/src/dataset/omni_dataset.py +435 -0
- omnigenome/src/lora/__init__.py +13 -0
- omnigenome/src/lora/lora_model.py +294 -0
- omnigenome/src/metric/__init__.py +15 -0
- omnigenome/src/metric/classification_metric.py +184 -0
- omnigenome/src/metric/metric.py +199 -0
- omnigenome/src/metric/ranking_metric.py +142 -0
- omnigenome/src/metric/regression_metric.py +191 -0
- omnigenome/src/misc/__init__.py +3 -0
- omnigenome/src/misc/utils.py +439 -0
- omnigenome/src/model/__init__.py +19 -0
- omnigenome/src/model/augmentation/__init__.py +12 -0
- omnigenome/src/model/augmentation/model.py +219 -0
- omnigenome/src/model/classification/__init__.py +12 -0
- omnigenome/src/model/classification/model.py +642 -0
- omnigenome/src/model/embedding/__init__.py +12 -0
- omnigenome/src/model/embedding/model.py +263 -0
- omnigenome/src/model/mlm/__init__.py +12 -0
- omnigenome/src/model/mlm/model.py +177 -0
- omnigenome/src/model/module_utils.py +232 -0
- omnigenome/src/model/regression/__init__.py +12 -0
- omnigenome/src/model/regression/model.py +786 -0
- omnigenome/src/model/regression/resnet.py +483 -0
- omnigenome/src/model/rna_design/__init__.py +12 -0
- omnigenome/src/model/rna_design/model.py +426 -0
- omnigenome/src/model/seq2seq/__init__.py +12 -0
- omnigenome/src/model/seq2seq/model.py +44 -0
- omnigenome/src/tokenizer/__init__.py +16 -0
- omnigenome/src/tokenizer/bpe_tokenizer.py +226 -0
- omnigenome/src/tokenizer/kmers_tokenizer.py +247 -0
- omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +249 -0
- omnigenome/src/trainer/__init__.py +14 -0
- omnigenome/src/trainer/accelerate_trainer.py +739 -0
- omnigenome/src/trainer/hf_trainer.py +75 -0
- omnigenome/src/trainer/trainer.py +579 -0
- omnigenome/utility/__init__.py +3 -0
- omnigenome/utility/dataset_hub/__init__.py +13 -0
- omnigenome/utility/dataset_hub/dataset_hub.py +178 -0
- omnigenome/utility/ensemble.py +324 -0
- omnigenome/utility/hub_utils.py +517 -0
- omnigenome/utility/model_hub/__init__.py +12 -0
- omnigenome/utility/model_hub/model_hub.py +231 -0
- omnigenome/utility/pipeline_hub/__init__.py +12 -0
- omnigenome/utility/pipeline_hub/pipeline.py +483 -0
- omnigenome/utility/pipeline_hub/pipeline_hub.py +129 -0
- omnigenome-0.3.0a0.dist-info/METADATA +224 -0
- omnigenome-0.3.0a0.dist-info/RECORD +85 -0
- omnigenome-0.3.0a0.dist-info/WHEEL +5 -0
- omnigenome-0.3.0a0.dist-info/entry_points.txt +3 -0
- omnigenome-0.3.0a0.dist-info/licenses/LICENSE +201 -0
- omnigenome-0.3.0a0.dist-info/top_level.txt +2 -0
- tests/__init__.py +9 -0
- tests/conftest.py +160 -0
- tests/test_dataset_patterns.py +291 -0
- tests/test_examples_syntax.py +83 -0
- tests/test_model_loading.py +183 -0
- tests/test_rna_functions.py +255 -0
- tests/test_training_patterns.py +302 -0
|
@@ -0,0 +1,642 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# file: model.py
|
|
3
|
+
# time: 18:36 06/04/2024
|
|
4
|
+
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
+
# github: https://github.com/yangheng95
|
|
6
|
+
# huggingface: https://huggingface.co/yangheng
|
|
7
|
+
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
8
|
+
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
from ...abc.abstract_model import OmniModel
|
|
13
|
+
from ..module_utils import OmniPooling
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class OmniModelForTokenClassification(OmniModel):
|
|
17
|
+
"""
|
|
18
|
+
Model for token classification tasks in genomics.
|
|
19
|
+
|
|
20
|
+
This model is designed for token-level classification tasks such as
|
|
21
|
+
sequence labeling, where each token in the input sequence needs to be
|
|
22
|
+
classified into different categories. It extends the base OmniModel
|
|
23
|
+
with token-level classification capabilities.
|
|
24
|
+
|
|
25
|
+
The model adds a classification head on top of the base model's hidden
|
|
26
|
+
states and applies softmax to produce probability distributions over
|
|
27
|
+
the label classes for each token.
|
|
28
|
+
|
|
29
|
+
Attributes:
|
|
30
|
+
softmax (torch.nn.Softmax): Softmax layer for probability computation.
|
|
31
|
+
classifier (torch.nn.Linear): Linear classification head.
|
|
32
|
+
loss_fn (torch.nn.CrossEntropyLoss): Loss function for training.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(self, config_or_model, tokenizer, *args, **kwargs):
|
|
36
|
+
"""
|
|
37
|
+
Initializes the token classification model.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
config_or_model: Model configuration, pre-trained model path, or model instance.
|
|
41
|
+
tokenizer: The tokenizer associated with the model.
|
|
42
|
+
*args: Additional positional arguments.
|
|
43
|
+
**kwargs: Additional keyword arguments.
|
|
44
|
+
|
|
45
|
+
Example:
|
|
46
|
+
>>> model = OmniModelForTokenClassification("model_path", tokenizer)
|
|
47
|
+
"""
|
|
48
|
+
super().__init__(config_or_model, tokenizer, *args, **kwargs)
|
|
49
|
+
self.metadata["model_name"] = self.__class__.__name__
|
|
50
|
+
self.softmax = torch.nn.Softmax(dim=-1)
|
|
51
|
+
self.classifier = torch.nn.Linear(
|
|
52
|
+
self.config.hidden_size, self.config.num_labels
|
|
53
|
+
)
|
|
54
|
+
self.loss_fn = torch.nn.CrossEntropyLoss()
|
|
55
|
+
self.model_info()
|
|
56
|
+
|
|
57
|
+
def forward(self, **inputs):
|
|
58
|
+
"""
|
|
59
|
+
Forward pass for token classification.
|
|
60
|
+
|
|
61
|
+
This method performs the forward pass through the model, computing
|
|
62
|
+
logits for each token in the input sequence and applying softmax
|
|
63
|
+
to produce probability distributions.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
**inputs: Input tensors including 'input_ids', 'attention_mask',
|
|
67
|
+
and optionally 'labels'.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
dict: A dictionary containing:
|
|
71
|
+
- logits: Token-level classification logits
|
|
72
|
+
- last_hidden_state: Final hidden states from the base model
|
|
73
|
+
- labels: Ground truth labels (if provided)
|
|
74
|
+
|
|
75
|
+
Example:
|
|
76
|
+
>>> outputs = model(
|
|
77
|
+
... input_ids=torch.tensor([[1, 2, 3, 4]]),
|
|
78
|
+
... attention_mask=torch.tensor([[1, 1, 1, 1]]),
|
|
79
|
+
... labels=torch.tensor([[0, 1, 0, 1]])
|
|
80
|
+
... )
|
|
81
|
+
"""
|
|
82
|
+
labels = inputs.pop("labels", None)
|
|
83
|
+
last_hidden_state = self.last_hidden_state_forward(**inputs)
|
|
84
|
+
last_hidden_state = self.dropout(last_hidden_state)
|
|
85
|
+
last_hidden_state = self.activation(last_hidden_state)
|
|
86
|
+
logits = self.classifier(last_hidden_state)
|
|
87
|
+
logits = self.softmax(logits)
|
|
88
|
+
outputs = {
|
|
89
|
+
"logits": logits,
|
|
90
|
+
"last_hidden_state": last_hidden_state,
|
|
91
|
+
"labels": labels,
|
|
92
|
+
}
|
|
93
|
+
return outputs
|
|
94
|
+
|
|
95
|
+
def predict(self, sequence_or_inputs, **kwargs):
|
|
96
|
+
"""
|
|
97
|
+
Performs token-level prediction on raw inputs.
|
|
98
|
+
|
|
99
|
+
This method takes raw sequences or tokenized inputs and returns
|
|
100
|
+
token-level predictions. It processes the inputs through the model
|
|
101
|
+
and returns the predicted class for each token.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
105
|
+
tokenized inputs (dict/tuple).
|
|
106
|
+
**kwargs: Additional arguments for tokenization and inference.
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
dict: A dictionary containing:
|
|
110
|
+
- predictions: Predicted class indices for each token
|
|
111
|
+
- logits: Raw logits from the model
|
|
112
|
+
- last_hidden_state: Final hidden states
|
|
113
|
+
|
|
114
|
+
Example:
|
|
115
|
+
>>> # Predict on a single sequence
|
|
116
|
+
>>> outputs = model.predict("ATCGATCG")
|
|
117
|
+
>>> print(outputs['predictions'].shape) # (seq_len,)
|
|
118
|
+
|
|
119
|
+
>>> # Predict on multiple sequences
|
|
120
|
+
>>> outputs = model.predict(["ATCGATCG", "GCTAGCTA"])
|
|
121
|
+
"""
|
|
122
|
+
raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
|
|
123
|
+
logits = raw_outputs["logits"]
|
|
124
|
+
last_hidden_state = raw_outputs["last_hidden_state"]
|
|
125
|
+
|
|
126
|
+
predictions = []
|
|
127
|
+
for i in range(logits.shape[0]):
|
|
128
|
+
predictions.append(logits[i].argmax(dim=-1).detach().cpu())
|
|
129
|
+
|
|
130
|
+
outputs = {
|
|
131
|
+
"predictions": (
|
|
132
|
+
torch.vstack(predictions).to(self.model.device)
|
|
133
|
+
if predictions[0].shape
|
|
134
|
+
else torch.tensor(predictions).to(self.model.device)
|
|
135
|
+
),
|
|
136
|
+
"logits": logits,
|
|
137
|
+
"last_hidden_state": last_hidden_state,
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
return outputs
|
|
141
|
+
|
|
142
|
+
def inference(self, sequence_or_inputs, **kwargs):
|
|
143
|
+
"""
|
|
144
|
+
Performs token-level inference with human-readable output.
|
|
145
|
+
|
|
146
|
+
This method provides processed, human-readable token-level predictions.
|
|
147
|
+
It converts logits to class labels and handles special tokens appropriately.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
151
|
+
tokenized inputs (dict/tuple).
|
|
152
|
+
**kwargs: Additional arguments for tokenization and inference.
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
dict: A dictionary containing:
|
|
156
|
+
- predictions: Human-readable class labels for each token
|
|
157
|
+
- logits: Raw logits from the model
|
|
158
|
+
- confidence: Confidence scores for predictions
|
|
159
|
+
- last_hidden_state: Final hidden states
|
|
160
|
+
|
|
161
|
+
Example:
|
|
162
|
+
>>> # Inference on a single sequence
|
|
163
|
+
>>> results = model.inference("ATCGATCG")
|
|
164
|
+
>>> print(results['predictions']) # ['A', 'T', 'C', 'G', ...]
|
|
165
|
+
"""
|
|
166
|
+
raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
|
|
167
|
+
inputs = raw_outputs["inputs"]
|
|
168
|
+
logits = raw_outputs["logits"]
|
|
169
|
+
last_hidden_state = raw_outputs["last_hidden_state"]
|
|
170
|
+
|
|
171
|
+
predictions = []
|
|
172
|
+
for i in range(logits.shape[0]):
|
|
173
|
+
# Note that the first and last tokens are removed,
|
|
174
|
+
# and the length of outputs are calculated based on the tokenized inputs.
|
|
175
|
+
i_logit = logits[i][inputs["input_ids"][i].ne(self.config.pad_token_id)][
|
|
176
|
+
1:-1
|
|
177
|
+
]
|
|
178
|
+
prediction = [
|
|
179
|
+
self.config.id2label.get(x.item(), "") for x in i_logit.argmax(dim=-1)
|
|
180
|
+
]
|
|
181
|
+
predictions.append(prediction)
|
|
182
|
+
|
|
183
|
+
if not isinstance(sequence_or_inputs, list):
|
|
184
|
+
outputs = {
|
|
185
|
+
"predictions": predictions[0],
|
|
186
|
+
"logits": logits[0],
|
|
187
|
+
"confidence": torch.max(logits[0]),
|
|
188
|
+
"last_hidden_state": last_hidden_state[0],
|
|
189
|
+
}
|
|
190
|
+
else:
|
|
191
|
+
outputs = {
|
|
192
|
+
"predictions": predictions,
|
|
193
|
+
"logits": logits,
|
|
194
|
+
"confidence": torch.max(logits, dim=-1)[0],
|
|
195
|
+
"last_hidden_state": last_hidden_state,
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
return outputs
|
|
199
|
+
|
|
200
|
+
def loss_function(self, logits, labels):
|
|
201
|
+
"""
|
|
202
|
+
Calculates the cross-entropy loss for token classification.
|
|
203
|
+
|
|
204
|
+
This method computes the cross-entropy loss between the predicted
|
|
205
|
+
logits and the ground truth labels, ignoring padding tokens.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
logits (torch.Tensor): Predicted logits from the model.
|
|
209
|
+
labels (torch.Tensor): Ground truth labels.
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
torch.Tensor: The computed loss value.
|
|
213
|
+
|
|
214
|
+
Example:
|
|
215
|
+
>>> loss = model.loss_function(logits, labels)
|
|
216
|
+
"""
|
|
217
|
+
loss = self.loss_fn(logits.view(-1, self.config.num_labels), labels.view(-1))
|
|
218
|
+
return loss
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
class OmniModelForSequenceClassification(OmniModel):
|
|
222
|
+
"""
|
|
223
|
+
Model for sequence classification tasks in genomics.
|
|
224
|
+
|
|
225
|
+
This model is designed for sequence-level classification tasks where
|
|
226
|
+
the entire input sequence is classified into one of several categories.
|
|
227
|
+
It extends the base OmniModel with sequence-level classification capabilities.
|
|
228
|
+
|
|
229
|
+
The model uses a pooling mechanism to aggregate token-level representations
|
|
230
|
+
into a sequence-level representation, which is then classified using a
|
|
231
|
+
linear classifier.
|
|
232
|
+
|
|
233
|
+
Attributes:
|
|
234
|
+
pooler (OmniPooling): Pooling layer for sequence-level representation.
|
|
235
|
+
softmax (torch.nn.Softmax): Softmax layer for probability computation.
|
|
236
|
+
classifier (torch.nn.Linear): Linear classification head.
|
|
237
|
+
loss_fn (torch.nn.CrossEntropyLoss): Loss function for training.
|
|
238
|
+
"""
|
|
239
|
+
|
|
240
|
+
def __init__(self, config_or_model, tokenizer, *args, **kwargs):
|
|
241
|
+
"""
|
|
242
|
+
Initializes the sequence classification model.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
config_or_model: Model configuration, pre-trained model path, or model instance.
|
|
246
|
+
tokenizer: The tokenizer associated with the model.
|
|
247
|
+
*args: Additional positional arguments.
|
|
248
|
+
**kwargs: Additional keyword arguments.
|
|
249
|
+
|
|
250
|
+
Example:
|
|
251
|
+
>>> model = OmniModelForSequenceClassification("model_path", tokenizer)
|
|
252
|
+
"""
|
|
253
|
+
super().__init__(config_or_model, tokenizer, *args, **kwargs)
|
|
254
|
+
self.metadata["model_name"] = self.__class__.__name__
|
|
255
|
+
self.pooler = OmniPooling(self.config)
|
|
256
|
+
self.softmax = torch.nn.Softmax(dim=-1)
|
|
257
|
+
self.classifier = torch.nn.Linear(
|
|
258
|
+
self.config.hidden_size, self.config.num_labels
|
|
259
|
+
)
|
|
260
|
+
self.loss_fn = torch.nn.CrossEntropyLoss()
|
|
261
|
+
self.model_info()
|
|
262
|
+
|
|
263
|
+
def forward(self, **inputs):
|
|
264
|
+
"""
|
|
265
|
+
Forward pass for sequence classification.
|
|
266
|
+
|
|
267
|
+
This method performs the forward pass through the model, computing
|
|
268
|
+
sequence-level logits and applying softmax to produce probability
|
|
269
|
+
distributions over the label classes.
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
**inputs: Input tensors including 'input_ids', 'attention_mask',
|
|
273
|
+
and optionally 'labels'.
|
|
274
|
+
|
|
275
|
+
Returns:
|
|
276
|
+
dict: A dictionary containing:
|
|
277
|
+
- logits: Sequence-level classification logits
|
|
278
|
+
- last_hidden_state: Final hidden states from the base model
|
|
279
|
+
- labels: Ground truth labels (if provided)
|
|
280
|
+
|
|
281
|
+
Example:
|
|
282
|
+
>>> outputs = model(
|
|
283
|
+
... input_ids=torch.tensor([[1, 2, 3, 4]]),
|
|
284
|
+
... attention_mask=torch.tensor([[1, 1, 1, 1]]),
|
|
285
|
+
... labels=torch.tensor([0])
|
|
286
|
+
... )
|
|
287
|
+
"""
|
|
288
|
+
labels = inputs.pop("labels", None)
|
|
289
|
+
last_hidden_state = self.last_hidden_state_forward(**inputs)
|
|
290
|
+
last_hidden_state = self.dropout(last_hidden_state)
|
|
291
|
+
last_hidden_state = self.activation(last_hidden_state)
|
|
292
|
+
last_hidden_state = self.pooler(inputs, last_hidden_state)
|
|
293
|
+
logits = self.classifier(last_hidden_state)
|
|
294
|
+
logits = self.softmax(logits)
|
|
295
|
+
outputs = {
|
|
296
|
+
"logits": logits,
|
|
297
|
+
"last_hidden_state": last_hidden_state,
|
|
298
|
+
"labels": labels,
|
|
299
|
+
}
|
|
300
|
+
return outputs
|
|
301
|
+
|
|
302
|
+
def predict(self, sequence_or_inputs, **kwargs):
|
|
303
|
+
"""
|
|
304
|
+
Performs sequence-level prediction on raw inputs.
|
|
305
|
+
|
|
306
|
+
This method takes raw sequences or tokenized inputs and returns
|
|
307
|
+
sequence-level predictions. It processes the inputs through the model
|
|
308
|
+
and returns the predicted class for each sequence.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
312
|
+
tokenized inputs (dict/tuple).
|
|
313
|
+
**kwargs: Additional arguments for tokenization and inference.
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
dict: A dictionary containing:
|
|
317
|
+
- predictions: Predicted class indices for each sequence
|
|
318
|
+
- logits: Raw logits from the model
|
|
319
|
+
- last_hidden_state: Final hidden states
|
|
320
|
+
|
|
321
|
+
Example:
|
|
322
|
+
>>> # Predict on a single sequence
|
|
323
|
+
>>> outputs = model.predict("ATCGATCG")
|
|
324
|
+
>>> print(outputs['predictions']) # tensor([0])
|
|
325
|
+
|
|
326
|
+
>>> # Predict on multiple sequences
|
|
327
|
+
>>> outputs = model.predict(["ATCGATCG", "GCTAGCTA"])
|
|
328
|
+
"""
|
|
329
|
+
raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
|
|
330
|
+
|
|
331
|
+
logits = raw_outputs["logits"]
|
|
332
|
+
last_hidden_state = raw_outputs["last_hidden_state"]
|
|
333
|
+
|
|
334
|
+
predictions = []
|
|
335
|
+
for i in range(logits.shape[0]):
|
|
336
|
+
predictions.append(logits[i].argmax(dim=-1))
|
|
337
|
+
|
|
338
|
+
outputs = {
|
|
339
|
+
"predictions": (
|
|
340
|
+
torch.vstack(predictions).to(self.model.device)
|
|
341
|
+
if predictions[0].shape
|
|
342
|
+
else torch.tensor(predictions).to(self.model.device)
|
|
343
|
+
),
|
|
344
|
+
"logits": logits,
|
|
345
|
+
"last_hidden_state": last_hidden_state,
|
|
346
|
+
}
|
|
347
|
+
|
|
348
|
+
return outputs
|
|
349
|
+
|
|
350
|
+
def inference(self, sequence_or_inputs, **kwargs):
|
|
351
|
+
"""
|
|
352
|
+
Performs sequence-level inference with human-readable output.
|
|
353
|
+
|
|
354
|
+
This method provides processed, human-readable sequence-level predictions.
|
|
355
|
+
It converts logits to class labels and provides confidence scores.
|
|
356
|
+
|
|
357
|
+
Args:
|
|
358
|
+
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
359
|
+
tokenized inputs (dict/tuple).
|
|
360
|
+
**kwargs: Additional arguments for tokenization and inference.
|
|
361
|
+
|
|
362
|
+
Returns:
|
|
363
|
+
dict: A dictionary containing:
|
|
364
|
+
- predictions: Human-readable class labels for each sequence
|
|
365
|
+
- logits: Raw logits from the model
|
|
366
|
+
- confidence: Confidence scores for predictions
|
|
367
|
+
- last_hidden_state: Final hidden states
|
|
368
|
+
|
|
369
|
+
Example:
|
|
370
|
+
>>> # Inference on a single sequence
|
|
371
|
+
>>> results = model.inference("ATCGATCG")
|
|
372
|
+
>>> print(results['predictions']) # "positive"
|
|
373
|
+
>>> print(results['confidence']) # 0.95
|
|
374
|
+
"""
|
|
375
|
+
raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
|
|
376
|
+
|
|
377
|
+
logits = raw_outputs["logits"]
|
|
378
|
+
last_hidden_state = raw_outputs["last_hidden_state"]
|
|
379
|
+
|
|
380
|
+
predictions = []
|
|
381
|
+
for i in range(logits.shape[0]):
|
|
382
|
+
predictions.append(
|
|
383
|
+
self.config.id2label.get(logits[i].argmax(dim=-1).item(), "")
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
if not isinstance(sequence_or_inputs, list):
|
|
387
|
+
outputs = {
|
|
388
|
+
"predictions": predictions[0],
|
|
389
|
+
"logits": logits[0],
|
|
390
|
+
"confidence": torch.max(logits[0]),
|
|
391
|
+
"last_hidden_state": last_hidden_state[0],
|
|
392
|
+
}
|
|
393
|
+
else:
|
|
394
|
+
outputs = {
|
|
395
|
+
"predictions": predictions,
|
|
396
|
+
"logits": logits,
|
|
397
|
+
"confidence": torch.max(logits, dim=-1)[0],
|
|
398
|
+
"last_hidden_state": last_hidden_state,
|
|
399
|
+
}
|
|
400
|
+
|
|
401
|
+
return outputs
|
|
402
|
+
|
|
403
|
+
def loss_function(self, logits, labels):
|
|
404
|
+
"""
|
|
405
|
+
Calculates the cross-entropy loss for sequence classification.
|
|
406
|
+
|
|
407
|
+
This method computes the cross-entropy loss between the predicted
|
|
408
|
+
logits and the ground truth labels.
|
|
409
|
+
|
|
410
|
+
Args:
|
|
411
|
+
logits (torch.Tensor): Predicted logits from the model.
|
|
412
|
+
labels (torch.Tensor): Ground truth labels.
|
|
413
|
+
|
|
414
|
+
Returns:
|
|
415
|
+
torch.Tensor: The computed loss value.
|
|
416
|
+
|
|
417
|
+
Example:
|
|
418
|
+
>>> loss = model.loss_function(logits, labels)
|
|
419
|
+
"""
|
|
420
|
+
loss = self.loss_fn(logits.view(-1, self.config.num_labels), labels.view(-1))
|
|
421
|
+
return loss
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
class OmniModelForMultiLabelSequenceClassification(
|
|
425
|
+
OmniModelForSequenceClassification
|
|
426
|
+
):
|
|
427
|
+
"""
|
|
428
|
+
Model for multi-label sequence classification tasks in genomics.
|
|
429
|
+
|
|
430
|
+
This model is designed for multi-label classification tasks where
|
|
431
|
+
a single sequence can be assigned multiple labels simultaneously.
|
|
432
|
+
It extends the sequence classification model with multi-label capabilities.
|
|
433
|
+
|
|
434
|
+
The model uses sigmoid activation instead of softmax to allow multiple
|
|
435
|
+
labels per sequence and uses binary cross-entropy loss for training.
|
|
436
|
+
|
|
437
|
+
Attributes:
|
|
438
|
+
softmax (torch.nn.Sigmoid): Sigmoid layer for multi-label probability computation.
|
|
439
|
+
loss_fn (torch.nn.BCELoss): Binary cross-entropy loss function for training.
|
|
440
|
+
"""
|
|
441
|
+
|
|
442
|
+
def __init__(self, config_or_model, tokenizer, *args, **kwargs):
|
|
443
|
+
"""
|
|
444
|
+
Initializes the multi-label sequence classification model.
|
|
445
|
+
|
|
446
|
+
Args:
|
|
447
|
+
config_or_model: Model configuration, pre-trained model path, or model instance.
|
|
448
|
+
tokenizer: The tokenizer associated with the model.
|
|
449
|
+
*args: Additional positional arguments.
|
|
450
|
+
**kwargs: Additional keyword arguments.
|
|
451
|
+
|
|
452
|
+
Example:
|
|
453
|
+
>>> model = OmniModelForMultiLabelSequenceClassification("model_path", tokenizer)
|
|
454
|
+
"""
|
|
455
|
+
super().__init__(config_or_model, tokenizer, *args, **kwargs)
|
|
456
|
+
self.metadata["model_name"] = self.__class__.__name__
|
|
457
|
+
self.softmax = torch.nn.Sigmoid()
|
|
458
|
+
self.loss_fn = torch.nn.BCELoss()
|
|
459
|
+
self.model_info()
|
|
460
|
+
|
|
461
|
+
def loss_function(self, logits, labels):
|
|
462
|
+
"""
|
|
463
|
+
Calculates the binary cross-entropy loss for multi-label classification.
|
|
464
|
+
|
|
465
|
+
This method computes the binary cross-entropy loss between the predicted
|
|
466
|
+
probabilities and the ground truth multi-label targets.
|
|
467
|
+
|
|
468
|
+
Args:
|
|
469
|
+
logits (torch.Tensor): Predicted logits from the model.
|
|
470
|
+
labels (torch.Tensor): Ground truth multi-label targets.
|
|
471
|
+
|
|
472
|
+
Returns:
|
|
473
|
+
torch.Tensor: The computed loss value.
|
|
474
|
+
|
|
475
|
+
Example:
|
|
476
|
+
>>> loss = model.loss_function(logits, labels)
|
|
477
|
+
"""
|
|
478
|
+
loss = self.loss_fn(logits.view(-1), labels.view(-1).to(torch.float32))
|
|
479
|
+
return loss
|
|
480
|
+
|
|
481
|
+
def predict(self, sequence_or_inputs, **kwargs):
|
|
482
|
+
"""
|
|
483
|
+
Performs multi-label prediction on raw inputs.
|
|
484
|
+
|
|
485
|
+
This method takes raw sequences or tokenized inputs and returns
|
|
486
|
+
multi-label predictions. It applies a threshold to determine
|
|
487
|
+
which labels are active for each sequence.
|
|
488
|
+
|
|
489
|
+
Args:
|
|
490
|
+
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
491
|
+
tokenized inputs (dict/tuple).
|
|
492
|
+
**kwargs: Additional arguments for tokenization and inference.
|
|
493
|
+
|
|
494
|
+
Returns:
|
|
495
|
+
dict: A dictionary containing:
|
|
496
|
+
- predictions: Multi-label predictions for each sequence
|
|
497
|
+
- logits: Raw logits from the model
|
|
498
|
+
- last_hidden_state: Final hidden states
|
|
499
|
+
|
|
500
|
+
Example:
|
|
501
|
+
>>> # Predict on a single sequence
|
|
502
|
+
>>> outputs = model.predict("ATCGATCG")
|
|
503
|
+
>>> print(outputs['predictions']) # tensor([1, 0, 1, 0])
|
|
504
|
+
"""
|
|
505
|
+
raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
|
|
506
|
+
|
|
507
|
+
logits = raw_outputs["logits"]
|
|
508
|
+
last_hidden_state = raw_outputs["last_hidden_state"]
|
|
509
|
+
|
|
510
|
+
predictions = []
|
|
511
|
+
for i in range(logits.shape[0]):
|
|
512
|
+
prediction = logits[i].ge(0.5).to(torch.int).cpu()
|
|
513
|
+
predictions.append(prediction)
|
|
514
|
+
|
|
515
|
+
outputs = {
|
|
516
|
+
"predictions": (
|
|
517
|
+
torch.vstack(predictions).to(self.model.device)
|
|
518
|
+
if predictions[0].shape
|
|
519
|
+
else torch.tensor(predictions).to(self.model.device)
|
|
520
|
+
),
|
|
521
|
+
"logits": logits,
|
|
522
|
+
"last_hidden_state": last_hidden_state,
|
|
523
|
+
}
|
|
524
|
+
|
|
525
|
+
return outputs
|
|
526
|
+
|
|
527
|
+
def inference(self, sequence_or_inputs, **kwargs):
|
|
528
|
+
"""
|
|
529
|
+
Performs multi-label inference with human-readable output.
|
|
530
|
+
|
|
531
|
+
This method provides processed, human-readable multi-label predictions.
|
|
532
|
+
It converts logits to binary labels and provides confidence scores.
|
|
533
|
+
|
|
534
|
+
Args:
|
|
535
|
+
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
536
|
+
tokenized inputs (dict/tuple).
|
|
537
|
+
**kwargs: Additional arguments for tokenization and inference.
|
|
538
|
+
|
|
539
|
+
Returns:
|
|
540
|
+
dict: A dictionary containing:
|
|
541
|
+
- predictions: Human-readable binary labels for each sequence
|
|
542
|
+
- logits: Raw logits from the model
|
|
543
|
+
- confidence: Confidence scores for predictions
|
|
544
|
+
- last_hidden_state: Final hidden states
|
|
545
|
+
|
|
546
|
+
Example:
|
|
547
|
+
>>> # Inference on a single sequence
|
|
548
|
+
>>> results = model.inference("ATCGATCG")
|
|
549
|
+
>>> print(results['predictions']) # tensor([1, 0, 1, 0])
|
|
550
|
+
"""
|
|
551
|
+
return self.predict(sequence_or_inputs, **kwargs)
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
class OmniModelForTokenClassificationWith2DStructure(
|
|
555
|
+
OmniModelForTokenClassification
|
|
556
|
+
):
|
|
557
|
+
def __init__(self, config_or_model, tokenizer, *args, **kwargs):
|
|
558
|
+
super().__init__(config_or_model, tokenizer, *args, **kwargs)
|
|
559
|
+
self.metadata["model_name"] = self.__class__.__name__
|
|
560
|
+
self.pooler = OmniPooling(self.config)
|
|
561
|
+
self.model_info()
|
|
562
|
+
|
|
563
|
+
def forward(self, **inputs):
|
|
564
|
+
labels = inputs.pop("labels", None)
|
|
565
|
+
last_hidden_state = self.last_hidden_state_forward(**inputs)
|
|
566
|
+
last_hidden_state = self.dropout(last_hidden_state)
|
|
567
|
+
last_hidden_state = self.activation(last_hidden_state)
|
|
568
|
+
logits = self.classifier(last_hidden_state)
|
|
569
|
+
logits = self.softmax(logits)
|
|
570
|
+
outputs = {
|
|
571
|
+
"logits": logits,
|
|
572
|
+
"last_hidden_state": last_hidden_state,
|
|
573
|
+
"labels": labels,
|
|
574
|
+
}
|
|
575
|
+
return outputs
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
class OmniModelForSequenceClassificationWith2DStructure(
|
|
579
|
+
OmniModelForSequenceClassification
|
|
580
|
+
):
|
|
581
|
+
def __init__(self, config_or_model, tokenizer, *args, **kwargs):
|
|
582
|
+
super().__init__(config_or_model, tokenizer, *args, **kwargs)
|
|
583
|
+
self.metadata["model_name"] = self.__class__.__name__
|
|
584
|
+
self.pooler = OmniPooling(self.config)
|
|
585
|
+
self.model_info()
|
|
586
|
+
|
|
587
|
+
def forward(self, **inputs):
|
|
588
|
+
labels = inputs.pop("labels", None)
|
|
589
|
+
last_hidden_state = self.last_hidden_state_forward(**inputs)
|
|
590
|
+
last_hidden_state = self.dropout(last_hidden_state)
|
|
591
|
+
last_hidden_state = self.activation(last_hidden_state)
|
|
592
|
+
last_hidden_state = self.pooler(inputs, last_hidden_state)
|
|
593
|
+
logits = self.classifier(last_hidden_state)
|
|
594
|
+
logits = self.softmax(logits)
|
|
595
|
+
|
|
596
|
+
outputs = {
|
|
597
|
+
"logits": logits,
|
|
598
|
+
"last_hidden_state": last_hidden_state,
|
|
599
|
+
"labels": labels,
|
|
600
|
+
}
|
|
601
|
+
return outputs
|
|
602
|
+
|
|
603
|
+
|
|
604
|
+
class OmniModelForMultiLabelSequenceClassificationWith2DStructure(
|
|
605
|
+
OmniModelForSequenceClassificationWith2DStructure
|
|
606
|
+
):
|
|
607
|
+
def __init__(self, config_or_model, tokenizer, *args, **kwargs):
|
|
608
|
+
super().__init__(config_or_model, tokenizer, *args, **kwargs)
|
|
609
|
+
self.metadata["model_name"] = self.__class__.__name__
|
|
610
|
+
self.softmax = torch.nn.Sigmoid()
|
|
611
|
+
self.loss_fn = torch.nn.BCELoss()
|
|
612
|
+
self.model_info()
|
|
613
|
+
|
|
614
|
+
def loss_function(self, logits, labels):
|
|
615
|
+
loss = self.loss_fn(logits.view(-1), labels.view(-1).to(torch.float32))
|
|
616
|
+
return loss
|
|
617
|
+
|
|
618
|
+
def predict(self, sequence_or_inputs, **kwargs):
|
|
619
|
+
raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
|
|
620
|
+
|
|
621
|
+
logits = raw_outputs["logits"]
|
|
622
|
+
last_hidden_state = raw_outputs["last_hidden_state"]
|
|
623
|
+
|
|
624
|
+
predictions = []
|
|
625
|
+
for i in range(logits.shape[0]):
|
|
626
|
+
prediction = logits[i].ge(0.5).to(torch.int).cpu()
|
|
627
|
+
predictions.append(prediction)
|
|
628
|
+
|
|
629
|
+
outputs = {
|
|
630
|
+
"predictions": (
|
|
631
|
+
torch.vstack(predictions).to(self.model.device)
|
|
632
|
+
if predictions[0].shape
|
|
633
|
+
else torch.tensor(predictions).to(self.model.device)
|
|
634
|
+
),
|
|
635
|
+
"logits": logits,
|
|
636
|
+
"last_hidden_state": last_hidden_state,
|
|
637
|
+
}
|
|
638
|
+
|
|
639
|
+
return outputs
|
|
640
|
+
|
|
641
|
+
def inference(self, sequence_or_inputs, **kwargs):
|
|
642
|
+
return self.predict(sequence_or_inputs, **kwargs)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# file: __init__.py
|
|
3
|
+
# time: 01:51 06/05/2024
|
|
4
|
+
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
+
# github: https://github.com/yangheng95
|
|
6
|
+
# huggingface: https://huggingface.co/yangheng
|
|
7
|
+
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
8
|
+
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
9
|
+
"""
|
|
10
|
+
This package contains modules for embedding models.
|
|
11
|
+
"""
|
|
12
|
+
|