deepchopper 1.3.0__cp310-abi3-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. deepchopper/__init__.py +9 -0
  2. deepchopper/__init__.pyi +67 -0
  3. deepchopper/__main__.py +4 -0
  4. deepchopper/cli.py +260 -0
  5. deepchopper/data/__init__.py +15 -0
  6. deepchopper/data/components/__init__.py +1 -0
  7. deepchopper/data/encode_fq.py +41 -0
  8. deepchopper/data/fq_datamodule.py +352 -0
  9. deepchopper/data/hg_data.py +39 -0
  10. deepchopper/data/only_fq.py +388 -0
  11. deepchopper/deepchopper.abi3.so +0 -0
  12. deepchopper/eval.py +86 -0
  13. deepchopper/models/__init__.py +4 -0
  14. deepchopper/models/basic_module.py +243 -0
  15. deepchopper/models/callbacks.py +57 -0
  16. deepchopper/models/cnn.py +54 -0
  17. deepchopper/models/components/__init__.py +1 -0
  18. deepchopper/models/dc_hg.py +163 -0
  19. deepchopper/models/llm/__init__.py +32 -0
  20. deepchopper/models/llm/caduceus.py +55 -0
  21. deepchopper/models/llm/components.py +99 -0
  22. deepchopper/models/llm/head.py +102 -0
  23. deepchopper/models/llm/hyena.py +41 -0
  24. deepchopper/models/llm/metric.py +44 -0
  25. deepchopper/models/llm/tokenizer.py +205 -0
  26. deepchopper/models/transformer.py +107 -0
  27. deepchopper/py.typed +0 -0
  28. deepchopper/train.py +109 -0
  29. deepchopper/ui/__init__.py +1 -0
  30. deepchopper/ui/main.py +189 -0
  31. deepchopper/utils/__init__.py +37 -0
  32. deepchopper/utils/instantiators.py +54 -0
  33. deepchopper/utils/logging_utils.py +53 -0
  34. deepchopper/utils/preprocess.py +62 -0
  35. deepchopper/utils/print.py +102 -0
  36. deepchopper/utils/pylogger.py +57 -0
  37. deepchopper/utils/rich_utils.py +100 -0
  38. deepchopper/utils/utils.py +138 -0
  39. deepchopper-1.3.0.dist-info/METADATA +254 -0
  40. deepchopper-1.3.0.dist-info/RECORD +43 -0
  41. deepchopper-1.3.0.dist-info/WHEEL +4 -0
  42. deepchopper-1.3.0.dist-info/entry_points.txt +2 -0
  43. deepchopper-1.3.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,243 @@
1
+ from typing import Any
2
+
3
+ import torch
4
+ from huggingface_hub import PyTorchModelHubMixin
5
+ from lightning import LightningModule
6
+ from torch import nn
7
+ from torchmetrics import MaxMetric, MeanMetric
8
+ from torchmetrics.classification import F1Score, Precision, Recall
9
+
10
+
11
+ class ContinuousIntervalLoss(nn.Module):
12
+ """A custom loss function that penalizes the model for predicting different classes in consecutive positions."""
13
+
14
+ def __init__(self, lambda_penalty: float = 0, **kwargs):
15
+ super().__init__()
16
+ self.base = torch.nn.CrossEntropyLoss(**kwargs)
17
+ self.lambda_penalty = lambda_penalty
18
+
19
+ @property
20
+ def ignore_index(self):
21
+ return self.base.ignore_index
22
+
23
+ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
24
+ loss = self.base(pred, target)
25
+ if self.lambda_penalty == 0:
26
+ return loss
27
+ valid_mask = target != self.ignore_index
28
+ true_pred = pred.argmax(-1)[valid_mask]
29
+ true_target = target[valid_mask]
30
+ penalty = self.lambda_penalty * (true_pred[1:] != true_target[:-1]).float().mean()
31
+ return loss + penalty
32
+
33
+
34
+ class TokenClassificationLit(LightningModule, PyTorchModelHubMixin):
35
+ """A PyTorch Lightning module for training a token classification model."""
36
+
37
+ def __init__(
38
+ self,
39
+ net: nn.Module,
40
+ optimizer: torch.optim.Optimizer,
41
+ scheduler: torch.optim.lr_scheduler.LRScheduler,
42
+ criterion: nn.Module,
43
+ *,
44
+ compile: bool,
45
+ ):
46
+ """Genomics Benchmark CNN model for PyTorch Lightning.
47
+
48
+ :param net: The CNN model.
49
+ :param scheduler: The learning rate scheduler to use for training.
50
+ """
51
+ super().__init__()
52
+
53
+ self.example_input_array = {
54
+ "input_ids": torch.randint(0, 11, (1, 1000)),
55
+ "input_quals": torch.rand(1, 1000),
56
+ } # [batch, seq_len]
57
+
58
+ # this line allows to access init params with 'self.hparams' attribute
59
+ # also ensures init params will be stored in ckpt
60
+ self.save_hyperparameters(logger=False, ignore=["net", "criterion"])
61
+ self.net = net
62
+ # loss function
63
+ self.criterion = criterion
64
+
65
+ # metric objects for calculating and averaging accuracy across batches
66
+ self.train_acc = F1Score(
67
+ task="binary", num_classes=net.number_of_classes, ignore_index=self.criterion.ignore_index
68
+ )
69
+ self.val_acc = F1Score(
70
+ task="binary", num_classes=net.number_of_classes, ignore_index=self.criterion.ignore_index
71
+ )
72
+ self.test_acc = F1Score(
73
+ task="binary", num_classes=net.number_of_classes, ignore_index=self.criterion.ignore_index
74
+ )
75
+
76
+ self.test_precision = Precision(
77
+ task="binary", num_classes=net.number_of_classes, ignore_index=self.criterion.ignore_index
78
+ )
79
+ self.test_recall = Recall(
80
+ task="binary", num_classes=net.number_of_classes, ignore_index=self.criterion.ignore_index
81
+ )
82
+
83
+ # for averaging loss across batches
84
+ self.train_loss = MeanMetric()
85
+ self.val_loss = MeanMetric()
86
+ self.test_loss = MeanMetric()
87
+ # for tracking best so far validation accuracy
88
+ self.val_acc_best = MaxMetric()
89
+
90
+ def forward(
91
+ self,
92
+ input_ids: torch.Tensor,
93
+ input_quals: torch.Tensor,
94
+ ) -> torch.Tensor:
95
+ """Perform a forward pass through the model `self.net`.
96
+
97
+ :param x: A tensor of images.
98
+ :return: A tensor of logits.
99
+ """
100
+ return self.net(input_ids, input_quals)
101
+
102
+ def on_train_start(self) -> None:
103
+ """Lightning hook that is called when training begins."""
104
+ # by default lightning executes validation step sanity checks before training starts,
105
+ # so it's worth to make sure validation metrics don't store results from these checks
106
+ self.val_loss.reset()
107
+ self.val_acc.reset()
108
+ self.val_acc_best.reset()
109
+
110
+ def model_step(self, batch: tuple[torch.Tensor, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
111
+ """Perform a single model step on a batch of data.
112
+
113
+ :param batch: A batch of data (a tuple) containing the input tensor of images and target labels.
114
+ :return: A tuple containing (in order):
115
+ - A tensor of losses.
116
+ - A tensor of predictions.
117
+ - A tensor of target labels.
118
+ """
119
+ input_ids = batch["input_ids"]
120
+ input_quals = batch["input_quals"]
121
+ logits = self.forward(input_ids, input_quals)
122
+ loss = self.criterion(logits.reshape(-1, logits.size(-1)), batch["labels"].long().view(-1))
123
+ preds = torch.argmax(logits, dim=-1)
124
+ return loss, preds, batch["labels"]
125
+
126
+ def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
127
+ """Perform a single training step on a batch of data from the training set.
128
+
129
+ :param batch: A batch of data (a tuple) containing the input tensor of images and target
130
+ labels.
131
+ :param batch_idx: The index of the current batch.
132
+ :return: A tensor of losses between model predictions and targets.
133
+ """
134
+ loss, preds, targets = self.model_step(batch)
135
+
136
+ # update and log metrics
137
+ self.train_loss(loss)
138
+ self.train_acc(preds, targets)
139
+
140
+ self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True)
141
+ self.log("train/f1", self.train_acc, on_step=False, on_epoch=True, prog_bar=True)
142
+
143
+ # return loss or backpropagation will fail
144
+ return loss
145
+
146
+ def on_train_epoch_end(self) -> None:
147
+ """Lightning hook that is called when a training epoch ends."""
148
+
149
+ def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
150
+ """Perform a single validation step on a batch of data from the validation set.
151
+
152
+ :param batch: A batch of data (a tuple) containing the input tensor of images and target
153
+ labels.
154
+ :param batch_idx: The index of the current batch.
155
+ """
156
+ loss, preds, targets = self.model_step(batch)
157
+
158
+ # update and log metrics
159
+ self.val_loss(loss)
160
+ self.val_acc(preds, targets)
161
+
162
+ self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)
163
+ self.log("val/f1", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
164
+
165
+ def on_validation_epoch_end(self) -> None:
166
+ """Lightning hook that is called when a validation epoch ends."""
167
+ acc = self.val_acc.compute() # get current val acc
168
+ self.val_acc_best(acc) # update best so far val acc
169
+ # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object
170
+ # otherwise metric would be reset by lightning after each epoch
171
+ self.log("val/f1_best", self.val_acc_best.compute(), sync_dist=True, prog_bar=True)
172
+
173
+ def test_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
174
+ """Perform a single test step on a batch of data from the test set.
175
+
176
+ :param batch: A batch of data (a tuple) containing the input tensor of images and target
177
+ labels.
178
+ :param batch_idx: The index of the current batch.
179
+ """
180
+ loss, preds, targets = self.model_step(batch)
181
+
182
+ # update and log metrics
183
+ self.test_loss(loss)
184
+ self.test_acc(preds, targets)
185
+
186
+ self.test_precision(preds, targets)
187
+ self.test_recall(preds, targets)
188
+
189
+ self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True)
190
+ self.log("test/f1", self.test_acc, on_step=False, on_epoch=True, prog_bar=True)
191
+
192
+ def on_test_epoch_end(self) -> None:
193
+ """Lightning hook that is called when a test epoch ends."""
194
+ self.log("test/precision", self.test_precision)
195
+ self.log("test/recall", self.test_recall)
196
+
197
+ def predict_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
198
+ """Perform a single prediction step on a batch of data from the test set.
199
+
200
+ :param batch: A batch of data (a tuple) containing the input tensor of images and target
201
+ labels.
202
+ :param batch_idx: The index of the current batch.
203
+ """
204
+ input_ids = batch["input_ids"]
205
+ input_quals = batch["input_quals"]
206
+ logits = self.forward(input_ids, input_quals)
207
+ return logits, batch["labels"]
208
+
209
+ def setup(self, stage: str) -> None:
210
+ """Lightning hook that is called at the beginning of fit (train + validate), validate, test, or predict.
211
+
212
+ This is a good hook when you need to build models dynamically or adjust something about
213
+ them. This hook is called on every process when using DDP.
214
+
215
+ :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
216
+ """
217
+ if self.hparams.compile and stage == "fit":
218
+ self.net = torch.compile(self.net)
219
+
220
+ def configure_optimizers(self) -> dict[str, Any]:
221
+ """Choose what optimizers and learning-rate schedulers to use in your optimization.
222
+
223
+ Normally you'd need one. But in the case of GANs or similar you might have multiple.
224
+
225
+ Examples:
226
+ https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
227
+
228
+ :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training.
229
+ """
230
+ optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())
231
+
232
+ if self.hparams.scheduler is not None:
233
+ scheduler = self.hparams.scheduler(optimizer=optimizer)
234
+ return {
235
+ "optimizer": optimizer,
236
+ "lr_scheduler": {
237
+ "scheduler": scheduler,
238
+ "monitor": "val/loss",
239
+ "interval": "epoch",
240
+ "frequency": 1,
241
+ },
242
+ }
243
+ return {"optimizer": optimizer}
@@ -0,0 +1,57 @@
1
+ from pathlib import Path
2
+
3
+ import torch
4
+ from lightning.pytorch.callbacks import BasePredictionWriter
5
+
6
+
7
+ class CustomWriter(BasePredictionWriter):
8
+ def __init__(self, output_dir, write_interval="epoch"):
9
+ super().__init__(write_interval)
10
+ self.output_dir = Path(output_dir)
11
+
12
+ def write_on_batch_end(self, trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx):
13
+ folder = self.output_dir / str(dataloader_idx)
14
+ if not folder.exists():
15
+ folder.mkdir(parents=True, exist_ok=True)
16
+
17
+ save_prediction = {
18
+ "prediction": prediction[0].cpu(),
19
+ "target": prediction[1].to(torch.int64).cpu(),
20
+ "seq": batch["input_ids"].cpu(),
21
+ "qual": batch["input_quals"].cpu(),
22
+ "id": batch["id"].to(torch.int64).cpu(),
23
+ }
24
+
25
+ torch.save(save_prediction, folder / f"{trainer.global_rank}_{batch_idx}.pt")
26
+
27
+ def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
28
+ # WARN: This is a simple implementation that saves all predictions in a single file
29
+ if not self.output_dir.exists():
30
+ self.output_dir.mkdir(parents=False, exist_ok=True)
31
+
32
+ torch.save(predictions, self.output_dir / "predictions.pt")
33
+
34
+
35
+ class PredictionWriter(BasePredictionWriter):
36
+ def __init__(self, output_dir, write_interval="epoch"):
37
+ super().__init__(write_interval)
38
+ self.output_dir = Path(output_dir)
39
+
40
+ def write_on_batch_end(self, trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx):
41
+ folder = self.output_dir / str(dataloader_idx)
42
+ if not folder.exists():
43
+ folder.mkdir(parents=True, exist_ok=True)
44
+
45
+ save_prediction = {
46
+ "prediction": prediction[0].cpu(),
47
+ "id": batch["id"].to(torch.int64).cpu(),
48
+ }
49
+
50
+ torch.save(save_prediction, folder / f"{trainer.global_rank}_{batch_idx}.pt")
51
+
52
+ def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
53
+ # WARN: This is a simple implementation that saves all predictions in a single file
54
+ if not self.output_dir.exists():
55
+ self.output_dir.mkdir(parents=False, exist_ok=True)
56
+
57
+ torch.save(predictions, self.output_dir / "predictions.pt")
@@ -0,0 +1,54 @@
1
+ import torch
2
+ import torch.nn.functional as F # noqa: N812
3
+ from torch import nn
4
+
5
+
6
+ class BenchmarkCNN(nn.Module):
7
+ """BenchmarkCNN."""
8
+
9
+ def __init__(self, number_of_classes, vocab_size, num_filters, filter_sizes, embedding_dim=100):
10
+ """Genomics Benchmark CNN model.
11
+
12
+ `embedding_dim` = 100 comes from:
13
+ https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks/tree/main/experiments/torch_cnn_experiments
14
+ """
15
+ super().__init__()
16
+ self.number_of_classes = number_of_classes
17
+ self.embeddings = nn.Embedding(vocab_size, embedding_dim)
18
+
19
+ self.qual_linear1 = nn.Sequential(
20
+ nn.Linear(1, embedding_dim),
21
+ )
22
+
23
+ layers = []
24
+ in_channels = embedding_dim
25
+ for idx, fs in enumerate(filter_sizes):
26
+ layers.append(
27
+ nn.Conv1d(
28
+ in_channels=in_channels,
29
+ out_channels=num_filters[idx],
30
+ kernel_size=fs,
31
+ padding="same",
32
+ )
33
+ )
34
+ layers.append(nn.BatchNorm1d(num_filters[idx]))
35
+ layers.append(nn.ReLU())
36
+ in_channels = num_filters[idx]
37
+
38
+ self.model = nn.Sequential(*layers)
39
+
40
+ # use number of kernel same as the length of the sequence and average pooling
41
+ # then flatten and use dense layers
42
+ self.dense_model = nn.Sequential(
43
+ nn.Linear(in_channels, number_of_classes),
44
+ )
45
+
46
+ def forward(
47
+ self, input_ids: torch.Tensor, input_quals: torch.Tensor
48
+ ): # Adding `state` to be consistent with other models
49
+ x = self.embeddings(input_ids)
50
+ x = F.relu(x + self.qual_linear1(input_quals.unsqueeze(-1)))
51
+ x = x.transpose(1, 2) # [batch_size, embedding_dim, input_len]
52
+ x = self.model(x)
53
+ x = x.transpose(1, 2)
54
+ return self.dense_model(x) # [batch_size, input_len, num_filters]
@@ -0,0 +1 @@
1
+ """model components."""
@@ -0,0 +1,163 @@
1
+ from functools import partial
2
+
3
+ import torch
4
+
5
+ from . import basic_module, llm
6
+ from .basic_module import TokenClassificationLit
7
+
8
+
9
+ class DeepChopper:
10
+ """DeepChopper: A genomic language model for chimera artifact detection.
11
+
12
+ This class provides convenient methods to load DeepChopper models from checkpoints or
13
+ from pretrained models on the Hugging Face Hub, and to push trained models to the Hub.
14
+
15
+ Example:
16
+ Load a pretrained model:
17
+ >>> model = DeepChopper.from_pretrained("yangliz5/deepchopper")
18
+
19
+ Load from a local checkpoint:
20
+ >>> model = DeepChopper.from_checkpoint("path/to/checkpoint.ckpt")
21
+
22
+ Push a model to Hugging Face Hub:
23
+ >>> model = DeepChopper.to_hub("username/model-name", "path/to/checkpoint.ckpt")
24
+ """
25
+
26
+ @staticmethod
27
+ def to_hub(
28
+ model_name: str,
29
+ checkpoint_path: str,
30
+ *,
31
+ commit_message: str = "Upload DeepChopper model",
32
+ private: bool = False,
33
+ token: str | None = None,
34
+ ):
35
+ """Load a model from a checkpoint and push it to the Hugging Face Hub.
36
+
37
+ Args:
38
+ model_name: The repository ID on Hugging Face Hub (format: username/model-name)
39
+ checkpoint_path: Path to the local checkpoint file (.ckpt)
40
+ commit_message: Commit message for the upload (default: "Upload DeepChopper model")
41
+ private: Whether to create a private repository (default: False)
42
+ token: Hugging Face API token. If None, uses the stored token from `huggingface-cli login`
43
+
44
+ Returns:
45
+ The loaded TokenClassificationLit model
46
+
47
+ Example:
48
+ >>> model = DeepChopper.to_hub(
49
+ ... "username/deepchopper-v1",
50
+ ... "epoch_012_f1_0.9947.ckpt",
51
+ ... commit_message="Upload DeepChopper v1.0",
52
+ ... private=False
53
+ ... )
54
+ """
55
+ model = DeepChopper.from_checkpoint(checkpoint_path)
56
+
57
+ # Prepare kwargs for push_to_hub
58
+ push_kwargs = {
59
+ "repo_id": model_name,
60
+ "commit_message": commit_message,
61
+ "private": private,
62
+ }
63
+
64
+ if token is not None:
65
+ push_kwargs["token"] = token
66
+
67
+ model.push_to_hub(**push_kwargs)
68
+ return model
69
+
70
+ @staticmethod
71
+ def from_checkpoint(checkpoint_path: str):
72
+ """Load a DeepChopper model from a local checkpoint file.
73
+
74
+ This method creates a new TokenClassificationLit model with the HyenaDNA backbone
75
+ and loads the weights from the specified checkpoint file.
76
+
77
+ Args:
78
+ checkpoint_path: Path to the checkpoint file (.ckpt) containing model weights
79
+
80
+ Returns:
81
+ A TokenClassificationLit model loaded with the checkpoint weights
82
+
83
+ Example:
84
+ >>> model = DeepChopper.from_checkpoint("epoch_012_f1_0.9947.ckpt")
85
+
86
+ Note:
87
+ This function loads checkpoints with weights_only=False to support PyTorch Lightning
88
+ checkpoints containing optimizer/scheduler configs. Only load checkpoints from trusted sources.
89
+ """
90
+ model = TokenClassificationLit(
91
+ net=llm.hyena.TokenClassificationModule(
92
+ number_of_classes=2,
93
+ backbone_name="hyenadna-small-32k-seqlen",
94
+ freeze_backbone=False,
95
+ head=llm.TokenClassificationHead(
96
+ input_size=256,
97
+ lin1_size=1024,
98
+ lin2_size=1024,
99
+ num_class=2,
100
+ use_identity_layer_for_qual=True,
101
+ use_qual=True,
102
+ ),
103
+ ),
104
+ optimizer=partial( # type: ignore[arg-type]
105
+ torch.optim.Adam,
106
+ lr=0.0002,
107
+ weight_decay=0,
108
+ ),
109
+ scheduler=partial(torch.optim.lr_scheduler.ReduceLROnPlateau, mode="min", factor=0.1, patience=10), # type: ignore[arg-type]
110
+ criterion=basic_module.ContinuousIntervalLoss(lambda_penalty=0),
111
+ compile=False,
112
+ )
113
+ # weights_only=False is required for PyTorch Lightning checkpoints that contain
114
+ # optimizer/scheduler configs (functools.partial). Only use with trusted checkpoints.
115
+ checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
116
+ model.load_state_dict(checkpoint["state_dict"])
117
+ return model
118
+
119
+ @staticmethod
120
+ def from_pretrained(model_name: str):
121
+ """Load a pretrained DeepChopper model from the Hugging Face Hub.
122
+
123
+ This method downloads and loads a pretrained model from the Hugging Face Hub.
124
+ The model architecture is automatically configured to match the expected
125
+ HyenaDNA-based token classification setup.
126
+
127
+ Args:
128
+ model_name: The repository ID on Hugging Face Hub (e.g., "yangliz5/deepchopper")
129
+
130
+ Returns:
131
+ A TokenClassificationLit model loaded with pretrained weights
132
+
133
+ Example:
134
+ >>> model = DeepChopper.from_pretrained("yangliz5/deepchopper")
135
+
136
+ Note:
137
+ This requires an internet connection to download the model from Hugging Face Hub.
138
+ The model will be cached locally after the first download.
139
+ """
140
+ return TokenClassificationLit.from_pretrained(
141
+ model_name,
142
+ net=llm.hyena.TokenClassificationModule(
143
+ number_of_classes=2,
144
+ backbone_name="hyenadna-small-32k-seqlen",
145
+ freeze_backbone=False,
146
+ head=llm.TokenClassificationHead(
147
+ input_size=256,
148
+ lin1_size=1024,
149
+ lin2_size=1024,
150
+ num_class=2,
151
+ use_identity_layer_for_qual=True,
152
+ use_qual=True,
153
+ ),
154
+ ),
155
+ optimizer=partial( # type: ignore[arg-type]
156
+ torch.optim.Adam,
157
+ lr=0.0002,
158
+ weight_decay=0,
159
+ ),
160
+ scheduler=partial(torch.optim.lr_scheduler.ReduceLROnPlateau, mode="min", factor=0.1, patience=10), # type: ignore[arg-type]
161
+ criterion=basic_module.ContinuousIntervalLoss(lambda_penalty=0),
162
+ compile=False,
163
+ )
@@ -0,0 +1,32 @@
1
+ """hyena model components and utilities."""
2
+
3
+ from .components import (
4
+ HyenadnaMaxLengths,
5
+ TokenClassification,
6
+ TokenClassificationConfig,
7
+ )
8
+ from .head import TokenClassificationHead
9
+ from .hyena import TokenClassificationModule
10
+ from .metric import IGNORE_INDEX, compute_metrics
11
+ from .tokenizer import (
12
+ DataCollatorForTokenClassificationWithQual,
13
+ load_tokenizer_from_hyena_model,
14
+ tokenize_and_align_labels_and_quals,
15
+ tokenize_and_align_labels_and_quals_ids,
16
+ tokenize_dataset,
17
+ )
18
+
19
+ __all__ = [
20
+ "IGNORE_INDEX",
21
+ "DataCollatorForTokenClassificationWithQual",
22
+ "HyenadnaMaxLengths",
23
+ "TokenClassification",
24
+ "TokenClassificationConfig",
25
+ "TokenClassificationHead",
26
+ "TokenClassificationModule",
27
+ "compute_metrics",
28
+ "load_tokenizer_from_hyena_model",
29
+ "tokenize_and_align_labels_and_quals",
30
+ "tokenize_and_align_labels_and_quals_ids",
31
+ "tokenize_dataset",
32
+ ]
@@ -0,0 +1,55 @@
1
+ import torch
2
+ from torch import nn
3
+ from transformers import AutoModel
4
+
5
+ BACKBONES = [
6
+ "hyenadna-tiny-1k-seqlen",
7
+ "hyenadna-small-32k-seqlen",
8
+ "hyenadna-medium-160k-seqlen",
9
+ "hyenadna-medium-450k-seqlen",
10
+ "hyenadna-large-1m-seqlen",
11
+ "caduceus-ph_seqlen-131k_d_model-256_n_layer-16",
12
+ "caduceus-ps_seqlen-131k_d_model-256_n_layer-16",
13
+ ]
14
+
15
+ # https://github.com/kuleshov-group/caduceus
16
+
17
+
18
+ class TokenClassificationModule(nn.Module):
19
+ """Token classification model."""
20
+
21
+ def __init__(
22
+ self,
23
+ number_of_classes: int,
24
+ head: nn.Module,
25
+ backbone_name: str = "caduceus-ph_seqlen-131k_d_model-256_n_layer-16",
26
+ ):
27
+ super().__init__()
28
+ self.number_of_classes = number_of_classes
29
+ self.backbone_name = backbone_name
30
+
31
+ if "hyenadna" in backbone_name:
32
+ model_name = f"LongSafari/{backbone_name}-hf"
33
+ elif "caduceus" in backbone_name:
34
+ model_name = f"kuleshov-group/{backbone_name}"
35
+ else:
36
+ msg = f"Unknown backbone model: {backbone_name}"
37
+ raise ValueError(msg)
38
+
39
+ self.backbone = AutoModel.from_pretrained(model_name, trust_remote_code=True)
40
+ self.head = head
41
+
42
+ def forward(
43
+ self,
44
+ input_ids: torch.Tensor,
45
+ input_quals: torch.Tensor,
46
+ ):
47
+ transformer_outputs = self.backbone(
48
+ input_ids,
49
+ inputs_embeds=None,
50
+ output_hidden_states=None,
51
+ return_dict=None,
52
+ )
53
+
54
+ hidden_states = transformer_outputs[0]
55
+ return self.head(hidden_states, input_quals)