careamics 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +39 -28
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +391 -0
- careamics/cli/main.py +134 -0
- careamics/config/__init__.py +7 -3
- careamics/config/architectures/__init__.py +2 -2
- careamics/config/architectures/architecture_model.py +1 -1
- careamics/config/architectures/custom_model.py +11 -8
- careamics/config/architectures/lvae_model.py +170 -0
- careamics/config/configuration_factory.py +481 -170
- careamics/config/configuration_model.py +6 -3
- careamics/config/data_model.py +31 -20
- careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +35 -45
- careamics/config/likelihood_model.py +60 -0
- careamics/config/nm_model.py +127 -0
- careamics/config/optimizer_models.py +3 -1
- careamics/config/support/supported_activations.py +1 -0
- careamics/config/support/supported_algorithms.py +17 -4
- careamics/config/support/supported_architectures.py +8 -11
- careamics/config/support/supported_losses.py +3 -1
- careamics/config/support/supported_optimizers.py +1 -1
- careamics/config/support/supported_transforms.py +1 -0
- careamics/config/training_model.py +35 -6
- careamics/config/transformations/__init__.py +4 -1
- careamics/config/transformations/n2v_manipulate_model.py +1 -1
- careamics/config/transformations/transform_union.py +20 -0
- careamics/config/vae_algorithm_model.py +137 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +364 -0
- careamics/file_io/read/tiff.py +1 -1
- careamics/lightning/__init__.py +3 -2
- careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
- careamics/lightning/lightning_module.py +367 -9
- careamics/lightning/predict_data_module.py +2 -2
- careamics/lightning/train_data_module.py +4 -4
- careamics/losses/__init__.py +11 -1
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/{losses.py → fcn/losses.py} +1 -1
- careamics/losses/loss_factory.py +112 -6
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/dataset/__init__.py +15 -0
- careamics/lvae_training/dataset/config.py +123 -0
- careamics/lvae_training/dataset/lc_dataset.py +267 -0
- careamics/lvae_training/{data_modules.py → dataset/multich_dataset.py} +375 -501
- careamics/lvae_training/dataset/multifile_dataset.py +334 -0
- careamics/lvae_training/dataset/types.py +43 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +232 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +109 -64
- careamics/lvae_training/get_config.py +1 -1
- careamics/lvae_training/train_lvae.py +6 -3
- careamics/model_io/bioimage/bioimage_utils.py +1 -1
- careamics/model_io/bioimage/model_description.py +2 -2
- careamics/model_io/bmz_io.py +20 -7
- careamics/model_io/model_io_utils.py +16 -4
- careamics/models/__init__.py +1 -3
- careamics/models/activation.py +2 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +21 -21
- careamics/models/lvae/likelihoods.py +190 -129
- careamics/models/lvae/lvae.py +60 -148
- careamics/models/lvae/noise_models.py +318 -186
- careamics/models/lvae/utils.py +2 -2
- careamics/models/model_factory.py +22 -7
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/stitch_prediction.py +16 -2
- careamics/transforms/compose.py +90 -15
- careamics/transforms/n2v_manipulate.py +6 -2
- careamics/transforms/normalize.py +14 -3
- careamics/transforms/pixel_manipulation.py +1 -1
- careamics/transforms/xy_flip.py +16 -6
- careamics/transforms/xy_random_rotate90.py +16 -7
- careamics/utils/metrics.py +277 -24
- careamics/utils/serializers.py +60 -0
- {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/METADATA +5 -4
- {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/RECORD +85 -60
- careamics-0.0.4.dist-info/entry_points.txt +2 -0
- careamics/config/architectures/vae_model.py +0 -42
- careamics/lvae_training/data_utils.py +0 -618
- {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/WHEEL +0 -0
- {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,189 +2,66 @@
|
|
|
2
2
|
A place for Datasets and Dataloaders.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
import
|
|
6
|
-
from typing import Tuple, Union
|
|
5
|
+
from typing import Tuple, Union, Callable
|
|
7
6
|
|
|
8
|
-
# import albumentations as A
|
|
9
|
-
import ml_collections
|
|
10
7
|
import numpy as np
|
|
11
|
-
from skimage.transform import resize
|
|
12
|
-
|
|
13
|
-
from .data_utils import (
|
|
14
|
-
DataSplitType,
|
|
15
|
-
DataType,
|
|
16
|
-
GridAlignement,
|
|
17
|
-
GridIndexManager,
|
|
18
|
-
IndexSwitcher,
|
|
19
|
-
get_datasplit_tuples,
|
|
20
|
-
get_mrc_data,
|
|
21
|
-
load_tiff,
|
|
22
|
-
)
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
def get_train_val_data(
|
|
26
|
-
data_config,
|
|
27
|
-
fpath,
|
|
28
|
-
datasplit_type: DataSplitType,
|
|
29
|
-
val_fraction=None,
|
|
30
|
-
test_fraction=None,
|
|
31
|
-
allow_generation=None,
|
|
32
|
-
ignore_specific_datapoints=None,
|
|
33
|
-
):
|
|
34
|
-
"""
|
|
35
|
-
Load the data from the given path and split them in training, validation and test sets.
|
|
36
|
-
|
|
37
|
-
Ensure that the shape of data should be N*H*W*C: N is number of data points. H,W are the image dimensions.
|
|
38
|
-
C is the number of channels.
|
|
39
|
-
"""
|
|
40
|
-
if data_config.data_type == DataType.SeparateTiffData:
|
|
41
|
-
fpath1 = os.path.join(fpath, data_config.ch1_fname)
|
|
42
|
-
fpath2 = os.path.join(fpath, data_config.ch2_fname)
|
|
43
|
-
fpaths = [fpath1, fpath2]
|
|
44
|
-
fpath0 = ""
|
|
45
|
-
if "ch_input_fname" in data_config:
|
|
46
|
-
fpath0 = os.path.join(fpath, data_config.ch_input_fname)
|
|
47
|
-
fpaths = [fpath0] + fpaths
|
|
48
8
|
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
data = np.concatenate([load_tiff(fpath)[..., None] for fpath in fpaths], axis=3)
|
|
55
|
-
if data_config.data_type == DataType.PredictedTiffData:
|
|
56
|
-
assert len(data.shape) == 5 and data.shape[-1] == 1
|
|
57
|
-
data = data[..., 0].copy()
|
|
58
|
-
# data = data[::3].copy()
|
|
59
|
-
# NOTE: This was not the correct way to do it. It is so because the noise present in the input was directly related
|
|
60
|
-
# to the noise present in the channels and so this is not the way we would get the data.
|
|
61
|
-
# We need to add the noise independently to the input and the target.
|
|
62
|
-
|
|
63
|
-
# if data_config.get('poisson_noise_factor', False):
|
|
64
|
-
# data = np.random.poisson(data)
|
|
65
|
-
# if data_config.get('enable_gaussian_noise', False):
|
|
66
|
-
# synthetic_scale = data_config.get('synthetic_gaussian_scale', 0.1)
|
|
67
|
-
# print('Adding Gaussian noise with scale', synthetic_scale)
|
|
68
|
-
# noise = np.random.normal(0, synthetic_scale, data.shape)
|
|
69
|
-
# data = data + noise
|
|
70
|
-
|
|
71
|
-
if datasplit_type == DataSplitType.All:
|
|
72
|
-
return data.astype(np.float32)
|
|
73
|
-
|
|
74
|
-
train_idx, val_idx, test_idx = get_datasplit_tuples(
|
|
75
|
-
val_fraction, test_fraction, len(data), starting_test=True
|
|
76
|
-
)
|
|
77
|
-
if datasplit_type == DataSplitType.Train:
|
|
78
|
-
return data[train_idx].astype(np.float32)
|
|
79
|
-
elif datasplit_type == DataSplitType.Val:
|
|
80
|
-
return data[val_idx].astype(np.float32)
|
|
81
|
-
elif datasplit_type == DataSplitType.Test:
|
|
82
|
-
return data[test_idx].astype(np.float32)
|
|
83
|
-
|
|
84
|
-
elif data_config.data_type == DataType.BioSR_MRC:
|
|
85
|
-
num_channels = data_config.get("num_channels", 2)
|
|
86
|
-
fpaths = []
|
|
87
|
-
data_list = []
|
|
88
|
-
for i in range(num_channels):
|
|
89
|
-
fpath1 = os.path.join(fpath, data_config.get(f"ch{i + 1}_fname"))
|
|
90
|
-
fpaths.append(fpath1)
|
|
91
|
-
data = get_mrc_data(fpath1)[..., None]
|
|
92
|
-
data_list.append(data)
|
|
93
|
-
|
|
94
|
-
dirname = os.path.dirname(os.path.dirname(fpaths[0])) + "/"
|
|
95
|
-
|
|
96
|
-
msg = ",".join([x[len(dirname) :] for x in fpaths])
|
|
97
|
-
print(
|
|
98
|
-
f"Loaded from {dirname} Channels:{len(fpaths)} {msg} Mode:{DataSplitType.name(datasplit_type)}"
|
|
99
|
-
)
|
|
100
|
-
N = data_list[0].shape[0]
|
|
101
|
-
for data in data_list:
|
|
102
|
-
N = min(N, data.shape[0])
|
|
103
|
-
|
|
104
|
-
cropped_data = []
|
|
105
|
-
for data in data_list:
|
|
106
|
-
cropped_data.append(data[:N])
|
|
107
|
-
|
|
108
|
-
data = np.concatenate(cropped_data, axis=3)
|
|
109
|
-
|
|
110
|
-
if datasplit_type == DataSplitType.All:
|
|
111
|
-
return data.astype(np.float32)
|
|
112
|
-
|
|
113
|
-
train_idx, val_idx, test_idx = get_datasplit_tuples(
|
|
114
|
-
val_fraction, test_fraction, len(data), starting_test=True
|
|
115
|
-
)
|
|
116
|
-
if datasplit_type == DataSplitType.Train:
|
|
117
|
-
return data[train_idx].astype(np.float32)
|
|
118
|
-
elif datasplit_type == DataSplitType.Val:
|
|
119
|
-
return data[val_idx].astype(np.float32)
|
|
120
|
-
elif datasplit_type == DataSplitType.Test:
|
|
121
|
-
return data[test_idx].astype(np.float32)
|
|
9
|
+
from .utils.empty_patch_fetcher import EmptyPatchFetcher
|
|
10
|
+
from .utils.index_manager import GridIndexManager
|
|
11
|
+
from .utils.index_switcher import IndexSwitcher
|
|
12
|
+
from .config import DatasetConfig
|
|
13
|
+
from .types import DataSplitType, TilingMode
|
|
122
14
|
|
|
123
15
|
|
|
124
16
|
class MultiChDloader:
|
|
125
|
-
|
|
126
17
|
def __init__(
|
|
127
18
|
self,
|
|
128
|
-
data_config:
|
|
19
|
+
data_config: DatasetConfig,
|
|
129
20
|
fpath: str,
|
|
130
|
-
|
|
21
|
+
load_data_fn: Callable,
|
|
131
22
|
val_fraction: float = None,
|
|
132
23
|
test_fraction: float = None,
|
|
133
|
-
normalized_input=None,
|
|
134
|
-
enable_rotation_aug: bool = False,
|
|
135
|
-
enable_random_cropping: bool = False,
|
|
136
|
-
use_one_mu_std=None,
|
|
137
|
-
allow_generation: bool = False,
|
|
138
|
-
max_val: float = None,
|
|
139
|
-
grid_alignment=GridAlignement.LeftTop,
|
|
140
|
-
overlapping_padding_kwargs=None,
|
|
141
|
-
print_vars: bool = True,
|
|
142
24
|
):
|
|
143
|
-
"""
|
|
144
|
-
Here, an image is split into grids of size img_sz.
|
|
145
|
-
Args:
|
|
146
|
-
repeat_factor: Since we are doing a random crop, repeat_factor is
|
|
147
|
-
given which can repeatedly sample from the same image. If self.N=12
|
|
148
|
-
and repeat_factor is 5, then index upto 12*5 = 60 is allowed.
|
|
149
|
-
use_one_mu_std: If this is set to true, then one mean and stdev is used
|
|
150
|
-
for both channels. Otherwise, two different meean and stdev are used.
|
|
151
|
-
|
|
152
|
-
"""
|
|
25
|
+
""" """
|
|
153
26
|
self._data_type = data_config.data_type
|
|
154
27
|
self._fpath = fpath
|
|
155
|
-
self._data = self.
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
28
|
+
self._data = self._noise_data = None
|
|
29
|
+
self.Z = 1
|
|
30
|
+
self._5Ddata = False
|
|
31
|
+
self._tiling_mode = data_config.tiling_mode
|
|
159
32
|
# by default, if the noise is present, add it to the input and target.
|
|
160
33
|
self._disable_noise = False # to add synthetic noise
|
|
34
|
+
self._poisson_noise_factor = None
|
|
161
35
|
self._train_index_switcher = None
|
|
36
|
+
self._depth3D = data_config.depth3D
|
|
37
|
+
self._mode_3D = data_config.mode_3D
|
|
162
38
|
# NOTE: Input is the sum of the different channels. It is not the average of the different channels.
|
|
163
|
-
self._input_is_sum = data_config.
|
|
164
|
-
self._num_channels = data_config.
|
|
165
|
-
self._input_idx = data_config.
|
|
166
|
-
self._tar_idx_list = data_config.
|
|
39
|
+
self._input_is_sum = data_config.input_is_sum
|
|
40
|
+
self._num_channels = data_config.num_channels
|
|
41
|
+
self._input_idx = data_config.input_idx
|
|
42
|
+
self._tar_idx_list = data_config.target_idx_list
|
|
167
43
|
|
|
168
|
-
if datasplit_type == DataSplitType.Train:
|
|
169
|
-
self._datausage_fraction =
|
|
44
|
+
if data_config.datasplit_type == DataSplitType.Train:
|
|
45
|
+
self._datausage_fraction = data_config.trainig_datausage_fraction
|
|
170
46
|
# assert self._datausage_fraction == 1.0, 'Not supported. Use validtarget_random_fraction and training_validtarget_fraction to get the same effect'
|
|
171
|
-
self._validtarget_rand_fract =
|
|
47
|
+
self._validtarget_rand_fract = data_config.validtarget_random_fraction
|
|
172
48
|
# self._validtarget_random_fraction_final = data_config.get('validtarget_random_fraction_final', None)
|
|
173
49
|
# self._validtarget_random_fraction_stepepoch = data_config.get('validtarget_random_fraction_stepepoch', None)
|
|
174
50
|
# self._idx_count = 0
|
|
175
|
-
elif datasplit_type == DataSplitType.Val:
|
|
176
|
-
self._datausage_fraction =
|
|
51
|
+
elif data_config.datasplit_type == DataSplitType.Val:
|
|
52
|
+
self._datausage_fraction = data_config.validation_datausage_fraction
|
|
177
53
|
else:
|
|
178
54
|
self._datausage_fraction = 1.0
|
|
179
55
|
|
|
180
56
|
self.load_data(
|
|
181
57
|
data_config,
|
|
182
|
-
datasplit_type,
|
|
58
|
+
data_config.datasplit_type,
|
|
59
|
+
load_data_fn=load_data_fn,
|
|
183
60
|
val_fraction=val_fraction,
|
|
184
61
|
test_fraction=test_fraction,
|
|
185
|
-
allow_generation=allow_generation,
|
|
62
|
+
allow_generation=data_config.allow_generation,
|
|
186
63
|
)
|
|
187
|
-
self._normalized_input = normalized_input
|
|
64
|
+
self._normalized_input = data_config.normalized_input
|
|
188
65
|
self._quantile = 1.0
|
|
189
66
|
self._channelwise_quantile = False
|
|
190
67
|
self._background_quantile = 0.0
|
|
@@ -194,31 +71,29 @@ class MultiChDloader:
|
|
|
194
71
|
|
|
195
72
|
self._background_values = None
|
|
196
73
|
|
|
197
|
-
self.
|
|
198
|
-
self.
|
|
199
|
-
|
|
200
|
-
assert (
|
|
74
|
+
self._overlapping_padding_kwargs = data_config.overlapping_padding_kwargs
|
|
75
|
+
if self._tiling_mode in [TilingMode.TrimBoundary, TilingMode.ShiftBoundary]:
|
|
76
|
+
if (
|
|
201
77
|
self._overlapping_padding_kwargs is None
|
|
202
78
|
or data_config.multiscale_lowres_count is not None
|
|
203
|
-
)
|
|
204
|
-
|
|
79
|
+
):
|
|
80
|
+
# raise warning
|
|
81
|
+
print("Padding is not used with this alignement style")
|
|
82
|
+
else:
|
|
205
83
|
assert (
|
|
206
84
|
self._overlapping_padding_kwargs is not None
|
|
207
|
-
), "
|
|
85
|
+
), "When not trimming boudnary, padding is needed."
|
|
208
86
|
|
|
209
|
-
self._is_train = datasplit_type == DataSplitType.Train
|
|
87
|
+
self._is_train = data_config.datasplit_type == DataSplitType.Train
|
|
210
88
|
|
|
211
89
|
# input = alpha * ch1 + (1-alpha)*ch2.
|
|
212
90
|
# alpha is sampled randomly between these two extremes
|
|
213
|
-
self._start_alpha_arr = self._end_alpha_arr = self._return_alpha =
|
|
214
|
-
self._alpha_weighted_target
|
|
215
|
-
) = None
|
|
91
|
+
self._start_alpha_arr = self._end_alpha_arr = self._return_alpha = None
|
|
216
92
|
|
|
217
93
|
self._img_sz = self._grid_sz = self._repeat_factor = self.idx_manager = None
|
|
218
94
|
if self._is_train:
|
|
219
|
-
self._start_alpha_arr =
|
|
220
|
-
self._end_alpha_arr =
|
|
221
|
-
self._alpha_weighted_target = False
|
|
95
|
+
self._start_alpha_arr = data_config.start_alpha
|
|
96
|
+
self._end_alpha_arr = data_config.end_alpha
|
|
222
97
|
|
|
223
98
|
self.set_img_sz(
|
|
224
99
|
data_config.image_size,
|
|
@@ -229,11 +104,13 @@ class MultiChDloader:
|
|
|
229
104
|
),
|
|
230
105
|
)
|
|
231
106
|
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
107
|
+
if self._validtarget_rand_fract is not None:
|
|
108
|
+
self._train_index_switcher = IndexSwitcher(
|
|
109
|
+
self.idx_manager, data_config, self._img_sz
|
|
110
|
+
)
|
|
235
111
|
|
|
236
112
|
else:
|
|
113
|
+
|
|
237
114
|
self.set_img_sz(
|
|
238
115
|
data_config.image_size,
|
|
239
116
|
(
|
|
@@ -246,33 +123,46 @@ class MultiChDloader:
|
|
|
246
123
|
self._return_alpha = False
|
|
247
124
|
self._return_index = False
|
|
248
125
|
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
126
|
+
self._empty_patch_replacement_enabled = (
|
|
127
|
+
data_config.empty_patch_replacement_enabled and self._is_train
|
|
128
|
+
)
|
|
129
|
+
if self._empty_patch_replacement_enabled:
|
|
130
|
+
self._empty_patch_replacement_channel_idx = (
|
|
131
|
+
data_config.empty_patch_replacement_channel_idx
|
|
132
|
+
)
|
|
133
|
+
self._empty_patch_replacement_probab = (
|
|
134
|
+
data_config.empty_patch_replacement_probab
|
|
135
|
+
)
|
|
136
|
+
data_frames = self._data[..., self._empty_patch_replacement_channel_idx]
|
|
137
|
+
# NOTE: This is on the raw data. So, it must be called before removing the background.
|
|
138
|
+
self._empty_patch_fetcher = EmptyPatchFetcher(
|
|
139
|
+
self.idx_manager,
|
|
140
|
+
self._img_sz,
|
|
141
|
+
data_frames,
|
|
142
|
+
max_val_threshold=data_config.empty_patch_max_val_threshold,
|
|
143
|
+
)
|
|
260
144
|
|
|
261
|
-
self.rm_bkground_set_max_val_and_upperclip_data(
|
|
145
|
+
self.rm_bkground_set_max_val_and_upperclip_data(
|
|
146
|
+
data_config.max_val, data_config.datasplit_type
|
|
147
|
+
)
|
|
262
148
|
|
|
263
149
|
# For overlapping dloader, image_size and repeat_factors are not related. hence a different function.
|
|
264
150
|
|
|
265
151
|
self._mean = None
|
|
266
152
|
self._std = None
|
|
267
|
-
self._use_one_mu_std = use_one_mu_std
|
|
268
|
-
|
|
269
|
-
self._target_separate_normalization =
|
|
153
|
+
self._use_one_mu_std = data_config.use_one_mu_std
|
|
154
|
+
|
|
155
|
+
self._target_separate_normalization = data_config.target_separate_normalization
|
|
156
|
+
|
|
157
|
+
self._enable_rotation = data_config.enable_rotation_aug
|
|
158
|
+
flipz_3D = data_config.random_flip_z_3D
|
|
159
|
+
self._flipz_3D = flipz_3D and self._enable_rotation
|
|
270
160
|
|
|
271
|
-
self.
|
|
272
|
-
self._enable_random_cropping = enable_random_cropping
|
|
161
|
+
self._enable_random_cropping = data_config.enable_random_cropping
|
|
273
162
|
self._uncorrelated_channels = (
|
|
274
|
-
data_config.
|
|
163
|
+
data_config.uncorrelated_channels and self._is_train
|
|
275
164
|
)
|
|
165
|
+
self._uncorrelated_channel_probab = data_config.uncorrelated_channel_probab
|
|
276
166
|
assert self._is_train or self._uncorrelated_channels is False
|
|
277
167
|
assert (
|
|
278
168
|
self._enable_random_cropping is True or self._uncorrelated_channels is False
|
|
@@ -281,14 +171,15 @@ class MultiChDloader:
|
|
|
281
171
|
|
|
282
172
|
self._rotation_transform = None
|
|
283
173
|
if self._enable_rotation:
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
174
|
+
# TODO: fix this import
|
|
175
|
+
import albumentations as A
|
|
176
|
+
|
|
287
177
|
self._rotation_transform = A.Compose([A.Flip(), A.RandomRotate90()])
|
|
288
178
|
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
179
|
+
# TODO: remove print log messages
|
|
180
|
+
# if print_vars:
|
|
181
|
+
# msg = self._init_msg()
|
|
182
|
+
# print(msg)
|
|
292
183
|
|
|
293
184
|
def disable_noise(self):
|
|
294
185
|
assert (
|
|
@@ -306,11 +197,12 @@ class MultiChDloader:
|
|
|
306
197
|
self,
|
|
307
198
|
data_config,
|
|
308
199
|
datasplit_type,
|
|
200
|
+
load_data_fn: Callable,
|
|
309
201
|
val_fraction=None,
|
|
310
202
|
test_fraction=None,
|
|
311
203
|
allow_generation=None,
|
|
312
204
|
):
|
|
313
|
-
self._data =
|
|
205
|
+
self._data = load_data_fn(
|
|
314
206
|
data_config,
|
|
315
207
|
self._fpath,
|
|
316
208
|
datasplit_type,
|
|
@@ -318,7 +210,9 @@ class MultiChDloader:
|
|
|
318
210
|
test_fraction=test_fraction,
|
|
319
211
|
allow_generation=allow_generation,
|
|
320
212
|
)
|
|
213
|
+
self._loaded_data_preprocessing(data_config)
|
|
321
214
|
|
|
215
|
+
def _loaded_data_preprocessing(self, data_config):
|
|
322
216
|
old_shape = self._data.shape
|
|
323
217
|
if self._datausage_fraction < 1.0:
|
|
324
218
|
framepixelcount = np.prod(self._data.shape[1:3])
|
|
@@ -339,28 +233,37 @@ class MultiChDloader:
|
|
|
339
233
|
)
|
|
340
234
|
|
|
341
235
|
msg = ""
|
|
342
|
-
if data_config.
|
|
236
|
+
if data_config.poisson_noise_factor > 0:
|
|
343
237
|
self._poisson_noise_factor = data_config.poisson_noise_factor
|
|
344
238
|
msg += f"Adding Poisson noise with factor {self._poisson_noise_factor}.\t"
|
|
345
|
-
self._data = (
|
|
346
|
-
np.random.poisson(self._data / self._poisson_noise_factor)
|
|
347
|
-
* self._poisson_noise_factor
|
|
348
|
-
)
|
|
239
|
+
self._data = np.random.poisson(self._data / self._poisson_noise_factor)
|
|
349
240
|
|
|
350
|
-
if data_config.
|
|
351
|
-
synthetic_scale = data_config.
|
|
241
|
+
if data_config.enable_gaussian_noise:
|
|
242
|
+
synthetic_scale = data_config.synthetic_gaussian_scale
|
|
352
243
|
msg += f"Adding Gaussian noise with scale {synthetic_scale}"
|
|
353
244
|
# 0 => noise for input. 1: => noise for all targets.
|
|
354
245
|
shape = self._data.shape
|
|
355
246
|
self._noise_data = np.random.normal(
|
|
356
247
|
0, synthetic_scale, (*shape[:-1], shape[-1] + 1)
|
|
357
248
|
)
|
|
358
|
-
if data_config.
|
|
249
|
+
if data_config.input_has_dependant_noise:
|
|
359
250
|
msg += ". Moreover, input has dependent noise"
|
|
360
251
|
self._noise_data[..., 0] = np.mean(self._noise_data[..., 1:], axis=-1)
|
|
361
252
|
print(msg)
|
|
362
253
|
|
|
363
|
-
|
|
254
|
+
if len(self._data.shape) == 5:
|
|
255
|
+
if self._mode_3D:
|
|
256
|
+
self._5Ddata = True
|
|
257
|
+
else:
|
|
258
|
+
assert self._depth3D == 1, "Depth3D must be 1 for 2D training"
|
|
259
|
+
self._data = self._data.reshape(-1, *self._data.shape[2:])
|
|
260
|
+
|
|
261
|
+
if self._5Ddata:
|
|
262
|
+
self.Z = self._data.shape[1]
|
|
263
|
+
|
|
264
|
+
if self._depth3D > 1:
|
|
265
|
+
assert self._5Ddata, "Data must be 5D:NxZxHxWxC for 3D data"
|
|
266
|
+
|
|
364
267
|
assert (
|
|
365
268
|
self._data.shape[-1] == self._num_channels
|
|
366
269
|
), "Number of channels in data and config do not match."
|
|
@@ -441,9 +344,13 @@ class MultiChDloader:
|
|
|
441
344
|
def get_img_sz(self):
|
|
442
345
|
return self._img_sz
|
|
443
346
|
|
|
347
|
+
def get_num_frames(self):
|
|
348
|
+
return self._data.shape[0]
|
|
349
|
+
|
|
444
350
|
def reduce_data(
|
|
445
351
|
self, t_list=None, h_start=None, h_end=None, w_start=None, w_end=None
|
|
446
352
|
):
|
|
353
|
+
assert not self._5Ddata, "This function is not supported for 3D data."
|
|
447
354
|
if t_list is None:
|
|
448
355
|
t_list = list(range(self._data.shape[0]))
|
|
449
356
|
if h_start is None:
|
|
@@ -461,25 +368,56 @@ class MultiChDloader:
|
|
|
461
368
|
t_list, h_start:h_end, w_start:w_end, :
|
|
462
369
|
].copy()
|
|
463
370
|
|
|
464
|
-
self.N = len(t_list)
|
|
465
371
|
self.set_img_sz(self._img_sz, self._grid_sz)
|
|
466
372
|
print(
|
|
467
373
|
f"[{self.__class__.__name__}] Data reduced. New data shape: {self._data.shape}"
|
|
468
374
|
)
|
|
469
375
|
|
|
470
|
-
def
|
|
376
|
+
def get_idx_manager_shapes(
|
|
377
|
+
self, patch_size: int, grid_size: Union[int, Tuple[int, int, int]]
|
|
378
|
+
):
|
|
379
|
+
numC = self._data.shape[-1]
|
|
380
|
+
if self._5Ddata:
|
|
381
|
+
patch_shape = (1, self._depth3D, patch_size, patch_size, numC)
|
|
382
|
+
if isinstance(grid_size, int):
|
|
383
|
+
grid_shape = (1, 1, grid_size, grid_size, numC)
|
|
384
|
+
else:
|
|
385
|
+
assert len(grid_size) == 3
|
|
386
|
+
assert all(
|
|
387
|
+
[g <= p for g, p in zip(grid_size, patch_shape[1:-1])]
|
|
388
|
+
), f"Grid size {grid_size} must be less than patch size {patch_shape[1:-1]}"
|
|
389
|
+
grid_shape = (1, grid_size[0], grid_size[1], grid_size[2], numC)
|
|
390
|
+
else:
|
|
391
|
+
assert isinstance(grid_size, int)
|
|
392
|
+
grid_shape = (1, grid_size, grid_size, numC)
|
|
393
|
+
patch_shape = (1, patch_size, patch_size, numC)
|
|
394
|
+
|
|
395
|
+
return patch_shape, grid_shape
|
|
396
|
+
|
|
397
|
+
def set_img_sz(self, image_size, grid_size: Union[int, Tuple[int, int, int]]):
|
|
471
398
|
"""
|
|
472
399
|
If one wants to change the image size on the go, then this can be used.
|
|
473
400
|
Args:
|
|
474
401
|
image_size: size of one patch
|
|
475
402
|
grid_size: frame is divided into square grids of this size. A patch centered on a grid having size `image_size` is returned.
|
|
476
403
|
"""
|
|
404
|
+
|
|
477
405
|
self._img_sz = image_size
|
|
478
406
|
self._grid_sz = grid_size
|
|
407
|
+
shape = self._data.shape
|
|
408
|
+
|
|
409
|
+
patch_shape, grid_shape = self.get_idx_manager_shapes(
|
|
410
|
+
self._img_sz, self._grid_sz
|
|
411
|
+
)
|
|
479
412
|
self.idx_manager = GridIndexManager(
|
|
480
|
-
|
|
413
|
+
shape, grid_shape, patch_shape, self._tiling_mode
|
|
481
414
|
)
|
|
482
|
-
self.set_repeat_factor()
|
|
415
|
+
# self.set_repeat_factor()
|
|
416
|
+
|
|
417
|
+
def __len__(self):
|
|
418
|
+
# Vera: N is the number of frames in Z stack
|
|
419
|
+
# Repeat factor is n_rows * n_cols
|
|
420
|
+
return self.idx_manager.total_grid_count()
|
|
483
421
|
|
|
484
422
|
def set_repeat_factor(self):
|
|
485
423
|
if self._grid_sz > 1:
|
|
@@ -497,10 +435,20 @@ class MultiChDloader:
|
|
|
497
435
|
msg = (
|
|
498
436
|
f"[{self.__class__.__name__}] Train:{int(self._is_train)} Sz:{self._img_sz}"
|
|
499
437
|
)
|
|
438
|
+
dim_sizes = [
|
|
439
|
+
self.idx_manager.get_individual_dim_grid_count(dim)
|
|
440
|
+
for dim in range(len(self._data.shape))
|
|
441
|
+
]
|
|
442
|
+
dim_sizes = ",".join([str(x) for x in dim_sizes])
|
|
500
443
|
msg += f" N:{self.N} NumPatchPerN:{self._repeat_factor}"
|
|
444
|
+
msg += f"{self.idx_manager.total_grid_count()} DimSz:({dim_sizes})"
|
|
445
|
+
msg += f" TrimB:{self._tiling_mode}"
|
|
501
446
|
# msg += f' NormInp:{self._normalized_input}'
|
|
502
447
|
# msg += f' SingleNorm:{self._use_one_mu_std}'
|
|
503
448
|
msg += f" Rot:{self._enable_rotation}"
|
|
449
|
+
if self._flipz_3D:
|
|
450
|
+
msg += f" FlipZ:{self._flipz_3D}"
|
|
451
|
+
|
|
504
452
|
msg += f" RandCrop:{self._enable_random_cropping}"
|
|
505
453
|
msg += f" Channel:{self._num_channels}"
|
|
506
454
|
# msg += f' Q:{self._quantile}'
|
|
@@ -529,40 +477,52 @@ class MultiChDloader:
|
|
|
529
477
|
)
|
|
530
478
|
|
|
531
479
|
if self._enable_random_cropping:
|
|
532
|
-
|
|
480
|
+
patch_start_loc = self._get_random_hw(h, w)
|
|
481
|
+
if self._5Ddata:
|
|
482
|
+
patch_start_loc = (
|
|
483
|
+
np.random.choice(1 + img_tuples[0].shape[-3] - self._depth3D),
|
|
484
|
+
) + patch_start_loc
|
|
533
485
|
else:
|
|
534
|
-
|
|
486
|
+
patch_start_loc = self._get_deterministic_loc(index)
|
|
535
487
|
|
|
536
488
|
cropped_imgs = []
|
|
537
489
|
for img in img_tuples:
|
|
538
|
-
img = self._crop_flip_img(img,
|
|
490
|
+
img = self._crop_flip_img(img, patch_start_loc, False, False)
|
|
539
491
|
cropped_imgs.append(img)
|
|
540
492
|
|
|
541
493
|
return (
|
|
542
494
|
*tuple(cropped_imgs),
|
|
543
495
|
{
|
|
544
|
-
"h": [h_start, h_start + self._img_sz],
|
|
545
|
-
"w": [w_start, w_start + self._img_sz],
|
|
546
496
|
"hflip": False,
|
|
547
497
|
"wflip": False,
|
|
548
498
|
},
|
|
549
499
|
)
|
|
550
500
|
|
|
551
|
-
def _crop_img(self, img: np.ndarray,
|
|
552
|
-
if self.
|
|
501
|
+
def _crop_img(self, img: np.ndarray, patch_start_loc: Tuple):
|
|
502
|
+
if self._tiling_mode in [TilingMode.TrimBoundary, TilingMode.ShiftBoundary]:
|
|
553
503
|
# In training, this is used.
|
|
554
504
|
# NOTE: It is my opinion that if I just use self._crop_img_with_padding, it will work perfectly fine.
|
|
555
505
|
# The only benefit this if else loop provides is that it makes it easier to see what happens during training.
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
506
|
+
patch_end_loc = (
|
|
507
|
+
np.array(patch_start_loc, dtype=np.int32)
|
|
508
|
+
+ self.idx_manager.patch_shape[1:-1]
|
|
509
|
+
)
|
|
510
|
+
if self._5Ddata:
|
|
511
|
+
z_start, h_start, w_start = patch_start_loc
|
|
512
|
+
z_end, h_end, w_end = patch_end_loc
|
|
513
|
+
new_img = img[..., z_start:z_end, h_start:h_end, w_start:w_end]
|
|
514
|
+
else:
|
|
515
|
+
h_start, w_start = patch_start_loc
|
|
516
|
+
h_end, w_end = patch_end_loc
|
|
517
|
+
new_img = img[..., h_start:h_end, w_start:w_end]
|
|
518
|
+
|
|
559
519
|
return new_img
|
|
560
|
-
|
|
520
|
+
else:
|
|
561
521
|
# During evaluation, this is used. In this situation, we can have negative h_start, w_start. Or h_start +self._img_sz can be larger than frame
|
|
562
522
|
# In these situations, we need some sort of padding. This is not needed in the LeftTop alignement.
|
|
563
|
-
return self._crop_img_with_padding(img,
|
|
523
|
+
return self._crop_img_with_padding(img, patch_start_loc)
|
|
564
524
|
|
|
565
|
-
def get_begin_end_padding(self, start_pos, max_len):
|
|
525
|
+
def get_begin_end_padding(self, start_pos, end_pos, max_len):
|
|
566
526
|
"""
|
|
567
527
|
The effect is that the image with size self._grid_sz is in the center of the patch with sufficient
|
|
568
528
|
padding on all four sides so that the final patch size is self._img_sz.
|
|
@@ -572,44 +532,56 @@ class MultiChDloader:
|
|
|
572
532
|
if start_pos < 0:
|
|
573
533
|
pad_start = -1 * start_pos
|
|
574
534
|
|
|
575
|
-
pad_end = max(0,
|
|
535
|
+
pad_end = max(0, end_pos - max_len)
|
|
576
536
|
|
|
577
537
|
return pad_start, pad_end
|
|
578
538
|
|
|
579
|
-
def _crop_img_with_padding(
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
539
|
+
def _crop_img_with_padding(
|
|
540
|
+
self, img: np.ndarray, patch_start_loc, max_len_vals=None
|
|
541
|
+
):
|
|
542
|
+
if max_len_vals is None:
|
|
543
|
+
max_len_vals = self.idx_manager.data_shape[1:-1]
|
|
544
|
+
patch_end_loc = np.array(patch_start_loc, dtype=int) + np.array(
|
|
545
|
+
self.idx_manager.patch_shape[1:-1], dtype=int
|
|
546
|
+
)
|
|
547
|
+
boundary_crossed = []
|
|
548
|
+
valid_slice = []
|
|
549
|
+
padding = [[0, 0]]
|
|
550
|
+
for start_idx, end_idx, max_len in zip(
|
|
551
|
+
patch_start_loc, patch_end_loc, max_len_vals
|
|
552
|
+
):
|
|
553
|
+
boundary_crossed.append(end_idx > max_len or start_idx < 0)
|
|
554
|
+
valid_slice.append((max(0, start_idx), min(max_len, end_idx)))
|
|
555
|
+
pad = [0, 0]
|
|
556
|
+
if boundary_crossed[-1]:
|
|
557
|
+
pad = self.get_begin_end_padding(start_idx, end_idx, max_len)
|
|
558
|
+
padding.append(pad)
|
|
589
559
|
# max() is needed since h_start could be negative.
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
560
|
+
if self._5Ddata:
|
|
561
|
+
new_img = img[
|
|
562
|
+
...,
|
|
563
|
+
valid_slice[0][0] : valid_slice[0][1],
|
|
564
|
+
valid_slice[1][0] : valid_slice[1][1],
|
|
565
|
+
valid_slice[2][0] : valid_slice[2][1],
|
|
566
|
+
]
|
|
567
|
+
else:
|
|
568
|
+
new_img = img[
|
|
569
|
+
...,
|
|
570
|
+
valid_slice[0][0] : valid_slice[0][1],
|
|
571
|
+
valid_slice[1][0] : valid_slice[1][1],
|
|
572
|
+
]
|
|
603
573
|
|
|
574
|
+
# print(np.array(padding).shape, img.shape, new_img.shape)
|
|
575
|
+
# print(padding)
|
|
604
576
|
if not np.all(padding == 0):
|
|
605
577
|
new_img = np.pad(new_img, padding, **self._overlapping_padding_kwargs)
|
|
606
578
|
|
|
607
579
|
return new_img
|
|
608
580
|
|
|
609
581
|
def _crop_flip_img(
|
|
610
|
-
self, img: np.ndarray,
|
|
582
|
+
self, img: np.ndarray, patch_start_loc: Tuple, h_flip: bool, w_flip: bool
|
|
611
583
|
):
|
|
612
|
-
new_img = self._crop_img(img,
|
|
584
|
+
new_img = self._crop_img(img, patch_start_loc)
|
|
613
585
|
if h_flip:
|
|
614
586
|
new_img = new_img[..., ::-1, :]
|
|
615
587
|
if w_flip:
|
|
@@ -617,9 +589,6 @@ class MultiChDloader:
|
|
|
617
589
|
|
|
618
590
|
return new_img.astype(np.float32)
|
|
619
591
|
|
|
620
|
-
def __len__(self):
|
|
621
|
-
return self.N * self._repeat_factor
|
|
622
|
-
|
|
623
592
|
def _load_img(
|
|
624
593
|
self, index: Union[int, Tuple[int, int]]
|
|
625
594
|
) -> Tuple[np.ndarray, np.ndarray]:
|
|
@@ -631,12 +600,21 @@ class MultiChDloader:
|
|
|
631
600
|
else:
|
|
632
601
|
idx = index[0]
|
|
633
602
|
|
|
634
|
-
|
|
603
|
+
patch_loc_list = self.idx_manager.get_patch_location_from_dataset_idx(idx)
|
|
604
|
+
imgs = self._data[patch_loc_list[0]]
|
|
605
|
+
# if self._5Ddata:
|
|
606
|
+
# assert self._noise_data is None, 'Noise is not supported for 5D data'
|
|
607
|
+
# n_loc, z_loc = patch_loc_list[:2]
|
|
608
|
+
# z_loc_interval = range(z_loc, z_loc + self._depth3D)
|
|
609
|
+
# imgs = self._data[n_loc, z_loc_interval]
|
|
610
|
+
# else:
|
|
611
|
+
# imgs = self._data[patch_loc_list[0]]
|
|
612
|
+
|
|
635
613
|
loaded_imgs = [imgs[None, ..., i] for i in range(imgs.shape[-1])]
|
|
636
614
|
noise = []
|
|
637
615
|
if self._noise_data is not None and not self._disable_noise:
|
|
638
616
|
noise = [
|
|
639
|
-
self._noise_data[
|
|
617
|
+
self._noise_data[patch_loc_list[0]][None, ..., i]
|
|
640
618
|
for i in range(self._noise_data.shape[-1])
|
|
641
619
|
]
|
|
642
620
|
return tuple(loaded_imgs), tuple(noise)
|
|
@@ -660,6 +638,18 @@ class MultiChDloader:
|
|
|
660
638
|
normalized_imgs.append(img)
|
|
661
639
|
return tuple(normalized_imgs)
|
|
662
640
|
|
|
641
|
+
def normalize_input(self, x):
|
|
642
|
+
mean_dict, std_dict = self.get_mean_std()
|
|
643
|
+
mean_ = mean_dict["input"].mean()
|
|
644
|
+
std_ = std_dict["input"].mean()
|
|
645
|
+
return (x - mean_) / std_
|
|
646
|
+
|
|
647
|
+
def normalize_target(self, target):
|
|
648
|
+
mean_dict, std_dict = self.get_mean_std()
|
|
649
|
+
mean_ = mean_dict["target"].squeeze(0)
|
|
650
|
+
std_ = std_dict["target"].squeeze(0)
|
|
651
|
+
return (target - mean_) / std_
|
|
652
|
+
|
|
663
653
|
def get_grid_size(self):
|
|
664
654
|
return self._grid_sz
|
|
665
655
|
|
|
@@ -669,27 +659,16 @@ class MultiChDloader:
|
|
|
669
659
|
def per_side_overlap_pixelcount(self):
|
|
670
660
|
return (self._img_sz - self._grid_sz) // 2
|
|
671
661
|
|
|
672
|
-
def on_boundary(self, cur_loc, frame_size):
|
|
673
|
-
|
|
662
|
+
# def on_boundary(self, cur_loc, frame_size):
|
|
663
|
+
# return cur_loc + self._img_sz > frame_size or cur_loc < 0
|
|
674
664
|
|
|
675
|
-
def
|
|
665
|
+
def _get_deterministic_loc(self, index: int):
|
|
676
666
|
"""
|
|
677
667
|
It returns the top-left corner of the patch corresponding to index.
|
|
678
668
|
"""
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
else:
|
|
683
|
-
idx, grid_size = index
|
|
684
|
-
|
|
685
|
-
h_start, w_start = self.idx_manager.get_deterministic_hw(
|
|
686
|
-
idx, grid_size=grid_size
|
|
687
|
-
)
|
|
688
|
-
if self._grid_alignment == GridAlignement.LeftTop:
|
|
689
|
-
return h_start, w_start
|
|
690
|
-
elif self._grid_alignment == GridAlignement.Center:
|
|
691
|
-
pad = self.per_side_overlap_pixelcount()
|
|
692
|
-
return h_start - pad, w_start - pad
|
|
669
|
+
loc_list = self.idx_manager.get_patch_location_from_dataset_idx(index)
|
|
670
|
+
# last dim is channel. we need to take the third and the second last element.
|
|
671
|
+
return loc_list[1:-1]
|
|
693
672
|
|
|
694
673
|
def compute_individual_mean_std(self):
|
|
695
674
|
# numpy 1.19.2 has issues in computing for large arrays. https://github.com/numpy/numpy/issues/8869
|
|
@@ -715,6 +694,10 @@ class MultiChDloader:
|
|
|
715
694
|
|
|
716
695
|
mean = np.array(mean_arr)
|
|
717
696
|
std = np.array(std_arr)
|
|
697
|
+
if (
|
|
698
|
+
self._5Ddata
|
|
699
|
+
): # NOTE: IDEALLY this should be only when the model expects 3D data.
|
|
700
|
+
return mean[None, :, None, None, None], std[None, :, None, None, None]
|
|
718
701
|
|
|
719
702
|
return mean[None, :, None, None], std[None, :, None, None]
|
|
720
703
|
|
|
@@ -776,6 +759,10 @@ class MultiChDloader:
|
|
|
776
759
|
if self._skip_normalization_using_mean:
|
|
777
760
|
mean = np.zeros_like(mean)
|
|
778
761
|
|
|
762
|
+
if self._5Ddata:
|
|
763
|
+
mean = mean[:, :, None]
|
|
764
|
+
std = std[:, :, None]
|
|
765
|
+
|
|
779
766
|
mean_dict = {"input": mean} # , 'target':mean}
|
|
780
767
|
std_dict = {"input": std} # , 'target':std}
|
|
781
768
|
|
|
@@ -810,8 +797,14 @@ class MultiChDloader:
|
|
|
810
797
|
return cropped_img_tuples, cropped_noise_tuples
|
|
811
798
|
|
|
812
799
|
def replace_with_empty_patch(self, img_tuples):
|
|
800
|
+
"""
|
|
801
|
+
Replaces the content of one of the channels with background
|
|
802
|
+
"""
|
|
813
803
|
empty_index = self._empty_patch_fetcher.sample()
|
|
814
|
-
empty_img_tuples = self._get_img(empty_index)
|
|
804
|
+
empty_img_tuples, empty_img_noise_tuples = self._get_img(empty_index)
|
|
805
|
+
assert (
|
|
806
|
+
len(empty_img_noise_tuples) == 0
|
|
807
|
+
), "Noise is not supported with empty patch replacement"
|
|
815
808
|
final_img_tuples = []
|
|
816
809
|
for tuple_idx in range(len(img_tuples)):
|
|
817
810
|
if tuple_idx == self._empty_patch_replacement_channel_idx:
|
|
@@ -834,14 +827,7 @@ class MultiChDloader:
|
|
|
834
827
|
)
|
|
835
828
|
img_tuples = [img_tuples[i] for i in self._tar_idx_list]
|
|
836
829
|
|
|
837
|
-
|
|
838
|
-
assert self._input_is_sum is False
|
|
839
|
-
target = []
|
|
840
|
-
for i in range(len(img_tuples)):
|
|
841
|
-
target.append(img_tuples[i] * alpha[i])
|
|
842
|
-
target = np.concatenate(target, axis=0)
|
|
843
|
-
else:
|
|
844
|
-
target = np.concatenate(img_tuples, axis=0)
|
|
830
|
+
target = np.concatenate(img_tuples, axis=0)
|
|
845
831
|
return target
|
|
846
832
|
|
|
847
833
|
def _compute_input_with_alpha(self, img_tuples, alpha_list):
|
|
@@ -902,9 +888,6 @@ class MultiChDloader:
|
|
|
902
888
|
index = self._train_index_switcher.get_invalid_target_index()
|
|
903
889
|
return index
|
|
904
890
|
|
|
905
|
-
def _rotate(self, img_tuples, noise_tuples):
|
|
906
|
-
return self._rotate2D(img_tuples, noise_tuples)
|
|
907
|
-
|
|
908
891
|
def _rotate2D(self, img_tuples, noise_tuples):
|
|
909
892
|
img_kwargs = {}
|
|
910
893
|
for i, img in enumerate(img_tuples):
|
|
@@ -921,6 +904,7 @@ class MultiChDloader:
|
|
|
921
904
|
rot_dic = self._rotation_transform(
|
|
922
905
|
image=img_tuples[0][0], **img_kwargs, **noise_kwargs
|
|
923
906
|
)
|
|
907
|
+
|
|
924
908
|
rotated_img_tuples = []
|
|
925
909
|
for i, img in enumerate(img_tuples):
|
|
926
910
|
if len(img) == 1:
|
|
@@ -946,7 +930,101 @@ class MultiChDloader:
|
|
|
946
930
|
|
|
947
931
|
return rotated_img_tuples, rotated_noise_tuples
|
|
948
932
|
|
|
933
|
+
def _rotate(self, img_tuples, noise_tuples):
|
|
934
|
+
|
|
935
|
+
if self._5Ddata:
|
|
936
|
+
return self._rotate3D(img_tuples, noise_tuples)
|
|
937
|
+
else:
|
|
938
|
+
return self._rotate2D(img_tuples, noise_tuples)
|
|
939
|
+
|
|
940
|
+
def _rotate3D(self, img_tuples, noise_tuples):
|
|
941
|
+
img_kwargs = {}
|
|
942
|
+
# random flip in z direction
|
|
943
|
+
flip_z = self._flipz_3D and np.random.rand() < 0.5
|
|
944
|
+
for i, img in enumerate(img_tuples):
|
|
945
|
+
for j in range(self._depth3D):
|
|
946
|
+
for k in range(len(img)):
|
|
947
|
+
if flip_z:
|
|
948
|
+
z_idx = self._depth3D - 1 - j
|
|
949
|
+
else:
|
|
950
|
+
z_idx = j
|
|
951
|
+
img_kwargs[f"img{i}_{z_idx}_{k}"] = img[k, j]
|
|
952
|
+
|
|
953
|
+
noise_kwargs = {}
|
|
954
|
+
for i, nimg in enumerate(noise_tuples):
|
|
955
|
+
for j in range(self._depth3D):
|
|
956
|
+
for k in range(len(nimg)):
|
|
957
|
+
if flip_z:
|
|
958
|
+
z_idx = self._depth3D - 1 - j
|
|
959
|
+
else:
|
|
960
|
+
z_idx = j
|
|
961
|
+
noise_kwargs[f"noise{i}_{z_idx}_{k}"] = nimg[k, j]
|
|
962
|
+
|
|
963
|
+
keys = list(img_kwargs.keys()) + list(noise_kwargs.keys())
|
|
964
|
+
self._rotation_transform.add_targets({k: "image" for k in keys})
|
|
965
|
+
rot_dic = self._rotation_transform(
|
|
966
|
+
image=img_tuples[0][0][0], **img_kwargs, **noise_kwargs
|
|
967
|
+
)
|
|
968
|
+
rotated_img_tuples = []
|
|
969
|
+
for i, img in enumerate(img_tuples):
|
|
970
|
+
if len(img) == 1:
|
|
971
|
+
rotated_img_tuples.append(
|
|
972
|
+
np.concatenate(
|
|
973
|
+
[
|
|
974
|
+
rot_dic[f"img{i}_{j}_0"][None, None]
|
|
975
|
+
for j in range(self._depth3D)
|
|
976
|
+
],
|
|
977
|
+
axis=1,
|
|
978
|
+
)
|
|
979
|
+
)
|
|
980
|
+
else:
|
|
981
|
+
temp_arr = []
|
|
982
|
+
for k in range(len(img)):
|
|
983
|
+
temp_arr.append(
|
|
984
|
+
np.concatenate(
|
|
985
|
+
[
|
|
986
|
+
rot_dic[f"img{i}_{j}_{k}"][None, None]
|
|
987
|
+
for j in range(self._depth3D)
|
|
988
|
+
],
|
|
989
|
+
axis=1,
|
|
990
|
+
)
|
|
991
|
+
)
|
|
992
|
+
rotated_img_tuples.append(np.concatenate(temp_arr, axis=0))
|
|
993
|
+
|
|
994
|
+
rotated_noise_tuples = []
|
|
995
|
+
for i, nimg in enumerate(noise_tuples):
|
|
996
|
+
if len(nimg) == 1:
|
|
997
|
+
rotated_noise_tuples.append(
|
|
998
|
+
np.concatenate(
|
|
999
|
+
[
|
|
1000
|
+
rot_dic[f"noise{i}_{j}_0"][None, None]
|
|
1001
|
+
for j in range(self._depth3D)
|
|
1002
|
+
],
|
|
1003
|
+
axis=1,
|
|
1004
|
+
)
|
|
1005
|
+
)
|
|
1006
|
+
else:
|
|
1007
|
+
temp_arr = []
|
|
1008
|
+
for k in range(len(nimg)):
|
|
1009
|
+
temp_arr.append(
|
|
1010
|
+
np.concatenate(
|
|
1011
|
+
[
|
|
1012
|
+
rot_dic[f"noise{i}_{j}_{k}"][None, None]
|
|
1013
|
+
for j in range(self._depth3D)
|
|
1014
|
+
],
|
|
1015
|
+
axis=1,
|
|
1016
|
+
)
|
|
1017
|
+
)
|
|
1018
|
+
rotated_noise_tuples.append(np.concatenate(temp_arr, axis=0))
|
|
1019
|
+
|
|
1020
|
+
return rotated_img_tuples, rotated_noise_tuples
|
|
1021
|
+
|
|
949
1022
|
def get_uncorrelated_img_tuples(self, index):
|
|
1023
|
+
"""
|
|
1024
|
+
Content of channels like actin and nuclei is "correlated" in its
|
|
1025
|
+
respective location, this function allows to pick channels' content
|
|
1026
|
+
from different patches of the image to make it "uncorrelated".
|
|
1027
|
+
"""
|
|
950
1028
|
img_tuples, noise_tuples = self._get_img(index)
|
|
951
1029
|
assert len(noise_tuples) == 0
|
|
952
1030
|
img_tuples = [img_tuples[0]]
|
|
@@ -959,10 +1037,15 @@ class MultiChDloader:
|
|
|
959
1037
|
def __getitem__(
|
|
960
1038
|
self, index: Union[int, Tuple[int, int]]
|
|
961
1039
|
) -> Tuple[np.ndarray, np.ndarray]:
|
|
1040
|
+
# Vera: input can be both real microscopic image and two separate channels that are summed in the code
|
|
1041
|
+
|
|
962
1042
|
if self._train_index_switcher is not None:
|
|
963
1043
|
index = self._get_index_from_valid_target_logic(index)
|
|
964
1044
|
|
|
965
|
-
if
|
|
1045
|
+
if (
|
|
1046
|
+
self._uncorrelated_channels
|
|
1047
|
+
and np.random.rand() < self._uncorrelated_channel_probab
|
|
1048
|
+
):
|
|
966
1049
|
img_tuples, noise_tuples = self.get_uncorrelated_img_tuples(index)
|
|
967
1050
|
else:
|
|
968
1051
|
img_tuples, noise_tuples = self._get_img(index)
|
|
@@ -971,28 +1054,36 @@ class MultiChDloader:
|
|
|
971
1054
|
self._empty_patch_replacement_enabled != True
|
|
972
1055
|
), "This is not supported with noise"
|
|
973
1056
|
|
|
1057
|
+
# Replace the content of one of the channels
|
|
1058
|
+
# with background with given probability
|
|
974
1059
|
if self._empty_patch_replacement_enabled:
|
|
975
1060
|
if np.random.rand() < self._empty_patch_replacement_probab:
|
|
976
1061
|
img_tuples = self.replace_with_empty_patch(img_tuples)
|
|
977
1062
|
|
|
1063
|
+
# Noise tuples are not needed for the paper
|
|
1064
|
+
# the image tuples are noisy by default
|
|
1065
|
+
# TODO: remove noise tuples completely?
|
|
978
1066
|
if self._enable_rotation:
|
|
979
1067
|
img_tuples, noise_tuples = self._rotate(img_tuples, noise_tuples)
|
|
980
1068
|
|
|
981
|
-
#
|
|
1069
|
+
# Add noise tuples with image tuples to create the input
|
|
982
1070
|
if len(noise_tuples) > 0:
|
|
983
1071
|
factor = np.sqrt(2) if self._input_is_sum else 1.0
|
|
984
1072
|
input_tuples = [x + noise_tuples[0] * factor for x in img_tuples]
|
|
985
1073
|
else:
|
|
986
1074
|
input_tuples = img_tuples
|
|
1075
|
+
|
|
1076
|
+
# Weight the individual channels, typically alpha is fixed
|
|
987
1077
|
inp, alpha = self._compute_input(input_tuples)
|
|
988
1078
|
|
|
989
|
-
#
|
|
1079
|
+
# Add noise tuples to the image tuples to create the target
|
|
990
1080
|
if len(noise_tuples) >= 1:
|
|
991
1081
|
img_tuples = [x + noise for x, noise in zip(img_tuples, noise_tuples[1:])]
|
|
992
1082
|
|
|
993
1083
|
target = self._compute_target(img_tuples, alpha)
|
|
1084
|
+
norm_target = self.normalize_target(target)
|
|
994
1085
|
|
|
995
|
-
output = [inp,
|
|
1086
|
+
output = [inp, norm_target]
|
|
996
1087
|
|
|
997
1088
|
if self._return_alpha:
|
|
998
1089
|
output.append(alpha)
|
|
@@ -1000,221 +1091,4 @@ class MultiChDloader:
|
|
|
1000
1091
|
if self._return_index:
|
|
1001
1092
|
output.append(index)
|
|
1002
1093
|
|
|
1003
|
-
if isinstance(index, int) or isinstance(index, np.int64):
|
|
1004
|
-
return tuple(output)
|
|
1005
|
-
|
|
1006
|
-
_, grid_size = index
|
|
1007
|
-
output.append(grid_size)
|
|
1008
|
-
return tuple(output)
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
class LCMultiChDloader(MultiChDloader):
|
|
1012
|
-
|
|
1013
|
-
def __init__(
|
|
1014
|
-
self,
|
|
1015
|
-
data_config,
|
|
1016
|
-
fpath: str,
|
|
1017
|
-
datasplit_type: DataSplitType = None,
|
|
1018
|
-
val_fraction=None,
|
|
1019
|
-
test_fraction=None,
|
|
1020
|
-
normalized_input=None,
|
|
1021
|
-
enable_rotation_aug: bool = False,
|
|
1022
|
-
use_one_mu_std=None,
|
|
1023
|
-
num_scales: int = None,
|
|
1024
|
-
enable_random_cropping=False,
|
|
1025
|
-
padding_kwargs: dict = None,
|
|
1026
|
-
allow_generation: bool = False,
|
|
1027
|
-
lowres_supervision=None,
|
|
1028
|
-
max_val=None,
|
|
1029
|
-
grid_alignment=GridAlignement.LeftTop,
|
|
1030
|
-
overlapping_padding_kwargs=None,
|
|
1031
|
-
print_vars=True,
|
|
1032
|
-
):
|
|
1033
|
-
"""
|
|
1034
|
-
Args:
|
|
1035
|
-
num_scales: The number of resolutions at which we want the input. Note that the target is formed at the
|
|
1036
|
-
highest resolution.
|
|
1037
|
-
"""
|
|
1038
|
-
self._padding_kwargs = (
|
|
1039
|
-
padding_kwargs # mode=padding_mode, constant_values=constant_value
|
|
1040
|
-
)
|
|
1041
|
-
if overlapping_padding_kwargs is not None:
|
|
1042
|
-
assert (
|
|
1043
|
-
self._padding_kwargs == overlapping_padding_kwargs
|
|
1044
|
-
), "During evaluation, overlapping_padding_kwargs should be same as padding_args. \
|
|
1045
|
-
It should be so since we just use overlapping_padding_kwargs when it is not None"
|
|
1046
|
-
|
|
1047
|
-
else:
|
|
1048
|
-
overlapping_padding_kwargs = padding_kwargs
|
|
1049
|
-
|
|
1050
|
-
super().__init__(
|
|
1051
|
-
data_config,
|
|
1052
|
-
fpath,
|
|
1053
|
-
datasplit_type=datasplit_type,
|
|
1054
|
-
val_fraction=val_fraction,
|
|
1055
|
-
test_fraction=test_fraction,
|
|
1056
|
-
normalized_input=normalized_input,
|
|
1057
|
-
enable_rotation_aug=enable_rotation_aug,
|
|
1058
|
-
enable_random_cropping=enable_random_cropping,
|
|
1059
|
-
use_one_mu_std=use_one_mu_std,
|
|
1060
|
-
allow_generation=allow_generation,
|
|
1061
|
-
max_val=max_val,
|
|
1062
|
-
grid_alignment=grid_alignment,
|
|
1063
|
-
overlapping_padding_kwargs=overlapping_padding_kwargs,
|
|
1064
|
-
print_vars=print_vars,
|
|
1065
|
-
)
|
|
1066
|
-
self.num_scales = num_scales
|
|
1067
|
-
assert self.num_scales is not None
|
|
1068
|
-
self._scaled_data = [self._data]
|
|
1069
|
-
self._scaled_noise_data = [self._noise_data]
|
|
1070
|
-
|
|
1071
|
-
assert isinstance(self.num_scales, int) and self.num_scales >= 1
|
|
1072
|
-
self._lowres_supervision = lowres_supervision
|
|
1073
|
-
assert isinstance(self._padding_kwargs, dict)
|
|
1074
|
-
assert "mode" in self._padding_kwargs
|
|
1075
|
-
|
|
1076
|
-
for _ in range(1, self.num_scales):
|
|
1077
|
-
shape = self._scaled_data[-1].shape
|
|
1078
|
-
assert len(shape) == 4
|
|
1079
|
-
new_shape = (shape[0], shape[1] // 2, shape[2] // 2, shape[3])
|
|
1080
|
-
ds_data = resize(
|
|
1081
|
-
self._scaled_data[-1].astype(np.float32), new_shape
|
|
1082
|
-
).astype(self._scaled_data[-1].dtype)
|
|
1083
|
-
# NOTE: These asserts are important. the resize method expects np.float32. otherwise, one gets weird results.
|
|
1084
|
-
assert (
|
|
1085
|
-
ds_data.max() / self._scaled_data[-1].max() < 5
|
|
1086
|
-
), "Downsampled image should not have very different values"
|
|
1087
|
-
assert (
|
|
1088
|
-
ds_data.max() / self._scaled_data[-1].max() > 0.2
|
|
1089
|
-
), "Downsampled image should not have very different values"
|
|
1090
|
-
|
|
1091
|
-
self._scaled_data.append(ds_data)
|
|
1092
|
-
# do the same for noise
|
|
1093
|
-
if self._noise_data is not None:
|
|
1094
|
-
noise_data = resize(self._scaled_noise_data[-1], new_shape)
|
|
1095
|
-
self._scaled_noise_data.append(noise_data)
|
|
1096
|
-
|
|
1097
|
-
def _init_msg(self):
|
|
1098
|
-
msg = super()._init_msg()
|
|
1099
|
-
msg += f" Pad:{self._padding_kwargs}"
|
|
1100
|
-
return msg
|
|
1101
|
-
|
|
1102
|
-
def _load_scaled_img(
|
|
1103
|
-
self, scaled_index, index: Union[int, Tuple[int, int]]
|
|
1104
|
-
) -> Tuple[np.ndarray, np.ndarray]:
|
|
1105
|
-
if isinstance(index, int):
|
|
1106
|
-
idx = index
|
|
1107
|
-
else:
|
|
1108
|
-
idx, _ = index
|
|
1109
|
-
imgs = self._scaled_data[scaled_index][idx % self.N]
|
|
1110
|
-
imgs = tuple([imgs[None, :, :, i] for i in range(imgs.shape[-1])])
|
|
1111
|
-
if self._noise_data is not None:
|
|
1112
|
-
noisedata = self._scaled_noise_data[scaled_index][idx % self.N]
|
|
1113
|
-
noise = tuple(
|
|
1114
|
-
[noisedata[None, :, :, i] for i in range(noisedata.shape[-1])]
|
|
1115
|
-
)
|
|
1116
|
-
factor = np.sqrt(2) if self._input_is_sum else 1.0
|
|
1117
|
-
# since we are using this lowres images for just the input, we need to add the noise of the input.
|
|
1118
|
-
assert self._lowres_supervision is None or self._lowres_supervision is False
|
|
1119
|
-
imgs = tuple([img + noise[0] * factor for img in imgs])
|
|
1120
|
-
return imgs
|
|
1121
|
-
|
|
1122
|
-
def _crop_img(self, img: np.ndarray, h_start: int, w_start: int):
|
|
1123
|
-
"""
|
|
1124
|
-
Here, h_start, w_start could be negative. That simply means we need to pick the content from 0. So,
|
|
1125
|
-
the cropped image will be smaller than self._img_sz * self._img_sz
|
|
1126
|
-
"""
|
|
1127
|
-
return self._crop_img_with_padding(img, h_start, w_start)
|
|
1128
|
-
|
|
1129
|
-
def _get_img(self, index: int):
|
|
1130
|
-
"""
|
|
1131
|
-
Returns the primary patch along with low resolution patches centered on the primary patch.
|
|
1132
|
-
"""
|
|
1133
|
-
img_tuples, noise_tuples = self._load_img(index)
|
|
1134
|
-
assert self._img_sz is not None
|
|
1135
|
-
h, w = img_tuples[0].shape[-2:]
|
|
1136
|
-
if self._enable_random_cropping:
|
|
1137
|
-
h_start, w_start = self._get_random_hw(h, w)
|
|
1138
|
-
else:
|
|
1139
|
-
h_start, w_start = self._get_deterministic_hw(index)
|
|
1140
|
-
|
|
1141
|
-
cropped_img_tuples = [
|
|
1142
|
-
self._crop_flip_img(img, h_start, w_start, False, False)
|
|
1143
|
-
for img in img_tuples
|
|
1144
|
-
]
|
|
1145
|
-
cropped_noise_tuples = [
|
|
1146
|
-
self._crop_flip_img(noise, h_start, w_start, False, False)
|
|
1147
|
-
for noise in noise_tuples
|
|
1148
|
-
]
|
|
1149
|
-
h_center = h_start + self._img_sz // 2
|
|
1150
|
-
w_center = w_start + self._img_sz // 2
|
|
1151
|
-
allres_versions = {
|
|
1152
|
-
i: [cropped_img_tuples[i]] for i in range(len(cropped_img_tuples))
|
|
1153
|
-
}
|
|
1154
|
-
for scale_idx in range(1, self.num_scales):
|
|
1155
|
-
scaled_img_tuples = self._load_scaled_img(scale_idx, index)
|
|
1156
|
-
|
|
1157
|
-
h_center = h_center // 2
|
|
1158
|
-
w_center = w_center // 2
|
|
1159
|
-
|
|
1160
|
-
h_start = h_center - self._img_sz // 2
|
|
1161
|
-
w_start = w_center - self._img_sz // 2
|
|
1162
|
-
|
|
1163
|
-
scaled_cropped_img_tuples = [
|
|
1164
|
-
self._crop_flip_img(img, h_start, w_start, False, False)
|
|
1165
|
-
for img in scaled_img_tuples
|
|
1166
|
-
]
|
|
1167
|
-
for ch_idx in range(len(img_tuples)):
|
|
1168
|
-
allres_versions[ch_idx].append(scaled_cropped_img_tuples[ch_idx])
|
|
1169
|
-
|
|
1170
|
-
output_img_tuples = tuple(
|
|
1171
|
-
[
|
|
1172
|
-
np.concatenate(allres_versions[ch_idx])
|
|
1173
|
-
for ch_idx in range(len(img_tuples))
|
|
1174
|
-
]
|
|
1175
|
-
)
|
|
1176
|
-
return output_img_tuples, cropped_noise_tuples
|
|
1177
|
-
|
|
1178
|
-
def __getitem__(self, index: Union[int, Tuple[int, int]]):
|
|
1179
|
-
if self._uncorrelated_channels:
|
|
1180
|
-
img_tuples, noise_tuples = self.get_uncorrelated_img_tuples(index)
|
|
1181
|
-
else:
|
|
1182
|
-
img_tuples, noise_tuples = self._get_img(index)
|
|
1183
|
-
|
|
1184
|
-
if self._enable_rotation:
|
|
1185
|
-
img_tuples, noise_tuples = self._rotate(img_tuples, noise_tuples)
|
|
1186
|
-
|
|
1187
|
-
assert self._lowres_supervision != True
|
|
1188
|
-
# add noise to input
|
|
1189
|
-
if len(noise_tuples) > 0:
|
|
1190
|
-
factor = np.sqrt(2) if self._input_is_sum else 1.0
|
|
1191
|
-
input_tuples = []
|
|
1192
|
-
for x in img_tuples:
|
|
1193
|
-
# NOTE: other LC levels already have noise added. So, we just need to add noise to the highest resolution.
|
|
1194
|
-
x[0] = x[0] + noise_tuples[0] * factor
|
|
1195
|
-
input_tuples.append(x)
|
|
1196
|
-
else:
|
|
1197
|
-
input_tuples = img_tuples
|
|
1198
|
-
|
|
1199
|
-
inp, alpha = self._compute_input(input_tuples)
|
|
1200
|
-
# assert self._alpha_weighted_target in [False, None]
|
|
1201
|
-
target_tuples = [img[:1] for img in img_tuples]
|
|
1202
|
-
# add noise to target.
|
|
1203
|
-
if len(noise_tuples) >= 1:
|
|
1204
|
-
target_tuples = [
|
|
1205
|
-
x + noise for x, noise in zip(target_tuples, noise_tuples[1:])
|
|
1206
|
-
]
|
|
1207
|
-
|
|
1208
|
-
target = self._compute_target(target_tuples, alpha)
|
|
1209
|
-
|
|
1210
|
-
output = [inp, target]
|
|
1211
|
-
|
|
1212
|
-
if self._return_alpha:
|
|
1213
|
-
output.append(alpha)
|
|
1214
|
-
|
|
1215
|
-
if isinstance(index, int):
|
|
1216
|
-
return tuple(output)
|
|
1217
|
-
|
|
1218
|
-
_, grid_size = index
|
|
1219
|
-
output.append(grid_size)
|
|
1220
1094
|
return tuple(output)
|