careamics 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (98) hide show
  1. careamics/careamist.py +24 -7
  2. careamics/cli/utils.py +1 -1
  3. careamics/config/algorithms/n2v_algorithm_model.py +1 -1
  4. careamics/config/architectures/unet_model.py +3 -0
  5. careamics/config/callback_model.py +23 -34
  6. careamics/config/configuration.py +55 -4
  7. careamics/config/configuration_factories.py +288 -23
  8. careamics/config/data/__init__.py +2 -0
  9. careamics/config/data/data_model.py +41 -4
  10. careamics/config/data/ng_data_model.py +381 -0
  11. careamics/config/data/patching_strategies/__init__.py +14 -0
  12. careamics/config/data/patching_strategies/_overlapping_patched_model.py +103 -0
  13. careamics/config/data/patching_strategies/_patched_model.py +56 -0
  14. careamics/config/data/patching_strategies/random_patching_model.py +21 -0
  15. careamics/config/data/patching_strategies/sequential_patching_model.py +25 -0
  16. careamics/config/data/patching_strategies/tiled_patching_model.py +40 -0
  17. careamics/config/data/patching_strategies/whole_patching_model.py +12 -0
  18. careamics/config/inference_model.py +6 -3
  19. careamics/config/optimizer_models.py +1 -3
  20. careamics/config/support/supported_data.py +7 -0
  21. careamics/config/support/supported_patching_strategies.py +22 -0
  22. careamics/config/training_model.py +0 -2
  23. careamics/config/validators/validator_utils.py +4 -3
  24. careamics/dataset/dataset_utils/iterate_over_files.py +2 -2
  25. careamics/dataset/in_memory_dataset.py +2 -1
  26. careamics/dataset/iterable_dataset.py +2 -2
  27. careamics/dataset/iterable_pred_dataset.py +2 -2
  28. careamics/dataset/iterable_tiled_pred_dataset.py +2 -2
  29. careamics/dataset/patching/patching.py +3 -2
  30. careamics/dataset/tiling/lvae_tiled_patching.py +16 -6
  31. careamics/dataset/tiling/tiled_patching.py +2 -1
  32. careamics/dataset_ng/README.md +212 -0
  33. careamics/dataset_ng/dataset.py +229 -0
  34. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  35. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  36. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +734 -0
  37. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  38. careamics/dataset_ng/{demo_dataset.ipynb → demos/demo_dataset.ipynb} +60 -53
  39. careamics/dataset_ng/{demo_patch_extractor.py → demos/demo_patch_extractor.py} +7 -9
  40. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +292 -0
  41. careamics/dataset_ng/factory.py +451 -0
  42. careamics/dataset_ng/legacy_interoperability.py +170 -0
  43. careamics/dataset_ng/patch_extractor/__init__.py +3 -8
  44. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +7 -5
  45. careamics/dataset_ng/patch_extractor/image_stack/__init__.py +4 -1
  46. careamics/dataset_ng/patch_extractor/image_stack/czi_image_stack.py +360 -0
  47. careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +5 -1
  48. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -1
  49. careamics/dataset_ng/patch_extractor/image_stack_loader.py +5 -75
  50. careamics/dataset_ng/patch_extractor/patch_extractor.py +5 -4
  51. careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +114 -105
  52. careamics/dataset_ng/patching_strategies/__init__.py +6 -1
  53. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +31 -0
  54. careamics/dataset_ng/patching_strategies/random_patching.py +5 -1
  55. careamics/dataset_ng/patching_strategies/sequential_patching.py +5 -5
  56. careamics/dataset_ng/patching_strategies/tiling_strategy.py +172 -0
  57. careamics/dataset_ng/patching_strategies/whole_sample.py +36 -0
  58. careamics/file_io/read/get_func.py +2 -1
  59. careamics/lightning/dataset_ng/__init__.py +1 -0
  60. careamics/lightning/dataset_ng/data_module.py +678 -0
  61. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  62. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  63. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  64. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +212 -0
  65. careamics/lightning/lightning_module.py +5 -1
  66. careamics/lightning/predict_data_module.py +2 -1
  67. careamics/lightning/train_data_module.py +2 -1
  68. careamics/losses/loss_factory.py +2 -1
  69. careamics/lvae_training/dataset/__init__.py +8 -3
  70. careamics/lvae_training/dataset/config.py +3 -3
  71. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  72. careamics/lvae_training/dataset/multich_dataset.py +46 -17
  73. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  74. careamics/lvae_training/dataset/types.py +3 -3
  75. careamics/lvae_training/dataset/utils/index_manager.py +259 -0
  76. careamics/lvae_training/eval_utils.py +93 -3
  77. careamics/model_io/bioimage/bioimage_utils.py +1 -1
  78. careamics/model_io/bioimage/model_description.py +1 -1
  79. careamics/model_io/bmz_io.py +1 -1
  80. careamics/model_io/model_io_utils.py +2 -2
  81. careamics/models/activation.py +2 -1
  82. careamics/prediction_utils/prediction_outputs.py +1 -1
  83. careamics/prediction_utils/stitch_prediction.py +1 -1
  84. careamics/transforms/compose.py +1 -0
  85. careamics/transforms/n2v_manipulate_torch.py +15 -9
  86. careamics/transforms/normalize.py +18 -7
  87. careamics/transforms/pixel_manipulation_torch.py +59 -92
  88. careamics/utils/lightning_utils.py +25 -11
  89. careamics/utils/metrics.py +2 -1
  90. careamics/utils/torch_utils.py +23 -0
  91. {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/METADATA +12 -11
  92. {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/RECORD +95 -69
  93. careamics/dataset_ng/dataset/__init__.py +0 -3
  94. careamics/dataset_ng/dataset/dataset.py +0 -184
  95. careamics/dataset_ng/demo_patch_extractor_factory.py +0 -37
  96. {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/WHEEL +0 -0
  97. {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/entry_points.txt +0 -0
  98. {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1067 @@
1
+ """
2
+ A place for Datasets and Dataloaders.
3
+ """
4
+
5
+ from collections import defaultdict
6
+ from functools import cache
7
+ from pathlib import Path
8
+ from typing import Callable, Union
9
+
10
+ import numpy as np
11
+ from skimage.transform import resize
12
+
13
+ from .config import DatasetConfig
14
+ from .types import DataSplitType, TilingMode
15
+ from .utils.empty_patch_fetcher import EmptyPatchFetcher
16
+ from .utils.index_manager import GridIndexManagerRef
17
+
18
+
19
+ class MultiChDloaderRef:
20
+ def __init__(
21
+ self,
22
+ data_config: DatasetConfig,
23
+ fpath: str,
24
+ load_data_fn: Callable,
25
+ val_fraction: float = None,
26
+ test_fraction: float = None,
27
+ ):
28
+ """ """
29
+ self._data_type = data_config.data_type
30
+ self._fpath = Path(fpath)
31
+ self._data = None
32
+ self._3Ddata = False # TODO wtf it was 5D
33
+ self._tiling_mode = data_config.tiling_mode
34
+ # by default, if the noise is present, add it to the input and target.
35
+ self._depth3D = data_config.depth3D
36
+ self._mode_3D = data_config.mode_3D
37
+ # NOTE: Input is the sum of the different channels. It is not the average of the different channels.
38
+ self._input_is_sum = data_config.input_is_sum
39
+ self._num_channels = data_config.num_channels
40
+ self._input_idx = data_config.input_idx
41
+ self._tar_idx_list = data_config.target_idx_list
42
+
43
+ self.load_data(
44
+ data_config,
45
+ data_config.datasplit_type,
46
+ load_data_fn=load_data_fn,
47
+ val_fraction=val_fraction,
48
+ test_fraction=test_fraction,
49
+ allow_generation=data_config.allow_generation,
50
+ )
51
+
52
+ self._data_shapes = self.get_data_shapes()
53
+ self._normalized_input = data_config.normalized_input
54
+ self._quantile = 1.0
55
+ self._channelwise_quantile = False
56
+ self._background_quantile = 0.0
57
+ self._clip_background_noise_to_zero = False
58
+ self._skip_normalization_using_mean = False
59
+ self._empty_patch_replacement_enabled = False
60
+
61
+ self._background_values = None
62
+
63
+ self._overlapping_padding_kwargs = data_config.overlapping_padding_kwargs
64
+ if self._tiling_mode in [TilingMode.TrimBoundary, TilingMode.ShiftBoundary]:
65
+ if (
66
+ self._overlapping_padding_kwargs is None
67
+ or data_config.multiscale_lowres_count is not None
68
+ ):
69
+ # raise warning
70
+ print("Padding is not used with this alignement style")
71
+ else:
72
+ assert (
73
+ self._overlapping_padding_kwargs is not None
74
+ ), "When not trimming boudnary, padding is needed."
75
+
76
+ self._is_train = data_config.datasplit_type == DataSplitType.Train
77
+
78
+ # input = alpha * ch1 + (1-alpha)*ch2.
79
+ # alpha is sampled randomly between these two extremes
80
+ self._start_alpha_arr = self._end_alpha_arr = self._return_alpha = None
81
+
82
+ self._img_sz = self._grid_sz = self._repeat_factor = self.idx_manager = None
83
+
84
+ # changed set_img_sz because "grid_size" in data_config returns false
85
+ try:
86
+ grid_size = data_config.grid_size
87
+ except AttributeError:
88
+ grid_size = data_config.image_size
89
+
90
+ if self._is_train:
91
+ self._start_alpha_arr = data_config.start_alpha # TODO why only for train?
92
+ self._end_alpha_arr = data_config.end_alpha
93
+
94
+ self.set_img_sz(data_config.image_size, grid_size)
95
+
96
+ self._empty_patch_replacement_enabled = (
97
+ data_config.empty_patch_replacement_enabled and self._is_train
98
+ )
99
+ if self._empty_patch_replacement_enabled:
100
+ self._empty_patch_replacement_channel_idx = (
101
+ data_config.empty_patch_replacement_channel_idx
102
+ )
103
+ self._empty_patch_replacement_probab = (
104
+ data_config.empty_patch_replacement_probab
105
+ )
106
+ data_frames = self._data[..., self._empty_patch_replacement_channel_idx]
107
+ # NOTE: This is on the raw data. So, it must be called before removing the background.
108
+ self._empty_patch_fetcher = EmptyPatchFetcher(
109
+ self.idx_manager,
110
+ self._img_sz,
111
+ data_frames,
112
+ max_val_threshold=data_config.empty_patch_max_val_threshold,
113
+ )
114
+
115
+ self.rm_bkground_set_max_val_and_upperclip_data(
116
+ data_config.max_val, data_config.datasplit_type
117
+ )
118
+
119
+ # For overlapping dloader, image_size and repeat_factors are not related. hence a different function.
120
+
121
+ self._mean = None
122
+ self._std = None
123
+ self._use_one_mu_std = data_config.use_one_mu_std
124
+
125
+ self._target_separate_normalization = data_config.target_separate_normalization
126
+
127
+ self._enable_rotation = data_config.enable_rotation_aug
128
+ flipz_3D = data_config.random_flip_z_3D
129
+ self._flipz_3D = flipz_3D and self._enable_rotation
130
+
131
+ self._enable_random_cropping = data_config.enable_random_cropping
132
+ self._uncorrelated_channels = (
133
+ data_config.uncorrelated_channels and self._is_train
134
+ )
135
+ self._uncorrelated_channel_probab = data_config.uncorrelated_channel_probab
136
+ assert self._is_train or self._uncorrelated_channels is False
137
+ assert (
138
+ self._enable_random_cropping is True or self._uncorrelated_channels is False
139
+ )
140
+ # Randomly rotate [-90,90]
141
+
142
+ self._rotation_transform = None
143
+ if self._enable_rotation:
144
+ # TODO: fix this import
145
+ import albumentations as A
146
+
147
+ self._rotation_transform = A.Compose([A.Flip(), A.RandomRotate90()])
148
+
149
+ # TODO: remove print log messages
150
+ # if print_vars:
151
+ # msg = self._init_msg()
152
+ # print(msg)
153
+
154
+ def get_data_shapes(self):
155
+ if self._3Ddata: # TODO we assume images don't have a channel dimension
156
+ [
157
+ [
158
+ im.shape if len(im.shape) == 4 else (1, *im.shape)
159
+ for im in self._data[ch]
160
+ ]
161
+ for ch in range(len(self._data))
162
+ ]
163
+ else:
164
+ return [
165
+ [
166
+ im.shape if len(im.shape) == 3 else (1, *im.shape)
167
+ for im in self._data[ch]
168
+ ]
169
+ for ch in range(len(self._data))
170
+ ]
171
+
172
+ def load_data(
173
+ self,
174
+ data_config,
175
+ datasplit_type,
176
+ load_data_fn: Callable,
177
+ val_fraction=None,
178
+ test_fraction=None,
179
+ allow_generation=None,
180
+ ):
181
+ self._data = load_data_fn(
182
+ data_config,
183
+ self._fpath,
184
+ datasplit_type,
185
+ val_fraction=val_fraction,
186
+ test_fraction=test_fraction,
187
+ allow_generation=allow_generation,
188
+ )
189
+
190
+ # TODO check for 2D/3D data consistency with config
191
+ # TODO check number of channels consistency with config
192
+
193
+ def save_background(self, channel_idx, frame_idx, background_value):
194
+ self._background_values[frame_idx, channel_idx] = background_value
195
+
196
+ def get_background(self, channel_idx, frame_idx):
197
+ return self._background_values[frame_idx, channel_idx]
198
+
199
+ def rm_bkground_set_max_val_and_upperclip_data(self, max_val, datasplit_type):
200
+ # self.remove_background() # TODO revisit
201
+ self.set_max_val(max_val, datasplit_type)
202
+ self.upperclip_data()
203
+
204
+ def upperclip_data(self):
205
+ for ch_idx, data in enumerate(self._data):
206
+ if self.max_val[ch_idx] is not None:
207
+ for idx in range(len(data)):
208
+ data[idx][data[idx] > self.max_val[ch_idx]] = self.max_val[ch_idx]
209
+
210
+ def compute_max_val(self):
211
+ # TODO add channelwise quantile ?
212
+ return [
213
+ max([np.quantile(im, self._quantile) for im in ch]) for ch in self._data
214
+ ]
215
+
216
+ def set_max_val(self, max_val, datasplit_type):
217
+ if max_val is None:
218
+ assert datasplit_type in [DataSplitType.Train, DataSplitType.All]
219
+ self.max_val = self.compute_max_val()
220
+ else:
221
+ assert max_val is not None
222
+ self.max_val = max_val
223
+
224
+ def get_max_val(self):
225
+ return self.max_val
226
+
227
+ def get_img_sz(self):
228
+ return self._img_sz
229
+
230
+ def get_num_frames(self):
231
+ """Returns the number of the longest channel."""
232
+ return max(self.idx_manager.total_grid_count()[0])
233
+
234
+ def reduce_data(
235
+ self,
236
+ t_list=None,
237
+ z_start=None,
238
+ z_end=None,
239
+ h_start=None,
240
+ h_end=None,
241
+ w_start=None,
242
+ w_end=None,
243
+ ):
244
+ raise NotImplementedError("Not implemented")
245
+
246
+ def get_idx_manager_shapes(
247
+ self, patch_size: int, grid_size: Union[int, tuple[int, int, int]]
248
+ ):
249
+ numC = len(self._data_shapes)
250
+ if self._3Ddata:
251
+ patch_shape = (1, self._depth3D, patch_size, patch_size)
252
+ if isinstance(grid_size, int):
253
+ grid_shape = (1, 1, grid_size, grid_size)
254
+ else:
255
+ assert len(grid_size) == 3
256
+ assert all(
257
+ [g <= p for g, p in zip(grid_size, patch_shape[1:-1])]
258
+ ), f"Grid size {grid_size} must be less than patch size {patch_shape[1:-1]}"
259
+ grid_shape = (1, grid_size[0], grid_size[1], grid_size[2])
260
+ else:
261
+ assert isinstance(grid_size, int)
262
+ grid_shape = (1, grid_size, grid_size)
263
+ patch_shape = (1, patch_size, patch_size)
264
+
265
+ return patch_shape, grid_shape
266
+
267
+ def set_img_sz(self, image_size, grid_size: Union[int, tuple[int, int, int]]):
268
+ """
269
+ If one wants to change the image size on the go, then this can be used.
270
+ Args:
271
+ image_size: size of one patch
272
+ grid_size: frame is divided into square grids of this size. A patch centered on a grid having size `image_size` is returned.
273
+ """
274
+ # hacky way to deal with image shape from new conf
275
+ self._img_sz = image_size[-1] # TODO revisit!
276
+ self._grid_sz = grid_size
277
+ shapes = self._data_shapes
278
+
279
+ patch_shape, grid_shape = self.get_idx_manager_shapes(
280
+ self._img_sz, self._grid_sz
281
+ )
282
+ self.idx_manager = GridIndexManagerRef(
283
+ shapes, grid_shape, patch_shape, self._tiling_mode
284
+ )
285
+
286
+ def __len__(self):
287
+ # If channel length is not equal, return the longest
288
+ return max(self.idx_manager.total_grid_count()[0])
289
+
290
+ def _init_msg(
291
+ self,
292
+ ):
293
+ msg = (
294
+ f"[{self.__class__.__name__}] Train:{int(self._is_train)} Sz:{self._img_sz}"
295
+ )
296
+ dim_sizes = [
297
+ self.idx_manager.get_individual_dim_grid_count(dim)
298
+ for dim in range(len(self._data.shape))
299
+ ]
300
+ dim_sizes = ",".join([str(x) for x in dim_sizes])
301
+ msg += f" N:{self.N} NumPatchPerN:{self._repeat_factor}"
302
+ msg += f"{self.idx_manager.total_grid_count()} DimSz:({dim_sizes})"
303
+ msg += f" TrimB:{self._tiling_mode}"
304
+ # msg += f' NormInp:{self._normalized_input}'
305
+ # msg += f' SingleNorm:{self._use_one_mu_std}'
306
+ msg += f" Rot:{self._enable_rotation}"
307
+ if self._flipz_3D:
308
+ msg += f" FlipZ:{self._flipz_3D}"
309
+
310
+ msg += f" RandCrop:{self._enable_random_cropping}"
311
+ msg += f" Channel:{self._num_channels}"
312
+ # msg += f' Q:{self._quantile}'
313
+ if self._input_is_sum:
314
+ msg += f" SummedInput:{self._input_is_sum}"
315
+
316
+ if self._empty_patch_replacement_enabled:
317
+ msg += f" ReplaceWithRandSample:{self._empty_patch_replacement_enabled}"
318
+ if self._uncorrelated_channels:
319
+ msg += f" Uncorr:{self._uncorrelated_channels}"
320
+ if self._empty_patch_replacement_enabled:
321
+ msg += f"-{self._empty_patch_replacement_channel_idx}-{self._empty_patch_replacement_probab}"
322
+ if self._background_quantile > 0.0:
323
+ msg += f" BckQ:{self._background_quantile}"
324
+
325
+ if self._start_alpha_arr is not None:
326
+ msg += f" Alpha:[{self._start_alpha_arr},{self._end_alpha_arr}]"
327
+ return msg
328
+
329
+ def _crop_imgs(self, ch_idx: int, patch_idx: int, img: np.ndarray):
330
+ h, w = img.shape[-2:]
331
+ if self._img_sz is None:
332
+ return (
333
+ img,
334
+ {"h": [0, h], "w": [0, w], "hflip": False, "wflip": False},
335
+ )
336
+
337
+ if self._enable_random_cropping:
338
+ # this parameter is ambiguous. It toggles between random/deterministic patching
339
+ patch_start_loc = self._get_random_hw(h, w)
340
+ if self._3Ddata:
341
+ patch_start_loc = (
342
+ np.random.choice(1 + img.shape[-3] - self._depth3D),
343
+ ) + patch_start_loc
344
+ else:
345
+ # Patch coordinates are calculated by the index manager.
346
+ patch_start_loc = self._get_deterministic_loc(ch_idx, patch_idx)
347
+ cropped_img = self._crop_flip_img(img, patch_start_loc, False, False)
348
+
349
+ return cropped_img
350
+
351
+ def _crop_img(self, img: np.ndarray, patch_start_loc: tuple):
352
+ if self._tiling_mode in [TilingMode.TrimBoundary, TilingMode.ShiftBoundary]:
353
+ # In training, this is used.
354
+ # NOTE: It is my opinion that if I just use self._crop_img_with_padding, it will work perfectly fine.
355
+ # The only benefit this if else loop provides is that it makes it easier to see what happens during training.
356
+ patch_end_loc = (
357
+ np.array(patch_start_loc, dtype=np.int32)
358
+ + self.idx_manager.patch_shape[1:-1]
359
+ )
360
+ if self._3Ddata:
361
+ z_start, h_start, w_start = patch_start_loc
362
+ z_end, h_end, w_end = patch_end_loc
363
+ new_img = img[..., z_start:z_end, h_start:h_end, w_start:w_end]
364
+ else:
365
+ h_start, w_start = patch_start_loc
366
+ h_end, w_end = patch_end_loc
367
+ new_img = img[..., h_start:h_end, w_start:w_end]
368
+
369
+ return new_img
370
+ else:
371
+ # During evaluation, this is used. In this situation, we can have negative h_start, w_start. Or h_start +self._img_sz can be larger than frame
372
+ # In these situations, we need some sort of padding. This is not needed in the LeftTop alignement.
373
+ return self._crop_img_with_padding(img, patch_start_loc)
374
+
375
+ def get_begin_end_padding(self, start_pos, end_pos, max_len):
376
+ """
377
+ The effect is that the image with size self._grid_sz is in the center of the patch with sufficient
378
+ padding on all four sides so that the final patch size is self._img_sz.
379
+ """
380
+ pad_start = 0
381
+ pad_end = 0
382
+ if start_pos < 0:
383
+ pad_start = -1 * start_pos
384
+
385
+ pad_end = max(0, end_pos - max_len)
386
+
387
+ return pad_start, pad_end
388
+
389
+ def _crop_img_with_padding(
390
+ self, img: np.ndarray, patch_start_loc, max_len_vals=None
391
+ ):
392
+ if max_len_vals is None:
393
+ max_len_vals = self.idx_manager.data_shape[1:-1]
394
+ patch_end_loc = np.array(patch_start_loc, dtype=int) + np.array(
395
+ self.idx_manager.patch_shape[1:-1], dtype=int
396
+ )
397
+ boundary_crossed = []
398
+ valid_slice = []
399
+ padding = [[0, 0]]
400
+ for start_idx, end_idx, max_len in zip(
401
+ patch_start_loc, patch_end_loc, max_len_vals
402
+ ):
403
+ boundary_crossed.append(end_idx > max_len or start_idx < 0)
404
+ valid_slice.append((max(0, start_idx), min(max_len, end_idx)))
405
+ pad = [0, 0]
406
+ if boundary_crossed[-1]:
407
+ pad = self.get_begin_end_padding(start_idx, end_idx, max_len)
408
+ padding.append(pad)
409
+ # max() is needed since h_start could be negative.
410
+ if self._3Ddata:
411
+ new_img = img[
412
+ ...,
413
+ valid_slice[0][0] : valid_slice[0][1],
414
+ valid_slice[1][0] : valid_slice[1][1],
415
+ valid_slice[2][0] : valid_slice[2][1],
416
+ ]
417
+ else:
418
+ new_img = img[
419
+ ...,
420
+ valid_slice[0][0] : valid_slice[0][1],
421
+ valid_slice[1][0] : valid_slice[1][1],
422
+ ]
423
+
424
+ # print(np.array(padding).shape, img.shape, new_img.shape)
425
+ # print(padding)
426
+ if not np.all(padding == 0):
427
+ new_img = np.pad(new_img, padding, **self._overlapping_padding_kwargs)
428
+
429
+ return new_img
430
+
431
+ def _crop_flip_img(
432
+ self, img: np.ndarray, patch_start_loc: tuple, h_flip: bool, w_flip: bool
433
+ ):
434
+ new_img = self._crop_img(img, patch_start_loc)
435
+ if h_flip:
436
+ new_img = new_img[..., ::-1, :]
437
+ if w_flip:
438
+ new_img = new_img[..., :, ::-1]
439
+
440
+ return new_img.astype(np.float32)
441
+
442
+ def _load_img(self, ch_idx: int, patch_idx: int) -> tuple[np.ndarray, np.ndarray]:
443
+ """
444
+ Returns the channels and also the respective noise channels.
445
+ """
446
+ patch_loc_list = self.idx_manager.get_patch_location_from_patch_idx(
447
+ ch_idx, patch_idx
448
+ )
449
+ # TODO we should be adding channel dim here probably
450
+ img = self._data[ch_idx][patch_loc_list[0]]
451
+ return img
452
+
453
+ def get_mean_std(self):
454
+ return self._mean, self._std
455
+
456
+ def set_mean_std(self, mean_val, std_val):
457
+ self._mean = mean_val
458
+ self._std = std_val
459
+
460
+ def normalize_target(self, target):
461
+ mean_dict, std_dict = self.get_mean_std()
462
+ mean_ = mean_dict["target"] # .squeeze(0)
463
+ std_ = std_dict["target"] # .squeeze(0)
464
+ return (target - mean_) / std_
465
+
466
+ def get_grid_size(self):
467
+ return self._grid_sz
468
+
469
+ def get_idx_manager(self):
470
+ return self.idx_manager
471
+
472
+ def per_side_overlap_pixelcount(self):
473
+ return (self._img_sz - self._grid_sz) // 2
474
+
475
+ def _get_deterministic_loc(self, ch_idx: int, patch_idx: int):
476
+ """
477
+ It returns the top-left corner of the patch corresponding to index.
478
+ """
479
+ loc_list = self.idx_manager.get_patch_location_from_patch_idx(ch_idx, patch_idx)
480
+ # last dim is channel. we need to take the third and the second last element.
481
+ return loc_list[2:]
482
+
483
+ @cache
484
+ def crop_probablities(self, ch_idx):
485
+ sizes = np.array([np.prod(x.shape) for x in self._data[ch_idx]])
486
+ return sizes / sizes.sum()
487
+
488
+ def sample_crop(self, ch_idx):
489
+ idx = None
490
+ count = 0
491
+ while idx is None:
492
+ count += 1
493
+ idx = np.random.choice(
494
+ len(self._data[ch_idx]), p=self.crop_probablities(ch_idx)
495
+ )
496
+ data = self._data[ch_idx][idx] # TODO no channel and S dim ?
497
+ # changed for ndim
498
+ if all(
499
+ d >= self._img_sz for d in data.shape[-2:]
500
+ ): # TODO dims were hardcoded
501
+ h = np.random.randint(0, data.shape[-2] - self._img_sz)
502
+ w = np.random.randint(0, data.shape[-1] - self._img_sz)
503
+
504
+ if len(data.shape) > 2 and not self._3Ddata:
505
+ s = np.random.randint(0, data.shape[0] - 1)
506
+ return data[s, h : h + self._img_sz, w : w + self._img_sz]
507
+ else:
508
+ return data[h : h + self._img_sz, w : w + self._img_sz]
509
+
510
+ elif count > 100:
511
+ raise ValueError("Cannot find a valid crop")
512
+ else:
513
+ idx = None
514
+
515
+ return None
516
+
517
+ def _l2(self, x):
518
+ return np.sqrt(np.mean(np.array(x) ** 2))
519
+
520
+ def compute_mean_std(self, allow_for_validation_data=False):
521
+ """
522
+ Note that we must compute this only for training data.
523
+ """
524
+ if self._3Ddata:
525
+ raise NotImplementedError("Not implemented for 3D data")
526
+
527
+ if self._input_is_sum:
528
+ mean_tar_dict = defaultdict(list)
529
+ std_tar_dict = defaultdict(list)
530
+ mean_inp = []
531
+ std_inp = []
532
+ for _ in range(30000):
533
+ crops = []
534
+ for ch_idx in range(len(self._data)):
535
+ crop = self.sample_crop(ch_idx)
536
+ mean_tar_dict[ch_idx].append(np.mean(crop))
537
+ std_tar_dict[ch_idx].append(np.std(crop))
538
+ crops.append(crop)
539
+
540
+ inp = 0
541
+ for img in crops:
542
+ inp += img
543
+
544
+ mean_inp.append(np.mean(inp))
545
+ std_inp.append(np.std(inp))
546
+
547
+ output_mean = defaultdict(list)
548
+ output_std = defaultdict(list)
549
+
550
+ NC = len(self._data)
551
+ for ch_idx in range(NC):
552
+ output_mean["target"].append(np.mean(mean_tar_dict[ch_idx]))
553
+ output_std["target"].append(self._l2(std_tar_dict[ch_idx]))
554
+
555
+ output_mean["target"] = np.array(output_mean["target"]).reshape(NC, 1, 1)
556
+ output_std["target"] = np.array(output_std["target"]).reshape(NC, 1, 1)
557
+
558
+ output_mean["input"] = np.array([np.mean(mean_inp)]).reshape(1, 1, 1)
559
+ output_std["input"] = np.array([self._l2(std_inp)]).reshape(1, 1, 1)
560
+ else:
561
+ raise NotImplementedError("Not implemented for non-summed input")
562
+
563
+ return dict(output_mean), dict(output_std)
564
+
565
+ def set_mean_std(self, mean_dict, std_dict):
566
+ self._data_mean = mean_dict
567
+ self._data_std = std_dict
568
+
569
+ def get_mean_std(self):
570
+ return self._data_mean, self._data_std
571
+
572
+ def _get_random_hw(self, h: int, w: int):
573
+ """
574
+ Random starting position for the crop for the img with index `index`.
575
+ """
576
+ if h != self._img_sz:
577
+ h_start = np.random.choice(h - self._img_sz)
578
+ w_start = np.random.choice(w - self._img_sz)
579
+ else:
580
+ h_start = 0
581
+ w_start = 0
582
+ return h_start, w_start
583
+
584
+ def replace_with_empty_patch(self, img_tuples):
585
+ """
586
+ Replaces the content of one of the channels with background
587
+ """
588
+ empty_index = self._empty_patch_fetcher.sample()
589
+ empty_img_tuples, empty_img_noise_tuples = self._get_img(empty_index)
590
+ assert (
591
+ len(empty_img_noise_tuples) == 0
592
+ ), "Noise is not supported with empty patch replacement"
593
+ final_img_tuples = []
594
+ for tuple_idx in range(len(img_tuples)):
595
+ if tuple_idx == self._empty_patch_replacement_channel_idx:
596
+ final_img_tuples.append(empty_img_tuples[tuple_idx])
597
+ else:
598
+ final_img_tuples.append(img_tuples[tuple_idx])
599
+ return tuple(final_img_tuples)
600
+
601
+ def get_mean_std_for_input(self):
602
+ mean, std = self.get_mean_std()
603
+ return mean["input"], std["input"]
604
+
605
+ def _compute_target(self, img_tuples, alpha):
606
+ if self._tar_idx_list is not None and isinstance(self._tar_idx_list, int):
607
+ target = img_tuples[self._tar_idx_list]
608
+ else:
609
+ if self._tar_idx_list is not None:
610
+ assert isinstance(self._tar_idx_list, list) or isinstance(
611
+ self._tar_idx_list, tuple
612
+ )
613
+ img_tuples = [img_tuples[i] for i in self._tar_idx_list]
614
+
615
+ target = np.stack(img_tuples, axis=0)
616
+ return target
617
+
618
+ def _compute_input_with_alpha(self, img_tuples, alpha_list):
619
+ # assert self._normalized_input is True, "normalization should happen here"
620
+ if self._input_idx is not None:
621
+ inp = img_tuples[self._input_idx]
622
+ else:
623
+ inp = 0
624
+ for alpha, img in zip(alpha_list, img_tuples):
625
+ inp += img * alpha
626
+
627
+ if self._normalized_input is False:
628
+ return inp.astype(np.float32)
629
+
630
+ mean, std = self.get_mean_std_for_input()
631
+ mean = mean.squeeze()
632
+ std = std.squeeze()
633
+ if mean.size == 1:
634
+ mean = mean.reshape(
635
+ 1,
636
+ )
637
+ std = std.reshape(
638
+ 1,
639
+ )
640
+
641
+ for i in range(len(mean)):
642
+ assert mean[0] == mean[i]
643
+ assert std[0] == std[i]
644
+
645
+ inp = (inp - mean[0]) / std[0]
646
+ return inp.astype(np.float32)
647
+
648
+ def _sample_alpha(self):
649
+ alpha_arr = []
650
+ for i in range(self._num_channels):
651
+ alpha_pos = np.random.rand()
652
+ alpha = self._start_alpha_arr[i] + alpha_pos * (
653
+ self._end_alpha_arr[i] - self._start_alpha_arr[i]
654
+ )
655
+ alpha_arr.append(alpha)
656
+ return alpha_arr
657
+
658
+ def _compute_input(self, img_tuples):
659
+ alpha = [1 / len(img_tuples) for _ in range(len(img_tuples))]
660
+ if self._start_alpha_arr is not None:
661
+ alpha = self._sample_alpha()
662
+
663
+ inp = self._compute_input_with_alpha(img_tuples, alpha)
664
+ if self._input_is_sum:
665
+ inp = len(img_tuples) * inp
666
+
667
+ # TODO instead we add channel here
668
+ if len(inp.shape) == 2 or (len(inp.shape) == 3 and self._3Ddata):
669
+ inp = inp[None, ...]
670
+
671
+ return inp, alpha
672
+
673
+ def _get_index_from_valid_target_logic(self, index):
674
+ if self._validtarget_rand_fract is not None:
675
+ if np.random.rand() < self._validtarget_rand_fract:
676
+ index = self._train_index_switcher.get_valid_target_index()
677
+ else:
678
+ index = self._train_index_switcher.get_invalid_target_index()
679
+ return index
680
+
681
+ def _rotate2D(self, img_tuples):
682
+ img_kwargs = {}
683
+ for i, img in enumerate(img_tuples):
684
+ for k in range(len(img)):
685
+ img_kwargs[f"img{i}_{k}"] = img[k]
686
+
687
+ keys = list(img_kwargs.keys())
688
+ self._rotation_transform.add_targets({k: "image" for k in keys})
689
+ rot_dic = self._rotation_transform(image=img_tuples[0][0], **img_kwargs)
690
+
691
+ rotated_img_tuples = []
692
+ for i, img in enumerate(img_tuples):
693
+ if len(img) == 1:
694
+ rotated_img_tuples.append(rot_dic[f"img{i}_0"][None])
695
+ else:
696
+ rotated_img_tuples.append(
697
+ np.concatenate(
698
+ [rot_dic[f"img{i}_{k}"][None] for k in range(len(img))], axis=0
699
+ )
700
+ )
701
+
702
+ return rotated_img_tuples
703
+
704
+ def _rotate3D(self, img_tuples):
705
+ img_kwargs = {}
706
+ # random flip in z direction
707
+ flip_z = self._flipz_3D and np.random.rand() < 0.5
708
+ for i, img in enumerate(img_tuples):
709
+ for j in range(self._depth3D):
710
+ for k in range(len(img)):
711
+ if flip_z:
712
+ z_idx = self._depth3D - 1 - j
713
+ else:
714
+ z_idx = j
715
+ img_kwargs[f"img{i}_{z_idx}_{k}"] = img[k, j]
716
+
717
+ keys = list(img_kwargs.keys())
718
+ self._rotation_transform.add_targets({k: "image" for k in keys})
719
+ rot_dic = self._rotation_transform(image=img_tuples[0][0][0], **img_kwargs)
720
+ rotated_img_tuples = []
721
+ for i, img in enumerate(img_tuples):
722
+ if len(img) == 1:
723
+ rotated_img_tuples.append(
724
+ np.concatenate(
725
+ [
726
+ rot_dic[f"img{i}_{j}_0"][None, None]
727
+ for j in range(self._depth3D)
728
+ ],
729
+ axis=1,
730
+ )
731
+ )
732
+ else:
733
+ temp_arr = []
734
+ for k in range(len(img)):
735
+ temp_arr.append(
736
+ np.concatenate(
737
+ [
738
+ rot_dic[f"img{i}_{j}_{k}"][None, None]
739
+ for j in range(self._depth3D)
740
+ ],
741
+ axis=1,
742
+ )
743
+ )
744
+ rotated_img_tuples.append(np.concatenate(temp_arr, axis=0))
745
+
746
+ return rotated_img_tuples
747
+
748
+ def _rotate(self, img_tuples, noise_tuples):
749
+
750
+ if self._3Ddata:
751
+ return self._rotate3D(img_tuples, noise_tuples)
752
+ else:
753
+ return self._rotate2D(img_tuples, noise_tuples)
754
+
755
+ def _get_img(self, ch_idx: int, patch_idx: int):
756
+ """
757
+ Loads an image.
758
+ Crops the image such that cropped image has content.
759
+ """
760
+ img = self._load_img(ch_idx, patch_idx)
761
+ cropped_img = self._crop_imgs(ch_idx, patch_idx, img)
762
+ return cropped_img
763
+
764
+ def get_uncorrelated_img_tuples(self, index):
765
+ """
766
+ Content of channels like actin and nuclei is "correlated" in its
767
+ respective location, this function allows to pick channels' content
768
+ from different patches of the image to make it "uncorrelated".
769
+ """
770
+ img_tuples = []
771
+ for ch_idx in range(len(self._data)):
772
+ if ch_idx == 0:
773
+ # dataset index becomes sample index because all channels have the same
774
+ # length
775
+ img_tuples.append(self._get_img(0, index))
776
+ else:
777
+ # get a random index from corresponding channel
778
+ sample_index = np.random.randint(
779
+ self.idx_manager.total_grid_count()[0][ch_idx]
780
+ )
781
+ img_tuples.append(self._get_img(ch_idx, sample_index))
782
+ return img_tuples
783
+
784
+ def __getitem__(
785
+ self, index: Union[int, tuple[int, int]]
786
+ ) -> tuple[np.ndarray, np.ndarray]:
787
+
788
+ # Uncorrelated channels means crops to create the input are taken from different
789
+ # spatial locations of the image.
790
+ if (
791
+ self._uncorrelated_channels
792
+ and np.random.rand() < self._uncorrelated_channel_probab
793
+ ):
794
+ input_tuples = self.get_uncorrelated_img_tuples(index)
795
+ else:
796
+ # 0 is the channel index, because in this case locations are the same for
797
+ # all channels
798
+ # tuple for compatibility with _compute_input. #TODO check
799
+ input_tuples = (self._get_img(0, index),)
800
+
801
+ if self._enable_rotation:
802
+ input_tuples = self._rotate(input_tuples)
803
+
804
+ # Weight the individual channels, typically alpha is fixed
805
+ inp, alpha = self._compute_input(input_tuples)
806
+
807
+ target = self._compute_target(input_tuples, alpha)
808
+ norm_target = self.normalize_target(target)
809
+
810
+ return inp, norm_target
811
+
812
+
813
+ class LCMultiChDloaderRef(MultiChDloaderRef):
814
+ def __init__(
815
+ self,
816
+ data_config: DatasetConfig,
817
+ fpath: str,
818
+ load_data_fn: Callable,
819
+ val_fraction=None,
820
+ test_fraction=None,
821
+ ):
822
+ self._padding_kwargs = (
823
+ data_config.padding_kwargs # mode=padding_mode, constant_values=constant_value
824
+ )
825
+ self._uncorrelated_channel_probab = data_config.uncorrelated_channel_probab
826
+
827
+ super().__init__(
828
+ data_config,
829
+ fpath,
830
+ load_data_fn=load_data_fn,
831
+ val_fraction=val_fraction,
832
+ test_fraction=test_fraction,
833
+ )
834
+
835
+ if data_config.overlapping_padding_kwargs is not None:
836
+ assert (
837
+ self._padding_kwargs == data_config.overlapping_padding_kwargs
838
+ ), "During evaluation, overlapping_padding_kwargs should be same as padding_args. \
839
+ It should be so since we just use overlapping_padding_kwargs when it is not None"
840
+
841
+ else:
842
+ self._overlapping_padding_kwargs = data_config.padding_kwargs
843
+
844
+ self.multiscale_lowres_count = data_config.multiscale_lowres_count
845
+ assert self.multiscale_lowres_count is not None
846
+ self._scaled_data = [self._data]
847
+ self._scaled_noise_data = [self._noise_data]
848
+
849
+ assert (
850
+ isinstance(self.multiscale_lowres_count, int)
851
+ and self.multiscale_lowres_count >= 1
852
+ )
853
+ assert isinstance(self._padding_kwargs, dict)
854
+ assert "mode" in self._padding_kwargs
855
+
856
+ for _ in range(1, self.multiscale_lowres_count):
857
+ shape = self._scaled_data[-1].shape
858
+ assert len(shape) == 4
859
+ new_shape = (shape[0], shape[1] // 2, shape[2] // 2, shape[3])
860
+ ds_data = resize(
861
+ self._scaled_data[-1].astype(np.float32), new_shape
862
+ ).astype(self._scaled_data[-1].dtype)
863
+ # NOTE: These asserts are important. the resize method expects np.float32. otherwise, one gets weird results.
864
+ assert (
865
+ ds_data.max() / self._scaled_data[-1].max() < 5
866
+ ), "Downsampled image should not have very different values"
867
+ assert (
868
+ ds_data.max() / self._scaled_data[-1].max() > 0.2
869
+ ), "Downsampled image should not have very different values"
870
+
871
+ self._scaled_data.append(ds_data)
872
+ # do the same for noise
873
+ if self._noise_data is not None:
874
+ noise_data = resize(self._scaled_noise_data[-1], new_shape)
875
+ self._scaled_noise_data.append(noise_data)
876
+
877
+ def reduce_data(
878
+ self, t_list=None, h_start=None, h_end=None, w_start=None, w_end=None
879
+ ):
880
+ assert t_list is not None
881
+ assert h_start is None
882
+ assert h_end is None
883
+ assert w_start is None
884
+ assert w_end is None
885
+
886
+ self._data = self._data[t_list].copy()
887
+ self._scaled_data = [
888
+ self._scaled_data[i][t_list].copy() for i in range(len(self._scaled_data))
889
+ ]
890
+
891
+ if self._noise_data is not None:
892
+ self._noise_data = self._noise_data[t_list].copy()
893
+ self._scaled_noise_data = [
894
+ self._scaled_noise_data[i][t_list].copy()
895
+ for i in range(len(self._scaled_noise_data))
896
+ ]
897
+
898
+ self.N = len(t_list)
899
+ # TODO where tf is self._img_sz defined?
900
+ self.set_img_sz([self._img_sz, self._img_sz], self._grid_sz)
901
+ print(
902
+ f"[{self.__class__.__name__}] Data reduced. New data shape: {self._data.shape}"
903
+ )
904
+
905
+ def _init_msg(self):
906
+ msg = super()._init_msg()
907
+ msg += f" Pad:{self._padding_kwargs}"
908
+ if self._uncorrelated_channels:
909
+ msg += f" UncorrChProbab:{self._uncorrelated_channel_probab}"
910
+ return msg
911
+
912
+ def _load_scaled_img(
913
+ self, scaled_index, index: Union[int, tuple[int, int]]
914
+ ) -> tuple[np.ndarray, np.ndarray]:
915
+ if isinstance(index, int):
916
+ idx = index
917
+ else:
918
+ idx, _ = index
919
+
920
+ # tidx = self.idx_manager.get_t(idx)
921
+ patch_loc_list = self.idx_manager.get_patch_location_from_dataset_idx(idx)
922
+ nidx = patch_loc_list[0]
923
+
924
+ imgs = self._scaled_data[scaled_index][nidx]
925
+ imgs = tuple([imgs[None, ..., i] for i in range(imgs.shape[-1])])
926
+ if self._noise_data is not None:
927
+ noisedata = self._scaled_noise_data[scaled_index][nidx]
928
+ noise = tuple([noisedata[None, ..., i] for i in range(noisedata.shape[-1])])
929
+ factor = np.sqrt(2) if self._input_is_sum else 1.0
930
+ imgs = tuple([img + noise[0] * factor for img in imgs])
931
+ return imgs
932
+
933
+ def _crop_img(self, img: np.ndarray, patch_start_loc: tuple):
934
+ """
935
+ Here, h_start, w_start could be negative. That simply means we need to pick the content from 0. So,
936
+ the cropped image will be smaller than self._img_sz * self._img_sz
937
+ """
938
+ max_len_vals = list(self.idx_manager.data_shape[1:-1])
939
+ max_len_vals[-2:] = img.shape[-2:]
940
+ return self._crop_img_with_padding(
941
+ img, patch_start_loc, max_len_vals=max_len_vals
942
+ )
943
+
944
+ def _get_img(self, index: int):
945
+ """
946
+ Returns the primary patch along with low resolution patches centered on the primary patch.
947
+ """
948
+ # Noise_tuples is populated when there is synthetic noise in training
949
+ # Should have similar type of noise with the noise model
950
+ # Starting with microsplit, dump the noise, use it instead as an augmentation if nessesary
951
+ img_tuples, noise_tuples = self._load_img(index)
952
+ assert self._img_sz is not None
953
+ h, w = img_tuples[0].shape[-2:]
954
+ if self._enable_random_cropping:
955
+ patch_start_loc = self._get_random_hw(h, w)
956
+ if self._3Ddata:
957
+ patch_start_loc = (
958
+ np.random.choice(img_tuples[0].shape[-3] - self._depth3D),
959
+ ) + patch_start_loc
960
+ else:
961
+ patch_start_loc = self._get_deterministic_loc(index)
962
+
963
+ # LC logic is located here, the function crops the image of the highest resolution
964
+ cropped_img_tuples = [
965
+ self._crop_flip_img(img, patch_start_loc, False, False)
966
+ for img in img_tuples
967
+ ]
968
+ cropped_noise_tuples = [
969
+ self._crop_flip_img(noise, patch_start_loc, False, False)
970
+ for noise in noise_tuples
971
+ ]
972
+ patch_start_loc = list(patch_start_loc)
973
+ h_start, w_start = patch_start_loc[-2], patch_start_loc[-1]
974
+ h_center = h_start + self._img_sz // 2
975
+ w_center = w_start + self._img_sz // 2
976
+ allres_versions = {
977
+ i: [cropped_img_tuples[i]] for i in range(len(cropped_img_tuples))
978
+ }
979
+ for scale_idx in range(1, self.multiscale_lowres_count):
980
+ # Returning the image of the lower resolution
981
+ scaled_img_tuples = self._load_scaled_img(scale_idx, index)
982
+
983
+ h_center = h_center // 2
984
+ w_center = w_center // 2
985
+
986
+ h_start = h_center - self._img_sz // 2
987
+ w_start = w_center - self._img_sz // 2
988
+ patch_start_loc[-2:] = [h_start, w_start]
989
+ scaled_cropped_img_tuples = [
990
+ self._crop_flip_img(img, patch_start_loc, False, False)
991
+ for img in scaled_img_tuples
992
+ ]
993
+ for ch_idx in range(len(img_tuples)):
994
+ allres_versions[ch_idx].append(scaled_cropped_img_tuples[ch_idx])
995
+
996
+ output_img_tuples = tuple(
997
+ [
998
+ np.concatenate(allres_versions[ch_idx])
999
+ for ch_idx in range(len(img_tuples))
1000
+ ]
1001
+ )
1002
+ return output_img_tuples, cropped_noise_tuples
1003
+
1004
+ def __getitem__(self, index: Union[int, tuple[int, int]]):
1005
+ img_tuples, noise_tuples = self._get_img(index)
1006
+ if self._uncorrelated_channels:
1007
+ assert (
1008
+ self._input_idx is None
1009
+ ), "Uncorrelated channels is not implemented when there is a separate input channel."
1010
+ if np.random.rand() < self._uncorrelated_channel_probab:
1011
+ img_tuples_new = [None] * len(img_tuples)
1012
+ img_tuples_new[0] = img_tuples[0]
1013
+ for i in range(1, len(img_tuples)):
1014
+ new_index = np.random.randint(len(self))
1015
+ img_tuples_tmp, _ = self._get_img(new_index)
1016
+ img_tuples_new[i] = img_tuples_tmp[i]
1017
+ img_tuples = img_tuples_new
1018
+
1019
+ if self._is_train:
1020
+ if self._empty_patch_replacement_enabled:
1021
+ if np.random.rand() < self._empty_patch_replacement_probab:
1022
+ img_tuples = self.replace_with_empty_patch(img_tuples)
1023
+
1024
+ if self._enable_rotation:
1025
+ img_tuples, noise_tuples = self._rotate(img_tuples, noise_tuples)
1026
+
1027
+ # add noise to input, if noise is present combine it with the image
1028
+ # factor is for the compute input not to have too much noise because the average of two gaussians
1029
+ if len(noise_tuples) > 0:
1030
+ factor = np.sqrt(2) if self._input_is_sum else 1.0
1031
+ input_tuples = []
1032
+ for x in img_tuples:
1033
+ x = (
1034
+ x.copy()
1035
+ ) # to avoid changing the original image since it is later used for target
1036
+ # NOTE: other LC levels already have noise added. So, we just need to add noise to the highest resolution.
1037
+ x[0] = x[0] + noise_tuples[0] * factor
1038
+ input_tuples.append(x)
1039
+ else:
1040
+ input_tuples = img_tuples
1041
+
1042
+ # Compute the input by sum / average the channels
1043
+ # Alpha is an amount of weight which is applied to the channels when combining them
1044
+ # How to sample alpha is still under research
1045
+ inp, alpha = self._compute_input(input_tuples)
1046
+ target_tuples = [img[:1] for img in img_tuples]
1047
+ # add noise to target.
1048
+ if len(noise_tuples) >= 1:
1049
+ target_tuples = [
1050
+ x + noise for x, noise in zip(target_tuples, noise_tuples[1:])
1051
+ ]
1052
+
1053
+ target = self._compute_target(target_tuples, alpha)
1054
+
1055
+ norm_target = self.normalize_target(target)
1056
+
1057
+ output = [inp, norm_target]
1058
+
1059
+ if self._return_alpha:
1060
+ output.append(alpha)
1061
+
1062
+ if isinstance(index, int):
1063
+ return tuple(output)
1064
+
1065
+ _, grid_size = index
1066
+ output.append(grid_size)
1067
+ return tuple(output)