deepchopper 1.3.0__cp310-abi3-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. deepchopper/__init__.py +9 -0
  2. deepchopper/__init__.pyi +67 -0
  3. deepchopper/__main__.py +4 -0
  4. deepchopper/cli.py +260 -0
  5. deepchopper/data/__init__.py +15 -0
  6. deepchopper/data/components/__init__.py +1 -0
  7. deepchopper/data/encode_fq.py +41 -0
  8. deepchopper/data/fq_datamodule.py +352 -0
  9. deepchopper/data/hg_data.py +39 -0
  10. deepchopper/data/only_fq.py +388 -0
  11. deepchopper/deepchopper.abi3.so +0 -0
  12. deepchopper/eval.py +86 -0
  13. deepchopper/models/__init__.py +4 -0
  14. deepchopper/models/basic_module.py +243 -0
  15. deepchopper/models/callbacks.py +57 -0
  16. deepchopper/models/cnn.py +54 -0
  17. deepchopper/models/components/__init__.py +1 -0
  18. deepchopper/models/dc_hg.py +163 -0
  19. deepchopper/models/llm/__init__.py +32 -0
  20. deepchopper/models/llm/caduceus.py +55 -0
  21. deepchopper/models/llm/components.py +99 -0
  22. deepchopper/models/llm/head.py +102 -0
  23. deepchopper/models/llm/hyena.py +41 -0
  24. deepchopper/models/llm/metric.py +44 -0
  25. deepchopper/models/llm/tokenizer.py +205 -0
  26. deepchopper/models/transformer.py +107 -0
  27. deepchopper/py.typed +0 -0
  28. deepchopper/train.py +109 -0
  29. deepchopper/ui/__init__.py +1 -0
  30. deepchopper/ui/main.py +189 -0
  31. deepchopper/utils/__init__.py +37 -0
  32. deepchopper/utils/instantiators.py +54 -0
  33. deepchopper/utils/logging_utils.py +53 -0
  34. deepchopper/utils/preprocess.py +62 -0
  35. deepchopper/utils/print.py +102 -0
  36. deepchopper/utils/pylogger.py +57 -0
  37. deepchopper/utils/rich_utils.py +100 -0
  38. deepchopper/utils/utils.py +138 -0
  39. deepchopper-1.3.0.dist-info/METADATA +254 -0
  40. deepchopper-1.3.0.dist-info/RECORD +43 -0
  41. deepchopper-1.3.0.dist-info/WHEEL +4 -0
  42. deepchopper-1.3.0.dist-info/entry_points.txt +2 -0
  43. deepchopper-1.3.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,388 @@
1
+ import multiprocessing
2
+ from collections.abc import Iterator
3
+ from functools import partial
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ import pyfastx
8
+ from datasets import Dataset as HuggingFaceDataset
9
+ from lightning import LightningDataModule
10
+ from torch.utils.data import DataLoader, Dataset
11
+ from transformers import AutoTokenizer
12
+
13
+ import deepchopper
14
+ from deepchopper.models.llm import (
15
+ DataCollatorForTokenClassificationWithQual,
16
+ tokenize_and_align_labels_and_quals,
17
+ tokenize_and_align_labels_and_quals_ids,
18
+ )
19
+
20
+
21
+ def parse_fastq_file(file_path: Path, has_targets: bool = True) -> Iterator[dict]:
22
+ """Parse a FastQ file using pyfastx and return a dictionary.
23
+
24
+ Args:
25
+ file_path: Path to the FastQ file (.fq, .fastq, .fq.gz, .fastq.gz)
26
+ has_targets: Whether the file contains target labels in the identifier line
27
+
28
+ Raises:
29
+ ValueError: If file is empty, corrupted, or contains invalid records
30
+ RuntimeError: If parsing fails
31
+ """
32
+ try:
33
+ # Use pyfastx to parse the file
34
+ fq = pyfastx.Fastx(str(file_path), uppercase=True)
35
+
36
+ record_count = 0
37
+ for name, seq, qual in fq:
38
+ # Validate record completeness
39
+ if not name or not seq or not qual:
40
+ msg = f"Incomplete FASTQ record at position {record_count} in {file_path}"
41
+ raise ValueError(msg)
42
+
43
+ # Validate sequence and quality lengths match
44
+ if len(seq) != len(qual):
45
+ msg = f"Sequence/quality length mismatch in record '{name}': seq={len(seq)}, qual={len(qual)}"
46
+ raise ValueError(msg)
47
+
48
+ # Parse target if present
49
+ target = [0, 0]
50
+ if has_targets:
51
+ try:
52
+ target = deepchopper.parse_target_from_id(name)
53
+ except Exception as e:
54
+ msg = f"Failed to parse target from ID '{name}': {e}"
55
+ raise ValueError(msg) from e
56
+
57
+ encoded_qual = deepchopper.encode_qual(qual, deepchopper.default.QUAL_OFFSET)
58
+
59
+ yield {
60
+ "id": name,
61
+ "seq": seq,
62
+ "qual": encoded_qual,
63
+ "target": target,
64
+ }
65
+
66
+ record_count += 1
67
+
68
+ # Ensure we read at least one record
69
+ if record_count == 0:
70
+ msg = f"No valid records found in {file_path}"
71
+ raise ValueError(msg)
72
+
73
+ except pyfastx.FastxError as e:
74
+ msg = f"FASTQ parsing error in {file_path}: {e}"
75
+ raise RuntimeError(msg) from e
76
+ except Exception as e:
77
+ # Re-raise ValueError and RuntimeError as-is
78
+ if isinstance(e, (ValueError, RuntimeError)):
79
+ raise
80
+ msg = f"Error parsing FastQ file {file_path}: {e}"
81
+ raise RuntimeError(msg) from e
82
+
83
+
84
+ class OnlyFqDataModule(LightningDataModule):
85
+ """PyTorch Lightning DataModule for genomic sequence data in FASTQ format.
86
+
87
+ This DataModule is designed to handle FASTQ files containing DNA or RNA sequences,
88
+ along with their associated quality scores and optional target labels embedded in the sequence identifiers.
89
+ It parses FASTQ files using pyfastx, encodes quality scores, and extracts targets for supervised learning tasks.
90
+
91
+ The module provides train, validation, test, and predict dataloaders compatible with PyTorch Lightning workflows.
92
+ It supports integration with HuggingFace tokenizers and custom data collators for token classification tasks.
93
+
94
+ Expected input:
95
+ - FASTQ files (.fq, .fastq, .fq.gz, .fastq.gz) with sequence identifiers optionally containing target labels.
96
+ - Each record includes a sequence, quality string, and (optionally) a target label.
97
+
98
+ Key features:
99
+ - Efficient parsing of large FASTQ files using pyfastx.
100
+ - Encoding of quality scores for model input.
101
+ - Extraction of target labels from sequence identifiers.
102
+ - Customizable data collation and tokenization for downstream models.
103
+
104
+ Implements the standard LightningDataModule interface:
105
+ - prepare_data
106
+ - setup
107
+ - train_dataloader
108
+ - val_dataloader
109
+ - test_dataloader
110
+ - predict_dataloader
111
+ - teardown
112
+
113
+ This allows you to share a full dataset without explaining how to download,
114
+ split, transform and process the data.
115
+
116
+ Read the docs:
117
+ https://lightning.ai/docs/pytorch/latest/data/datamodule.html
118
+ """
119
+
120
+ def __init__(
121
+ self,
122
+ tokenizer: AutoTokenizer,
123
+ train_data_path: Path,
124
+ val_data_path: Path | None = None,
125
+ test_data_path: Path | None = None,
126
+ predict_data_path: Path | None = None,
127
+ batch_size: int = 12,
128
+ num_workers: int = 0,
129
+ max_train_samples: int | None = None,
130
+ max_val_samples: int | None = None,
131
+ max_test_samples: int | None = None,
132
+ max_predict_samples: int | None = None,
133
+ *,
134
+ pin_memory: bool = False,
135
+ ) -> None:
136
+ """Initialize a `FqDataModule`.
137
+
138
+ :param train_val_test_split: The train, validation and test split. Defaults to `(55_000, 5_000, 10_000)`.
139
+ :param batch_size: The batch size. Defaults to `64`.
140
+ :param num_workers: The number of workers. Defaults to `0`.
141
+ :param pin_memory: Whether to pin memory. Defaults to `False`.
142
+ """
143
+ super().__init__()
144
+
145
+ # this line allows to access init params with 'self.hparams' attribute
146
+ # also ensures init params will be stored in ckpt
147
+ self.save_hyperparameters(logger=False)
148
+
149
+ self.data_train: Dataset | None = None
150
+ self.data_val: Dataset | None = None
151
+ self.data_test: Dataset | None = None
152
+ self.batch_size_per_device = batch_size
153
+ self.data_collator = DataCollatorForTokenClassificationWithQual(tokenizer)
154
+
155
+ @property
156
+ def num_classes(self) -> int:
157
+ """Get the number of classes.
158
+
159
+ :return: The number of MNIST classes (10).
160
+ """
161
+ return 2
162
+
163
+ def prepare_data(self) -> None:
164
+ """Encode the FastQ data to Parquet format."""
165
+ data_paths = [self.hparams.train_data_path]
166
+
167
+ if self.hparams.val_data_path is not None:
168
+ data_paths.append(self.hparams.val_data_path)
169
+
170
+ if self.hparams.test_data_path is not None:
171
+ data_paths.append(self.hparams.test_data_path)
172
+
173
+ if self.hparams.predict_data_path is not None:
174
+ data_paths.append(self.hparams.predict_data_path)
175
+ # no need to prepare data for prediction
176
+ return
177
+
178
+ for data_path in data_paths:
179
+ if not Path(data_path).exists():
180
+ msg = f"Data file {data_path} does not exist."
181
+ raise ValueError(msg)
182
+
183
+ def setup(self, stage: str | None = None) -> None:
184
+ """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
185
+
186
+ This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and
187
+ `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after
188
+ `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to
189
+ `self.setup()` once the data is prepared and available for use.
190
+
191
+ :param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``.
192
+ """
193
+ # Divide batch size by the number of devices.
194
+ if self.trainer is not None:
195
+ if self.hparams.batch_size % self.trainer.world_size != 0:
196
+ msg = f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})."
197
+ raise RuntimeError(msg)
198
+ self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size
199
+
200
+ if stage == "predict":
201
+ if not self.hparams.predict_data_path:
202
+ msg = "Predict data path is required for prediction stage."
203
+ raise ValueError(msg)
204
+
205
+ # Calculate appropriate num_proc based on file size
206
+ file_path = Path(self.hparams.predict_data_path)
207
+ file_size_gb = file_path.stat().st_size / (1024**3)
208
+
209
+ base_num_proc = min(self.hparams.num_workers, multiprocessing.cpu_count() - 1)
210
+
211
+ # Reduce parallelism for large files to avoid memory issues
212
+ if file_size_gb > 1.0: # Files larger than 1GB
213
+ num_proc = min(max(1, base_num_proc), 4)
214
+ import logging
215
+
216
+ logging.info(f"Large file detected ({file_size_gb:.2f}GB), limiting num_proc to {num_proc}")
217
+ else:
218
+ num_proc = max(1, base_num_proc)
219
+
220
+ predict_dataset = HuggingFaceDataset.from_generator(
221
+ parse_fastq_file,
222
+ gen_kwargs={"file_path": self.hparams.predict_data_path, "has_targets": False},
223
+ num_proc=num_proc,
224
+ ).with_format("torch")
225
+
226
+ self.data_predict = predict_dataset.map(
227
+ partial(
228
+ tokenize_and_align_labels_and_quals_ids,
229
+ tokenizer=self.hparams.tokenizer,
230
+ max_length=self.hparams.tokenizer.max_len_single_sentence,
231
+ ),
232
+ num_proc=max(1, num_proc), # type: ignore
233
+ ).remove_columns(["seq", "qual", "target"])
234
+ del predict_dataset
235
+ return
236
+
237
+ # load and split datasets only if not loaded already
238
+ if not self.data_train and not self.data_val and not self.data_test:
239
+ num_proc = min(self.hparams.num_workers, multiprocessing.cpu_count() - 1)
240
+ data_files = {}
241
+ data_files["train"] = self.hparams.train_data_path
242
+
243
+ if self.hparams.val_data_path is None:
244
+ msg = "Please provide a validation data path."
245
+ raise ValueError(msg)
246
+
247
+ if self.hparams.test_data_path is None:
248
+ msg = "Please provide a test data path."
249
+ raise ValueError(msg)
250
+
251
+ train_dataset = HuggingFaceDataset.from_generator(
252
+ parse_fastq_file,
253
+ gen_kwargs={"file_path": self.hparams.train_data_path, "has_targets": True},
254
+ num_proc=max(1, num_proc),
255
+ ).with_format("torch")
256
+
257
+ val_dataset = HuggingFaceDataset.from_generator(
258
+ parse_fastq_file,
259
+ gen_kwargs={"file_path": self.hparams.val_data_path, "has_targets": True},
260
+ num_proc=max(1, num_proc),
261
+ ).with_format("torch")
262
+
263
+ test_dataset = HuggingFaceDataset.from_generator(
264
+ parse_fastq_file,
265
+ gen_kwargs={"file_path": self.hparams.test_data_path, "has_targets": True},
266
+ num_proc=max(1, num_proc),
267
+ ).with_format("torch")
268
+
269
+ if self.hparams.max_train_samples is not None:
270
+ max_train_samples = min(self.hparams.max_train_samples, len(train_dataset))
271
+ train_dataset = train_dataset.select(range(max_train_samples))
272
+
273
+ if self.hparams.max_val_samples is not None:
274
+ max_val_samples = min(self.hparams.max_val_samples, len(val_dataset))
275
+ val_dataset = val_dataset.select(range(max_val_samples))
276
+
277
+ if self.hparams.max_test_samples is not None:
278
+ max_test_samples = min(self.hparams.max_test_samples, len(test_dataset))
279
+ test_dataset = test_dataset.select(range(max_test_samples))
280
+
281
+ self.data_train = train_dataset.map(
282
+ partial(
283
+ tokenize_and_align_labels_and_quals,
284
+ tokenizer=self.hparams.tokenizer,
285
+ max_length=self.hparams.tokenizer.max_len_single_sentence,
286
+ ),
287
+ num_proc=max(1, num_proc), # type: ignore
288
+ ).remove_columns(["id", "seq", "qual", "target"])
289
+
290
+ self.data_val = val_dataset.map(
291
+ partial(
292
+ tokenize_and_align_labels_and_quals,
293
+ tokenizer=self.hparams.tokenizer,
294
+ max_length=self.hparams.tokenizer.max_len_single_sentence,
295
+ ),
296
+ num_proc=max(1, num_proc), # type: ignore
297
+ ).remove_columns(["id", "seq", "qual", "target"])
298
+
299
+ self.data_test = test_dataset.map(
300
+ partial(
301
+ tokenize_and_align_labels_and_quals,
302
+ tokenizer=self.hparams.tokenizer,
303
+ max_length=self.hparams.tokenizer.max_len_single_sentence,
304
+ ),
305
+ num_proc=max(1, num_proc), # type: ignore
306
+ ).remove_columns(["id", "seq", "qual", "target"])
307
+
308
+ del train_dataset, val_dataset, test_dataset
309
+
310
+ def train_dataloader(self) -> DataLoader[Any]:
311
+ """Create and return the train dataloader.
312
+
313
+ :return: The train dataloader.
314
+ """
315
+ return DataLoader(
316
+ dataset=self.data_train,
317
+ batch_size=self.batch_size_per_device,
318
+ num_workers=self.hparams.num_workers,
319
+ pin_memory=self.hparams.pin_memory,
320
+ collate_fn=self.data_collator.torch_call,
321
+ shuffle=True,
322
+ )
323
+
324
+ def val_dataloader(self) -> DataLoader[Any]:
325
+ """Create and return the validation dataloader.
326
+
327
+ :return: The validation dataloader.
328
+ """
329
+ return DataLoader(
330
+ dataset=self.data_val,
331
+ batch_size=self.batch_size_per_device,
332
+ num_workers=self.hparams.num_workers,
333
+ pin_memory=self.hparams.pin_memory,
334
+ collate_fn=self.data_collator.torch_call,
335
+ shuffle=False,
336
+ )
337
+
338
+ def test_dataloader(self) -> DataLoader[Any]:
339
+ """Create and return the test dataloader.
340
+
341
+ :return: The test dataloader.
342
+ """
343
+ return DataLoader(
344
+ dataset=self.data_test,
345
+ batch_size=self.batch_size_per_device,
346
+ num_workers=self.hparams.num_workers,
347
+ pin_memory=self.hparams.pin_memory,
348
+ collate_fn=self.data_collator.torch_call,
349
+ shuffle=False,
350
+ )
351
+
352
+ def predict_dataloader(self) -> DataLoader[Any]:
353
+ """Create and return the predict dataloader.
354
+
355
+ :return: The predict dataloader.
356
+ """
357
+ return DataLoader(
358
+ dataset=self.data_predict,
359
+ batch_size=self.batch_size_per_device,
360
+ num_workers=self.hparams.num_workers,
361
+ pin_memory=self.hparams.pin_memory,
362
+ collate_fn=self.data_collator.torch_call,
363
+ shuffle=False,
364
+ )
365
+
366
+ def teardown(self, stage: str | None = None) -> None:
367
+ """Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`,.
368
+
369
+ `trainer.test()`, and `trainer.predict()`.
370
+
371
+ :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
372
+ Defaults to ``None``.
373
+ """
374
+
375
+ def state_dict(self) -> dict[Any, Any]:
376
+ """Called when saving a checkpoint. Implement to generate and save the datamodule state.
377
+
378
+ :return: A dictionary containing the datamodule state that you want to save.
379
+ """
380
+ return {}
381
+
382
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
383
+ """Called when loading a checkpoint. Implement to reload datamodule state given datamodule.
384
+
385
+ `state_dict()`.
386
+
387
+ :param state_dict: The datamodule state returned by `self.state_dict()`.
388
+ """
Binary file
deepchopper/eval.py ADDED
@@ -0,0 +1,86 @@
1
+ import os
2
+ from typing import TYPE_CHECKING, Any
3
+
4
+ import hydra
5
+ from omegaconf import DictConfig
6
+
7
+ from .utils import (
8
+ RankedLogger,
9
+ extras,
10
+ instantiate_callbacks,
11
+ instantiate_loggers,
12
+ log_hyperparameters,
13
+ task_wrapper,
14
+ )
15
+
16
+ if TYPE_CHECKING:
17
+ from lightning import Callback, LightningDataModule, LightningModule, Trainer
18
+ from lightning.pytorch.loggers import Logger
19
+
20
+ log = RankedLogger(__name__, rank_zero_only=True)
21
+
22
+
23
+ @task_wrapper
24
+ def evaluate(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]:
25
+ """Evaluates given checkpoint on a datamodule testset.
26
+
27
+ This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
28
+ failure. Useful for multiruns, saving info about the crash, etc.
29
+
30
+ :param cfg: DictConfig configuration composed by Hydra.
31
+ :return: Tuple[dict, dict] with metrics and dict with all instantiated objects.
32
+ """
33
+ assert cfg.ckpt_path
34
+
35
+ log.info(f"Instantiating datamodule <{cfg.data._target_}>")
36
+ datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
37
+
38
+ log.info(f"Instantiating model <{cfg.model._target_}>")
39
+ model: LightningModule = hydra.utils.instantiate(cfg.model)
40
+
41
+ log.info("Instantiating callbacks...")
42
+ callbacks: list[Callback] = instantiate_callbacks(cfg.get("callbacks"))
43
+
44
+ log.info("Instantiating loggers...")
45
+ logger: list[Logger] = instantiate_loggers(cfg.get("logger"))
46
+
47
+ log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
48
+ trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger, callbacks=callbacks)
49
+
50
+ object_dict = {
51
+ "cfg": cfg,
52
+ "datamodule": datamodule,
53
+ "model": model,
54
+ "logger": logger,
55
+ "trainer": trainer,
56
+ }
57
+
58
+ if logger:
59
+ log.info("Logging hyperparameters!")
60
+ log_hyperparameters(object_dict)
61
+
62
+ if datamodule.hparams.predict_data_path is None:
63
+ log.info("Starting testing!")
64
+ trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
65
+ else:
66
+ # for predictions use trainer.predict(...)
67
+ import multiprocess.context as ctx
68
+
69
+ ctx._force_start_method("spawn")
70
+ trainer.predict(model=model, dataloaders=datamodule, ckpt_path=cfg.ckpt_path, return_predictions=False)
71
+
72
+ metric_dict = trainer.callback_metrics
73
+ return metric_dict, object_dict
74
+
75
+
76
+ @hydra.main(version_base="1.3", config_path=os.getenv("DC_CONFIG_PATH", "configs"), config_name="eval.yaml")
77
+ def main(cfg: DictConfig) -> None:
78
+ """Main entry point for evaluation.
79
+
80
+ :param cfg: DictConfig configuration composed by Hydra.
81
+ """
82
+ # apply extra utilities
83
+ # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
84
+ extras(cfg)
85
+
86
+ evaluate(cfg)
@@ -0,0 +1,4 @@
1
+ """models."""
2
+
3
+ from . import basic_module, callbacks
4
+ from .dc_hg import DeepChopper