careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (155) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +729 -0
  3. careamics/config/__init__.py +39 -0
  4. careamics/config/architectures/__init__.py +17 -0
  5. careamics/config/architectures/architecture_model.py +37 -0
  6. careamics/config/architectures/custom_model.py +162 -0
  7. careamics/config/architectures/lvae_model.py +174 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/callback_model.py +123 -0
  11. careamics/config/configuration_factory.py +583 -0
  12. careamics/config/configuration_model.py +604 -0
  13. careamics/config/data_model.py +527 -0
  14. careamics/config/fcn_algorithm_model.py +147 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/likelihood_model.py +43 -0
  17. careamics/config/nm_model.py +101 -0
  18. careamics/config/optimizer_models.py +187 -0
  19. careamics/config/references/__init__.py +45 -0
  20. careamics/config/references/algorithm_descriptions.py +132 -0
  21. careamics/config/references/references.py +39 -0
  22. careamics/config/support/__init__.py +31 -0
  23. careamics/config/support/supported_activations.py +27 -0
  24. careamics/config/support/supported_algorithms.py +33 -0
  25. careamics/config/support/supported_architectures.py +17 -0
  26. careamics/config/support/supported_data.py +109 -0
  27. careamics/config/support/supported_loggers.py +10 -0
  28. careamics/config/support/supported_losses.py +29 -0
  29. careamics/config/support/supported_optimizers.py +57 -0
  30. careamics/config/support/supported_pixel_manipulations.py +15 -0
  31. careamics/config/support/supported_struct_axis.py +21 -0
  32. careamics/config/support/supported_transforms.py +11 -0
  33. careamics/config/tile_information.py +65 -0
  34. careamics/config/training_model.py +72 -0
  35. careamics/config/transformations/__init__.py +15 -0
  36. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  37. careamics/config/transformations/normalize_model.py +60 -0
  38. careamics/config/transformations/transform_model.py +45 -0
  39. careamics/config/transformations/xy_flip_model.py +43 -0
  40. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  41. careamics/config/vae_algorithm_model.py +171 -0
  42. careamics/config/validators/__init__.py +5 -0
  43. careamics/config/validators/validator_utils.py +101 -0
  44. careamics/conftest.py +39 -0
  45. careamics/dataset/__init__.py +17 -0
  46. careamics/dataset/dataset_utils/__init__.py +19 -0
  47. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  48. careamics/dataset/dataset_utils/file_utils.py +141 -0
  49. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  50. careamics/dataset/dataset_utils/running_stats.py +186 -0
  51. careamics/dataset/in_memory_dataset.py +310 -0
  52. careamics/dataset/in_memory_pred_dataset.py +88 -0
  53. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  54. careamics/dataset/iterable_dataset.py +295 -0
  55. careamics/dataset/iterable_pred_dataset.py +122 -0
  56. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  57. careamics/dataset/patching/__init__.py +1 -0
  58. careamics/dataset/patching/patching.py +299 -0
  59. careamics/dataset/patching/random_patching.py +201 -0
  60. careamics/dataset/patching/sequential_patching.py +212 -0
  61. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  62. careamics/dataset/tiling/__init__.py +10 -0
  63. careamics/dataset/tiling/collate_tiles.py +33 -0
  64. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  65. careamics/dataset/tiling/tiled_patching.py +164 -0
  66. careamics/dataset/zarr_dataset.py +151 -0
  67. careamics/file_io/__init__.py +15 -0
  68. careamics/file_io/read/__init__.py +12 -0
  69. careamics/file_io/read/get_func.py +56 -0
  70. careamics/file_io/read/tiff.py +58 -0
  71. careamics/file_io/read/zarr.py +60 -0
  72. careamics/file_io/write/__init__.py +15 -0
  73. careamics/file_io/write/get_func.py +63 -0
  74. careamics/file_io/write/tiff.py +40 -0
  75. careamics/lightning/__init__.py +18 -0
  76. careamics/lightning/callbacks/__init__.py +11 -0
  77. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  79. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  80. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  81. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  82. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  83. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  84. careamics/lightning/lightning_module.py +632 -0
  85. careamics/lightning/predict_data_module.py +333 -0
  86. careamics/lightning/train_data_module.py +680 -0
  87. careamics/losses/__init__.py +15 -0
  88. careamics/losses/fcn/__init__.py +1 -0
  89. careamics/losses/fcn/losses.py +98 -0
  90. careamics/losses/loss_factory.py +155 -0
  91. careamics/losses/lvae/__init__.py +1 -0
  92. careamics/losses/lvae/loss_utils.py +83 -0
  93. careamics/losses/lvae/losses.py +445 -0
  94. careamics/lvae_training/__init__.py +0 -0
  95. careamics/lvae_training/dataset/__init__.py +0 -0
  96. careamics/lvae_training/dataset/data_utils.py +701 -0
  97. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  98. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  99. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  100. careamics/lvae_training/dataset/vae_dataset.py +1054 -0
  101. careamics/lvae_training/eval_utils.py +905 -0
  102. careamics/lvae_training/get_config.py +84 -0
  103. careamics/lvae_training/lightning_module.py +701 -0
  104. careamics/lvae_training/metrics.py +214 -0
  105. careamics/lvae_training/train_lvae.py +342 -0
  106. careamics/lvae_training/train_utils.py +121 -0
  107. careamics/model_io/__init__.py +7 -0
  108. careamics/model_io/bioimage/__init__.py +11 -0
  109. careamics/model_io/bioimage/_readme_factory.py +121 -0
  110. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  111. careamics/model_io/bioimage/model_description.py +327 -0
  112. careamics/model_io/bmz_io.py +246 -0
  113. careamics/model_io/model_io_utils.py +95 -0
  114. careamics/models/__init__.py +5 -0
  115. careamics/models/activation.py +39 -0
  116. careamics/models/layers.py +493 -0
  117. careamics/models/lvae/__init__.py +3 -0
  118. careamics/models/lvae/layers.py +1998 -0
  119. careamics/models/lvae/likelihoods.py +364 -0
  120. careamics/models/lvae/lvae.py +901 -0
  121. careamics/models/lvae/noise_models.py +541 -0
  122. careamics/models/lvae/utils.py +395 -0
  123. careamics/models/model_factory.py +67 -0
  124. careamics/models/unet.py +443 -0
  125. careamics/prediction_utils/__init__.py +10 -0
  126. careamics/prediction_utils/lvae_prediction.py +158 -0
  127. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  128. careamics/prediction_utils/prediction_outputs.py +135 -0
  129. careamics/prediction_utils/stitch_prediction.py +112 -0
  130. careamics/transforms/__init__.py +20 -0
  131. careamics/transforms/compose.py +107 -0
  132. careamics/transforms/n2v_manipulate.py +146 -0
  133. careamics/transforms/normalize.py +243 -0
  134. careamics/transforms/pixel_manipulation.py +407 -0
  135. careamics/transforms/struct_mask_parameters.py +20 -0
  136. careamics/transforms/transform.py +24 -0
  137. careamics/transforms/tta.py +88 -0
  138. careamics/transforms/xy_flip.py +123 -0
  139. careamics/transforms/xy_random_rotate90.py +101 -0
  140. careamics/utils/__init__.py +19 -0
  141. careamics/utils/autocorrelation.py +40 -0
  142. careamics/utils/base_enum.py +60 -0
  143. careamics/utils/context.py +66 -0
  144. careamics/utils/logging.py +322 -0
  145. careamics/utils/metrics.py +188 -0
  146. careamics/utils/path_utils.py +26 -0
  147. careamics/utils/ram.py +15 -0
  148. careamics/utils/receptive_field.py +108 -0
  149. careamics/utils/torch_utils.py +127 -0
  150. careamics-0.0.3.dist-info/METADATA +78 -0
  151. careamics-0.0.3.dist-info/RECORD +154 -0
  152. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
  153. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
  154. careamics-0.0.1.dist-info/METADATA +0 -46
  155. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,1054 @@
1
+ """
2
+ A place for Datasets and Dataloaders.
3
+ """
4
+
5
+ from typing import Tuple, Union
6
+
7
+ import numpy as np
8
+
9
+ from .data_utils import (
10
+ GridIndexManager,
11
+ IndexSwitcher,
12
+ get_train_val_data,
13
+ )
14
+ from .vae_data_config import VaeDatasetConfig, DataSplitType, GridAlignement
15
+
16
+
17
+ class MultiChDloader:
18
+ def __init__(
19
+ self,
20
+ data_config: VaeDatasetConfig,
21
+ fpath: str,
22
+ val_fraction: float = None,
23
+ test_fraction: float = None,
24
+ ):
25
+ """ """
26
+ self._data_type = data_config.data_type
27
+ self._fpath = fpath
28
+ self._data = self.N = self._noise_data = None
29
+ self.Z = 1
30
+ self._trim_boundary = data_config.trim_boundary
31
+ # Hardcoded params, not included in the config file.
32
+
33
+ # by default, if the noise is present, add it to the input and target.
34
+ self._disable_noise = False # to add synthetic noise
35
+ self._poisson_noise_factor = None
36
+ self._train_index_switcher = None
37
+ self._depth3D = data_config.depth3D
38
+ # NOTE: Input is the sum of the different channels. It is not the average of the different channels.
39
+ self._input_is_sum = data_config.input_is_sum
40
+ self._num_channels = data_config.num_channels
41
+ self._input_idx = data_config.input_idx
42
+ self._tar_idx_list = data_config.target_idx_list
43
+
44
+ if data_config.datasplit_type == DataSplitType.Train:
45
+ self._datausage_fraction = 1.0
46
+ # assert self._datausage_fraction == 1.0, 'Not supported. Use validtarget_random_fraction and training_validtarget_fraction to get the same effect'
47
+ self._validtarget_rand_fract = None
48
+ # self._validtarget_random_fraction_final = data_config.get('validtarget_random_fraction_final', None)
49
+ # self._validtarget_random_fraction_stepepoch = data_config.get('validtarget_random_fraction_stepepoch', None)
50
+ # self._idx_count = 0
51
+ elif data_config.datasplit_type == DataSplitType.Val:
52
+ self._datausage_fraction = 1.0
53
+ else:
54
+ self._datausage_fraction = 1.0
55
+
56
+ self.load_data(
57
+ data_config,
58
+ data_config.datasplit_type,
59
+ val_fraction=val_fraction,
60
+ test_fraction=test_fraction,
61
+ allow_generation=data_config.allow_generation,
62
+ )
63
+ self._normalized_input = data_config.normalized_input
64
+ self._quantile = 1.0
65
+ self._channelwise_quantile = False
66
+ self._background_quantile = 0.0
67
+ self._clip_background_noise_to_zero = False
68
+ self._skip_normalization_using_mean = False
69
+ self._empty_patch_replacement_enabled = False
70
+
71
+ self._background_values = None
72
+
73
+ self._grid_alignment = data_config.grid_alignment
74
+ self._overlapping_padding_kwargs = data_config.overlapping_padding_kwargs
75
+ if self._grid_alignment == GridAlignement.LeftTop:
76
+ assert (
77
+ self._overlapping_padding_kwargs is None
78
+ or data_config.multiscale_lowres_count is not None
79
+ ), "Padding is not used with this alignement style"
80
+ elif self._grid_alignment == GridAlignement.Center:
81
+ assert (
82
+ self._overlapping_padding_kwargs is not None
83
+ ), "With Center grid alignment, padding is needed."
84
+ if self._trim_boundary:
85
+ if (
86
+ self._overlapping_padding_kwargs is None
87
+ or data_config.multiscale_lowres_count is not None
88
+ ):
89
+ # raise warning
90
+ print("Padding is not used with this alignement style")
91
+ else:
92
+ assert (
93
+ self._overlapping_padding_kwargs is not None
94
+ ), "When not trimming boudnary, padding is needed."
95
+
96
+ self._is_train = data_config.datasplit_type == DataSplitType.Train
97
+
98
+ # input = alpha * ch1 + (1-alpha)*ch2.
99
+ # alpha is sampled randomly between these two extremes
100
+ self._start_alpha_arr = self._end_alpha_arr = self._return_alpha = None
101
+
102
+ self._img_sz = self._grid_sz = self._repeat_factor = self.idx_manager = None
103
+ if self._is_train:
104
+ self._start_alpha_arr = data_config.start_alpha
105
+ self._end_alpha_arr = data_config.end_alpha
106
+
107
+ self.set_img_sz(
108
+ data_config.image_size,
109
+ (
110
+ data_config.grid_size
111
+ if "grid_size" in data_config
112
+ else data_config.image_size
113
+ ),
114
+ )
115
+
116
+ if self._validtarget_rand_fract is not None:
117
+ self._train_index_switcher = IndexSwitcher(
118
+ self.idx_manager, data_config, self._img_sz
119
+ )
120
+
121
+ else:
122
+
123
+ self.set_img_sz(
124
+ data_config.image_size,
125
+ (
126
+ data_config.grid_size
127
+ if "grid_size" in data_config
128
+ else data_config.image_size
129
+ ),
130
+ )
131
+
132
+ self._return_alpha = False
133
+ self._return_index = False
134
+
135
+ self._empty_patch_replacement_enabled = (
136
+ data_config.empty_patch_replacement_enabled and self._is_train
137
+ )
138
+ if self._empty_patch_replacement_enabled:
139
+ self._empty_patch_replacement_channel_idx = (
140
+ data_config.empty_patch_replacement_channel_idx
141
+ )
142
+ self._empty_patch_replacement_probab = (
143
+ data_config.empty_patch_replacement_probab
144
+ )
145
+ data_frames = self._data[..., self._empty_patch_replacement_channel_idx]
146
+ # NOTE: This is on the raw data. So, it must be called before removing the background.
147
+ # TODO: missing import, needs fixing asap!
148
+ self._empty_patch_fetcher = EmptyPatchFetcher(
149
+ self.idx_manager,
150
+ self._img_sz,
151
+ data_frames,
152
+ max_val_threshold=data_config.empty_patch_max_val_threshold,
153
+ )
154
+
155
+ self.rm_bkground_set_max_val_and_upperclip_data(
156
+ data_config.max_val, data_config.datasplit_type
157
+ )
158
+
159
+ # For overlapping dloader, image_size and repeat_factors are not related. hence a different function.
160
+
161
+ self._mean = None
162
+ self._std = None
163
+ self._use_one_mu_std = data_config.use_one_mu_std
164
+ # Hardcoded
165
+ self._target_separate_normalization = True
166
+
167
+ self._enable_rotation = data_config.enable_rotation_aug
168
+ self._enable_random_cropping = data_config.enable_random_cropping
169
+ self._uncorrelated_channels = (
170
+ data_config.uncorrelated_channels and self._is_train
171
+ )
172
+ assert self._is_train or self._uncorrelated_channels is False
173
+ assert (
174
+ self._enable_random_cropping is True or self._uncorrelated_channels is False
175
+ )
176
+ # Randomly rotate [-90,90]
177
+
178
+ self._rotation_transform = None
179
+ if self._enable_rotation:
180
+ raise NotImplementedError(
181
+ "Augmentation by means of rotation is not supported yet."
182
+ )
183
+ self._rotation_transform = A.Compose([A.Flip(), A.RandomRotate90()])
184
+
185
+ # TODO: remove print log messages
186
+ # if print_vars:
187
+ # msg = self._init_msg()
188
+ # print(msg)
189
+
190
+ def disable_noise(self):
191
+ assert (
192
+ self._poisson_noise_factor is None
193
+ ), "This is not supported. Poisson noise is added to the data itself and so the noise cannot be disabled."
194
+ self._disable_noise = True
195
+
196
+ def enable_noise(self):
197
+ self._disable_noise = False
198
+
199
+ def get_data_shape(self):
200
+ return self._data.shape
201
+
202
+ def load_data(
203
+ self,
204
+ data_config,
205
+ datasplit_type,
206
+ val_fraction=None,
207
+ test_fraction=None,
208
+ allow_generation=None,
209
+ ):
210
+ self._data = get_train_val_data(
211
+ data_config,
212
+ self._fpath,
213
+ datasplit_type,
214
+ val_fraction=val_fraction,
215
+ test_fraction=test_fraction,
216
+ allow_generation=allow_generation,
217
+ )
218
+
219
+ old_shape = self._data.shape
220
+ if self._datausage_fraction < 1.0:
221
+ framepixelcount = np.prod(self._data.shape[1:3])
222
+ pixelcount = int(
223
+ len(self._data) * framepixelcount * self._datausage_fraction
224
+ )
225
+ frame_count = int(np.ceil(pixelcount / framepixelcount))
226
+ last_frame_reduced_size, _ = IndexSwitcher.get_reduced_frame_size(
227
+ self._data.shape[:3], self._datausage_fraction
228
+ )
229
+ self._data = self._data[:frame_count].copy()
230
+ if frame_count == 1:
231
+ self._data = self._data[
232
+ :, :last_frame_reduced_size, :last_frame_reduced_size
233
+ ].copy()
234
+ print(
235
+ f"[{self.__class__.__name__}] New data shape: {self._data.shape} Old: {old_shape}"
236
+ )
237
+
238
+ msg = ""
239
+ if data_config.poisson_noise_factor > 0:
240
+ self._poisson_noise_factor = data_config.poisson_noise_factor
241
+ msg += f"Adding Poisson noise with factor {self._poisson_noise_factor}.\t"
242
+ self._data = (
243
+ np.random.poisson(self._data / self._poisson_noise_factor)
244
+ * self._poisson_noise_factor
245
+ )
246
+
247
+ if data_config.enable_gaussian_noise:
248
+ synthetic_scale = data_config.synthetic_gaussian_scale
249
+ msg += f"Adding Gaussian noise with scale {synthetic_scale}"
250
+ # 0 => noise for input. 1: => noise for all targets.
251
+ shape = self._data.shape
252
+ self._noise_data = np.random.normal(
253
+ 0, synthetic_scale, (*shape[:-1], shape[-1] + 1)
254
+ )
255
+ if data_config.input_has_dependant_noise:
256
+ msg += ". Moreover, input has dependent noise"
257
+ self._noise_data[..., 0] = np.mean(self._noise_data[..., 1:], axis=-1)
258
+ print(msg)
259
+
260
+ self._5Ddata = len(self._data.shape) == 5
261
+ if self._5Ddata:
262
+ self.Z = self._data.shape[1]
263
+
264
+ if self._depth3D > 1:
265
+ assert self._5Ddata, "Data must be 5D:NxZxHxWxC for 3D data"
266
+
267
+ assert (
268
+ self._data.shape[-1] == self._num_channels
269
+ ), "Number of channels in data and config do not match."
270
+
271
+ def save_background(self, channel_idx, frame_idx, background_value):
272
+ self._background_values[frame_idx, channel_idx] = background_value
273
+
274
+ def get_background(self, channel_idx, frame_idx):
275
+ return self._background_values[frame_idx, channel_idx]
276
+
277
+ def remove_background(self):
278
+
279
+ self._background_values = np.zeros((self._data.shape[0], self._data.shape[-1]))
280
+
281
+ if self._background_quantile == 0.0:
282
+ assert (
283
+ self._clip_background_noise_to_zero is False
284
+ ), "This operation currently happens later in this function."
285
+ return
286
+
287
+ if self._data.dtype in [np.uint16]:
288
+ # unsigned integer creates havoc
289
+ self._data = self._data.astype(np.int32)
290
+
291
+ for ch in range(self._data.shape[-1]):
292
+ for idx in range(self._data.shape[0]):
293
+ qval = np.quantile(self._data[idx, ..., ch], self._background_quantile)
294
+ assert (
295
+ np.abs(qval) > 20
296
+ ), "We are truncating the qval to an integer which will only make sense if it is large enough"
297
+ # NOTE: Here, there can be an issue if you work with normalized data
298
+ qval = int(qval)
299
+ self.save_background(ch, idx, qval)
300
+ self._data[idx, ..., ch] -= qval
301
+
302
+ if self._clip_background_noise_to_zero:
303
+ self._data[self._data < 0] = 0
304
+
305
+ def rm_bkground_set_max_val_and_upperclip_data(self, max_val, datasplit_type):
306
+ self.remove_background()
307
+ self.set_max_val(max_val, datasplit_type)
308
+ self.upperclip_data()
309
+
310
+ def upperclip_data(self):
311
+ if isinstance(self.max_val, list):
312
+ chN = self._data.shape[-1]
313
+ assert chN == len(self.max_val)
314
+ for ch in range(chN):
315
+ ch_data = self._data[..., ch]
316
+ ch_q = self.max_val[ch]
317
+ ch_data[ch_data > ch_q] = ch_q
318
+ self._data[..., ch] = ch_data
319
+ else:
320
+ self._data[self._data > self.max_val] = self.max_val
321
+
322
+ def compute_max_val(self):
323
+ if self._channelwise_quantile:
324
+ max_val_arr = [
325
+ np.quantile(self._data[..., i], self._quantile)
326
+ for i in range(self._data.shape[-1])
327
+ ]
328
+ return max_val_arr
329
+ else:
330
+ return np.quantile(self._data, self._quantile)
331
+
332
+ def set_max_val(self, max_val, datasplit_type):
333
+
334
+ if max_val is None:
335
+ assert datasplit_type == DataSplitType.Train
336
+ self.max_val = self.compute_max_val()
337
+ else:
338
+ assert max_val is not None
339
+ self.max_val = max_val
340
+
341
+ def get_max_val(self):
342
+ return self.max_val
343
+
344
+ def get_img_sz(self):
345
+ return self._img_sz
346
+
347
+ def get_num_frames(self):
348
+ return self._data.shape[0]
349
+
350
+ def reduce_data(
351
+ self, t_list=None, h_start=None, h_end=None, w_start=None, w_end=None
352
+ ):
353
+ assert not self._5Ddata, "This function is not supported for 3D data."
354
+ if t_list is None:
355
+ t_list = list(range(self._data.shape[0]))
356
+ if h_start is None:
357
+ h_start = 0
358
+ if h_end is None:
359
+ h_end = self._data.shape[1]
360
+ if w_start is None:
361
+ w_start = 0
362
+ if w_end is None:
363
+ w_end = self._data.shape[2]
364
+
365
+ self._data = self._data[t_list, h_start:h_end, w_start:w_end, :].copy()
366
+ if self._noise_data is not None:
367
+ self._noise_data = self._noise_data[
368
+ t_list, h_start:h_end, w_start:w_end, :
369
+ ].copy()
370
+
371
+ self.set_img_sz(self._img_sz, self._grid_sz)
372
+ print(
373
+ f"[{self.__class__.__name__}] Data reduced. New data shape: {self._data.shape}"
374
+ )
375
+
376
+ def get_idx_manager_shapes(self, patch_size: int, grid_size: int):
377
+ numC = self._data.shape[-1]
378
+ if self._5Ddata:
379
+ grid_shape = (1, 1, grid_size, grid_size, numC)
380
+ patch_shape = (1, self._depth3D, patch_size, patch_size, numC)
381
+ else:
382
+ grid_shape = (1, grid_size, grid_size, numC)
383
+ patch_shape = (1, patch_size, patch_size, numC)
384
+
385
+ return patch_shape, grid_shape
386
+
387
+ def set_img_sz(self, image_size, grid_size):
388
+ """
389
+ If one wants to change the image size on the go, then this can be used.
390
+ Args:
391
+ image_size: size of one patch
392
+ grid_size: frame is divided into square grids of this size. A patch centered on a grid having size `image_size` is returned.
393
+ """
394
+
395
+ self._img_sz = image_size
396
+ self._grid_sz = grid_size
397
+ shape = self._data.shape
398
+
399
+ patch_shape, grid_shape = self.get_idx_manager_shapes(
400
+ self._img_sz, self._grid_sz
401
+ )
402
+ self.idx_manager = GridIndexManager(
403
+ shape, grid_shape, patch_shape, self._trim_boundary
404
+ )
405
+ # self.set_repeat_factor()
406
+
407
+ def __len__(self):
408
+ # Vera: N is the number of frames in Z stack
409
+ # Repeat factor is n_rows * n_cols
410
+ return self.idx_manager.total_grid_count()
411
+
412
+ def set_repeat_factor(self):
413
+ if self._grid_sz > 1:
414
+ self._repeat_factor = self.idx_manager.grid_rows(
415
+ self._grid_sz
416
+ ) * self.idx_manager.grid_cols(self._grid_sz)
417
+ else:
418
+ self._repeat_factor = self.idx_manager.grid_rows(
419
+ self._img_sz
420
+ ) * self.idx_manager.grid_cols(self._img_sz)
421
+
422
+ def _init_msg(
423
+ self,
424
+ ):
425
+ msg = (
426
+ f"[{self.__class__.__name__}] Train:{int(self._is_train)} Sz:{self._img_sz}"
427
+ )
428
+ dim_sizes = [
429
+ self.idx_manager.get_individual_dim_grid_count(dim)
430
+ for dim in range(len(self._data.shape))
431
+ ]
432
+ dim_sizes = ",".join([str(x) for x in dim_sizes])
433
+ msg += f" N:{self.N} NumPatchPerN:{self._repeat_factor}"
434
+ msg += f"{self.idx_manager.total_grid_count()} DimSz:({dim_sizes})"
435
+ msg += f" TrimB:{self._trim_boundary}"
436
+ # msg += f' NormInp:{self._normalized_input}'
437
+ # msg += f' SingleNorm:{self._use_one_mu_std}'
438
+ msg += f" Rot:{self._enable_rotation}"
439
+ msg += f" RandCrop:{self._enable_random_cropping}"
440
+ msg += f" Channel:{self._num_channels}"
441
+ # msg += f' Q:{self._quantile}'
442
+ if self._input_is_sum:
443
+ msg += f" SummedInput:{self._input_is_sum}"
444
+
445
+ if self._empty_patch_replacement_enabled:
446
+ msg += f" ReplaceWithRandSample:{self._empty_patch_replacement_enabled}"
447
+ if self._uncorrelated_channels:
448
+ msg += f" Uncorr:{self._uncorrelated_channels}"
449
+ if self._empty_patch_replacement_enabled:
450
+ msg += f"-{self._empty_patch_replacement_channel_idx}-{self._empty_patch_replacement_probab}"
451
+ if self._background_quantile > 0.0:
452
+ msg += f" BckQ:{self._background_quantile}"
453
+
454
+ if self._start_alpha_arr is not None:
455
+ msg += f" Alpha:[{self._start_alpha_arr},{self._end_alpha_arr}]"
456
+ return msg
457
+
458
+ def _crop_imgs(self, index, *img_tuples: np.ndarray):
459
+ h, w = img_tuples[0].shape[-2:]
460
+ if self._img_sz is None:
461
+ return (
462
+ *img_tuples,
463
+ {"h": [0, h], "w": [0, w], "hflip": False, "wflip": False},
464
+ )
465
+
466
+ if self._enable_random_cropping:
467
+ patch_start_loc = self._get_random_hw(h, w)
468
+ if self._5Ddata:
469
+ patch_start_loc = (
470
+ np.random.choice(img_tuples[0].shape[-3] - self._depth3D),
471
+ ) + patch_start_loc
472
+ else:
473
+ patch_start_loc = self._get_deterministic_loc(index)
474
+
475
+ cropped_imgs = []
476
+ for img in img_tuples:
477
+ img = self._crop_flip_img(img, patch_start_loc, False, False)
478
+ cropped_imgs.append(img)
479
+
480
+ return (
481
+ *tuple(cropped_imgs),
482
+ {
483
+ "hflip": False,
484
+ "wflip": False,
485
+ },
486
+ )
487
+
488
+ def _crop_img(self, img: np.ndarray, patch_start_loc: Tuple):
489
+ if self._trim_boundary:
490
+ # In training, this is used.
491
+ # NOTE: It is my opinion that if I just use self._crop_img_with_padding, it will work perfectly fine.
492
+ # The only benefit this if else loop provides is that it makes it easier to see what happens during training.
493
+ patch_end_loc = (
494
+ np.array(patch_start_loc, dtype=np.int32)
495
+ + self.idx_manager.patch_shape[1:-1]
496
+ )
497
+ if self._5Ddata:
498
+ z_start, h_start, w_start = patch_start_loc
499
+ z_end, h_end, w_end = patch_end_loc
500
+ new_img = img[..., z_start:z_end, h_start:h_end, w_start:w_end]
501
+ else:
502
+ h_start, w_start = patch_start_loc
503
+ h_end, w_end = patch_end_loc
504
+ new_img = img[..., h_start:h_end, w_start:w_end]
505
+
506
+ return new_img
507
+ else:
508
+ # During evaluation, this is used. In this situation, we can have negative h_start, w_start. Or h_start +self._img_sz can be larger than frame
509
+ # In these situations, we need some sort of padding. This is not needed in the LeftTop alignement.
510
+ return self._crop_img_with_padding(img, patch_start_loc)
511
+
512
+ def get_begin_end_padding(self, start_pos, end_pos, max_len):
513
+ """
514
+ The effect is that the image with size self._grid_sz is in the center of the patch with sufficient
515
+ padding on all four sides so that the final patch size is self._img_sz.
516
+ """
517
+ pad_start = 0
518
+ pad_end = 0
519
+ if start_pos < 0:
520
+ pad_start = -1 * start_pos
521
+
522
+ pad_end = max(0, end_pos - max_len)
523
+
524
+ return pad_start, pad_end
525
+
526
+ def _crop_img_with_padding(
527
+ self, img: np.ndarray, patch_start_loc, max_len_vals=None
528
+ ):
529
+ if max_len_vals is None:
530
+ max_len_vals = self.idx_manager.data_shape[1:-1]
531
+ patch_end_loc = np.array(patch_start_loc, dtype=int) + np.array(
532
+ self.idx_manager.patch_shape[1:-1], dtype=int
533
+ )
534
+ boundary_crossed = []
535
+ valid_slice = []
536
+ padding = [[0, 0]]
537
+ for start_idx, end_idx, max_len in zip(
538
+ patch_start_loc, patch_end_loc, max_len_vals
539
+ ):
540
+ boundary_crossed.append(end_idx > max_len or start_idx < 0)
541
+ valid_slice.append((max(0, start_idx), min(max_len, end_idx)))
542
+ pad = [0, 0]
543
+ if boundary_crossed[-1]:
544
+ pad = self.get_begin_end_padding(start_idx, end_idx, max_len)
545
+ padding.append(pad)
546
+ # max() is needed since h_start could be negative.
547
+ if self._5Ddata:
548
+ new_img = img[
549
+ ...,
550
+ valid_slice[0][0] : valid_slice[0][1],
551
+ valid_slice[1][0] : valid_slice[1][1],
552
+ valid_slice[2][0] : valid_slice[2][1],
553
+ ]
554
+ else:
555
+ new_img = img[
556
+ ...,
557
+ valid_slice[0][0] : valid_slice[0][1],
558
+ valid_slice[1][0] : valid_slice[1][1],
559
+ ]
560
+
561
+ # print(np.array(padding).shape, img.shape, new_img.shape)
562
+ # print(padding)
563
+ if not np.all(padding == 0):
564
+ new_img = np.pad(new_img, padding, **self._overlapping_padding_kwargs)
565
+
566
+ return new_img
567
+
568
+ def _crop_flip_img(
569
+ self, img: np.ndarray, patch_start_loc: Tuple, h_flip: bool, w_flip: bool
570
+ ):
571
+ new_img = self._crop_img(img, patch_start_loc)
572
+ if h_flip:
573
+ new_img = new_img[..., ::-1, :]
574
+ if w_flip:
575
+ new_img = new_img[..., :, ::-1]
576
+
577
+ return new_img.astype(np.float32)
578
+
579
+ def _load_img(
580
+ self, index: Union[int, Tuple[int, int]]
581
+ ) -> Tuple[np.ndarray, np.ndarray]:
582
+ """
583
+ Returns the channels and also the respective noise channels.
584
+ """
585
+ if isinstance(index, int) or isinstance(index, np.int64):
586
+ idx = index
587
+ else:
588
+ idx = index[0]
589
+
590
+ patch_loc_list = self.idx_manager.get_patch_location_from_dataset_idx(idx)
591
+ imgs = self._data[patch_loc_list[0]]
592
+ # if self._5Ddata:
593
+ # assert self._noise_data is None, 'Noise is not supported for 5D data'
594
+ # n_loc, z_loc = patch_loc_list[:2]
595
+ # z_loc_interval = range(z_loc, z_loc + self._depth3D)
596
+ # imgs = self._data[n_loc, z_loc_interval]
597
+ # else:
598
+ # imgs = self._data[patch_loc_list[0]]
599
+
600
+ loaded_imgs = [imgs[None, ..., i] for i in range(imgs.shape[-1])]
601
+ noise = []
602
+ if self._noise_data is not None and not self._disable_noise:
603
+ noise = [
604
+ self._noise_data[patch_loc_list[0]][None, ..., i]
605
+ for i in range(self._noise_data.shape[-1])
606
+ ]
607
+ return tuple(loaded_imgs), tuple(noise)
608
+
609
+ def get_mean_std(self):
610
+ return self._mean, self._std
611
+
612
+ def set_mean_std(self, mean_val, std_val):
613
+ self._mean = mean_val
614
+ self._std = std_val
615
+
616
+ def normalize_img(self, *img_tuples):
617
+ mean, std = self.get_mean_std()
618
+ mean = mean["target"]
619
+ std = std["target"]
620
+ mean = mean.squeeze()
621
+ std = std.squeeze()
622
+ normalized_imgs = []
623
+ for i, img in enumerate(img_tuples):
624
+ img = (img - mean[i]) / std[i]
625
+ normalized_imgs.append(img)
626
+ return tuple(normalized_imgs)
627
+
628
+ def get_grid_size(self):
629
+ return self._grid_sz
630
+
631
+ def get_idx_manager(self):
632
+ return self.idx_manager
633
+
634
+ def per_side_overlap_pixelcount(self):
635
+ return (self._img_sz - self._grid_sz) // 2
636
+
637
+ # def on_boundary(self, cur_loc, frame_size):
638
+ # return cur_loc + self._img_sz > frame_size or cur_loc < 0
639
+
640
+ def _get_deterministic_loc(self, index: int):
641
+ """
642
+ It returns the top-left corner of the patch corresponding to index.
643
+ """
644
+ loc_list = self.idx_manager.get_patch_location_from_dataset_idx(index)
645
+ # last dim is channel. we need to take the third and the second last element.
646
+ return loc_list[1:-1]
647
+
648
+ def compute_individual_mean_std(self):
649
+ # numpy 1.19.2 has issues in computing for large arrays. https://github.com/numpy/numpy/issues/8869
650
+ # mean = np.mean(self._data, axis=(0, 1, 2))
651
+ # std = np.std(self._data, axis=(0, 1, 2))
652
+ mean_arr = []
653
+ std_arr = []
654
+ for ch_idx in range(self._data.shape[-1]):
655
+ mean_ = (
656
+ 0.0
657
+ if self._skip_normalization_using_mean
658
+ else self._data[..., ch_idx].mean()
659
+ )
660
+ if self._noise_data is not None:
661
+ std_ = (
662
+ self._data[..., ch_idx] + self._noise_data[..., ch_idx + 1]
663
+ ).std()
664
+ else:
665
+ std_ = self._data[..., ch_idx].std()
666
+
667
+ mean_arr.append(mean_)
668
+ std_arr.append(std_)
669
+
670
+ mean = np.array(mean_arr)
671
+ std = np.array(std_arr)
672
+ if (
673
+ self._5Ddata
674
+ ): # NOTE: IDEALLY this should be only when the model expects 3D data.
675
+ return mean[None, :, None, None, None], std[None, :, None, None, None]
676
+
677
+ return mean[None, :, None, None], std[None, :, None, None]
678
+
679
+ def compute_mean_std(self, allow_for_validation_data=False):
680
+ """
681
+ Note that we must compute this only for training data.
682
+ """
683
+ assert (
684
+ self._is_train is True or allow_for_validation_data
685
+ ), "This is just allowed for training data"
686
+ assert self._use_one_mu_std is True, "This is the only supported case"
687
+
688
+ if self._input_idx is not None:
689
+ assert (
690
+ self._tar_idx_list is not None
691
+ ), "tar_idx_list must be set if input_idx is set."
692
+ assert self._noise_data is None, "This is not supported with noise"
693
+ assert (
694
+ self._target_separate_normalization is True
695
+ ), "This is not supported with target_separate_normalization=False"
696
+
697
+ mean, std = self.compute_individual_mean_std()
698
+ mean_dict = {
699
+ "input": mean[:, self._input_idx : self._input_idx + 1],
700
+ "target": mean[:, self._tar_idx_list],
701
+ }
702
+ std_dict = {
703
+ "input": std[:, self._input_idx : self._input_idx + 1],
704
+ "target": std[:, self._tar_idx_list],
705
+ }
706
+ return mean_dict, std_dict
707
+
708
+ if self._input_is_sum:
709
+ assert self._noise_data is None, "This is not supported with noise"
710
+ mean = [
711
+ np.mean(self._data[..., k : k + 1], keepdims=True)
712
+ for k in range(self._num_channels)
713
+ ]
714
+ mean = np.sum(mean, keepdims=True)[0]
715
+ std = np.linalg.norm(
716
+ [
717
+ np.std(self._data[..., k : k + 1], keepdims=True)
718
+ for k in range(self._num_channels)
719
+ ],
720
+ keepdims=True,
721
+ )[0]
722
+ else:
723
+ mean = np.mean(self._data, keepdims=True).reshape(1, 1, 1, 1)
724
+ if self._noise_data is not None:
725
+ std = np.std(
726
+ self._data + self._noise_data[..., 1:], keepdims=True
727
+ ).reshape(1, 1, 1, 1)
728
+ else:
729
+ std = np.std(self._data, keepdims=True).reshape(1, 1, 1, 1)
730
+
731
+ mean = np.repeat(mean, self._num_channels, axis=1)
732
+ std = np.repeat(std, self._num_channels, axis=1)
733
+
734
+ if self._skip_normalization_using_mean:
735
+ mean = np.zeros_like(mean)
736
+
737
+ if self._5Ddata:
738
+ mean = mean[:, :, None]
739
+ std = std[:, :, None]
740
+
741
+ mean_dict = {"input": mean} # , 'target':mean}
742
+ std_dict = {"input": std} # , 'target':std}
743
+
744
+ if self._target_separate_normalization:
745
+ mean, std = self.compute_individual_mean_std()
746
+
747
+ mean_dict["target"] = mean
748
+ std_dict["target"] = std
749
+ return mean_dict, std_dict
750
+
751
+ def _get_random_hw(self, h: int, w: int):
752
+ """
753
+ Random starting position for the crop for the img with index `index`.
754
+ """
755
+ if h != self._img_sz:
756
+ h_start = np.random.choice(h - self._img_sz)
757
+ w_start = np.random.choice(w - self._img_sz)
758
+ else:
759
+ h_start = 0
760
+ w_start = 0
761
+ return h_start, w_start
762
+
763
+ def _get_img(self, index: Union[int, Tuple[int, int]]):
764
+ """
765
+ Loads an image.
766
+ Crops the image such that cropped image has content.
767
+ """
768
+ img_tuples, noise_tuples = self._load_img(index)
769
+ cropped_img_tuples = self._crop_imgs(index, *img_tuples, *noise_tuples)[:-1]
770
+ cropped_noise_tuples = cropped_img_tuples[len(img_tuples) :]
771
+ cropped_img_tuples = cropped_img_tuples[: len(img_tuples)]
772
+ return cropped_img_tuples, cropped_noise_tuples
773
+
774
+ def replace_with_empty_patch(self, img_tuples):
775
+ """
776
+ Replaces the content of one of the channels with background
777
+ """
778
+ empty_index = self._empty_patch_fetcher.sample()
779
+ empty_img_tuples, empty_img_noise_tuples = self._get_img(empty_index)
780
+ assert (
781
+ len(empty_img_noise_tuples) == 0
782
+ ), "Noise is not supported with empty patch replacement"
783
+ final_img_tuples = []
784
+ for tuple_idx in range(len(img_tuples)):
785
+ if tuple_idx == self._empty_patch_replacement_channel_idx:
786
+ final_img_tuples.append(empty_img_tuples[tuple_idx])
787
+ else:
788
+ final_img_tuples.append(img_tuples[tuple_idx])
789
+ return tuple(final_img_tuples)
790
+
791
+ def get_mean_std_for_input(self):
792
+ mean, std = self.get_mean_std()
793
+ return mean["input"], std["input"]
794
+
795
+ def _compute_target(self, img_tuples, alpha):
796
+ if self._tar_idx_list is not None and isinstance(self._tar_idx_list, int):
797
+ target = img_tuples[self._tar_idx_list]
798
+ else:
799
+ if self._tar_idx_list is not None:
800
+ assert isinstance(self._tar_idx_list, list) or isinstance(
801
+ self._tar_idx_list, tuple
802
+ )
803
+ img_tuples = [img_tuples[i] for i in self._tar_idx_list]
804
+
805
+ target = np.concatenate(img_tuples, axis=0)
806
+ return target
807
+
808
+ def _compute_input_with_alpha(self, img_tuples, alpha_list):
809
+ # assert self._normalized_input is True, "normalization should happen here"
810
+ if self._input_idx is not None:
811
+ inp = img_tuples[self._input_idx]
812
+ else:
813
+ inp = 0
814
+ for alpha, img in zip(alpha_list, img_tuples):
815
+ inp += img * alpha
816
+
817
+ if self._normalized_input is False:
818
+ return inp.astype(np.float32)
819
+
820
+ mean, std = self.get_mean_std_for_input()
821
+ mean = mean.squeeze()
822
+ std = std.squeeze()
823
+ if mean.size == 1:
824
+ mean = mean.reshape(
825
+ 1,
826
+ )
827
+ std = std.reshape(
828
+ 1,
829
+ )
830
+
831
+ for i in range(len(mean)):
832
+ assert mean[0] == mean[i]
833
+ assert std[0] == std[i]
834
+
835
+ inp = (inp - mean[0]) / std[0]
836
+ return inp.astype(np.float32)
837
+
838
+ def _sample_alpha(self):
839
+ alpha_arr = []
840
+ for i in range(self._num_channels):
841
+ alpha_pos = np.random.rand()
842
+ alpha = self._start_alpha_arr[i] + alpha_pos * (
843
+ self._end_alpha_arr[i] - self._start_alpha_arr[i]
844
+ )
845
+ alpha_arr.append(alpha)
846
+ return alpha_arr
847
+
848
+ def _compute_input(self, img_tuples):
849
+ alpha = [1 / len(img_tuples) for _ in range(len(img_tuples))]
850
+ if self._start_alpha_arr is not None:
851
+ alpha = self._sample_alpha()
852
+
853
+ inp = self._compute_input_with_alpha(img_tuples, alpha)
854
+ if self._input_is_sum:
855
+ inp = len(img_tuples) * inp
856
+ return inp, alpha
857
+
858
+ def _get_index_from_valid_target_logic(self, index):
859
+ if self._validtarget_rand_fract is not None:
860
+ if np.random.rand() < self._validtarget_rand_fract:
861
+ index = self._train_index_switcher.get_valid_target_index()
862
+ else:
863
+ index = self._train_index_switcher.get_invalid_target_index()
864
+ return index
865
+
866
+ def _rotate2D(self, img_tuples, noise_tuples):
867
+ img_kwargs = {}
868
+ for i, img in enumerate(img_tuples):
869
+ for k in range(len(img)):
870
+ img_kwargs[f"img{i}_{k}"] = img[k]
871
+
872
+ noise_kwargs = {}
873
+ for i, nimg in enumerate(noise_tuples):
874
+ for k in range(len(nimg)):
875
+ noise_kwargs[f"noise{i}_{k}"] = nimg[k]
876
+
877
+ keys = list(img_kwargs.keys()) + list(noise_kwargs.keys())
878
+ self._rotation_transform.add_targets({k: "image" for k in keys})
879
+ rot_dic = self._rotation_transform(
880
+ image=img_tuples[0][0], **img_kwargs, **noise_kwargs
881
+ )
882
+
883
+ rotated_img_tuples = []
884
+ for i, img in enumerate(img_tuples):
885
+ if len(img) == 1:
886
+ rotated_img_tuples.append(rot_dic[f"img{i}_0"][None])
887
+ else:
888
+ rotated_img_tuples.append(
889
+ np.concatenate(
890
+ [rot_dic[f"img{i}_{k}"][None] for k in range(len(img))], axis=0
891
+ )
892
+ )
893
+
894
+ rotated_noise_tuples = []
895
+ for i, nimg in enumerate(noise_tuples):
896
+ if len(nimg) == 1:
897
+ rotated_noise_tuples.append(rot_dic[f"noise{i}_0"][None])
898
+ else:
899
+ rotated_noise_tuples.append(
900
+ np.concatenate(
901
+ [rot_dic[f"noise{i}_{k}"][None] for k in range(len(nimg))],
902
+ axis=0,
903
+ )
904
+ )
905
+
906
+ return rotated_img_tuples, rotated_noise_tuples
907
+
908
+ def _rotate(self, img_tuples, noise_tuples):
909
+ if self._depth3D > 1:
910
+ return self._rotate3D(img_tuples, noise_tuples)
911
+ else:
912
+ return self._rotate2D(img_tuples, noise_tuples)
913
+
914
+ def _rotate3D(self, img_tuples, noise_tuples):
915
+ img_kwargs = {}
916
+ for i, img in enumerate(img_tuples):
917
+ for j in range(self._depth3D):
918
+ for k in range(len(img)):
919
+ img_kwargs[f"img{i}_{j}_{k}"] = img[k, j]
920
+
921
+ noise_kwargs = {}
922
+ for i, nimg in enumerate(noise_tuples):
923
+ for j in range(self._depth3D):
924
+ for k in range(len(nimg)):
925
+ noise_kwargs[f"noise{i}_{j}_{k}"] = nimg[k, j]
926
+
927
+ keys = list(img_kwargs.keys()) + list(noise_kwargs.keys())
928
+ self._rotation_transform.add_targets({k: "image" for k in keys})
929
+ rot_dic = self._rotation_transform(
930
+ image=img_tuples[0][0], **img_kwargs, **noise_kwargs
931
+ )
932
+ rotated_img_tuples = []
933
+ for i, img in enumerate(img_tuples):
934
+ if len(img) == 1:
935
+ rotated_img_tuples.append(
936
+ np.concatenate(
937
+ [
938
+ rot_dic[f"img{i}_{j}_0"][None, None]
939
+ for j in range(self._depth3D)
940
+ ],
941
+ axis=1,
942
+ )
943
+ )
944
+ else:
945
+ temp_arr = []
946
+ for k in range(len(img)):
947
+ temp_arr.append(
948
+ np.concatenate(
949
+ [
950
+ rot_dic[f"img{i}_{j}_{k}"][None, None]
951
+ for j in range(self._depth3D)
952
+ ],
953
+ axis=1,
954
+ )
955
+ )
956
+ rotated_img_tuples.append(np.concatenate(temp_arr, axis=0))
957
+
958
+ rotated_noise_tuples = []
959
+ for i, nimg in enumerate(noise_tuples):
960
+ if len(nimg) == 1:
961
+ rotated_noise_tuples.append(
962
+ np.concatenate(
963
+ [
964
+ rot_dic[f"noise{i}_{j}_0"][None, None]
965
+ for j in range(self._depth3D)
966
+ ],
967
+ axis=1,
968
+ )
969
+ )
970
+ else:
971
+ temp_arr = []
972
+ for k in range(len(nimg)):
973
+ temp_arr.append(
974
+ np.concatenate(
975
+ [
976
+ rot_dic[f"noise{i}_{j}_{k}"][None, None]
977
+ for j in range(self._depth3D)
978
+ ],
979
+ axis=1,
980
+ )
981
+ )
982
+ rotated_noise_tuples.append(np.concatenate(temp_arr, axis=0))
983
+
984
+ return rotated_img_tuples, rotated_noise_tuples
985
+
986
+ def get_uncorrelated_img_tuples(self, index):
987
+ """
988
+ Content of channels like actin and nuclei is "correlated" in its
989
+ respective location, this function allows to pick channels' content
990
+ from different patches of the image to make it "uncorrelated".
991
+ """
992
+ img_tuples, noise_tuples = self._get_img(index)
993
+ assert len(noise_tuples) == 0
994
+ img_tuples = [img_tuples[0]]
995
+ for ch_idx in range(1, len(img_tuples)):
996
+ new_index = np.random.randint(len(self))
997
+ other_img_tuples, _ = self._get_img(new_index)
998
+ img_tuples.append(other_img_tuples[ch_idx])
999
+ return img_tuples, noise_tuples
1000
+
1001
+ def __getitem__(
1002
+ self, index: Union[int, Tuple[int, int]]
1003
+ ) -> Tuple[np.ndarray, np.ndarray]:
1004
+ # Vera: input can be both real microscopic image and two separate channels that are summed in the code
1005
+
1006
+ if self._train_index_switcher is not None:
1007
+ index = self._get_index_from_valid_target_logic(index)
1008
+
1009
+ if self._uncorrelated_channels:
1010
+ img_tuples, noise_tuples = self.get_uncorrelated_img_tuples(index)
1011
+ else:
1012
+ img_tuples, noise_tuples = self._get_img(index)
1013
+
1014
+ assert (
1015
+ self._empty_patch_replacement_enabled != True
1016
+ ), "This is not supported with noise"
1017
+
1018
+ # Replace the content of one of the channels
1019
+ # with background with given probability
1020
+ if self._empty_patch_replacement_enabled:
1021
+ if np.random.rand() < self._empty_patch_replacement_probab:
1022
+ img_tuples = self.replace_with_empty_patch(img_tuples)
1023
+
1024
+ # Noise tuples are not needed for the paper
1025
+ # the image tuples are noisy by default
1026
+ # TODO: remove noise tuples completely?
1027
+ if self._enable_rotation:
1028
+ img_tuples, noise_tuples = self._rotate(img_tuples, noise_tuples)
1029
+
1030
+ # Add noise tuples with image tuples to create the input
1031
+ if len(noise_tuples) > 0:
1032
+ factor = np.sqrt(2) if self._input_is_sum else 1.0
1033
+ input_tuples = [x + noise_tuples[0] * factor for x in img_tuples]
1034
+ else:
1035
+ input_tuples = img_tuples
1036
+
1037
+ # Weight the individual channels, typically alpha is fixed
1038
+ inp, alpha = self._compute_input(input_tuples)
1039
+
1040
+ # Add noise tuples to the image tuples to create the target
1041
+ if len(noise_tuples) >= 1:
1042
+ img_tuples = [x + noise for x, noise in zip(img_tuples, noise_tuples[1:])]
1043
+
1044
+ target = self._compute_target(img_tuples, alpha)
1045
+
1046
+ output = [inp, target]
1047
+
1048
+ if self._return_alpha:
1049
+ output.append(alpha)
1050
+
1051
+ if self._return_index:
1052
+ output.append(index)
1053
+
1054
+ return tuple(output)