deepchopper 1.3.0__cp310-abi3-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. deepchopper/__init__.py +9 -0
  2. deepchopper/__init__.pyi +67 -0
  3. deepchopper/__main__.py +4 -0
  4. deepchopper/cli.py +260 -0
  5. deepchopper/data/__init__.py +15 -0
  6. deepchopper/data/components/__init__.py +1 -0
  7. deepchopper/data/encode_fq.py +41 -0
  8. deepchopper/data/fq_datamodule.py +352 -0
  9. deepchopper/data/hg_data.py +39 -0
  10. deepchopper/data/only_fq.py +388 -0
  11. deepchopper/deepchopper.abi3.so +0 -0
  12. deepchopper/eval.py +86 -0
  13. deepchopper/models/__init__.py +4 -0
  14. deepchopper/models/basic_module.py +243 -0
  15. deepchopper/models/callbacks.py +57 -0
  16. deepchopper/models/cnn.py +54 -0
  17. deepchopper/models/components/__init__.py +1 -0
  18. deepchopper/models/dc_hg.py +163 -0
  19. deepchopper/models/llm/__init__.py +32 -0
  20. deepchopper/models/llm/caduceus.py +55 -0
  21. deepchopper/models/llm/components.py +99 -0
  22. deepchopper/models/llm/head.py +102 -0
  23. deepchopper/models/llm/hyena.py +41 -0
  24. deepchopper/models/llm/metric.py +44 -0
  25. deepchopper/models/llm/tokenizer.py +205 -0
  26. deepchopper/models/transformer.py +107 -0
  27. deepchopper/py.typed +0 -0
  28. deepchopper/train.py +109 -0
  29. deepchopper/ui/__init__.py +1 -0
  30. deepchopper/ui/main.py +189 -0
  31. deepchopper/utils/__init__.py +37 -0
  32. deepchopper/utils/instantiators.py +54 -0
  33. deepchopper/utils/logging_utils.py +53 -0
  34. deepchopper/utils/preprocess.py +62 -0
  35. deepchopper/utils/print.py +102 -0
  36. deepchopper/utils/pylogger.py +57 -0
  37. deepchopper/utils/rich_utils.py +100 -0
  38. deepchopper/utils/utils.py +138 -0
  39. deepchopper-1.3.0.dist-info/METADATA +254 -0
  40. deepchopper-1.3.0.dist-info/RECORD +43 -0
  41. deepchopper-1.3.0.dist-info/WHEEL +4 -0
  42. deepchopper-1.3.0.dist-info/entry_points.txt +2 -0
  43. deepchopper-1.3.0.dist-info/licenses/LICENSE +201 -0
deepchopper/train.py ADDED
@@ -0,0 +1,109 @@
1
+ import os
2
+ from typing import TYPE_CHECKING, Any
3
+
4
+ import hydra
5
+ import lightning as L # noqa: N812
6
+ from lightning import Callback, LightningDataModule, LightningModule, Trainer
7
+ from omegaconf import DictConfig
8
+
9
+ from .utils import (
10
+ RankedLogger,
11
+ extras,
12
+ get_metric_value,
13
+ instantiate_callbacks,
14
+ instantiate_loggers,
15
+ log_hyperparameters,
16
+ task_wrapper,
17
+ )
18
+
19
+ if TYPE_CHECKING:
20
+ from lightning.pytorch.loggers import Logger
21
+
22
+ import torch
23
+
24
+ torch.set_float32_matmul_precision("high")
25
+
26
+ log = RankedLogger(__name__, rank_zero_only=True)
27
+
28
+
29
+ @task_wrapper
30
+ def train(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]:
31
+ """Trains the model. Can additionally evaluate on a testset, using best weights obtained during training.
32
+
33
+ This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
34
+ failure. Useful for multiruns, saving info about the crash, etc.
35
+
36
+ :param cfg: A DictConfig configuration composed by Hydra.
37
+ :return: A tuple with metrics and dict with all instantiated objects.
38
+ """
39
+ # set seed for random number generators in pytorch, numpy and python.random
40
+ if cfg.get("seed"):
41
+ L.seed_everything(cfg.seed, workers=True)
42
+
43
+ log.info(f"Instantiating datamodule <{cfg.data._target_}>")
44
+ datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
45
+
46
+ log.info(f"Instantiating model <{cfg.model._target_}>")
47
+ model: LightningModule = hydra.utils.instantiate(cfg.model)
48
+
49
+ log.info("Instantiating callbacks...")
50
+ callbacks: list[Callback] = instantiate_callbacks(cfg.get("callbacks"))
51
+
52
+ log.info("Instantiating loggers...")
53
+ logger: list[Logger] = instantiate_loggers(cfg.get("logger"))
54
+
55
+ log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
56
+ trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)
57
+
58
+ object_dict = {
59
+ "cfg": cfg,
60
+ "datamodule": datamodule,
61
+ "model": model,
62
+ "callbacks": callbacks,
63
+ "logger": logger,
64
+ "trainer": trainer,
65
+ }
66
+
67
+ if logger:
68
+ log.info("Logging hyperparameters!")
69
+ log_hyperparameters(object_dict)
70
+
71
+ if cfg.get("train"):
72
+ log.info("Starting training!")
73
+ trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
74
+
75
+ train_metrics = trainer.callback_metrics
76
+
77
+ if cfg.get("test"):
78
+ log.info("Starting testing!")
79
+ ckpt_path = trainer.checkpoint_callback.best_model_path
80
+ if ckpt_path == "":
81
+ log.warning("Best ckpt not found! Using current weights for testing...")
82
+ ckpt_path = None
83
+ trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
84
+ log.info(f"Best ckpt path: {ckpt_path}")
85
+
86
+ test_metrics = trainer.callback_metrics
87
+
88
+ # merge train and test metrics
89
+ metric_dict = {**train_metrics, **test_metrics}
90
+
91
+ return metric_dict, object_dict
92
+
93
+
94
+ @hydra.main(version_base="1.3", config_path=os.getenv("DC_CONFIG_PATH", "configs"), config_name="train.yaml")
95
+ def main(cfg: DictConfig) -> float | None:
96
+ """Main entry point for training.
97
+
98
+ :param cfg: DictConfig configuration composed by Hydra.
99
+ :return: Optional[float] with optimized metric value.
100
+ """
101
+ # apply extra utilities
102
+ # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
103
+ extras(cfg)
104
+
105
+ # train the model
106
+ metric_dict, _ = train(cfg)
107
+
108
+ # safely retrieve metric value for hydra-based hyperparameter optimization
109
+ return get_metric_value(metric_dict=metric_dict, metric_name=cfg.get("optimized_metric"))
@@ -0,0 +1 @@
1
+ from .main import main
deepchopper/ui/main.py ADDED
@@ -0,0 +1,189 @@
1
+ import multiprocessing
2
+ from functools import partial
3
+ from pathlib import Path
4
+
5
+ import gradio as gr
6
+ import lightning
7
+ import torch
8
+ from datasets import Dataset
9
+ from torch.utils.data import DataLoader
10
+
11
+ import deepchopper
12
+ from deepchopper.deepchopper import default, remove_intervals_and_keep_left, smooth_label_region
13
+ from deepchopper.models.llm import (
14
+ tokenize_and_align_labels_and_quals,
15
+ )
16
+ from deepchopper.utils import (
17
+ summary_predict,
18
+ )
19
+
20
+
21
+ def parse_fq_record(text: str):
22
+ """Parse a single FASTQ record into a dictionary."""
23
+ lines = text.strip().split("\n")
24
+ for i in range(0, len(lines), 4):
25
+ content = lines[i : i + 4]
26
+ record_id, seq, _, qual = content
27
+ assert len(seq) == len(qual) # noqa: S101
28
+
29
+ input_qual = deepchopper.encode_qual(qual, default.QUAL_OFFSET)
30
+
31
+ yield {
32
+ "id": record_id,
33
+ "seq": seq,
34
+ "qual": torch.Tensor(input_qual),
35
+ "target": [0, 0],
36
+ }
37
+
38
+
39
+ def load_dataset(text: str, tokenizer):
40
+ """Load dataset from text."""
41
+ dataset = Dataset.from_generator(parse_fq_record, gen_kwargs={"text": text}).with_format("torch")
42
+ tokenized_dataset = dataset.map(
43
+ partial(
44
+ tokenize_and_align_labels_and_quals,
45
+ tokenizer=tokenizer,
46
+ max_length=tokenizer.max_len_single_sentence,
47
+ ),
48
+ num_proc=multiprocessing.cpu_count(), # type: ignore
49
+ ).remove_columns(["id", "seq", "qual", "target"])
50
+ return dataset, tokenized_dataset
51
+
52
+
53
+ def predict(
54
+ text: str,
55
+ smooth_window_size: int = 21,
56
+ min_interval_size: int = 13,
57
+ approved_interval_number: int = 20,
58
+ max_process_intervals: int = 8, # default is 4
59
+ batch_size: int = 1,
60
+ num_workers: int = 1,
61
+ ):
62
+ tokenizer = deepchopper.models.llm.load_tokenizer_from_hyena_model(model_name="hyenadna-small-32k-seqlen")
63
+ dataset, tokenized_dataset = load_dataset(text, tokenizer)
64
+
65
+ dataloader = DataLoader(tokenized_dataset, batch_size=batch_size, num_workers=num_workers, persistent_workers=True)
66
+ model = deepchopper.DeepChopper.from_pretrained("yangliz5/deepchopper")
67
+
68
+ accelerator = "gpu" if torch.cuda.is_available() else "cpu"
69
+ trainer = lightning.pytorch.trainer.Trainer(
70
+ accelerator=accelerator,
71
+ devices="auto",
72
+ deterministic=False,
73
+ logger=False,
74
+ )
75
+
76
+ predicts = trainer.predict(model=model, dataloaders=dataloader, return_predictions=True)
77
+
78
+ assert len(predicts) == 1 # noqa: S101
79
+
80
+ smooth_interval_json: list[dict[str, int]] = []
81
+ highlighted_text: list[tuple[str, str | None]] = []
82
+
83
+ for idx, preds in enumerate(predicts):
84
+ true_prediction, _true_label = summary_predict(predictions=preds[0], labels=preds[1])
85
+
86
+ _id = dataset[idx]["id"]
87
+ seq = dataset[idx]["seq"]
88
+
89
+ smooth_predict_targets = smooth_label_region(
90
+ true_prediction[0], smooth_window_size, min_interval_size, approved_interval_number
91
+ )
92
+
93
+ if not smooth_predict_targets or len(smooth_predict_targets) > max_process_intervals:
94
+ continue
95
+
96
+ # zip two consecutive elements
97
+ _selected_seqs, selected_intervals = remove_intervals_and_keep_left(seq, smooth_predict_targets)
98
+ total_intervals = sorted(selected_intervals + smooth_predict_targets)
99
+
100
+ smooth_interval_json.extend({"start": i[0], "end": i[1]} for i in smooth_predict_targets)
101
+
102
+ highlighted_text.extend(
103
+ (seq[interval[0] : interval[1]], "ada" if interval in smooth_predict_targets else None)
104
+ for interval in total_intervals
105
+ )
106
+ return smooth_interval_json, highlighted_text
107
+
108
+
109
+ def process_input(text: str | None, file: str | None):
110
+ """Process the input and return the prediction."""
111
+ if not text and not file:
112
+ gr.Warning("Both text and file are empty")
113
+
114
+ if file:
115
+ MAX_LINES = 4
116
+ file_content = []
117
+ with Path(file).open() as f:
118
+ for idx, line in enumerate(f):
119
+ if idx >= MAX_LINES:
120
+ break
121
+ file_content.append(line)
122
+ text = "".join(file_content)
123
+ return predict(text=text)
124
+
125
+ return predict(text=text)
126
+
127
+
128
+ def create_gradio_app():
129
+ """Create a Gradio app for DeepChopper."""
130
+ example = (
131
+ "@1065:1135|393d635c-64f0-41ed-8531-12174d8efb28+f6a60069-1fcf-4049-8e7c-37523b4e273f\n"
132
+ "GCAGCTATGAATGCAAGGCCACAAGGTGGATGGAAGAGTTGTGGAACCAAAGAGCTGTCTTCCAGAGAAGATTTCGAGATAAGTCGCCCATCAGTGAACAAGATATTGTTGGTGGCATTTGATGAGAACGTTCCAAGATTATTGACAGATTAGTGAAAAGTAAGATTGAAATCATGACTGACCGTAAGTGGCAAGAAAGGGCTTTTGCCTTTGTAACCTTTGACGACCATGACTCCGTGGATAAGATTGTCATTCAGAATACCATACTGTGAATGGCCACATCTTTATTGTGAAGTTAGAAAAGCCCTGTCAAAGCAAGAGATGAATCAGTGCTTCTCCAGCCAAAGAGGTCGAAGTGGTTCTGGAAACTTTGGTGGTGGTCGTGGAGGTGGTTTCGGTGGGAATGACAACTCGGTCGTGGAGGAAACTTCAGTGGTCGTGGTGGCTTTGGTGGCAGCCGTGGTGGTGGTGGATATGGTGGCAGTGGGGATGGCTATAATGGATTTGGTAATGATGGAAGCAATTTGGAGGTGGTGGAAGCTACAATGATTTTGGGAATTACAACAATCAGTCTTCAAATTTTGGACCCCTAGGAGGAAATTTTGGTAGAAGCTCTGGCCCCATGGCGGTGGAGGCCAAATACTTTTGCAAACCACGAAACCAAGGTGGCTATGGCGGTCCAGCAGCAGCAGTAGCTATGGCAGTGGCAGAAGATTTTAATTAGGAAACAAAGCTTAGCAGGAGAGGAGAGCCAGAGAAGTGACAGGGAAGTACAGGTTACAACAGATTTGTGAACTCAGCCCAAGCACAGTGGTGGCAGGGCCTAGCTGCTACAAAGAAGACATGTTTTAGACAAATACTCATGTGTATGGGCAAAACTTGAGGACTGTATTTGTGACTAACTGTATAACAGGTTATTTTAGTTTCTGTTTGTGGAAAGTGTAAAGCATTCCAACAAAGGTTTTTAATGTAGATTTTTTTTTTTGCACCCCATGCTGTTGATTTGCTAAATGTAACAGTCTGATCGTGACGCTGAATAAATGTCTTTTTTAAAAAAAAAAAAAAGCTCCCTCCCATCCCCTGCTGCTAACTGATCCCATTATATCTAACCTGCCCCCCCATATCACCTGCTCCCGAGCTACCTAAGAACAGCTAAAAGAGCACACCCGCATGTAGCAAAATAGTGGGAAGATTATAGGTAGAGGCGACAAACCTACCGAGCCTGGTGATAGCTGGTTGTCCTAGATAGAATCTTAGTTCAACTTTAAATTTGCCCACAGAACCCTCTAAATCCCCTTGTAAATTTAACTGTTAGTCCAAAGAGGAACAGCTCTTTGGACACTAGGAAAAAACCTTGTAGAGAGTAAAAAATCAACACCCA\n"
133
+ "+\n"
134
+ ".0==?SSSSSSSSSSSH2216<868;SSSSSSSSSQQSRSIIHEDDESSSSSSJIKMGEKISSJJICCBDQ?;;8:;,**(&$'+501)\"#$()+%&&0<5+*/('%'))))'''$##\"\"\"\"%&--$\"\"\"('%)1L3*'')'#\"#&+*$&\"\"#*(&'''+,,<;9<BHGF//.LKORQSK<###%*-89<FSSSSE=BAFHFDB???3313NN?>=ANOSJDCADHGMOQSSD=7>BRRSPIEEEOQSSQ4->LIC7EE045///03IIJQSSSNGE6('.5??@A@=,,EGRSPKJ<==<556GFLLQRANSSSSSSSSG...*%%%(***(%'3@LOOSSSSM...7BCMMSSSSSSSSSSSSSSSDFIPSSSGGGGPOQLIHIL4103HMSILLNOSSSSSSSSSS22CBCGSHHHHSSSSSSSSD??@<<<:DDDSSSSSSSSSSA@6688OSSSSSROJJKLSNNNMSSSSQPOOSOOQSSSSSRRHIHISSRSSSSSSSSSSSJFF=??@SSQRK:424<444FFG///1S@@@ASNNNNPN:4JMDDLPSSSSSSBA?B?@@+'&'BD**8EDEFQPIMLE$$&',79CSJJPSGA+***DN;3-('&(;>6(()/-,,)%')1FRNNJ-:=>GC;&;CHNFFDCEEKJLFA22/27A.....HSQLHL))8<=?JSSSFGSKIHDDCCEFDAA@CFJKLNL>:9/1>>?OSLK@+HPSA;>>>K;;;;SSSSOQLPPMORSSSSSQSSSSSSS=:9**?D889SSRFFEDKJJJEEDKSSSNNOSSS.---,&*++SSSSQRSSSSQPGED<<89<@GJ999:SSKBBBAJHK=SSSJJKNMGHKKHQA<<>OPKFEAACDHJKMORB/)'((6**)15DA99;JSQSSS2())+J))EGMQOMMKJF>?<<AA620..D..,/112SOIIJSQFNEEEOMF?066=>@4,3;B>87FSSSSSSSSSSSSSSS<<::5658@AHMMSSRECC448/=<<>SSCB:5546;<??KF==;;FFEDFHKKJG):C>=>BJHINJFDPPPPPPPPPPPPPP%'*%$%+-%'(-22&&%('''&&&#\"\"%&'+0,,0;:1&\"\"%'(+++8'**(\"$$#&$'**//.3497$\"3CFHLOSSSSR:887:;;FSSRPRSSS4433$#$%&$$-056>@:;>=@?AHEFEC;*EKMSSRSRRDB>=AFRSSSSBSOOPSMDAABHH976951-9DHPQO/---?@ELSSQSRJHKKBKKLSSLINSOSSQSRIMSSSSSS>?MKIINSSGSSSSSSSQQMK544MJKKNKHGGLFFGBDB?EHIKGD?@DHPPIIF555)&(+,ADSSSSRQSSSQSS=9/0JJMSQSOSSO/97=B@=:>"
135
+ )
136
+
137
+ custom_css = """
138
+ .header { text-align: center; margin-bottom: 30px; }
139
+ .footer { text-align: center; margin-top: 30px; font-size: 0.8em; color: #666; }
140
+ """
141
+
142
+ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
143
+ gr.HTML(
144
+ """
145
+ <div class="header">
146
+ <h1>🧬 DeepChopper: DNA Sequence Analysis</h1>
147
+ <p>Analyze DNA sequences and detect artificial sequences</p>
148
+ </div>
149
+ """
150
+ )
151
+
152
+ with gr.Row():
153
+ with gr.Column(scale=1):
154
+ text_input = gr.Textbox(
155
+ label="Input DNA Sequence", placeholder="Paste your DNA sequence here...", lines=10
156
+ )
157
+ file_input = gr.File(label="Or upload a FASTQ file")
158
+ submit_btn = gr.Button("Analyze", variant="primary")
159
+
160
+ with gr.Column(scale=1):
161
+ json_output = gr.JSON(label="Detected Artificial Regions")
162
+ highlighted_text = gr.HighlightedText(label="Highlighted Sequence")
163
+
164
+ submit_btn.click(fn=process_input, inputs=[text_input, file_input], outputs=[json_output, highlighted_text])
165
+
166
+ gr.Examples(
167
+ examples=[[example]],
168
+ inputs=[text_input],
169
+ )
170
+
171
+ gr.HTML(
172
+ """
173
+ <div class="footer">
174
+ <p>DeepChopper - Powered by AI for DNA sequence analysis</p>
175
+ </div>
176
+ """
177
+ )
178
+
179
+ return demo
180
+
181
+
182
+ def main():
183
+ """Launch the Gradio app."""
184
+ app = create_gradio_app()
185
+ app.launch()
186
+
187
+
188
+ if __name__ == "__main__":
189
+ main()
@@ -0,0 +1,37 @@
1
+ """Utils."""
2
+
3
+ from .instantiators import instantiate_callbacks, instantiate_loggers
4
+ from .logging_utils import log_hyperparameters
5
+ from .preprocess import load_kmer2id, load_safetensor, save_ndarray_to_safetensor
6
+ from .print import (
7
+ alignment_predict,
8
+ highlight_target,
9
+ highlight_targets,
10
+ hightlight_predict,
11
+ hightlight_predicts,
12
+ summary_predict,
13
+ )
14
+ from .pylogger import RankedLogger
15
+ from .rich_utils import print_config_tree
16
+ from .utils import device, extras, get_metric_value, task_wrapper
17
+
18
+ __all__ = [
19
+ "RankedLogger",
20
+ "alignment_predict",
21
+ "device",
22
+ "extras",
23
+ "get_metric_value",
24
+ "highlight_target",
25
+ "highlight_targets",
26
+ "hightlight_predict",
27
+ "hightlight_predicts",
28
+ "instantiate_callbacks",
29
+ "instantiate_loggers",
30
+ "load_kmer2id",
31
+ "load_safetensor",
32
+ "log_hyperparameters",
33
+ "print_config_tree",
34
+ "save_ndarray_to_safetensor",
35
+ "summary_predict",
36
+ "task_wrapper",
37
+ ]
@@ -0,0 +1,54 @@
1
+ import hydra
2
+ from lightning import Callback
3
+ from lightning.pytorch.loggers import Logger
4
+ from omegaconf import DictConfig
5
+
6
+ from . import pylogger
7
+
8
+ log = pylogger.RankedLogger(__name__, rank_zero_only=True)
9
+
10
+
11
+ def instantiate_callbacks(callbacks_cfg: DictConfig) -> list[Callback]:
12
+ """Instantiates callbacks from config.
13
+
14
+ :param callbacks_cfg: A DictConfig object containing callback configurations.
15
+ :return: A list of instantiated callbacks.
16
+ """
17
+ callbacks: list[Callback] = []
18
+
19
+ if not callbacks_cfg:
20
+ log.warning("No callback configs found! Skipping..")
21
+ return callbacks
22
+
23
+ if not isinstance(callbacks_cfg, DictConfig):
24
+ raise TypeError("Callbacks config must be a DictConfig!")
25
+
26
+ for _, cb_conf in callbacks_cfg.items():
27
+ if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
28
+ log.info(f"Instantiating callback <{cb_conf._target_}>")
29
+ callbacks.append(hydra.utils.instantiate(cb_conf))
30
+
31
+ return callbacks
32
+
33
+
34
+ def instantiate_loggers(logger_cfg: DictConfig) -> list[Logger]:
35
+ """Instantiates loggers from config.
36
+
37
+ :param logger_cfg: A DictConfig object containing logger configurations.
38
+ :return: A list of instantiated loggers.
39
+ """
40
+ logger: list[Logger] = []
41
+
42
+ if not logger_cfg:
43
+ log.warning("No logger configs found! Skipping...")
44
+ return logger
45
+
46
+ if not isinstance(logger_cfg, DictConfig):
47
+ raise TypeError("Logger config must be a DictConfig!")
48
+
49
+ for _, lg_conf in logger_cfg.items():
50
+ if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
51
+ log.info(f"Instantiating logger <{lg_conf._target_}>")
52
+ logger.append(hydra.utils.instantiate(lg_conf))
53
+
54
+ return logger
@@ -0,0 +1,53 @@
1
+ from typing import Any
2
+
3
+ from lightning_utilities.core.rank_zero import rank_zero_only
4
+ from omegaconf import OmegaConf
5
+
6
+ from . import pylogger
7
+
8
+ log = pylogger.RankedLogger(__name__, rank_zero_only=True)
9
+
10
+
11
+ @rank_zero_only
12
+ def log_hyperparameters(object_dict: dict[str, Any]) -> None:
13
+ """Controls which config parts are saved by Lightning loggers.
14
+
15
+ Additionally saves:
16
+ - Number of model parameters
17
+
18
+ :param object_dict: A dictionary containing the following objects:
19
+ - `"cfg"`: A DictConfig object containing the main config.
20
+ - `"model"`: The Lightning model.
21
+ - `"trainer"`: The Lightning trainer.
22
+ """
23
+ hparams = {}
24
+
25
+ cfg = OmegaConf.to_container(object_dict["cfg"])
26
+ model = object_dict["model"]
27
+ trainer = object_dict["trainer"]
28
+
29
+ if not trainer.logger:
30
+ log.warning("Logger not found! Skipping hyperparameter logging...")
31
+ return
32
+
33
+ hparams["model"] = cfg["model"]
34
+
35
+ # save number of model parameters
36
+ hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
37
+ hparams["model/params/trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad)
38
+ hparams["model/params/non_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad)
39
+
40
+ hparams["data"] = cfg["data"]
41
+ hparams["trainer"] = cfg["trainer"]
42
+
43
+ hparams["callbacks"] = cfg.get("callbacks")
44
+ hparams["extras"] = cfg.get("extras")
45
+
46
+ hparams["task_name"] = cfg.get("task_name")
47
+ hparams["tags"] = cfg.get("tags")
48
+ hparams["ckpt_path"] = cfg.get("ckpt_path")
49
+ hparams["seed"] = cfg.get("seed")
50
+
51
+ # send hparams to all loggers
52
+ for logger in trainer.loggers:
53
+ logger.log_hyperparams(hparams)
@@ -0,0 +1,62 @@
1
+ from pathlib import Path
2
+
3
+ import numpy as np
4
+ import torch
5
+ from safetensors.torch import load_file, save_file
6
+
7
+
8
+ def load_safetensor(file_path: Path, device="cpu") -> dict[str, torch.Tensor]:
9
+ file_path = Path(file_path)
10
+ return load_file(file_path, device=device)
11
+
12
+
13
+ def save_ndarray_to_safetensor(ndarray: dict[str, np.ndarray], file_path: Path):
14
+ """Save a dictionary of NumPy arrays as PyTorch tensors to a file.
15
+
16
+ Parameters:
17
+ - ndarray (dict[str, np.ndarray]): A dictionary where keys are strings and values are NumPy arrays.
18
+ - file_path (Path): The file path where the PyTorch tensors will be saved.
19
+
20
+ Returns:
21
+ - None
22
+
23
+ Raises:
24
+ - ValueError: If the input dictionary is empty.
25
+ - TypeError: If the input dictionary values are not NumPy arrays.
26
+ - FileNotFoundError: If the specified file path does not exist.
27
+ - OSError: If there is an issue with writing to the file.
28
+
29
+ This function converts each NumPy array in the input dictionary to a PyTorch tensor and saves the resulting dictionary
30
+ to a file specified by the file path. If any errors occur during the process, appropriate exceptions are raised.
31
+ """
32
+ tensor_dict = {key: torch.tensor(value) for key, value in ndarray.items()}
33
+ save_file(tensor_dict, file_path)
34
+
35
+
36
+ def load_kmer2id(kmer2id_file: Path) -> dict:
37
+ """Load a dictionary mapping kmer strings to integer IDs from a file.
38
+
39
+ Parameters:
40
+ kmer2id_file (Path): A Path object pointing to the file containing the kmer to ID mapping.
41
+
42
+ Returns:
43
+ dict: A dictionary mapping kmer strings to integer IDs.
44
+
45
+ Raises:
46
+ FileNotFoundError: If the specified file does not exist.
47
+ ValueError: If the file format is incorrect or if the kmer and ID cannot be properly parsed.
48
+
49
+ Example:
50
+ If the kmer2id_file contains the following lines:
51
+ ATG 1
52
+ CAG 2
53
+ The function will return {'ATG': 1, 'CAG': 2}.
54
+ """
55
+ kmer2id_file = Path(kmer2id_file)
56
+
57
+ kmer2id = {}
58
+ with Path(kmer2id_file).open("r") as f:
59
+ for line in f:
60
+ kmer, idx = line.strip().split()
61
+ kmer2id[kmer] = int(idx)
62
+ return kmer2id
@@ -0,0 +1,102 @@
1
+ import typing
2
+
3
+ import numpy as np
4
+ import torch
5
+ from rich.console import Console
6
+ from rich.highlighter import RegexHighlighter
7
+ from rich.text import Text
8
+ from rich.theme import Theme
9
+
10
+ from deepchopper.deepchopper import summary_predict as rust_summary_predict
11
+ from deepchopper.models.llm import IGNORE_INDEX
12
+
13
+
14
+ def hightlight_predicts(
15
+ seq: str,
16
+ targets: list[tuple[int, int]],
17
+ predicts: list[tuple[int, int]],
18
+ style: str = "bold magenta",
19
+ width: int = 80,
20
+ ):
21
+ """Highlight the predicted and labeled sequences."""
22
+ target_seq = Text(seq)
23
+ predict_seq = Text(seq)
24
+ console = Console()
25
+
26
+ for start, end in targets:
27
+ target_seq.stylize(style, start, end)
28
+
29
+ for start, end in predicts:
30
+ predict_seq.stylize(style, start, end)
31
+
32
+ front1 = "L:"
33
+ front2 = "P:"
34
+ for t1, t2 in zip(target_seq.wrap(console, width), predict_seq.wrap(console, width), strict=True):
35
+ console.print(front1, t1)
36
+ console.print(front2, t2)
37
+
38
+
39
+ def highlight_targets(seq: str, targets: list[tuple[int, int]], style="bold magenta"):
40
+ """Highlight the target sequences."""
41
+ text = Text(seq)
42
+ console = Console()
43
+ for start, end in targets:
44
+ text.stylize(style, start, end)
45
+ console.print(text)
46
+
47
+
48
+ def highlight_target(seq: str, start: int, end: int, style="bold magenta"):
49
+ """Highlight the target sequence."""
50
+ text = Text(seq)
51
+ console = Console()
52
+ text.stylize(style, start, end)
53
+ console.print(text)
54
+
55
+
56
+ def hightlight_predict(seq: str, target_start: int, target_end: int, predict_start: int, predict_end: int):
57
+ """Highlight the predicted sequence."""
58
+ text = Text(seq)
59
+ console = Console()
60
+
61
+ text.stylize("#adb0b1", target_start, target_end)
62
+ text.stylize("bold magenta", predict_start, predict_end)
63
+
64
+ console.print(text)
65
+
66
+
67
+ def summary_predict(predictions, labels):
68
+ """Summarize the predictions and labels."""
69
+ predictions = np.argmax(predictions, axis=2)
70
+ # Initialize lists to hold the filtered predictions and labels
71
+
72
+ if isinstance(predictions, torch.Tensor):
73
+ predictions = predictions.cpu().numpy()
74
+
75
+ if isinstance(labels, torch.Tensor):
76
+ labels = labels.cpu().numpy()
77
+
78
+ true_predictions, true_labels = rust_summary_predict(predictions, labels, IGNORE_INDEX)
79
+ return true_predictions, true_labels
80
+
81
+
82
+ class LabelHighlighter(RegexHighlighter):
83
+ """Apply style to anything that looks like an email."""
84
+
85
+ base_style = "label."
86
+ highlights: typing.ClassVar = [r"(?P<label>1+)"]
87
+
88
+
89
+ def alignment_predict(prediction, label):
90
+ """Print the alignment of the predicted and labeled sequences."""
91
+ import textwrap
92
+
93
+ prediction_str = "".join(str(x) for x in prediction)
94
+ label_str = "".join(str(x) for x in label)
95
+
96
+ front2 = "L:"
97
+ front1 = "P:"
98
+ theme = Theme({"label.label": "bold magenta"})
99
+ console = Console(highlighter=LabelHighlighter(), theme=theme)
100
+ for l1, l2 in zip(textwrap.wrap(prediction_str), textwrap.wrap(label_str), strict=True):
101
+ ss = f"{front1}{l1}\n{front2}{l2}"
102
+ console.print(ss)
@@ -0,0 +1,57 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import TYPE_CHECKING
5
+
6
+ from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
7
+
8
+ if TYPE_CHECKING:
9
+ from collections.abc import Mapping
10
+
11
+
12
+ class RankedLogger(logging.LoggerAdapter):
13
+ """A multi-GPU-friendly python command line logger."""
14
+
15
+ def __init__(
16
+ self,
17
+ name: str = __name__,
18
+ *,
19
+ rank_zero_only: bool = False,
20
+ extra: Mapping[str, object] | None = None,
21
+ ) -> None:
22
+ """Initializes a multi-GPU-friendly python command line logger that logs on all processes
23
+ with their rank prefixed in the log message.
24
+
25
+ :param name: The name of the logger. Default is ``__name__``.
26
+ :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
27
+ :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
28
+ """
29
+ logger = logging.getLogger(name)
30
+ super().__init__(logger=logger, extra=extra)
31
+ self.rank_zero_only = rank_zero_only
32
+
33
+ def log(self, level: int, msg: str, rank: int | None = None, *args, **kwargs) -> None:
34
+ """Delegate a log call to the underlying logger, after prefixing its message with the rank
35
+ of the process it's being logged from. If `'rank'` is provided, then the log will only
36
+ occur on that rank/process.
37
+
38
+ :param level: The level to log at. Look at `logging.__init__.py` for more information.
39
+ :param msg: The message to log.
40
+ :param rank: The rank to log at.
41
+ :param args: Additional args to pass to the underlying logging function.
42
+ :param kwargs: Any additional keyword args to pass to the underlying logging function.
43
+ """
44
+ if self.isEnabledFor(level):
45
+ msg, kwargs = self.process(msg, kwargs)
46
+ current_rank = getattr(rank_zero_only, "rank", None)
47
+ if current_rank is None:
48
+ raise RuntimeError("The `rank_zero_only.rank` needs to be set before use")
49
+ msg = rank_prefixed_message(msg, current_rank)
50
+ if self.rank_zero_only:
51
+ if current_rank == 0:
52
+ self.logger.log(level, msg, *args, **kwargs)
53
+ else:
54
+ if rank is None:
55
+ self.logger.log(level, msg, *args, **kwargs)
56
+ elif current_rank == rank:
57
+ self.logger.log(level, msg, *args, **kwargs)