deepchopper 1.3.0__cp310-abi3-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepchopper/__init__.py +9 -0
- deepchopper/__init__.pyi +67 -0
- deepchopper/__main__.py +4 -0
- deepchopper/cli.py +260 -0
- deepchopper/data/__init__.py +15 -0
- deepchopper/data/components/__init__.py +1 -0
- deepchopper/data/encode_fq.py +41 -0
- deepchopper/data/fq_datamodule.py +352 -0
- deepchopper/data/hg_data.py +39 -0
- deepchopper/data/only_fq.py +388 -0
- deepchopper/deepchopper.abi3.so +0 -0
- deepchopper/eval.py +86 -0
- deepchopper/models/__init__.py +4 -0
- deepchopper/models/basic_module.py +243 -0
- deepchopper/models/callbacks.py +57 -0
- deepchopper/models/cnn.py +54 -0
- deepchopper/models/components/__init__.py +1 -0
- deepchopper/models/dc_hg.py +163 -0
- deepchopper/models/llm/__init__.py +32 -0
- deepchopper/models/llm/caduceus.py +55 -0
- deepchopper/models/llm/components.py +99 -0
- deepchopper/models/llm/head.py +102 -0
- deepchopper/models/llm/hyena.py +41 -0
- deepchopper/models/llm/metric.py +44 -0
- deepchopper/models/llm/tokenizer.py +205 -0
- deepchopper/models/transformer.py +107 -0
- deepchopper/py.typed +0 -0
- deepchopper/train.py +109 -0
- deepchopper/ui/__init__.py +1 -0
- deepchopper/ui/main.py +189 -0
- deepchopper/utils/__init__.py +37 -0
- deepchopper/utils/instantiators.py +54 -0
- deepchopper/utils/logging_utils.py +53 -0
- deepchopper/utils/preprocess.py +62 -0
- deepchopper/utils/print.py +102 -0
- deepchopper/utils/pylogger.py +57 -0
- deepchopper/utils/rich_utils.py +100 -0
- deepchopper/utils/utils.py +138 -0
- deepchopper-1.3.0.dist-info/METADATA +254 -0
- deepchopper-1.3.0.dist-info/RECORD +43 -0
- deepchopper-1.3.0.dist-info/WHEEL +4 -0
- deepchopper-1.3.0.dist-info/entry_points.txt +2 -0
- deepchopper-1.3.0.dist-info/licenses/LICENSE +201 -0
deepchopper/train.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import TYPE_CHECKING, Any
|
|
3
|
+
|
|
4
|
+
import hydra
|
|
5
|
+
import lightning as L # noqa: N812
|
|
6
|
+
from lightning import Callback, LightningDataModule, LightningModule, Trainer
|
|
7
|
+
from omegaconf import DictConfig
|
|
8
|
+
|
|
9
|
+
from .utils import (
|
|
10
|
+
RankedLogger,
|
|
11
|
+
extras,
|
|
12
|
+
get_metric_value,
|
|
13
|
+
instantiate_callbacks,
|
|
14
|
+
instantiate_loggers,
|
|
15
|
+
log_hyperparameters,
|
|
16
|
+
task_wrapper,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from lightning.pytorch.loggers import Logger
|
|
21
|
+
|
|
22
|
+
import torch
|
|
23
|
+
|
|
24
|
+
torch.set_float32_matmul_precision("high")
|
|
25
|
+
|
|
26
|
+
log = RankedLogger(__name__, rank_zero_only=True)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@task_wrapper
|
|
30
|
+
def train(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
31
|
+
"""Trains the model. Can additionally evaluate on a testset, using best weights obtained during training.
|
|
32
|
+
|
|
33
|
+
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
|
|
34
|
+
failure. Useful for multiruns, saving info about the crash, etc.
|
|
35
|
+
|
|
36
|
+
:param cfg: A DictConfig configuration composed by Hydra.
|
|
37
|
+
:return: A tuple with metrics and dict with all instantiated objects.
|
|
38
|
+
"""
|
|
39
|
+
# set seed for random number generators in pytorch, numpy and python.random
|
|
40
|
+
if cfg.get("seed"):
|
|
41
|
+
L.seed_everything(cfg.seed, workers=True)
|
|
42
|
+
|
|
43
|
+
log.info(f"Instantiating datamodule <{cfg.data._target_}>")
|
|
44
|
+
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
|
|
45
|
+
|
|
46
|
+
log.info(f"Instantiating model <{cfg.model._target_}>")
|
|
47
|
+
model: LightningModule = hydra.utils.instantiate(cfg.model)
|
|
48
|
+
|
|
49
|
+
log.info("Instantiating callbacks...")
|
|
50
|
+
callbacks: list[Callback] = instantiate_callbacks(cfg.get("callbacks"))
|
|
51
|
+
|
|
52
|
+
log.info("Instantiating loggers...")
|
|
53
|
+
logger: list[Logger] = instantiate_loggers(cfg.get("logger"))
|
|
54
|
+
|
|
55
|
+
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
|
|
56
|
+
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)
|
|
57
|
+
|
|
58
|
+
object_dict = {
|
|
59
|
+
"cfg": cfg,
|
|
60
|
+
"datamodule": datamodule,
|
|
61
|
+
"model": model,
|
|
62
|
+
"callbacks": callbacks,
|
|
63
|
+
"logger": logger,
|
|
64
|
+
"trainer": trainer,
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
if logger:
|
|
68
|
+
log.info("Logging hyperparameters!")
|
|
69
|
+
log_hyperparameters(object_dict)
|
|
70
|
+
|
|
71
|
+
if cfg.get("train"):
|
|
72
|
+
log.info("Starting training!")
|
|
73
|
+
trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
|
|
74
|
+
|
|
75
|
+
train_metrics = trainer.callback_metrics
|
|
76
|
+
|
|
77
|
+
if cfg.get("test"):
|
|
78
|
+
log.info("Starting testing!")
|
|
79
|
+
ckpt_path = trainer.checkpoint_callback.best_model_path
|
|
80
|
+
if ckpt_path == "":
|
|
81
|
+
log.warning("Best ckpt not found! Using current weights for testing...")
|
|
82
|
+
ckpt_path = None
|
|
83
|
+
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
|
|
84
|
+
log.info(f"Best ckpt path: {ckpt_path}")
|
|
85
|
+
|
|
86
|
+
test_metrics = trainer.callback_metrics
|
|
87
|
+
|
|
88
|
+
# merge train and test metrics
|
|
89
|
+
metric_dict = {**train_metrics, **test_metrics}
|
|
90
|
+
|
|
91
|
+
return metric_dict, object_dict
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@hydra.main(version_base="1.3", config_path=os.getenv("DC_CONFIG_PATH", "configs"), config_name="train.yaml")
|
|
95
|
+
def main(cfg: DictConfig) -> float | None:
|
|
96
|
+
"""Main entry point for training.
|
|
97
|
+
|
|
98
|
+
:param cfg: DictConfig configuration composed by Hydra.
|
|
99
|
+
:return: Optional[float] with optimized metric value.
|
|
100
|
+
"""
|
|
101
|
+
# apply extra utilities
|
|
102
|
+
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
|
|
103
|
+
extras(cfg)
|
|
104
|
+
|
|
105
|
+
# train the model
|
|
106
|
+
metric_dict, _ = train(cfg)
|
|
107
|
+
|
|
108
|
+
# safely retrieve metric value for hydra-based hyperparameter optimization
|
|
109
|
+
return get_metric_value(metric_dict=metric_dict, metric_name=cfg.get("optimized_metric"))
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .main import main
|
deepchopper/ui/main.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
import multiprocessing
|
|
2
|
+
from functools import partial
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import gradio as gr
|
|
6
|
+
import lightning
|
|
7
|
+
import torch
|
|
8
|
+
from datasets import Dataset
|
|
9
|
+
from torch.utils.data import DataLoader
|
|
10
|
+
|
|
11
|
+
import deepchopper
|
|
12
|
+
from deepchopper.deepchopper import default, remove_intervals_and_keep_left, smooth_label_region
|
|
13
|
+
from deepchopper.models.llm import (
|
|
14
|
+
tokenize_and_align_labels_and_quals,
|
|
15
|
+
)
|
|
16
|
+
from deepchopper.utils import (
|
|
17
|
+
summary_predict,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def parse_fq_record(text: str):
|
|
22
|
+
"""Parse a single FASTQ record into a dictionary."""
|
|
23
|
+
lines = text.strip().split("\n")
|
|
24
|
+
for i in range(0, len(lines), 4):
|
|
25
|
+
content = lines[i : i + 4]
|
|
26
|
+
record_id, seq, _, qual = content
|
|
27
|
+
assert len(seq) == len(qual) # noqa: S101
|
|
28
|
+
|
|
29
|
+
input_qual = deepchopper.encode_qual(qual, default.QUAL_OFFSET)
|
|
30
|
+
|
|
31
|
+
yield {
|
|
32
|
+
"id": record_id,
|
|
33
|
+
"seq": seq,
|
|
34
|
+
"qual": torch.Tensor(input_qual),
|
|
35
|
+
"target": [0, 0],
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def load_dataset(text: str, tokenizer):
|
|
40
|
+
"""Load dataset from text."""
|
|
41
|
+
dataset = Dataset.from_generator(parse_fq_record, gen_kwargs={"text": text}).with_format("torch")
|
|
42
|
+
tokenized_dataset = dataset.map(
|
|
43
|
+
partial(
|
|
44
|
+
tokenize_and_align_labels_and_quals,
|
|
45
|
+
tokenizer=tokenizer,
|
|
46
|
+
max_length=tokenizer.max_len_single_sentence,
|
|
47
|
+
),
|
|
48
|
+
num_proc=multiprocessing.cpu_count(), # type: ignore
|
|
49
|
+
).remove_columns(["id", "seq", "qual", "target"])
|
|
50
|
+
return dataset, tokenized_dataset
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def predict(
|
|
54
|
+
text: str,
|
|
55
|
+
smooth_window_size: int = 21,
|
|
56
|
+
min_interval_size: int = 13,
|
|
57
|
+
approved_interval_number: int = 20,
|
|
58
|
+
max_process_intervals: int = 8, # default is 4
|
|
59
|
+
batch_size: int = 1,
|
|
60
|
+
num_workers: int = 1,
|
|
61
|
+
):
|
|
62
|
+
tokenizer = deepchopper.models.llm.load_tokenizer_from_hyena_model(model_name="hyenadna-small-32k-seqlen")
|
|
63
|
+
dataset, tokenized_dataset = load_dataset(text, tokenizer)
|
|
64
|
+
|
|
65
|
+
dataloader = DataLoader(tokenized_dataset, batch_size=batch_size, num_workers=num_workers, persistent_workers=True)
|
|
66
|
+
model = deepchopper.DeepChopper.from_pretrained("yangliz5/deepchopper")
|
|
67
|
+
|
|
68
|
+
accelerator = "gpu" if torch.cuda.is_available() else "cpu"
|
|
69
|
+
trainer = lightning.pytorch.trainer.Trainer(
|
|
70
|
+
accelerator=accelerator,
|
|
71
|
+
devices="auto",
|
|
72
|
+
deterministic=False,
|
|
73
|
+
logger=False,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
predicts = trainer.predict(model=model, dataloaders=dataloader, return_predictions=True)
|
|
77
|
+
|
|
78
|
+
assert len(predicts) == 1 # noqa: S101
|
|
79
|
+
|
|
80
|
+
smooth_interval_json: list[dict[str, int]] = []
|
|
81
|
+
highlighted_text: list[tuple[str, str | None]] = []
|
|
82
|
+
|
|
83
|
+
for idx, preds in enumerate(predicts):
|
|
84
|
+
true_prediction, _true_label = summary_predict(predictions=preds[0], labels=preds[1])
|
|
85
|
+
|
|
86
|
+
_id = dataset[idx]["id"]
|
|
87
|
+
seq = dataset[idx]["seq"]
|
|
88
|
+
|
|
89
|
+
smooth_predict_targets = smooth_label_region(
|
|
90
|
+
true_prediction[0], smooth_window_size, min_interval_size, approved_interval_number
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
if not smooth_predict_targets or len(smooth_predict_targets) > max_process_intervals:
|
|
94
|
+
continue
|
|
95
|
+
|
|
96
|
+
# zip two consecutive elements
|
|
97
|
+
_selected_seqs, selected_intervals = remove_intervals_and_keep_left(seq, smooth_predict_targets)
|
|
98
|
+
total_intervals = sorted(selected_intervals + smooth_predict_targets)
|
|
99
|
+
|
|
100
|
+
smooth_interval_json.extend({"start": i[0], "end": i[1]} for i in smooth_predict_targets)
|
|
101
|
+
|
|
102
|
+
highlighted_text.extend(
|
|
103
|
+
(seq[interval[0] : interval[1]], "ada" if interval in smooth_predict_targets else None)
|
|
104
|
+
for interval in total_intervals
|
|
105
|
+
)
|
|
106
|
+
return smooth_interval_json, highlighted_text
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def process_input(text: str | None, file: str | None):
|
|
110
|
+
"""Process the input and return the prediction."""
|
|
111
|
+
if not text and not file:
|
|
112
|
+
gr.Warning("Both text and file are empty")
|
|
113
|
+
|
|
114
|
+
if file:
|
|
115
|
+
MAX_LINES = 4
|
|
116
|
+
file_content = []
|
|
117
|
+
with Path(file).open() as f:
|
|
118
|
+
for idx, line in enumerate(f):
|
|
119
|
+
if idx >= MAX_LINES:
|
|
120
|
+
break
|
|
121
|
+
file_content.append(line)
|
|
122
|
+
text = "".join(file_content)
|
|
123
|
+
return predict(text=text)
|
|
124
|
+
|
|
125
|
+
return predict(text=text)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def create_gradio_app():
|
|
129
|
+
"""Create a Gradio app for DeepChopper."""
|
|
130
|
+
example = (
|
|
131
|
+
"@1065:1135|393d635c-64f0-41ed-8531-12174d8efb28+f6a60069-1fcf-4049-8e7c-37523b4e273f\n"
|
|
132
|
+
"GCAGCTATGAATGCAAGGCCACAAGGTGGATGGAAGAGTTGTGGAACCAAAGAGCTGTCTTCCAGAGAAGATTTCGAGATAAGTCGCCCATCAGTGAACAAGATATTGTTGGTGGCATTTGATGAGAACGTTCCAAGATTATTGACAGATTAGTGAAAAGTAAGATTGAAATCATGACTGACCGTAAGTGGCAAGAAAGGGCTTTTGCCTTTGTAACCTTTGACGACCATGACTCCGTGGATAAGATTGTCATTCAGAATACCATACTGTGAATGGCCACATCTTTATTGTGAAGTTAGAAAAGCCCTGTCAAAGCAAGAGATGAATCAGTGCTTCTCCAGCCAAAGAGGTCGAAGTGGTTCTGGAAACTTTGGTGGTGGTCGTGGAGGTGGTTTCGGTGGGAATGACAACTCGGTCGTGGAGGAAACTTCAGTGGTCGTGGTGGCTTTGGTGGCAGCCGTGGTGGTGGTGGATATGGTGGCAGTGGGGATGGCTATAATGGATTTGGTAATGATGGAAGCAATTTGGAGGTGGTGGAAGCTACAATGATTTTGGGAATTACAACAATCAGTCTTCAAATTTTGGACCCCTAGGAGGAAATTTTGGTAGAAGCTCTGGCCCCATGGCGGTGGAGGCCAAATACTTTTGCAAACCACGAAACCAAGGTGGCTATGGCGGTCCAGCAGCAGCAGTAGCTATGGCAGTGGCAGAAGATTTTAATTAGGAAACAAAGCTTAGCAGGAGAGGAGAGCCAGAGAAGTGACAGGGAAGTACAGGTTACAACAGATTTGTGAACTCAGCCCAAGCACAGTGGTGGCAGGGCCTAGCTGCTACAAAGAAGACATGTTTTAGACAAATACTCATGTGTATGGGCAAAACTTGAGGACTGTATTTGTGACTAACTGTATAACAGGTTATTTTAGTTTCTGTTTGTGGAAAGTGTAAAGCATTCCAACAAAGGTTTTTAATGTAGATTTTTTTTTTTGCACCCCATGCTGTTGATTTGCTAAATGTAACAGTCTGATCGTGACGCTGAATAAATGTCTTTTTTAAAAAAAAAAAAAAGCTCCCTCCCATCCCCTGCTGCTAACTGATCCCATTATATCTAACCTGCCCCCCCATATCACCTGCTCCCGAGCTACCTAAGAACAGCTAAAAGAGCACACCCGCATGTAGCAAAATAGTGGGAAGATTATAGGTAGAGGCGACAAACCTACCGAGCCTGGTGATAGCTGGTTGTCCTAGATAGAATCTTAGTTCAACTTTAAATTTGCCCACAGAACCCTCTAAATCCCCTTGTAAATTTAACTGTTAGTCCAAAGAGGAACAGCTCTTTGGACACTAGGAAAAAACCTTGTAGAGAGTAAAAAATCAACACCCA\n"
|
|
133
|
+
"+\n"
|
|
134
|
+
".0==?SSSSSSSSSSSH2216<868;SSSSSSSSSQQSRSIIHEDDESSSSSSJIKMGEKISSJJICCBDQ?;;8:;,**(&$'+501)\"#$()+%&&0<5+*/('%'))))'''$##\"\"\"\"%&--$\"\"\"('%)1L3*'')'#\"#&+*$&\"\"#*(&'''+,,<;9<BHGF//.LKORQSK<###%*-89<FSSSSE=BAFHFDB???3313NN?>=ANOSJDCADHGMOQSSD=7>BRRSPIEEEOQSSQ4->LIC7EE045///03IIJQSSSNGE6('.5??@A@=,,EGRSPKJ<==<556GFLLQRANSSSSSSSSG...*%%%(***(%'3@LOOSSSSM...7BCMMSSSSSSSSSSSSSSSDFIPSSSGGGGPOQLIHIL4103HMSILLNOSSSSSSSSSS22CBCGSHHHHSSSSSSSSD??@<<<:DDDSSSSSSSSSSA@6688OSSSSSROJJKLSNNNMSSSSQPOOSOOQSSSSSRRHIHISSRSSSSSSSSSSSJFF=??@SSQRK:424<444FFG///1S@@@ASNNNNPN:4JMDDLPSSSSSSBA?B?@@+'&'BD**8EDEFQPIMLE$$&',79CSJJPSGA+***DN;3-('&(;>6(()/-,,)%')1FRNNJ-:=>GC;&;CHNFFDCEEKJLFA22/27A.....HSQLHL))8<=?JSSSFGSKIHDDCCEFDAA@CFJKLNL>:9/1>>?OSLK@+HPSA;>>>K;;;;SSSSOQLPPMORSSSSSQSSSSSSS=:9**?D889SSRFFEDKJJJEEDKSSSNNOSSS.---,&*++SSSSQRSSSSQPGED<<89<@GJ999:SSKBBBAJHK=SSSJJKNMGHKKHQA<<>OPKFEAACDHJKMORB/)'((6**)15DA99;JSQSSS2())+J))EGMQOMMKJF>?<<AA620..D..,/112SOIIJSQFNEEEOMF?066=>@4,3;B>87FSSSSSSSSSSSSSSS<<::5658@AHMMSSRECC448/=<<>SSCB:5546;<??KF==;;FFEDFHKKJG):C>=>BJHINJFDPPPPPPPPPPPPPP%'*%$%+-%'(-22&&%('''&&&#\"\"%&'+0,,0;:1&\"\"%'(+++8'**(\"$$#&$'**//.3497$\"3CFHLOSSSSR:887:;;FSSRPRSSS4433$#$%&$$-056>@:;>=@?AHEFEC;*EKMSSRSRRDB>=AFRSSSSBSOOPSMDAABHH976951-9DHPQO/---?@ELSSQSRJHKKBKKLSSLINSOSSQSRIMSSSSSS>?MKIINSSGSSSSSSSQQMK544MJKKNKHGGLFFGBDB?EHIKGD?@DHPPIIF555)&(+,ADSSSSRQSSSQSS=9/0JJMSQSOSSO/97=B@=:>"
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
custom_css = """
|
|
138
|
+
.header { text-align: center; margin-bottom: 30px; }
|
|
139
|
+
.footer { text-align: center; margin-top: 30px; font-size: 0.8em; color: #666; }
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
|
|
143
|
+
gr.HTML(
|
|
144
|
+
"""
|
|
145
|
+
<div class="header">
|
|
146
|
+
<h1>🧬 DeepChopper: DNA Sequence Analysis</h1>
|
|
147
|
+
<p>Analyze DNA sequences and detect artificial sequences</p>
|
|
148
|
+
</div>
|
|
149
|
+
"""
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
with gr.Row():
|
|
153
|
+
with gr.Column(scale=1):
|
|
154
|
+
text_input = gr.Textbox(
|
|
155
|
+
label="Input DNA Sequence", placeholder="Paste your DNA sequence here...", lines=10
|
|
156
|
+
)
|
|
157
|
+
file_input = gr.File(label="Or upload a FASTQ file")
|
|
158
|
+
submit_btn = gr.Button("Analyze", variant="primary")
|
|
159
|
+
|
|
160
|
+
with gr.Column(scale=1):
|
|
161
|
+
json_output = gr.JSON(label="Detected Artificial Regions")
|
|
162
|
+
highlighted_text = gr.HighlightedText(label="Highlighted Sequence")
|
|
163
|
+
|
|
164
|
+
submit_btn.click(fn=process_input, inputs=[text_input, file_input], outputs=[json_output, highlighted_text])
|
|
165
|
+
|
|
166
|
+
gr.Examples(
|
|
167
|
+
examples=[[example]],
|
|
168
|
+
inputs=[text_input],
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
gr.HTML(
|
|
172
|
+
"""
|
|
173
|
+
<div class="footer">
|
|
174
|
+
<p>DeepChopper - Powered by AI for DNA sequence analysis</p>
|
|
175
|
+
</div>
|
|
176
|
+
"""
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
return demo
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def main():
|
|
183
|
+
"""Launch the Gradio app."""
|
|
184
|
+
app = create_gradio_app()
|
|
185
|
+
app.launch()
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
if __name__ == "__main__":
|
|
189
|
+
main()
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""Utils."""
|
|
2
|
+
|
|
3
|
+
from .instantiators import instantiate_callbacks, instantiate_loggers
|
|
4
|
+
from .logging_utils import log_hyperparameters
|
|
5
|
+
from .preprocess import load_kmer2id, load_safetensor, save_ndarray_to_safetensor
|
|
6
|
+
from .print import (
|
|
7
|
+
alignment_predict,
|
|
8
|
+
highlight_target,
|
|
9
|
+
highlight_targets,
|
|
10
|
+
hightlight_predict,
|
|
11
|
+
hightlight_predicts,
|
|
12
|
+
summary_predict,
|
|
13
|
+
)
|
|
14
|
+
from .pylogger import RankedLogger
|
|
15
|
+
from .rich_utils import print_config_tree
|
|
16
|
+
from .utils import device, extras, get_metric_value, task_wrapper
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"RankedLogger",
|
|
20
|
+
"alignment_predict",
|
|
21
|
+
"device",
|
|
22
|
+
"extras",
|
|
23
|
+
"get_metric_value",
|
|
24
|
+
"highlight_target",
|
|
25
|
+
"highlight_targets",
|
|
26
|
+
"hightlight_predict",
|
|
27
|
+
"hightlight_predicts",
|
|
28
|
+
"instantiate_callbacks",
|
|
29
|
+
"instantiate_loggers",
|
|
30
|
+
"load_kmer2id",
|
|
31
|
+
"load_safetensor",
|
|
32
|
+
"log_hyperparameters",
|
|
33
|
+
"print_config_tree",
|
|
34
|
+
"save_ndarray_to_safetensor",
|
|
35
|
+
"summary_predict",
|
|
36
|
+
"task_wrapper",
|
|
37
|
+
]
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
import hydra
|
|
2
|
+
from lightning import Callback
|
|
3
|
+
from lightning.pytorch.loggers import Logger
|
|
4
|
+
from omegaconf import DictConfig
|
|
5
|
+
|
|
6
|
+
from . import pylogger
|
|
7
|
+
|
|
8
|
+
log = pylogger.RankedLogger(__name__, rank_zero_only=True)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def instantiate_callbacks(callbacks_cfg: DictConfig) -> list[Callback]:
|
|
12
|
+
"""Instantiates callbacks from config.
|
|
13
|
+
|
|
14
|
+
:param callbacks_cfg: A DictConfig object containing callback configurations.
|
|
15
|
+
:return: A list of instantiated callbacks.
|
|
16
|
+
"""
|
|
17
|
+
callbacks: list[Callback] = []
|
|
18
|
+
|
|
19
|
+
if not callbacks_cfg:
|
|
20
|
+
log.warning("No callback configs found! Skipping..")
|
|
21
|
+
return callbacks
|
|
22
|
+
|
|
23
|
+
if not isinstance(callbacks_cfg, DictConfig):
|
|
24
|
+
raise TypeError("Callbacks config must be a DictConfig!")
|
|
25
|
+
|
|
26
|
+
for _, cb_conf in callbacks_cfg.items():
|
|
27
|
+
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
|
|
28
|
+
log.info(f"Instantiating callback <{cb_conf._target_}>")
|
|
29
|
+
callbacks.append(hydra.utils.instantiate(cb_conf))
|
|
30
|
+
|
|
31
|
+
return callbacks
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def instantiate_loggers(logger_cfg: DictConfig) -> list[Logger]:
|
|
35
|
+
"""Instantiates loggers from config.
|
|
36
|
+
|
|
37
|
+
:param logger_cfg: A DictConfig object containing logger configurations.
|
|
38
|
+
:return: A list of instantiated loggers.
|
|
39
|
+
"""
|
|
40
|
+
logger: list[Logger] = []
|
|
41
|
+
|
|
42
|
+
if not logger_cfg:
|
|
43
|
+
log.warning("No logger configs found! Skipping...")
|
|
44
|
+
return logger
|
|
45
|
+
|
|
46
|
+
if not isinstance(logger_cfg, DictConfig):
|
|
47
|
+
raise TypeError("Logger config must be a DictConfig!")
|
|
48
|
+
|
|
49
|
+
for _, lg_conf in logger_cfg.items():
|
|
50
|
+
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
|
|
51
|
+
log.info(f"Instantiating logger <{lg_conf._target_}>")
|
|
52
|
+
logger.append(hydra.utils.instantiate(lg_conf))
|
|
53
|
+
|
|
54
|
+
return logger
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from lightning_utilities.core.rank_zero import rank_zero_only
|
|
4
|
+
from omegaconf import OmegaConf
|
|
5
|
+
|
|
6
|
+
from . import pylogger
|
|
7
|
+
|
|
8
|
+
log = pylogger.RankedLogger(__name__, rank_zero_only=True)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@rank_zero_only
|
|
12
|
+
def log_hyperparameters(object_dict: dict[str, Any]) -> None:
|
|
13
|
+
"""Controls which config parts are saved by Lightning loggers.
|
|
14
|
+
|
|
15
|
+
Additionally saves:
|
|
16
|
+
- Number of model parameters
|
|
17
|
+
|
|
18
|
+
:param object_dict: A dictionary containing the following objects:
|
|
19
|
+
- `"cfg"`: A DictConfig object containing the main config.
|
|
20
|
+
- `"model"`: The Lightning model.
|
|
21
|
+
- `"trainer"`: The Lightning trainer.
|
|
22
|
+
"""
|
|
23
|
+
hparams = {}
|
|
24
|
+
|
|
25
|
+
cfg = OmegaConf.to_container(object_dict["cfg"])
|
|
26
|
+
model = object_dict["model"]
|
|
27
|
+
trainer = object_dict["trainer"]
|
|
28
|
+
|
|
29
|
+
if not trainer.logger:
|
|
30
|
+
log.warning("Logger not found! Skipping hyperparameter logging...")
|
|
31
|
+
return
|
|
32
|
+
|
|
33
|
+
hparams["model"] = cfg["model"]
|
|
34
|
+
|
|
35
|
+
# save number of model parameters
|
|
36
|
+
hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
|
|
37
|
+
hparams["model/params/trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
38
|
+
hparams["model/params/non_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad)
|
|
39
|
+
|
|
40
|
+
hparams["data"] = cfg["data"]
|
|
41
|
+
hparams["trainer"] = cfg["trainer"]
|
|
42
|
+
|
|
43
|
+
hparams["callbacks"] = cfg.get("callbacks")
|
|
44
|
+
hparams["extras"] = cfg.get("extras")
|
|
45
|
+
|
|
46
|
+
hparams["task_name"] = cfg.get("task_name")
|
|
47
|
+
hparams["tags"] = cfg.get("tags")
|
|
48
|
+
hparams["ckpt_path"] = cfg.get("ckpt_path")
|
|
49
|
+
hparams["seed"] = cfg.get("seed")
|
|
50
|
+
|
|
51
|
+
# send hparams to all loggers
|
|
52
|
+
for logger in trainer.loggers:
|
|
53
|
+
logger.log_hyperparams(hparams)
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
from safetensors.torch import load_file, save_file
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def load_safetensor(file_path: Path, device="cpu") -> dict[str, torch.Tensor]:
|
|
9
|
+
file_path = Path(file_path)
|
|
10
|
+
return load_file(file_path, device=device)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def save_ndarray_to_safetensor(ndarray: dict[str, np.ndarray], file_path: Path):
|
|
14
|
+
"""Save a dictionary of NumPy arrays as PyTorch tensors to a file.
|
|
15
|
+
|
|
16
|
+
Parameters:
|
|
17
|
+
- ndarray (dict[str, np.ndarray]): A dictionary where keys are strings and values are NumPy arrays.
|
|
18
|
+
- file_path (Path): The file path where the PyTorch tensors will be saved.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
- None
|
|
22
|
+
|
|
23
|
+
Raises:
|
|
24
|
+
- ValueError: If the input dictionary is empty.
|
|
25
|
+
- TypeError: If the input dictionary values are not NumPy arrays.
|
|
26
|
+
- FileNotFoundError: If the specified file path does not exist.
|
|
27
|
+
- OSError: If there is an issue with writing to the file.
|
|
28
|
+
|
|
29
|
+
This function converts each NumPy array in the input dictionary to a PyTorch tensor and saves the resulting dictionary
|
|
30
|
+
to a file specified by the file path. If any errors occur during the process, appropriate exceptions are raised.
|
|
31
|
+
"""
|
|
32
|
+
tensor_dict = {key: torch.tensor(value) for key, value in ndarray.items()}
|
|
33
|
+
save_file(tensor_dict, file_path)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def load_kmer2id(kmer2id_file: Path) -> dict:
|
|
37
|
+
"""Load a dictionary mapping kmer strings to integer IDs from a file.
|
|
38
|
+
|
|
39
|
+
Parameters:
|
|
40
|
+
kmer2id_file (Path): A Path object pointing to the file containing the kmer to ID mapping.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
dict: A dictionary mapping kmer strings to integer IDs.
|
|
44
|
+
|
|
45
|
+
Raises:
|
|
46
|
+
FileNotFoundError: If the specified file does not exist.
|
|
47
|
+
ValueError: If the file format is incorrect or if the kmer and ID cannot be properly parsed.
|
|
48
|
+
|
|
49
|
+
Example:
|
|
50
|
+
If the kmer2id_file contains the following lines:
|
|
51
|
+
ATG 1
|
|
52
|
+
CAG 2
|
|
53
|
+
The function will return {'ATG': 1, 'CAG': 2}.
|
|
54
|
+
"""
|
|
55
|
+
kmer2id_file = Path(kmer2id_file)
|
|
56
|
+
|
|
57
|
+
kmer2id = {}
|
|
58
|
+
with Path(kmer2id_file).open("r") as f:
|
|
59
|
+
for line in f:
|
|
60
|
+
kmer, idx = line.strip().split()
|
|
61
|
+
kmer2id[kmer] = int(idx)
|
|
62
|
+
return kmer2id
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
from rich.console import Console
|
|
6
|
+
from rich.highlighter import RegexHighlighter
|
|
7
|
+
from rich.text import Text
|
|
8
|
+
from rich.theme import Theme
|
|
9
|
+
|
|
10
|
+
from deepchopper.deepchopper import summary_predict as rust_summary_predict
|
|
11
|
+
from deepchopper.models.llm import IGNORE_INDEX
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def hightlight_predicts(
|
|
15
|
+
seq: str,
|
|
16
|
+
targets: list[tuple[int, int]],
|
|
17
|
+
predicts: list[tuple[int, int]],
|
|
18
|
+
style: str = "bold magenta",
|
|
19
|
+
width: int = 80,
|
|
20
|
+
):
|
|
21
|
+
"""Highlight the predicted and labeled sequences."""
|
|
22
|
+
target_seq = Text(seq)
|
|
23
|
+
predict_seq = Text(seq)
|
|
24
|
+
console = Console()
|
|
25
|
+
|
|
26
|
+
for start, end in targets:
|
|
27
|
+
target_seq.stylize(style, start, end)
|
|
28
|
+
|
|
29
|
+
for start, end in predicts:
|
|
30
|
+
predict_seq.stylize(style, start, end)
|
|
31
|
+
|
|
32
|
+
front1 = "L:"
|
|
33
|
+
front2 = "P:"
|
|
34
|
+
for t1, t2 in zip(target_seq.wrap(console, width), predict_seq.wrap(console, width), strict=True):
|
|
35
|
+
console.print(front1, t1)
|
|
36
|
+
console.print(front2, t2)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def highlight_targets(seq: str, targets: list[tuple[int, int]], style="bold magenta"):
|
|
40
|
+
"""Highlight the target sequences."""
|
|
41
|
+
text = Text(seq)
|
|
42
|
+
console = Console()
|
|
43
|
+
for start, end in targets:
|
|
44
|
+
text.stylize(style, start, end)
|
|
45
|
+
console.print(text)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def highlight_target(seq: str, start: int, end: int, style="bold magenta"):
|
|
49
|
+
"""Highlight the target sequence."""
|
|
50
|
+
text = Text(seq)
|
|
51
|
+
console = Console()
|
|
52
|
+
text.stylize(style, start, end)
|
|
53
|
+
console.print(text)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def hightlight_predict(seq: str, target_start: int, target_end: int, predict_start: int, predict_end: int):
|
|
57
|
+
"""Highlight the predicted sequence."""
|
|
58
|
+
text = Text(seq)
|
|
59
|
+
console = Console()
|
|
60
|
+
|
|
61
|
+
text.stylize("#adb0b1", target_start, target_end)
|
|
62
|
+
text.stylize("bold magenta", predict_start, predict_end)
|
|
63
|
+
|
|
64
|
+
console.print(text)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def summary_predict(predictions, labels):
|
|
68
|
+
"""Summarize the predictions and labels."""
|
|
69
|
+
predictions = np.argmax(predictions, axis=2)
|
|
70
|
+
# Initialize lists to hold the filtered predictions and labels
|
|
71
|
+
|
|
72
|
+
if isinstance(predictions, torch.Tensor):
|
|
73
|
+
predictions = predictions.cpu().numpy()
|
|
74
|
+
|
|
75
|
+
if isinstance(labels, torch.Tensor):
|
|
76
|
+
labels = labels.cpu().numpy()
|
|
77
|
+
|
|
78
|
+
true_predictions, true_labels = rust_summary_predict(predictions, labels, IGNORE_INDEX)
|
|
79
|
+
return true_predictions, true_labels
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class LabelHighlighter(RegexHighlighter):
|
|
83
|
+
"""Apply style to anything that looks like an email."""
|
|
84
|
+
|
|
85
|
+
base_style = "label."
|
|
86
|
+
highlights: typing.ClassVar = [r"(?P<label>1+)"]
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def alignment_predict(prediction, label):
|
|
90
|
+
"""Print the alignment of the predicted and labeled sequences."""
|
|
91
|
+
import textwrap
|
|
92
|
+
|
|
93
|
+
prediction_str = "".join(str(x) for x in prediction)
|
|
94
|
+
label_str = "".join(str(x) for x in label)
|
|
95
|
+
|
|
96
|
+
front2 = "L:"
|
|
97
|
+
front1 = "P:"
|
|
98
|
+
theme = Theme({"label.label": "bold magenta"})
|
|
99
|
+
console = Console(highlighter=LabelHighlighter(), theme=theme)
|
|
100
|
+
for l1, l2 in zip(textwrap.wrap(prediction_str), textwrap.wrap(label_str), strict=True):
|
|
101
|
+
ss = f"{front1}{l1}\n{front2}{l2}"
|
|
102
|
+
console.print(ss)
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from collections.abc import Mapping
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class RankedLogger(logging.LoggerAdapter):
|
|
13
|
+
"""A multi-GPU-friendly python command line logger."""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
name: str = __name__,
|
|
18
|
+
*,
|
|
19
|
+
rank_zero_only: bool = False,
|
|
20
|
+
extra: Mapping[str, object] | None = None,
|
|
21
|
+
) -> None:
|
|
22
|
+
"""Initializes a multi-GPU-friendly python command line logger that logs on all processes
|
|
23
|
+
with their rank prefixed in the log message.
|
|
24
|
+
|
|
25
|
+
:param name: The name of the logger. Default is ``__name__``.
|
|
26
|
+
:param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
|
|
27
|
+
:param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
|
|
28
|
+
"""
|
|
29
|
+
logger = logging.getLogger(name)
|
|
30
|
+
super().__init__(logger=logger, extra=extra)
|
|
31
|
+
self.rank_zero_only = rank_zero_only
|
|
32
|
+
|
|
33
|
+
def log(self, level: int, msg: str, rank: int | None = None, *args, **kwargs) -> None:
|
|
34
|
+
"""Delegate a log call to the underlying logger, after prefixing its message with the rank
|
|
35
|
+
of the process it's being logged from. If `'rank'` is provided, then the log will only
|
|
36
|
+
occur on that rank/process.
|
|
37
|
+
|
|
38
|
+
:param level: The level to log at. Look at `logging.__init__.py` for more information.
|
|
39
|
+
:param msg: The message to log.
|
|
40
|
+
:param rank: The rank to log at.
|
|
41
|
+
:param args: Additional args to pass to the underlying logging function.
|
|
42
|
+
:param kwargs: Any additional keyword args to pass to the underlying logging function.
|
|
43
|
+
"""
|
|
44
|
+
if self.isEnabledFor(level):
|
|
45
|
+
msg, kwargs = self.process(msg, kwargs)
|
|
46
|
+
current_rank = getattr(rank_zero_only, "rank", None)
|
|
47
|
+
if current_rank is None:
|
|
48
|
+
raise RuntimeError("The `rank_zero_only.rank` needs to be set before use")
|
|
49
|
+
msg = rank_prefixed_message(msg, current_rank)
|
|
50
|
+
if self.rank_zero_only:
|
|
51
|
+
if current_rank == 0:
|
|
52
|
+
self.logger.log(level, msg, *args, **kwargs)
|
|
53
|
+
else:
|
|
54
|
+
if rank is None:
|
|
55
|
+
self.logger.log(level, msg, *args, **kwargs)
|
|
56
|
+
elif current_rank == rank:
|
|
57
|
+
self.logger.log(level, msg, *args, **kwargs)
|