careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (155) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +729 -0
  3. careamics/config/__init__.py +39 -0
  4. careamics/config/architectures/__init__.py +17 -0
  5. careamics/config/architectures/architecture_model.py +37 -0
  6. careamics/config/architectures/custom_model.py +162 -0
  7. careamics/config/architectures/lvae_model.py +174 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/callback_model.py +123 -0
  11. careamics/config/configuration_factory.py +583 -0
  12. careamics/config/configuration_model.py +604 -0
  13. careamics/config/data_model.py +527 -0
  14. careamics/config/fcn_algorithm_model.py +147 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/likelihood_model.py +43 -0
  17. careamics/config/nm_model.py +101 -0
  18. careamics/config/optimizer_models.py +187 -0
  19. careamics/config/references/__init__.py +45 -0
  20. careamics/config/references/algorithm_descriptions.py +132 -0
  21. careamics/config/references/references.py +39 -0
  22. careamics/config/support/__init__.py +31 -0
  23. careamics/config/support/supported_activations.py +27 -0
  24. careamics/config/support/supported_algorithms.py +33 -0
  25. careamics/config/support/supported_architectures.py +17 -0
  26. careamics/config/support/supported_data.py +109 -0
  27. careamics/config/support/supported_loggers.py +10 -0
  28. careamics/config/support/supported_losses.py +29 -0
  29. careamics/config/support/supported_optimizers.py +57 -0
  30. careamics/config/support/supported_pixel_manipulations.py +15 -0
  31. careamics/config/support/supported_struct_axis.py +21 -0
  32. careamics/config/support/supported_transforms.py +11 -0
  33. careamics/config/tile_information.py +65 -0
  34. careamics/config/training_model.py +72 -0
  35. careamics/config/transformations/__init__.py +15 -0
  36. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  37. careamics/config/transformations/normalize_model.py +60 -0
  38. careamics/config/transformations/transform_model.py +45 -0
  39. careamics/config/transformations/xy_flip_model.py +43 -0
  40. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  41. careamics/config/vae_algorithm_model.py +171 -0
  42. careamics/config/validators/__init__.py +5 -0
  43. careamics/config/validators/validator_utils.py +101 -0
  44. careamics/conftest.py +39 -0
  45. careamics/dataset/__init__.py +17 -0
  46. careamics/dataset/dataset_utils/__init__.py +19 -0
  47. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  48. careamics/dataset/dataset_utils/file_utils.py +141 -0
  49. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  50. careamics/dataset/dataset_utils/running_stats.py +186 -0
  51. careamics/dataset/in_memory_dataset.py +310 -0
  52. careamics/dataset/in_memory_pred_dataset.py +88 -0
  53. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  54. careamics/dataset/iterable_dataset.py +295 -0
  55. careamics/dataset/iterable_pred_dataset.py +122 -0
  56. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  57. careamics/dataset/patching/__init__.py +1 -0
  58. careamics/dataset/patching/patching.py +299 -0
  59. careamics/dataset/patching/random_patching.py +201 -0
  60. careamics/dataset/patching/sequential_patching.py +212 -0
  61. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  62. careamics/dataset/tiling/__init__.py +10 -0
  63. careamics/dataset/tiling/collate_tiles.py +33 -0
  64. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  65. careamics/dataset/tiling/tiled_patching.py +164 -0
  66. careamics/dataset/zarr_dataset.py +151 -0
  67. careamics/file_io/__init__.py +15 -0
  68. careamics/file_io/read/__init__.py +12 -0
  69. careamics/file_io/read/get_func.py +56 -0
  70. careamics/file_io/read/tiff.py +58 -0
  71. careamics/file_io/read/zarr.py +60 -0
  72. careamics/file_io/write/__init__.py +15 -0
  73. careamics/file_io/write/get_func.py +63 -0
  74. careamics/file_io/write/tiff.py +40 -0
  75. careamics/lightning/__init__.py +18 -0
  76. careamics/lightning/callbacks/__init__.py +11 -0
  77. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  79. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  80. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  81. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  82. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  83. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  84. careamics/lightning/lightning_module.py +632 -0
  85. careamics/lightning/predict_data_module.py +333 -0
  86. careamics/lightning/train_data_module.py +680 -0
  87. careamics/losses/__init__.py +15 -0
  88. careamics/losses/fcn/__init__.py +1 -0
  89. careamics/losses/fcn/losses.py +98 -0
  90. careamics/losses/loss_factory.py +155 -0
  91. careamics/losses/lvae/__init__.py +1 -0
  92. careamics/losses/lvae/loss_utils.py +83 -0
  93. careamics/losses/lvae/losses.py +445 -0
  94. careamics/lvae_training/__init__.py +0 -0
  95. careamics/lvae_training/dataset/__init__.py +0 -0
  96. careamics/lvae_training/dataset/data_utils.py +701 -0
  97. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  98. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  99. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  100. careamics/lvae_training/dataset/vae_dataset.py +1054 -0
  101. careamics/lvae_training/eval_utils.py +905 -0
  102. careamics/lvae_training/get_config.py +84 -0
  103. careamics/lvae_training/lightning_module.py +701 -0
  104. careamics/lvae_training/metrics.py +214 -0
  105. careamics/lvae_training/train_lvae.py +342 -0
  106. careamics/lvae_training/train_utils.py +121 -0
  107. careamics/model_io/__init__.py +7 -0
  108. careamics/model_io/bioimage/__init__.py +11 -0
  109. careamics/model_io/bioimage/_readme_factory.py +121 -0
  110. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  111. careamics/model_io/bioimage/model_description.py +327 -0
  112. careamics/model_io/bmz_io.py +246 -0
  113. careamics/model_io/model_io_utils.py +95 -0
  114. careamics/models/__init__.py +5 -0
  115. careamics/models/activation.py +39 -0
  116. careamics/models/layers.py +493 -0
  117. careamics/models/lvae/__init__.py +3 -0
  118. careamics/models/lvae/layers.py +1998 -0
  119. careamics/models/lvae/likelihoods.py +364 -0
  120. careamics/models/lvae/lvae.py +901 -0
  121. careamics/models/lvae/noise_models.py +541 -0
  122. careamics/models/lvae/utils.py +395 -0
  123. careamics/models/model_factory.py +67 -0
  124. careamics/models/unet.py +443 -0
  125. careamics/prediction_utils/__init__.py +10 -0
  126. careamics/prediction_utils/lvae_prediction.py +158 -0
  127. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  128. careamics/prediction_utils/prediction_outputs.py +135 -0
  129. careamics/prediction_utils/stitch_prediction.py +112 -0
  130. careamics/transforms/__init__.py +20 -0
  131. careamics/transforms/compose.py +107 -0
  132. careamics/transforms/n2v_manipulate.py +146 -0
  133. careamics/transforms/normalize.py +243 -0
  134. careamics/transforms/pixel_manipulation.py +407 -0
  135. careamics/transforms/struct_mask_parameters.py +20 -0
  136. careamics/transforms/transform.py +24 -0
  137. careamics/transforms/tta.py +88 -0
  138. careamics/transforms/xy_flip.py +123 -0
  139. careamics/transforms/xy_random_rotate90.py +101 -0
  140. careamics/utils/__init__.py +19 -0
  141. careamics/utils/autocorrelation.py +40 -0
  142. careamics/utils/base_enum.py +60 -0
  143. careamics/utils/context.py +66 -0
  144. careamics/utils/logging.py +322 -0
  145. careamics/utils/metrics.py +188 -0
  146. careamics/utils/path_utils.py +26 -0
  147. careamics/utils/ram.py +15 -0
  148. careamics/utils/receptive_field.py +108 -0
  149. careamics/utils/torch_utils.py +127 -0
  150. careamics-0.0.3.dist-info/METADATA +78 -0
  151. careamics-0.0.3.dist-info/RECORD +154 -0
  152. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
  153. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
  154. careamics-0.0.1.dist-info/METADATA +0 -46
  155. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,214 @@
1
+ """
2
+ This script contains the functions/classes to compute loss and metrics used to train and evaluate the performance of the model.
3
+ """
4
+
5
+ import numpy as np
6
+ import torch
7
+ from skimage.metrics import structural_similarity
8
+ from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure
9
+
10
+ from careamics.models.lvae.utils import allow_numpy
11
+
12
+
13
+ class RunningPSNR:
14
+ """
15
+ This class allows to compute the running PSNR during validation step in training.
16
+ In this way it is possible to compute the PSNR on the entire validation set one batch at the time.
17
+ """
18
+
19
+ def __init__(self):
20
+ # number of elements seen so far during the epoch
21
+ self.N = None
22
+ # running sum of the MSE over the self.N elements seen so far
23
+ self.mse_sum = None
24
+ # running max and min values of the self.N target images seen so far
25
+ self.max = self.min = None
26
+ self.reset()
27
+
28
+ def reset(self):
29
+ """
30
+ Used to reset the running PSNR (usually called at the end of each epoch).
31
+ """
32
+ self.mse_sum = 0
33
+ self.N = 0
34
+ self.max = self.min = None
35
+
36
+ def update(self, rec: torch.Tensor, tar: torch.Tensor) -> None:
37
+ """
38
+ Given a batch of reconstructed and target images, it updates the MSE and.
39
+
40
+ Parameters
41
+ ----------
42
+ rec: torch.Tensor
43
+ Batch of reconstructed images (B, H, W).
44
+ tar: torch.Tensor
45
+ Batch of target images (B, H, W).
46
+ """
47
+ ins_max = torch.max(tar).item()
48
+ ins_min = torch.min(tar).item()
49
+ if self.max is None:
50
+ assert self.min is None
51
+ self.max = ins_max
52
+ self.min = ins_min
53
+ else:
54
+ self.max = max(self.max, ins_max)
55
+ self.min = min(self.min, ins_min)
56
+
57
+ mse = (rec - tar) ** 2
58
+ elementwise_mse = torch.mean(mse.view(len(mse), -1), dim=1)
59
+ self.mse_sum += torch.nansum(elementwise_mse)
60
+ self.N += len(elementwise_mse) - torch.sum(torch.isnan(elementwise_mse))
61
+
62
+ def get(self):
63
+ """
64
+ The get the actual PSNR value given the running statistics.
65
+ """
66
+ if self.N == 0 or self.N is None:
67
+ return None
68
+ rmse = torch.sqrt(self.mse_sum / self.N)
69
+ return 20 * torch.log10((self.max - self.min) / rmse)
70
+
71
+
72
+ def zero_mean(x):
73
+ return x - torch.mean(x, dim=1, keepdim=True)
74
+
75
+
76
+ def fix_range(gt, x):
77
+ a = torch.sum(gt * x, dim=1, keepdim=True) / (torch.sum(x * x, dim=1, keepdim=True))
78
+ return x * a
79
+
80
+
81
+ def fix(gt, x):
82
+ gt_ = zero_mean(gt)
83
+ return fix_range(gt_, zero_mean(x))
84
+
85
+
86
+ def _PSNR_internal(gt, pred, range_=None):
87
+ if range_ is None:
88
+ range_ = torch.max(gt, dim=1).values - torch.min(gt, dim=1).values
89
+
90
+ mse = torch.mean((gt - pred) ** 2, dim=1)
91
+ return 20 * torch.log10(range_ / torch.sqrt(mse))
92
+
93
+
94
+ @allow_numpy
95
+ def PSNR(gt, pred, range_=None):
96
+ """
97
+ Compute PSNR.
98
+
99
+ Parameters
100
+ ----------
101
+ gt: array
102
+ Ground truth image.
103
+ pred: array
104
+ Predicted image.
105
+ """
106
+ assert len(gt.shape) == 3, "Images must be in shape: (batch,H,W)"
107
+
108
+ gt = gt.view(len(gt), -1)
109
+ pred = pred.view(len(gt), -1)
110
+ return _PSNR_internal(gt, pred, range_=range_)
111
+
112
+
113
+ @allow_numpy
114
+ def RangeInvariantPsnr(gt: torch.Tensor, pred: torch.Tensor):
115
+ """
116
+ NOTE: Works only for grayscale images.
117
+ Adapted from https://github.com/juglab/ScaleInvPSNR/blob/master/psnr.py
118
+ It rescales the prediction to ensure that the prediction has the same range as the ground truth.
119
+ """
120
+ assert len(gt.shape) == 3, "Images must be in shape: (batch,H,W)"
121
+ gt = gt.view(len(gt), -1)
122
+ pred = pred.view(len(gt), -1)
123
+ ra = (torch.max(gt, dim=1).values - torch.min(gt, dim=1).values) / torch.std(
124
+ gt, dim=1
125
+ )
126
+ gt_ = zero_mean(gt) / torch.std(gt, dim=1, keepdim=True)
127
+ return _PSNR_internal(zero_mean(gt_), fix(gt_, pred), ra)
128
+
129
+
130
+ def _avg_psnr(target, prediction, psnr_fn):
131
+ output = np.mean(
132
+ [
133
+ psnr_fn(target[i : i + 1], prediction[i : i + 1]).item()
134
+ for i in range(len(prediction))
135
+ ]
136
+ )
137
+ return round(output, 2)
138
+
139
+
140
+ def avg_range_inv_psnr(target, prediction):
141
+ return _avg_psnr(target, prediction, RangeInvariantPsnr)
142
+
143
+
144
+ def avg_psnr(target, prediction):
145
+ return _avg_psnr(target, prediction, PSNR)
146
+
147
+
148
+ def compute_masked_psnr(mask, tar1, tar2, pred1, pred2):
149
+ mask = mask.astype(bool)
150
+ mask = mask[..., 0]
151
+ tmp_tar1 = tar1[mask].reshape((len(tar1), -1, 1))
152
+ tmp_pred1 = pred1[mask].reshape((len(tar1), -1, 1))
153
+ tmp_tar2 = tar2[mask].reshape((len(tar2), -1, 1))
154
+ tmp_pred2 = pred2[mask].reshape((len(tar2), -1, 1))
155
+ psnr1 = avg_range_inv_psnr(tmp_tar1, tmp_pred1)
156
+ psnr2 = avg_range_inv_psnr(tmp_tar2, tmp_pred2)
157
+ return psnr1, psnr2
158
+
159
+
160
+ def avg_ssim(target, prediction):
161
+ ssim = [
162
+ structural_similarity(
163
+ target[i], prediction[i], data_range=(target[i].max() - target[i].min())
164
+ )
165
+ for i in range(len(target))
166
+ ]
167
+ return np.mean(ssim), np.std(ssim)
168
+
169
+
170
+ @allow_numpy
171
+ def range_invariant_multiscale_ssim(gt_, pred_):
172
+ """
173
+ Computes range invariant multiscale ssim for one channel.
174
+ This has the benefit that it is invariant to scalar multiplications in the prediction.
175
+ """
176
+ shape = gt_.shape
177
+ gt_ = torch.Tensor(gt_.reshape((shape[0], -1)))
178
+ pred_ = torch.Tensor(pred_.reshape((shape[0], -1)))
179
+ gt_ = zero_mean(gt_)
180
+ pred_ = zero_mean(pred_)
181
+ pred_ = fix(gt_, pred_)
182
+ pred_ = pred_.reshape(shape)
183
+ gt_ = gt_.reshape(shape)
184
+
185
+ ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(
186
+ data_range=gt_.max() - gt_.min()
187
+ )
188
+ return ms_ssim(torch.Tensor(pred_[:, None]), torch.Tensor(gt_[:, None])).item()
189
+
190
+
191
+ def compute_multiscale_ssim(gt_, pred_, range_invariant=True):
192
+ """
193
+ Computes multiscale ssim for each channel.
194
+ Args:
195
+ gt_: ground truth image with shape (N, H, W, C)
196
+ pred_: predicted image with shape (N, H, W, C)
197
+ range_invariant: whether to use range invariant multiscale ssim
198
+ """
199
+ ms_ssim_values = {i: None for i in range(gt_.shape[-1])}
200
+ for ch_idx in range(gt_.shape[-1]):
201
+ tar_tmp = gt_[..., ch_idx]
202
+ pred_tmp = pred_[..., ch_idx]
203
+ if range_invariant:
204
+ ms_ssim_values[ch_idx] = range_invariant_multiscale_ssim(tar_tmp, pred_tmp)
205
+ else:
206
+ ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(
207
+ data_range=tar_tmp.max() - tar_tmp.min()
208
+ )
209
+ ms_ssim_values[ch_idx] = ms_ssim(
210
+ torch.Tensor(pred_tmp[:, None]), torch.Tensor(tar_tmp[:, None])
211
+ ).item()
212
+
213
+ output = [ms_ssim_values[i] for i in range(gt_.shape[-1])]
214
+ return output
@@ -0,0 +1,342 @@
1
+ """
2
+ This script is meant to load data, initialize the model, and provide the logic for training it.
3
+ """
4
+
5
+ import glob
6
+ import os
7
+ import socket
8
+ import sys
9
+ from typing import Dict
10
+
11
+ import pytorch_lightning as pl
12
+ import torch
13
+ from absl import app, flags
14
+ from ml_collections.config_flags import config_flags
15
+ from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
16
+ from pytorch_lightning.callbacks.early_stopping import EarlyStopping
17
+ from pytorch_lightning.loggers import WandbLogger
18
+ from torch.utils.data import DataLoader
19
+
20
+ sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
21
+ print(sys.path)
22
+
23
+ from careamics.lvae_training.dataset.data_modules import (
24
+ LCMultiChDloader,
25
+ MultiChDloader,
26
+ )
27
+ from careamics.lvae_training.dataset.data_utils import DataSplitType
28
+ from careamics.lvae_training.lightning_module import LadderVAELight
29
+ from careamics.lvae_training.train_utils import *
30
+
31
+ FLAGS = flags.FLAGS
32
+
33
+ config_flags.DEFINE_config_file(
34
+ "config", None, "Training configuration.", lock_config=False
35
+ )
36
+ flags.DEFINE_string("workdir", None, "Work directory.")
37
+ flags.DEFINE_enum("mode", None, ["train", "eval"], "Running mode: train or eval")
38
+ flags.DEFINE_string(
39
+ "logdir", "/group/jug/federico/wandb_backup/", "The folder name for storing logging"
40
+ )
41
+ flags.DEFINE_string(
42
+ "datadir", "/group/jug/federico/careamics_training/data/BioSR", "Data directory."
43
+ )
44
+ flags.DEFINE_boolean("use_max_version", False, "Overwrite the max version of the model")
45
+ flags.DEFINE_string(
46
+ "load_ckptfpath",
47
+ "",
48
+ "The path to a previous ckpt from which the weights should be loaded",
49
+ )
50
+ flags.mark_flags_as_required(["workdir", "config", "mode"])
51
+
52
+
53
+ def create_dataset(
54
+ config,
55
+ datadir,
56
+ eval_datasplit_type=DataSplitType.Val,
57
+ raw_data_dict=None,
58
+ skip_train_dataset=False,
59
+ kwargs_dict=None,
60
+ ):
61
+
62
+ if kwargs_dict is None:
63
+ kwargs_dict = {}
64
+
65
+ datapath = datadir
66
+
67
+ # Hard-coded parameters (used to be in the config file)
68
+ normalized_input = True
69
+ use_one_mu_std = True
70
+ train_aug_rotate = False
71
+ enable_random_cropping = True
72
+ lowres_supervision = False
73
+
74
+ # 1) Data loader for Lateral Contextualization
75
+ if (
76
+ "multiscale_lowres_count" in config.data
77
+ and config.data.multiscale_lowres_count is not None
78
+ ):
79
+ # Get padding attributes
80
+ if "padding_kwargs" not in kwargs_dict:
81
+ padding_kwargs = {}
82
+ if "padding_mode" in config.data and config.data.padding_mode is not None:
83
+ padding_kwargs["mode"] = config.data.padding_mode
84
+ else:
85
+ padding_kwargs["mode"] = "reflect"
86
+ if "padding_value" in config.data and config.data.padding_value is not None:
87
+ padding_kwargs["constant_values"] = config.data.padding_value
88
+ else:
89
+ padding_kwargs["constant_values"] = None
90
+ else:
91
+ padding_kwargs = kwargs_dict.pop("padding_kwargs")
92
+
93
+ train_data = (
94
+ None
95
+ if skip_train_dataset
96
+ else LCMultiChDloader(
97
+ config.data,
98
+ datapath,
99
+ datasplit_type=DataSplitType.Train,
100
+ val_fraction=0.1,
101
+ test_fraction=0.1,
102
+ normalized_input=normalized_input,
103
+ use_one_mu_std=use_one_mu_std,
104
+ enable_rotation_aug=train_aug_rotate,
105
+ enable_random_cropping=enable_random_cropping,
106
+ num_scales=config.data.multiscale_lowres_count,
107
+ lowres_supervision=lowres_supervision,
108
+ padding_kwargs=padding_kwargs,
109
+ **kwargs_dict,
110
+ allow_generation=True,
111
+ )
112
+ )
113
+ max_val = train_data.get_max_val()
114
+
115
+ val_data = LCMultiChDloader(
116
+ config.data,
117
+ datapath,
118
+ datasplit_type=eval_datasplit_type,
119
+ val_fraction=0.1,
120
+ test_fraction=0.1,
121
+ normalized_input=normalized_input,
122
+ use_one_mu_std=use_one_mu_std,
123
+ enable_rotation_aug=False, # No rotation aug on validation
124
+ enable_random_cropping=False,
125
+ # No random cropping on validation. Validation is evaluated on determistic grids
126
+ num_scales=config.data.multiscale_lowres_count,
127
+ lowres_supervision=lowres_supervision,
128
+ padding_kwargs=padding_kwargs,
129
+ allow_generation=False,
130
+ **kwargs_dict,
131
+ max_val=max_val,
132
+ )
133
+ # 2) Vanilla data loader
134
+ else:
135
+ train_data_kwargs = {"allow_generation": True, **kwargs_dict}
136
+ val_data_kwargs = {"allow_generation": False, **kwargs_dict}
137
+
138
+ train_data_kwargs["enable_random_cropping"] = enable_random_cropping
139
+ val_data_kwargs["enable_random_cropping"] = False
140
+
141
+ train_data = (
142
+ None
143
+ if skip_train_dataset
144
+ else MultiChDloader(
145
+ data_config=config.data,
146
+ fpath=datapath,
147
+ datasplit_type=DataSplitType.Train,
148
+ val_fraction=0.1,
149
+ test_fraction=0.1,
150
+ normalized_input=normalized_input,
151
+ use_one_mu_std=use_one_mu_std,
152
+ enable_rotation_aug=train_aug_rotate,
153
+ **train_data_kwargs,
154
+ )
155
+ )
156
+
157
+ max_val = train_data.get_max_val()
158
+ val_data = MultiChDloader(
159
+ data_config=config.data,
160
+ fpath=datapath,
161
+ datasplit_type=eval_datasplit_type,
162
+ val_fraction=0.1,
163
+ test_fraction=0.1,
164
+ normalized_input=normalized_input,
165
+ use_one_mu_std=use_one_mu_std,
166
+ enable_rotation_aug=False, # No rotation aug on validation
167
+ max_val=max_val,
168
+ **val_data_kwargs,
169
+ )
170
+
171
+ # For normalizing, we should be using the training data's mean and std.
172
+ mean_val, std_val = train_data.compute_mean_std()
173
+ train_data.set_mean_std(mean_val, std_val)
174
+ val_data.set_mean_std(mean_val, std_val)
175
+
176
+ return train_data, val_data
177
+
178
+
179
+ def create_model_and_train(
180
+ config: ml_collections.ConfigDict,
181
+ data_mean: Dict[str, torch.Tensor],
182
+ data_std: Dict[str, torch.Tensor],
183
+ logger: WandbLogger,
184
+ checkpoint_callback: ModelCheckpoint,
185
+ train_loader: DataLoader,
186
+ val_loader: DataLoader,
187
+ ):
188
+ # tensorboard previous files.
189
+ for filename in glob.glob(config.workdir + "/events*"):
190
+ os.remove(filename)
191
+
192
+ # checkpoints
193
+ for filename in glob.glob(config.workdir + "/*.ckpt"):
194
+ os.remove(filename)
195
+
196
+ if "num_targets" in config.model:
197
+ target_ch = config.model.num_targets
198
+ else:
199
+ target_ch = config.data.get("num_channels", 2)
200
+
201
+ # Instantiate the model (lightning wrapper)
202
+ model = LadderVAELight(
203
+ data_mean=data_mean, data_std=data_std, config=config, target_ch=target_ch
204
+ )
205
+
206
+ # Load pre-trained weights if any
207
+ if config.training.pre_trained_ckpt_fpath:
208
+ print("Starting with pre-trained model", config.training.pre_trained_ckpt_fpath)
209
+ checkpoint = torch.load(config.training.pre_trained_ckpt_fpath)
210
+ _ = model.load_state_dict(checkpoint["state_dict"], strict=False)
211
+
212
+ estop_monitor = config.model.get("monitor", "val_loss")
213
+ estop_mode = MetricMonitor(estop_monitor).mode()
214
+
215
+ callbacks = [
216
+ EarlyStopping(
217
+ monitor=estop_monitor,
218
+ min_delta=1e-6,
219
+ patience=config.training.earlystop_patience,
220
+ verbose=True,
221
+ mode=estop_mode,
222
+ ),
223
+ checkpoint_callback,
224
+ LearningRateMonitor(logging_interval="epoch"),
225
+ ]
226
+
227
+ logger.experiment.config.update(config.to_dict())
228
+ # wandb.init(config=config)
229
+ trainer = pl.Trainer(
230
+ accelerator="gpu",
231
+ max_epochs=config.training.max_epochs,
232
+ gradient_clip_val=config.training.grad_clip_norm_value,
233
+ gradient_clip_algorithm=config.training.gradient_clip_algorithm,
234
+ logger=logger,
235
+ callbacks=callbacks,
236
+ # limit_train_batches = config.training.limit_train_batches,
237
+ precision=config.training.precision,
238
+ )
239
+ trainer.fit(model, train_loader, val_loader)
240
+
241
+
242
+ def train_network(
243
+ train_loader: DataLoader,
244
+ val_loader: DataLoader,
245
+ data_mean: Dict[str, torch.Tensor],
246
+ data_std: Dict[str, torch.Tensor],
247
+ config: ml_collections.ConfigDict,
248
+ model_name: str,
249
+ logdir: str,
250
+ ):
251
+ ckpt_monitor = config.model.get("monitor", "val_loss")
252
+ ckpt_mode = MetricMonitor(ckpt_monitor).mode()
253
+ checkpoint_callback = ModelCheckpoint(
254
+ monitor=ckpt_monitor,
255
+ dirpath=config.workdir,
256
+ filename=model_name + "_best",
257
+ save_last=True,
258
+ save_top_k=1,
259
+ mode=ckpt_mode,
260
+ )
261
+ checkpoint_callback.CHECKPOINT_NAME_LAST = model_name + "_last"
262
+ logger = WandbLogger(
263
+ name=os.path.join(config.hostname, config.exptname),
264
+ save_dir=logdir,
265
+ project="Disentanglement",
266
+ )
267
+
268
+ create_model_and_train(
269
+ config=config,
270
+ data_mean=data_mean,
271
+ data_std=data_std,
272
+ logger=logger,
273
+ checkpoint_callback=checkpoint_callback,
274
+ train_loader=train_loader,
275
+ val_loader=val_loader,
276
+ )
277
+
278
+
279
+ def main(argv):
280
+ config = FLAGS.config
281
+
282
+ assert os.path.exists(FLAGS.workdir)
283
+ cur_workdir, relative_path = get_workdir(
284
+ config, FLAGS.workdir, FLAGS.use_max_version
285
+ )
286
+ print(f"Saving training to {cur_workdir}")
287
+
288
+ config.workdir = cur_workdir
289
+ config.exptname = relative_path
290
+ config.hostname = socket.gethostname()
291
+ config.datadir = FLAGS.datadir
292
+ config.training.pre_trained_ckpt_fpath = FLAGS.load_ckptfpath
293
+
294
+ if FLAGS.mode == "train":
295
+ set_logger(workdir=cur_workdir)
296
+ raw_data_dict = None
297
+
298
+ # From now on, config cannot be changed.
299
+ config = ml_collections.FrozenConfigDict(config)
300
+ log_config(config, cur_workdir)
301
+
302
+ train_data, val_data = create_dataset(
303
+ config, FLAGS.datadir, raw_data_dict=raw_data_dict
304
+ )
305
+
306
+ mean_dict, std_dict = get_mean_std_dict_for_model(config, train_data)
307
+
308
+ batch_size = config.training.batch_size
309
+ shuffle = True
310
+ train_dloader = DataLoader(
311
+ train_data,
312
+ pin_memory=False,
313
+ num_workers=config.training.num_workers,
314
+ shuffle=shuffle,
315
+ batch_size=batch_size,
316
+ )
317
+ val_dloader = DataLoader(
318
+ val_data,
319
+ pin_memory=False,
320
+ num_workers=config.training.num_workers,
321
+ shuffle=False,
322
+ batch_size=batch_size,
323
+ )
324
+
325
+ train_network(
326
+ train_loader=train_dloader,
327
+ val_loader=val_dloader,
328
+ data_mean=mean_dict,
329
+ data_std=std_dict,
330
+ config=config,
331
+ model_name="BaselineVAECL",
332
+ logdir=FLAGS.logdir,
333
+ )
334
+
335
+ elif FLAGS.mode == "eval":
336
+ pass
337
+ else:
338
+ raise ValueError(f"Mode {FLAGS.mode} not recognized.")
339
+
340
+
341
+ if __name__ == "__main__":
342
+ app.run(main)
@@ -0,0 +1,121 @@
1
+ """
2
+ This script contains the utility functions for training the LVAE model.
3
+ These functions are mainly used in `train.py` script.
4
+ """
5
+
6
+ import logging
7
+ import os
8
+ import pickle
9
+ import time
10
+ from copy import deepcopy
11
+ from datetime import datetime
12
+ from pathlib import Path
13
+
14
+ import ml_collections
15
+
16
+
17
+ def log_config(config: ml_collections.ConfigDict, cur_workdir: str) -> None:
18
+ # Saving config file.
19
+ with open(os.path.join(cur_workdir, "config.pkl"), "wb") as f:
20
+ pickle.dump(config, f)
21
+ print(f"Saved config to {cur_workdir}/config.pkl")
22
+
23
+
24
+ def set_logger(workdir: str) -> None:
25
+ os.makedirs(workdir, exist_ok=True)
26
+ fstream = open(os.path.join(workdir, "stdout.txt"), "w")
27
+ handler = logging.StreamHandler(fstream)
28
+ formatter = logging.Formatter(
29
+ "%(levelname)s - %(filename)s - %(asctime)s - %(message)s"
30
+ )
31
+ handler.setFormatter(formatter)
32
+ logger = logging.getLogger()
33
+ logger.addHandler(handler)
34
+ logger.setLevel("INFO")
35
+
36
+
37
+ def get_new_model_version(model_dir: str) -> str:
38
+ """
39
+ A model will have multiple runs. Each run will have a different version.
40
+ """
41
+ versions = []
42
+ for version_dir in os.listdir(model_dir):
43
+ try:
44
+ versions.append(int(version_dir))
45
+ except:
46
+ print(
47
+ f"Invalid subdirectory:{model_dir}/{version_dir}. Only integer versions are allowed"
48
+ )
49
+ exit()
50
+ if len(versions) == 0:
51
+ return "0"
52
+ return f"{max(versions) + 1}"
53
+
54
+
55
+ def get_model_name(config: ml_collections.ConfigDict) -> str:
56
+ return "LVAE_denoiSplit"
57
+
58
+
59
+ def get_workdir(
60
+ config: ml_collections.ConfigDict,
61
+ root_dir: str,
62
+ use_max_version: bool,
63
+ nested_call: int = 0,
64
+ ):
65
+ rel_path = datetime.now().strftime("%y%m")
66
+ cur_workdir = os.path.join(root_dir, rel_path)
67
+ Path(cur_workdir).mkdir(exist_ok=True)
68
+
69
+ rel_path = os.path.join(rel_path, get_model_name(config))
70
+ cur_workdir = os.path.join(root_dir, rel_path)
71
+ Path(cur_workdir).mkdir(exist_ok=True)
72
+
73
+ if use_max_version:
74
+ # Used for debugging.
75
+ version = int(get_new_model_version(cur_workdir))
76
+ if version > 0:
77
+ version = f"{version - 1}"
78
+
79
+ rel_path = os.path.join(rel_path, str(version))
80
+ else:
81
+ rel_path = os.path.join(rel_path, get_new_model_version(cur_workdir))
82
+
83
+ cur_workdir = os.path.join(root_dir, rel_path)
84
+ try:
85
+ Path(cur_workdir).mkdir(exist_ok=False)
86
+ except FileExistsError:
87
+ print(
88
+ f"Workdir {cur_workdir} already exists. Probably because someother program also created the exact same directory. Trying to get a new version."
89
+ )
90
+ time.sleep(2.5)
91
+ if nested_call > 10:
92
+ raise ValueError(
93
+ f"Cannot create a new directory. {cur_workdir} already exists."
94
+ )
95
+
96
+ return get_workdir(config, root_dir, use_max_version, nested_call + 1)
97
+
98
+ return cur_workdir, rel_path
99
+
100
+
101
+ def get_mean_std_dict_for_model(config, train_dset):
102
+ """
103
+ Computes the mean and std for the model. This will be subsequently passed to the model.
104
+ """
105
+ mean_dict, std_dict = train_dset.get_mean_std()
106
+
107
+ return deepcopy(mean_dict), deepcopy(std_dict)
108
+
109
+
110
+ class MetricMonitor:
111
+ def __init__(self, metric):
112
+ assert metric in ["val_loss", "val_psnr"]
113
+ self.metric = metric
114
+
115
+ def mode(self):
116
+ if self.metric == "val_loss":
117
+ return "min"
118
+ elif self.metric == "val_psnr":
119
+ return "max"
120
+ else:
121
+ raise ValueError(f"Invalid metric:{self.metric}")
@@ -0,0 +1,7 @@
1
+ """Model I/O utilities."""
2
+
3
+ __all__ = ["load_pretrained", "export_to_bmz"]
4
+
5
+
6
+ from .bmz_io import export_to_bmz
7
+ from .model_io_utils import load_pretrained
@@ -0,0 +1,11 @@
1
+ """Bioimage Model Zoo format functions."""
2
+
3
+ __all__ = [
4
+ "create_model_description",
5
+ "extract_model_path",
6
+ "get_unzip_path",
7
+ "create_env_text",
8
+ ]
9
+
10
+ from .bioimage_utils import create_env_text, get_unzip_path
11
+ from .model_description import create_model_description, extract_model_path