omnigenome 0.3.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of omnigenome might be problematic. Click here for more details.

Files changed (85) hide show
  1. omnigenome/__init__.py +281 -0
  2. omnigenome/auto/__init__.py +3 -0
  3. omnigenome/auto/auto_bench/__init__.py +12 -0
  4. omnigenome/auto/auto_bench/auto_bench.py +484 -0
  5. omnigenome/auto/auto_bench/auto_bench_cli.py +230 -0
  6. omnigenome/auto/auto_bench/auto_bench_config.py +216 -0
  7. omnigenome/auto/auto_bench/config_check.py +34 -0
  8. omnigenome/auto/auto_train/__init__.py +13 -0
  9. omnigenome/auto/auto_train/auto_train.py +430 -0
  10. omnigenome/auto/auto_train/auto_train_cli.py +222 -0
  11. omnigenome/auto/bench_hub/__init__.py +12 -0
  12. omnigenome/auto/bench_hub/bench_hub.py +25 -0
  13. omnigenome/cli/__init__.py +13 -0
  14. omnigenome/cli/commands/__init__.py +13 -0
  15. omnigenome/cli/commands/base.py +83 -0
  16. omnigenome/cli/commands/bench/__init__.py +13 -0
  17. omnigenome/cli/commands/bench/bench_cli.py +202 -0
  18. omnigenome/cli/commands/rna/__init__.py +13 -0
  19. omnigenome/cli/commands/rna/rna_design.py +178 -0
  20. omnigenome/cli/omnigenome_cli.py +128 -0
  21. omnigenome/src/__init__.py +12 -0
  22. omnigenome/src/abc/__init__.py +12 -0
  23. omnigenome/src/abc/abstract_dataset.py +622 -0
  24. omnigenome/src/abc/abstract_metric.py +114 -0
  25. omnigenome/src/abc/abstract_model.py +689 -0
  26. omnigenome/src/abc/abstract_tokenizer.py +267 -0
  27. omnigenome/src/dataset/__init__.py +16 -0
  28. omnigenome/src/dataset/omni_dataset.py +435 -0
  29. omnigenome/src/lora/__init__.py +13 -0
  30. omnigenome/src/lora/lora_model.py +294 -0
  31. omnigenome/src/metric/__init__.py +15 -0
  32. omnigenome/src/metric/classification_metric.py +184 -0
  33. omnigenome/src/metric/metric.py +199 -0
  34. omnigenome/src/metric/ranking_metric.py +142 -0
  35. omnigenome/src/metric/regression_metric.py +191 -0
  36. omnigenome/src/misc/__init__.py +3 -0
  37. omnigenome/src/misc/utils.py +439 -0
  38. omnigenome/src/model/__init__.py +19 -0
  39. omnigenome/src/model/augmentation/__init__.py +12 -0
  40. omnigenome/src/model/augmentation/model.py +219 -0
  41. omnigenome/src/model/classification/__init__.py +12 -0
  42. omnigenome/src/model/classification/model.py +642 -0
  43. omnigenome/src/model/embedding/__init__.py +12 -0
  44. omnigenome/src/model/embedding/model.py +263 -0
  45. omnigenome/src/model/mlm/__init__.py +12 -0
  46. omnigenome/src/model/mlm/model.py +177 -0
  47. omnigenome/src/model/module_utils.py +232 -0
  48. omnigenome/src/model/regression/__init__.py +12 -0
  49. omnigenome/src/model/regression/model.py +786 -0
  50. omnigenome/src/model/regression/resnet.py +483 -0
  51. omnigenome/src/model/rna_design/__init__.py +12 -0
  52. omnigenome/src/model/rna_design/model.py +426 -0
  53. omnigenome/src/model/seq2seq/__init__.py +12 -0
  54. omnigenome/src/model/seq2seq/model.py +44 -0
  55. omnigenome/src/tokenizer/__init__.py +16 -0
  56. omnigenome/src/tokenizer/bpe_tokenizer.py +226 -0
  57. omnigenome/src/tokenizer/kmers_tokenizer.py +247 -0
  58. omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +249 -0
  59. omnigenome/src/trainer/__init__.py +14 -0
  60. omnigenome/src/trainer/accelerate_trainer.py +739 -0
  61. omnigenome/src/trainer/hf_trainer.py +75 -0
  62. omnigenome/src/trainer/trainer.py +579 -0
  63. omnigenome/utility/__init__.py +3 -0
  64. omnigenome/utility/dataset_hub/__init__.py +13 -0
  65. omnigenome/utility/dataset_hub/dataset_hub.py +178 -0
  66. omnigenome/utility/ensemble.py +324 -0
  67. omnigenome/utility/hub_utils.py +517 -0
  68. omnigenome/utility/model_hub/__init__.py +12 -0
  69. omnigenome/utility/model_hub/model_hub.py +231 -0
  70. omnigenome/utility/pipeline_hub/__init__.py +12 -0
  71. omnigenome/utility/pipeline_hub/pipeline.py +483 -0
  72. omnigenome/utility/pipeline_hub/pipeline_hub.py +129 -0
  73. omnigenome-0.3.0a0.dist-info/METADATA +224 -0
  74. omnigenome-0.3.0a0.dist-info/RECORD +85 -0
  75. omnigenome-0.3.0a0.dist-info/WHEEL +5 -0
  76. omnigenome-0.3.0a0.dist-info/entry_points.txt +3 -0
  77. omnigenome-0.3.0a0.dist-info/licenses/LICENSE +201 -0
  78. omnigenome-0.3.0a0.dist-info/top_level.txt +2 -0
  79. tests/__init__.py +9 -0
  80. tests/conftest.py +160 -0
  81. tests/test_dataset_patterns.py +291 -0
  82. tests/test_examples_syntax.py +83 -0
  83. tests/test_model_loading.py +183 -0
  84. tests/test_rna_functions.py +255 -0
  85. tests/test_training_patterns.py +302 -0
@@ -0,0 +1,642 @@
1
+ # -*- coding: utf-8 -*-
2
+ # file: model.py
3
+ # time: 18:36 06/04/2024
4
+ # author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
+ # github: https://github.com/yangheng95
6
+ # huggingface: https://huggingface.co/yangheng
7
+ # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
8
+ # Copyright (C) 2019-2024. All Rights Reserved.
9
+
10
+ import torch
11
+
12
+ from ...abc.abstract_model import OmniModel
13
+ from ..module_utils import OmniPooling
14
+
15
+
16
+ class OmniModelForTokenClassification(OmniModel):
17
+ """
18
+ Model for token classification tasks in genomics.
19
+
20
+ This model is designed for token-level classification tasks such as
21
+ sequence labeling, where each token in the input sequence needs to be
22
+ classified into different categories. It extends the base OmniModel
23
+ with token-level classification capabilities.
24
+
25
+ The model adds a classification head on top of the base model's hidden
26
+ states and applies softmax to produce probability distributions over
27
+ the label classes for each token.
28
+
29
+ Attributes:
30
+ softmax (torch.nn.Softmax): Softmax layer for probability computation.
31
+ classifier (torch.nn.Linear): Linear classification head.
32
+ loss_fn (torch.nn.CrossEntropyLoss): Loss function for training.
33
+ """
34
+
35
+ def __init__(self, config_or_model, tokenizer, *args, **kwargs):
36
+ """
37
+ Initializes the token classification model.
38
+
39
+ Args:
40
+ config_or_model: Model configuration, pre-trained model path, or model instance.
41
+ tokenizer: The tokenizer associated with the model.
42
+ *args: Additional positional arguments.
43
+ **kwargs: Additional keyword arguments.
44
+
45
+ Example:
46
+ >>> model = OmniModelForTokenClassification("model_path", tokenizer)
47
+ """
48
+ super().__init__(config_or_model, tokenizer, *args, **kwargs)
49
+ self.metadata["model_name"] = self.__class__.__name__
50
+ self.softmax = torch.nn.Softmax(dim=-1)
51
+ self.classifier = torch.nn.Linear(
52
+ self.config.hidden_size, self.config.num_labels
53
+ )
54
+ self.loss_fn = torch.nn.CrossEntropyLoss()
55
+ self.model_info()
56
+
57
+ def forward(self, **inputs):
58
+ """
59
+ Forward pass for token classification.
60
+
61
+ This method performs the forward pass through the model, computing
62
+ logits for each token in the input sequence and applying softmax
63
+ to produce probability distributions.
64
+
65
+ Args:
66
+ **inputs: Input tensors including 'input_ids', 'attention_mask',
67
+ and optionally 'labels'.
68
+
69
+ Returns:
70
+ dict: A dictionary containing:
71
+ - logits: Token-level classification logits
72
+ - last_hidden_state: Final hidden states from the base model
73
+ - labels: Ground truth labels (if provided)
74
+
75
+ Example:
76
+ >>> outputs = model(
77
+ ... input_ids=torch.tensor([[1, 2, 3, 4]]),
78
+ ... attention_mask=torch.tensor([[1, 1, 1, 1]]),
79
+ ... labels=torch.tensor([[0, 1, 0, 1]])
80
+ ... )
81
+ """
82
+ labels = inputs.pop("labels", None)
83
+ last_hidden_state = self.last_hidden_state_forward(**inputs)
84
+ last_hidden_state = self.dropout(last_hidden_state)
85
+ last_hidden_state = self.activation(last_hidden_state)
86
+ logits = self.classifier(last_hidden_state)
87
+ logits = self.softmax(logits)
88
+ outputs = {
89
+ "logits": logits,
90
+ "last_hidden_state": last_hidden_state,
91
+ "labels": labels,
92
+ }
93
+ return outputs
94
+
95
+ def predict(self, sequence_or_inputs, **kwargs):
96
+ """
97
+ Performs token-level prediction on raw inputs.
98
+
99
+ This method takes raw sequences or tokenized inputs and returns
100
+ token-level predictions. It processes the inputs through the model
101
+ and returns the predicted class for each token.
102
+
103
+ Args:
104
+ sequence_or_inputs: A sequence (str), list of sequences, or
105
+ tokenized inputs (dict/tuple).
106
+ **kwargs: Additional arguments for tokenization and inference.
107
+
108
+ Returns:
109
+ dict: A dictionary containing:
110
+ - predictions: Predicted class indices for each token
111
+ - logits: Raw logits from the model
112
+ - last_hidden_state: Final hidden states
113
+
114
+ Example:
115
+ >>> # Predict on a single sequence
116
+ >>> outputs = model.predict("ATCGATCG")
117
+ >>> print(outputs['predictions'].shape) # (seq_len,)
118
+
119
+ >>> # Predict on multiple sequences
120
+ >>> outputs = model.predict(["ATCGATCG", "GCTAGCTA"])
121
+ """
122
+ raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
123
+ logits = raw_outputs["logits"]
124
+ last_hidden_state = raw_outputs["last_hidden_state"]
125
+
126
+ predictions = []
127
+ for i in range(logits.shape[0]):
128
+ predictions.append(logits[i].argmax(dim=-1).detach().cpu())
129
+
130
+ outputs = {
131
+ "predictions": (
132
+ torch.vstack(predictions).to(self.model.device)
133
+ if predictions[0].shape
134
+ else torch.tensor(predictions).to(self.model.device)
135
+ ),
136
+ "logits": logits,
137
+ "last_hidden_state": last_hidden_state,
138
+ }
139
+
140
+ return outputs
141
+
142
+ def inference(self, sequence_or_inputs, **kwargs):
143
+ """
144
+ Performs token-level inference with human-readable output.
145
+
146
+ This method provides processed, human-readable token-level predictions.
147
+ It converts logits to class labels and handles special tokens appropriately.
148
+
149
+ Args:
150
+ sequence_or_inputs: A sequence (str), list of sequences, or
151
+ tokenized inputs (dict/tuple).
152
+ **kwargs: Additional arguments for tokenization and inference.
153
+
154
+ Returns:
155
+ dict: A dictionary containing:
156
+ - predictions: Human-readable class labels for each token
157
+ - logits: Raw logits from the model
158
+ - confidence: Confidence scores for predictions
159
+ - last_hidden_state: Final hidden states
160
+
161
+ Example:
162
+ >>> # Inference on a single sequence
163
+ >>> results = model.inference("ATCGATCG")
164
+ >>> print(results['predictions']) # ['A', 'T', 'C', 'G', ...]
165
+ """
166
+ raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
167
+ inputs = raw_outputs["inputs"]
168
+ logits = raw_outputs["logits"]
169
+ last_hidden_state = raw_outputs["last_hidden_state"]
170
+
171
+ predictions = []
172
+ for i in range(logits.shape[0]):
173
+ # Note that the first and last tokens are removed,
174
+ # and the length of outputs are calculated based on the tokenized inputs.
175
+ i_logit = logits[i][inputs["input_ids"][i].ne(self.config.pad_token_id)][
176
+ 1:-1
177
+ ]
178
+ prediction = [
179
+ self.config.id2label.get(x.item(), "") for x in i_logit.argmax(dim=-1)
180
+ ]
181
+ predictions.append(prediction)
182
+
183
+ if not isinstance(sequence_or_inputs, list):
184
+ outputs = {
185
+ "predictions": predictions[0],
186
+ "logits": logits[0],
187
+ "confidence": torch.max(logits[0]),
188
+ "last_hidden_state": last_hidden_state[0],
189
+ }
190
+ else:
191
+ outputs = {
192
+ "predictions": predictions,
193
+ "logits": logits,
194
+ "confidence": torch.max(logits, dim=-1)[0],
195
+ "last_hidden_state": last_hidden_state,
196
+ }
197
+
198
+ return outputs
199
+
200
+ def loss_function(self, logits, labels):
201
+ """
202
+ Calculates the cross-entropy loss for token classification.
203
+
204
+ This method computes the cross-entropy loss between the predicted
205
+ logits and the ground truth labels, ignoring padding tokens.
206
+
207
+ Args:
208
+ logits (torch.Tensor): Predicted logits from the model.
209
+ labels (torch.Tensor): Ground truth labels.
210
+
211
+ Returns:
212
+ torch.Tensor: The computed loss value.
213
+
214
+ Example:
215
+ >>> loss = model.loss_function(logits, labels)
216
+ """
217
+ loss = self.loss_fn(logits.view(-1, self.config.num_labels), labels.view(-1))
218
+ return loss
219
+
220
+
221
+ class OmniModelForSequenceClassification(OmniModel):
222
+ """
223
+ Model for sequence classification tasks in genomics.
224
+
225
+ This model is designed for sequence-level classification tasks where
226
+ the entire input sequence is classified into one of several categories.
227
+ It extends the base OmniModel with sequence-level classification capabilities.
228
+
229
+ The model uses a pooling mechanism to aggregate token-level representations
230
+ into a sequence-level representation, which is then classified using a
231
+ linear classifier.
232
+
233
+ Attributes:
234
+ pooler (OmniPooling): Pooling layer for sequence-level representation.
235
+ softmax (torch.nn.Softmax): Softmax layer for probability computation.
236
+ classifier (torch.nn.Linear): Linear classification head.
237
+ loss_fn (torch.nn.CrossEntropyLoss): Loss function for training.
238
+ """
239
+
240
+ def __init__(self, config_or_model, tokenizer, *args, **kwargs):
241
+ """
242
+ Initializes the sequence classification model.
243
+
244
+ Args:
245
+ config_or_model: Model configuration, pre-trained model path, or model instance.
246
+ tokenizer: The tokenizer associated with the model.
247
+ *args: Additional positional arguments.
248
+ **kwargs: Additional keyword arguments.
249
+
250
+ Example:
251
+ >>> model = OmniModelForSequenceClassification("model_path", tokenizer)
252
+ """
253
+ super().__init__(config_or_model, tokenizer, *args, **kwargs)
254
+ self.metadata["model_name"] = self.__class__.__name__
255
+ self.pooler = OmniPooling(self.config)
256
+ self.softmax = torch.nn.Softmax(dim=-1)
257
+ self.classifier = torch.nn.Linear(
258
+ self.config.hidden_size, self.config.num_labels
259
+ )
260
+ self.loss_fn = torch.nn.CrossEntropyLoss()
261
+ self.model_info()
262
+
263
+ def forward(self, **inputs):
264
+ """
265
+ Forward pass for sequence classification.
266
+
267
+ This method performs the forward pass through the model, computing
268
+ sequence-level logits and applying softmax to produce probability
269
+ distributions over the label classes.
270
+
271
+ Args:
272
+ **inputs: Input tensors including 'input_ids', 'attention_mask',
273
+ and optionally 'labels'.
274
+
275
+ Returns:
276
+ dict: A dictionary containing:
277
+ - logits: Sequence-level classification logits
278
+ - last_hidden_state: Final hidden states from the base model
279
+ - labels: Ground truth labels (if provided)
280
+
281
+ Example:
282
+ >>> outputs = model(
283
+ ... input_ids=torch.tensor([[1, 2, 3, 4]]),
284
+ ... attention_mask=torch.tensor([[1, 1, 1, 1]]),
285
+ ... labels=torch.tensor([0])
286
+ ... )
287
+ """
288
+ labels = inputs.pop("labels", None)
289
+ last_hidden_state = self.last_hidden_state_forward(**inputs)
290
+ last_hidden_state = self.dropout(last_hidden_state)
291
+ last_hidden_state = self.activation(last_hidden_state)
292
+ last_hidden_state = self.pooler(inputs, last_hidden_state)
293
+ logits = self.classifier(last_hidden_state)
294
+ logits = self.softmax(logits)
295
+ outputs = {
296
+ "logits": logits,
297
+ "last_hidden_state": last_hidden_state,
298
+ "labels": labels,
299
+ }
300
+ return outputs
301
+
302
+ def predict(self, sequence_or_inputs, **kwargs):
303
+ """
304
+ Performs sequence-level prediction on raw inputs.
305
+
306
+ This method takes raw sequences or tokenized inputs and returns
307
+ sequence-level predictions. It processes the inputs through the model
308
+ and returns the predicted class for each sequence.
309
+
310
+ Args:
311
+ sequence_or_inputs: A sequence (str), list of sequences, or
312
+ tokenized inputs (dict/tuple).
313
+ **kwargs: Additional arguments for tokenization and inference.
314
+
315
+ Returns:
316
+ dict: A dictionary containing:
317
+ - predictions: Predicted class indices for each sequence
318
+ - logits: Raw logits from the model
319
+ - last_hidden_state: Final hidden states
320
+
321
+ Example:
322
+ >>> # Predict on a single sequence
323
+ >>> outputs = model.predict("ATCGATCG")
324
+ >>> print(outputs['predictions']) # tensor([0])
325
+
326
+ >>> # Predict on multiple sequences
327
+ >>> outputs = model.predict(["ATCGATCG", "GCTAGCTA"])
328
+ """
329
+ raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
330
+
331
+ logits = raw_outputs["logits"]
332
+ last_hidden_state = raw_outputs["last_hidden_state"]
333
+
334
+ predictions = []
335
+ for i in range(logits.shape[0]):
336
+ predictions.append(logits[i].argmax(dim=-1))
337
+
338
+ outputs = {
339
+ "predictions": (
340
+ torch.vstack(predictions).to(self.model.device)
341
+ if predictions[0].shape
342
+ else torch.tensor(predictions).to(self.model.device)
343
+ ),
344
+ "logits": logits,
345
+ "last_hidden_state": last_hidden_state,
346
+ }
347
+
348
+ return outputs
349
+
350
+ def inference(self, sequence_or_inputs, **kwargs):
351
+ """
352
+ Performs sequence-level inference with human-readable output.
353
+
354
+ This method provides processed, human-readable sequence-level predictions.
355
+ It converts logits to class labels and provides confidence scores.
356
+
357
+ Args:
358
+ sequence_or_inputs: A sequence (str), list of sequences, or
359
+ tokenized inputs (dict/tuple).
360
+ **kwargs: Additional arguments for tokenization and inference.
361
+
362
+ Returns:
363
+ dict: A dictionary containing:
364
+ - predictions: Human-readable class labels for each sequence
365
+ - logits: Raw logits from the model
366
+ - confidence: Confidence scores for predictions
367
+ - last_hidden_state: Final hidden states
368
+
369
+ Example:
370
+ >>> # Inference on a single sequence
371
+ >>> results = model.inference("ATCGATCG")
372
+ >>> print(results['predictions']) # "positive"
373
+ >>> print(results['confidence']) # 0.95
374
+ """
375
+ raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
376
+
377
+ logits = raw_outputs["logits"]
378
+ last_hidden_state = raw_outputs["last_hidden_state"]
379
+
380
+ predictions = []
381
+ for i in range(logits.shape[0]):
382
+ predictions.append(
383
+ self.config.id2label.get(logits[i].argmax(dim=-1).item(), "")
384
+ )
385
+
386
+ if not isinstance(sequence_or_inputs, list):
387
+ outputs = {
388
+ "predictions": predictions[0],
389
+ "logits": logits[0],
390
+ "confidence": torch.max(logits[0]),
391
+ "last_hidden_state": last_hidden_state[0],
392
+ }
393
+ else:
394
+ outputs = {
395
+ "predictions": predictions,
396
+ "logits": logits,
397
+ "confidence": torch.max(logits, dim=-1)[0],
398
+ "last_hidden_state": last_hidden_state,
399
+ }
400
+
401
+ return outputs
402
+
403
+ def loss_function(self, logits, labels):
404
+ """
405
+ Calculates the cross-entropy loss for sequence classification.
406
+
407
+ This method computes the cross-entropy loss between the predicted
408
+ logits and the ground truth labels.
409
+
410
+ Args:
411
+ logits (torch.Tensor): Predicted logits from the model.
412
+ labels (torch.Tensor): Ground truth labels.
413
+
414
+ Returns:
415
+ torch.Tensor: The computed loss value.
416
+
417
+ Example:
418
+ >>> loss = model.loss_function(logits, labels)
419
+ """
420
+ loss = self.loss_fn(logits.view(-1, self.config.num_labels), labels.view(-1))
421
+ return loss
422
+
423
+
424
+ class OmniModelForMultiLabelSequenceClassification(
425
+ OmniModelForSequenceClassification
426
+ ):
427
+ """
428
+ Model for multi-label sequence classification tasks in genomics.
429
+
430
+ This model is designed for multi-label classification tasks where
431
+ a single sequence can be assigned multiple labels simultaneously.
432
+ It extends the sequence classification model with multi-label capabilities.
433
+
434
+ The model uses sigmoid activation instead of softmax to allow multiple
435
+ labels per sequence and uses binary cross-entropy loss for training.
436
+
437
+ Attributes:
438
+ softmax (torch.nn.Sigmoid): Sigmoid layer for multi-label probability computation.
439
+ loss_fn (torch.nn.BCELoss): Binary cross-entropy loss function for training.
440
+ """
441
+
442
+ def __init__(self, config_or_model, tokenizer, *args, **kwargs):
443
+ """
444
+ Initializes the multi-label sequence classification model.
445
+
446
+ Args:
447
+ config_or_model: Model configuration, pre-trained model path, or model instance.
448
+ tokenizer: The tokenizer associated with the model.
449
+ *args: Additional positional arguments.
450
+ **kwargs: Additional keyword arguments.
451
+
452
+ Example:
453
+ >>> model = OmniModelForMultiLabelSequenceClassification("model_path", tokenizer)
454
+ """
455
+ super().__init__(config_or_model, tokenizer, *args, **kwargs)
456
+ self.metadata["model_name"] = self.__class__.__name__
457
+ self.softmax = torch.nn.Sigmoid()
458
+ self.loss_fn = torch.nn.BCELoss()
459
+ self.model_info()
460
+
461
+ def loss_function(self, logits, labels):
462
+ """
463
+ Calculates the binary cross-entropy loss for multi-label classification.
464
+
465
+ This method computes the binary cross-entropy loss between the predicted
466
+ probabilities and the ground truth multi-label targets.
467
+
468
+ Args:
469
+ logits (torch.Tensor): Predicted logits from the model.
470
+ labels (torch.Tensor): Ground truth multi-label targets.
471
+
472
+ Returns:
473
+ torch.Tensor: The computed loss value.
474
+
475
+ Example:
476
+ >>> loss = model.loss_function(logits, labels)
477
+ """
478
+ loss = self.loss_fn(logits.view(-1), labels.view(-1).to(torch.float32))
479
+ return loss
480
+
481
+ def predict(self, sequence_or_inputs, **kwargs):
482
+ """
483
+ Performs multi-label prediction on raw inputs.
484
+
485
+ This method takes raw sequences or tokenized inputs and returns
486
+ multi-label predictions. It applies a threshold to determine
487
+ which labels are active for each sequence.
488
+
489
+ Args:
490
+ sequence_or_inputs: A sequence (str), list of sequences, or
491
+ tokenized inputs (dict/tuple).
492
+ **kwargs: Additional arguments for tokenization and inference.
493
+
494
+ Returns:
495
+ dict: A dictionary containing:
496
+ - predictions: Multi-label predictions for each sequence
497
+ - logits: Raw logits from the model
498
+ - last_hidden_state: Final hidden states
499
+
500
+ Example:
501
+ >>> # Predict on a single sequence
502
+ >>> outputs = model.predict("ATCGATCG")
503
+ >>> print(outputs['predictions']) # tensor([1, 0, 1, 0])
504
+ """
505
+ raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
506
+
507
+ logits = raw_outputs["logits"]
508
+ last_hidden_state = raw_outputs["last_hidden_state"]
509
+
510
+ predictions = []
511
+ for i in range(logits.shape[0]):
512
+ prediction = logits[i].ge(0.5).to(torch.int).cpu()
513
+ predictions.append(prediction)
514
+
515
+ outputs = {
516
+ "predictions": (
517
+ torch.vstack(predictions).to(self.model.device)
518
+ if predictions[0].shape
519
+ else torch.tensor(predictions).to(self.model.device)
520
+ ),
521
+ "logits": logits,
522
+ "last_hidden_state": last_hidden_state,
523
+ }
524
+
525
+ return outputs
526
+
527
+ def inference(self, sequence_or_inputs, **kwargs):
528
+ """
529
+ Performs multi-label inference with human-readable output.
530
+
531
+ This method provides processed, human-readable multi-label predictions.
532
+ It converts logits to binary labels and provides confidence scores.
533
+
534
+ Args:
535
+ sequence_or_inputs: A sequence (str), list of sequences, or
536
+ tokenized inputs (dict/tuple).
537
+ **kwargs: Additional arguments for tokenization and inference.
538
+
539
+ Returns:
540
+ dict: A dictionary containing:
541
+ - predictions: Human-readable binary labels for each sequence
542
+ - logits: Raw logits from the model
543
+ - confidence: Confidence scores for predictions
544
+ - last_hidden_state: Final hidden states
545
+
546
+ Example:
547
+ >>> # Inference on a single sequence
548
+ >>> results = model.inference("ATCGATCG")
549
+ >>> print(results['predictions']) # tensor([1, 0, 1, 0])
550
+ """
551
+ return self.predict(sequence_or_inputs, **kwargs)
552
+
553
+
554
+ class OmniModelForTokenClassificationWith2DStructure(
555
+ OmniModelForTokenClassification
556
+ ):
557
+ def __init__(self, config_or_model, tokenizer, *args, **kwargs):
558
+ super().__init__(config_or_model, tokenizer, *args, **kwargs)
559
+ self.metadata["model_name"] = self.__class__.__name__
560
+ self.pooler = OmniPooling(self.config)
561
+ self.model_info()
562
+
563
+ def forward(self, **inputs):
564
+ labels = inputs.pop("labels", None)
565
+ last_hidden_state = self.last_hidden_state_forward(**inputs)
566
+ last_hidden_state = self.dropout(last_hidden_state)
567
+ last_hidden_state = self.activation(last_hidden_state)
568
+ logits = self.classifier(last_hidden_state)
569
+ logits = self.softmax(logits)
570
+ outputs = {
571
+ "logits": logits,
572
+ "last_hidden_state": last_hidden_state,
573
+ "labels": labels,
574
+ }
575
+ return outputs
576
+
577
+
578
+ class OmniModelForSequenceClassificationWith2DStructure(
579
+ OmniModelForSequenceClassification
580
+ ):
581
+ def __init__(self, config_or_model, tokenizer, *args, **kwargs):
582
+ super().__init__(config_or_model, tokenizer, *args, **kwargs)
583
+ self.metadata["model_name"] = self.__class__.__name__
584
+ self.pooler = OmniPooling(self.config)
585
+ self.model_info()
586
+
587
+ def forward(self, **inputs):
588
+ labels = inputs.pop("labels", None)
589
+ last_hidden_state = self.last_hidden_state_forward(**inputs)
590
+ last_hidden_state = self.dropout(last_hidden_state)
591
+ last_hidden_state = self.activation(last_hidden_state)
592
+ last_hidden_state = self.pooler(inputs, last_hidden_state)
593
+ logits = self.classifier(last_hidden_state)
594
+ logits = self.softmax(logits)
595
+
596
+ outputs = {
597
+ "logits": logits,
598
+ "last_hidden_state": last_hidden_state,
599
+ "labels": labels,
600
+ }
601
+ return outputs
602
+
603
+
604
+ class OmniModelForMultiLabelSequenceClassificationWith2DStructure(
605
+ OmniModelForSequenceClassificationWith2DStructure
606
+ ):
607
+ def __init__(self, config_or_model, tokenizer, *args, **kwargs):
608
+ super().__init__(config_or_model, tokenizer, *args, **kwargs)
609
+ self.metadata["model_name"] = self.__class__.__name__
610
+ self.softmax = torch.nn.Sigmoid()
611
+ self.loss_fn = torch.nn.BCELoss()
612
+ self.model_info()
613
+
614
+ def loss_function(self, logits, labels):
615
+ loss = self.loss_fn(logits.view(-1), labels.view(-1).to(torch.float32))
616
+ return loss
617
+
618
+ def predict(self, sequence_or_inputs, **kwargs):
619
+ raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
620
+
621
+ logits = raw_outputs["logits"]
622
+ last_hidden_state = raw_outputs["last_hidden_state"]
623
+
624
+ predictions = []
625
+ for i in range(logits.shape[0]):
626
+ prediction = logits[i].ge(0.5).to(torch.int).cpu()
627
+ predictions.append(prediction)
628
+
629
+ outputs = {
630
+ "predictions": (
631
+ torch.vstack(predictions).to(self.model.device)
632
+ if predictions[0].shape
633
+ else torch.tensor(predictions).to(self.model.device)
634
+ ),
635
+ "logits": logits,
636
+ "last_hidden_state": last_hidden_state,
637
+ }
638
+
639
+ return outputs
640
+
641
+ def inference(self, sequence_or_inputs, **kwargs):
642
+ return self.predict(sequence_or_inputs, **kwargs)
@@ -0,0 +1,12 @@
1
+ # -*- coding: utf-8 -*-
2
+ # file: __init__.py
3
+ # time: 01:51 06/05/2024
4
+ # author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
+ # github: https://github.com/yangheng95
6
+ # huggingface: https://huggingface.co/yangheng
7
+ # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
8
+ # Copyright (C) 2019-2024. All Rights Reserved.
9
+ """
10
+ This package contains modules for embedding models.
11
+ """
12
+