deepchopper 1.3.0__cp310-abi3-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepchopper/__init__.py +9 -0
- deepchopper/__init__.pyi +67 -0
- deepchopper/__main__.py +4 -0
- deepchopper/cli.py +260 -0
- deepchopper/data/__init__.py +15 -0
- deepchopper/data/components/__init__.py +1 -0
- deepchopper/data/encode_fq.py +41 -0
- deepchopper/data/fq_datamodule.py +352 -0
- deepchopper/data/hg_data.py +39 -0
- deepchopper/data/only_fq.py +388 -0
- deepchopper/deepchopper.abi3.so +0 -0
- deepchopper/eval.py +86 -0
- deepchopper/models/__init__.py +4 -0
- deepchopper/models/basic_module.py +243 -0
- deepchopper/models/callbacks.py +57 -0
- deepchopper/models/cnn.py +54 -0
- deepchopper/models/components/__init__.py +1 -0
- deepchopper/models/dc_hg.py +163 -0
- deepchopper/models/llm/__init__.py +32 -0
- deepchopper/models/llm/caduceus.py +55 -0
- deepchopper/models/llm/components.py +99 -0
- deepchopper/models/llm/head.py +102 -0
- deepchopper/models/llm/hyena.py +41 -0
- deepchopper/models/llm/metric.py +44 -0
- deepchopper/models/llm/tokenizer.py +205 -0
- deepchopper/models/transformer.py +107 -0
- deepchopper/py.typed +0 -0
- deepchopper/train.py +109 -0
- deepchopper/ui/__init__.py +1 -0
- deepchopper/ui/main.py +189 -0
- deepchopper/utils/__init__.py +37 -0
- deepchopper/utils/instantiators.py +54 -0
- deepchopper/utils/logging_utils.py +53 -0
- deepchopper/utils/preprocess.py +62 -0
- deepchopper/utils/print.py +102 -0
- deepchopper/utils/pylogger.py +57 -0
- deepchopper/utils/rich_utils.py +100 -0
- deepchopper/utils/utils.py +138 -0
- deepchopper-1.3.0.dist-info/METADATA +254 -0
- deepchopper-1.3.0.dist-info/RECORD +43 -0
- deepchopper-1.3.0.dist-info/WHEEL +4 -0
- deepchopper-1.3.0.dist-info/entry_points.txt +2 -0
- deepchopper-1.3.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from huggingface_hub import PyTorchModelHubMixin
|
|
5
|
+
from lightning import LightningModule
|
|
6
|
+
from torch import nn
|
|
7
|
+
from torchmetrics import MaxMetric, MeanMetric
|
|
8
|
+
from torchmetrics.classification import F1Score, Precision, Recall
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ContinuousIntervalLoss(nn.Module):
|
|
12
|
+
"""A custom loss function that penalizes the model for predicting different classes in consecutive positions."""
|
|
13
|
+
|
|
14
|
+
def __init__(self, lambda_penalty: float = 0, **kwargs):
|
|
15
|
+
super().__init__()
|
|
16
|
+
self.base = torch.nn.CrossEntropyLoss(**kwargs)
|
|
17
|
+
self.lambda_penalty = lambda_penalty
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def ignore_index(self):
|
|
21
|
+
return self.base.ignore_index
|
|
22
|
+
|
|
23
|
+
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
24
|
+
loss = self.base(pred, target)
|
|
25
|
+
if self.lambda_penalty == 0:
|
|
26
|
+
return loss
|
|
27
|
+
valid_mask = target != self.ignore_index
|
|
28
|
+
true_pred = pred.argmax(-1)[valid_mask]
|
|
29
|
+
true_target = target[valid_mask]
|
|
30
|
+
penalty = self.lambda_penalty * (true_pred[1:] != true_target[:-1]).float().mean()
|
|
31
|
+
return loss + penalty
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class TokenClassificationLit(LightningModule, PyTorchModelHubMixin):
|
|
35
|
+
"""A PyTorch Lightning module for training a token classification model."""
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
net: nn.Module,
|
|
40
|
+
optimizer: torch.optim.Optimizer,
|
|
41
|
+
scheduler: torch.optim.lr_scheduler.LRScheduler,
|
|
42
|
+
criterion: nn.Module,
|
|
43
|
+
*,
|
|
44
|
+
compile: bool,
|
|
45
|
+
):
|
|
46
|
+
"""Genomics Benchmark CNN model for PyTorch Lightning.
|
|
47
|
+
|
|
48
|
+
:param net: The CNN model.
|
|
49
|
+
:param scheduler: The learning rate scheduler to use for training.
|
|
50
|
+
"""
|
|
51
|
+
super().__init__()
|
|
52
|
+
|
|
53
|
+
self.example_input_array = {
|
|
54
|
+
"input_ids": torch.randint(0, 11, (1, 1000)),
|
|
55
|
+
"input_quals": torch.rand(1, 1000),
|
|
56
|
+
} # [batch, seq_len]
|
|
57
|
+
|
|
58
|
+
# this line allows to access init params with 'self.hparams' attribute
|
|
59
|
+
# also ensures init params will be stored in ckpt
|
|
60
|
+
self.save_hyperparameters(logger=False, ignore=["net", "criterion"])
|
|
61
|
+
self.net = net
|
|
62
|
+
# loss function
|
|
63
|
+
self.criterion = criterion
|
|
64
|
+
|
|
65
|
+
# metric objects for calculating and averaging accuracy across batches
|
|
66
|
+
self.train_acc = F1Score(
|
|
67
|
+
task="binary", num_classes=net.number_of_classes, ignore_index=self.criterion.ignore_index
|
|
68
|
+
)
|
|
69
|
+
self.val_acc = F1Score(
|
|
70
|
+
task="binary", num_classes=net.number_of_classes, ignore_index=self.criterion.ignore_index
|
|
71
|
+
)
|
|
72
|
+
self.test_acc = F1Score(
|
|
73
|
+
task="binary", num_classes=net.number_of_classes, ignore_index=self.criterion.ignore_index
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
self.test_precision = Precision(
|
|
77
|
+
task="binary", num_classes=net.number_of_classes, ignore_index=self.criterion.ignore_index
|
|
78
|
+
)
|
|
79
|
+
self.test_recall = Recall(
|
|
80
|
+
task="binary", num_classes=net.number_of_classes, ignore_index=self.criterion.ignore_index
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# for averaging loss across batches
|
|
84
|
+
self.train_loss = MeanMetric()
|
|
85
|
+
self.val_loss = MeanMetric()
|
|
86
|
+
self.test_loss = MeanMetric()
|
|
87
|
+
# for tracking best so far validation accuracy
|
|
88
|
+
self.val_acc_best = MaxMetric()
|
|
89
|
+
|
|
90
|
+
def forward(
|
|
91
|
+
self,
|
|
92
|
+
input_ids: torch.Tensor,
|
|
93
|
+
input_quals: torch.Tensor,
|
|
94
|
+
) -> torch.Tensor:
|
|
95
|
+
"""Perform a forward pass through the model `self.net`.
|
|
96
|
+
|
|
97
|
+
:param x: A tensor of images.
|
|
98
|
+
:return: A tensor of logits.
|
|
99
|
+
"""
|
|
100
|
+
return self.net(input_ids, input_quals)
|
|
101
|
+
|
|
102
|
+
def on_train_start(self) -> None:
|
|
103
|
+
"""Lightning hook that is called when training begins."""
|
|
104
|
+
# by default lightning executes validation step sanity checks before training starts,
|
|
105
|
+
# so it's worth to make sure validation metrics don't store results from these checks
|
|
106
|
+
self.val_loss.reset()
|
|
107
|
+
self.val_acc.reset()
|
|
108
|
+
self.val_acc_best.reset()
|
|
109
|
+
|
|
110
|
+
def model_step(self, batch: tuple[torch.Tensor, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
111
|
+
"""Perform a single model step on a batch of data.
|
|
112
|
+
|
|
113
|
+
:param batch: A batch of data (a tuple) containing the input tensor of images and target labels.
|
|
114
|
+
:return: A tuple containing (in order):
|
|
115
|
+
- A tensor of losses.
|
|
116
|
+
- A tensor of predictions.
|
|
117
|
+
- A tensor of target labels.
|
|
118
|
+
"""
|
|
119
|
+
input_ids = batch["input_ids"]
|
|
120
|
+
input_quals = batch["input_quals"]
|
|
121
|
+
logits = self.forward(input_ids, input_quals)
|
|
122
|
+
loss = self.criterion(logits.reshape(-1, logits.size(-1)), batch["labels"].long().view(-1))
|
|
123
|
+
preds = torch.argmax(logits, dim=-1)
|
|
124
|
+
return loss, preds, batch["labels"]
|
|
125
|
+
|
|
126
|
+
def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
|
|
127
|
+
"""Perform a single training step on a batch of data from the training set.
|
|
128
|
+
|
|
129
|
+
:param batch: A batch of data (a tuple) containing the input tensor of images and target
|
|
130
|
+
labels.
|
|
131
|
+
:param batch_idx: The index of the current batch.
|
|
132
|
+
:return: A tensor of losses between model predictions and targets.
|
|
133
|
+
"""
|
|
134
|
+
loss, preds, targets = self.model_step(batch)
|
|
135
|
+
|
|
136
|
+
# update and log metrics
|
|
137
|
+
self.train_loss(loss)
|
|
138
|
+
self.train_acc(preds, targets)
|
|
139
|
+
|
|
140
|
+
self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True)
|
|
141
|
+
self.log("train/f1", self.train_acc, on_step=False, on_epoch=True, prog_bar=True)
|
|
142
|
+
|
|
143
|
+
# return loss or backpropagation will fail
|
|
144
|
+
return loss
|
|
145
|
+
|
|
146
|
+
def on_train_epoch_end(self) -> None:
|
|
147
|
+
"""Lightning hook that is called when a training epoch ends."""
|
|
148
|
+
|
|
149
|
+
def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
|
|
150
|
+
"""Perform a single validation step on a batch of data from the validation set.
|
|
151
|
+
|
|
152
|
+
:param batch: A batch of data (a tuple) containing the input tensor of images and target
|
|
153
|
+
labels.
|
|
154
|
+
:param batch_idx: The index of the current batch.
|
|
155
|
+
"""
|
|
156
|
+
loss, preds, targets = self.model_step(batch)
|
|
157
|
+
|
|
158
|
+
# update and log metrics
|
|
159
|
+
self.val_loss(loss)
|
|
160
|
+
self.val_acc(preds, targets)
|
|
161
|
+
|
|
162
|
+
self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)
|
|
163
|
+
self.log("val/f1", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
|
|
164
|
+
|
|
165
|
+
def on_validation_epoch_end(self) -> None:
|
|
166
|
+
"""Lightning hook that is called when a validation epoch ends."""
|
|
167
|
+
acc = self.val_acc.compute() # get current val acc
|
|
168
|
+
self.val_acc_best(acc) # update best so far val acc
|
|
169
|
+
# log `val_acc_best` as a value through `.compute()` method, instead of as a metric object
|
|
170
|
+
# otherwise metric would be reset by lightning after each epoch
|
|
171
|
+
self.log("val/f1_best", self.val_acc_best.compute(), sync_dist=True, prog_bar=True)
|
|
172
|
+
|
|
173
|
+
def test_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
|
|
174
|
+
"""Perform a single test step on a batch of data from the test set.
|
|
175
|
+
|
|
176
|
+
:param batch: A batch of data (a tuple) containing the input tensor of images and target
|
|
177
|
+
labels.
|
|
178
|
+
:param batch_idx: The index of the current batch.
|
|
179
|
+
"""
|
|
180
|
+
loss, preds, targets = self.model_step(batch)
|
|
181
|
+
|
|
182
|
+
# update and log metrics
|
|
183
|
+
self.test_loss(loss)
|
|
184
|
+
self.test_acc(preds, targets)
|
|
185
|
+
|
|
186
|
+
self.test_precision(preds, targets)
|
|
187
|
+
self.test_recall(preds, targets)
|
|
188
|
+
|
|
189
|
+
self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True)
|
|
190
|
+
self.log("test/f1", self.test_acc, on_step=False, on_epoch=True, prog_bar=True)
|
|
191
|
+
|
|
192
|
+
def on_test_epoch_end(self) -> None:
|
|
193
|
+
"""Lightning hook that is called when a test epoch ends."""
|
|
194
|
+
self.log("test/precision", self.test_precision)
|
|
195
|
+
self.log("test/recall", self.test_recall)
|
|
196
|
+
|
|
197
|
+
def predict_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
|
|
198
|
+
"""Perform a single prediction step on a batch of data from the test set.
|
|
199
|
+
|
|
200
|
+
:param batch: A batch of data (a tuple) containing the input tensor of images and target
|
|
201
|
+
labels.
|
|
202
|
+
:param batch_idx: The index of the current batch.
|
|
203
|
+
"""
|
|
204
|
+
input_ids = batch["input_ids"]
|
|
205
|
+
input_quals = batch["input_quals"]
|
|
206
|
+
logits = self.forward(input_ids, input_quals)
|
|
207
|
+
return logits, batch["labels"]
|
|
208
|
+
|
|
209
|
+
def setup(self, stage: str) -> None:
|
|
210
|
+
"""Lightning hook that is called at the beginning of fit (train + validate), validate, test, or predict.
|
|
211
|
+
|
|
212
|
+
This is a good hook when you need to build models dynamically or adjust something about
|
|
213
|
+
them. This hook is called on every process when using DDP.
|
|
214
|
+
|
|
215
|
+
:param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
|
|
216
|
+
"""
|
|
217
|
+
if self.hparams.compile and stage == "fit":
|
|
218
|
+
self.net = torch.compile(self.net)
|
|
219
|
+
|
|
220
|
+
def configure_optimizers(self) -> dict[str, Any]:
|
|
221
|
+
"""Choose what optimizers and learning-rate schedulers to use in your optimization.
|
|
222
|
+
|
|
223
|
+
Normally you'd need one. But in the case of GANs or similar you might have multiple.
|
|
224
|
+
|
|
225
|
+
Examples:
|
|
226
|
+
https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
|
|
227
|
+
|
|
228
|
+
:return: A dict containing the configured optimizers and learning-rate schedulers to be used for training.
|
|
229
|
+
"""
|
|
230
|
+
optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())
|
|
231
|
+
|
|
232
|
+
if self.hparams.scheduler is not None:
|
|
233
|
+
scheduler = self.hparams.scheduler(optimizer=optimizer)
|
|
234
|
+
return {
|
|
235
|
+
"optimizer": optimizer,
|
|
236
|
+
"lr_scheduler": {
|
|
237
|
+
"scheduler": scheduler,
|
|
238
|
+
"monitor": "val/loss",
|
|
239
|
+
"interval": "epoch",
|
|
240
|
+
"frequency": 1,
|
|
241
|
+
},
|
|
242
|
+
}
|
|
243
|
+
return {"optimizer": optimizer}
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from lightning.pytorch.callbacks import BasePredictionWriter
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class CustomWriter(BasePredictionWriter):
|
|
8
|
+
def __init__(self, output_dir, write_interval="epoch"):
|
|
9
|
+
super().__init__(write_interval)
|
|
10
|
+
self.output_dir = Path(output_dir)
|
|
11
|
+
|
|
12
|
+
def write_on_batch_end(self, trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx):
|
|
13
|
+
folder = self.output_dir / str(dataloader_idx)
|
|
14
|
+
if not folder.exists():
|
|
15
|
+
folder.mkdir(parents=True, exist_ok=True)
|
|
16
|
+
|
|
17
|
+
save_prediction = {
|
|
18
|
+
"prediction": prediction[0].cpu(),
|
|
19
|
+
"target": prediction[1].to(torch.int64).cpu(),
|
|
20
|
+
"seq": batch["input_ids"].cpu(),
|
|
21
|
+
"qual": batch["input_quals"].cpu(),
|
|
22
|
+
"id": batch["id"].to(torch.int64).cpu(),
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
torch.save(save_prediction, folder / f"{trainer.global_rank}_{batch_idx}.pt")
|
|
26
|
+
|
|
27
|
+
def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
|
|
28
|
+
# WARN: This is a simple implementation that saves all predictions in a single file
|
|
29
|
+
if not self.output_dir.exists():
|
|
30
|
+
self.output_dir.mkdir(parents=False, exist_ok=True)
|
|
31
|
+
|
|
32
|
+
torch.save(predictions, self.output_dir / "predictions.pt")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class PredictionWriter(BasePredictionWriter):
|
|
36
|
+
def __init__(self, output_dir, write_interval="epoch"):
|
|
37
|
+
super().__init__(write_interval)
|
|
38
|
+
self.output_dir = Path(output_dir)
|
|
39
|
+
|
|
40
|
+
def write_on_batch_end(self, trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx):
|
|
41
|
+
folder = self.output_dir / str(dataloader_idx)
|
|
42
|
+
if not folder.exists():
|
|
43
|
+
folder.mkdir(parents=True, exist_ok=True)
|
|
44
|
+
|
|
45
|
+
save_prediction = {
|
|
46
|
+
"prediction": prediction[0].cpu(),
|
|
47
|
+
"id": batch["id"].to(torch.int64).cpu(),
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
torch.save(save_prediction, folder / f"{trainer.global_rank}_{batch_idx}.pt")
|
|
51
|
+
|
|
52
|
+
def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
|
|
53
|
+
# WARN: This is a simple implementation that saves all predictions in a single file
|
|
54
|
+
if not self.output_dir.exists():
|
|
55
|
+
self.output_dir.mkdir(parents=False, exist_ok=True)
|
|
56
|
+
|
|
57
|
+
torch.save(predictions, self.output_dir / "predictions.pt")
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F # noqa: N812
|
|
3
|
+
from torch import nn
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class BenchmarkCNN(nn.Module):
|
|
7
|
+
"""BenchmarkCNN."""
|
|
8
|
+
|
|
9
|
+
def __init__(self, number_of_classes, vocab_size, num_filters, filter_sizes, embedding_dim=100):
|
|
10
|
+
"""Genomics Benchmark CNN model.
|
|
11
|
+
|
|
12
|
+
`embedding_dim` = 100 comes from:
|
|
13
|
+
https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks/tree/main/experiments/torch_cnn_experiments
|
|
14
|
+
"""
|
|
15
|
+
super().__init__()
|
|
16
|
+
self.number_of_classes = number_of_classes
|
|
17
|
+
self.embeddings = nn.Embedding(vocab_size, embedding_dim)
|
|
18
|
+
|
|
19
|
+
self.qual_linear1 = nn.Sequential(
|
|
20
|
+
nn.Linear(1, embedding_dim),
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
layers = []
|
|
24
|
+
in_channels = embedding_dim
|
|
25
|
+
for idx, fs in enumerate(filter_sizes):
|
|
26
|
+
layers.append(
|
|
27
|
+
nn.Conv1d(
|
|
28
|
+
in_channels=in_channels,
|
|
29
|
+
out_channels=num_filters[idx],
|
|
30
|
+
kernel_size=fs,
|
|
31
|
+
padding="same",
|
|
32
|
+
)
|
|
33
|
+
)
|
|
34
|
+
layers.append(nn.BatchNorm1d(num_filters[idx]))
|
|
35
|
+
layers.append(nn.ReLU())
|
|
36
|
+
in_channels = num_filters[idx]
|
|
37
|
+
|
|
38
|
+
self.model = nn.Sequential(*layers)
|
|
39
|
+
|
|
40
|
+
# use number of kernel same as the length of the sequence and average pooling
|
|
41
|
+
# then flatten and use dense layers
|
|
42
|
+
self.dense_model = nn.Sequential(
|
|
43
|
+
nn.Linear(in_channels, number_of_classes),
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
def forward(
|
|
47
|
+
self, input_ids: torch.Tensor, input_quals: torch.Tensor
|
|
48
|
+
): # Adding `state` to be consistent with other models
|
|
49
|
+
x = self.embeddings(input_ids)
|
|
50
|
+
x = F.relu(x + self.qual_linear1(input_quals.unsqueeze(-1)))
|
|
51
|
+
x = x.transpose(1, 2) # [batch_size, embedding_dim, input_len]
|
|
52
|
+
x = self.model(x)
|
|
53
|
+
x = x.transpose(1, 2)
|
|
54
|
+
return self.dense_model(x) # [batch_size, input_len, num_filters]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""model components."""
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
from functools import partial
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from . import basic_module, llm
|
|
6
|
+
from .basic_module import TokenClassificationLit
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class DeepChopper:
|
|
10
|
+
"""DeepChopper: A genomic language model for chimera artifact detection.
|
|
11
|
+
|
|
12
|
+
This class provides convenient methods to load DeepChopper models from checkpoints or
|
|
13
|
+
from pretrained models on the Hugging Face Hub, and to push trained models to the Hub.
|
|
14
|
+
|
|
15
|
+
Example:
|
|
16
|
+
Load a pretrained model:
|
|
17
|
+
>>> model = DeepChopper.from_pretrained("yangliz5/deepchopper")
|
|
18
|
+
|
|
19
|
+
Load from a local checkpoint:
|
|
20
|
+
>>> model = DeepChopper.from_checkpoint("path/to/checkpoint.ckpt")
|
|
21
|
+
|
|
22
|
+
Push a model to Hugging Face Hub:
|
|
23
|
+
>>> model = DeepChopper.to_hub("username/model-name", "path/to/checkpoint.ckpt")
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
@staticmethod
|
|
27
|
+
def to_hub(
|
|
28
|
+
model_name: str,
|
|
29
|
+
checkpoint_path: str,
|
|
30
|
+
*,
|
|
31
|
+
commit_message: str = "Upload DeepChopper model",
|
|
32
|
+
private: bool = False,
|
|
33
|
+
token: str | None = None,
|
|
34
|
+
):
|
|
35
|
+
"""Load a model from a checkpoint and push it to the Hugging Face Hub.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
model_name: The repository ID on Hugging Face Hub (format: username/model-name)
|
|
39
|
+
checkpoint_path: Path to the local checkpoint file (.ckpt)
|
|
40
|
+
commit_message: Commit message for the upload (default: "Upload DeepChopper model")
|
|
41
|
+
private: Whether to create a private repository (default: False)
|
|
42
|
+
token: Hugging Face API token. If None, uses the stored token from `huggingface-cli login`
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
The loaded TokenClassificationLit model
|
|
46
|
+
|
|
47
|
+
Example:
|
|
48
|
+
>>> model = DeepChopper.to_hub(
|
|
49
|
+
... "username/deepchopper-v1",
|
|
50
|
+
... "epoch_012_f1_0.9947.ckpt",
|
|
51
|
+
... commit_message="Upload DeepChopper v1.0",
|
|
52
|
+
... private=False
|
|
53
|
+
... )
|
|
54
|
+
"""
|
|
55
|
+
model = DeepChopper.from_checkpoint(checkpoint_path)
|
|
56
|
+
|
|
57
|
+
# Prepare kwargs for push_to_hub
|
|
58
|
+
push_kwargs = {
|
|
59
|
+
"repo_id": model_name,
|
|
60
|
+
"commit_message": commit_message,
|
|
61
|
+
"private": private,
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
if token is not None:
|
|
65
|
+
push_kwargs["token"] = token
|
|
66
|
+
|
|
67
|
+
model.push_to_hub(**push_kwargs)
|
|
68
|
+
return model
|
|
69
|
+
|
|
70
|
+
@staticmethod
|
|
71
|
+
def from_checkpoint(checkpoint_path: str):
|
|
72
|
+
"""Load a DeepChopper model from a local checkpoint file.
|
|
73
|
+
|
|
74
|
+
This method creates a new TokenClassificationLit model with the HyenaDNA backbone
|
|
75
|
+
and loads the weights from the specified checkpoint file.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
checkpoint_path: Path to the checkpoint file (.ckpt) containing model weights
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
A TokenClassificationLit model loaded with the checkpoint weights
|
|
82
|
+
|
|
83
|
+
Example:
|
|
84
|
+
>>> model = DeepChopper.from_checkpoint("epoch_012_f1_0.9947.ckpt")
|
|
85
|
+
|
|
86
|
+
Note:
|
|
87
|
+
This function loads checkpoints with weights_only=False to support PyTorch Lightning
|
|
88
|
+
checkpoints containing optimizer/scheduler configs. Only load checkpoints from trusted sources.
|
|
89
|
+
"""
|
|
90
|
+
model = TokenClassificationLit(
|
|
91
|
+
net=llm.hyena.TokenClassificationModule(
|
|
92
|
+
number_of_classes=2,
|
|
93
|
+
backbone_name="hyenadna-small-32k-seqlen",
|
|
94
|
+
freeze_backbone=False,
|
|
95
|
+
head=llm.TokenClassificationHead(
|
|
96
|
+
input_size=256,
|
|
97
|
+
lin1_size=1024,
|
|
98
|
+
lin2_size=1024,
|
|
99
|
+
num_class=2,
|
|
100
|
+
use_identity_layer_for_qual=True,
|
|
101
|
+
use_qual=True,
|
|
102
|
+
),
|
|
103
|
+
),
|
|
104
|
+
optimizer=partial( # type: ignore[arg-type]
|
|
105
|
+
torch.optim.Adam,
|
|
106
|
+
lr=0.0002,
|
|
107
|
+
weight_decay=0,
|
|
108
|
+
),
|
|
109
|
+
scheduler=partial(torch.optim.lr_scheduler.ReduceLROnPlateau, mode="min", factor=0.1, patience=10), # type: ignore[arg-type]
|
|
110
|
+
criterion=basic_module.ContinuousIntervalLoss(lambda_penalty=0),
|
|
111
|
+
compile=False,
|
|
112
|
+
)
|
|
113
|
+
# weights_only=False is required for PyTorch Lightning checkpoints that contain
|
|
114
|
+
# optimizer/scheduler configs (functools.partial). Only use with trusted checkpoints.
|
|
115
|
+
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
|
116
|
+
model.load_state_dict(checkpoint["state_dict"])
|
|
117
|
+
return model
|
|
118
|
+
|
|
119
|
+
@staticmethod
|
|
120
|
+
def from_pretrained(model_name: str):
|
|
121
|
+
"""Load a pretrained DeepChopper model from the Hugging Face Hub.
|
|
122
|
+
|
|
123
|
+
This method downloads and loads a pretrained model from the Hugging Face Hub.
|
|
124
|
+
The model architecture is automatically configured to match the expected
|
|
125
|
+
HyenaDNA-based token classification setup.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
model_name: The repository ID on Hugging Face Hub (e.g., "yangliz5/deepchopper")
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
A TokenClassificationLit model loaded with pretrained weights
|
|
132
|
+
|
|
133
|
+
Example:
|
|
134
|
+
>>> model = DeepChopper.from_pretrained("yangliz5/deepchopper")
|
|
135
|
+
|
|
136
|
+
Note:
|
|
137
|
+
This requires an internet connection to download the model from Hugging Face Hub.
|
|
138
|
+
The model will be cached locally after the first download.
|
|
139
|
+
"""
|
|
140
|
+
return TokenClassificationLit.from_pretrained(
|
|
141
|
+
model_name,
|
|
142
|
+
net=llm.hyena.TokenClassificationModule(
|
|
143
|
+
number_of_classes=2,
|
|
144
|
+
backbone_name="hyenadna-small-32k-seqlen",
|
|
145
|
+
freeze_backbone=False,
|
|
146
|
+
head=llm.TokenClassificationHead(
|
|
147
|
+
input_size=256,
|
|
148
|
+
lin1_size=1024,
|
|
149
|
+
lin2_size=1024,
|
|
150
|
+
num_class=2,
|
|
151
|
+
use_identity_layer_for_qual=True,
|
|
152
|
+
use_qual=True,
|
|
153
|
+
),
|
|
154
|
+
),
|
|
155
|
+
optimizer=partial( # type: ignore[arg-type]
|
|
156
|
+
torch.optim.Adam,
|
|
157
|
+
lr=0.0002,
|
|
158
|
+
weight_decay=0,
|
|
159
|
+
),
|
|
160
|
+
scheduler=partial(torch.optim.lr_scheduler.ReduceLROnPlateau, mode="min", factor=0.1, patience=10), # type: ignore[arg-type]
|
|
161
|
+
criterion=basic_module.ContinuousIntervalLoss(lambda_penalty=0),
|
|
162
|
+
compile=False,
|
|
163
|
+
)
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
"""hyena model components and utilities."""
|
|
2
|
+
|
|
3
|
+
from .components import (
|
|
4
|
+
HyenadnaMaxLengths,
|
|
5
|
+
TokenClassification,
|
|
6
|
+
TokenClassificationConfig,
|
|
7
|
+
)
|
|
8
|
+
from .head import TokenClassificationHead
|
|
9
|
+
from .hyena import TokenClassificationModule
|
|
10
|
+
from .metric import IGNORE_INDEX, compute_metrics
|
|
11
|
+
from .tokenizer import (
|
|
12
|
+
DataCollatorForTokenClassificationWithQual,
|
|
13
|
+
load_tokenizer_from_hyena_model,
|
|
14
|
+
tokenize_and_align_labels_and_quals,
|
|
15
|
+
tokenize_and_align_labels_and_quals_ids,
|
|
16
|
+
tokenize_dataset,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"IGNORE_INDEX",
|
|
21
|
+
"DataCollatorForTokenClassificationWithQual",
|
|
22
|
+
"HyenadnaMaxLengths",
|
|
23
|
+
"TokenClassification",
|
|
24
|
+
"TokenClassificationConfig",
|
|
25
|
+
"TokenClassificationHead",
|
|
26
|
+
"TokenClassificationModule",
|
|
27
|
+
"compute_metrics",
|
|
28
|
+
"load_tokenizer_from_hyena_model",
|
|
29
|
+
"tokenize_and_align_labels_and_quals",
|
|
30
|
+
"tokenize_and_align_labels_and_quals_ids",
|
|
31
|
+
"tokenize_dataset",
|
|
32
|
+
]
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
from transformers import AutoModel
|
|
4
|
+
|
|
5
|
+
BACKBONES = [
|
|
6
|
+
"hyenadna-tiny-1k-seqlen",
|
|
7
|
+
"hyenadna-small-32k-seqlen",
|
|
8
|
+
"hyenadna-medium-160k-seqlen",
|
|
9
|
+
"hyenadna-medium-450k-seqlen",
|
|
10
|
+
"hyenadna-large-1m-seqlen",
|
|
11
|
+
"caduceus-ph_seqlen-131k_d_model-256_n_layer-16",
|
|
12
|
+
"caduceus-ps_seqlen-131k_d_model-256_n_layer-16",
|
|
13
|
+
]
|
|
14
|
+
|
|
15
|
+
# https://github.com/kuleshov-group/caduceus
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class TokenClassificationModule(nn.Module):
|
|
19
|
+
"""Token classification model."""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
number_of_classes: int,
|
|
24
|
+
head: nn.Module,
|
|
25
|
+
backbone_name: str = "caduceus-ph_seqlen-131k_d_model-256_n_layer-16",
|
|
26
|
+
):
|
|
27
|
+
super().__init__()
|
|
28
|
+
self.number_of_classes = number_of_classes
|
|
29
|
+
self.backbone_name = backbone_name
|
|
30
|
+
|
|
31
|
+
if "hyenadna" in backbone_name:
|
|
32
|
+
model_name = f"LongSafari/{backbone_name}-hf"
|
|
33
|
+
elif "caduceus" in backbone_name:
|
|
34
|
+
model_name = f"kuleshov-group/{backbone_name}"
|
|
35
|
+
else:
|
|
36
|
+
msg = f"Unknown backbone model: {backbone_name}"
|
|
37
|
+
raise ValueError(msg)
|
|
38
|
+
|
|
39
|
+
self.backbone = AutoModel.from_pretrained(model_name, trust_remote_code=True)
|
|
40
|
+
self.head = head
|
|
41
|
+
|
|
42
|
+
def forward(
|
|
43
|
+
self,
|
|
44
|
+
input_ids: torch.Tensor,
|
|
45
|
+
input_quals: torch.Tensor,
|
|
46
|
+
):
|
|
47
|
+
transformer_outputs = self.backbone(
|
|
48
|
+
input_ids,
|
|
49
|
+
inputs_embeds=None,
|
|
50
|
+
output_hidden_states=None,
|
|
51
|
+
return_dict=None,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
hidden_states = transformer_outputs[0]
|
|
55
|
+
return self.head(hidden_states, input_quals)
|