careamics 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (87) hide show
  1. careamics/careamist.py +39 -28
  2. careamics/cli/__init__.py +5 -0
  3. careamics/cli/conf.py +391 -0
  4. careamics/cli/main.py +134 -0
  5. careamics/config/__init__.py +7 -3
  6. careamics/config/architectures/__init__.py +2 -2
  7. careamics/config/architectures/architecture_model.py +1 -1
  8. careamics/config/architectures/custom_model.py +11 -8
  9. careamics/config/architectures/lvae_model.py +170 -0
  10. careamics/config/configuration_factory.py +481 -170
  11. careamics/config/configuration_model.py +6 -3
  12. careamics/config/data_model.py +31 -20
  13. careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +35 -45
  14. careamics/config/likelihood_model.py +60 -0
  15. careamics/config/nm_model.py +127 -0
  16. careamics/config/optimizer_models.py +3 -1
  17. careamics/config/support/supported_activations.py +1 -0
  18. careamics/config/support/supported_algorithms.py +17 -4
  19. careamics/config/support/supported_architectures.py +8 -11
  20. careamics/config/support/supported_losses.py +3 -1
  21. careamics/config/support/supported_optimizers.py +1 -1
  22. careamics/config/support/supported_transforms.py +1 -0
  23. careamics/config/training_model.py +35 -6
  24. careamics/config/transformations/__init__.py +4 -1
  25. careamics/config/transformations/n2v_manipulate_model.py +1 -1
  26. careamics/config/transformations/transform_union.py +20 -0
  27. careamics/config/vae_algorithm_model.py +137 -0
  28. careamics/dataset/tiling/lvae_tiled_patching.py +364 -0
  29. careamics/file_io/read/tiff.py +1 -1
  30. careamics/lightning/__init__.py +3 -2
  31. careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
  32. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
  33. careamics/lightning/lightning_module.py +367 -9
  34. careamics/lightning/predict_data_module.py +2 -2
  35. careamics/lightning/train_data_module.py +4 -4
  36. careamics/losses/__init__.py +11 -1
  37. careamics/losses/fcn/__init__.py +1 -0
  38. careamics/losses/{losses.py → fcn/losses.py} +1 -1
  39. careamics/losses/loss_factory.py +112 -6
  40. careamics/losses/lvae/__init__.py +1 -0
  41. careamics/losses/lvae/loss_utils.py +83 -0
  42. careamics/losses/lvae/losses.py +445 -0
  43. careamics/lvae_training/dataset/__init__.py +15 -0
  44. careamics/lvae_training/dataset/config.py +123 -0
  45. careamics/lvae_training/dataset/lc_dataset.py +267 -0
  46. careamics/lvae_training/{data_modules.py → dataset/multich_dataset.py} +375 -501
  47. careamics/lvae_training/dataset/multifile_dataset.py +334 -0
  48. careamics/lvae_training/dataset/types.py +43 -0
  49. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  50. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  51. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  52. careamics/lvae_training/dataset/utils/index_manager.py +232 -0
  53. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  54. careamics/lvae_training/eval_utils.py +109 -64
  55. careamics/lvae_training/get_config.py +1 -1
  56. careamics/lvae_training/train_lvae.py +6 -3
  57. careamics/model_io/bioimage/bioimage_utils.py +1 -1
  58. careamics/model_io/bioimage/model_description.py +2 -2
  59. careamics/model_io/bmz_io.py +20 -7
  60. careamics/model_io/model_io_utils.py +16 -4
  61. careamics/models/__init__.py +1 -3
  62. careamics/models/activation.py +2 -0
  63. careamics/models/lvae/__init__.py +3 -0
  64. careamics/models/lvae/layers.py +21 -21
  65. careamics/models/lvae/likelihoods.py +190 -129
  66. careamics/models/lvae/lvae.py +60 -148
  67. careamics/models/lvae/noise_models.py +318 -186
  68. careamics/models/lvae/utils.py +2 -2
  69. careamics/models/model_factory.py +22 -7
  70. careamics/prediction_utils/lvae_prediction.py +158 -0
  71. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  72. careamics/prediction_utils/stitch_prediction.py +16 -2
  73. careamics/transforms/compose.py +90 -15
  74. careamics/transforms/n2v_manipulate.py +6 -2
  75. careamics/transforms/normalize.py +14 -3
  76. careamics/transforms/pixel_manipulation.py +1 -1
  77. careamics/transforms/xy_flip.py +16 -6
  78. careamics/transforms/xy_random_rotate90.py +16 -7
  79. careamics/utils/metrics.py +277 -24
  80. careamics/utils/serializers.py +60 -0
  81. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/METADATA +5 -4
  82. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/RECORD +85 -60
  83. careamics-0.0.4.dist-info/entry_points.txt +2 -0
  84. careamics/config/architectures/vae_model.py +0 -42
  85. careamics/lvae_training/data_utils.py +0 -618
  86. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/WHEEL +0 -0
  87. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/licenses/LICENSE +0 -0
@@ -2,189 +2,66 @@
2
2
  A place for Datasets and Dataloaders.
3
3
  """
4
4
 
5
- import os
6
- from typing import Tuple, Union
5
+ from typing import Tuple, Union, Callable
7
6
 
8
- # import albumentations as A
9
- import ml_collections
10
7
  import numpy as np
11
- from skimage.transform import resize
12
-
13
- from .data_utils import (
14
- DataSplitType,
15
- DataType,
16
- GridAlignement,
17
- GridIndexManager,
18
- IndexSwitcher,
19
- get_datasplit_tuples,
20
- get_mrc_data,
21
- load_tiff,
22
- )
23
-
24
-
25
- def get_train_val_data(
26
- data_config,
27
- fpath,
28
- datasplit_type: DataSplitType,
29
- val_fraction=None,
30
- test_fraction=None,
31
- allow_generation=None,
32
- ignore_specific_datapoints=None,
33
- ):
34
- """
35
- Load the data from the given path and split them in training, validation and test sets.
36
-
37
- Ensure that the shape of data should be N*H*W*C: N is number of data points. H,W are the image dimensions.
38
- C is the number of channels.
39
- """
40
- if data_config.data_type == DataType.SeparateTiffData:
41
- fpath1 = os.path.join(fpath, data_config.ch1_fname)
42
- fpath2 = os.path.join(fpath, data_config.ch2_fname)
43
- fpaths = [fpath1, fpath2]
44
- fpath0 = ""
45
- if "ch_input_fname" in data_config:
46
- fpath0 = os.path.join(fpath, data_config.ch_input_fname)
47
- fpaths = [fpath0] + fpaths
48
8
 
49
- print(
50
- f"Loading from {fpath} Channels: "
51
- f"{fpath1},{fpath2}, inp:{fpath0} Mode:{DataSplitType.name(datasplit_type)}"
52
- )
53
-
54
- data = np.concatenate([load_tiff(fpath)[..., None] for fpath in fpaths], axis=3)
55
- if data_config.data_type == DataType.PredictedTiffData:
56
- assert len(data.shape) == 5 and data.shape[-1] == 1
57
- data = data[..., 0].copy()
58
- # data = data[::3].copy()
59
- # NOTE: This was not the correct way to do it. It is so because the noise present in the input was directly related
60
- # to the noise present in the channels and so this is not the way we would get the data.
61
- # We need to add the noise independently to the input and the target.
62
-
63
- # if data_config.get('poisson_noise_factor', False):
64
- # data = np.random.poisson(data)
65
- # if data_config.get('enable_gaussian_noise', False):
66
- # synthetic_scale = data_config.get('synthetic_gaussian_scale', 0.1)
67
- # print('Adding Gaussian noise with scale', synthetic_scale)
68
- # noise = np.random.normal(0, synthetic_scale, data.shape)
69
- # data = data + noise
70
-
71
- if datasplit_type == DataSplitType.All:
72
- return data.astype(np.float32)
73
-
74
- train_idx, val_idx, test_idx = get_datasplit_tuples(
75
- val_fraction, test_fraction, len(data), starting_test=True
76
- )
77
- if datasplit_type == DataSplitType.Train:
78
- return data[train_idx].astype(np.float32)
79
- elif datasplit_type == DataSplitType.Val:
80
- return data[val_idx].astype(np.float32)
81
- elif datasplit_type == DataSplitType.Test:
82
- return data[test_idx].astype(np.float32)
83
-
84
- elif data_config.data_type == DataType.BioSR_MRC:
85
- num_channels = data_config.get("num_channels", 2)
86
- fpaths = []
87
- data_list = []
88
- for i in range(num_channels):
89
- fpath1 = os.path.join(fpath, data_config.get(f"ch{i + 1}_fname"))
90
- fpaths.append(fpath1)
91
- data = get_mrc_data(fpath1)[..., None]
92
- data_list.append(data)
93
-
94
- dirname = os.path.dirname(os.path.dirname(fpaths[0])) + "/"
95
-
96
- msg = ",".join([x[len(dirname) :] for x in fpaths])
97
- print(
98
- f"Loaded from {dirname} Channels:{len(fpaths)} {msg} Mode:{DataSplitType.name(datasplit_type)}"
99
- )
100
- N = data_list[0].shape[0]
101
- for data in data_list:
102
- N = min(N, data.shape[0])
103
-
104
- cropped_data = []
105
- for data in data_list:
106
- cropped_data.append(data[:N])
107
-
108
- data = np.concatenate(cropped_data, axis=3)
109
-
110
- if datasplit_type == DataSplitType.All:
111
- return data.astype(np.float32)
112
-
113
- train_idx, val_idx, test_idx = get_datasplit_tuples(
114
- val_fraction, test_fraction, len(data), starting_test=True
115
- )
116
- if datasplit_type == DataSplitType.Train:
117
- return data[train_idx].astype(np.float32)
118
- elif datasplit_type == DataSplitType.Val:
119
- return data[val_idx].astype(np.float32)
120
- elif datasplit_type == DataSplitType.Test:
121
- return data[test_idx].astype(np.float32)
9
+ from .utils.empty_patch_fetcher import EmptyPatchFetcher
10
+ from .utils.index_manager import GridIndexManager
11
+ from .utils.index_switcher import IndexSwitcher
12
+ from .config import DatasetConfig
13
+ from .types import DataSplitType, TilingMode
122
14
 
123
15
 
124
16
  class MultiChDloader:
125
-
126
17
  def __init__(
127
18
  self,
128
- data_config: ml_collections.ConfigDict,
19
+ data_config: DatasetConfig,
129
20
  fpath: str,
130
- datasplit_type: DataSplitType = None,
21
+ load_data_fn: Callable,
131
22
  val_fraction: float = None,
132
23
  test_fraction: float = None,
133
- normalized_input=None,
134
- enable_rotation_aug: bool = False,
135
- enable_random_cropping: bool = False,
136
- use_one_mu_std=None,
137
- allow_generation: bool = False,
138
- max_val: float = None,
139
- grid_alignment=GridAlignement.LeftTop,
140
- overlapping_padding_kwargs=None,
141
- print_vars: bool = True,
142
24
  ):
143
- """
144
- Here, an image is split into grids of size img_sz.
145
- Args:
146
- repeat_factor: Since we are doing a random crop, repeat_factor is
147
- given which can repeatedly sample from the same image. If self.N=12
148
- and repeat_factor is 5, then index upto 12*5 = 60 is allowed.
149
- use_one_mu_std: If this is set to true, then one mean and stdev is used
150
- for both channels. Otherwise, two different meean and stdev are used.
151
-
152
- """
25
+ """ """
153
26
  self._data_type = data_config.data_type
154
27
  self._fpath = fpath
155
- self._data = self.N = self._noise_data = None
156
-
157
- # Hardcoded params, not included in the config file.
158
-
28
+ self._data = self._noise_data = None
29
+ self.Z = 1
30
+ self._5Ddata = False
31
+ self._tiling_mode = data_config.tiling_mode
159
32
  # by default, if the noise is present, add it to the input and target.
160
33
  self._disable_noise = False # to add synthetic noise
34
+ self._poisson_noise_factor = None
161
35
  self._train_index_switcher = None
36
+ self._depth3D = data_config.depth3D
37
+ self._mode_3D = data_config.mode_3D
162
38
  # NOTE: Input is the sum of the different channels. It is not the average of the different channels.
163
- self._input_is_sum = data_config.get("input_is_sum", False)
164
- self._num_channels = data_config.get("num_channels", 2)
165
- self._input_idx = data_config.get("input_idx", None)
166
- self._tar_idx_list = data_config.get("target_idx_list", None)
39
+ self._input_is_sum = data_config.input_is_sum
40
+ self._num_channels = data_config.num_channels
41
+ self._input_idx = data_config.input_idx
42
+ self._tar_idx_list = data_config.target_idx_list
167
43
 
168
- if datasplit_type == DataSplitType.Train:
169
- self._datausage_fraction = 1.0
44
+ if data_config.datasplit_type == DataSplitType.Train:
45
+ self._datausage_fraction = data_config.trainig_datausage_fraction
170
46
  # assert self._datausage_fraction == 1.0, 'Not supported. Use validtarget_random_fraction and training_validtarget_fraction to get the same effect'
171
- self._validtarget_rand_fract = None
47
+ self._validtarget_rand_fract = data_config.validtarget_random_fraction
172
48
  # self._validtarget_random_fraction_final = data_config.get('validtarget_random_fraction_final', None)
173
49
  # self._validtarget_random_fraction_stepepoch = data_config.get('validtarget_random_fraction_stepepoch', None)
174
50
  # self._idx_count = 0
175
- elif datasplit_type == DataSplitType.Val:
176
- self._datausage_fraction = 1.0
51
+ elif data_config.datasplit_type == DataSplitType.Val:
52
+ self._datausage_fraction = data_config.validation_datausage_fraction
177
53
  else:
178
54
  self._datausage_fraction = 1.0
179
55
 
180
56
  self.load_data(
181
57
  data_config,
182
- datasplit_type,
58
+ data_config.datasplit_type,
59
+ load_data_fn=load_data_fn,
183
60
  val_fraction=val_fraction,
184
61
  test_fraction=test_fraction,
185
- allow_generation=allow_generation,
62
+ allow_generation=data_config.allow_generation,
186
63
  )
187
- self._normalized_input = normalized_input
64
+ self._normalized_input = data_config.normalized_input
188
65
  self._quantile = 1.0
189
66
  self._channelwise_quantile = False
190
67
  self._background_quantile = 0.0
@@ -194,31 +71,29 @@ class MultiChDloader:
194
71
 
195
72
  self._background_values = None
196
73
 
197
- self._grid_alignment = grid_alignment
198
- self._overlapping_padding_kwargs = overlapping_padding_kwargs
199
- if self._grid_alignment == GridAlignement.LeftTop:
200
- assert (
74
+ self._overlapping_padding_kwargs = data_config.overlapping_padding_kwargs
75
+ if self._tiling_mode in [TilingMode.TrimBoundary, TilingMode.ShiftBoundary]:
76
+ if (
201
77
  self._overlapping_padding_kwargs is None
202
78
  or data_config.multiscale_lowres_count is not None
203
- ), "Padding is not used with this alignement style"
204
- elif self._grid_alignment == GridAlignement.Center:
79
+ ):
80
+ # raise warning
81
+ print("Padding is not used with this alignement style")
82
+ else:
205
83
  assert (
206
84
  self._overlapping_padding_kwargs is not None
207
- ), "With Center grid alignment, padding is needed."
85
+ ), "When not trimming boudnary, padding is needed."
208
86
 
209
- self._is_train = datasplit_type == DataSplitType.Train
87
+ self._is_train = data_config.datasplit_type == DataSplitType.Train
210
88
 
211
89
  # input = alpha * ch1 + (1-alpha)*ch2.
212
90
  # alpha is sampled randomly between these two extremes
213
- self._start_alpha_arr = self._end_alpha_arr = self._return_alpha = (
214
- self._alpha_weighted_target
215
- ) = None
91
+ self._start_alpha_arr = self._end_alpha_arr = self._return_alpha = None
216
92
 
217
93
  self._img_sz = self._grid_sz = self._repeat_factor = self.idx_manager = None
218
94
  if self._is_train:
219
- self._start_alpha_arr = None
220
- self._end_alpha_arr = None
221
- self._alpha_weighted_target = False
95
+ self._start_alpha_arr = data_config.start_alpha
96
+ self._end_alpha_arr = data_config.end_alpha
222
97
 
223
98
  self.set_img_sz(
224
99
  data_config.image_size,
@@ -229,11 +104,13 @@ class MultiChDloader:
229
104
  ),
230
105
  )
231
106
 
232
- # if self._validtarget_rand_fract is not None:
233
- # self._train_index_switcher = IndexSwitcher(self.idx_manager, data_config, self._img_sz)
234
- # self._std_background_arr = None
107
+ if self._validtarget_rand_fract is not None:
108
+ self._train_index_switcher = IndexSwitcher(
109
+ self.idx_manager, data_config, self._img_sz
110
+ )
235
111
 
236
112
  else:
113
+
237
114
  self.set_img_sz(
238
115
  data_config.image_size,
239
116
  (
@@ -246,33 +123,46 @@ class MultiChDloader:
246
123
  self._return_alpha = False
247
124
  self._return_index = False
248
125
 
249
- # self._empty_patch_replacement_enabled = data_config.get("empty_patch_replacement_enabled",
250
- # False) and self._is_train
251
- # if self._empty_patch_replacement_enabled:
252
- # self._empty_patch_replacement_channel_idx = data_config.empty_patch_replacement_channel_idx
253
- # self._empty_patch_replacement_probab = data_config.empty_patch_replacement_probab
254
- # data_frames = self._data[..., self._empty_patch_replacement_channel_idx]
255
- # # NOTE: This is on the raw data. So, it must be called before removing the background.
256
- # self._empty_patch_fetcher = EmptyPatchFetcher(self.idx_manager,
257
- # self._img_sz,
258
- # data_frames,
259
- # max_val_threshold=data_config.empty_patch_max_val_threshold)
126
+ self._empty_patch_replacement_enabled = (
127
+ data_config.empty_patch_replacement_enabled and self._is_train
128
+ )
129
+ if self._empty_patch_replacement_enabled:
130
+ self._empty_patch_replacement_channel_idx = (
131
+ data_config.empty_patch_replacement_channel_idx
132
+ )
133
+ self._empty_patch_replacement_probab = (
134
+ data_config.empty_patch_replacement_probab
135
+ )
136
+ data_frames = self._data[..., self._empty_patch_replacement_channel_idx]
137
+ # NOTE: This is on the raw data. So, it must be called before removing the background.
138
+ self._empty_patch_fetcher = EmptyPatchFetcher(
139
+ self.idx_manager,
140
+ self._img_sz,
141
+ data_frames,
142
+ max_val_threshold=data_config.empty_patch_max_val_threshold,
143
+ )
260
144
 
261
- self.rm_bkground_set_max_val_and_upperclip_data(max_val, datasplit_type)
145
+ self.rm_bkground_set_max_val_and_upperclip_data(
146
+ data_config.max_val, data_config.datasplit_type
147
+ )
262
148
 
263
149
  # For overlapping dloader, image_size and repeat_factors are not related. hence a different function.
264
150
 
265
151
  self._mean = None
266
152
  self._std = None
267
- self._use_one_mu_std = use_one_mu_std
268
- # Hardcoded
269
- self._target_separate_normalization = True
153
+ self._use_one_mu_std = data_config.use_one_mu_std
154
+
155
+ self._target_separate_normalization = data_config.target_separate_normalization
156
+
157
+ self._enable_rotation = data_config.enable_rotation_aug
158
+ flipz_3D = data_config.random_flip_z_3D
159
+ self._flipz_3D = flipz_3D and self._enable_rotation
270
160
 
271
- self._enable_rotation = enable_rotation_aug
272
- self._enable_random_cropping = enable_random_cropping
161
+ self._enable_random_cropping = data_config.enable_random_cropping
273
162
  self._uncorrelated_channels = (
274
- data_config.get("uncorrelated_channels", False) and self._is_train
163
+ data_config.uncorrelated_channels and self._is_train
275
164
  )
165
+ self._uncorrelated_channel_probab = data_config.uncorrelated_channel_probab
276
166
  assert self._is_train or self._uncorrelated_channels is False
277
167
  assert (
278
168
  self._enable_random_cropping is True or self._uncorrelated_channels is False
@@ -281,14 +171,15 @@ class MultiChDloader:
281
171
 
282
172
  self._rotation_transform = None
283
173
  if self._enable_rotation:
284
- raise NotImplementedError(
285
- "Augmentation by means of rotation is not supported yet."
286
- )
174
+ # TODO: fix this import
175
+ import albumentations as A
176
+
287
177
  self._rotation_transform = A.Compose([A.Flip(), A.RandomRotate90()])
288
178
 
289
- if print_vars:
290
- msg = self._init_msg()
291
- print(msg)
179
+ # TODO: remove print log messages
180
+ # if print_vars:
181
+ # msg = self._init_msg()
182
+ # print(msg)
292
183
 
293
184
  def disable_noise(self):
294
185
  assert (
@@ -306,11 +197,12 @@ class MultiChDloader:
306
197
  self,
307
198
  data_config,
308
199
  datasplit_type,
200
+ load_data_fn: Callable,
309
201
  val_fraction=None,
310
202
  test_fraction=None,
311
203
  allow_generation=None,
312
204
  ):
313
- self._data = get_train_val_data(
205
+ self._data = load_data_fn(
314
206
  data_config,
315
207
  self._fpath,
316
208
  datasplit_type,
@@ -318,7 +210,9 @@ class MultiChDloader:
318
210
  test_fraction=test_fraction,
319
211
  allow_generation=allow_generation,
320
212
  )
213
+ self._loaded_data_preprocessing(data_config)
321
214
 
215
+ def _loaded_data_preprocessing(self, data_config):
322
216
  old_shape = self._data.shape
323
217
  if self._datausage_fraction < 1.0:
324
218
  framepixelcount = np.prod(self._data.shape[1:3])
@@ -339,28 +233,37 @@ class MultiChDloader:
339
233
  )
340
234
 
341
235
  msg = ""
342
- if data_config.get("poisson_noise_factor", -1) > 0:
236
+ if data_config.poisson_noise_factor > 0:
343
237
  self._poisson_noise_factor = data_config.poisson_noise_factor
344
238
  msg += f"Adding Poisson noise with factor {self._poisson_noise_factor}.\t"
345
- self._data = (
346
- np.random.poisson(self._data / self._poisson_noise_factor)
347
- * self._poisson_noise_factor
348
- )
239
+ self._data = np.random.poisson(self._data / self._poisson_noise_factor)
349
240
 
350
- if data_config.get("enable_gaussian_noise", False):
351
- synthetic_scale = data_config.get("synthetic_gaussian_scale", 0.1)
241
+ if data_config.enable_gaussian_noise:
242
+ synthetic_scale = data_config.synthetic_gaussian_scale
352
243
  msg += f"Adding Gaussian noise with scale {synthetic_scale}"
353
244
  # 0 => noise for input. 1: => noise for all targets.
354
245
  shape = self._data.shape
355
246
  self._noise_data = np.random.normal(
356
247
  0, synthetic_scale, (*shape[:-1], shape[-1] + 1)
357
248
  )
358
- if data_config.get("input_has_dependant_noise", False):
249
+ if data_config.input_has_dependant_noise:
359
250
  msg += ". Moreover, input has dependent noise"
360
251
  self._noise_data[..., 0] = np.mean(self._noise_data[..., 1:], axis=-1)
361
252
  print(msg)
362
253
 
363
- self.N = len(self._data)
254
+ if len(self._data.shape) == 5:
255
+ if self._mode_3D:
256
+ self._5Ddata = True
257
+ else:
258
+ assert self._depth3D == 1, "Depth3D must be 1 for 2D training"
259
+ self._data = self._data.reshape(-1, *self._data.shape[2:])
260
+
261
+ if self._5Ddata:
262
+ self.Z = self._data.shape[1]
263
+
264
+ if self._depth3D > 1:
265
+ assert self._5Ddata, "Data must be 5D:NxZxHxWxC for 3D data"
266
+
364
267
  assert (
365
268
  self._data.shape[-1] == self._num_channels
366
269
  ), "Number of channels in data and config do not match."
@@ -441,9 +344,13 @@ class MultiChDloader:
441
344
  def get_img_sz(self):
442
345
  return self._img_sz
443
346
 
347
+ def get_num_frames(self):
348
+ return self._data.shape[0]
349
+
444
350
  def reduce_data(
445
351
  self, t_list=None, h_start=None, h_end=None, w_start=None, w_end=None
446
352
  ):
353
+ assert not self._5Ddata, "This function is not supported for 3D data."
447
354
  if t_list is None:
448
355
  t_list = list(range(self._data.shape[0]))
449
356
  if h_start is None:
@@ -461,25 +368,56 @@ class MultiChDloader:
461
368
  t_list, h_start:h_end, w_start:w_end, :
462
369
  ].copy()
463
370
 
464
- self.N = len(t_list)
465
371
  self.set_img_sz(self._img_sz, self._grid_sz)
466
372
  print(
467
373
  f"[{self.__class__.__name__}] Data reduced. New data shape: {self._data.shape}"
468
374
  )
469
375
 
470
- def set_img_sz(self, image_size, grid_size):
376
+ def get_idx_manager_shapes(
377
+ self, patch_size: int, grid_size: Union[int, Tuple[int, int, int]]
378
+ ):
379
+ numC = self._data.shape[-1]
380
+ if self._5Ddata:
381
+ patch_shape = (1, self._depth3D, patch_size, patch_size, numC)
382
+ if isinstance(grid_size, int):
383
+ grid_shape = (1, 1, grid_size, grid_size, numC)
384
+ else:
385
+ assert len(grid_size) == 3
386
+ assert all(
387
+ [g <= p for g, p in zip(grid_size, patch_shape[1:-1])]
388
+ ), f"Grid size {grid_size} must be less than patch size {patch_shape[1:-1]}"
389
+ grid_shape = (1, grid_size[0], grid_size[1], grid_size[2], numC)
390
+ else:
391
+ assert isinstance(grid_size, int)
392
+ grid_shape = (1, grid_size, grid_size, numC)
393
+ patch_shape = (1, patch_size, patch_size, numC)
394
+
395
+ return patch_shape, grid_shape
396
+
397
+ def set_img_sz(self, image_size, grid_size: Union[int, Tuple[int, int, int]]):
471
398
  """
472
399
  If one wants to change the image size on the go, then this can be used.
473
400
  Args:
474
401
  image_size: size of one patch
475
402
  grid_size: frame is divided into square grids of this size. A patch centered on a grid having size `image_size` is returned.
476
403
  """
404
+
477
405
  self._img_sz = image_size
478
406
  self._grid_sz = grid_size
407
+ shape = self._data.shape
408
+
409
+ patch_shape, grid_shape = self.get_idx_manager_shapes(
410
+ self._img_sz, self._grid_sz
411
+ )
479
412
  self.idx_manager = GridIndexManager(
480
- self._data.shape, self._grid_sz, self._img_sz, self._grid_alignment
413
+ shape, grid_shape, patch_shape, self._tiling_mode
481
414
  )
482
- self.set_repeat_factor()
415
+ # self.set_repeat_factor()
416
+
417
+ def __len__(self):
418
+ # Vera: N is the number of frames in Z stack
419
+ # Repeat factor is n_rows * n_cols
420
+ return self.idx_manager.total_grid_count()
483
421
 
484
422
  def set_repeat_factor(self):
485
423
  if self._grid_sz > 1:
@@ -497,10 +435,20 @@ class MultiChDloader:
497
435
  msg = (
498
436
  f"[{self.__class__.__name__}] Train:{int(self._is_train)} Sz:{self._img_sz}"
499
437
  )
438
+ dim_sizes = [
439
+ self.idx_manager.get_individual_dim_grid_count(dim)
440
+ for dim in range(len(self._data.shape))
441
+ ]
442
+ dim_sizes = ",".join([str(x) for x in dim_sizes])
500
443
  msg += f" N:{self.N} NumPatchPerN:{self._repeat_factor}"
444
+ msg += f"{self.idx_manager.total_grid_count()} DimSz:({dim_sizes})"
445
+ msg += f" TrimB:{self._tiling_mode}"
501
446
  # msg += f' NormInp:{self._normalized_input}'
502
447
  # msg += f' SingleNorm:{self._use_one_mu_std}'
503
448
  msg += f" Rot:{self._enable_rotation}"
449
+ if self._flipz_3D:
450
+ msg += f" FlipZ:{self._flipz_3D}"
451
+
504
452
  msg += f" RandCrop:{self._enable_random_cropping}"
505
453
  msg += f" Channel:{self._num_channels}"
506
454
  # msg += f' Q:{self._quantile}'
@@ -529,40 +477,52 @@ class MultiChDloader:
529
477
  )
530
478
 
531
479
  if self._enable_random_cropping:
532
- h_start, w_start = self._get_random_hw(h, w)
480
+ patch_start_loc = self._get_random_hw(h, w)
481
+ if self._5Ddata:
482
+ patch_start_loc = (
483
+ np.random.choice(1 + img_tuples[0].shape[-3] - self._depth3D),
484
+ ) + patch_start_loc
533
485
  else:
534
- h_start, w_start = self._get_deterministic_hw(index)
486
+ patch_start_loc = self._get_deterministic_loc(index)
535
487
 
536
488
  cropped_imgs = []
537
489
  for img in img_tuples:
538
- img = self._crop_flip_img(img, h_start, w_start, False, False)
490
+ img = self._crop_flip_img(img, patch_start_loc, False, False)
539
491
  cropped_imgs.append(img)
540
492
 
541
493
  return (
542
494
  *tuple(cropped_imgs),
543
495
  {
544
- "h": [h_start, h_start + self._img_sz],
545
- "w": [w_start, w_start + self._img_sz],
546
496
  "hflip": False,
547
497
  "wflip": False,
548
498
  },
549
499
  )
550
500
 
551
- def _crop_img(self, img: np.ndarray, h_start: int, w_start: int):
552
- if self._grid_alignment == GridAlignement.LeftTop:
501
+ def _crop_img(self, img: np.ndarray, patch_start_loc: Tuple):
502
+ if self._tiling_mode in [TilingMode.TrimBoundary, TilingMode.ShiftBoundary]:
553
503
  # In training, this is used.
554
504
  # NOTE: It is my opinion that if I just use self._crop_img_with_padding, it will work perfectly fine.
555
505
  # The only benefit this if else loop provides is that it makes it easier to see what happens during training.
556
- new_img = img[
557
- ..., h_start : h_start + self._img_sz, w_start : w_start + self._img_sz
558
- ]
506
+ patch_end_loc = (
507
+ np.array(patch_start_loc, dtype=np.int32)
508
+ + self.idx_manager.patch_shape[1:-1]
509
+ )
510
+ if self._5Ddata:
511
+ z_start, h_start, w_start = patch_start_loc
512
+ z_end, h_end, w_end = patch_end_loc
513
+ new_img = img[..., z_start:z_end, h_start:h_end, w_start:w_end]
514
+ else:
515
+ h_start, w_start = patch_start_loc
516
+ h_end, w_end = patch_end_loc
517
+ new_img = img[..., h_start:h_end, w_start:w_end]
518
+
559
519
  return new_img
560
- elif self._grid_alignment == GridAlignement.Center:
520
+ else:
561
521
  # During evaluation, this is used. In this situation, we can have negative h_start, w_start. Or h_start +self._img_sz can be larger than frame
562
522
  # In these situations, we need some sort of padding. This is not needed in the LeftTop alignement.
563
- return self._crop_img_with_padding(img, h_start, w_start)
523
+ return self._crop_img_with_padding(img, patch_start_loc)
564
524
 
565
- def get_begin_end_padding(self, start_pos, max_len):
525
+ def get_begin_end_padding(self, start_pos, end_pos, max_len):
566
526
  """
567
527
  The effect is that the image with size self._grid_sz is in the center of the patch with sufficient
568
528
  padding on all four sides so that the final patch size is self._img_sz.
@@ -572,44 +532,56 @@ class MultiChDloader:
572
532
  if start_pos < 0:
573
533
  pad_start = -1 * start_pos
574
534
 
575
- pad_end = max(0, start_pos + self._img_sz - max_len)
535
+ pad_end = max(0, end_pos - max_len)
576
536
 
577
537
  return pad_start, pad_end
578
538
 
579
- def _crop_img_with_padding(self, img: np.ndarray, h_start: int, w_start: int):
580
- _, H, W = img.shape
581
- h_on_boundary = self.on_boundary(h_start, H)
582
- w_on_boundary = self.on_boundary(w_start, W)
583
-
584
- assert h_start < H
585
- assert w_start < W
586
-
587
- assert h_start + self._img_sz <= H or h_on_boundary
588
- assert w_start + self._img_sz <= W or w_on_boundary
539
+ def _crop_img_with_padding(
540
+ self, img: np.ndarray, patch_start_loc, max_len_vals=None
541
+ ):
542
+ if max_len_vals is None:
543
+ max_len_vals = self.idx_manager.data_shape[1:-1]
544
+ patch_end_loc = np.array(patch_start_loc, dtype=int) + np.array(
545
+ self.idx_manager.patch_shape[1:-1], dtype=int
546
+ )
547
+ boundary_crossed = []
548
+ valid_slice = []
549
+ padding = [[0, 0]]
550
+ for start_idx, end_idx, max_len in zip(
551
+ patch_start_loc, patch_end_loc, max_len_vals
552
+ ):
553
+ boundary_crossed.append(end_idx > max_len or start_idx < 0)
554
+ valid_slice.append((max(0, start_idx), min(max_len, end_idx)))
555
+ pad = [0, 0]
556
+ if boundary_crossed[-1]:
557
+ pad = self.get_begin_end_padding(start_idx, end_idx, max_len)
558
+ padding.append(pad)
589
559
  # max() is needed since h_start could be negative.
590
- new_img = img[
591
- ...,
592
- max(0, h_start) : h_start + self._img_sz,
593
- max(0, w_start) : w_start + self._img_sz,
594
- ]
595
- padding = np.array([[0, 0], [0, 0], [0, 0]])
596
-
597
- if h_on_boundary:
598
- pad = self.get_begin_end_padding(h_start, H)
599
- padding[1] = pad
600
- if w_on_boundary:
601
- pad = self.get_begin_end_padding(w_start, W)
602
- padding[2] = pad
560
+ if self._5Ddata:
561
+ new_img = img[
562
+ ...,
563
+ valid_slice[0][0] : valid_slice[0][1],
564
+ valid_slice[1][0] : valid_slice[1][1],
565
+ valid_slice[2][0] : valid_slice[2][1],
566
+ ]
567
+ else:
568
+ new_img = img[
569
+ ...,
570
+ valid_slice[0][0] : valid_slice[0][1],
571
+ valid_slice[1][0] : valid_slice[1][1],
572
+ ]
603
573
 
574
+ # print(np.array(padding).shape, img.shape, new_img.shape)
575
+ # print(padding)
604
576
  if not np.all(padding == 0):
605
577
  new_img = np.pad(new_img, padding, **self._overlapping_padding_kwargs)
606
578
 
607
579
  return new_img
608
580
 
609
581
  def _crop_flip_img(
610
- self, img: np.ndarray, h_start: int, w_start: int, h_flip: bool, w_flip: bool
582
+ self, img: np.ndarray, patch_start_loc: Tuple, h_flip: bool, w_flip: bool
611
583
  ):
612
- new_img = self._crop_img(img, h_start, w_start)
584
+ new_img = self._crop_img(img, patch_start_loc)
613
585
  if h_flip:
614
586
  new_img = new_img[..., ::-1, :]
615
587
  if w_flip:
@@ -617,9 +589,6 @@ class MultiChDloader:
617
589
 
618
590
  return new_img.astype(np.float32)
619
591
 
620
- def __len__(self):
621
- return self.N * self._repeat_factor
622
-
623
592
  def _load_img(
624
593
  self, index: Union[int, Tuple[int, int]]
625
594
  ) -> Tuple[np.ndarray, np.ndarray]:
@@ -631,12 +600,21 @@ class MultiChDloader:
631
600
  else:
632
601
  idx = index[0]
633
602
 
634
- imgs = self._data[self.idx_manager.get_t(idx)]
603
+ patch_loc_list = self.idx_manager.get_patch_location_from_dataset_idx(idx)
604
+ imgs = self._data[patch_loc_list[0]]
605
+ # if self._5Ddata:
606
+ # assert self._noise_data is None, 'Noise is not supported for 5D data'
607
+ # n_loc, z_loc = patch_loc_list[:2]
608
+ # z_loc_interval = range(z_loc, z_loc + self._depth3D)
609
+ # imgs = self._data[n_loc, z_loc_interval]
610
+ # else:
611
+ # imgs = self._data[patch_loc_list[0]]
612
+
635
613
  loaded_imgs = [imgs[None, ..., i] for i in range(imgs.shape[-1])]
636
614
  noise = []
637
615
  if self._noise_data is not None and not self._disable_noise:
638
616
  noise = [
639
- self._noise_data[self.idx_manager.get_t(idx)][None, ..., i]
617
+ self._noise_data[patch_loc_list[0]][None, ..., i]
640
618
  for i in range(self._noise_data.shape[-1])
641
619
  ]
642
620
  return tuple(loaded_imgs), tuple(noise)
@@ -660,6 +638,18 @@ class MultiChDloader:
660
638
  normalized_imgs.append(img)
661
639
  return tuple(normalized_imgs)
662
640
 
641
+ def normalize_input(self, x):
642
+ mean_dict, std_dict = self.get_mean_std()
643
+ mean_ = mean_dict["input"].mean()
644
+ std_ = std_dict["input"].mean()
645
+ return (x - mean_) / std_
646
+
647
+ def normalize_target(self, target):
648
+ mean_dict, std_dict = self.get_mean_std()
649
+ mean_ = mean_dict["target"].squeeze(0)
650
+ std_ = std_dict["target"].squeeze(0)
651
+ return (target - mean_) / std_
652
+
663
653
  def get_grid_size(self):
664
654
  return self._grid_sz
665
655
 
@@ -669,27 +659,16 @@ class MultiChDloader:
669
659
  def per_side_overlap_pixelcount(self):
670
660
  return (self._img_sz - self._grid_sz) // 2
671
661
 
672
- def on_boundary(self, cur_loc, frame_size):
673
- return cur_loc + self._img_sz > frame_size or cur_loc < 0
662
+ # def on_boundary(self, cur_loc, frame_size):
663
+ # return cur_loc + self._img_sz > frame_size or cur_loc < 0
674
664
 
675
- def _get_deterministic_hw(self, index: Union[int, Tuple[int, int]]):
665
+ def _get_deterministic_loc(self, index: int):
676
666
  """
677
667
  It returns the top-left corner of the patch corresponding to index.
678
668
  """
679
- if isinstance(index, int) or isinstance(index, np.int64):
680
- idx = index
681
- grid_size = self._grid_sz
682
- else:
683
- idx, grid_size = index
684
-
685
- h_start, w_start = self.idx_manager.get_deterministic_hw(
686
- idx, grid_size=grid_size
687
- )
688
- if self._grid_alignment == GridAlignement.LeftTop:
689
- return h_start, w_start
690
- elif self._grid_alignment == GridAlignement.Center:
691
- pad = self.per_side_overlap_pixelcount()
692
- return h_start - pad, w_start - pad
669
+ loc_list = self.idx_manager.get_patch_location_from_dataset_idx(index)
670
+ # last dim is channel. we need to take the third and the second last element.
671
+ return loc_list[1:-1]
693
672
 
694
673
  def compute_individual_mean_std(self):
695
674
  # numpy 1.19.2 has issues in computing for large arrays. https://github.com/numpy/numpy/issues/8869
@@ -715,6 +694,10 @@ class MultiChDloader:
715
694
 
716
695
  mean = np.array(mean_arr)
717
696
  std = np.array(std_arr)
697
+ if (
698
+ self._5Ddata
699
+ ): # NOTE: IDEALLY this should be only when the model expects 3D data.
700
+ return mean[None, :, None, None, None], std[None, :, None, None, None]
718
701
 
719
702
  return mean[None, :, None, None], std[None, :, None, None]
720
703
 
@@ -776,6 +759,10 @@ class MultiChDloader:
776
759
  if self._skip_normalization_using_mean:
777
760
  mean = np.zeros_like(mean)
778
761
 
762
+ if self._5Ddata:
763
+ mean = mean[:, :, None]
764
+ std = std[:, :, None]
765
+
779
766
  mean_dict = {"input": mean} # , 'target':mean}
780
767
  std_dict = {"input": std} # , 'target':std}
781
768
 
@@ -810,8 +797,14 @@ class MultiChDloader:
810
797
  return cropped_img_tuples, cropped_noise_tuples
811
798
 
812
799
  def replace_with_empty_patch(self, img_tuples):
800
+ """
801
+ Replaces the content of one of the channels with background
802
+ """
813
803
  empty_index = self._empty_patch_fetcher.sample()
814
- empty_img_tuples = self._get_img(empty_index)
804
+ empty_img_tuples, empty_img_noise_tuples = self._get_img(empty_index)
805
+ assert (
806
+ len(empty_img_noise_tuples) == 0
807
+ ), "Noise is not supported with empty patch replacement"
815
808
  final_img_tuples = []
816
809
  for tuple_idx in range(len(img_tuples)):
817
810
  if tuple_idx == self._empty_patch_replacement_channel_idx:
@@ -834,14 +827,7 @@ class MultiChDloader:
834
827
  )
835
828
  img_tuples = [img_tuples[i] for i in self._tar_idx_list]
836
829
 
837
- if self._alpha_weighted_target:
838
- assert self._input_is_sum is False
839
- target = []
840
- for i in range(len(img_tuples)):
841
- target.append(img_tuples[i] * alpha[i])
842
- target = np.concatenate(target, axis=0)
843
- else:
844
- target = np.concatenate(img_tuples, axis=0)
830
+ target = np.concatenate(img_tuples, axis=0)
845
831
  return target
846
832
 
847
833
  def _compute_input_with_alpha(self, img_tuples, alpha_list):
@@ -902,9 +888,6 @@ class MultiChDloader:
902
888
  index = self._train_index_switcher.get_invalid_target_index()
903
889
  return index
904
890
 
905
- def _rotate(self, img_tuples, noise_tuples):
906
- return self._rotate2D(img_tuples, noise_tuples)
907
-
908
891
  def _rotate2D(self, img_tuples, noise_tuples):
909
892
  img_kwargs = {}
910
893
  for i, img in enumerate(img_tuples):
@@ -921,6 +904,7 @@ class MultiChDloader:
921
904
  rot_dic = self._rotation_transform(
922
905
  image=img_tuples[0][0], **img_kwargs, **noise_kwargs
923
906
  )
907
+
924
908
  rotated_img_tuples = []
925
909
  for i, img in enumerate(img_tuples):
926
910
  if len(img) == 1:
@@ -946,7 +930,101 @@ class MultiChDloader:
946
930
 
947
931
  return rotated_img_tuples, rotated_noise_tuples
948
932
 
933
+ def _rotate(self, img_tuples, noise_tuples):
934
+
935
+ if self._5Ddata:
936
+ return self._rotate3D(img_tuples, noise_tuples)
937
+ else:
938
+ return self._rotate2D(img_tuples, noise_tuples)
939
+
940
+ def _rotate3D(self, img_tuples, noise_tuples):
941
+ img_kwargs = {}
942
+ # random flip in z direction
943
+ flip_z = self._flipz_3D and np.random.rand() < 0.5
944
+ for i, img in enumerate(img_tuples):
945
+ for j in range(self._depth3D):
946
+ for k in range(len(img)):
947
+ if flip_z:
948
+ z_idx = self._depth3D - 1 - j
949
+ else:
950
+ z_idx = j
951
+ img_kwargs[f"img{i}_{z_idx}_{k}"] = img[k, j]
952
+
953
+ noise_kwargs = {}
954
+ for i, nimg in enumerate(noise_tuples):
955
+ for j in range(self._depth3D):
956
+ for k in range(len(nimg)):
957
+ if flip_z:
958
+ z_idx = self._depth3D - 1 - j
959
+ else:
960
+ z_idx = j
961
+ noise_kwargs[f"noise{i}_{z_idx}_{k}"] = nimg[k, j]
962
+
963
+ keys = list(img_kwargs.keys()) + list(noise_kwargs.keys())
964
+ self._rotation_transform.add_targets({k: "image" for k in keys})
965
+ rot_dic = self._rotation_transform(
966
+ image=img_tuples[0][0][0], **img_kwargs, **noise_kwargs
967
+ )
968
+ rotated_img_tuples = []
969
+ for i, img in enumerate(img_tuples):
970
+ if len(img) == 1:
971
+ rotated_img_tuples.append(
972
+ np.concatenate(
973
+ [
974
+ rot_dic[f"img{i}_{j}_0"][None, None]
975
+ for j in range(self._depth3D)
976
+ ],
977
+ axis=1,
978
+ )
979
+ )
980
+ else:
981
+ temp_arr = []
982
+ for k in range(len(img)):
983
+ temp_arr.append(
984
+ np.concatenate(
985
+ [
986
+ rot_dic[f"img{i}_{j}_{k}"][None, None]
987
+ for j in range(self._depth3D)
988
+ ],
989
+ axis=1,
990
+ )
991
+ )
992
+ rotated_img_tuples.append(np.concatenate(temp_arr, axis=0))
993
+
994
+ rotated_noise_tuples = []
995
+ for i, nimg in enumerate(noise_tuples):
996
+ if len(nimg) == 1:
997
+ rotated_noise_tuples.append(
998
+ np.concatenate(
999
+ [
1000
+ rot_dic[f"noise{i}_{j}_0"][None, None]
1001
+ for j in range(self._depth3D)
1002
+ ],
1003
+ axis=1,
1004
+ )
1005
+ )
1006
+ else:
1007
+ temp_arr = []
1008
+ for k in range(len(nimg)):
1009
+ temp_arr.append(
1010
+ np.concatenate(
1011
+ [
1012
+ rot_dic[f"noise{i}_{j}_{k}"][None, None]
1013
+ for j in range(self._depth3D)
1014
+ ],
1015
+ axis=1,
1016
+ )
1017
+ )
1018
+ rotated_noise_tuples.append(np.concatenate(temp_arr, axis=0))
1019
+
1020
+ return rotated_img_tuples, rotated_noise_tuples
1021
+
949
1022
  def get_uncorrelated_img_tuples(self, index):
1023
+ """
1024
+ Content of channels like actin and nuclei is "correlated" in its
1025
+ respective location, this function allows to pick channels' content
1026
+ from different patches of the image to make it "uncorrelated".
1027
+ """
950
1028
  img_tuples, noise_tuples = self._get_img(index)
951
1029
  assert len(noise_tuples) == 0
952
1030
  img_tuples = [img_tuples[0]]
@@ -959,10 +1037,15 @@ class MultiChDloader:
959
1037
  def __getitem__(
960
1038
  self, index: Union[int, Tuple[int, int]]
961
1039
  ) -> Tuple[np.ndarray, np.ndarray]:
1040
+ # Vera: input can be both real microscopic image and two separate channels that are summed in the code
1041
+
962
1042
  if self._train_index_switcher is not None:
963
1043
  index = self._get_index_from_valid_target_logic(index)
964
1044
 
965
- if self._uncorrelated_channels:
1045
+ if (
1046
+ self._uncorrelated_channels
1047
+ and np.random.rand() < self._uncorrelated_channel_probab
1048
+ ):
966
1049
  img_tuples, noise_tuples = self.get_uncorrelated_img_tuples(index)
967
1050
  else:
968
1051
  img_tuples, noise_tuples = self._get_img(index)
@@ -971,28 +1054,36 @@ class MultiChDloader:
971
1054
  self._empty_patch_replacement_enabled != True
972
1055
  ), "This is not supported with noise"
973
1056
 
1057
+ # Replace the content of one of the channels
1058
+ # with background with given probability
974
1059
  if self._empty_patch_replacement_enabled:
975
1060
  if np.random.rand() < self._empty_patch_replacement_probab:
976
1061
  img_tuples = self.replace_with_empty_patch(img_tuples)
977
1062
 
1063
+ # Noise tuples are not needed for the paper
1064
+ # the image tuples are noisy by default
1065
+ # TODO: remove noise tuples completely?
978
1066
  if self._enable_rotation:
979
1067
  img_tuples, noise_tuples = self._rotate(img_tuples, noise_tuples)
980
1068
 
981
- # add noise to input
1069
+ # Add noise tuples with image tuples to create the input
982
1070
  if len(noise_tuples) > 0:
983
1071
  factor = np.sqrt(2) if self._input_is_sum else 1.0
984
1072
  input_tuples = [x + noise_tuples[0] * factor for x in img_tuples]
985
1073
  else:
986
1074
  input_tuples = img_tuples
1075
+
1076
+ # Weight the individual channels, typically alpha is fixed
987
1077
  inp, alpha = self._compute_input(input_tuples)
988
1078
 
989
- # add noise to target.
1079
+ # Add noise tuples to the image tuples to create the target
990
1080
  if len(noise_tuples) >= 1:
991
1081
  img_tuples = [x + noise for x, noise in zip(img_tuples, noise_tuples[1:])]
992
1082
 
993
1083
  target = self._compute_target(img_tuples, alpha)
1084
+ norm_target = self.normalize_target(target)
994
1085
 
995
- output = [inp, target]
1086
+ output = [inp, norm_target]
996
1087
 
997
1088
  if self._return_alpha:
998
1089
  output.append(alpha)
@@ -1000,221 +1091,4 @@ class MultiChDloader:
1000
1091
  if self._return_index:
1001
1092
  output.append(index)
1002
1093
 
1003
- if isinstance(index, int) or isinstance(index, np.int64):
1004
- return tuple(output)
1005
-
1006
- _, grid_size = index
1007
- output.append(grid_size)
1008
- return tuple(output)
1009
-
1010
-
1011
- class LCMultiChDloader(MultiChDloader):
1012
-
1013
- def __init__(
1014
- self,
1015
- data_config,
1016
- fpath: str,
1017
- datasplit_type: DataSplitType = None,
1018
- val_fraction=None,
1019
- test_fraction=None,
1020
- normalized_input=None,
1021
- enable_rotation_aug: bool = False,
1022
- use_one_mu_std=None,
1023
- num_scales: int = None,
1024
- enable_random_cropping=False,
1025
- padding_kwargs: dict = None,
1026
- allow_generation: bool = False,
1027
- lowres_supervision=None,
1028
- max_val=None,
1029
- grid_alignment=GridAlignement.LeftTop,
1030
- overlapping_padding_kwargs=None,
1031
- print_vars=True,
1032
- ):
1033
- """
1034
- Args:
1035
- num_scales: The number of resolutions at which we want the input. Note that the target is formed at the
1036
- highest resolution.
1037
- """
1038
- self._padding_kwargs = (
1039
- padding_kwargs # mode=padding_mode, constant_values=constant_value
1040
- )
1041
- if overlapping_padding_kwargs is not None:
1042
- assert (
1043
- self._padding_kwargs == overlapping_padding_kwargs
1044
- ), "During evaluation, overlapping_padding_kwargs should be same as padding_args. \
1045
- It should be so since we just use overlapping_padding_kwargs when it is not None"
1046
-
1047
- else:
1048
- overlapping_padding_kwargs = padding_kwargs
1049
-
1050
- super().__init__(
1051
- data_config,
1052
- fpath,
1053
- datasplit_type=datasplit_type,
1054
- val_fraction=val_fraction,
1055
- test_fraction=test_fraction,
1056
- normalized_input=normalized_input,
1057
- enable_rotation_aug=enable_rotation_aug,
1058
- enable_random_cropping=enable_random_cropping,
1059
- use_one_mu_std=use_one_mu_std,
1060
- allow_generation=allow_generation,
1061
- max_val=max_val,
1062
- grid_alignment=grid_alignment,
1063
- overlapping_padding_kwargs=overlapping_padding_kwargs,
1064
- print_vars=print_vars,
1065
- )
1066
- self.num_scales = num_scales
1067
- assert self.num_scales is not None
1068
- self._scaled_data = [self._data]
1069
- self._scaled_noise_data = [self._noise_data]
1070
-
1071
- assert isinstance(self.num_scales, int) and self.num_scales >= 1
1072
- self._lowres_supervision = lowres_supervision
1073
- assert isinstance(self._padding_kwargs, dict)
1074
- assert "mode" in self._padding_kwargs
1075
-
1076
- for _ in range(1, self.num_scales):
1077
- shape = self._scaled_data[-1].shape
1078
- assert len(shape) == 4
1079
- new_shape = (shape[0], shape[1] // 2, shape[2] // 2, shape[3])
1080
- ds_data = resize(
1081
- self._scaled_data[-1].astype(np.float32), new_shape
1082
- ).astype(self._scaled_data[-1].dtype)
1083
- # NOTE: These asserts are important. the resize method expects np.float32. otherwise, one gets weird results.
1084
- assert (
1085
- ds_data.max() / self._scaled_data[-1].max() < 5
1086
- ), "Downsampled image should not have very different values"
1087
- assert (
1088
- ds_data.max() / self._scaled_data[-1].max() > 0.2
1089
- ), "Downsampled image should not have very different values"
1090
-
1091
- self._scaled_data.append(ds_data)
1092
- # do the same for noise
1093
- if self._noise_data is not None:
1094
- noise_data = resize(self._scaled_noise_data[-1], new_shape)
1095
- self._scaled_noise_data.append(noise_data)
1096
-
1097
- def _init_msg(self):
1098
- msg = super()._init_msg()
1099
- msg += f" Pad:{self._padding_kwargs}"
1100
- return msg
1101
-
1102
- def _load_scaled_img(
1103
- self, scaled_index, index: Union[int, Tuple[int, int]]
1104
- ) -> Tuple[np.ndarray, np.ndarray]:
1105
- if isinstance(index, int):
1106
- idx = index
1107
- else:
1108
- idx, _ = index
1109
- imgs = self._scaled_data[scaled_index][idx % self.N]
1110
- imgs = tuple([imgs[None, :, :, i] for i in range(imgs.shape[-1])])
1111
- if self._noise_data is not None:
1112
- noisedata = self._scaled_noise_data[scaled_index][idx % self.N]
1113
- noise = tuple(
1114
- [noisedata[None, :, :, i] for i in range(noisedata.shape[-1])]
1115
- )
1116
- factor = np.sqrt(2) if self._input_is_sum else 1.0
1117
- # since we are using this lowres images for just the input, we need to add the noise of the input.
1118
- assert self._lowres_supervision is None or self._lowres_supervision is False
1119
- imgs = tuple([img + noise[0] * factor for img in imgs])
1120
- return imgs
1121
-
1122
- def _crop_img(self, img: np.ndarray, h_start: int, w_start: int):
1123
- """
1124
- Here, h_start, w_start could be negative. That simply means we need to pick the content from 0. So,
1125
- the cropped image will be smaller than self._img_sz * self._img_sz
1126
- """
1127
- return self._crop_img_with_padding(img, h_start, w_start)
1128
-
1129
- def _get_img(self, index: int):
1130
- """
1131
- Returns the primary patch along with low resolution patches centered on the primary patch.
1132
- """
1133
- img_tuples, noise_tuples = self._load_img(index)
1134
- assert self._img_sz is not None
1135
- h, w = img_tuples[0].shape[-2:]
1136
- if self._enable_random_cropping:
1137
- h_start, w_start = self._get_random_hw(h, w)
1138
- else:
1139
- h_start, w_start = self._get_deterministic_hw(index)
1140
-
1141
- cropped_img_tuples = [
1142
- self._crop_flip_img(img, h_start, w_start, False, False)
1143
- for img in img_tuples
1144
- ]
1145
- cropped_noise_tuples = [
1146
- self._crop_flip_img(noise, h_start, w_start, False, False)
1147
- for noise in noise_tuples
1148
- ]
1149
- h_center = h_start + self._img_sz // 2
1150
- w_center = w_start + self._img_sz // 2
1151
- allres_versions = {
1152
- i: [cropped_img_tuples[i]] for i in range(len(cropped_img_tuples))
1153
- }
1154
- for scale_idx in range(1, self.num_scales):
1155
- scaled_img_tuples = self._load_scaled_img(scale_idx, index)
1156
-
1157
- h_center = h_center // 2
1158
- w_center = w_center // 2
1159
-
1160
- h_start = h_center - self._img_sz // 2
1161
- w_start = w_center - self._img_sz // 2
1162
-
1163
- scaled_cropped_img_tuples = [
1164
- self._crop_flip_img(img, h_start, w_start, False, False)
1165
- for img in scaled_img_tuples
1166
- ]
1167
- for ch_idx in range(len(img_tuples)):
1168
- allres_versions[ch_idx].append(scaled_cropped_img_tuples[ch_idx])
1169
-
1170
- output_img_tuples = tuple(
1171
- [
1172
- np.concatenate(allres_versions[ch_idx])
1173
- for ch_idx in range(len(img_tuples))
1174
- ]
1175
- )
1176
- return output_img_tuples, cropped_noise_tuples
1177
-
1178
- def __getitem__(self, index: Union[int, Tuple[int, int]]):
1179
- if self._uncorrelated_channels:
1180
- img_tuples, noise_tuples = self.get_uncorrelated_img_tuples(index)
1181
- else:
1182
- img_tuples, noise_tuples = self._get_img(index)
1183
-
1184
- if self._enable_rotation:
1185
- img_tuples, noise_tuples = self._rotate(img_tuples, noise_tuples)
1186
-
1187
- assert self._lowres_supervision != True
1188
- # add noise to input
1189
- if len(noise_tuples) > 0:
1190
- factor = np.sqrt(2) if self._input_is_sum else 1.0
1191
- input_tuples = []
1192
- for x in img_tuples:
1193
- # NOTE: other LC levels already have noise added. So, we just need to add noise to the highest resolution.
1194
- x[0] = x[0] + noise_tuples[0] * factor
1195
- input_tuples.append(x)
1196
- else:
1197
- input_tuples = img_tuples
1198
-
1199
- inp, alpha = self._compute_input(input_tuples)
1200
- # assert self._alpha_weighted_target in [False, None]
1201
- target_tuples = [img[:1] for img in img_tuples]
1202
- # add noise to target.
1203
- if len(noise_tuples) >= 1:
1204
- target_tuples = [
1205
- x + noise for x, noise in zip(target_tuples, noise_tuples[1:])
1206
- ]
1207
-
1208
- target = self._compute_target(target_tuples, alpha)
1209
-
1210
- output = [inp, target]
1211
-
1212
- if self._return_alpha:
1213
- output.append(alpha)
1214
-
1215
- if isinstance(index, int):
1216
- return tuple(output)
1217
-
1218
- _, grid_size = index
1219
- output.append(grid_size)
1220
1094
  return tuple(output)