careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (69) hide show
  1. careamics/careamist.py +163 -266
  2. careamics/config/algorithm_model.py +0 -15
  3. careamics/config/architectures/custom_model.py +3 -3
  4. careamics/config/configuration_example.py +0 -3
  5. careamics/config/configuration_factory.py +23 -25
  6. careamics/config/configuration_model.py +11 -11
  7. careamics/config/data_model.py +80 -50
  8. careamics/config/inference_model.py +29 -17
  9. careamics/config/optimizer_models.py +7 -7
  10. careamics/config/support/supported_transforms.py +0 -1
  11. careamics/config/tile_information.py +26 -58
  12. careamics/config/transformations/normalize_model.py +32 -4
  13. careamics/config/validators/validator_utils.py +1 -1
  14. careamics/dataset/__init__.py +12 -1
  15. careamics/dataset/dataset_utils/__init__.py +8 -1
  16. careamics/dataset/dataset_utils/file_utils.py +1 -1
  17. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  18. careamics/dataset/dataset_utils/read_tiff.py +0 -9
  19. careamics/dataset/dataset_utils/running_stats.py +186 -0
  20. careamics/dataset/in_memory_dataset.py +66 -171
  21. careamics/dataset/in_memory_pred_dataset.py +88 -0
  22. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  23. careamics/dataset/iterable_dataset.py +92 -249
  24. careamics/dataset/iterable_pred_dataset.py +121 -0
  25. careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
  26. careamics/dataset/patching/patching.py +54 -25
  27. careamics/dataset/patching/random_patching.py +9 -4
  28. careamics/dataset/patching/validate_patch_dimension.py +5 -3
  29. careamics/dataset/tiling/__init__.py +10 -0
  30. careamics/dataset/tiling/collate_tiles.py +33 -0
  31. careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
  32. careamics/lightning_datamodule.py +1 -6
  33. careamics/lightning_module.py +11 -7
  34. careamics/lightning_prediction_datamodule.py +52 -72
  35. careamics/lvae_training/__init__.py +0 -0
  36. careamics/lvae_training/data_modules.py +1220 -0
  37. careamics/lvae_training/data_utils.py +618 -0
  38. careamics/lvae_training/eval_utils.py +905 -0
  39. careamics/lvae_training/get_config.py +84 -0
  40. careamics/lvae_training/lightning_module.py +701 -0
  41. careamics/lvae_training/metrics.py +214 -0
  42. careamics/lvae_training/train_lvae.py +339 -0
  43. careamics/lvae_training/train_utils.py +121 -0
  44. careamics/model_io/bioimage/model_description.py +40 -32
  45. careamics/model_io/bmz_io.py +1 -1
  46. careamics/model_io/model_io_utils.py +5 -2
  47. careamics/models/lvae/__init__.py +0 -0
  48. careamics/models/lvae/layers.py +1998 -0
  49. careamics/models/lvae/likelihoods.py +312 -0
  50. careamics/models/lvae/lvae.py +985 -0
  51. careamics/models/lvae/noise_models.py +409 -0
  52. careamics/models/lvae/utils.py +395 -0
  53. careamics/prediction_utils/__init__.py +12 -0
  54. careamics/prediction_utils/create_pred_datamodule.py +185 -0
  55. careamics/prediction_utils/prediction_outputs.py +165 -0
  56. careamics/prediction_utils/stitch_prediction.py +100 -0
  57. careamics/transforms/n2v_manipulate.py +3 -1
  58. careamics/transforms/normalize.py +139 -68
  59. careamics/transforms/pixel_manipulation.py +33 -9
  60. careamics/transforms/tta.py +43 -29
  61. careamics/utils/ram.py +2 -2
  62. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +7 -6
  63. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/RECORD +65 -42
  64. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
  65. careamics/lightning_prediction_loop.py +0 -118
  66. careamics/prediction/__init__.py +0 -7
  67. careamics/prediction/stitch_prediction.py +0 -70
  68. careamics/utils/running_stats.py +0 -43
  69. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1220 @@
1
+ """
2
+ A place for Datasets and Dataloaders.
3
+ """
4
+
5
+ import os
6
+ from typing import Tuple, Union
7
+
8
+ # import albumentations as A
9
+ import ml_collections
10
+ import numpy as np
11
+ from skimage.transform import resize
12
+
13
+ from .data_utils import (
14
+ DataSplitType,
15
+ DataType,
16
+ GridAlignement,
17
+ GridIndexManager,
18
+ IndexSwitcher,
19
+ get_datasplit_tuples,
20
+ get_mrc_data,
21
+ load_tiff,
22
+ )
23
+
24
+
25
+ def get_train_val_data(
26
+ data_config,
27
+ fpath,
28
+ datasplit_type: DataSplitType,
29
+ val_fraction=None,
30
+ test_fraction=None,
31
+ allow_generation=None,
32
+ ignore_specific_datapoints=None,
33
+ ):
34
+ """
35
+ Load the data from the given path and split them in training, validation and test sets.
36
+
37
+ Ensure that the shape of data should be N*H*W*C: N is number of data points. H,W are the image dimensions.
38
+ C is the number of channels.
39
+ """
40
+ if data_config.data_type == DataType.SeparateTiffData:
41
+ fpath1 = os.path.join(fpath, data_config.ch1_fname)
42
+ fpath2 = os.path.join(fpath, data_config.ch2_fname)
43
+ fpaths = [fpath1, fpath2]
44
+ fpath0 = ""
45
+ if "ch_input_fname" in data_config:
46
+ fpath0 = os.path.join(fpath, data_config.ch_input_fname)
47
+ fpaths = [fpath0] + fpaths
48
+
49
+ print(
50
+ f"Loading from {fpath} Channels: "
51
+ f"{fpath1},{fpath2}, inp:{fpath0} Mode:{DataSplitType.name(datasplit_type)}"
52
+ )
53
+
54
+ data = np.concatenate([load_tiff(fpath)[..., None] for fpath in fpaths], axis=3)
55
+ if data_config.data_type == DataType.PredictedTiffData:
56
+ assert len(data.shape) == 5 and data.shape[-1] == 1
57
+ data = data[..., 0].copy()
58
+ # data = data[::3].copy()
59
+ # NOTE: This was not the correct way to do it. It is so because the noise present in the input was directly related
60
+ # to the noise present in the channels and so this is not the way we would get the data.
61
+ # We need to add the noise independently to the input and the target.
62
+
63
+ # if data_config.get('poisson_noise_factor', False):
64
+ # data = np.random.poisson(data)
65
+ # if data_config.get('enable_gaussian_noise', False):
66
+ # synthetic_scale = data_config.get('synthetic_gaussian_scale', 0.1)
67
+ # print('Adding Gaussian noise with scale', synthetic_scale)
68
+ # noise = np.random.normal(0, synthetic_scale, data.shape)
69
+ # data = data + noise
70
+
71
+ if datasplit_type == DataSplitType.All:
72
+ return data.astype(np.float32)
73
+
74
+ train_idx, val_idx, test_idx = get_datasplit_tuples(
75
+ val_fraction, test_fraction, len(data), starting_test=True
76
+ )
77
+ if datasplit_type == DataSplitType.Train:
78
+ return data[train_idx].astype(np.float32)
79
+ elif datasplit_type == DataSplitType.Val:
80
+ return data[val_idx].astype(np.float32)
81
+ elif datasplit_type == DataSplitType.Test:
82
+ return data[test_idx].astype(np.float32)
83
+
84
+ elif data_config.data_type == DataType.BioSR_MRC:
85
+ num_channels = data_config.get("num_channels", 2)
86
+ fpaths = []
87
+ data_list = []
88
+ for i in range(num_channels):
89
+ fpath1 = os.path.join(fpath, data_config.get(f"ch{i + 1}_fname"))
90
+ fpaths.append(fpath1)
91
+ data = get_mrc_data(fpath1)[..., None]
92
+ data_list.append(data)
93
+
94
+ dirname = os.path.dirname(os.path.dirname(fpaths[0])) + "/"
95
+
96
+ msg = ",".join([x[len(dirname) :] for x in fpaths])
97
+ print(
98
+ f"Loaded from {dirname} Channels:{len(fpaths)} {msg} Mode:{DataSplitType.name(datasplit_type)}"
99
+ )
100
+ N = data_list[0].shape[0]
101
+ for data in data_list:
102
+ N = min(N, data.shape[0])
103
+
104
+ cropped_data = []
105
+ for data in data_list:
106
+ cropped_data.append(data[:N])
107
+
108
+ data = np.concatenate(cropped_data, axis=3)
109
+
110
+ if datasplit_type == DataSplitType.All:
111
+ return data.astype(np.float32)
112
+
113
+ train_idx, val_idx, test_idx = get_datasplit_tuples(
114
+ val_fraction, test_fraction, len(data), starting_test=True
115
+ )
116
+ if datasplit_type == DataSplitType.Train:
117
+ return data[train_idx].astype(np.float32)
118
+ elif datasplit_type == DataSplitType.Val:
119
+ return data[val_idx].astype(np.float32)
120
+ elif datasplit_type == DataSplitType.Test:
121
+ return data[test_idx].astype(np.float32)
122
+
123
+
124
+ class MultiChDloader:
125
+
126
+ def __init__(
127
+ self,
128
+ data_config: ml_collections.ConfigDict,
129
+ fpath: str,
130
+ datasplit_type: DataSplitType = None,
131
+ val_fraction: float = None,
132
+ test_fraction: float = None,
133
+ normalized_input=None,
134
+ enable_rotation_aug: bool = False,
135
+ enable_random_cropping: bool = False,
136
+ use_one_mu_std=None,
137
+ allow_generation: bool = False,
138
+ max_val: float = None,
139
+ grid_alignment=GridAlignement.LeftTop,
140
+ overlapping_padding_kwargs=None,
141
+ print_vars: bool = True,
142
+ ):
143
+ """
144
+ Here, an image is split into grids of size img_sz.
145
+ Args:
146
+ repeat_factor: Since we are doing a random crop, repeat_factor is
147
+ given which can repeatedly sample from the same image. If self.N=12
148
+ and repeat_factor is 5, then index upto 12*5 = 60 is allowed.
149
+ use_one_mu_std: If this is set to true, then one mean and stdev is used
150
+ for both channels. Otherwise, two different meean and stdev are used.
151
+
152
+ """
153
+ self._data_type = data_config.data_type
154
+ self._fpath = fpath
155
+ self._data = self.N = self._noise_data = None
156
+
157
+ # Hardcoded params, not included in the config file.
158
+
159
+ # by default, if the noise is present, add it to the input and target.
160
+ self._disable_noise = False # to add synthetic noise
161
+ self._train_index_switcher = None
162
+ # NOTE: Input is the sum of the different channels. It is not the average of the different channels.
163
+ self._input_is_sum = data_config.get("input_is_sum", False)
164
+ self._num_channels = data_config.get("num_channels", 2)
165
+ self._input_idx = data_config.get("input_idx", None)
166
+ self._tar_idx_list = data_config.get("target_idx_list", None)
167
+
168
+ if datasplit_type == DataSplitType.Train:
169
+ self._datausage_fraction = 1.0
170
+ # assert self._datausage_fraction == 1.0, 'Not supported. Use validtarget_random_fraction and training_validtarget_fraction to get the same effect'
171
+ self._validtarget_rand_fract = None
172
+ # self._validtarget_random_fraction_final = data_config.get('validtarget_random_fraction_final', None)
173
+ # self._validtarget_random_fraction_stepepoch = data_config.get('validtarget_random_fraction_stepepoch', None)
174
+ # self._idx_count = 0
175
+ elif datasplit_type == DataSplitType.Val:
176
+ self._datausage_fraction = 1.0
177
+ else:
178
+ self._datausage_fraction = 1.0
179
+
180
+ self.load_data(
181
+ data_config,
182
+ datasplit_type,
183
+ val_fraction=val_fraction,
184
+ test_fraction=test_fraction,
185
+ allow_generation=allow_generation,
186
+ )
187
+ self._normalized_input = normalized_input
188
+ self._quantile = 1.0
189
+ self._channelwise_quantile = False
190
+ self._background_quantile = 0.0
191
+ self._clip_background_noise_to_zero = False
192
+ self._skip_normalization_using_mean = False
193
+ self._empty_patch_replacement_enabled = False
194
+
195
+ self._background_values = None
196
+
197
+ self._grid_alignment = grid_alignment
198
+ self._overlapping_padding_kwargs = overlapping_padding_kwargs
199
+ if self._grid_alignment == GridAlignement.LeftTop:
200
+ assert (
201
+ self._overlapping_padding_kwargs is None
202
+ or data_config.multiscale_lowres_count is not None
203
+ ), "Padding is not used with this alignement style"
204
+ elif self._grid_alignment == GridAlignement.Center:
205
+ assert (
206
+ self._overlapping_padding_kwargs is not None
207
+ ), "With Center grid alignment, padding is needed."
208
+
209
+ self._is_train = datasplit_type == DataSplitType.Train
210
+
211
+ # input = alpha * ch1 + (1-alpha)*ch2.
212
+ # alpha is sampled randomly between these two extremes
213
+ self._start_alpha_arr = self._end_alpha_arr = self._return_alpha = (
214
+ self._alpha_weighted_target
215
+ ) = None
216
+
217
+ self._img_sz = self._grid_sz = self._repeat_factor = self.idx_manager = None
218
+ if self._is_train:
219
+ self._start_alpha_arr = None
220
+ self._end_alpha_arr = None
221
+ self._alpha_weighted_target = False
222
+
223
+ self.set_img_sz(
224
+ data_config.image_size,
225
+ (
226
+ data_config.grid_size
227
+ if "grid_size" in data_config
228
+ else data_config.image_size
229
+ ),
230
+ )
231
+
232
+ # if self._validtarget_rand_fract is not None:
233
+ # self._train_index_switcher = IndexSwitcher(self.idx_manager, data_config, self._img_sz)
234
+ # self._std_background_arr = None
235
+
236
+ else:
237
+ self.set_img_sz(
238
+ data_config.image_size,
239
+ (
240
+ data_config.grid_size
241
+ if "grid_size" in data_config
242
+ else data_config.image_size
243
+ ),
244
+ )
245
+
246
+ self._return_alpha = False
247
+ self._return_index = False
248
+
249
+ # self._empty_patch_replacement_enabled = data_config.get("empty_patch_replacement_enabled",
250
+ # False) and self._is_train
251
+ # if self._empty_patch_replacement_enabled:
252
+ # self._empty_patch_replacement_channel_idx = data_config.empty_patch_replacement_channel_idx
253
+ # self._empty_patch_replacement_probab = data_config.empty_patch_replacement_probab
254
+ # data_frames = self._data[..., self._empty_patch_replacement_channel_idx]
255
+ # # NOTE: This is on the raw data. So, it must be called before removing the background.
256
+ # self._empty_patch_fetcher = EmptyPatchFetcher(self.idx_manager,
257
+ # self._img_sz,
258
+ # data_frames,
259
+ # max_val_threshold=data_config.empty_patch_max_val_threshold)
260
+
261
+ self.rm_bkground_set_max_val_and_upperclip_data(max_val, datasplit_type)
262
+
263
+ # For overlapping dloader, image_size and repeat_factors are not related. hence a different function.
264
+
265
+ self._mean = None
266
+ self._std = None
267
+ self._use_one_mu_std = use_one_mu_std
268
+ # Hardcoded
269
+ self._target_separate_normalization = True
270
+
271
+ self._enable_rotation = enable_rotation_aug
272
+ self._enable_random_cropping = enable_random_cropping
273
+ self._uncorrelated_channels = (
274
+ data_config.get("uncorrelated_channels", False) and self._is_train
275
+ )
276
+ assert self._is_train or self._uncorrelated_channels is False
277
+ assert (
278
+ self._enable_random_cropping is True or self._uncorrelated_channels is False
279
+ )
280
+ # Randomly rotate [-90,90]
281
+
282
+ self._rotation_transform = None
283
+ if self._enable_rotation:
284
+ raise NotImplementedError(
285
+ "Augmentation by means of rotation is not supported yet."
286
+ )
287
+ self._rotation_transform = A.Compose([A.Flip(), A.RandomRotate90()])
288
+
289
+ if print_vars:
290
+ msg = self._init_msg()
291
+ print(msg)
292
+
293
+ def disable_noise(self):
294
+ assert (
295
+ self._poisson_noise_factor is None
296
+ ), "This is not supported. Poisson noise is added to the data itself and so the noise cannot be disabled."
297
+ self._disable_noise = True
298
+
299
+ def enable_noise(self):
300
+ self._disable_noise = False
301
+
302
+ def get_data_shape(self):
303
+ return self._data.shape
304
+
305
+ def load_data(
306
+ self,
307
+ data_config,
308
+ datasplit_type,
309
+ val_fraction=None,
310
+ test_fraction=None,
311
+ allow_generation=None,
312
+ ):
313
+ self._data = get_train_val_data(
314
+ data_config,
315
+ self._fpath,
316
+ datasplit_type,
317
+ val_fraction=val_fraction,
318
+ test_fraction=test_fraction,
319
+ allow_generation=allow_generation,
320
+ )
321
+
322
+ old_shape = self._data.shape
323
+ if self._datausage_fraction < 1.0:
324
+ framepixelcount = np.prod(self._data.shape[1:3])
325
+ pixelcount = int(
326
+ len(self._data) * framepixelcount * self._datausage_fraction
327
+ )
328
+ frame_count = int(np.ceil(pixelcount / framepixelcount))
329
+ last_frame_reduced_size, _ = IndexSwitcher.get_reduced_frame_size(
330
+ self._data.shape[:3], self._datausage_fraction
331
+ )
332
+ self._data = self._data[:frame_count].copy()
333
+ if frame_count == 1:
334
+ self._data = self._data[
335
+ :, :last_frame_reduced_size, :last_frame_reduced_size
336
+ ].copy()
337
+ print(
338
+ f"[{self.__class__.__name__}] New data shape: {self._data.shape} Old: {old_shape}"
339
+ )
340
+
341
+ msg = ""
342
+ if data_config.get("poisson_noise_factor", -1) > 0:
343
+ self._poisson_noise_factor = data_config.poisson_noise_factor
344
+ msg += f"Adding Poisson noise with factor {self._poisson_noise_factor}.\t"
345
+ self._data = (
346
+ np.random.poisson(self._data / self._poisson_noise_factor)
347
+ * self._poisson_noise_factor
348
+ )
349
+
350
+ if data_config.get("enable_gaussian_noise", False):
351
+ synthetic_scale = data_config.get("synthetic_gaussian_scale", 0.1)
352
+ msg += f"Adding Gaussian noise with scale {synthetic_scale}"
353
+ # 0 => noise for input. 1: => noise for all targets.
354
+ shape = self._data.shape
355
+ self._noise_data = np.random.normal(
356
+ 0, synthetic_scale, (*shape[:-1], shape[-1] + 1)
357
+ )
358
+ if data_config.get("input_has_dependant_noise", False):
359
+ msg += ". Moreover, input has dependent noise"
360
+ self._noise_data[..., 0] = np.mean(self._noise_data[..., 1:], axis=-1)
361
+ print(msg)
362
+
363
+ self.N = len(self._data)
364
+ assert (
365
+ self._data.shape[-1] == self._num_channels
366
+ ), "Number of channels in data and config do not match."
367
+
368
+ def save_background(self, channel_idx, frame_idx, background_value):
369
+ self._background_values[frame_idx, channel_idx] = background_value
370
+
371
+ def get_background(self, channel_idx, frame_idx):
372
+ return self._background_values[frame_idx, channel_idx]
373
+
374
+ def remove_background(self):
375
+
376
+ self._background_values = np.zeros((self._data.shape[0], self._data.shape[-1]))
377
+
378
+ if self._background_quantile == 0.0:
379
+ assert (
380
+ self._clip_background_noise_to_zero is False
381
+ ), "This operation currently happens later in this function."
382
+ return
383
+
384
+ if self._data.dtype in [np.uint16]:
385
+ # unsigned integer creates havoc
386
+ self._data = self._data.astype(np.int32)
387
+
388
+ for ch in range(self._data.shape[-1]):
389
+ for idx in range(self._data.shape[0]):
390
+ qval = np.quantile(self._data[idx, ..., ch], self._background_quantile)
391
+ assert (
392
+ np.abs(qval) > 20
393
+ ), "We are truncating the qval to an integer which will only make sense if it is large enough"
394
+ # NOTE: Here, there can be an issue if you work with normalized data
395
+ qval = int(qval)
396
+ self.save_background(ch, idx, qval)
397
+ self._data[idx, ..., ch] -= qval
398
+
399
+ if self._clip_background_noise_to_zero:
400
+ self._data[self._data < 0] = 0
401
+
402
+ def rm_bkground_set_max_val_and_upperclip_data(self, max_val, datasplit_type):
403
+ self.remove_background()
404
+ self.set_max_val(max_val, datasplit_type)
405
+ self.upperclip_data()
406
+
407
+ def upperclip_data(self):
408
+ if isinstance(self.max_val, list):
409
+ chN = self._data.shape[-1]
410
+ assert chN == len(self.max_val)
411
+ for ch in range(chN):
412
+ ch_data = self._data[..., ch]
413
+ ch_q = self.max_val[ch]
414
+ ch_data[ch_data > ch_q] = ch_q
415
+ self._data[..., ch] = ch_data
416
+ else:
417
+ self._data[self._data > self.max_val] = self.max_val
418
+
419
+ def compute_max_val(self):
420
+ if self._channelwise_quantile:
421
+ max_val_arr = [
422
+ np.quantile(self._data[..., i], self._quantile)
423
+ for i in range(self._data.shape[-1])
424
+ ]
425
+ return max_val_arr
426
+ else:
427
+ return np.quantile(self._data, self._quantile)
428
+
429
+ def set_max_val(self, max_val, datasplit_type):
430
+
431
+ if max_val is None:
432
+ assert datasplit_type == DataSplitType.Train
433
+ self.max_val = self.compute_max_val()
434
+ else:
435
+ assert max_val is not None
436
+ self.max_val = max_val
437
+
438
+ def get_max_val(self):
439
+ return self.max_val
440
+
441
+ def get_img_sz(self):
442
+ return self._img_sz
443
+
444
+ def reduce_data(
445
+ self, t_list=None, h_start=None, h_end=None, w_start=None, w_end=None
446
+ ):
447
+ if t_list is None:
448
+ t_list = list(range(self._data.shape[0]))
449
+ if h_start is None:
450
+ h_start = 0
451
+ if h_end is None:
452
+ h_end = self._data.shape[1]
453
+ if w_start is None:
454
+ w_start = 0
455
+ if w_end is None:
456
+ w_end = self._data.shape[2]
457
+
458
+ self._data = self._data[t_list, h_start:h_end, w_start:w_end, :].copy()
459
+ if self._noise_data is not None:
460
+ self._noise_data = self._noise_data[
461
+ t_list, h_start:h_end, w_start:w_end, :
462
+ ].copy()
463
+
464
+ self.N = len(t_list)
465
+ self.set_img_sz(self._img_sz, self._grid_sz)
466
+ print(
467
+ f"[{self.__class__.__name__}] Data reduced. New data shape: {self._data.shape}"
468
+ )
469
+
470
+ def set_img_sz(self, image_size, grid_size):
471
+ """
472
+ If one wants to change the image size on the go, then this can be used.
473
+ Args:
474
+ image_size: size of one patch
475
+ grid_size: frame is divided into square grids of this size. A patch centered on a grid having size `image_size` is returned.
476
+ """
477
+ self._img_sz = image_size
478
+ self._grid_sz = grid_size
479
+ self.idx_manager = GridIndexManager(
480
+ self._data.shape, self._grid_sz, self._img_sz, self._grid_alignment
481
+ )
482
+ self.set_repeat_factor()
483
+
484
+ def set_repeat_factor(self):
485
+ if self._grid_sz > 1:
486
+ self._repeat_factor = self.idx_manager.grid_rows(
487
+ self._grid_sz
488
+ ) * self.idx_manager.grid_cols(self._grid_sz)
489
+ else:
490
+ self._repeat_factor = self.idx_manager.grid_rows(
491
+ self._img_sz
492
+ ) * self.idx_manager.grid_cols(self._img_sz)
493
+
494
+ def _init_msg(
495
+ self,
496
+ ):
497
+ msg = (
498
+ f"[{self.__class__.__name__}] Train:{int(self._is_train)} Sz:{self._img_sz}"
499
+ )
500
+ msg += f" N:{self.N} NumPatchPerN:{self._repeat_factor}"
501
+ # msg += f' NormInp:{self._normalized_input}'
502
+ # msg += f' SingleNorm:{self._use_one_mu_std}'
503
+ msg += f" Rot:{self._enable_rotation}"
504
+ msg += f" RandCrop:{self._enable_random_cropping}"
505
+ msg += f" Channel:{self._num_channels}"
506
+ # msg += f' Q:{self._quantile}'
507
+ if self._input_is_sum:
508
+ msg += f" SummedInput:{self._input_is_sum}"
509
+
510
+ if self._empty_patch_replacement_enabled:
511
+ msg += f" ReplaceWithRandSample:{self._empty_patch_replacement_enabled}"
512
+ if self._uncorrelated_channels:
513
+ msg += f" Uncorr:{self._uncorrelated_channels}"
514
+ if self._empty_patch_replacement_enabled:
515
+ msg += f"-{self._empty_patch_replacement_channel_idx}-{self._empty_patch_replacement_probab}"
516
+ if self._background_quantile > 0.0:
517
+ msg += f" BckQ:{self._background_quantile}"
518
+
519
+ if self._start_alpha_arr is not None:
520
+ msg += f" Alpha:[{self._start_alpha_arr},{self._end_alpha_arr}]"
521
+ return msg
522
+
523
+ def _crop_imgs(self, index, *img_tuples: np.ndarray):
524
+ h, w = img_tuples[0].shape[-2:]
525
+ if self._img_sz is None:
526
+ return (
527
+ *img_tuples,
528
+ {"h": [0, h], "w": [0, w], "hflip": False, "wflip": False},
529
+ )
530
+
531
+ if self._enable_random_cropping:
532
+ h_start, w_start = self._get_random_hw(h, w)
533
+ else:
534
+ h_start, w_start = self._get_deterministic_hw(index)
535
+
536
+ cropped_imgs = []
537
+ for img in img_tuples:
538
+ img = self._crop_flip_img(img, h_start, w_start, False, False)
539
+ cropped_imgs.append(img)
540
+
541
+ return (
542
+ *tuple(cropped_imgs),
543
+ {
544
+ "h": [h_start, h_start + self._img_sz],
545
+ "w": [w_start, w_start + self._img_sz],
546
+ "hflip": False,
547
+ "wflip": False,
548
+ },
549
+ )
550
+
551
+ def _crop_img(self, img: np.ndarray, h_start: int, w_start: int):
552
+ if self._grid_alignment == GridAlignement.LeftTop:
553
+ # In training, this is used.
554
+ # NOTE: It is my opinion that if I just use self._crop_img_with_padding, it will work perfectly fine.
555
+ # The only benefit this if else loop provides is that it makes it easier to see what happens during training.
556
+ new_img = img[
557
+ ..., h_start : h_start + self._img_sz, w_start : w_start + self._img_sz
558
+ ]
559
+ return new_img
560
+ elif self._grid_alignment == GridAlignement.Center:
561
+ # During evaluation, this is used. In this situation, we can have negative h_start, w_start. Or h_start +self._img_sz can be larger than frame
562
+ # In these situations, we need some sort of padding. This is not needed in the LeftTop alignement.
563
+ return self._crop_img_with_padding(img, h_start, w_start)
564
+
565
+ def get_begin_end_padding(self, start_pos, max_len):
566
+ """
567
+ The effect is that the image with size self._grid_sz is in the center of the patch with sufficient
568
+ padding on all four sides so that the final patch size is self._img_sz.
569
+ """
570
+ pad_start = 0
571
+ pad_end = 0
572
+ if start_pos < 0:
573
+ pad_start = -1 * start_pos
574
+
575
+ pad_end = max(0, start_pos + self._img_sz - max_len)
576
+
577
+ return pad_start, pad_end
578
+
579
+ def _crop_img_with_padding(self, img: np.ndarray, h_start: int, w_start: int):
580
+ _, H, W = img.shape
581
+ h_on_boundary = self.on_boundary(h_start, H)
582
+ w_on_boundary = self.on_boundary(w_start, W)
583
+
584
+ assert h_start < H
585
+ assert w_start < W
586
+
587
+ assert h_start + self._img_sz <= H or h_on_boundary
588
+ assert w_start + self._img_sz <= W or w_on_boundary
589
+ # max() is needed since h_start could be negative.
590
+ new_img = img[
591
+ ...,
592
+ max(0, h_start) : h_start + self._img_sz,
593
+ max(0, w_start) : w_start + self._img_sz,
594
+ ]
595
+ padding = np.array([[0, 0], [0, 0], [0, 0]])
596
+
597
+ if h_on_boundary:
598
+ pad = self.get_begin_end_padding(h_start, H)
599
+ padding[1] = pad
600
+ if w_on_boundary:
601
+ pad = self.get_begin_end_padding(w_start, W)
602
+ padding[2] = pad
603
+
604
+ if not np.all(padding == 0):
605
+ new_img = np.pad(new_img, padding, **self._overlapping_padding_kwargs)
606
+
607
+ return new_img
608
+
609
+ def _crop_flip_img(
610
+ self, img: np.ndarray, h_start: int, w_start: int, h_flip: bool, w_flip: bool
611
+ ):
612
+ new_img = self._crop_img(img, h_start, w_start)
613
+ if h_flip:
614
+ new_img = new_img[..., ::-1, :]
615
+ if w_flip:
616
+ new_img = new_img[..., :, ::-1]
617
+
618
+ return new_img.astype(np.float32)
619
+
620
+ def __len__(self):
621
+ return self.N * self._repeat_factor
622
+
623
+ def _load_img(
624
+ self, index: Union[int, Tuple[int, int]]
625
+ ) -> Tuple[np.ndarray, np.ndarray]:
626
+ """
627
+ Returns the channels and also the respective noise channels.
628
+ """
629
+ if isinstance(index, int) or isinstance(index, np.int64):
630
+ idx = index
631
+ else:
632
+ idx = index[0]
633
+
634
+ imgs = self._data[self.idx_manager.get_t(idx)]
635
+ loaded_imgs = [imgs[None, ..., i] for i in range(imgs.shape[-1])]
636
+ noise = []
637
+ if self._noise_data is not None and not self._disable_noise:
638
+ noise = [
639
+ self._noise_data[self.idx_manager.get_t(idx)][None, ..., i]
640
+ for i in range(self._noise_data.shape[-1])
641
+ ]
642
+ return tuple(loaded_imgs), tuple(noise)
643
+
644
+ def get_mean_std(self):
645
+ return self._mean, self._std
646
+
647
+ def set_mean_std(self, mean_val, std_val):
648
+ self._mean = mean_val
649
+ self._std = std_val
650
+
651
+ def normalize_img(self, *img_tuples):
652
+ mean, std = self.get_mean_std()
653
+ mean = mean["target"]
654
+ std = std["target"]
655
+ mean = mean.squeeze()
656
+ std = std.squeeze()
657
+ normalized_imgs = []
658
+ for i, img in enumerate(img_tuples):
659
+ img = (img - mean[i]) / std[i]
660
+ normalized_imgs.append(img)
661
+ return tuple(normalized_imgs)
662
+
663
+ def get_grid_size(self):
664
+ return self._grid_sz
665
+
666
+ def get_idx_manager(self):
667
+ return self.idx_manager
668
+
669
+ def per_side_overlap_pixelcount(self):
670
+ return (self._img_sz - self._grid_sz) // 2
671
+
672
+ def on_boundary(self, cur_loc, frame_size):
673
+ return cur_loc + self._img_sz > frame_size or cur_loc < 0
674
+
675
+ def _get_deterministic_hw(self, index: Union[int, Tuple[int, int]]):
676
+ """
677
+ It returns the top-left corner of the patch corresponding to index.
678
+ """
679
+ if isinstance(index, int) or isinstance(index, np.int64):
680
+ idx = index
681
+ grid_size = self._grid_sz
682
+ else:
683
+ idx, grid_size = index
684
+
685
+ h_start, w_start = self.idx_manager.get_deterministic_hw(
686
+ idx, grid_size=grid_size
687
+ )
688
+ if self._grid_alignment == GridAlignement.LeftTop:
689
+ return h_start, w_start
690
+ elif self._grid_alignment == GridAlignement.Center:
691
+ pad = self.per_side_overlap_pixelcount()
692
+ return h_start - pad, w_start - pad
693
+
694
+ def compute_individual_mean_std(self):
695
+ # numpy 1.19.2 has issues in computing for large arrays. https://github.com/numpy/numpy/issues/8869
696
+ # mean = np.mean(self._data, axis=(0, 1, 2))
697
+ # std = np.std(self._data, axis=(0, 1, 2))
698
+ mean_arr = []
699
+ std_arr = []
700
+ for ch_idx in range(self._data.shape[-1]):
701
+ mean_ = (
702
+ 0.0
703
+ if self._skip_normalization_using_mean
704
+ else self._data[..., ch_idx].mean()
705
+ )
706
+ if self._noise_data is not None:
707
+ std_ = (
708
+ self._data[..., ch_idx] + self._noise_data[..., ch_idx + 1]
709
+ ).std()
710
+ else:
711
+ std_ = self._data[..., ch_idx].std()
712
+
713
+ mean_arr.append(mean_)
714
+ std_arr.append(std_)
715
+
716
+ mean = np.array(mean_arr)
717
+ std = np.array(std_arr)
718
+
719
+ return mean[None, :, None, None], std[None, :, None, None]
720
+
721
+ def compute_mean_std(self, allow_for_validation_data=False):
722
+ """
723
+ Note that we must compute this only for training data.
724
+ """
725
+ assert (
726
+ self._is_train is True or allow_for_validation_data
727
+ ), "This is just allowed for training data"
728
+ assert self._use_one_mu_std is True, "This is the only supported case"
729
+
730
+ if self._input_idx is not None:
731
+ assert (
732
+ self._tar_idx_list is not None
733
+ ), "tar_idx_list must be set if input_idx is set."
734
+ assert self._noise_data is None, "This is not supported with noise"
735
+ assert (
736
+ self._target_separate_normalization is True
737
+ ), "This is not supported with target_separate_normalization=False"
738
+
739
+ mean, std = self.compute_individual_mean_std()
740
+ mean_dict = {
741
+ "input": mean[:, self._input_idx : self._input_idx + 1],
742
+ "target": mean[:, self._tar_idx_list],
743
+ }
744
+ std_dict = {
745
+ "input": std[:, self._input_idx : self._input_idx + 1],
746
+ "target": std[:, self._tar_idx_list],
747
+ }
748
+ return mean_dict, std_dict
749
+
750
+ if self._input_is_sum:
751
+ assert self._noise_data is None, "This is not supported with noise"
752
+ mean = [
753
+ np.mean(self._data[..., k : k + 1], keepdims=True)
754
+ for k in range(self._num_channels)
755
+ ]
756
+ mean = np.sum(mean, keepdims=True)[0]
757
+ std = np.linalg.norm(
758
+ [
759
+ np.std(self._data[..., k : k + 1], keepdims=True)
760
+ for k in range(self._num_channels)
761
+ ],
762
+ keepdims=True,
763
+ )[0]
764
+ else:
765
+ mean = np.mean(self._data, keepdims=True).reshape(1, 1, 1, 1)
766
+ if self._noise_data is not None:
767
+ std = np.std(
768
+ self._data + self._noise_data[..., 1:], keepdims=True
769
+ ).reshape(1, 1, 1, 1)
770
+ else:
771
+ std = np.std(self._data, keepdims=True).reshape(1, 1, 1, 1)
772
+
773
+ mean = np.repeat(mean, self._num_channels, axis=1)
774
+ std = np.repeat(std, self._num_channels, axis=1)
775
+
776
+ if self._skip_normalization_using_mean:
777
+ mean = np.zeros_like(mean)
778
+
779
+ mean_dict = {"input": mean} # , 'target':mean}
780
+ std_dict = {"input": std} # , 'target':std}
781
+
782
+ if self._target_separate_normalization:
783
+ mean, std = self.compute_individual_mean_std()
784
+
785
+ mean_dict["target"] = mean
786
+ std_dict["target"] = std
787
+ return mean_dict, std_dict
788
+
789
+ def _get_random_hw(self, h: int, w: int):
790
+ """
791
+ Random starting position for the crop for the img with index `index`.
792
+ """
793
+ if h != self._img_sz:
794
+ h_start = np.random.choice(h - self._img_sz)
795
+ w_start = np.random.choice(w - self._img_sz)
796
+ else:
797
+ h_start = 0
798
+ w_start = 0
799
+ return h_start, w_start
800
+
801
+ def _get_img(self, index: Union[int, Tuple[int, int]]):
802
+ """
803
+ Loads an image.
804
+ Crops the image such that cropped image has content.
805
+ """
806
+ img_tuples, noise_tuples = self._load_img(index)
807
+ cropped_img_tuples = self._crop_imgs(index, *img_tuples, *noise_tuples)[:-1]
808
+ cropped_noise_tuples = cropped_img_tuples[len(img_tuples) :]
809
+ cropped_img_tuples = cropped_img_tuples[: len(img_tuples)]
810
+ return cropped_img_tuples, cropped_noise_tuples
811
+
812
+ def replace_with_empty_patch(self, img_tuples):
813
+ empty_index = self._empty_patch_fetcher.sample()
814
+ empty_img_tuples = self._get_img(empty_index)
815
+ final_img_tuples = []
816
+ for tuple_idx in range(len(img_tuples)):
817
+ if tuple_idx == self._empty_patch_replacement_channel_idx:
818
+ final_img_tuples.append(empty_img_tuples[tuple_idx])
819
+ else:
820
+ final_img_tuples.append(img_tuples[tuple_idx])
821
+ return tuple(final_img_tuples)
822
+
823
+ def get_mean_std_for_input(self):
824
+ mean, std = self.get_mean_std()
825
+ return mean["input"], std["input"]
826
+
827
+ def _compute_target(self, img_tuples, alpha):
828
+ if self._tar_idx_list is not None and isinstance(self._tar_idx_list, int):
829
+ target = img_tuples[self._tar_idx_list]
830
+ else:
831
+ if self._tar_idx_list is not None:
832
+ assert isinstance(self._tar_idx_list, list) or isinstance(
833
+ self._tar_idx_list, tuple
834
+ )
835
+ img_tuples = [img_tuples[i] for i in self._tar_idx_list]
836
+
837
+ if self._alpha_weighted_target:
838
+ assert self._input_is_sum is False
839
+ target = []
840
+ for i in range(len(img_tuples)):
841
+ target.append(img_tuples[i] * alpha[i])
842
+ target = np.concatenate(target, axis=0)
843
+ else:
844
+ target = np.concatenate(img_tuples, axis=0)
845
+ return target
846
+
847
+ def _compute_input_with_alpha(self, img_tuples, alpha_list):
848
+ # assert self._normalized_input is True, "normalization should happen here"
849
+ if self._input_idx is not None:
850
+ inp = img_tuples[self._input_idx]
851
+ else:
852
+ inp = 0
853
+ for alpha, img in zip(alpha_list, img_tuples):
854
+ inp += img * alpha
855
+
856
+ if self._normalized_input is False:
857
+ return inp.astype(np.float32)
858
+
859
+ mean, std = self.get_mean_std_for_input()
860
+ mean = mean.squeeze()
861
+ std = std.squeeze()
862
+ if mean.size == 1:
863
+ mean = mean.reshape(
864
+ 1,
865
+ )
866
+ std = std.reshape(
867
+ 1,
868
+ )
869
+
870
+ for i in range(len(mean)):
871
+ assert mean[0] == mean[i]
872
+ assert std[0] == std[i]
873
+
874
+ inp = (inp - mean[0]) / std[0]
875
+ return inp.astype(np.float32)
876
+
877
+ def _sample_alpha(self):
878
+ alpha_arr = []
879
+ for i in range(self._num_channels):
880
+ alpha_pos = np.random.rand()
881
+ alpha = self._start_alpha_arr[i] + alpha_pos * (
882
+ self._end_alpha_arr[i] - self._start_alpha_arr[i]
883
+ )
884
+ alpha_arr.append(alpha)
885
+ return alpha_arr
886
+
887
+ def _compute_input(self, img_tuples):
888
+ alpha = [1 / len(img_tuples) for _ in range(len(img_tuples))]
889
+ if self._start_alpha_arr is not None:
890
+ alpha = self._sample_alpha()
891
+
892
+ inp = self._compute_input_with_alpha(img_tuples, alpha)
893
+ if self._input_is_sum:
894
+ inp = len(img_tuples) * inp
895
+ return inp, alpha
896
+
897
+ def _get_index_from_valid_target_logic(self, index):
898
+ if self._validtarget_rand_fract is not None:
899
+ if np.random.rand() < self._validtarget_rand_fract:
900
+ index = self._train_index_switcher.get_valid_target_index()
901
+ else:
902
+ index = self._train_index_switcher.get_invalid_target_index()
903
+ return index
904
+
905
+ def _rotate(self, img_tuples, noise_tuples):
906
+ return self._rotate2D(img_tuples, noise_tuples)
907
+
908
+ def _rotate2D(self, img_tuples, noise_tuples):
909
+ img_kwargs = {}
910
+ for i, img in enumerate(img_tuples):
911
+ for k in range(len(img)):
912
+ img_kwargs[f"img{i}_{k}"] = img[k]
913
+
914
+ noise_kwargs = {}
915
+ for i, nimg in enumerate(noise_tuples):
916
+ for k in range(len(nimg)):
917
+ noise_kwargs[f"noise{i}_{k}"] = nimg[k]
918
+
919
+ keys = list(img_kwargs.keys()) + list(noise_kwargs.keys())
920
+ self._rotation_transform.add_targets({k: "image" for k in keys})
921
+ rot_dic = self._rotation_transform(
922
+ image=img_tuples[0][0], **img_kwargs, **noise_kwargs
923
+ )
924
+ rotated_img_tuples = []
925
+ for i, img in enumerate(img_tuples):
926
+ if len(img) == 1:
927
+ rotated_img_tuples.append(rot_dic[f"img{i}_0"][None])
928
+ else:
929
+ rotated_img_tuples.append(
930
+ np.concatenate(
931
+ [rot_dic[f"img{i}_{k}"][None] for k in range(len(img))], axis=0
932
+ )
933
+ )
934
+
935
+ rotated_noise_tuples = []
936
+ for i, nimg in enumerate(noise_tuples):
937
+ if len(nimg) == 1:
938
+ rotated_noise_tuples.append(rot_dic[f"noise{i}_0"][None])
939
+ else:
940
+ rotated_noise_tuples.append(
941
+ np.concatenate(
942
+ [rot_dic[f"noise{i}_{k}"][None] for k in range(len(nimg))],
943
+ axis=0,
944
+ )
945
+ )
946
+
947
+ return rotated_img_tuples, rotated_noise_tuples
948
+
949
+ def get_uncorrelated_img_tuples(self, index):
950
+ img_tuples, noise_tuples = self._get_img(index)
951
+ assert len(noise_tuples) == 0
952
+ img_tuples = [img_tuples[0]]
953
+ for ch_idx in range(1, len(img_tuples)):
954
+ new_index = np.random.randint(len(self))
955
+ other_img_tuples, _ = self._get_img(new_index)
956
+ img_tuples.append(other_img_tuples[ch_idx])
957
+ return img_tuples, noise_tuples
958
+
959
+ def __getitem__(
960
+ self, index: Union[int, Tuple[int, int]]
961
+ ) -> Tuple[np.ndarray, np.ndarray]:
962
+ if self._train_index_switcher is not None:
963
+ index = self._get_index_from_valid_target_logic(index)
964
+
965
+ if self._uncorrelated_channels:
966
+ img_tuples, noise_tuples = self.get_uncorrelated_img_tuples(index)
967
+ else:
968
+ img_tuples, noise_tuples = self._get_img(index)
969
+
970
+ assert (
971
+ self._empty_patch_replacement_enabled != True
972
+ ), "This is not supported with noise"
973
+
974
+ if self._empty_patch_replacement_enabled:
975
+ if np.random.rand() < self._empty_patch_replacement_probab:
976
+ img_tuples = self.replace_with_empty_patch(img_tuples)
977
+
978
+ if self._enable_rotation:
979
+ img_tuples, noise_tuples = self._rotate(img_tuples, noise_tuples)
980
+
981
+ # add noise to input
982
+ if len(noise_tuples) > 0:
983
+ factor = np.sqrt(2) if self._input_is_sum else 1.0
984
+ input_tuples = [x + noise_tuples[0] * factor for x in img_tuples]
985
+ else:
986
+ input_tuples = img_tuples
987
+ inp, alpha = self._compute_input(input_tuples)
988
+
989
+ # add noise to target.
990
+ if len(noise_tuples) >= 1:
991
+ img_tuples = [x + noise for x, noise in zip(img_tuples, noise_tuples[1:])]
992
+
993
+ target = self._compute_target(img_tuples, alpha)
994
+
995
+ output = [inp, target]
996
+
997
+ if self._return_alpha:
998
+ output.append(alpha)
999
+
1000
+ if self._return_index:
1001
+ output.append(index)
1002
+
1003
+ if isinstance(index, int) or isinstance(index, np.int64):
1004
+ return tuple(output)
1005
+
1006
+ _, grid_size = index
1007
+ output.append(grid_size)
1008
+ return tuple(output)
1009
+
1010
+
1011
+ class LCMultiChDloader(MultiChDloader):
1012
+
1013
+ def __init__(
1014
+ self,
1015
+ data_config,
1016
+ fpath: str,
1017
+ datasplit_type: DataSplitType = None,
1018
+ val_fraction=None,
1019
+ test_fraction=None,
1020
+ normalized_input=None,
1021
+ enable_rotation_aug: bool = False,
1022
+ use_one_mu_std=None,
1023
+ num_scales: int = None,
1024
+ enable_random_cropping=False,
1025
+ padding_kwargs: dict = None,
1026
+ allow_generation: bool = False,
1027
+ lowres_supervision=None,
1028
+ max_val=None,
1029
+ grid_alignment=GridAlignement.LeftTop,
1030
+ overlapping_padding_kwargs=None,
1031
+ print_vars=True,
1032
+ ):
1033
+ """
1034
+ Args:
1035
+ num_scales: The number of resolutions at which we want the input. Note that the target is formed at the
1036
+ highest resolution.
1037
+ """
1038
+ self._padding_kwargs = (
1039
+ padding_kwargs # mode=padding_mode, constant_values=constant_value
1040
+ )
1041
+ if overlapping_padding_kwargs is not None:
1042
+ assert (
1043
+ self._padding_kwargs == overlapping_padding_kwargs
1044
+ ), "During evaluation, overlapping_padding_kwargs should be same as padding_args. \
1045
+ It should be so since we just use overlapping_padding_kwargs when it is not None"
1046
+
1047
+ else:
1048
+ overlapping_padding_kwargs = padding_kwargs
1049
+
1050
+ super().__init__(
1051
+ data_config,
1052
+ fpath,
1053
+ datasplit_type=datasplit_type,
1054
+ val_fraction=val_fraction,
1055
+ test_fraction=test_fraction,
1056
+ normalized_input=normalized_input,
1057
+ enable_rotation_aug=enable_rotation_aug,
1058
+ enable_random_cropping=enable_random_cropping,
1059
+ use_one_mu_std=use_one_mu_std,
1060
+ allow_generation=allow_generation,
1061
+ max_val=max_val,
1062
+ grid_alignment=grid_alignment,
1063
+ overlapping_padding_kwargs=overlapping_padding_kwargs,
1064
+ print_vars=print_vars,
1065
+ )
1066
+ self.num_scales = num_scales
1067
+ assert self.num_scales is not None
1068
+ self._scaled_data = [self._data]
1069
+ self._scaled_noise_data = [self._noise_data]
1070
+
1071
+ assert isinstance(self.num_scales, int) and self.num_scales >= 1
1072
+ self._lowres_supervision = lowres_supervision
1073
+ assert isinstance(self._padding_kwargs, dict)
1074
+ assert "mode" in self._padding_kwargs
1075
+
1076
+ for _ in range(1, self.num_scales):
1077
+ shape = self._scaled_data[-1].shape
1078
+ assert len(shape) == 4
1079
+ new_shape = (shape[0], shape[1] // 2, shape[2] // 2, shape[3])
1080
+ ds_data = resize(
1081
+ self._scaled_data[-1].astype(np.float32), new_shape
1082
+ ).astype(self._scaled_data[-1].dtype)
1083
+ # NOTE: These asserts are important. the resize method expects np.float32. otherwise, one gets weird results.
1084
+ assert (
1085
+ ds_data.max() / self._scaled_data[-1].max() < 5
1086
+ ), "Downsampled image should not have very different values"
1087
+ assert (
1088
+ ds_data.max() / self._scaled_data[-1].max() > 0.2
1089
+ ), "Downsampled image should not have very different values"
1090
+
1091
+ self._scaled_data.append(ds_data)
1092
+ # do the same for noise
1093
+ if self._noise_data is not None:
1094
+ noise_data = resize(self._scaled_noise_data[-1], new_shape)
1095
+ self._scaled_noise_data.append(noise_data)
1096
+
1097
+ def _init_msg(self):
1098
+ msg = super()._init_msg()
1099
+ msg += f" Pad:{self._padding_kwargs}"
1100
+ return msg
1101
+
1102
+ def _load_scaled_img(
1103
+ self, scaled_index, index: Union[int, Tuple[int, int]]
1104
+ ) -> Tuple[np.ndarray, np.ndarray]:
1105
+ if isinstance(index, int):
1106
+ idx = index
1107
+ else:
1108
+ idx, _ = index
1109
+ imgs = self._scaled_data[scaled_index][idx % self.N]
1110
+ imgs = tuple([imgs[None, :, :, i] for i in range(imgs.shape[-1])])
1111
+ if self._noise_data is not None:
1112
+ noisedata = self._scaled_noise_data[scaled_index][idx % self.N]
1113
+ noise = tuple(
1114
+ [noisedata[None, :, :, i] for i in range(noisedata.shape[-1])]
1115
+ )
1116
+ factor = np.sqrt(2) if self._input_is_sum else 1.0
1117
+ # since we are using this lowres images for just the input, we need to add the noise of the input.
1118
+ assert self._lowres_supervision is None or self._lowres_supervision is False
1119
+ imgs = tuple([img + noise[0] * factor for img in imgs])
1120
+ return imgs
1121
+
1122
+ def _crop_img(self, img: np.ndarray, h_start: int, w_start: int):
1123
+ """
1124
+ Here, h_start, w_start could be negative. That simply means we need to pick the content from 0. So,
1125
+ the cropped image will be smaller than self._img_sz * self._img_sz
1126
+ """
1127
+ return self._crop_img_with_padding(img, h_start, w_start)
1128
+
1129
+ def _get_img(self, index: int):
1130
+ """
1131
+ Returns the primary patch along with low resolution patches centered on the primary patch.
1132
+ """
1133
+ img_tuples, noise_tuples = self._load_img(index)
1134
+ assert self._img_sz is not None
1135
+ h, w = img_tuples[0].shape[-2:]
1136
+ if self._enable_random_cropping:
1137
+ h_start, w_start = self._get_random_hw(h, w)
1138
+ else:
1139
+ h_start, w_start = self._get_deterministic_hw(index)
1140
+
1141
+ cropped_img_tuples = [
1142
+ self._crop_flip_img(img, h_start, w_start, False, False)
1143
+ for img in img_tuples
1144
+ ]
1145
+ cropped_noise_tuples = [
1146
+ self._crop_flip_img(noise, h_start, w_start, False, False)
1147
+ for noise in noise_tuples
1148
+ ]
1149
+ h_center = h_start + self._img_sz // 2
1150
+ w_center = w_start + self._img_sz // 2
1151
+ allres_versions = {
1152
+ i: [cropped_img_tuples[i]] for i in range(len(cropped_img_tuples))
1153
+ }
1154
+ for scale_idx in range(1, self.num_scales):
1155
+ scaled_img_tuples = self._load_scaled_img(scale_idx, index)
1156
+
1157
+ h_center = h_center // 2
1158
+ w_center = w_center // 2
1159
+
1160
+ h_start = h_center - self._img_sz // 2
1161
+ w_start = w_center - self._img_sz // 2
1162
+
1163
+ scaled_cropped_img_tuples = [
1164
+ self._crop_flip_img(img, h_start, w_start, False, False)
1165
+ for img in scaled_img_tuples
1166
+ ]
1167
+ for ch_idx in range(len(img_tuples)):
1168
+ allres_versions[ch_idx].append(scaled_cropped_img_tuples[ch_idx])
1169
+
1170
+ output_img_tuples = tuple(
1171
+ [
1172
+ np.concatenate(allres_versions[ch_idx])
1173
+ for ch_idx in range(len(img_tuples))
1174
+ ]
1175
+ )
1176
+ return output_img_tuples, cropped_noise_tuples
1177
+
1178
+ def __getitem__(self, index: Union[int, Tuple[int, int]]):
1179
+ if self._uncorrelated_channels:
1180
+ img_tuples, noise_tuples = self.get_uncorrelated_img_tuples(index)
1181
+ else:
1182
+ img_tuples, noise_tuples = self._get_img(index)
1183
+
1184
+ if self._enable_rotation:
1185
+ img_tuples, noise_tuples = self._rotate(img_tuples, noise_tuples)
1186
+
1187
+ assert self._lowres_supervision != True
1188
+ # add noise to input
1189
+ if len(noise_tuples) > 0:
1190
+ factor = np.sqrt(2) if self._input_is_sum else 1.0
1191
+ input_tuples = []
1192
+ for x in img_tuples:
1193
+ # NOTE: other LC levels already have noise added. So, we just need to add noise to the highest resolution.
1194
+ x[0] = x[0] + noise_tuples[0] * factor
1195
+ input_tuples.append(x)
1196
+ else:
1197
+ input_tuples = img_tuples
1198
+
1199
+ inp, alpha = self._compute_input(input_tuples)
1200
+ # assert self._alpha_weighted_target in [False, None]
1201
+ target_tuples = [img[:1] for img in img_tuples]
1202
+ # add noise to target.
1203
+ if len(noise_tuples) >= 1:
1204
+ target_tuples = [
1205
+ x + noise for x, noise in zip(target_tuples, noise_tuples[1:])
1206
+ ]
1207
+
1208
+ target = self._compute_target(target_tuples, alpha)
1209
+
1210
+ output = [inp, target]
1211
+
1212
+ if self._return_alpha:
1213
+ output.append(alpha)
1214
+
1215
+ if isinstance(index, int):
1216
+ return tuple(output)
1217
+
1218
+ _, grid_size = index
1219
+ output.append(grid_size)
1220
+ return tuple(output)