careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +163 -266
- careamics/config/algorithm_model.py +0 -15
- careamics/config/architectures/custom_model.py +3 -3
- careamics/config/configuration_example.py +0 -3
- careamics/config/configuration_factory.py +23 -25
- careamics/config/configuration_model.py +11 -11
- careamics/config/data_model.py +80 -50
- careamics/config/inference_model.py +29 -17
- careamics/config/optimizer_models.py +7 -7
- careamics/config/support/supported_transforms.py +0 -1
- careamics/config/tile_information.py +26 -58
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/validators/validator_utils.py +1 -1
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -1
- careamics/dataset/dataset_utils/file_utils.py +1 -1
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/read_tiff.py +0 -9
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +66 -171
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +92 -249
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
- careamics/dataset/patching/patching.py +54 -25
- careamics/dataset/patching/random_patching.py +9 -4
- careamics/dataset/patching/validate_patch_dimension.py +5 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
- careamics/lightning_datamodule.py +1 -6
- careamics/lightning_module.py +11 -7
- careamics/lightning_prediction_datamodule.py +52 -72
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +1 -1
- careamics/model_io/model_io_utils.py +5 -2
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/prediction_utils/__init__.py +12 -0
- careamics/prediction_utils/create_pred_datamodule.py +185 -0
- careamics/prediction_utils/prediction_outputs.py +165 -0
- careamics/prediction_utils/stitch_prediction.py +100 -0
- careamics/transforms/n2v_manipulate.py +3 -1
- careamics/transforms/normalize.py +139 -68
- careamics/transforms/pixel_manipulation.py +33 -9
- careamics/transforms/tta.py +43 -29
- careamics/utils/ram.py +2 -2
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +7 -6
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/RECORD +65 -42
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
- careamics/lightning_prediction_loop.py +0 -118
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -70
- careamics/utils/running_stats.py +0 -43
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This script contains the functions/classes to compute loss and metrics used to train and evaluate the performance of the model.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
from skimage.metrics import structural_similarity
|
|
8
|
+
from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure
|
|
9
|
+
|
|
10
|
+
from careamics.models.lvae.utils import allow_numpy
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class RunningPSNR:
|
|
14
|
+
"""
|
|
15
|
+
This class allows to compute the running PSNR during validation step in training.
|
|
16
|
+
In this way it is possible to compute the PSNR on the entire validation set one batch at the time.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self):
|
|
20
|
+
# number of elements seen so far during the epoch
|
|
21
|
+
self.N = None
|
|
22
|
+
# running sum of the MSE over the self.N elements seen so far
|
|
23
|
+
self.mse_sum = None
|
|
24
|
+
# running max and min values of the self.N target images seen so far
|
|
25
|
+
self.max = self.min = None
|
|
26
|
+
self.reset()
|
|
27
|
+
|
|
28
|
+
def reset(self):
|
|
29
|
+
"""
|
|
30
|
+
Used to reset the running PSNR (usually called at the end of each epoch).
|
|
31
|
+
"""
|
|
32
|
+
self.mse_sum = 0
|
|
33
|
+
self.N = 0
|
|
34
|
+
self.max = self.min = None
|
|
35
|
+
|
|
36
|
+
def update(self, rec: torch.Tensor, tar: torch.Tensor) -> None:
|
|
37
|
+
"""
|
|
38
|
+
Given a batch of reconstructed and target images, it updates the MSE and.
|
|
39
|
+
|
|
40
|
+
Parameters
|
|
41
|
+
----------
|
|
42
|
+
rec: torch.Tensor
|
|
43
|
+
Batch of reconstructed images (B, H, W).
|
|
44
|
+
tar: torch.Tensor
|
|
45
|
+
Batch of target images (B, H, W).
|
|
46
|
+
"""
|
|
47
|
+
ins_max = torch.max(tar).item()
|
|
48
|
+
ins_min = torch.min(tar).item()
|
|
49
|
+
if self.max is None:
|
|
50
|
+
assert self.min is None
|
|
51
|
+
self.max = ins_max
|
|
52
|
+
self.min = ins_min
|
|
53
|
+
else:
|
|
54
|
+
self.max = max(self.max, ins_max)
|
|
55
|
+
self.min = min(self.min, ins_min)
|
|
56
|
+
|
|
57
|
+
mse = (rec - tar) ** 2
|
|
58
|
+
elementwise_mse = torch.mean(mse.view(len(mse), -1), dim=1)
|
|
59
|
+
self.mse_sum += torch.nansum(elementwise_mse)
|
|
60
|
+
self.N += len(elementwise_mse) - torch.sum(torch.isnan(elementwise_mse))
|
|
61
|
+
|
|
62
|
+
def get(self):
|
|
63
|
+
"""
|
|
64
|
+
The get the actual PSNR value given the running statistics.
|
|
65
|
+
"""
|
|
66
|
+
if self.N == 0 or self.N is None:
|
|
67
|
+
return None
|
|
68
|
+
rmse = torch.sqrt(self.mse_sum / self.N)
|
|
69
|
+
return 20 * torch.log10((self.max - self.min) / rmse)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def zero_mean(x):
|
|
73
|
+
return x - torch.mean(x, dim=1, keepdim=True)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def fix_range(gt, x):
|
|
77
|
+
a = torch.sum(gt * x, dim=1, keepdim=True) / (torch.sum(x * x, dim=1, keepdim=True))
|
|
78
|
+
return x * a
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def fix(gt, x):
|
|
82
|
+
gt_ = zero_mean(gt)
|
|
83
|
+
return fix_range(gt_, zero_mean(x))
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _PSNR_internal(gt, pred, range_=None):
|
|
87
|
+
if range_ is None:
|
|
88
|
+
range_ = torch.max(gt, dim=1).values - torch.min(gt, dim=1).values
|
|
89
|
+
|
|
90
|
+
mse = torch.mean((gt - pred) ** 2, dim=1)
|
|
91
|
+
return 20 * torch.log10(range_ / torch.sqrt(mse))
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@allow_numpy
|
|
95
|
+
def PSNR(gt, pred, range_=None):
|
|
96
|
+
"""
|
|
97
|
+
Compute PSNR.
|
|
98
|
+
|
|
99
|
+
Parameters
|
|
100
|
+
----------
|
|
101
|
+
gt: array
|
|
102
|
+
Ground truth image.
|
|
103
|
+
pred: array
|
|
104
|
+
Predicted image.
|
|
105
|
+
"""
|
|
106
|
+
assert len(gt.shape) == 3, "Images must be in shape: (batch,H,W)"
|
|
107
|
+
|
|
108
|
+
gt = gt.view(len(gt), -1)
|
|
109
|
+
pred = pred.view(len(gt), -1)
|
|
110
|
+
return _PSNR_internal(gt, pred, range_=range_)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
@allow_numpy
|
|
114
|
+
def RangeInvariantPsnr(gt: torch.Tensor, pred: torch.Tensor):
|
|
115
|
+
"""
|
|
116
|
+
NOTE: Works only for grayscale images.
|
|
117
|
+
Adapted from https://github.com/juglab/ScaleInvPSNR/blob/master/psnr.py
|
|
118
|
+
It rescales the prediction to ensure that the prediction has the same range as the ground truth.
|
|
119
|
+
"""
|
|
120
|
+
assert len(gt.shape) == 3, "Images must be in shape: (batch,H,W)"
|
|
121
|
+
gt = gt.view(len(gt), -1)
|
|
122
|
+
pred = pred.view(len(gt), -1)
|
|
123
|
+
ra = (torch.max(gt, dim=1).values - torch.min(gt, dim=1).values) / torch.std(
|
|
124
|
+
gt, dim=1
|
|
125
|
+
)
|
|
126
|
+
gt_ = zero_mean(gt) / torch.std(gt, dim=1, keepdim=True)
|
|
127
|
+
return _PSNR_internal(zero_mean(gt_), fix(gt_, pred), ra)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def _avg_psnr(target, prediction, psnr_fn):
|
|
131
|
+
output = np.mean(
|
|
132
|
+
[
|
|
133
|
+
psnr_fn(target[i : i + 1], prediction[i : i + 1]).item()
|
|
134
|
+
for i in range(len(prediction))
|
|
135
|
+
]
|
|
136
|
+
)
|
|
137
|
+
return round(output, 2)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def avg_range_inv_psnr(target, prediction):
|
|
141
|
+
return _avg_psnr(target, prediction, RangeInvariantPsnr)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def avg_psnr(target, prediction):
|
|
145
|
+
return _avg_psnr(target, prediction, PSNR)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def compute_masked_psnr(mask, tar1, tar2, pred1, pred2):
|
|
149
|
+
mask = mask.astype(bool)
|
|
150
|
+
mask = mask[..., 0]
|
|
151
|
+
tmp_tar1 = tar1[mask].reshape((len(tar1), -1, 1))
|
|
152
|
+
tmp_pred1 = pred1[mask].reshape((len(tar1), -1, 1))
|
|
153
|
+
tmp_tar2 = tar2[mask].reshape((len(tar2), -1, 1))
|
|
154
|
+
tmp_pred2 = pred2[mask].reshape((len(tar2), -1, 1))
|
|
155
|
+
psnr1 = avg_range_inv_psnr(tmp_tar1, tmp_pred1)
|
|
156
|
+
psnr2 = avg_range_inv_psnr(tmp_tar2, tmp_pred2)
|
|
157
|
+
return psnr1, psnr2
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def avg_ssim(target, prediction):
|
|
161
|
+
ssim = [
|
|
162
|
+
structural_similarity(
|
|
163
|
+
target[i], prediction[i], data_range=(target[i].max() - target[i].min())
|
|
164
|
+
)
|
|
165
|
+
for i in range(len(target))
|
|
166
|
+
]
|
|
167
|
+
return np.mean(ssim), np.std(ssim)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
@allow_numpy
|
|
171
|
+
def range_invariant_multiscale_ssim(gt_, pred_):
|
|
172
|
+
"""
|
|
173
|
+
Computes range invariant multiscale ssim for one channel.
|
|
174
|
+
This has the benefit that it is invariant to scalar multiplications in the prediction.
|
|
175
|
+
"""
|
|
176
|
+
shape = gt_.shape
|
|
177
|
+
gt_ = torch.Tensor(gt_.reshape((shape[0], -1)))
|
|
178
|
+
pred_ = torch.Tensor(pred_.reshape((shape[0], -1)))
|
|
179
|
+
gt_ = zero_mean(gt_)
|
|
180
|
+
pred_ = zero_mean(pred_)
|
|
181
|
+
pred_ = fix(gt_, pred_)
|
|
182
|
+
pred_ = pred_.reshape(shape)
|
|
183
|
+
gt_ = gt_.reshape(shape)
|
|
184
|
+
|
|
185
|
+
ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(
|
|
186
|
+
data_range=gt_.max() - gt_.min()
|
|
187
|
+
)
|
|
188
|
+
return ms_ssim(torch.Tensor(pred_[:, None]), torch.Tensor(gt_[:, None])).item()
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def compute_multiscale_ssim(gt_, pred_, range_invariant=True):
|
|
192
|
+
"""
|
|
193
|
+
Computes multiscale ssim for each channel.
|
|
194
|
+
Args:
|
|
195
|
+
gt_: ground truth image with shape (N, H, W, C)
|
|
196
|
+
pred_: predicted image with shape (N, H, W, C)
|
|
197
|
+
range_invariant: whether to use range invariant multiscale ssim
|
|
198
|
+
"""
|
|
199
|
+
ms_ssim_values = {i: None for i in range(gt_.shape[-1])}
|
|
200
|
+
for ch_idx in range(gt_.shape[-1]):
|
|
201
|
+
tar_tmp = gt_[..., ch_idx]
|
|
202
|
+
pred_tmp = pred_[..., ch_idx]
|
|
203
|
+
if range_invariant:
|
|
204
|
+
ms_ssim_values[ch_idx] = range_invariant_multiscale_ssim(tar_tmp, pred_tmp)
|
|
205
|
+
else:
|
|
206
|
+
ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(
|
|
207
|
+
data_range=tar_tmp.max() - tar_tmp.min()
|
|
208
|
+
)
|
|
209
|
+
ms_ssim_values[ch_idx] = ms_ssim(
|
|
210
|
+
torch.Tensor(pred_tmp[:, None]), torch.Tensor(tar_tmp[:, None])
|
|
211
|
+
).item()
|
|
212
|
+
|
|
213
|
+
output = [ms_ssim_values[i] for i in range(gt_.shape[-1])]
|
|
214
|
+
return output
|
|
@@ -0,0 +1,339 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This script is meant to load data, intialize the model, and provide the logic for training it.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import glob
|
|
6
|
+
import os
|
|
7
|
+
import socket
|
|
8
|
+
import sys
|
|
9
|
+
from typing import Dict
|
|
10
|
+
|
|
11
|
+
import pytorch_lightning as pl
|
|
12
|
+
import torch
|
|
13
|
+
from absl import app, flags
|
|
14
|
+
from ml_collections.config_flags import config_flags
|
|
15
|
+
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
|
|
16
|
+
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
|
17
|
+
from pytorch_lightning.loggers import WandbLogger
|
|
18
|
+
from torch.utils.data import DataLoader
|
|
19
|
+
|
|
20
|
+
sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
|
|
21
|
+
print(sys.path)
|
|
22
|
+
|
|
23
|
+
from careamics.lvae_training.data_modules import LCMultiChDloader, MultiChDloader
|
|
24
|
+
from careamics.lvae_training.data_utils import DataSplitType
|
|
25
|
+
from careamics.lvae_training.lightning_module import LadderVAELight
|
|
26
|
+
from careamics.lvae_training.train_utils import *
|
|
27
|
+
|
|
28
|
+
FLAGS = flags.FLAGS
|
|
29
|
+
|
|
30
|
+
config_flags.DEFINE_config_file(
|
|
31
|
+
"config", None, "Training configuration.", lock_config=False
|
|
32
|
+
)
|
|
33
|
+
flags.DEFINE_string("workdir", None, "Work directory.")
|
|
34
|
+
flags.DEFINE_enum("mode", None, ["train", "eval"], "Running mode: train or eval")
|
|
35
|
+
flags.DEFINE_string(
|
|
36
|
+
"logdir", "/group/jug/federico/wandb_backup/", "The folder name for storing logging"
|
|
37
|
+
)
|
|
38
|
+
flags.DEFINE_string(
|
|
39
|
+
"datadir", "/group/jug/federico/careamics_training/data/BioSR", "Data directory."
|
|
40
|
+
)
|
|
41
|
+
flags.DEFINE_boolean("use_max_version", False, "Overwrite the max version of the model")
|
|
42
|
+
flags.DEFINE_string(
|
|
43
|
+
"load_ckptfpath",
|
|
44
|
+
"",
|
|
45
|
+
"The path to a previous ckpt from which the weights should be loaded",
|
|
46
|
+
)
|
|
47
|
+
flags.mark_flags_as_required(["workdir", "config", "mode"])
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def create_dataset(
|
|
51
|
+
config,
|
|
52
|
+
datadir,
|
|
53
|
+
eval_datasplit_type=DataSplitType.Val,
|
|
54
|
+
raw_data_dict=None,
|
|
55
|
+
skip_train_dataset=False,
|
|
56
|
+
kwargs_dict=None,
|
|
57
|
+
):
|
|
58
|
+
|
|
59
|
+
if kwargs_dict is None:
|
|
60
|
+
kwargs_dict = {}
|
|
61
|
+
|
|
62
|
+
datapath = datadir
|
|
63
|
+
|
|
64
|
+
# Hard-coded parameters (used to be in the config file)
|
|
65
|
+
normalized_input = True
|
|
66
|
+
use_one_mu_std = True
|
|
67
|
+
train_aug_rotate = False
|
|
68
|
+
enable_random_cropping = True
|
|
69
|
+
lowres_supervision = False
|
|
70
|
+
|
|
71
|
+
# 1) Data loader for Lateral Contextualization
|
|
72
|
+
if (
|
|
73
|
+
"multiscale_lowres_count" in config.data
|
|
74
|
+
and config.data.multiscale_lowres_count is not None
|
|
75
|
+
):
|
|
76
|
+
# Get padding attributes
|
|
77
|
+
if "padding_kwargs" not in kwargs_dict:
|
|
78
|
+
padding_kwargs = {}
|
|
79
|
+
if "padding_mode" in config.data and config.data.padding_mode is not None:
|
|
80
|
+
padding_kwargs["mode"] = config.data.padding_mode
|
|
81
|
+
else:
|
|
82
|
+
padding_kwargs["mode"] = "reflect"
|
|
83
|
+
if "padding_value" in config.data and config.data.padding_value is not None:
|
|
84
|
+
padding_kwargs["constant_values"] = config.data.padding_value
|
|
85
|
+
else:
|
|
86
|
+
padding_kwargs["constant_values"] = None
|
|
87
|
+
else:
|
|
88
|
+
padding_kwargs = kwargs_dict.pop("padding_kwargs")
|
|
89
|
+
|
|
90
|
+
train_data = (
|
|
91
|
+
None
|
|
92
|
+
if skip_train_dataset
|
|
93
|
+
else LCMultiChDloader(
|
|
94
|
+
config.data,
|
|
95
|
+
datapath,
|
|
96
|
+
datasplit_type=DataSplitType.Train,
|
|
97
|
+
val_fraction=0.1,
|
|
98
|
+
test_fraction=0.1,
|
|
99
|
+
normalized_input=normalized_input,
|
|
100
|
+
use_one_mu_std=use_one_mu_std,
|
|
101
|
+
enable_rotation_aug=train_aug_rotate,
|
|
102
|
+
enable_random_cropping=enable_random_cropping,
|
|
103
|
+
num_scales=config.data.multiscale_lowres_count,
|
|
104
|
+
lowres_supervision=lowres_supervision,
|
|
105
|
+
padding_kwargs=padding_kwargs,
|
|
106
|
+
**kwargs_dict,
|
|
107
|
+
allow_generation=True,
|
|
108
|
+
)
|
|
109
|
+
)
|
|
110
|
+
max_val = train_data.get_max_val()
|
|
111
|
+
|
|
112
|
+
val_data = LCMultiChDloader(
|
|
113
|
+
config.data,
|
|
114
|
+
datapath,
|
|
115
|
+
datasplit_type=eval_datasplit_type,
|
|
116
|
+
val_fraction=0.1,
|
|
117
|
+
test_fraction=0.1,
|
|
118
|
+
normalized_input=normalized_input,
|
|
119
|
+
use_one_mu_std=use_one_mu_std,
|
|
120
|
+
enable_rotation_aug=False, # No rotation aug on validation
|
|
121
|
+
enable_random_cropping=False,
|
|
122
|
+
# No random cropping on validation. Validation is evaluated on determistic grids
|
|
123
|
+
num_scales=config.data.multiscale_lowres_count,
|
|
124
|
+
lowres_supervision=lowres_supervision,
|
|
125
|
+
padding_kwargs=padding_kwargs,
|
|
126
|
+
allow_generation=False,
|
|
127
|
+
**kwargs_dict,
|
|
128
|
+
max_val=max_val,
|
|
129
|
+
)
|
|
130
|
+
# 2) Vanilla data loader
|
|
131
|
+
else:
|
|
132
|
+
train_data_kwargs = {"allow_generation": True, **kwargs_dict}
|
|
133
|
+
val_data_kwargs = {"allow_generation": False, **kwargs_dict}
|
|
134
|
+
|
|
135
|
+
train_data_kwargs["enable_random_cropping"] = enable_random_cropping
|
|
136
|
+
val_data_kwargs["enable_random_cropping"] = False
|
|
137
|
+
|
|
138
|
+
train_data = (
|
|
139
|
+
None
|
|
140
|
+
if skip_train_dataset
|
|
141
|
+
else MultiChDloader(
|
|
142
|
+
data_config=config.data,
|
|
143
|
+
fpath=datapath,
|
|
144
|
+
datasplit_type=DataSplitType.Train,
|
|
145
|
+
val_fraction=0.1,
|
|
146
|
+
test_fraction=0.1,
|
|
147
|
+
normalized_input=normalized_input,
|
|
148
|
+
use_one_mu_std=use_one_mu_std,
|
|
149
|
+
enable_rotation_aug=train_aug_rotate,
|
|
150
|
+
**train_data_kwargs,
|
|
151
|
+
)
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
max_val = train_data.get_max_val()
|
|
155
|
+
val_data = MultiChDloader(
|
|
156
|
+
data_config=config.data,
|
|
157
|
+
fpath=datapath,
|
|
158
|
+
datasplit_type=eval_datasplit_type,
|
|
159
|
+
val_fraction=0.1,
|
|
160
|
+
test_fraction=0.1,
|
|
161
|
+
normalized_input=normalized_input,
|
|
162
|
+
use_one_mu_std=use_one_mu_std,
|
|
163
|
+
enable_rotation_aug=False, # No rotation aug on validation
|
|
164
|
+
max_val=max_val,
|
|
165
|
+
**val_data_kwargs,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
# For normalizing, we should be using the training data's mean and std.
|
|
169
|
+
mean_val, std_val = train_data.compute_mean_std()
|
|
170
|
+
train_data.set_mean_std(mean_val, std_val)
|
|
171
|
+
val_data.set_mean_std(mean_val, std_val)
|
|
172
|
+
|
|
173
|
+
return train_data, val_data
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def create_model_and_train(
|
|
177
|
+
config: ml_collections.ConfigDict,
|
|
178
|
+
data_mean: Dict[str, torch.Tensor],
|
|
179
|
+
data_std: Dict[str, torch.Tensor],
|
|
180
|
+
logger: WandbLogger,
|
|
181
|
+
checkpoint_callback: ModelCheckpoint,
|
|
182
|
+
train_loader: DataLoader,
|
|
183
|
+
val_loader: DataLoader,
|
|
184
|
+
):
|
|
185
|
+
# tensorboard previous files.
|
|
186
|
+
for filename in glob.glob(config.workdir + "/events*"):
|
|
187
|
+
os.remove(filename)
|
|
188
|
+
|
|
189
|
+
# checkpoints
|
|
190
|
+
for filename in glob.glob(config.workdir + "/*.ckpt"):
|
|
191
|
+
os.remove(filename)
|
|
192
|
+
|
|
193
|
+
if "num_targets" in config.model:
|
|
194
|
+
target_ch = config.model.num_targets
|
|
195
|
+
else:
|
|
196
|
+
target_ch = config.data.get("num_channels", 2)
|
|
197
|
+
|
|
198
|
+
# Instantiate the model (lightning wrapper)
|
|
199
|
+
model = LadderVAELight(
|
|
200
|
+
data_mean=data_mean, data_std=data_std, config=config, target_ch=target_ch
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
# Load pre-trained weights if any
|
|
204
|
+
if config.training.pre_trained_ckpt_fpath:
|
|
205
|
+
print("Starting with pre-trained model", config.training.pre_trained_ckpt_fpath)
|
|
206
|
+
checkpoint = torch.load(config.training.pre_trained_ckpt_fpath)
|
|
207
|
+
_ = model.load_state_dict(checkpoint["state_dict"], strict=False)
|
|
208
|
+
|
|
209
|
+
estop_monitor = config.model.get("monitor", "val_loss")
|
|
210
|
+
estop_mode = MetricMonitor(estop_monitor).mode()
|
|
211
|
+
|
|
212
|
+
callbacks = [
|
|
213
|
+
EarlyStopping(
|
|
214
|
+
monitor=estop_monitor,
|
|
215
|
+
min_delta=1e-6,
|
|
216
|
+
patience=config.training.earlystop_patience,
|
|
217
|
+
verbose=True,
|
|
218
|
+
mode=estop_mode,
|
|
219
|
+
),
|
|
220
|
+
checkpoint_callback,
|
|
221
|
+
LearningRateMonitor(logging_interval="epoch"),
|
|
222
|
+
]
|
|
223
|
+
|
|
224
|
+
logger.experiment.config.update(config.to_dict())
|
|
225
|
+
# wandb.init(config=config)
|
|
226
|
+
trainer = pl.Trainer(
|
|
227
|
+
accelerator="gpu",
|
|
228
|
+
max_epochs=config.training.max_epochs,
|
|
229
|
+
gradient_clip_val=config.training.grad_clip_norm_value,
|
|
230
|
+
gradient_clip_algorithm=config.training.gradient_clip_algorithm,
|
|
231
|
+
logger=logger,
|
|
232
|
+
callbacks=callbacks,
|
|
233
|
+
# limit_train_batches = config.training.limit_train_batches,
|
|
234
|
+
precision=config.training.precision,
|
|
235
|
+
)
|
|
236
|
+
trainer.fit(model, train_loader, val_loader)
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def train_network(
|
|
240
|
+
train_loader: DataLoader,
|
|
241
|
+
val_loader: DataLoader,
|
|
242
|
+
data_mean: Dict[str, torch.Tensor],
|
|
243
|
+
data_std: Dict[str, torch.Tensor],
|
|
244
|
+
config: ml_collections.ConfigDict,
|
|
245
|
+
model_name: str,
|
|
246
|
+
logdir: str,
|
|
247
|
+
):
|
|
248
|
+
ckpt_monitor = config.model.get("monitor", "val_loss")
|
|
249
|
+
ckpt_mode = MetricMonitor(ckpt_monitor).mode()
|
|
250
|
+
checkpoint_callback = ModelCheckpoint(
|
|
251
|
+
monitor=ckpt_monitor,
|
|
252
|
+
dirpath=config.workdir,
|
|
253
|
+
filename=model_name + "_best",
|
|
254
|
+
save_last=True,
|
|
255
|
+
save_top_k=1,
|
|
256
|
+
mode=ckpt_mode,
|
|
257
|
+
)
|
|
258
|
+
checkpoint_callback.CHECKPOINT_NAME_LAST = model_name + "_last"
|
|
259
|
+
logger = WandbLogger(
|
|
260
|
+
name=os.path.join(config.hostname, config.exptname),
|
|
261
|
+
save_dir=logdir,
|
|
262
|
+
project="Disentanglement",
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
create_model_and_train(
|
|
266
|
+
config=config,
|
|
267
|
+
data_mean=data_mean,
|
|
268
|
+
data_std=data_std,
|
|
269
|
+
logger=logger,
|
|
270
|
+
checkpoint_callback=checkpoint_callback,
|
|
271
|
+
train_loader=train_loader,
|
|
272
|
+
val_loader=val_loader,
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def main(argv):
|
|
277
|
+
config = FLAGS.config
|
|
278
|
+
|
|
279
|
+
assert os.path.exists(FLAGS.workdir)
|
|
280
|
+
cur_workdir, relative_path = get_workdir(
|
|
281
|
+
config, FLAGS.workdir, FLAGS.use_max_version
|
|
282
|
+
)
|
|
283
|
+
print(f"Saving training to {cur_workdir}")
|
|
284
|
+
|
|
285
|
+
config.workdir = cur_workdir
|
|
286
|
+
config.exptname = relative_path
|
|
287
|
+
config.hostname = socket.gethostname()
|
|
288
|
+
config.datadir = FLAGS.datadir
|
|
289
|
+
config.training.pre_trained_ckpt_fpath = FLAGS.load_ckptfpath
|
|
290
|
+
|
|
291
|
+
if FLAGS.mode == "train":
|
|
292
|
+
set_logger(workdir=cur_workdir)
|
|
293
|
+
raw_data_dict = None
|
|
294
|
+
|
|
295
|
+
# From now on, config cannot be changed.
|
|
296
|
+
config = ml_collections.FrozenConfigDict(config)
|
|
297
|
+
log_config(config, cur_workdir)
|
|
298
|
+
|
|
299
|
+
train_data, val_data = create_dataset(
|
|
300
|
+
config, FLAGS.datadir, raw_data_dict=raw_data_dict
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
mean_dict, std_dict = get_mean_std_dict_for_model(config, train_data)
|
|
304
|
+
|
|
305
|
+
batch_size = config.training.batch_size
|
|
306
|
+
shuffle = True
|
|
307
|
+
train_dloader = DataLoader(
|
|
308
|
+
train_data,
|
|
309
|
+
pin_memory=False,
|
|
310
|
+
num_workers=config.training.num_workers,
|
|
311
|
+
shuffle=shuffle,
|
|
312
|
+
batch_size=batch_size,
|
|
313
|
+
)
|
|
314
|
+
val_dloader = DataLoader(
|
|
315
|
+
val_data,
|
|
316
|
+
pin_memory=False,
|
|
317
|
+
num_workers=config.training.num_workers,
|
|
318
|
+
shuffle=False,
|
|
319
|
+
batch_size=batch_size,
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
train_network(
|
|
323
|
+
train_loader=train_dloader,
|
|
324
|
+
val_loader=val_dloader,
|
|
325
|
+
data_mean=mean_dict,
|
|
326
|
+
data_std=std_dict,
|
|
327
|
+
config=config,
|
|
328
|
+
model_name="BaselineVAECL",
|
|
329
|
+
logdir=FLAGS.logdir,
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
elif FLAGS.mode == "eval":
|
|
333
|
+
pass
|
|
334
|
+
else:
|
|
335
|
+
raise ValueError(f"Mode {FLAGS.mode} not recognized.")
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
if __name__ == "__main__":
|
|
339
|
+
app.run(main)
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This script contains the utility functions for training the LVAE model.
|
|
3
|
+
These functions are mainly used in `train.py` script.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
import pickle
|
|
9
|
+
import time
|
|
10
|
+
from copy import deepcopy
|
|
11
|
+
from datetime import datetime
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
|
|
14
|
+
import ml_collections
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def log_config(config: ml_collections.ConfigDict, cur_workdir: str) -> None:
|
|
18
|
+
# Saving config file.
|
|
19
|
+
with open(os.path.join(cur_workdir, "config.pkl"), "wb") as f:
|
|
20
|
+
pickle.dump(config, f)
|
|
21
|
+
print(f"Saved config to {cur_workdir}/config.pkl")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def set_logger(workdir: str) -> None:
|
|
25
|
+
os.makedirs(workdir, exist_ok=True)
|
|
26
|
+
fstream = open(os.path.join(workdir, "stdout.txt"), "w")
|
|
27
|
+
handler = logging.StreamHandler(fstream)
|
|
28
|
+
formatter = logging.Formatter(
|
|
29
|
+
"%(levelname)s - %(filename)s - %(asctime)s - %(message)s"
|
|
30
|
+
)
|
|
31
|
+
handler.setFormatter(formatter)
|
|
32
|
+
logger = logging.getLogger()
|
|
33
|
+
logger.addHandler(handler)
|
|
34
|
+
logger.setLevel("INFO")
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def get_new_model_version(model_dir: str) -> str:
|
|
38
|
+
"""
|
|
39
|
+
A model will have multiple runs. Each run will have a different version.
|
|
40
|
+
"""
|
|
41
|
+
versions = []
|
|
42
|
+
for version_dir in os.listdir(model_dir):
|
|
43
|
+
try:
|
|
44
|
+
versions.append(int(version_dir))
|
|
45
|
+
except:
|
|
46
|
+
print(
|
|
47
|
+
f"Invalid subdirectory:{model_dir}/{version_dir}. Only integer versions are allowed"
|
|
48
|
+
)
|
|
49
|
+
exit()
|
|
50
|
+
if len(versions) == 0:
|
|
51
|
+
return "0"
|
|
52
|
+
return f"{max(versions) + 1}"
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def get_model_name(config: ml_collections.ConfigDict) -> str:
|
|
56
|
+
return "LVAE_denoiSplit"
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def get_workdir(
|
|
60
|
+
config: ml_collections.ConfigDict,
|
|
61
|
+
root_dir: str,
|
|
62
|
+
use_max_version: bool,
|
|
63
|
+
nested_call: int = 0,
|
|
64
|
+
):
|
|
65
|
+
rel_path = datetime.now().strftime("%y%m")
|
|
66
|
+
cur_workdir = os.path.join(root_dir, rel_path)
|
|
67
|
+
Path(cur_workdir).mkdir(exist_ok=True)
|
|
68
|
+
|
|
69
|
+
rel_path = os.path.join(rel_path, get_model_name(config))
|
|
70
|
+
cur_workdir = os.path.join(root_dir, rel_path)
|
|
71
|
+
Path(cur_workdir).mkdir(exist_ok=True)
|
|
72
|
+
|
|
73
|
+
if use_max_version:
|
|
74
|
+
# Used for debugging.
|
|
75
|
+
version = int(get_new_model_version(cur_workdir))
|
|
76
|
+
if version > 0:
|
|
77
|
+
version = f"{version - 1}"
|
|
78
|
+
|
|
79
|
+
rel_path = os.path.join(rel_path, str(version))
|
|
80
|
+
else:
|
|
81
|
+
rel_path = os.path.join(rel_path, get_new_model_version(cur_workdir))
|
|
82
|
+
|
|
83
|
+
cur_workdir = os.path.join(root_dir, rel_path)
|
|
84
|
+
try:
|
|
85
|
+
Path(cur_workdir).mkdir(exist_ok=False)
|
|
86
|
+
except FileExistsError:
|
|
87
|
+
print(
|
|
88
|
+
f"Workdir {cur_workdir} already exists. Probably because someother program also created the exact same directory. Trying to get a new version."
|
|
89
|
+
)
|
|
90
|
+
time.sleep(2.5)
|
|
91
|
+
if nested_call > 10:
|
|
92
|
+
raise ValueError(
|
|
93
|
+
f"Cannot create a new directory. {cur_workdir} already exists."
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
return get_workdir(config, root_dir, use_max_version, nested_call + 1)
|
|
97
|
+
|
|
98
|
+
return cur_workdir, rel_path
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def get_mean_std_dict_for_model(config, train_dset):
|
|
102
|
+
"""
|
|
103
|
+
Computes the mean and std for the model. This will be subsequently passed to the model.
|
|
104
|
+
"""
|
|
105
|
+
mean_dict, std_dict = train_dset.get_mean_std()
|
|
106
|
+
|
|
107
|
+
return deepcopy(mean_dict), deepcopy(std_dict)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class MetricMonitor:
|
|
111
|
+
def __init__(self, metric):
|
|
112
|
+
assert metric in ["val_loss", "val_psnr"]
|
|
113
|
+
self.metric = metric
|
|
114
|
+
|
|
115
|
+
def mode(self):
|
|
116
|
+
if self.metric == "val_loss":
|
|
117
|
+
return "min"
|
|
118
|
+
elif self.metric == "val_psnr":
|
|
119
|
+
return "max"
|
|
120
|
+
else:
|
|
121
|
+
raise ValueError(f"Invalid metric:{self.metric}")
|