deepchopper 1.3.0__cp310-abi3-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepchopper/__init__.py +9 -0
- deepchopper/__init__.pyi +67 -0
- deepchopper/__main__.py +4 -0
- deepchopper/cli.py +260 -0
- deepchopper/data/__init__.py +15 -0
- deepchopper/data/components/__init__.py +1 -0
- deepchopper/data/encode_fq.py +41 -0
- deepchopper/data/fq_datamodule.py +352 -0
- deepchopper/data/hg_data.py +39 -0
- deepchopper/data/only_fq.py +388 -0
- deepchopper/deepchopper.abi3.so +0 -0
- deepchopper/eval.py +86 -0
- deepchopper/models/__init__.py +4 -0
- deepchopper/models/basic_module.py +243 -0
- deepchopper/models/callbacks.py +57 -0
- deepchopper/models/cnn.py +54 -0
- deepchopper/models/components/__init__.py +1 -0
- deepchopper/models/dc_hg.py +163 -0
- deepchopper/models/llm/__init__.py +32 -0
- deepchopper/models/llm/caduceus.py +55 -0
- deepchopper/models/llm/components.py +99 -0
- deepchopper/models/llm/head.py +102 -0
- deepchopper/models/llm/hyena.py +41 -0
- deepchopper/models/llm/metric.py +44 -0
- deepchopper/models/llm/tokenizer.py +205 -0
- deepchopper/models/transformer.py +107 -0
- deepchopper/py.typed +0 -0
- deepchopper/train.py +109 -0
- deepchopper/ui/__init__.py +1 -0
- deepchopper/ui/main.py +189 -0
- deepchopper/utils/__init__.py +37 -0
- deepchopper/utils/instantiators.py +54 -0
- deepchopper/utils/logging_utils.py +53 -0
- deepchopper/utils/preprocess.py +62 -0
- deepchopper/utils/print.py +102 -0
- deepchopper/utils/pylogger.py +57 -0
- deepchopper/utils/rich_utils.py +100 -0
- deepchopper/utils/utils.py +138 -0
- deepchopper-1.3.0.dist-info/METADATA +254 -0
- deepchopper-1.3.0.dist-info/RECORD +43 -0
- deepchopper-1.3.0.dist-info/WHEEL +4 -0
- deepchopper-1.3.0.dist-info/entry_points.txt +2 -0
- deepchopper-1.3.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,388 @@
|
|
|
1
|
+
import multiprocessing
|
|
2
|
+
from collections.abc import Iterator
|
|
3
|
+
from functools import partial
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import pyfastx
|
|
8
|
+
from datasets import Dataset as HuggingFaceDataset
|
|
9
|
+
from lightning import LightningDataModule
|
|
10
|
+
from torch.utils.data import DataLoader, Dataset
|
|
11
|
+
from transformers import AutoTokenizer
|
|
12
|
+
|
|
13
|
+
import deepchopper
|
|
14
|
+
from deepchopper.models.llm import (
|
|
15
|
+
DataCollatorForTokenClassificationWithQual,
|
|
16
|
+
tokenize_and_align_labels_and_quals,
|
|
17
|
+
tokenize_and_align_labels_and_quals_ids,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def parse_fastq_file(file_path: Path, has_targets: bool = True) -> Iterator[dict]:
|
|
22
|
+
"""Parse a FastQ file using pyfastx and return a dictionary.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
file_path: Path to the FastQ file (.fq, .fastq, .fq.gz, .fastq.gz)
|
|
26
|
+
has_targets: Whether the file contains target labels in the identifier line
|
|
27
|
+
|
|
28
|
+
Raises:
|
|
29
|
+
ValueError: If file is empty, corrupted, or contains invalid records
|
|
30
|
+
RuntimeError: If parsing fails
|
|
31
|
+
"""
|
|
32
|
+
try:
|
|
33
|
+
# Use pyfastx to parse the file
|
|
34
|
+
fq = pyfastx.Fastx(str(file_path), uppercase=True)
|
|
35
|
+
|
|
36
|
+
record_count = 0
|
|
37
|
+
for name, seq, qual in fq:
|
|
38
|
+
# Validate record completeness
|
|
39
|
+
if not name or not seq or not qual:
|
|
40
|
+
msg = f"Incomplete FASTQ record at position {record_count} in {file_path}"
|
|
41
|
+
raise ValueError(msg)
|
|
42
|
+
|
|
43
|
+
# Validate sequence and quality lengths match
|
|
44
|
+
if len(seq) != len(qual):
|
|
45
|
+
msg = f"Sequence/quality length mismatch in record '{name}': seq={len(seq)}, qual={len(qual)}"
|
|
46
|
+
raise ValueError(msg)
|
|
47
|
+
|
|
48
|
+
# Parse target if present
|
|
49
|
+
target = [0, 0]
|
|
50
|
+
if has_targets:
|
|
51
|
+
try:
|
|
52
|
+
target = deepchopper.parse_target_from_id(name)
|
|
53
|
+
except Exception as e:
|
|
54
|
+
msg = f"Failed to parse target from ID '{name}': {e}"
|
|
55
|
+
raise ValueError(msg) from e
|
|
56
|
+
|
|
57
|
+
encoded_qual = deepchopper.encode_qual(qual, deepchopper.default.QUAL_OFFSET)
|
|
58
|
+
|
|
59
|
+
yield {
|
|
60
|
+
"id": name,
|
|
61
|
+
"seq": seq,
|
|
62
|
+
"qual": encoded_qual,
|
|
63
|
+
"target": target,
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
record_count += 1
|
|
67
|
+
|
|
68
|
+
# Ensure we read at least one record
|
|
69
|
+
if record_count == 0:
|
|
70
|
+
msg = f"No valid records found in {file_path}"
|
|
71
|
+
raise ValueError(msg)
|
|
72
|
+
|
|
73
|
+
except pyfastx.FastxError as e:
|
|
74
|
+
msg = f"FASTQ parsing error in {file_path}: {e}"
|
|
75
|
+
raise RuntimeError(msg) from e
|
|
76
|
+
except Exception as e:
|
|
77
|
+
# Re-raise ValueError and RuntimeError as-is
|
|
78
|
+
if isinstance(e, (ValueError, RuntimeError)):
|
|
79
|
+
raise
|
|
80
|
+
msg = f"Error parsing FastQ file {file_path}: {e}"
|
|
81
|
+
raise RuntimeError(msg) from e
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class OnlyFqDataModule(LightningDataModule):
|
|
85
|
+
"""PyTorch Lightning DataModule for genomic sequence data in FASTQ format.
|
|
86
|
+
|
|
87
|
+
This DataModule is designed to handle FASTQ files containing DNA or RNA sequences,
|
|
88
|
+
along with their associated quality scores and optional target labels embedded in the sequence identifiers.
|
|
89
|
+
It parses FASTQ files using pyfastx, encodes quality scores, and extracts targets for supervised learning tasks.
|
|
90
|
+
|
|
91
|
+
The module provides train, validation, test, and predict dataloaders compatible with PyTorch Lightning workflows.
|
|
92
|
+
It supports integration with HuggingFace tokenizers and custom data collators for token classification tasks.
|
|
93
|
+
|
|
94
|
+
Expected input:
|
|
95
|
+
- FASTQ files (.fq, .fastq, .fq.gz, .fastq.gz) with sequence identifiers optionally containing target labels.
|
|
96
|
+
- Each record includes a sequence, quality string, and (optionally) a target label.
|
|
97
|
+
|
|
98
|
+
Key features:
|
|
99
|
+
- Efficient parsing of large FASTQ files using pyfastx.
|
|
100
|
+
- Encoding of quality scores for model input.
|
|
101
|
+
- Extraction of target labels from sequence identifiers.
|
|
102
|
+
- Customizable data collation and tokenization for downstream models.
|
|
103
|
+
|
|
104
|
+
Implements the standard LightningDataModule interface:
|
|
105
|
+
- prepare_data
|
|
106
|
+
- setup
|
|
107
|
+
- train_dataloader
|
|
108
|
+
- val_dataloader
|
|
109
|
+
- test_dataloader
|
|
110
|
+
- predict_dataloader
|
|
111
|
+
- teardown
|
|
112
|
+
|
|
113
|
+
This allows you to share a full dataset without explaining how to download,
|
|
114
|
+
split, transform and process the data.
|
|
115
|
+
|
|
116
|
+
Read the docs:
|
|
117
|
+
https://lightning.ai/docs/pytorch/latest/data/datamodule.html
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
def __init__(
|
|
121
|
+
self,
|
|
122
|
+
tokenizer: AutoTokenizer,
|
|
123
|
+
train_data_path: Path,
|
|
124
|
+
val_data_path: Path | None = None,
|
|
125
|
+
test_data_path: Path | None = None,
|
|
126
|
+
predict_data_path: Path | None = None,
|
|
127
|
+
batch_size: int = 12,
|
|
128
|
+
num_workers: int = 0,
|
|
129
|
+
max_train_samples: int | None = None,
|
|
130
|
+
max_val_samples: int | None = None,
|
|
131
|
+
max_test_samples: int | None = None,
|
|
132
|
+
max_predict_samples: int | None = None,
|
|
133
|
+
*,
|
|
134
|
+
pin_memory: bool = False,
|
|
135
|
+
) -> None:
|
|
136
|
+
"""Initialize a `FqDataModule`.
|
|
137
|
+
|
|
138
|
+
:param train_val_test_split: The train, validation and test split. Defaults to `(55_000, 5_000, 10_000)`.
|
|
139
|
+
:param batch_size: The batch size. Defaults to `64`.
|
|
140
|
+
:param num_workers: The number of workers. Defaults to `0`.
|
|
141
|
+
:param pin_memory: Whether to pin memory. Defaults to `False`.
|
|
142
|
+
"""
|
|
143
|
+
super().__init__()
|
|
144
|
+
|
|
145
|
+
# this line allows to access init params with 'self.hparams' attribute
|
|
146
|
+
# also ensures init params will be stored in ckpt
|
|
147
|
+
self.save_hyperparameters(logger=False)
|
|
148
|
+
|
|
149
|
+
self.data_train: Dataset | None = None
|
|
150
|
+
self.data_val: Dataset | None = None
|
|
151
|
+
self.data_test: Dataset | None = None
|
|
152
|
+
self.batch_size_per_device = batch_size
|
|
153
|
+
self.data_collator = DataCollatorForTokenClassificationWithQual(tokenizer)
|
|
154
|
+
|
|
155
|
+
@property
|
|
156
|
+
def num_classes(self) -> int:
|
|
157
|
+
"""Get the number of classes.
|
|
158
|
+
|
|
159
|
+
:return: The number of MNIST classes (10).
|
|
160
|
+
"""
|
|
161
|
+
return 2
|
|
162
|
+
|
|
163
|
+
def prepare_data(self) -> None:
|
|
164
|
+
"""Encode the FastQ data to Parquet format."""
|
|
165
|
+
data_paths = [self.hparams.train_data_path]
|
|
166
|
+
|
|
167
|
+
if self.hparams.val_data_path is not None:
|
|
168
|
+
data_paths.append(self.hparams.val_data_path)
|
|
169
|
+
|
|
170
|
+
if self.hparams.test_data_path is not None:
|
|
171
|
+
data_paths.append(self.hparams.test_data_path)
|
|
172
|
+
|
|
173
|
+
if self.hparams.predict_data_path is not None:
|
|
174
|
+
data_paths.append(self.hparams.predict_data_path)
|
|
175
|
+
# no need to prepare data for prediction
|
|
176
|
+
return
|
|
177
|
+
|
|
178
|
+
for data_path in data_paths:
|
|
179
|
+
if not Path(data_path).exists():
|
|
180
|
+
msg = f"Data file {data_path} does not exist."
|
|
181
|
+
raise ValueError(msg)
|
|
182
|
+
|
|
183
|
+
def setup(self, stage: str | None = None) -> None:
|
|
184
|
+
"""Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
|
|
185
|
+
|
|
186
|
+
This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and
|
|
187
|
+
`trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after
|
|
188
|
+
`self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to
|
|
189
|
+
`self.setup()` once the data is prepared and available for use.
|
|
190
|
+
|
|
191
|
+
:param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``.
|
|
192
|
+
"""
|
|
193
|
+
# Divide batch size by the number of devices.
|
|
194
|
+
if self.trainer is not None:
|
|
195
|
+
if self.hparams.batch_size % self.trainer.world_size != 0:
|
|
196
|
+
msg = f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})."
|
|
197
|
+
raise RuntimeError(msg)
|
|
198
|
+
self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size
|
|
199
|
+
|
|
200
|
+
if stage == "predict":
|
|
201
|
+
if not self.hparams.predict_data_path:
|
|
202
|
+
msg = "Predict data path is required for prediction stage."
|
|
203
|
+
raise ValueError(msg)
|
|
204
|
+
|
|
205
|
+
# Calculate appropriate num_proc based on file size
|
|
206
|
+
file_path = Path(self.hparams.predict_data_path)
|
|
207
|
+
file_size_gb = file_path.stat().st_size / (1024**3)
|
|
208
|
+
|
|
209
|
+
base_num_proc = min(self.hparams.num_workers, multiprocessing.cpu_count() - 1)
|
|
210
|
+
|
|
211
|
+
# Reduce parallelism for large files to avoid memory issues
|
|
212
|
+
if file_size_gb > 1.0: # Files larger than 1GB
|
|
213
|
+
num_proc = min(max(1, base_num_proc), 4)
|
|
214
|
+
import logging
|
|
215
|
+
|
|
216
|
+
logging.info(f"Large file detected ({file_size_gb:.2f}GB), limiting num_proc to {num_proc}")
|
|
217
|
+
else:
|
|
218
|
+
num_proc = max(1, base_num_proc)
|
|
219
|
+
|
|
220
|
+
predict_dataset = HuggingFaceDataset.from_generator(
|
|
221
|
+
parse_fastq_file,
|
|
222
|
+
gen_kwargs={"file_path": self.hparams.predict_data_path, "has_targets": False},
|
|
223
|
+
num_proc=num_proc,
|
|
224
|
+
).with_format("torch")
|
|
225
|
+
|
|
226
|
+
self.data_predict = predict_dataset.map(
|
|
227
|
+
partial(
|
|
228
|
+
tokenize_and_align_labels_and_quals_ids,
|
|
229
|
+
tokenizer=self.hparams.tokenizer,
|
|
230
|
+
max_length=self.hparams.tokenizer.max_len_single_sentence,
|
|
231
|
+
),
|
|
232
|
+
num_proc=max(1, num_proc), # type: ignore
|
|
233
|
+
).remove_columns(["seq", "qual", "target"])
|
|
234
|
+
del predict_dataset
|
|
235
|
+
return
|
|
236
|
+
|
|
237
|
+
# load and split datasets only if not loaded already
|
|
238
|
+
if not self.data_train and not self.data_val and not self.data_test:
|
|
239
|
+
num_proc = min(self.hparams.num_workers, multiprocessing.cpu_count() - 1)
|
|
240
|
+
data_files = {}
|
|
241
|
+
data_files["train"] = self.hparams.train_data_path
|
|
242
|
+
|
|
243
|
+
if self.hparams.val_data_path is None:
|
|
244
|
+
msg = "Please provide a validation data path."
|
|
245
|
+
raise ValueError(msg)
|
|
246
|
+
|
|
247
|
+
if self.hparams.test_data_path is None:
|
|
248
|
+
msg = "Please provide a test data path."
|
|
249
|
+
raise ValueError(msg)
|
|
250
|
+
|
|
251
|
+
train_dataset = HuggingFaceDataset.from_generator(
|
|
252
|
+
parse_fastq_file,
|
|
253
|
+
gen_kwargs={"file_path": self.hparams.train_data_path, "has_targets": True},
|
|
254
|
+
num_proc=max(1, num_proc),
|
|
255
|
+
).with_format("torch")
|
|
256
|
+
|
|
257
|
+
val_dataset = HuggingFaceDataset.from_generator(
|
|
258
|
+
parse_fastq_file,
|
|
259
|
+
gen_kwargs={"file_path": self.hparams.val_data_path, "has_targets": True},
|
|
260
|
+
num_proc=max(1, num_proc),
|
|
261
|
+
).with_format("torch")
|
|
262
|
+
|
|
263
|
+
test_dataset = HuggingFaceDataset.from_generator(
|
|
264
|
+
parse_fastq_file,
|
|
265
|
+
gen_kwargs={"file_path": self.hparams.test_data_path, "has_targets": True},
|
|
266
|
+
num_proc=max(1, num_proc),
|
|
267
|
+
).with_format("torch")
|
|
268
|
+
|
|
269
|
+
if self.hparams.max_train_samples is not None:
|
|
270
|
+
max_train_samples = min(self.hparams.max_train_samples, len(train_dataset))
|
|
271
|
+
train_dataset = train_dataset.select(range(max_train_samples))
|
|
272
|
+
|
|
273
|
+
if self.hparams.max_val_samples is not None:
|
|
274
|
+
max_val_samples = min(self.hparams.max_val_samples, len(val_dataset))
|
|
275
|
+
val_dataset = val_dataset.select(range(max_val_samples))
|
|
276
|
+
|
|
277
|
+
if self.hparams.max_test_samples is not None:
|
|
278
|
+
max_test_samples = min(self.hparams.max_test_samples, len(test_dataset))
|
|
279
|
+
test_dataset = test_dataset.select(range(max_test_samples))
|
|
280
|
+
|
|
281
|
+
self.data_train = train_dataset.map(
|
|
282
|
+
partial(
|
|
283
|
+
tokenize_and_align_labels_and_quals,
|
|
284
|
+
tokenizer=self.hparams.tokenizer,
|
|
285
|
+
max_length=self.hparams.tokenizer.max_len_single_sentence,
|
|
286
|
+
),
|
|
287
|
+
num_proc=max(1, num_proc), # type: ignore
|
|
288
|
+
).remove_columns(["id", "seq", "qual", "target"])
|
|
289
|
+
|
|
290
|
+
self.data_val = val_dataset.map(
|
|
291
|
+
partial(
|
|
292
|
+
tokenize_and_align_labels_and_quals,
|
|
293
|
+
tokenizer=self.hparams.tokenizer,
|
|
294
|
+
max_length=self.hparams.tokenizer.max_len_single_sentence,
|
|
295
|
+
),
|
|
296
|
+
num_proc=max(1, num_proc), # type: ignore
|
|
297
|
+
).remove_columns(["id", "seq", "qual", "target"])
|
|
298
|
+
|
|
299
|
+
self.data_test = test_dataset.map(
|
|
300
|
+
partial(
|
|
301
|
+
tokenize_and_align_labels_and_quals,
|
|
302
|
+
tokenizer=self.hparams.tokenizer,
|
|
303
|
+
max_length=self.hparams.tokenizer.max_len_single_sentence,
|
|
304
|
+
),
|
|
305
|
+
num_proc=max(1, num_proc), # type: ignore
|
|
306
|
+
).remove_columns(["id", "seq", "qual", "target"])
|
|
307
|
+
|
|
308
|
+
del train_dataset, val_dataset, test_dataset
|
|
309
|
+
|
|
310
|
+
def train_dataloader(self) -> DataLoader[Any]:
|
|
311
|
+
"""Create and return the train dataloader.
|
|
312
|
+
|
|
313
|
+
:return: The train dataloader.
|
|
314
|
+
"""
|
|
315
|
+
return DataLoader(
|
|
316
|
+
dataset=self.data_train,
|
|
317
|
+
batch_size=self.batch_size_per_device,
|
|
318
|
+
num_workers=self.hparams.num_workers,
|
|
319
|
+
pin_memory=self.hparams.pin_memory,
|
|
320
|
+
collate_fn=self.data_collator.torch_call,
|
|
321
|
+
shuffle=True,
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
def val_dataloader(self) -> DataLoader[Any]:
|
|
325
|
+
"""Create and return the validation dataloader.
|
|
326
|
+
|
|
327
|
+
:return: The validation dataloader.
|
|
328
|
+
"""
|
|
329
|
+
return DataLoader(
|
|
330
|
+
dataset=self.data_val,
|
|
331
|
+
batch_size=self.batch_size_per_device,
|
|
332
|
+
num_workers=self.hparams.num_workers,
|
|
333
|
+
pin_memory=self.hparams.pin_memory,
|
|
334
|
+
collate_fn=self.data_collator.torch_call,
|
|
335
|
+
shuffle=False,
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
def test_dataloader(self) -> DataLoader[Any]:
|
|
339
|
+
"""Create and return the test dataloader.
|
|
340
|
+
|
|
341
|
+
:return: The test dataloader.
|
|
342
|
+
"""
|
|
343
|
+
return DataLoader(
|
|
344
|
+
dataset=self.data_test,
|
|
345
|
+
batch_size=self.batch_size_per_device,
|
|
346
|
+
num_workers=self.hparams.num_workers,
|
|
347
|
+
pin_memory=self.hparams.pin_memory,
|
|
348
|
+
collate_fn=self.data_collator.torch_call,
|
|
349
|
+
shuffle=False,
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
def predict_dataloader(self) -> DataLoader[Any]:
|
|
353
|
+
"""Create and return the predict dataloader.
|
|
354
|
+
|
|
355
|
+
:return: The predict dataloader.
|
|
356
|
+
"""
|
|
357
|
+
return DataLoader(
|
|
358
|
+
dataset=self.data_predict,
|
|
359
|
+
batch_size=self.batch_size_per_device,
|
|
360
|
+
num_workers=self.hparams.num_workers,
|
|
361
|
+
pin_memory=self.hparams.pin_memory,
|
|
362
|
+
collate_fn=self.data_collator.torch_call,
|
|
363
|
+
shuffle=False,
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
def teardown(self, stage: str | None = None) -> None:
|
|
367
|
+
"""Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`,.
|
|
368
|
+
|
|
369
|
+
`trainer.test()`, and `trainer.predict()`.
|
|
370
|
+
|
|
371
|
+
:param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
|
|
372
|
+
Defaults to ``None``.
|
|
373
|
+
"""
|
|
374
|
+
|
|
375
|
+
def state_dict(self) -> dict[Any, Any]:
|
|
376
|
+
"""Called when saving a checkpoint. Implement to generate and save the datamodule state.
|
|
377
|
+
|
|
378
|
+
:return: A dictionary containing the datamodule state that you want to save.
|
|
379
|
+
"""
|
|
380
|
+
return {}
|
|
381
|
+
|
|
382
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
383
|
+
"""Called when loading a checkpoint. Implement to reload datamodule state given datamodule.
|
|
384
|
+
|
|
385
|
+
`state_dict()`.
|
|
386
|
+
|
|
387
|
+
:param state_dict: The datamodule state returned by `self.state_dict()`.
|
|
388
|
+
"""
|
|
Binary file
|
deepchopper/eval.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import TYPE_CHECKING, Any
|
|
3
|
+
|
|
4
|
+
import hydra
|
|
5
|
+
from omegaconf import DictConfig
|
|
6
|
+
|
|
7
|
+
from .utils import (
|
|
8
|
+
RankedLogger,
|
|
9
|
+
extras,
|
|
10
|
+
instantiate_callbacks,
|
|
11
|
+
instantiate_loggers,
|
|
12
|
+
log_hyperparameters,
|
|
13
|
+
task_wrapper,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from lightning import Callback, LightningDataModule, LightningModule, Trainer
|
|
18
|
+
from lightning.pytorch.loggers import Logger
|
|
19
|
+
|
|
20
|
+
log = RankedLogger(__name__, rank_zero_only=True)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@task_wrapper
|
|
24
|
+
def evaluate(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
25
|
+
"""Evaluates given checkpoint on a datamodule testset.
|
|
26
|
+
|
|
27
|
+
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
|
|
28
|
+
failure. Useful for multiruns, saving info about the crash, etc.
|
|
29
|
+
|
|
30
|
+
:param cfg: DictConfig configuration composed by Hydra.
|
|
31
|
+
:return: Tuple[dict, dict] with metrics and dict with all instantiated objects.
|
|
32
|
+
"""
|
|
33
|
+
assert cfg.ckpt_path
|
|
34
|
+
|
|
35
|
+
log.info(f"Instantiating datamodule <{cfg.data._target_}>")
|
|
36
|
+
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
|
|
37
|
+
|
|
38
|
+
log.info(f"Instantiating model <{cfg.model._target_}>")
|
|
39
|
+
model: LightningModule = hydra.utils.instantiate(cfg.model)
|
|
40
|
+
|
|
41
|
+
log.info("Instantiating callbacks...")
|
|
42
|
+
callbacks: list[Callback] = instantiate_callbacks(cfg.get("callbacks"))
|
|
43
|
+
|
|
44
|
+
log.info("Instantiating loggers...")
|
|
45
|
+
logger: list[Logger] = instantiate_loggers(cfg.get("logger"))
|
|
46
|
+
|
|
47
|
+
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
|
|
48
|
+
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger, callbacks=callbacks)
|
|
49
|
+
|
|
50
|
+
object_dict = {
|
|
51
|
+
"cfg": cfg,
|
|
52
|
+
"datamodule": datamodule,
|
|
53
|
+
"model": model,
|
|
54
|
+
"logger": logger,
|
|
55
|
+
"trainer": trainer,
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
if logger:
|
|
59
|
+
log.info("Logging hyperparameters!")
|
|
60
|
+
log_hyperparameters(object_dict)
|
|
61
|
+
|
|
62
|
+
if datamodule.hparams.predict_data_path is None:
|
|
63
|
+
log.info("Starting testing!")
|
|
64
|
+
trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
|
|
65
|
+
else:
|
|
66
|
+
# for predictions use trainer.predict(...)
|
|
67
|
+
import multiprocess.context as ctx
|
|
68
|
+
|
|
69
|
+
ctx._force_start_method("spawn")
|
|
70
|
+
trainer.predict(model=model, dataloaders=datamodule, ckpt_path=cfg.ckpt_path, return_predictions=False)
|
|
71
|
+
|
|
72
|
+
metric_dict = trainer.callback_metrics
|
|
73
|
+
return metric_dict, object_dict
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@hydra.main(version_base="1.3", config_path=os.getenv("DC_CONFIG_PATH", "configs"), config_name="eval.yaml")
|
|
77
|
+
def main(cfg: DictConfig) -> None:
|
|
78
|
+
"""Main entry point for evaluation.
|
|
79
|
+
|
|
80
|
+
:param cfg: DictConfig configuration composed by Hydra.
|
|
81
|
+
"""
|
|
82
|
+
# apply extra utilities
|
|
83
|
+
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
|
|
84
|
+
extras(cfg)
|
|
85
|
+
|
|
86
|
+
evaluate(cfg)
|