careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (69) hide show
  1. careamics/careamist.py +163 -266
  2. careamics/config/algorithm_model.py +0 -15
  3. careamics/config/architectures/custom_model.py +3 -3
  4. careamics/config/configuration_example.py +0 -3
  5. careamics/config/configuration_factory.py +23 -25
  6. careamics/config/configuration_model.py +11 -11
  7. careamics/config/data_model.py +80 -50
  8. careamics/config/inference_model.py +29 -17
  9. careamics/config/optimizer_models.py +7 -7
  10. careamics/config/support/supported_transforms.py +0 -1
  11. careamics/config/tile_information.py +26 -58
  12. careamics/config/transformations/normalize_model.py +32 -4
  13. careamics/config/validators/validator_utils.py +1 -1
  14. careamics/dataset/__init__.py +12 -1
  15. careamics/dataset/dataset_utils/__init__.py +8 -1
  16. careamics/dataset/dataset_utils/file_utils.py +1 -1
  17. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  18. careamics/dataset/dataset_utils/read_tiff.py +0 -9
  19. careamics/dataset/dataset_utils/running_stats.py +186 -0
  20. careamics/dataset/in_memory_dataset.py +66 -171
  21. careamics/dataset/in_memory_pred_dataset.py +88 -0
  22. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  23. careamics/dataset/iterable_dataset.py +92 -249
  24. careamics/dataset/iterable_pred_dataset.py +121 -0
  25. careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
  26. careamics/dataset/patching/patching.py +54 -25
  27. careamics/dataset/patching/random_patching.py +9 -4
  28. careamics/dataset/patching/validate_patch_dimension.py +5 -3
  29. careamics/dataset/tiling/__init__.py +10 -0
  30. careamics/dataset/tiling/collate_tiles.py +33 -0
  31. careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
  32. careamics/lightning_datamodule.py +1 -6
  33. careamics/lightning_module.py +11 -7
  34. careamics/lightning_prediction_datamodule.py +52 -72
  35. careamics/lvae_training/__init__.py +0 -0
  36. careamics/lvae_training/data_modules.py +1220 -0
  37. careamics/lvae_training/data_utils.py +618 -0
  38. careamics/lvae_training/eval_utils.py +905 -0
  39. careamics/lvae_training/get_config.py +84 -0
  40. careamics/lvae_training/lightning_module.py +701 -0
  41. careamics/lvae_training/metrics.py +214 -0
  42. careamics/lvae_training/train_lvae.py +339 -0
  43. careamics/lvae_training/train_utils.py +121 -0
  44. careamics/model_io/bioimage/model_description.py +40 -32
  45. careamics/model_io/bmz_io.py +1 -1
  46. careamics/model_io/model_io_utils.py +5 -2
  47. careamics/models/lvae/__init__.py +0 -0
  48. careamics/models/lvae/layers.py +1998 -0
  49. careamics/models/lvae/likelihoods.py +312 -0
  50. careamics/models/lvae/lvae.py +985 -0
  51. careamics/models/lvae/noise_models.py +409 -0
  52. careamics/models/lvae/utils.py +395 -0
  53. careamics/prediction_utils/__init__.py +12 -0
  54. careamics/prediction_utils/create_pred_datamodule.py +185 -0
  55. careamics/prediction_utils/prediction_outputs.py +165 -0
  56. careamics/prediction_utils/stitch_prediction.py +100 -0
  57. careamics/transforms/n2v_manipulate.py +3 -1
  58. careamics/transforms/normalize.py +139 -68
  59. careamics/transforms/pixel_manipulation.py +33 -9
  60. careamics/transforms/tta.py +43 -29
  61. careamics/utils/ram.py +2 -2
  62. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +7 -6
  63. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/RECORD +65 -42
  64. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
  65. careamics/lightning_prediction_loop.py +0 -118
  66. careamics/prediction/__init__.py +0 -7
  67. careamics/prediction/stitch_prediction.py +0 -70
  68. careamics/utils/running_stats.py +0 -43
  69. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,214 @@
1
+ """
2
+ This script contains the functions/classes to compute loss and metrics used to train and evaluate the performance of the model.
3
+ """
4
+
5
+ import numpy as np
6
+ import torch
7
+ from skimage.metrics import structural_similarity
8
+ from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure
9
+
10
+ from careamics.models.lvae.utils import allow_numpy
11
+
12
+
13
+ class RunningPSNR:
14
+ """
15
+ This class allows to compute the running PSNR during validation step in training.
16
+ In this way it is possible to compute the PSNR on the entire validation set one batch at the time.
17
+ """
18
+
19
+ def __init__(self):
20
+ # number of elements seen so far during the epoch
21
+ self.N = None
22
+ # running sum of the MSE over the self.N elements seen so far
23
+ self.mse_sum = None
24
+ # running max and min values of the self.N target images seen so far
25
+ self.max = self.min = None
26
+ self.reset()
27
+
28
+ def reset(self):
29
+ """
30
+ Used to reset the running PSNR (usually called at the end of each epoch).
31
+ """
32
+ self.mse_sum = 0
33
+ self.N = 0
34
+ self.max = self.min = None
35
+
36
+ def update(self, rec: torch.Tensor, tar: torch.Tensor) -> None:
37
+ """
38
+ Given a batch of reconstructed and target images, it updates the MSE and.
39
+
40
+ Parameters
41
+ ----------
42
+ rec: torch.Tensor
43
+ Batch of reconstructed images (B, H, W).
44
+ tar: torch.Tensor
45
+ Batch of target images (B, H, W).
46
+ """
47
+ ins_max = torch.max(tar).item()
48
+ ins_min = torch.min(tar).item()
49
+ if self.max is None:
50
+ assert self.min is None
51
+ self.max = ins_max
52
+ self.min = ins_min
53
+ else:
54
+ self.max = max(self.max, ins_max)
55
+ self.min = min(self.min, ins_min)
56
+
57
+ mse = (rec - tar) ** 2
58
+ elementwise_mse = torch.mean(mse.view(len(mse), -1), dim=1)
59
+ self.mse_sum += torch.nansum(elementwise_mse)
60
+ self.N += len(elementwise_mse) - torch.sum(torch.isnan(elementwise_mse))
61
+
62
+ def get(self):
63
+ """
64
+ The get the actual PSNR value given the running statistics.
65
+ """
66
+ if self.N == 0 or self.N is None:
67
+ return None
68
+ rmse = torch.sqrt(self.mse_sum / self.N)
69
+ return 20 * torch.log10((self.max - self.min) / rmse)
70
+
71
+
72
+ def zero_mean(x):
73
+ return x - torch.mean(x, dim=1, keepdim=True)
74
+
75
+
76
+ def fix_range(gt, x):
77
+ a = torch.sum(gt * x, dim=1, keepdim=True) / (torch.sum(x * x, dim=1, keepdim=True))
78
+ return x * a
79
+
80
+
81
+ def fix(gt, x):
82
+ gt_ = zero_mean(gt)
83
+ return fix_range(gt_, zero_mean(x))
84
+
85
+
86
+ def _PSNR_internal(gt, pred, range_=None):
87
+ if range_ is None:
88
+ range_ = torch.max(gt, dim=1).values - torch.min(gt, dim=1).values
89
+
90
+ mse = torch.mean((gt - pred) ** 2, dim=1)
91
+ return 20 * torch.log10(range_ / torch.sqrt(mse))
92
+
93
+
94
+ @allow_numpy
95
+ def PSNR(gt, pred, range_=None):
96
+ """
97
+ Compute PSNR.
98
+
99
+ Parameters
100
+ ----------
101
+ gt: array
102
+ Ground truth image.
103
+ pred: array
104
+ Predicted image.
105
+ """
106
+ assert len(gt.shape) == 3, "Images must be in shape: (batch,H,W)"
107
+
108
+ gt = gt.view(len(gt), -1)
109
+ pred = pred.view(len(gt), -1)
110
+ return _PSNR_internal(gt, pred, range_=range_)
111
+
112
+
113
+ @allow_numpy
114
+ def RangeInvariantPsnr(gt: torch.Tensor, pred: torch.Tensor):
115
+ """
116
+ NOTE: Works only for grayscale images.
117
+ Adapted from https://github.com/juglab/ScaleInvPSNR/blob/master/psnr.py
118
+ It rescales the prediction to ensure that the prediction has the same range as the ground truth.
119
+ """
120
+ assert len(gt.shape) == 3, "Images must be in shape: (batch,H,W)"
121
+ gt = gt.view(len(gt), -1)
122
+ pred = pred.view(len(gt), -1)
123
+ ra = (torch.max(gt, dim=1).values - torch.min(gt, dim=1).values) / torch.std(
124
+ gt, dim=1
125
+ )
126
+ gt_ = zero_mean(gt) / torch.std(gt, dim=1, keepdim=True)
127
+ return _PSNR_internal(zero_mean(gt_), fix(gt_, pred), ra)
128
+
129
+
130
+ def _avg_psnr(target, prediction, psnr_fn):
131
+ output = np.mean(
132
+ [
133
+ psnr_fn(target[i : i + 1], prediction[i : i + 1]).item()
134
+ for i in range(len(prediction))
135
+ ]
136
+ )
137
+ return round(output, 2)
138
+
139
+
140
+ def avg_range_inv_psnr(target, prediction):
141
+ return _avg_psnr(target, prediction, RangeInvariantPsnr)
142
+
143
+
144
+ def avg_psnr(target, prediction):
145
+ return _avg_psnr(target, prediction, PSNR)
146
+
147
+
148
+ def compute_masked_psnr(mask, tar1, tar2, pred1, pred2):
149
+ mask = mask.astype(bool)
150
+ mask = mask[..., 0]
151
+ tmp_tar1 = tar1[mask].reshape((len(tar1), -1, 1))
152
+ tmp_pred1 = pred1[mask].reshape((len(tar1), -1, 1))
153
+ tmp_tar2 = tar2[mask].reshape((len(tar2), -1, 1))
154
+ tmp_pred2 = pred2[mask].reshape((len(tar2), -1, 1))
155
+ psnr1 = avg_range_inv_psnr(tmp_tar1, tmp_pred1)
156
+ psnr2 = avg_range_inv_psnr(tmp_tar2, tmp_pred2)
157
+ return psnr1, psnr2
158
+
159
+
160
+ def avg_ssim(target, prediction):
161
+ ssim = [
162
+ structural_similarity(
163
+ target[i], prediction[i], data_range=(target[i].max() - target[i].min())
164
+ )
165
+ for i in range(len(target))
166
+ ]
167
+ return np.mean(ssim), np.std(ssim)
168
+
169
+
170
+ @allow_numpy
171
+ def range_invariant_multiscale_ssim(gt_, pred_):
172
+ """
173
+ Computes range invariant multiscale ssim for one channel.
174
+ This has the benefit that it is invariant to scalar multiplications in the prediction.
175
+ """
176
+ shape = gt_.shape
177
+ gt_ = torch.Tensor(gt_.reshape((shape[0], -1)))
178
+ pred_ = torch.Tensor(pred_.reshape((shape[0], -1)))
179
+ gt_ = zero_mean(gt_)
180
+ pred_ = zero_mean(pred_)
181
+ pred_ = fix(gt_, pred_)
182
+ pred_ = pred_.reshape(shape)
183
+ gt_ = gt_.reshape(shape)
184
+
185
+ ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(
186
+ data_range=gt_.max() - gt_.min()
187
+ )
188
+ return ms_ssim(torch.Tensor(pred_[:, None]), torch.Tensor(gt_[:, None])).item()
189
+
190
+
191
+ def compute_multiscale_ssim(gt_, pred_, range_invariant=True):
192
+ """
193
+ Computes multiscale ssim for each channel.
194
+ Args:
195
+ gt_: ground truth image with shape (N, H, W, C)
196
+ pred_: predicted image with shape (N, H, W, C)
197
+ range_invariant: whether to use range invariant multiscale ssim
198
+ """
199
+ ms_ssim_values = {i: None for i in range(gt_.shape[-1])}
200
+ for ch_idx in range(gt_.shape[-1]):
201
+ tar_tmp = gt_[..., ch_idx]
202
+ pred_tmp = pred_[..., ch_idx]
203
+ if range_invariant:
204
+ ms_ssim_values[ch_idx] = range_invariant_multiscale_ssim(tar_tmp, pred_tmp)
205
+ else:
206
+ ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(
207
+ data_range=tar_tmp.max() - tar_tmp.min()
208
+ )
209
+ ms_ssim_values[ch_idx] = ms_ssim(
210
+ torch.Tensor(pred_tmp[:, None]), torch.Tensor(tar_tmp[:, None])
211
+ ).item()
212
+
213
+ output = [ms_ssim_values[i] for i in range(gt_.shape[-1])]
214
+ return output
@@ -0,0 +1,339 @@
1
+ """
2
+ This script is meant to load data, intialize the model, and provide the logic for training it.
3
+ """
4
+
5
+ import glob
6
+ import os
7
+ import socket
8
+ import sys
9
+ from typing import Dict
10
+
11
+ import pytorch_lightning as pl
12
+ import torch
13
+ from absl import app, flags
14
+ from ml_collections.config_flags import config_flags
15
+ from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
16
+ from pytorch_lightning.callbacks.early_stopping import EarlyStopping
17
+ from pytorch_lightning.loggers import WandbLogger
18
+ from torch.utils.data import DataLoader
19
+
20
+ sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
21
+ print(sys.path)
22
+
23
+ from careamics.lvae_training.data_modules import LCMultiChDloader, MultiChDloader
24
+ from careamics.lvae_training.data_utils import DataSplitType
25
+ from careamics.lvae_training.lightning_module import LadderVAELight
26
+ from careamics.lvae_training.train_utils import *
27
+
28
+ FLAGS = flags.FLAGS
29
+
30
+ config_flags.DEFINE_config_file(
31
+ "config", None, "Training configuration.", lock_config=False
32
+ )
33
+ flags.DEFINE_string("workdir", None, "Work directory.")
34
+ flags.DEFINE_enum("mode", None, ["train", "eval"], "Running mode: train or eval")
35
+ flags.DEFINE_string(
36
+ "logdir", "/group/jug/federico/wandb_backup/", "The folder name for storing logging"
37
+ )
38
+ flags.DEFINE_string(
39
+ "datadir", "/group/jug/federico/careamics_training/data/BioSR", "Data directory."
40
+ )
41
+ flags.DEFINE_boolean("use_max_version", False, "Overwrite the max version of the model")
42
+ flags.DEFINE_string(
43
+ "load_ckptfpath",
44
+ "",
45
+ "The path to a previous ckpt from which the weights should be loaded",
46
+ )
47
+ flags.mark_flags_as_required(["workdir", "config", "mode"])
48
+
49
+
50
+ def create_dataset(
51
+ config,
52
+ datadir,
53
+ eval_datasplit_type=DataSplitType.Val,
54
+ raw_data_dict=None,
55
+ skip_train_dataset=False,
56
+ kwargs_dict=None,
57
+ ):
58
+
59
+ if kwargs_dict is None:
60
+ kwargs_dict = {}
61
+
62
+ datapath = datadir
63
+
64
+ # Hard-coded parameters (used to be in the config file)
65
+ normalized_input = True
66
+ use_one_mu_std = True
67
+ train_aug_rotate = False
68
+ enable_random_cropping = True
69
+ lowres_supervision = False
70
+
71
+ # 1) Data loader for Lateral Contextualization
72
+ if (
73
+ "multiscale_lowres_count" in config.data
74
+ and config.data.multiscale_lowres_count is not None
75
+ ):
76
+ # Get padding attributes
77
+ if "padding_kwargs" not in kwargs_dict:
78
+ padding_kwargs = {}
79
+ if "padding_mode" in config.data and config.data.padding_mode is not None:
80
+ padding_kwargs["mode"] = config.data.padding_mode
81
+ else:
82
+ padding_kwargs["mode"] = "reflect"
83
+ if "padding_value" in config.data and config.data.padding_value is not None:
84
+ padding_kwargs["constant_values"] = config.data.padding_value
85
+ else:
86
+ padding_kwargs["constant_values"] = None
87
+ else:
88
+ padding_kwargs = kwargs_dict.pop("padding_kwargs")
89
+
90
+ train_data = (
91
+ None
92
+ if skip_train_dataset
93
+ else LCMultiChDloader(
94
+ config.data,
95
+ datapath,
96
+ datasplit_type=DataSplitType.Train,
97
+ val_fraction=0.1,
98
+ test_fraction=0.1,
99
+ normalized_input=normalized_input,
100
+ use_one_mu_std=use_one_mu_std,
101
+ enable_rotation_aug=train_aug_rotate,
102
+ enable_random_cropping=enable_random_cropping,
103
+ num_scales=config.data.multiscale_lowres_count,
104
+ lowres_supervision=lowres_supervision,
105
+ padding_kwargs=padding_kwargs,
106
+ **kwargs_dict,
107
+ allow_generation=True,
108
+ )
109
+ )
110
+ max_val = train_data.get_max_val()
111
+
112
+ val_data = LCMultiChDloader(
113
+ config.data,
114
+ datapath,
115
+ datasplit_type=eval_datasplit_type,
116
+ val_fraction=0.1,
117
+ test_fraction=0.1,
118
+ normalized_input=normalized_input,
119
+ use_one_mu_std=use_one_mu_std,
120
+ enable_rotation_aug=False, # No rotation aug on validation
121
+ enable_random_cropping=False,
122
+ # No random cropping on validation. Validation is evaluated on determistic grids
123
+ num_scales=config.data.multiscale_lowres_count,
124
+ lowres_supervision=lowres_supervision,
125
+ padding_kwargs=padding_kwargs,
126
+ allow_generation=False,
127
+ **kwargs_dict,
128
+ max_val=max_val,
129
+ )
130
+ # 2) Vanilla data loader
131
+ else:
132
+ train_data_kwargs = {"allow_generation": True, **kwargs_dict}
133
+ val_data_kwargs = {"allow_generation": False, **kwargs_dict}
134
+
135
+ train_data_kwargs["enable_random_cropping"] = enable_random_cropping
136
+ val_data_kwargs["enable_random_cropping"] = False
137
+
138
+ train_data = (
139
+ None
140
+ if skip_train_dataset
141
+ else MultiChDloader(
142
+ data_config=config.data,
143
+ fpath=datapath,
144
+ datasplit_type=DataSplitType.Train,
145
+ val_fraction=0.1,
146
+ test_fraction=0.1,
147
+ normalized_input=normalized_input,
148
+ use_one_mu_std=use_one_mu_std,
149
+ enable_rotation_aug=train_aug_rotate,
150
+ **train_data_kwargs,
151
+ )
152
+ )
153
+
154
+ max_val = train_data.get_max_val()
155
+ val_data = MultiChDloader(
156
+ data_config=config.data,
157
+ fpath=datapath,
158
+ datasplit_type=eval_datasplit_type,
159
+ val_fraction=0.1,
160
+ test_fraction=0.1,
161
+ normalized_input=normalized_input,
162
+ use_one_mu_std=use_one_mu_std,
163
+ enable_rotation_aug=False, # No rotation aug on validation
164
+ max_val=max_val,
165
+ **val_data_kwargs,
166
+ )
167
+
168
+ # For normalizing, we should be using the training data's mean and std.
169
+ mean_val, std_val = train_data.compute_mean_std()
170
+ train_data.set_mean_std(mean_val, std_val)
171
+ val_data.set_mean_std(mean_val, std_val)
172
+
173
+ return train_data, val_data
174
+
175
+
176
+ def create_model_and_train(
177
+ config: ml_collections.ConfigDict,
178
+ data_mean: Dict[str, torch.Tensor],
179
+ data_std: Dict[str, torch.Tensor],
180
+ logger: WandbLogger,
181
+ checkpoint_callback: ModelCheckpoint,
182
+ train_loader: DataLoader,
183
+ val_loader: DataLoader,
184
+ ):
185
+ # tensorboard previous files.
186
+ for filename in glob.glob(config.workdir + "/events*"):
187
+ os.remove(filename)
188
+
189
+ # checkpoints
190
+ for filename in glob.glob(config.workdir + "/*.ckpt"):
191
+ os.remove(filename)
192
+
193
+ if "num_targets" in config.model:
194
+ target_ch = config.model.num_targets
195
+ else:
196
+ target_ch = config.data.get("num_channels", 2)
197
+
198
+ # Instantiate the model (lightning wrapper)
199
+ model = LadderVAELight(
200
+ data_mean=data_mean, data_std=data_std, config=config, target_ch=target_ch
201
+ )
202
+
203
+ # Load pre-trained weights if any
204
+ if config.training.pre_trained_ckpt_fpath:
205
+ print("Starting with pre-trained model", config.training.pre_trained_ckpt_fpath)
206
+ checkpoint = torch.load(config.training.pre_trained_ckpt_fpath)
207
+ _ = model.load_state_dict(checkpoint["state_dict"], strict=False)
208
+
209
+ estop_monitor = config.model.get("monitor", "val_loss")
210
+ estop_mode = MetricMonitor(estop_monitor).mode()
211
+
212
+ callbacks = [
213
+ EarlyStopping(
214
+ monitor=estop_monitor,
215
+ min_delta=1e-6,
216
+ patience=config.training.earlystop_patience,
217
+ verbose=True,
218
+ mode=estop_mode,
219
+ ),
220
+ checkpoint_callback,
221
+ LearningRateMonitor(logging_interval="epoch"),
222
+ ]
223
+
224
+ logger.experiment.config.update(config.to_dict())
225
+ # wandb.init(config=config)
226
+ trainer = pl.Trainer(
227
+ accelerator="gpu",
228
+ max_epochs=config.training.max_epochs,
229
+ gradient_clip_val=config.training.grad_clip_norm_value,
230
+ gradient_clip_algorithm=config.training.gradient_clip_algorithm,
231
+ logger=logger,
232
+ callbacks=callbacks,
233
+ # limit_train_batches = config.training.limit_train_batches,
234
+ precision=config.training.precision,
235
+ )
236
+ trainer.fit(model, train_loader, val_loader)
237
+
238
+
239
+ def train_network(
240
+ train_loader: DataLoader,
241
+ val_loader: DataLoader,
242
+ data_mean: Dict[str, torch.Tensor],
243
+ data_std: Dict[str, torch.Tensor],
244
+ config: ml_collections.ConfigDict,
245
+ model_name: str,
246
+ logdir: str,
247
+ ):
248
+ ckpt_monitor = config.model.get("monitor", "val_loss")
249
+ ckpt_mode = MetricMonitor(ckpt_monitor).mode()
250
+ checkpoint_callback = ModelCheckpoint(
251
+ monitor=ckpt_monitor,
252
+ dirpath=config.workdir,
253
+ filename=model_name + "_best",
254
+ save_last=True,
255
+ save_top_k=1,
256
+ mode=ckpt_mode,
257
+ )
258
+ checkpoint_callback.CHECKPOINT_NAME_LAST = model_name + "_last"
259
+ logger = WandbLogger(
260
+ name=os.path.join(config.hostname, config.exptname),
261
+ save_dir=logdir,
262
+ project="Disentanglement",
263
+ )
264
+
265
+ create_model_and_train(
266
+ config=config,
267
+ data_mean=data_mean,
268
+ data_std=data_std,
269
+ logger=logger,
270
+ checkpoint_callback=checkpoint_callback,
271
+ train_loader=train_loader,
272
+ val_loader=val_loader,
273
+ )
274
+
275
+
276
+ def main(argv):
277
+ config = FLAGS.config
278
+
279
+ assert os.path.exists(FLAGS.workdir)
280
+ cur_workdir, relative_path = get_workdir(
281
+ config, FLAGS.workdir, FLAGS.use_max_version
282
+ )
283
+ print(f"Saving training to {cur_workdir}")
284
+
285
+ config.workdir = cur_workdir
286
+ config.exptname = relative_path
287
+ config.hostname = socket.gethostname()
288
+ config.datadir = FLAGS.datadir
289
+ config.training.pre_trained_ckpt_fpath = FLAGS.load_ckptfpath
290
+
291
+ if FLAGS.mode == "train":
292
+ set_logger(workdir=cur_workdir)
293
+ raw_data_dict = None
294
+
295
+ # From now on, config cannot be changed.
296
+ config = ml_collections.FrozenConfigDict(config)
297
+ log_config(config, cur_workdir)
298
+
299
+ train_data, val_data = create_dataset(
300
+ config, FLAGS.datadir, raw_data_dict=raw_data_dict
301
+ )
302
+
303
+ mean_dict, std_dict = get_mean_std_dict_for_model(config, train_data)
304
+
305
+ batch_size = config.training.batch_size
306
+ shuffle = True
307
+ train_dloader = DataLoader(
308
+ train_data,
309
+ pin_memory=False,
310
+ num_workers=config.training.num_workers,
311
+ shuffle=shuffle,
312
+ batch_size=batch_size,
313
+ )
314
+ val_dloader = DataLoader(
315
+ val_data,
316
+ pin_memory=False,
317
+ num_workers=config.training.num_workers,
318
+ shuffle=False,
319
+ batch_size=batch_size,
320
+ )
321
+
322
+ train_network(
323
+ train_loader=train_dloader,
324
+ val_loader=val_dloader,
325
+ data_mean=mean_dict,
326
+ data_std=std_dict,
327
+ config=config,
328
+ model_name="BaselineVAECL",
329
+ logdir=FLAGS.logdir,
330
+ )
331
+
332
+ elif FLAGS.mode == "eval":
333
+ pass
334
+ else:
335
+ raise ValueError(f"Mode {FLAGS.mode} not recognized.")
336
+
337
+
338
+ if __name__ == "__main__":
339
+ app.run(main)
@@ -0,0 +1,121 @@
1
+ """
2
+ This script contains the utility functions for training the LVAE model.
3
+ These functions are mainly used in `train.py` script.
4
+ """
5
+
6
+ import logging
7
+ import os
8
+ import pickle
9
+ import time
10
+ from copy import deepcopy
11
+ from datetime import datetime
12
+ from pathlib import Path
13
+
14
+ import ml_collections
15
+
16
+
17
+ def log_config(config: ml_collections.ConfigDict, cur_workdir: str) -> None:
18
+ # Saving config file.
19
+ with open(os.path.join(cur_workdir, "config.pkl"), "wb") as f:
20
+ pickle.dump(config, f)
21
+ print(f"Saved config to {cur_workdir}/config.pkl")
22
+
23
+
24
+ def set_logger(workdir: str) -> None:
25
+ os.makedirs(workdir, exist_ok=True)
26
+ fstream = open(os.path.join(workdir, "stdout.txt"), "w")
27
+ handler = logging.StreamHandler(fstream)
28
+ formatter = logging.Formatter(
29
+ "%(levelname)s - %(filename)s - %(asctime)s - %(message)s"
30
+ )
31
+ handler.setFormatter(formatter)
32
+ logger = logging.getLogger()
33
+ logger.addHandler(handler)
34
+ logger.setLevel("INFO")
35
+
36
+
37
+ def get_new_model_version(model_dir: str) -> str:
38
+ """
39
+ A model will have multiple runs. Each run will have a different version.
40
+ """
41
+ versions = []
42
+ for version_dir in os.listdir(model_dir):
43
+ try:
44
+ versions.append(int(version_dir))
45
+ except:
46
+ print(
47
+ f"Invalid subdirectory:{model_dir}/{version_dir}. Only integer versions are allowed"
48
+ )
49
+ exit()
50
+ if len(versions) == 0:
51
+ return "0"
52
+ return f"{max(versions) + 1}"
53
+
54
+
55
+ def get_model_name(config: ml_collections.ConfigDict) -> str:
56
+ return "LVAE_denoiSplit"
57
+
58
+
59
+ def get_workdir(
60
+ config: ml_collections.ConfigDict,
61
+ root_dir: str,
62
+ use_max_version: bool,
63
+ nested_call: int = 0,
64
+ ):
65
+ rel_path = datetime.now().strftime("%y%m")
66
+ cur_workdir = os.path.join(root_dir, rel_path)
67
+ Path(cur_workdir).mkdir(exist_ok=True)
68
+
69
+ rel_path = os.path.join(rel_path, get_model_name(config))
70
+ cur_workdir = os.path.join(root_dir, rel_path)
71
+ Path(cur_workdir).mkdir(exist_ok=True)
72
+
73
+ if use_max_version:
74
+ # Used for debugging.
75
+ version = int(get_new_model_version(cur_workdir))
76
+ if version > 0:
77
+ version = f"{version - 1}"
78
+
79
+ rel_path = os.path.join(rel_path, str(version))
80
+ else:
81
+ rel_path = os.path.join(rel_path, get_new_model_version(cur_workdir))
82
+
83
+ cur_workdir = os.path.join(root_dir, rel_path)
84
+ try:
85
+ Path(cur_workdir).mkdir(exist_ok=False)
86
+ except FileExistsError:
87
+ print(
88
+ f"Workdir {cur_workdir} already exists. Probably because someother program also created the exact same directory. Trying to get a new version."
89
+ )
90
+ time.sleep(2.5)
91
+ if nested_call > 10:
92
+ raise ValueError(
93
+ f"Cannot create a new directory. {cur_workdir} already exists."
94
+ )
95
+
96
+ return get_workdir(config, root_dir, use_max_version, nested_call + 1)
97
+
98
+ return cur_workdir, rel_path
99
+
100
+
101
+ def get_mean_std_dict_for_model(config, train_dset):
102
+ """
103
+ Computes the mean and std for the model. This will be subsequently passed to the model.
104
+ """
105
+ mean_dict, std_dict = train_dset.get_mean_std()
106
+
107
+ return deepcopy(mean_dict), deepcopy(std_dict)
108
+
109
+
110
+ class MetricMonitor:
111
+ def __init__(self, metric):
112
+ assert metric in ["val_loss", "val_psnr"]
113
+ self.metric = metric
114
+
115
+ def mode(self):
116
+ if self.metric == "val_loss":
117
+ return "min"
118
+ elif self.metric == "val_psnr":
119
+ return "max"
120
+ else:
121
+ raise ValueError(f"Invalid metric:{self.metric}")