careamics 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (64) hide show
  1. careamics/careamist.py +14 -11
  2. careamics/config/__init__.py +7 -3
  3. careamics/config/architectures/__init__.py +2 -2
  4. careamics/config/architectures/architecture_model.py +1 -1
  5. careamics/config/architectures/custom_model.py +11 -8
  6. careamics/config/architectures/lvae_model.py +174 -0
  7. careamics/config/configuration_factory.py +11 -3
  8. careamics/config/configuration_model.py +7 -3
  9. careamics/config/data_model.py +33 -8
  10. careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +28 -43
  11. careamics/config/likelihood_model.py +43 -0
  12. careamics/config/nm_model.py +101 -0
  13. careamics/config/support/supported_activations.py +1 -0
  14. careamics/config/support/supported_algorithms.py +17 -4
  15. careamics/config/support/supported_architectures.py +8 -11
  16. careamics/config/support/supported_losses.py +3 -1
  17. careamics/config/transformations/n2v_manipulate_model.py +1 -1
  18. careamics/config/vae_algorithm_model.py +171 -0
  19. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  20. careamics/file_io/read/tiff.py +1 -1
  21. careamics/lightning/__init__.py +3 -2
  22. careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
  23. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
  24. careamics/lightning/lightning_module.py +365 -9
  25. careamics/lightning/predict_data_module.py +2 -2
  26. careamics/lightning/train_data_module.py +2 -2
  27. careamics/losses/__init__.py +11 -1
  28. careamics/losses/fcn/__init__.py +1 -0
  29. careamics/losses/{losses.py → fcn/losses.py} +1 -1
  30. careamics/losses/loss_factory.py +112 -6
  31. careamics/losses/lvae/__init__.py +1 -0
  32. careamics/losses/lvae/loss_utils.py +83 -0
  33. careamics/losses/lvae/losses.py +445 -0
  34. careamics/lvae_training/dataset/__init__.py +0 -0
  35. careamics/lvae_training/{data_utils.py → dataset/data_utils.py} +277 -194
  36. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  37. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  38. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  39. careamics/lvae_training/{data_modules.py → dataset/vae_dataset.py} +306 -472
  40. careamics/lvae_training/get_config.py +1 -1
  41. careamics/lvae_training/train_lvae.py +6 -3
  42. careamics/model_io/bioimage/bioimage_utils.py +1 -1
  43. careamics/model_io/bioimage/model_description.py +2 -2
  44. careamics/model_io/bmz_io.py +19 -6
  45. careamics/model_io/model_io_utils.py +16 -4
  46. careamics/models/__init__.py +1 -3
  47. careamics/models/activation.py +2 -0
  48. careamics/models/lvae/__init__.py +3 -0
  49. careamics/models/lvae/layers.py +21 -21
  50. careamics/models/lvae/likelihoods.py +180 -128
  51. careamics/models/lvae/lvae.py +52 -136
  52. careamics/models/lvae/noise_models.py +318 -186
  53. careamics/models/lvae/utils.py +2 -2
  54. careamics/models/model_factory.py +22 -7
  55. careamics/prediction_utils/lvae_prediction.py +158 -0
  56. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  57. careamics/prediction_utils/stitch_prediction.py +16 -2
  58. careamics/transforms/pixel_manipulation.py +1 -1
  59. careamics/utils/metrics.py +74 -1
  60. {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/METADATA +2 -2
  61. {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/RECORD +63 -49
  62. careamics/config/architectures/vae_model.py +0 -42
  63. {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/WHEEL +0 -0
  64. {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +0 -0
@@ -2,189 +2,65 @@
2
2
  A place for Datasets and Dataloaders.
3
3
  """
4
4
 
5
- import os
6
5
  from typing import Tuple, Union
7
6
 
8
- # import albumentations as A
9
- import ml_collections
10
7
  import numpy as np
11
- from skimage.transform import resize
12
8
 
13
9
  from .data_utils import (
14
- DataSplitType,
15
- DataType,
16
- GridAlignement,
17
10
  GridIndexManager,
18
11
  IndexSwitcher,
19
- get_datasplit_tuples,
20
- get_mrc_data,
21
- load_tiff,
12
+ get_train_val_data,
22
13
  )
23
-
24
-
25
- def get_train_val_data(
26
- data_config,
27
- fpath,
28
- datasplit_type: DataSplitType,
29
- val_fraction=None,
30
- test_fraction=None,
31
- allow_generation=None,
32
- ignore_specific_datapoints=None,
33
- ):
34
- """
35
- Load the data from the given path and split them in training, validation and test sets.
36
-
37
- Ensure that the shape of data should be N*H*W*C: N is number of data points. H,W are the image dimensions.
38
- C is the number of channels.
39
- """
40
- if data_config.data_type == DataType.SeparateTiffData:
41
- fpath1 = os.path.join(fpath, data_config.ch1_fname)
42
- fpath2 = os.path.join(fpath, data_config.ch2_fname)
43
- fpaths = [fpath1, fpath2]
44
- fpath0 = ""
45
- if "ch_input_fname" in data_config:
46
- fpath0 = os.path.join(fpath, data_config.ch_input_fname)
47
- fpaths = [fpath0] + fpaths
48
-
49
- print(
50
- f"Loading from {fpath} Channels: "
51
- f"{fpath1},{fpath2}, inp:{fpath0} Mode:{DataSplitType.name(datasplit_type)}"
52
- )
53
-
54
- data = np.concatenate([load_tiff(fpath)[..., None] for fpath in fpaths], axis=3)
55
- if data_config.data_type == DataType.PredictedTiffData:
56
- assert len(data.shape) == 5 and data.shape[-1] == 1
57
- data = data[..., 0].copy()
58
- # data = data[::3].copy()
59
- # NOTE: This was not the correct way to do it. It is so because the noise present in the input was directly related
60
- # to the noise present in the channels and so this is not the way we would get the data.
61
- # We need to add the noise independently to the input and the target.
62
-
63
- # if data_config.get('poisson_noise_factor', False):
64
- # data = np.random.poisson(data)
65
- # if data_config.get('enable_gaussian_noise', False):
66
- # synthetic_scale = data_config.get('synthetic_gaussian_scale', 0.1)
67
- # print('Adding Gaussian noise with scale', synthetic_scale)
68
- # noise = np.random.normal(0, synthetic_scale, data.shape)
69
- # data = data + noise
70
-
71
- if datasplit_type == DataSplitType.All:
72
- return data.astype(np.float32)
73
-
74
- train_idx, val_idx, test_idx = get_datasplit_tuples(
75
- val_fraction, test_fraction, len(data), starting_test=True
76
- )
77
- if datasplit_type == DataSplitType.Train:
78
- return data[train_idx].astype(np.float32)
79
- elif datasplit_type == DataSplitType.Val:
80
- return data[val_idx].astype(np.float32)
81
- elif datasplit_type == DataSplitType.Test:
82
- return data[test_idx].astype(np.float32)
83
-
84
- elif data_config.data_type == DataType.BioSR_MRC:
85
- num_channels = data_config.get("num_channels", 2)
86
- fpaths = []
87
- data_list = []
88
- for i in range(num_channels):
89
- fpath1 = os.path.join(fpath, data_config.get(f"ch{i + 1}_fname"))
90
- fpaths.append(fpath1)
91
- data = get_mrc_data(fpath1)[..., None]
92
- data_list.append(data)
93
-
94
- dirname = os.path.dirname(os.path.dirname(fpaths[0])) + "/"
95
-
96
- msg = ",".join([x[len(dirname) :] for x in fpaths])
97
- print(
98
- f"Loaded from {dirname} Channels:{len(fpaths)} {msg} Mode:{DataSplitType.name(datasplit_type)}"
99
- )
100
- N = data_list[0].shape[0]
101
- for data in data_list:
102
- N = min(N, data.shape[0])
103
-
104
- cropped_data = []
105
- for data in data_list:
106
- cropped_data.append(data[:N])
107
-
108
- data = np.concatenate(cropped_data, axis=3)
109
-
110
- if datasplit_type == DataSplitType.All:
111
- return data.astype(np.float32)
112
-
113
- train_idx, val_idx, test_idx = get_datasplit_tuples(
114
- val_fraction, test_fraction, len(data), starting_test=True
115
- )
116
- if datasplit_type == DataSplitType.Train:
117
- return data[train_idx].astype(np.float32)
118
- elif datasplit_type == DataSplitType.Val:
119
- return data[val_idx].astype(np.float32)
120
- elif datasplit_type == DataSplitType.Test:
121
- return data[test_idx].astype(np.float32)
14
+ from .vae_data_config import VaeDatasetConfig, DataSplitType, GridAlignement
122
15
 
123
16
 
124
17
  class MultiChDloader:
125
-
126
18
  def __init__(
127
19
  self,
128
- data_config: ml_collections.ConfigDict,
20
+ data_config: VaeDatasetConfig,
129
21
  fpath: str,
130
- datasplit_type: DataSplitType = None,
131
22
  val_fraction: float = None,
132
23
  test_fraction: float = None,
133
- normalized_input=None,
134
- enable_rotation_aug: bool = False,
135
- enable_random_cropping: bool = False,
136
- use_one_mu_std=None,
137
- allow_generation: bool = False,
138
- max_val: float = None,
139
- grid_alignment=GridAlignement.LeftTop,
140
- overlapping_padding_kwargs=None,
141
- print_vars: bool = True,
142
24
  ):
143
- """
144
- Here, an image is split into grids of size img_sz.
145
- Args:
146
- repeat_factor: Since we are doing a random crop, repeat_factor is
147
- given which can repeatedly sample from the same image. If self.N=12
148
- and repeat_factor is 5, then index upto 12*5 = 60 is allowed.
149
- use_one_mu_std: If this is set to true, then one mean and stdev is used
150
- for both channels. Otherwise, two different meean and stdev are used.
151
-
152
- """
25
+ """ """
153
26
  self._data_type = data_config.data_type
154
27
  self._fpath = fpath
155
28
  self._data = self.N = self._noise_data = None
156
-
29
+ self.Z = 1
30
+ self._trim_boundary = data_config.trim_boundary
157
31
  # Hardcoded params, not included in the config file.
158
32
 
159
33
  # by default, if the noise is present, add it to the input and target.
160
34
  self._disable_noise = False # to add synthetic noise
35
+ self._poisson_noise_factor = None
161
36
  self._train_index_switcher = None
37
+ self._depth3D = data_config.depth3D
162
38
  # NOTE: Input is the sum of the different channels. It is not the average of the different channels.
163
- self._input_is_sum = data_config.get("input_is_sum", False)
164
- self._num_channels = data_config.get("num_channels", 2)
165
- self._input_idx = data_config.get("input_idx", None)
166
- self._tar_idx_list = data_config.get("target_idx_list", None)
39
+ self._input_is_sum = data_config.input_is_sum
40
+ self._num_channels = data_config.num_channels
41
+ self._input_idx = data_config.input_idx
42
+ self._tar_idx_list = data_config.target_idx_list
167
43
 
168
- if datasplit_type == DataSplitType.Train:
44
+ if data_config.datasplit_type == DataSplitType.Train:
169
45
  self._datausage_fraction = 1.0
170
46
  # assert self._datausage_fraction == 1.0, 'Not supported. Use validtarget_random_fraction and training_validtarget_fraction to get the same effect'
171
47
  self._validtarget_rand_fract = None
172
48
  # self._validtarget_random_fraction_final = data_config.get('validtarget_random_fraction_final', None)
173
49
  # self._validtarget_random_fraction_stepepoch = data_config.get('validtarget_random_fraction_stepepoch', None)
174
50
  # self._idx_count = 0
175
- elif datasplit_type == DataSplitType.Val:
51
+ elif data_config.datasplit_type == DataSplitType.Val:
176
52
  self._datausage_fraction = 1.0
177
53
  else:
178
54
  self._datausage_fraction = 1.0
179
55
 
180
56
  self.load_data(
181
57
  data_config,
182
- datasplit_type,
58
+ data_config.datasplit_type,
183
59
  val_fraction=val_fraction,
184
60
  test_fraction=test_fraction,
185
- allow_generation=allow_generation,
61
+ allow_generation=data_config.allow_generation,
186
62
  )
187
- self._normalized_input = normalized_input
63
+ self._normalized_input = data_config.normalized_input
188
64
  self._quantile = 1.0
189
65
  self._channelwise_quantile = False
190
66
  self._background_quantile = 0.0
@@ -194,8 +70,8 @@ class MultiChDloader:
194
70
 
195
71
  self._background_values = None
196
72
 
197
- self._grid_alignment = grid_alignment
198
- self._overlapping_padding_kwargs = overlapping_padding_kwargs
73
+ self._grid_alignment = data_config.grid_alignment
74
+ self._overlapping_padding_kwargs = data_config.overlapping_padding_kwargs
199
75
  if self._grid_alignment == GridAlignement.LeftTop:
200
76
  assert (
201
77
  self._overlapping_padding_kwargs is None
@@ -205,20 +81,28 @@ class MultiChDloader:
205
81
  assert (
206
82
  self._overlapping_padding_kwargs is not None
207
83
  ), "With Center grid alignment, padding is needed."
84
+ if self._trim_boundary:
85
+ if (
86
+ self._overlapping_padding_kwargs is None
87
+ or data_config.multiscale_lowres_count is not None
88
+ ):
89
+ # raise warning
90
+ print("Padding is not used with this alignement style")
91
+ else:
92
+ assert (
93
+ self._overlapping_padding_kwargs is not None
94
+ ), "When not trimming boudnary, padding is needed."
208
95
 
209
- self._is_train = datasplit_type == DataSplitType.Train
96
+ self._is_train = data_config.datasplit_type == DataSplitType.Train
210
97
 
211
98
  # input = alpha * ch1 + (1-alpha)*ch2.
212
99
  # alpha is sampled randomly between these two extremes
213
- self._start_alpha_arr = self._end_alpha_arr = self._return_alpha = (
214
- self._alpha_weighted_target
215
- ) = None
100
+ self._start_alpha_arr = self._end_alpha_arr = self._return_alpha = None
216
101
 
217
102
  self._img_sz = self._grid_sz = self._repeat_factor = self.idx_manager = None
218
103
  if self._is_train:
219
- self._start_alpha_arr = None
220
- self._end_alpha_arr = None
221
- self._alpha_weighted_target = False
104
+ self._start_alpha_arr = data_config.start_alpha
105
+ self._end_alpha_arr = data_config.end_alpha
222
106
 
223
107
  self.set_img_sz(
224
108
  data_config.image_size,
@@ -229,11 +113,13 @@ class MultiChDloader:
229
113
  ),
230
114
  )
231
115
 
232
- # if self._validtarget_rand_fract is not None:
233
- # self._train_index_switcher = IndexSwitcher(self.idx_manager, data_config, self._img_sz)
234
- # self._std_background_arr = None
116
+ if self._validtarget_rand_fract is not None:
117
+ self._train_index_switcher = IndexSwitcher(
118
+ self.idx_manager, data_config, self._img_sz
119
+ )
235
120
 
236
121
  else:
122
+
237
123
  self.set_img_sz(
238
124
  data_config.image_size,
239
125
  (
@@ -246,32 +132,42 @@ class MultiChDloader:
246
132
  self._return_alpha = False
247
133
  self._return_index = False
248
134
 
249
- # self._empty_patch_replacement_enabled = data_config.get("empty_patch_replacement_enabled",
250
- # False) and self._is_train
251
- # if self._empty_patch_replacement_enabled:
252
- # self._empty_patch_replacement_channel_idx = data_config.empty_patch_replacement_channel_idx
253
- # self._empty_patch_replacement_probab = data_config.empty_patch_replacement_probab
254
- # data_frames = self._data[..., self._empty_patch_replacement_channel_idx]
255
- # # NOTE: This is on the raw data. So, it must be called before removing the background.
256
- # self._empty_patch_fetcher = EmptyPatchFetcher(self.idx_manager,
257
- # self._img_sz,
258
- # data_frames,
259
- # max_val_threshold=data_config.empty_patch_max_val_threshold)
135
+ self._empty_patch_replacement_enabled = (
136
+ data_config.empty_patch_replacement_enabled and self._is_train
137
+ )
138
+ if self._empty_patch_replacement_enabled:
139
+ self._empty_patch_replacement_channel_idx = (
140
+ data_config.empty_patch_replacement_channel_idx
141
+ )
142
+ self._empty_patch_replacement_probab = (
143
+ data_config.empty_patch_replacement_probab
144
+ )
145
+ data_frames = self._data[..., self._empty_patch_replacement_channel_idx]
146
+ # NOTE: This is on the raw data. So, it must be called before removing the background.
147
+ # TODO: missing import, needs fixing asap!
148
+ self._empty_patch_fetcher = EmptyPatchFetcher(
149
+ self.idx_manager,
150
+ self._img_sz,
151
+ data_frames,
152
+ max_val_threshold=data_config.empty_patch_max_val_threshold,
153
+ )
260
154
 
261
- self.rm_bkground_set_max_val_and_upperclip_data(max_val, datasplit_type)
155
+ self.rm_bkground_set_max_val_and_upperclip_data(
156
+ data_config.max_val, data_config.datasplit_type
157
+ )
262
158
 
263
159
  # For overlapping dloader, image_size and repeat_factors are not related. hence a different function.
264
160
 
265
161
  self._mean = None
266
162
  self._std = None
267
- self._use_one_mu_std = use_one_mu_std
163
+ self._use_one_mu_std = data_config.use_one_mu_std
268
164
  # Hardcoded
269
165
  self._target_separate_normalization = True
270
166
 
271
- self._enable_rotation = enable_rotation_aug
272
- self._enable_random_cropping = enable_random_cropping
167
+ self._enable_rotation = data_config.enable_rotation_aug
168
+ self._enable_random_cropping = data_config.enable_random_cropping
273
169
  self._uncorrelated_channels = (
274
- data_config.get("uncorrelated_channels", False) and self._is_train
170
+ data_config.uncorrelated_channels and self._is_train
275
171
  )
276
172
  assert self._is_train or self._uncorrelated_channels is False
277
173
  assert (
@@ -286,9 +182,10 @@ class MultiChDloader:
286
182
  )
287
183
  self._rotation_transform = A.Compose([A.Flip(), A.RandomRotate90()])
288
184
 
289
- if print_vars:
290
- msg = self._init_msg()
291
- print(msg)
185
+ # TODO: remove print log messages
186
+ # if print_vars:
187
+ # msg = self._init_msg()
188
+ # print(msg)
292
189
 
293
190
  def disable_noise(self):
294
191
  assert (
@@ -339,7 +236,7 @@ class MultiChDloader:
339
236
  )
340
237
 
341
238
  msg = ""
342
- if data_config.get("poisson_noise_factor", -1) > 0:
239
+ if data_config.poisson_noise_factor > 0:
343
240
  self._poisson_noise_factor = data_config.poisson_noise_factor
344
241
  msg += f"Adding Poisson noise with factor {self._poisson_noise_factor}.\t"
345
242
  self._data = (
@@ -347,20 +244,26 @@ class MultiChDloader:
347
244
  * self._poisson_noise_factor
348
245
  )
349
246
 
350
- if data_config.get("enable_gaussian_noise", False):
351
- synthetic_scale = data_config.get("synthetic_gaussian_scale", 0.1)
247
+ if data_config.enable_gaussian_noise:
248
+ synthetic_scale = data_config.synthetic_gaussian_scale
352
249
  msg += f"Adding Gaussian noise with scale {synthetic_scale}"
353
250
  # 0 => noise for input. 1: => noise for all targets.
354
251
  shape = self._data.shape
355
252
  self._noise_data = np.random.normal(
356
253
  0, synthetic_scale, (*shape[:-1], shape[-1] + 1)
357
254
  )
358
- if data_config.get("input_has_dependant_noise", False):
255
+ if data_config.input_has_dependant_noise:
359
256
  msg += ". Moreover, input has dependent noise"
360
257
  self._noise_data[..., 0] = np.mean(self._noise_data[..., 1:], axis=-1)
361
258
  print(msg)
362
259
 
363
- self.N = len(self._data)
260
+ self._5Ddata = len(self._data.shape) == 5
261
+ if self._5Ddata:
262
+ self.Z = self._data.shape[1]
263
+
264
+ if self._depth3D > 1:
265
+ assert self._5Ddata, "Data must be 5D:NxZxHxWxC for 3D data"
266
+
364
267
  assert (
365
268
  self._data.shape[-1] == self._num_channels
366
269
  ), "Number of channels in data and config do not match."
@@ -441,9 +344,13 @@ class MultiChDloader:
441
344
  def get_img_sz(self):
442
345
  return self._img_sz
443
346
 
347
+ def get_num_frames(self):
348
+ return self._data.shape[0]
349
+
444
350
  def reduce_data(
445
351
  self, t_list=None, h_start=None, h_end=None, w_start=None, w_end=None
446
352
  ):
353
+ assert not self._5Ddata, "This function is not supported for 3D data."
447
354
  if t_list is None:
448
355
  t_list = list(range(self._data.shape[0]))
449
356
  if h_start is None:
@@ -461,12 +368,22 @@ class MultiChDloader:
461
368
  t_list, h_start:h_end, w_start:w_end, :
462
369
  ].copy()
463
370
 
464
- self.N = len(t_list)
465
371
  self.set_img_sz(self._img_sz, self._grid_sz)
466
372
  print(
467
373
  f"[{self.__class__.__name__}] Data reduced. New data shape: {self._data.shape}"
468
374
  )
469
375
 
376
+ def get_idx_manager_shapes(self, patch_size: int, grid_size: int):
377
+ numC = self._data.shape[-1]
378
+ if self._5Ddata:
379
+ grid_shape = (1, 1, grid_size, grid_size, numC)
380
+ patch_shape = (1, self._depth3D, patch_size, patch_size, numC)
381
+ else:
382
+ grid_shape = (1, grid_size, grid_size, numC)
383
+ patch_shape = (1, patch_size, patch_size, numC)
384
+
385
+ return patch_shape, grid_shape
386
+
470
387
  def set_img_sz(self, image_size, grid_size):
471
388
  """
472
389
  If one wants to change the image size on the go, then this can be used.
@@ -474,12 +391,23 @@ class MultiChDloader:
474
391
  image_size: size of one patch
475
392
  grid_size: frame is divided into square grids of this size. A patch centered on a grid having size `image_size` is returned.
476
393
  """
394
+
477
395
  self._img_sz = image_size
478
396
  self._grid_sz = grid_size
397
+ shape = self._data.shape
398
+
399
+ patch_shape, grid_shape = self.get_idx_manager_shapes(
400
+ self._img_sz, self._grid_sz
401
+ )
479
402
  self.idx_manager = GridIndexManager(
480
- self._data.shape, self._grid_sz, self._img_sz, self._grid_alignment
403
+ shape, grid_shape, patch_shape, self._trim_boundary
481
404
  )
482
- self.set_repeat_factor()
405
+ # self.set_repeat_factor()
406
+
407
+ def __len__(self):
408
+ # Vera: N is the number of frames in Z stack
409
+ # Repeat factor is n_rows * n_cols
410
+ return self.idx_manager.total_grid_count()
483
411
 
484
412
  def set_repeat_factor(self):
485
413
  if self._grid_sz > 1:
@@ -497,7 +425,14 @@ class MultiChDloader:
497
425
  msg = (
498
426
  f"[{self.__class__.__name__}] Train:{int(self._is_train)} Sz:{self._img_sz}"
499
427
  )
428
+ dim_sizes = [
429
+ self.idx_manager.get_individual_dim_grid_count(dim)
430
+ for dim in range(len(self._data.shape))
431
+ ]
432
+ dim_sizes = ",".join([str(x) for x in dim_sizes])
500
433
  msg += f" N:{self.N} NumPatchPerN:{self._repeat_factor}"
434
+ msg += f"{self.idx_manager.total_grid_count()} DimSz:({dim_sizes})"
435
+ msg += f" TrimB:{self._trim_boundary}"
501
436
  # msg += f' NormInp:{self._normalized_input}'
502
437
  # msg += f' SingleNorm:{self._use_one_mu_std}'
503
438
  msg += f" Rot:{self._enable_rotation}"
@@ -529,40 +464,52 @@ class MultiChDloader:
529
464
  )
530
465
 
531
466
  if self._enable_random_cropping:
532
- h_start, w_start = self._get_random_hw(h, w)
467
+ patch_start_loc = self._get_random_hw(h, w)
468
+ if self._5Ddata:
469
+ patch_start_loc = (
470
+ np.random.choice(img_tuples[0].shape[-3] - self._depth3D),
471
+ ) + patch_start_loc
533
472
  else:
534
- h_start, w_start = self._get_deterministic_hw(index)
473
+ patch_start_loc = self._get_deterministic_loc(index)
535
474
 
536
475
  cropped_imgs = []
537
476
  for img in img_tuples:
538
- img = self._crop_flip_img(img, h_start, w_start, False, False)
477
+ img = self._crop_flip_img(img, patch_start_loc, False, False)
539
478
  cropped_imgs.append(img)
540
479
 
541
480
  return (
542
481
  *tuple(cropped_imgs),
543
482
  {
544
- "h": [h_start, h_start + self._img_sz],
545
- "w": [w_start, w_start + self._img_sz],
546
483
  "hflip": False,
547
484
  "wflip": False,
548
485
  },
549
486
  )
550
487
 
551
- def _crop_img(self, img: np.ndarray, h_start: int, w_start: int):
552
- if self._grid_alignment == GridAlignement.LeftTop:
488
+ def _crop_img(self, img: np.ndarray, patch_start_loc: Tuple):
489
+ if self._trim_boundary:
553
490
  # In training, this is used.
554
491
  # NOTE: It is my opinion that if I just use self._crop_img_with_padding, it will work perfectly fine.
555
492
  # The only benefit this if else loop provides is that it makes it easier to see what happens during training.
556
- new_img = img[
557
- ..., h_start : h_start + self._img_sz, w_start : w_start + self._img_sz
558
- ]
493
+ patch_end_loc = (
494
+ np.array(patch_start_loc, dtype=np.int32)
495
+ + self.idx_manager.patch_shape[1:-1]
496
+ )
497
+ if self._5Ddata:
498
+ z_start, h_start, w_start = patch_start_loc
499
+ z_end, h_end, w_end = patch_end_loc
500
+ new_img = img[..., z_start:z_end, h_start:h_end, w_start:w_end]
501
+ else:
502
+ h_start, w_start = patch_start_loc
503
+ h_end, w_end = patch_end_loc
504
+ new_img = img[..., h_start:h_end, w_start:w_end]
505
+
559
506
  return new_img
560
- elif self._grid_alignment == GridAlignement.Center:
507
+ else:
561
508
  # During evaluation, this is used. In this situation, we can have negative h_start, w_start. Or h_start +self._img_sz can be larger than frame
562
509
  # In these situations, we need some sort of padding. This is not needed in the LeftTop alignement.
563
- return self._crop_img_with_padding(img, h_start, w_start)
510
+ return self._crop_img_with_padding(img, patch_start_loc)
564
511
 
565
- def get_begin_end_padding(self, start_pos, max_len):
512
+ def get_begin_end_padding(self, start_pos, end_pos, max_len):
566
513
  """
567
514
  The effect is that the image with size self._grid_sz is in the center of the patch with sufficient
568
515
  padding on all four sides so that the final patch size is self._img_sz.
@@ -572,44 +519,56 @@ class MultiChDloader:
572
519
  if start_pos < 0:
573
520
  pad_start = -1 * start_pos
574
521
 
575
- pad_end = max(0, start_pos + self._img_sz - max_len)
522
+ pad_end = max(0, end_pos - max_len)
576
523
 
577
524
  return pad_start, pad_end
578
525
 
579
- def _crop_img_with_padding(self, img: np.ndarray, h_start: int, w_start: int):
580
- _, H, W = img.shape
581
- h_on_boundary = self.on_boundary(h_start, H)
582
- w_on_boundary = self.on_boundary(w_start, W)
583
-
584
- assert h_start < H
585
- assert w_start < W
586
-
587
- assert h_start + self._img_sz <= H or h_on_boundary
588
- assert w_start + self._img_sz <= W or w_on_boundary
526
+ def _crop_img_with_padding(
527
+ self, img: np.ndarray, patch_start_loc, max_len_vals=None
528
+ ):
529
+ if max_len_vals is None:
530
+ max_len_vals = self.idx_manager.data_shape[1:-1]
531
+ patch_end_loc = np.array(patch_start_loc, dtype=int) + np.array(
532
+ self.idx_manager.patch_shape[1:-1], dtype=int
533
+ )
534
+ boundary_crossed = []
535
+ valid_slice = []
536
+ padding = [[0, 0]]
537
+ for start_idx, end_idx, max_len in zip(
538
+ patch_start_loc, patch_end_loc, max_len_vals
539
+ ):
540
+ boundary_crossed.append(end_idx > max_len or start_idx < 0)
541
+ valid_slice.append((max(0, start_idx), min(max_len, end_idx)))
542
+ pad = [0, 0]
543
+ if boundary_crossed[-1]:
544
+ pad = self.get_begin_end_padding(start_idx, end_idx, max_len)
545
+ padding.append(pad)
589
546
  # max() is needed since h_start could be negative.
590
- new_img = img[
591
- ...,
592
- max(0, h_start) : h_start + self._img_sz,
593
- max(0, w_start) : w_start + self._img_sz,
594
- ]
595
- padding = np.array([[0, 0], [0, 0], [0, 0]])
596
-
597
- if h_on_boundary:
598
- pad = self.get_begin_end_padding(h_start, H)
599
- padding[1] = pad
600
- if w_on_boundary:
601
- pad = self.get_begin_end_padding(w_start, W)
602
- padding[2] = pad
547
+ if self._5Ddata:
548
+ new_img = img[
549
+ ...,
550
+ valid_slice[0][0] : valid_slice[0][1],
551
+ valid_slice[1][0] : valid_slice[1][1],
552
+ valid_slice[2][0] : valid_slice[2][1],
553
+ ]
554
+ else:
555
+ new_img = img[
556
+ ...,
557
+ valid_slice[0][0] : valid_slice[0][1],
558
+ valid_slice[1][0] : valid_slice[1][1],
559
+ ]
603
560
 
561
+ # print(np.array(padding).shape, img.shape, new_img.shape)
562
+ # print(padding)
604
563
  if not np.all(padding == 0):
605
564
  new_img = np.pad(new_img, padding, **self._overlapping_padding_kwargs)
606
565
 
607
566
  return new_img
608
567
 
609
568
  def _crop_flip_img(
610
- self, img: np.ndarray, h_start: int, w_start: int, h_flip: bool, w_flip: bool
569
+ self, img: np.ndarray, patch_start_loc: Tuple, h_flip: bool, w_flip: bool
611
570
  ):
612
- new_img = self._crop_img(img, h_start, w_start)
571
+ new_img = self._crop_img(img, patch_start_loc)
613
572
  if h_flip:
614
573
  new_img = new_img[..., ::-1, :]
615
574
  if w_flip:
@@ -617,9 +576,6 @@ class MultiChDloader:
617
576
 
618
577
  return new_img.astype(np.float32)
619
578
 
620
- def __len__(self):
621
- return self.N * self._repeat_factor
622
-
623
579
  def _load_img(
624
580
  self, index: Union[int, Tuple[int, int]]
625
581
  ) -> Tuple[np.ndarray, np.ndarray]:
@@ -631,12 +587,21 @@ class MultiChDloader:
631
587
  else:
632
588
  idx = index[0]
633
589
 
634
- imgs = self._data[self.idx_manager.get_t(idx)]
590
+ patch_loc_list = self.idx_manager.get_patch_location_from_dataset_idx(idx)
591
+ imgs = self._data[patch_loc_list[0]]
592
+ # if self._5Ddata:
593
+ # assert self._noise_data is None, 'Noise is not supported for 5D data'
594
+ # n_loc, z_loc = patch_loc_list[:2]
595
+ # z_loc_interval = range(z_loc, z_loc + self._depth3D)
596
+ # imgs = self._data[n_loc, z_loc_interval]
597
+ # else:
598
+ # imgs = self._data[patch_loc_list[0]]
599
+
635
600
  loaded_imgs = [imgs[None, ..., i] for i in range(imgs.shape[-1])]
636
601
  noise = []
637
602
  if self._noise_data is not None and not self._disable_noise:
638
603
  noise = [
639
- self._noise_data[self.idx_manager.get_t(idx)][None, ..., i]
604
+ self._noise_data[patch_loc_list[0]][None, ..., i]
640
605
  for i in range(self._noise_data.shape[-1])
641
606
  ]
642
607
  return tuple(loaded_imgs), tuple(noise)
@@ -669,27 +634,16 @@ class MultiChDloader:
669
634
  def per_side_overlap_pixelcount(self):
670
635
  return (self._img_sz - self._grid_sz) // 2
671
636
 
672
- def on_boundary(self, cur_loc, frame_size):
673
- return cur_loc + self._img_sz > frame_size or cur_loc < 0
637
+ # def on_boundary(self, cur_loc, frame_size):
638
+ # return cur_loc + self._img_sz > frame_size or cur_loc < 0
674
639
 
675
- def _get_deterministic_hw(self, index: Union[int, Tuple[int, int]]):
640
+ def _get_deterministic_loc(self, index: int):
676
641
  """
677
642
  It returns the top-left corner of the patch corresponding to index.
678
643
  """
679
- if isinstance(index, int) or isinstance(index, np.int64):
680
- idx = index
681
- grid_size = self._grid_sz
682
- else:
683
- idx, grid_size = index
684
-
685
- h_start, w_start = self.idx_manager.get_deterministic_hw(
686
- idx, grid_size=grid_size
687
- )
688
- if self._grid_alignment == GridAlignement.LeftTop:
689
- return h_start, w_start
690
- elif self._grid_alignment == GridAlignement.Center:
691
- pad = self.per_side_overlap_pixelcount()
692
- return h_start - pad, w_start - pad
644
+ loc_list = self.idx_manager.get_patch_location_from_dataset_idx(index)
645
+ # last dim is channel. we need to take the third and the second last element.
646
+ return loc_list[1:-1]
693
647
 
694
648
  def compute_individual_mean_std(self):
695
649
  # numpy 1.19.2 has issues in computing for large arrays. https://github.com/numpy/numpy/issues/8869
@@ -715,6 +669,10 @@ class MultiChDloader:
715
669
 
716
670
  mean = np.array(mean_arr)
717
671
  std = np.array(std_arr)
672
+ if (
673
+ self._5Ddata
674
+ ): # NOTE: IDEALLY this should be only when the model expects 3D data.
675
+ return mean[None, :, None, None, None], std[None, :, None, None, None]
718
676
 
719
677
  return mean[None, :, None, None], std[None, :, None, None]
720
678
 
@@ -776,6 +734,10 @@ class MultiChDloader:
776
734
  if self._skip_normalization_using_mean:
777
735
  mean = np.zeros_like(mean)
778
736
 
737
+ if self._5Ddata:
738
+ mean = mean[:, :, None]
739
+ std = std[:, :, None]
740
+
779
741
  mean_dict = {"input": mean} # , 'target':mean}
780
742
  std_dict = {"input": std} # , 'target':std}
781
743
 
@@ -810,8 +772,14 @@ class MultiChDloader:
810
772
  return cropped_img_tuples, cropped_noise_tuples
811
773
 
812
774
  def replace_with_empty_patch(self, img_tuples):
775
+ """
776
+ Replaces the content of one of the channels with background
777
+ """
813
778
  empty_index = self._empty_patch_fetcher.sample()
814
- empty_img_tuples = self._get_img(empty_index)
779
+ empty_img_tuples, empty_img_noise_tuples = self._get_img(empty_index)
780
+ assert (
781
+ len(empty_img_noise_tuples) == 0
782
+ ), "Noise is not supported with empty patch replacement"
815
783
  final_img_tuples = []
816
784
  for tuple_idx in range(len(img_tuples)):
817
785
  if tuple_idx == self._empty_patch_replacement_channel_idx:
@@ -834,14 +802,7 @@ class MultiChDloader:
834
802
  )
835
803
  img_tuples = [img_tuples[i] for i in self._tar_idx_list]
836
804
 
837
- if self._alpha_weighted_target:
838
- assert self._input_is_sum is False
839
- target = []
840
- for i in range(len(img_tuples)):
841
- target.append(img_tuples[i] * alpha[i])
842
- target = np.concatenate(target, axis=0)
843
- else:
844
- target = np.concatenate(img_tuples, axis=0)
805
+ target = np.concatenate(img_tuples, axis=0)
845
806
  return target
846
807
 
847
808
  def _compute_input_with_alpha(self, img_tuples, alpha_list):
@@ -902,9 +863,6 @@ class MultiChDloader:
902
863
  index = self._train_index_switcher.get_invalid_target_index()
903
864
  return index
904
865
 
905
- def _rotate(self, img_tuples, noise_tuples):
906
- return self._rotate2D(img_tuples, noise_tuples)
907
-
908
866
  def _rotate2D(self, img_tuples, noise_tuples):
909
867
  img_kwargs = {}
910
868
  for i, img in enumerate(img_tuples):
@@ -921,6 +879,7 @@ class MultiChDloader:
921
879
  rot_dic = self._rotation_transform(
922
880
  image=img_tuples[0][0], **img_kwargs, **noise_kwargs
923
881
  )
882
+
924
883
  rotated_img_tuples = []
925
884
  for i, img in enumerate(img_tuples):
926
885
  if len(img) == 1:
@@ -946,7 +905,90 @@ class MultiChDloader:
946
905
 
947
906
  return rotated_img_tuples, rotated_noise_tuples
948
907
 
908
+ def _rotate(self, img_tuples, noise_tuples):
909
+ if self._depth3D > 1:
910
+ return self._rotate3D(img_tuples, noise_tuples)
911
+ else:
912
+ return self._rotate2D(img_tuples, noise_tuples)
913
+
914
+ def _rotate3D(self, img_tuples, noise_tuples):
915
+ img_kwargs = {}
916
+ for i, img in enumerate(img_tuples):
917
+ for j in range(self._depth3D):
918
+ for k in range(len(img)):
919
+ img_kwargs[f"img{i}_{j}_{k}"] = img[k, j]
920
+
921
+ noise_kwargs = {}
922
+ for i, nimg in enumerate(noise_tuples):
923
+ for j in range(self._depth3D):
924
+ for k in range(len(nimg)):
925
+ noise_kwargs[f"noise{i}_{j}_{k}"] = nimg[k, j]
926
+
927
+ keys = list(img_kwargs.keys()) + list(noise_kwargs.keys())
928
+ self._rotation_transform.add_targets({k: "image" for k in keys})
929
+ rot_dic = self._rotation_transform(
930
+ image=img_tuples[0][0], **img_kwargs, **noise_kwargs
931
+ )
932
+ rotated_img_tuples = []
933
+ for i, img in enumerate(img_tuples):
934
+ if len(img) == 1:
935
+ rotated_img_tuples.append(
936
+ np.concatenate(
937
+ [
938
+ rot_dic[f"img{i}_{j}_0"][None, None]
939
+ for j in range(self._depth3D)
940
+ ],
941
+ axis=1,
942
+ )
943
+ )
944
+ else:
945
+ temp_arr = []
946
+ for k in range(len(img)):
947
+ temp_arr.append(
948
+ np.concatenate(
949
+ [
950
+ rot_dic[f"img{i}_{j}_{k}"][None, None]
951
+ for j in range(self._depth3D)
952
+ ],
953
+ axis=1,
954
+ )
955
+ )
956
+ rotated_img_tuples.append(np.concatenate(temp_arr, axis=0))
957
+
958
+ rotated_noise_tuples = []
959
+ for i, nimg in enumerate(noise_tuples):
960
+ if len(nimg) == 1:
961
+ rotated_noise_tuples.append(
962
+ np.concatenate(
963
+ [
964
+ rot_dic[f"noise{i}_{j}_0"][None, None]
965
+ for j in range(self._depth3D)
966
+ ],
967
+ axis=1,
968
+ )
969
+ )
970
+ else:
971
+ temp_arr = []
972
+ for k in range(len(nimg)):
973
+ temp_arr.append(
974
+ np.concatenate(
975
+ [
976
+ rot_dic[f"noise{i}_{j}_{k}"][None, None]
977
+ for j in range(self._depth3D)
978
+ ],
979
+ axis=1,
980
+ )
981
+ )
982
+ rotated_noise_tuples.append(np.concatenate(temp_arr, axis=0))
983
+
984
+ return rotated_img_tuples, rotated_noise_tuples
985
+
949
986
  def get_uncorrelated_img_tuples(self, index):
987
+ """
988
+ Content of channels like actin and nuclei is "correlated" in its
989
+ respective location, this function allows to pick channels' content
990
+ from different patches of the image to make it "uncorrelated".
991
+ """
950
992
  img_tuples, noise_tuples = self._get_img(index)
951
993
  assert len(noise_tuples) == 0
952
994
  img_tuples = [img_tuples[0]]
@@ -959,6 +1001,8 @@ class MultiChDloader:
959
1001
  def __getitem__(
960
1002
  self, index: Union[int, Tuple[int, int]]
961
1003
  ) -> Tuple[np.ndarray, np.ndarray]:
1004
+ # Vera: input can be both real microscopic image and two separate channels that are summed in the code
1005
+
962
1006
  if self._train_index_switcher is not None:
963
1007
  index = self._get_index_from_valid_target_logic(index)
964
1008
 
@@ -971,22 +1015,29 @@ class MultiChDloader:
971
1015
  self._empty_patch_replacement_enabled != True
972
1016
  ), "This is not supported with noise"
973
1017
 
1018
+ # Replace the content of one of the channels
1019
+ # with background with given probability
974
1020
  if self._empty_patch_replacement_enabled:
975
1021
  if np.random.rand() < self._empty_patch_replacement_probab:
976
1022
  img_tuples = self.replace_with_empty_patch(img_tuples)
977
1023
 
1024
+ # Noise tuples are not needed for the paper
1025
+ # the image tuples are noisy by default
1026
+ # TODO: remove noise tuples completely?
978
1027
  if self._enable_rotation:
979
1028
  img_tuples, noise_tuples = self._rotate(img_tuples, noise_tuples)
980
1029
 
981
- # add noise to input
1030
+ # Add noise tuples with image tuples to create the input
982
1031
  if len(noise_tuples) > 0:
983
1032
  factor = np.sqrt(2) if self._input_is_sum else 1.0
984
1033
  input_tuples = [x + noise_tuples[0] * factor for x in img_tuples]
985
1034
  else:
986
1035
  input_tuples = img_tuples
1036
+
1037
+ # Weight the individual channels, typically alpha is fixed
987
1038
  inp, alpha = self._compute_input(input_tuples)
988
1039
 
989
- # add noise to target.
1040
+ # Add noise tuples to the image tuples to create the target
990
1041
  if len(noise_tuples) >= 1:
991
1042
  img_tuples = [x + noise for x, noise in zip(img_tuples, noise_tuples[1:])]
992
1043
 
@@ -1000,221 +1051,4 @@ class MultiChDloader:
1000
1051
  if self._return_index:
1001
1052
  output.append(index)
1002
1053
 
1003
- if isinstance(index, int) or isinstance(index, np.int64):
1004
- return tuple(output)
1005
-
1006
- _, grid_size = index
1007
- output.append(grid_size)
1008
- return tuple(output)
1009
-
1010
-
1011
- class LCMultiChDloader(MultiChDloader):
1012
-
1013
- def __init__(
1014
- self,
1015
- data_config,
1016
- fpath: str,
1017
- datasplit_type: DataSplitType = None,
1018
- val_fraction=None,
1019
- test_fraction=None,
1020
- normalized_input=None,
1021
- enable_rotation_aug: bool = False,
1022
- use_one_mu_std=None,
1023
- num_scales: int = None,
1024
- enable_random_cropping=False,
1025
- padding_kwargs: dict = None,
1026
- allow_generation: bool = False,
1027
- lowres_supervision=None,
1028
- max_val=None,
1029
- grid_alignment=GridAlignement.LeftTop,
1030
- overlapping_padding_kwargs=None,
1031
- print_vars=True,
1032
- ):
1033
- """
1034
- Args:
1035
- num_scales: The number of resolutions at which we want the input. Note that the target is formed at the
1036
- highest resolution.
1037
- """
1038
- self._padding_kwargs = (
1039
- padding_kwargs # mode=padding_mode, constant_values=constant_value
1040
- )
1041
- if overlapping_padding_kwargs is not None:
1042
- assert (
1043
- self._padding_kwargs == overlapping_padding_kwargs
1044
- ), "During evaluation, overlapping_padding_kwargs should be same as padding_args. \
1045
- It should be so since we just use overlapping_padding_kwargs when it is not None"
1046
-
1047
- else:
1048
- overlapping_padding_kwargs = padding_kwargs
1049
-
1050
- super().__init__(
1051
- data_config,
1052
- fpath,
1053
- datasplit_type=datasplit_type,
1054
- val_fraction=val_fraction,
1055
- test_fraction=test_fraction,
1056
- normalized_input=normalized_input,
1057
- enable_rotation_aug=enable_rotation_aug,
1058
- enable_random_cropping=enable_random_cropping,
1059
- use_one_mu_std=use_one_mu_std,
1060
- allow_generation=allow_generation,
1061
- max_val=max_val,
1062
- grid_alignment=grid_alignment,
1063
- overlapping_padding_kwargs=overlapping_padding_kwargs,
1064
- print_vars=print_vars,
1065
- )
1066
- self.num_scales = num_scales
1067
- assert self.num_scales is not None
1068
- self._scaled_data = [self._data]
1069
- self._scaled_noise_data = [self._noise_data]
1070
-
1071
- assert isinstance(self.num_scales, int) and self.num_scales >= 1
1072
- self._lowres_supervision = lowres_supervision
1073
- assert isinstance(self._padding_kwargs, dict)
1074
- assert "mode" in self._padding_kwargs
1075
-
1076
- for _ in range(1, self.num_scales):
1077
- shape = self._scaled_data[-1].shape
1078
- assert len(shape) == 4
1079
- new_shape = (shape[0], shape[1] // 2, shape[2] // 2, shape[3])
1080
- ds_data = resize(
1081
- self._scaled_data[-1].astype(np.float32), new_shape
1082
- ).astype(self._scaled_data[-1].dtype)
1083
- # NOTE: These asserts are important. the resize method expects np.float32. otherwise, one gets weird results.
1084
- assert (
1085
- ds_data.max() / self._scaled_data[-1].max() < 5
1086
- ), "Downsampled image should not have very different values"
1087
- assert (
1088
- ds_data.max() / self._scaled_data[-1].max() > 0.2
1089
- ), "Downsampled image should not have very different values"
1090
-
1091
- self._scaled_data.append(ds_data)
1092
- # do the same for noise
1093
- if self._noise_data is not None:
1094
- noise_data = resize(self._scaled_noise_data[-1], new_shape)
1095
- self._scaled_noise_data.append(noise_data)
1096
-
1097
- def _init_msg(self):
1098
- msg = super()._init_msg()
1099
- msg += f" Pad:{self._padding_kwargs}"
1100
- return msg
1101
-
1102
- def _load_scaled_img(
1103
- self, scaled_index, index: Union[int, Tuple[int, int]]
1104
- ) -> Tuple[np.ndarray, np.ndarray]:
1105
- if isinstance(index, int):
1106
- idx = index
1107
- else:
1108
- idx, _ = index
1109
- imgs = self._scaled_data[scaled_index][idx % self.N]
1110
- imgs = tuple([imgs[None, :, :, i] for i in range(imgs.shape[-1])])
1111
- if self._noise_data is not None:
1112
- noisedata = self._scaled_noise_data[scaled_index][idx % self.N]
1113
- noise = tuple(
1114
- [noisedata[None, :, :, i] for i in range(noisedata.shape[-1])]
1115
- )
1116
- factor = np.sqrt(2) if self._input_is_sum else 1.0
1117
- # since we are using this lowres images for just the input, we need to add the noise of the input.
1118
- assert self._lowres_supervision is None or self._lowres_supervision is False
1119
- imgs = tuple([img + noise[0] * factor for img in imgs])
1120
- return imgs
1121
-
1122
- def _crop_img(self, img: np.ndarray, h_start: int, w_start: int):
1123
- """
1124
- Here, h_start, w_start could be negative. That simply means we need to pick the content from 0. So,
1125
- the cropped image will be smaller than self._img_sz * self._img_sz
1126
- """
1127
- return self._crop_img_with_padding(img, h_start, w_start)
1128
-
1129
- def _get_img(self, index: int):
1130
- """
1131
- Returns the primary patch along with low resolution patches centered on the primary patch.
1132
- """
1133
- img_tuples, noise_tuples = self._load_img(index)
1134
- assert self._img_sz is not None
1135
- h, w = img_tuples[0].shape[-2:]
1136
- if self._enable_random_cropping:
1137
- h_start, w_start = self._get_random_hw(h, w)
1138
- else:
1139
- h_start, w_start = self._get_deterministic_hw(index)
1140
-
1141
- cropped_img_tuples = [
1142
- self._crop_flip_img(img, h_start, w_start, False, False)
1143
- for img in img_tuples
1144
- ]
1145
- cropped_noise_tuples = [
1146
- self._crop_flip_img(noise, h_start, w_start, False, False)
1147
- for noise in noise_tuples
1148
- ]
1149
- h_center = h_start + self._img_sz // 2
1150
- w_center = w_start + self._img_sz // 2
1151
- allres_versions = {
1152
- i: [cropped_img_tuples[i]] for i in range(len(cropped_img_tuples))
1153
- }
1154
- for scale_idx in range(1, self.num_scales):
1155
- scaled_img_tuples = self._load_scaled_img(scale_idx, index)
1156
-
1157
- h_center = h_center // 2
1158
- w_center = w_center // 2
1159
-
1160
- h_start = h_center - self._img_sz // 2
1161
- w_start = w_center - self._img_sz // 2
1162
-
1163
- scaled_cropped_img_tuples = [
1164
- self._crop_flip_img(img, h_start, w_start, False, False)
1165
- for img in scaled_img_tuples
1166
- ]
1167
- for ch_idx in range(len(img_tuples)):
1168
- allres_versions[ch_idx].append(scaled_cropped_img_tuples[ch_idx])
1169
-
1170
- output_img_tuples = tuple(
1171
- [
1172
- np.concatenate(allres_versions[ch_idx])
1173
- for ch_idx in range(len(img_tuples))
1174
- ]
1175
- )
1176
- return output_img_tuples, cropped_noise_tuples
1177
-
1178
- def __getitem__(self, index: Union[int, Tuple[int, int]]):
1179
- if self._uncorrelated_channels:
1180
- img_tuples, noise_tuples = self.get_uncorrelated_img_tuples(index)
1181
- else:
1182
- img_tuples, noise_tuples = self._get_img(index)
1183
-
1184
- if self._enable_rotation:
1185
- img_tuples, noise_tuples = self._rotate(img_tuples, noise_tuples)
1186
-
1187
- assert self._lowres_supervision != True
1188
- # add noise to input
1189
- if len(noise_tuples) > 0:
1190
- factor = np.sqrt(2) if self._input_is_sum else 1.0
1191
- input_tuples = []
1192
- for x in img_tuples:
1193
- # NOTE: other LC levels already have noise added. So, we just need to add noise to the highest resolution.
1194
- x[0] = x[0] + noise_tuples[0] * factor
1195
- input_tuples.append(x)
1196
- else:
1197
- input_tuples = img_tuples
1198
-
1199
- inp, alpha = self._compute_input(input_tuples)
1200
- # assert self._alpha_weighted_target in [False, None]
1201
- target_tuples = [img[:1] for img in img_tuples]
1202
- # add noise to target.
1203
- if len(noise_tuples) >= 1:
1204
- target_tuples = [
1205
- x + noise for x, noise in zip(target_tuples, noise_tuples[1:])
1206
- ]
1207
-
1208
- target = self._compute_target(target_tuples, alpha)
1209
-
1210
- output = [inp, target]
1211
-
1212
- if self._return_alpha:
1213
- output.append(alpha)
1214
-
1215
- if isinstance(index, int):
1216
- return tuple(output)
1217
-
1218
- _, grid_size = index
1219
- output.append(grid_size)
1220
1054
  return tuple(output)