careamics 0.0.3__py3-none-any.whl → 0.0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (56) hide show
  1. careamics/careamist.py +25 -17
  2. careamics/cli/__init__.py +5 -0
  3. careamics/cli/conf.py +391 -0
  4. careamics/cli/main.py +134 -0
  5. careamics/config/architectures/lvae_model.py +0 -4
  6. careamics/config/configuration_factory.py +480 -177
  7. careamics/config/configuration_model.py +1 -2
  8. careamics/config/data_model.py +1 -15
  9. careamics/config/fcn_algorithm_model.py +14 -9
  10. careamics/config/likelihood_model.py +21 -4
  11. careamics/config/nm_model.py +31 -5
  12. careamics/config/optimizer_models.py +3 -1
  13. careamics/config/support/supported_optimizers.py +1 -1
  14. careamics/config/support/supported_transforms.py +1 -0
  15. careamics/config/training_model.py +35 -6
  16. careamics/config/transformations/__init__.py +4 -1
  17. careamics/config/transformations/transform_union.py +20 -0
  18. careamics/config/vae_algorithm_model.py +2 -36
  19. careamics/dataset/tiling/lvae_tiled_patching.py +90 -8
  20. careamics/lightning/lightning_module.py +10 -8
  21. careamics/lightning/train_data_module.py +2 -2
  22. careamics/losses/loss_factory.py +3 -3
  23. careamics/losses/lvae/losses.py +2 -2
  24. careamics/lvae_training/dataset/__init__.py +15 -0
  25. careamics/lvae_training/dataset/{vae_data_config.py → config.py} +25 -81
  26. careamics/lvae_training/dataset/lc_dataset.py +28 -20
  27. careamics/lvae_training/dataset/{vae_dataset.py → multich_dataset.py} +91 -51
  28. careamics/lvae_training/dataset/multifile_dataset.py +334 -0
  29. careamics/lvae_training/dataset/types.py +43 -0
  30. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  31. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  32. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  33. careamics/lvae_training/dataset/utils/index_manager.py +232 -0
  34. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  35. careamics/lvae_training/eval_utils.py +109 -64
  36. careamics/lvae_training/get_config.py +1 -1
  37. careamics/lvae_training/train_lvae.py +1 -1
  38. careamics/model_io/bioimage/bioimage_utils.py +4 -2
  39. careamics/model_io/bmz_io.py +6 -5
  40. careamics/models/lvae/likelihoods.py +18 -9
  41. careamics/models/lvae/lvae.py +12 -16
  42. careamics/models/lvae/noise_models.py +1 -1
  43. careamics/transforms/compose.py +90 -15
  44. careamics/transforms/n2v_manipulate.py +6 -2
  45. careamics/transforms/normalize.py +14 -3
  46. careamics/transforms/xy_flip.py +16 -6
  47. careamics/transforms/xy_random_rotate90.py +16 -7
  48. careamics/utils/metrics.py +204 -24
  49. careamics/utils/serializers.py +60 -0
  50. {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/METADATA +4 -3
  51. {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/RECORD +54 -43
  52. careamics-0.0.4.1.dist-info/entry_points.txt +2 -0
  53. careamics/lvae_training/dataset/data_utils.py +0 -701
  54. careamics/lvae_training/dataset/lc_dataset_config.py +0 -13
  55. {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/WHEEL +0 -0
  56. {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,63 +1,13 @@
1
1
  from typing import Any, Optional
2
- from enum import Enum
3
-
4
- from pydantic import BaseModel, ConfigDict, computed_field
5
-
6
-
7
- # TODO: get rid of unnecessary enums
8
- class DataType(Enum):
9
- MNIST = 0
10
- Places365 = 1
11
- NotMNIST = 2
12
- OptiMEM100_014 = 3
13
- CustomSinosoid = 4
14
- Prevedel_EMBL = 5
15
- AllenCellMito = 6
16
- SeparateTiffData = 7
17
- CustomSinosoidThreeCurve = 8
18
- SemiSupBloodVesselsEMBL = 9
19
- Pavia2 = 10
20
- Pavia2VanillaSplitting = 11
21
- ExpansionMicroscopyMitoTub = 12
22
- ShroffMitoEr = 13
23
- HTIba1Ki67 = 14
24
- BSD68 = 15
25
- BioSR_MRC = 16
26
- TavernaSox2Golgi = 17
27
- Dao3Channel = 18
28
- ExpMicroscopyV2 = 19
29
- Dao3ChannelWithInput = 20
30
- TavernaSox2GolgiV2 = 21
31
- TwoDset = 22
32
- PredictedTiffData = 23
33
- Pavia3SeqData = 24
34
- # Here, we have 16 splitting tasks.
35
- NicolaData = 25
36
-
37
-
38
- class DataSplitType(Enum):
39
- All = 0
40
- Train = 1
41
- Val = 2
42
- Test = 3
43
-
44
-
45
- class GridAlignement(Enum):
46
- """
47
- A patch is formed by padding the grid with content. If the grids are 'Center' aligned, then padding is to done equally on all 4 sides.
48
- On the other hand, if grids are 'LeftTop' aligned, padding is to be done on the right and bottom end of the grid.
49
- In the former case, one needs (patch_size - grid_size)//2 amount of content on the right end of the frame.
50
- In the latter case, one needs patch_size - grid_size amount of content on the right end of the frame.
51
- """
52
-
53
- LeftTop = 0
54
- Center = 1
55
-
56
-
57
- # TODO: for all bool params check if they are taking different values in Disentangle repo
2
+
3
+ from pydantic import BaseModel, ConfigDict
4
+
5
+ from .types import DataType, DataSplitType, TilingMode
6
+
7
+
58
8
  # TODO: check if any bool logic can be removed
59
- class VaeDatasetConfig(BaseModel):
60
- model_config = ConfigDict(validate_assignment=True)
9
+ class DatasetConfig(BaseModel):
10
+ model_config = ConfigDict(validate_assignment=True, extra="forbid")
61
11
 
62
12
  data_type: Optional[DataType]
63
13
  """Type of the dataset, should be one of DataType"""
@@ -132,15 +82,10 @@ class VaeDatasetConfig(BaseModel):
132
82
  # TODO: why is this not used?
133
83
  enable_rotation_aug: Optional[bool] = False
134
84
 
135
- grid_alignment: GridAlignement = GridAlignement.LeftTop
136
-
137
85
  max_val: Optional[float] = None
138
86
  """Maximum data in the dataset. Is calculated for train split, and should be
139
87
  externally set for val and test splits."""
140
88
 
141
- trim_boundary: Optional[bool] = True
142
- """Whether to trim boundary of the image"""
143
-
144
89
  overlapping_padding_kwargs: Any = None
145
90
  """Parameters for np.pad method"""
146
91
 
@@ -157,23 +102,22 @@ class VaeDatasetConfig(BaseModel):
157
102
  train_aug_rotate: Optional[bool] = False
158
103
  enable_random_cropping: Optional[bool] = True
159
104
 
160
- # TODO: not used?
161
105
  multiscale_lowres_count: Optional[int] = None
106
+ """Number of LC scales"""
107
+
108
+ tiling_mode: Optional[TilingMode] = TilingMode.ShiftBoundary
109
+
110
+ target_separate_normalization: Optional[bool] = True
111
+
112
+ mode_3D: Optional[bool] = False
113
+ """If training in 3D mode or not"""
114
+
115
+ trainig_datausage_fraction: Optional[float] = 1.0
116
+
117
+ validtarget_random_fraction: Optional[float] = None
118
+
119
+ validation_datausage_fraction: Optional[float] = 1.0
120
+
121
+ random_flip_z_3D: Optional[bool] = False
162
122
 
163
- @computed_field
164
- @property
165
- def padding_kwargs(self) -> dict:
166
- kwargs_dict = {}
167
- padding_kwargs = {}
168
- if (
169
- self.multiscale_lowres_count is not None
170
- and self.multiscale_lowres_count is not None
171
- ):
172
- # Get padding attributes
173
- if "padding_kwargs" not in kwargs_dict:
174
- padding_kwargs = {}
175
- padding_kwargs["mode"] = "constant"
176
- padding_kwargs["constant_values"] = 0
177
- else:
178
- padding_kwargs = kwargs_dict.pop("padding_kwargs")
179
- return padding_kwargs
123
+ padding_kwargs: Optional[dict] = None
@@ -2,34 +2,37 @@
2
2
  A place for Datasets and Dataloaders.
3
3
  """
4
4
 
5
- from typing import Tuple, Union
5
+ from typing import Tuple, Union, Callable
6
6
 
7
7
  import numpy as np
8
8
  from skimage.transform import resize
9
9
 
10
- from .lc_dataset_config import LCVaeDatasetConfig
11
- from .vae_dataset import MultiChDloader
10
+ from .config import DatasetConfig
11
+ from .multich_dataset import MultiChDloader
12
12
 
13
13
 
14
14
  class LCMultiChDloader(MultiChDloader):
15
-
16
15
  def __init__(
17
16
  self,
18
- data_config: LCVaeDatasetConfig,
17
+ data_config: DatasetConfig,
19
18
  fpath: str,
19
+ load_data_fn: Callable,
20
20
  val_fraction=None,
21
21
  test_fraction=None,
22
22
  ):
23
- """
24
- Args:
25
- num_scales: The number of resolutions at which we want the input. Note that the target is formed at the
26
- highest resolution.
27
- """
28
23
  self._padding_kwargs = (
29
24
  data_config.padding_kwargs # mode=padding_mode, constant_values=constant_value
30
25
  )
31
26
  self._uncorrelated_channel_probab = data_config.uncorrelated_channel_probab
32
27
 
28
+ super().__init__(
29
+ data_config,
30
+ fpath,
31
+ load_data_fn=load_data_fn,
32
+ val_fraction=val_fraction,
33
+ test_fraction=test_fraction,
34
+ )
35
+
33
36
  if data_config.overlapping_padding_kwargs is not None:
34
37
  assert (
35
38
  self._padding_kwargs == data_config.overlapping_padding_kwargs
@@ -37,21 +40,21 @@ class LCMultiChDloader(MultiChDloader):
37
40
  It should be so since we just use overlapping_padding_kwargs when it is not None"
38
41
 
39
42
  else:
40
- overlapping_padding_kwargs = data_config.padding_kwargs
43
+ self._overlapping_padding_kwargs = data_config.padding_kwargs
41
44
 
42
- super().__init__(
43
- data_config, fpath, val_fraction=val_fraction, test_fraction=test_fraction
44
- )
45
- self.num_scales = data_config.num_scales
46
- assert self.num_scales is not None
45
+ self.multiscale_lowres_count = data_config.multiscale_lowres_count
46
+ assert self.multiscale_lowres_count is not None
47
47
  self._scaled_data = [self._data]
48
48
  self._scaled_noise_data = [self._noise_data]
49
49
 
50
- assert isinstance(self.num_scales, int) and self.num_scales >= 1
50
+ assert (
51
+ isinstance(self.multiscale_lowres_count, int)
52
+ and self.multiscale_lowres_count >= 1
53
+ )
51
54
  assert isinstance(self._padding_kwargs, dict)
52
55
  assert "mode" in self._padding_kwargs
53
56
 
54
- for _ in range(1, self.num_scales):
57
+ for _ in range(1, self.multiscale_lowres_count):
55
58
  shape = self._scaled_data[-1].shape
56
59
  assert len(shape) == 4
57
60
  new_shape = (shape[0], shape[1] // 2, shape[2] // 2, shape[3])
@@ -173,7 +176,7 @@ class LCMultiChDloader(MultiChDloader):
173
176
  allres_versions = {
174
177
  i: [cropped_img_tuples[i]] for i in range(len(cropped_img_tuples))
175
178
  }
176
- for scale_idx in range(1, self.num_scales):
179
+ for scale_idx in range(1, self.multiscale_lowres_count):
177
180
  # Returning the image of the lower resolution
178
181
  scaled_img_tuples = self._load_scaled_img(scale_idx, index)
179
182
 
@@ -227,6 +230,9 @@ class LCMultiChDloader(MultiChDloader):
227
230
  factor = np.sqrt(2) if self._input_is_sum else 1.0
228
231
  input_tuples = []
229
232
  for x in img_tuples:
233
+ x = (
234
+ x.copy()
235
+ ) # to avoid changing the original image since it is later used for target
230
236
  # NOTE: other LC levels already have noise added. So, we just need to add noise to the highest resolution.
231
237
  x[0] = x[0] + noise_tuples[0] * factor
232
238
  input_tuples.append(x)
@@ -246,7 +252,9 @@ class LCMultiChDloader(MultiChDloader):
246
252
 
247
253
  target = self._compute_target(target_tuples, alpha)
248
254
 
249
- output = [inp, target]
255
+ norm_target = self.normalize_target(target)
256
+
257
+ output = [inp, norm_target]
250
258
 
251
259
  if self._return_alpha:
252
260
  output.append(alpha)
@@ -2,39 +2,39 @@
2
2
  A place for Datasets and Dataloaders.
3
3
  """
4
4
 
5
- from typing import Tuple, Union
5
+ from typing import Tuple, Union, Callable
6
6
 
7
7
  import numpy as np
8
8
 
9
- from .data_utils import (
10
- GridIndexManager,
11
- IndexSwitcher,
12
- get_train_val_data,
13
- )
14
- from .vae_data_config import VaeDatasetConfig, DataSplitType, GridAlignement
9
+ from .utils.empty_patch_fetcher import EmptyPatchFetcher
10
+ from .utils.index_manager import GridIndexManager
11
+ from .utils.index_switcher import IndexSwitcher
12
+ from .config import DatasetConfig
13
+ from .types import DataSplitType, TilingMode
15
14
 
16
15
 
17
16
  class MultiChDloader:
18
17
  def __init__(
19
18
  self,
20
- data_config: VaeDatasetConfig,
19
+ data_config: DatasetConfig,
21
20
  fpath: str,
21
+ load_data_fn: Callable,
22
22
  val_fraction: float = None,
23
23
  test_fraction: float = None,
24
24
  ):
25
25
  """ """
26
26
  self._data_type = data_config.data_type
27
27
  self._fpath = fpath
28
- self._data = self.N = self._noise_data = None
28
+ self._data = self._noise_data = None
29
29
  self.Z = 1
30
- self._trim_boundary = data_config.trim_boundary
31
- # Hardcoded params, not included in the config file.
32
-
30
+ self._5Ddata = False
31
+ self._tiling_mode = data_config.tiling_mode
33
32
  # by default, if the noise is present, add it to the input and target.
34
33
  self._disable_noise = False # to add synthetic noise
35
34
  self._poisson_noise_factor = None
36
35
  self._train_index_switcher = None
37
36
  self._depth3D = data_config.depth3D
37
+ self._mode_3D = data_config.mode_3D
38
38
  # NOTE: Input is the sum of the different channels. It is not the average of the different channels.
39
39
  self._input_is_sum = data_config.input_is_sum
40
40
  self._num_channels = data_config.num_channels
@@ -42,20 +42,21 @@ class MultiChDloader:
42
42
  self._tar_idx_list = data_config.target_idx_list
43
43
 
44
44
  if data_config.datasplit_type == DataSplitType.Train:
45
- self._datausage_fraction = 1.0
45
+ self._datausage_fraction = data_config.trainig_datausage_fraction
46
46
  # assert self._datausage_fraction == 1.0, 'Not supported. Use validtarget_random_fraction and training_validtarget_fraction to get the same effect'
47
- self._validtarget_rand_fract = None
47
+ self._validtarget_rand_fract = data_config.validtarget_random_fraction
48
48
  # self._validtarget_random_fraction_final = data_config.get('validtarget_random_fraction_final', None)
49
49
  # self._validtarget_random_fraction_stepepoch = data_config.get('validtarget_random_fraction_stepepoch', None)
50
50
  # self._idx_count = 0
51
51
  elif data_config.datasplit_type == DataSplitType.Val:
52
- self._datausage_fraction = 1.0
52
+ self._datausage_fraction = data_config.validation_datausage_fraction
53
53
  else:
54
54
  self._datausage_fraction = 1.0
55
55
 
56
56
  self.load_data(
57
57
  data_config,
58
58
  data_config.datasplit_type,
59
+ load_data_fn=load_data_fn,
59
60
  val_fraction=val_fraction,
60
61
  test_fraction=test_fraction,
61
62
  allow_generation=data_config.allow_generation,
@@ -70,18 +71,8 @@ class MultiChDloader:
70
71
 
71
72
  self._background_values = None
72
73
 
73
- self._grid_alignment = data_config.grid_alignment
74
74
  self._overlapping_padding_kwargs = data_config.overlapping_padding_kwargs
75
- if self._grid_alignment == GridAlignement.LeftTop:
76
- assert (
77
- self._overlapping_padding_kwargs is None
78
- or data_config.multiscale_lowres_count is not None
79
- ), "Padding is not used with this alignement style"
80
- elif self._grid_alignment == GridAlignement.Center:
81
- assert (
82
- self._overlapping_padding_kwargs is not None
83
- ), "With Center grid alignment, padding is needed."
84
- if self._trim_boundary:
75
+ if self._tiling_mode in [TilingMode.TrimBoundary, TilingMode.ShiftBoundary]:
85
76
  if (
86
77
  self._overlapping_padding_kwargs is None
87
78
  or data_config.multiscale_lowres_count is not None
@@ -144,7 +135,6 @@ class MultiChDloader:
144
135
  )
145
136
  data_frames = self._data[..., self._empty_patch_replacement_channel_idx]
146
137
  # NOTE: This is on the raw data. So, it must be called before removing the background.
147
- # TODO: missing import, needs fixing asap!
148
138
  self._empty_patch_fetcher = EmptyPatchFetcher(
149
139
  self.idx_manager,
150
140
  self._img_sz,
@@ -161,14 +151,18 @@ class MultiChDloader:
161
151
  self._mean = None
162
152
  self._std = None
163
153
  self._use_one_mu_std = data_config.use_one_mu_std
164
- # Hardcoded
165
- self._target_separate_normalization = True
154
+
155
+ self._target_separate_normalization = data_config.target_separate_normalization
166
156
 
167
157
  self._enable_rotation = data_config.enable_rotation_aug
158
+ flipz_3D = data_config.random_flip_z_3D
159
+ self._flipz_3D = flipz_3D and self._enable_rotation
160
+
168
161
  self._enable_random_cropping = data_config.enable_random_cropping
169
162
  self._uncorrelated_channels = (
170
163
  data_config.uncorrelated_channels and self._is_train
171
164
  )
165
+ self._uncorrelated_channel_probab = data_config.uncorrelated_channel_probab
172
166
  assert self._is_train or self._uncorrelated_channels is False
173
167
  assert (
174
168
  self._enable_random_cropping is True or self._uncorrelated_channels is False
@@ -177,9 +171,9 @@ class MultiChDloader:
177
171
 
178
172
  self._rotation_transform = None
179
173
  if self._enable_rotation:
180
- raise NotImplementedError(
181
- "Augmentation by means of rotation is not supported yet."
182
- )
174
+ # TODO: fix this import
175
+ import albumentations as A
176
+
183
177
  self._rotation_transform = A.Compose([A.Flip(), A.RandomRotate90()])
184
178
 
185
179
  # TODO: remove print log messages
@@ -203,11 +197,12 @@ class MultiChDloader:
203
197
  self,
204
198
  data_config,
205
199
  datasplit_type,
200
+ load_data_fn: Callable,
206
201
  val_fraction=None,
207
202
  test_fraction=None,
208
203
  allow_generation=None,
209
204
  ):
210
- self._data = get_train_val_data(
205
+ self._data = load_data_fn(
211
206
  data_config,
212
207
  self._fpath,
213
208
  datasplit_type,
@@ -215,7 +210,9 @@ class MultiChDloader:
215
210
  test_fraction=test_fraction,
216
211
  allow_generation=allow_generation,
217
212
  )
213
+ self._loaded_data_preprocessing(data_config)
218
214
 
215
+ def _loaded_data_preprocessing(self, data_config):
219
216
  old_shape = self._data.shape
220
217
  if self._datausage_fraction < 1.0:
221
218
  framepixelcount = np.prod(self._data.shape[1:3])
@@ -239,10 +236,7 @@ class MultiChDloader:
239
236
  if data_config.poisson_noise_factor > 0:
240
237
  self._poisson_noise_factor = data_config.poisson_noise_factor
241
238
  msg += f"Adding Poisson noise with factor {self._poisson_noise_factor}.\t"
242
- self._data = (
243
- np.random.poisson(self._data / self._poisson_noise_factor)
244
- * self._poisson_noise_factor
245
- )
239
+ self._data = np.random.poisson(self._data / self._poisson_noise_factor)
246
240
 
247
241
  if data_config.enable_gaussian_noise:
248
242
  synthetic_scale = data_config.synthetic_gaussian_scale
@@ -257,7 +251,13 @@ class MultiChDloader:
257
251
  self._noise_data[..., 0] = np.mean(self._noise_data[..., 1:], axis=-1)
258
252
  print(msg)
259
253
 
260
- self._5Ddata = len(self._data.shape) == 5
254
+ if len(self._data.shape) == 5:
255
+ if self._mode_3D:
256
+ self._5Ddata = True
257
+ else:
258
+ assert self._depth3D == 1, "Depth3D must be 1 for 2D training"
259
+ self._data = self._data.reshape(-1, *self._data.shape[2:])
260
+
261
261
  if self._5Ddata:
262
262
  self.Z = self._data.shape[1]
263
263
 
@@ -373,18 +373,28 @@ class MultiChDloader:
373
373
  f"[{self.__class__.__name__}] Data reduced. New data shape: {self._data.shape}"
374
374
  )
375
375
 
376
- def get_idx_manager_shapes(self, patch_size: int, grid_size: int):
376
+ def get_idx_manager_shapes(
377
+ self, patch_size: int, grid_size: Union[int, Tuple[int, int, int]]
378
+ ):
377
379
  numC = self._data.shape[-1]
378
380
  if self._5Ddata:
379
- grid_shape = (1, 1, grid_size, grid_size, numC)
380
381
  patch_shape = (1, self._depth3D, patch_size, patch_size, numC)
382
+ if isinstance(grid_size, int):
383
+ grid_shape = (1, 1, grid_size, grid_size, numC)
384
+ else:
385
+ assert len(grid_size) == 3
386
+ assert all(
387
+ [g <= p for g, p in zip(grid_size, patch_shape[1:-1])]
388
+ ), f"Grid size {grid_size} must be less than patch size {patch_shape[1:-1]}"
389
+ grid_shape = (1, grid_size[0], grid_size[1], grid_size[2], numC)
381
390
  else:
391
+ assert isinstance(grid_size, int)
382
392
  grid_shape = (1, grid_size, grid_size, numC)
383
393
  patch_shape = (1, patch_size, patch_size, numC)
384
394
 
385
395
  return patch_shape, grid_shape
386
396
 
387
- def set_img_sz(self, image_size, grid_size):
397
+ def set_img_sz(self, image_size, grid_size: Union[int, Tuple[int, int, int]]):
388
398
  """
389
399
  If one wants to change the image size on the go, then this can be used.
390
400
  Args:
@@ -400,7 +410,7 @@ class MultiChDloader:
400
410
  self._img_sz, self._grid_sz
401
411
  )
402
412
  self.idx_manager = GridIndexManager(
403
- shape, grid_shape, patch_shape, self._trim_boundary
413
+ shape, grid_shape, patch_shape, self._tiling_mode
404
414
  )
405
415
  # self.set_repeat_factor()
406
416
 
@@ -432,10 +442,13 @@ class MultiChDloader:
432
442
  dim_sizes = ",".join([str(x) for x in dim_sizes])
433
443
  msg += f" N:{self.N} NumPatchPerN:{self._repeat_factor}"
434
444
  msg += f"{self.idx_manager.total_grid_count()} DimSz:({dim_sizes})"
435
- msg += f" TrimB:{self._trim_boundary}"
445
+ msg += f" TrimB:{self._tiling_mode}"
436
446
  # msg += f' NormInp:{self._normalized_input}'
437
447
  # msg += f' SingleNorm:{self._use_one_mu_std}'
438
448
  msg += f" Rot:{self._enable_rotation}"
449
+ if self._flipz_3D:
450
+ msg += f" FlipZ:{self._flipz_3D}"
451
+
439
452
  msg += f" RandCrop:{self._enable_random_cropping}"
440
453
  msg += f" Channel:{self._num_channels}"
441
454
  # msg += f' Q:{self._quantile}'
@@ -467,7 +480,7 @@ class MultiChDloader:
467
480
  patch_start_loc = self._get_random_hw(h, w)
468
481
  if self._5Ddata:
469
482
  patch_start_loc = (
470
- np.random.choice(img_tuples[0].shape[-3] - self._depth3D),
483
+ np.random.choice(1 + img_tuples[0].shape[-3] - self._depth3D),
471
484
  ) + patch_start_loc
472
485
  else:
473
486
  patch_start_loc = self._get_deterministic_loc(index)
@@ -486,7 +499,7 @@ class MultiChDloader:
486
499
  )
487
500
 
488
501
  def _crop_img(self, img: np.ndarray, patch_start_loc: Tuple):
489
- if self._trim_boundary:
502
+ if self._tiling_mode in [TilingMode.TrimBoundary, TilingMode.ShiftBoundary]:
490
503
  # In training, this is used.
491
504
  # NOTE: It is my opinion that if I just use self._crop_img_with_padding, it will work perfectly fine.
492
505
  # The only benefit this if else loop provides is that it makes it easier to see what happens during training.
@@ -625,6 +638,18 @@ class MultiChDloader:
625
638
  normalized_imgs.append(img)
626
639
  return tuple(normalized_imgs)
627
640
 
641
+ def normalize_input(self, x):
642
+ mean_dict, std_dict = self.get_mean_std()
643
+ mean_ = mean_dict["input"].mean()
644
+ std_ = std_dict["input"].mean()
645
+ return (x - mean_) / std_
646
+
647
+ def normalize_target(self, target):
648
+ mean_dict, std_dict = self.get_mean_std()
649
+ mean_ = mean_dict["target"].squeeze(0)
650
+ std_ = std_dict["target"].squeeze(0)
651
+ return (target - mean_) / std_
652
+
628
653
  def get_grid_size(self):
629
654
  return self._grid_sz
630
655
 
@@ -906,28 +931,39 @@ class MultiChDloader:
906
931
  return rotated_img_tuples, rotated_noise_tuples
907
932
 
908
933
  def _rotate(self, img_tuples, noise_tuples):
909
- if self._depth3D > 1:
934
+
935
+ if self._5Ddata:
910
936
  return self._rotate3D(img_tuples, noise_tuples)
911
937
  else:
912
938
  return self._rotate2D(img_tuples, noise_tuples)
913
939
 
914
940
  def _rotate3D(self, img_tuples, noise_tuples):
915
941
  img_kwargs = {}
942
+ # random flip in z direction
943
+ flip_z = self._flipz_3D and np.random.rand() < 0.5
916
944
  for i, img in enumerate(img_tuples):
917
945
  for j in range(self._depth3D):
918
946
  for k in range(len(img)):
919
- img_kwargs[f"img{i}_{j}_{k}"] = img[k, j]
947
+ if flip_z:
948
+ z_idx = self._depth3D - 1 - j
949
+ else:
950
+ z_idx = j
951
+ img_kwargs[f"img{i}_{z_idx}_{k}"] = img[k, j]
920
952
 
921
953
  noise_kwargs = {}
922
954
  for i, nimg in enumerate(noise_tuples):
923
955
  for j in range(self._depth3D):
924
956
  for k in range(len(nimg)):
925
- noise_kwargs[f"noise{i}_{j}_{k}"] = nimg[k, j]
957
+ if flip_z:
958
+ z_idx = self._depth3D - 1 - j
959
+ else:
960
+ z_idx = j
961
+ noise_kwargs[f"noise{i}_{z_idx}_{k}"] = nimg[k, j]
926
962
 
927
963
  keys = list(img_kwargs.keys()) + list(noise_kwargs.keys())
928
964
  self._rotation_transform.add_targets({k: "image" for k in keys})
929
965
  rot_dic = self._rotation_transform(
930
- image=img_tuples[0][0], **img_kwargs, **noise_kwargs
966
+ image=img_tuples[0][0][0], **img_kwargs, **noise_kwargs
931
967
  )
932
968
  rotated_img_tuples = []
933
969
  for i, img in enumerate(img_tuples):
@@ -1006,7 +1042,10 @@ class MultiChDloader:
1006
1042
  if self._train_index_switcher is not None:
1007
1043
  index = self._get_index_from_valid_target_logic(index)
1008
1044
 
1009
- if self._uncorrelated_channels:
1045
+ if (
1046
+ self._uncorrelated_channels
1047
+ and np.random.rand() < self._uncorrelated_channel_probab
1048
+ ):
1010
1049
  img_tuples, noise_tuples = self.get_uncorrelated_img_tuples(index)
1011
1050
  else:
1012
1051
  img_tuples, noise_tuples = self._get_img(index)
@@ -1042,8 +1081,9 @@ class MultiChDloader:
1042
1081
  img_tuples = [x + noise for x, noise in zip(img_tuples, noise_tuples[1:])]
1043
1082
 
1044
1083
  target = self._compute_target(img_tuples, alpha)
1084
+ norm_target = self.normalize_target(target)
1045
1085
 
1046
- output = [inp, target]
1086
+ output = [inp, norm_target]
1047
1087
 
1048
1088
  if self._return_alpha:
1049
1089
  output.append(alpha)