careamics 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +14 -11
- careamics/config/__init__.py +7 -3
- careamics/config/architectures/__init__.py +2 -2
- careamics/config/architectures/architecture_model.py +1 -1
- careamics/config/architectures/custom_model.py +11 -8
- careamics/config/architectures/lvae_model.py +174 -0
- careamics/config/configuration_factory.py +11 -3
- careamics/config/configuration_model.py +7 -3
- careamics/config/data_model.py +33 -8
- careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +28 -43
- careamics/config/likelihood_model.py +43 -0
- careamics/config/nm_model.py +101 -0
- careamics/config/support/supported_activations.py +1 -0
- careamics/config/support/supported_algorithms.py +17 -4
- careamics/config/support/supported_architectures.py +8 -11
- careamics/config/support/supported_losses.py +3 -1
- careamics/config/transformations/n2v_manipulate_model.py +1 -1
- careamics/config/vae_algorithm_model.py +171 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
- careamics/file_io/read/tiff.py +1 -1
- careamics/lightning/__init__.py +3 -2
- careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
- careamics/lightning/lightning_module.py +365 -9
- careamics/lightning/predict_data_module.py +2 -2
- careamics/lightning/train_data_module.py +2 -2
- careamics/losses/__init__.py +11 -1
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/{losses.py → fcn/losses.py} +1 -1
- careamics/losses/loss_factory.py +112 -6
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/dataset/__init__.py +0 -0
- careamics/lvae_training/{data_utils.py → dataset/data_utils.py} +277 -194
- careamics/lvae_training/dataset/lc_dataset.py +259 -0
- careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
- careamics/lvae_training/dataset/vae_data_config.py +179 -0
- careamics/lvae_training/{data_modules.py → dataset/vae_dataset.py} +306 -472
- careamics/lvae_training/get_config.py +1 -1
- careamics/lvae_training/train_lvae.py +6 -3
- careamics/model_io/bioimage/bioimage_utils.py +1 -1
- careamics/model_io/bioimage/model_description.py +2 -2
- careamics/model_io/bmz_io.py +19 -6
- careamics/model_io/model_io_utils.py +16 -4
- careamics/models/__init__.py +1 -3
- careamics/models/activation.py +2 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +21 -21
- careamics/models/lvae/likelihoods.py +180 -128
- careamics/models/lvae/lvae.py +52 -136
- careamics/models/lvae/noise_models.py +318 -186
- careamics/models/lvae/utils.py +2 -2
- careamics/models/model_factory.py +22 -7
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/stitch_prediction.py +16 -2
- careamics/transforms/pixel_manipulation.py +1 -1
- careamics/utils/metrics.py +74 -1
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/METADATA +2 -2
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/RECORD +63 -49
- careamics/config/architectures/vae_model.py +0 -42
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/WHEEL +0 -0
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,189 +2,65 @@
|
|
|
2
2
|
A place for Datasets and Dataloaders.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
import os
|
|
6
5
|
from typing import Tuple, Union
|
|
7
6
|
|
|
8
|
-
# import albumentations as A
|
|
9
|
-
import ml_collections
|
|
10
7
|
import numpy as np
|
|
11
|
-
from skimage.transform import resize
|
|
12
8
|
|
|
13
9
|
from .data_utils import (
|
|
14
|
-
DataSplitType,
|
|
15
|
-
DataType,
|
|
16
|
-
GridAlignement,
|
|
17
10
|
GridIndexManager,
|
|
18
11
|
IndexSwitcher,
|
|
19
|
-
|
|
20
|
-
get_mrc_data,
|
|
21
|
-
load_tiff,
|
|
12
|
+
get_train_val_data,
|
|
22
13
|
)
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
def get_train_val_data(
|
|
26
|
-
data_config,
|
|
27
|
-
fpath,
|
|
28
|
-
datasplit_type: DataSplitType,
|
|
29
|
-
val_fraction=None,
|
|
30
|
-
test_fraction=None,
|
|
31
|
-
allow_generation=None,
|
|
32
|
-
ignore_specific_datapoints=None,
|
|
33
|
-
):
|
|
34
|
-
"""
|
|
35
|
-
Load the data from the given path and split them in training, validation and test sets.
|
|
36
|
-
|
|
37
|
-
Ensure that the shape of data should be N*H*W*C: N is number of data points. H,W are the image dimensions.
|
|
38
|
-
C is the number of channels.
|
|
39
|
-
"""
|
|
40
|
-
if data_config.data_type == DataType.SeparateTiffData:
|
|
41
|
-
fpath1 = os.path.join(fpath, data_config.ch1_fname)
|
|
42
|
-
fpath2 = os.path.join(fpath, data_config.ch2_fname)
|
|
43
|
-
fpaths = [fpath1, fpath2]
|
|
44
|
-
fpath0 = ""
|
|
45
|
-
if "ch_input_fname" in data_config:
|
|
46
|
-
fpath0 = os.path.join(fpath, data_config.ch_input_fname)
|
|
47
|
-
fpaths = [fpath0] + fpaths
|
|
48
|
-
|
|
49
|
-
print(
|
|
50
|
-
f"Loading from {fpath} Channels: "
|
|
51
|
-
f"{fpath1},{fpath2}, inp:{fpath0} Mode:{DataSplitType.name(datasplit_type)}"
|
|
52
|
-
)
|
|
53
|
-
|
|
54
|
-
data = np.concatenate([load_tiff(fpath)[..., None] for fpath in fpaths], axis=3)
|
|
55
|
-
if data_config.data_type == DataType.PredictedTiffData:
|
|
56
|
-
assert len(data.shape) == 5 and data.shape[-1] == 1
|
|
57
|
-
data = data[..., 0].copy()
|
|
58
|
-
# data = data[::3].copy()
|
|
59
|
-
# NOTE: This was not the correct way to do it. It is so because the noise present in the input was directly related
|
|
60
|
-
# to the noise present in the channels and so this is not the way we would get the data.
|
|
61
|
-
# We need to add the noise independently to the input and the target.
|
|
62
|
-
|
|
63
|
-
# if data_config.get('poisson_noise_factor', False):
|
|
64
|
-
# data = np.random.poisson(data)
|
|
65
|
-
# if data_config.get('enable_gaussian_noise', False):
|
|
66
|
-
# synthetic_scale = data_config.get('synthetic_gaussian_scale', 0.1)
|
|
67
|
-
# print('Adding Gaussian noise with scale', synthetic_scale)
|
|
68
|
-
# noise = np.random.normal(0, synthetic_scale, data.shape)
|
|
69
|
-
# data = data + noise
|
|
70
|
-
|
|
71
|
-
if datasplit_type == DataSplitType.All:
|
|
72
|
-
return data.astype(np.float32)
|
|
73
|
-
|
|
74
|
-
train_idx, val_idx, test_idx = get_datasplit_tuples(
|
|
75
|
-
val_fraction, test_fraction, len(data), starting_test=True
|
|
76
|
-
)
|
|
77
|
-
if datasplit_type == DataSplitType.Train:
|
|
78
|
-
return data[train_idx].astype(np.float32)
|
|
79
|
-
elif datasplit_type == DataSplitType.Val:
|
|
80
|
-
return data[val_idx].astype(np.float32)
|
|
81
|
-
elif datasplit_type == DataSplitType.Test:
|
|
82
|
-
return data[test_idx].astype(np.float32)
|
|
83
|
-
|
|
84
|
-
elif data_config.data_type == DataType.BioSR_MRC:
|
|
85
|
-
num_channels = data_config.get("num_channels", 2)
|
|
86
|
-
fpaths = []
|
|
87
|
-
data_list = []
|
|
88
|
-
for i in range(num_channels):
|
|
89
|
-
fpath1 = os.path.join(fpath, data_config.get(f"ch{i + 1}_fname"))
|
|
90
|
-
fpaths.append(fpath1)
|
|
91
|
-
data = get_mrc_data(fpath1)[..., None]
|
|
92
|
-
data_list.append(data)
|
|
93
|
-
|
|
94
|
-
dirname = os.path.dirname(os.path.dirname(fpaths[0])) + "/"
|
|
95
|
-
|
|
96
|
-
msg = ",".join([x[len(dirname) :] for x in fpaths])
|
|
97
|
-
print(
|
|
98
|
-
f"Loaded from {dirname} Channels:{len(fpaths)} {msg} Mode:{DataSplitType.name(datasplit_type)}"
|
|
99
|
-
)
|
|
100
|
-
N = data_list[0].shape[0]
|
|
101
|
-
for data in data_list:
|
|
102
|
-
N = min(N, data.shape[0])
|
|
103
|
-
|
|
104
|
-
cropped_data = []
|
|
105
|
-
for data in data_list:
|
|
106
|
-
cropped_data.append(data[:N])
|
|
107
|
-
|
|
108
|
-
data = np.concatenate(cropped_data, axis=3)
|
|
109
|
-
|
|
110
|
-
if datasplit_type == DataSplitType.All:
|
|
111
|
-
return data.astype(np.float32)
|
|
112
|
-
|
|
113
|
-
train_idx, val_idx, test_idx = get_datasplit_tuples(
|
|
114
|
-
val_fraction, test_fraction, len(data), starting_test=True
|
|
115
|
-
)
|
|
116
|
-
if datasplit_type == DataSplitType.Train:
|
|
117
|
-
return data[train_idx].astype(np.float32)
|
|
118
|
-
elif datasplit_type == DataSplitType.Val:
|
|
119
|
-
return data[val_idx].astype(np.float32)
|
|
120
|
-
elif datasplit_type == DataSplitType.Test:
|
|
121
|
-
return data[test_idx].astype(np.float32)
|
|
14
|
+
from .vae_data_config import VaeDatasetConfig, DataSplitType, GridAlignement
|
|
122
15
|
|
|
123
16
|
|
|
124
17
|
class MultiChDloader:
|
|
125
|
-
|
|
126
18
|
def __init__(
|
|
127
19
|
self,
|
|
128
|
-
data_config:
|
|
20
|
+
data_config: VaeDatasetConfig,
|
|
129
21
|
fpath: str,
|
|
130
|
-
datasplit_type: DataSplitType = None,
|
|
131
22
|
val_fraction: float = None,
|
|
132
23
|
test_fraction: float = None,
|
|
133
|
-
normalized_input=None,
|
|
134
|
-
enable_rotation_aug: bool = False,
|
|
135
|
-
enable_random_cropping: bool = False,
|
|
136
|
-
use_one_mu_std=None,
|
|
137
|
-
allow_generation: bool = False,
|
|
138
|
-
max_val: float = None,
|
|
139
|
-
grid_alignment=GridAlignement.LeftTop,
|
|
140
|
-
overlapping_padding_kwargs=None,
|
|
141
|
-
print_vars: bool = True,
|
|
142
24
|
):
|
|
143
|
-
"""
|
|
144
|
-
Here, an image is split into grids of size img_sz.
|
|
145
|
-
Args:
|
|
146
|
-
repeat_factor: Since we are doing a random crop, repeat_factor is
|
|
147
|
-
given which can repeatedly sample from the same image. If self.N=12
|
|
148
|
-
and repeat_factor is 5, then index upto 12*5 = 60 is allowed.
|
|
149
|
-
use_one_mu_std: If this is set to true, then one mean and stdev is used
|
|
150
|
-
for both channels. Otherwise, two different meean and stdev are used.
|
|
151
|
-
|
|
152
|
-
"""
|
|
25
|
+
""" """
|
|
153
26
|
self._data_type = data_config.data_type
|
|
154
27
|
self._fpath = fpath
|
|
155
28
|
self._data = self.N = self._noise_data = None
|
|
156
|
-
|
|
29
|
+
self.Z = 1
|
|
30
|
+
self._trim_boundary = data_config.trim_boundary
|
|
157
31
|
# Hardcoded params, not included in the config file.
|
|
158
32
|
|
|
159
33
|
# by default, if the noise is present, add it to the input and target.
|
|
160
34
|
self._disable_noise = False # to add synthetic noise
|
|
35
|
+
self._poisson_noise_factor = None
|
|
161
36
|
self._train_index_switcher = None
|
|
37
|
+
self._depth3D = data_config.depth3D
|
|
162
38
|
# NOTE: Input is the sum of the different channels. It is not the average of the different channels.
|
|
163
|
-
self._input_is_sum = data_config.
|
|
164
|
-
self._num_channels = data_config.
|
|
165
|
-
self._input_idx = data_config.
|
|
166
|
-
self._tar_idx_list = data_config.
|
|
39
|
+
self._input_is_sum = data_config.input_is_sum
|
|
40
|
+
self._num_channels = data_config.num_channels
|
|
41
|
+
self._input_idx = data_config.input_idx
|
|
42
|
+
self._tar_idx_list = data_config.target_idx_list
|
|
167
43
|
|
|
168
|
-
if datasplit_type == DataSplitType.Train:
|
|
44
|
+
if data_config.datasplit_type == DataSplitType.Train:
|
|
169
45
|
self._datausage_fraction = 1.0
|
|
170
46
|
# assert self._datausage_fraction == 1.0, 'Not supported. Use validtarget_random_fraction and training_validtarget_fraction to get the same effect'
|
|
171
47
|
self._validtarget_rand_fract = None
|
|
172
48
|
# self._validtarget_random_fraction_final = data_config.get('validtarget_random_fraction_final', None)
|
|
173
49
|
# self._validtarget_random_fraction_stepepoch = data_config.get('validtarget_random_fraction_stepepoch', None)
|
|
174
50
|
# self._idx_count = 0
|
|
175
|
-
elif datasplit_type == DataSplitType.Val:
|
|
51
|
+
elif data_config.datasplit_type == DataSplitType.Val:
|
|
176
52
|
self._datausage_fraction = 1.0
|
|
177
53
|
else:
|
|
178
54
|
self._datausage_fraction = 1.0
|
|
179
55
|
|
|
180
56
|
self.load_data(
|
|
181
57
|
data_config,
|
|
182
|
-
datasplit_type,
|
|
58
|
+
data_config.datasplit_type,
|
|
183
59
|
val_fraction=val_fraction,
|
|
184
60
|
test_fraction=test_fraction,
|
|
185
|
-
allow_generation=allow_generation,
|
|
61
|
+
allow_generation=data_config.allow_generation,
|
|
186
62
|
)
|
|
187
|
-
self._normalized_input = normalized_input
|
|
63
|
+
self._normalized_input = data_config.normalized_input
|
|
188
64
|
self._quantile = 1.0
|
|
189
65
|
self._channelwise_quantile = False
|
|
190
66
|
self._background_quantile = 0.0
|
|
@@ -194,8 +70,8 @@ class MultiChDloader:
|
|
|
194
70
|
|
|
195
71
|
self._background_values = None
|
|
196
72
|
|
|
197
|
-
self._grid_alignment = grid_alignment
|
|
198
|
-
self._overlapping_padding_kwargs = overlapping_padding_kwargs
|
|
73
|
+
self._grid_alignment = data_config.grid_alignment
|
|
74
|
+
self._overlapping_padding_kwargs = data_config.overlapping_padding_kwargs
|
|
199
75
|
if self._grid_alignment == GridAlignement.LeftTop:
|
|
200
76
|
assert (
|
|
201
77
|
self._overlapping_padding_kwargs is None
|
|
@@ -205,20 +81,28 @@ class MultiChDloader:
|
|
|
205
81
|
assert (
|
|
206
82
|
self._overlapping_padding_kwargs is not None
|
|
207
83
|
), "With Center grid alignment, padding is needed."
|
|
84
|
+
if self._trim_boundary:
|
|
85
|
+
if (
|
|
86
|
+
self._overlapping_padding_kwargs is None
|
|
87
|
+
or data_config.multiscale_lowres_count is not None
|
|
88
|
+
):
|
|
89
|
+
# raise warning
|
|
90
|
+
print("Padding is not used with this alignement style")
|
|
91
|
+
else:
|
|
92
|
+
assert (
|
|
93
|
+
self._overlapping_padding_kwargs is not None
|
|
94
|
+
), "When not trimming boudnary, padding is needed."
|
|
208
95
|
|
|
209
|
-
self._is_train = datasplit_type == DataSplitType.Train
|
|
96
|
+
self._is_train = data_config.datasplit_type == DataSplitType.Train
|
|
210
97
|
|
|
211
98
|
# input = alpha * ch1 + (1-alpha)*ch2.
|
|
212
99
|
# alpha is sampled randomly between these two extremes
|
|
213
|
-
self._start_alpha_arr = self._end_alpha_arr = self._return_alpha =
|
|
214
|
-
self._alpha_weighted_target
|
|
215
|
-
) = None
|
|
100
|
+
self._start_alpha_arr = self._end_alpha_arr = self._return_alpha = None
|
|
216
101
|
|
|
217
102
|
self._img_sz = self._grid_sz = self._repeat_factor = self.idx_manager = None
|
|
218
103
|
if self._is_train:
|
|
219
|
-
self._start_alpha_arr =
|
|
220
|
-
self._end_alpha_arr =
|
|
221
|
-
self._alpha_weighted_target = False
|
|
104
|
+
self._start_alpha_arr = data_config.start_alpha
|
|
105
|
+
self._end_alpha_arr = data_config.end_alpha
|
|
222
106
|
|
|
223
107
|
self.set_img_sz(
|
|
224
108
|
data_config.image_size,
|
|
@@ -229,11 +113,13 @@ class MultiChDloader:
|
|
|
229
113
|
),
|
|
230
114
|
)
|
|
231
115
|
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
116
|
+
if self._validtarget_rand_fract is not None:
|
|
117
|
+
self._train_index_switcher = IndexSwitcher(
|
|
118
|
+
self.idx_manager, data_config, self._img_sz
|
|
119
|
+
)
|
|
235
120
|
|
|
236
121
|
else:
|
|
122
|
+
|
|
237
123
|
self.set_img_sz(
|
|
238
124
|
data_config.image_size,
|
|
239
125
|
(
|
|
@@ -246,32 +132,42 @@ class MultiChDloader:
|
|
|
246
132
|
self._return_alpha = False
|
|
247
133
|
self._return_index = False
|
|
248
134
|
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
135
|
+
self._empty_patch_replacement_enabled = (
|
|
136
|
+
data_config.empty_patch_replacement_enabled and self._is_train
|
|
137
|
+
)
|
|
138
|
+
if self._empty_patch_replacement_enabled:
|
|
139
|
+
self._empty_patch_replacement_channel_idx = (
|
|
140
|
+
data_config.empty_patch_replacement_channel_idx
|
|
141
|
+
)
|
|
142
|
+
self._empty_patch_replacement_probab = (
|
|
143
|
+
data_config.empty_patch_replacement_probab
|
|
144
|
+
)
|
|
145
|
+
data_frames = self._data[..., self._empty_patch_replacement_channel_idx]
|
|
146
|
+
# NOTE: This is on the raw data. So, it must be called before removing the background.
|
|
147
|
+
# TODO: missing import, needs fixing asap!
|
|
148
|
+
self._empty_patch_fetcher = EmptyPatchFetcher(
|
|
149
|
+
self.idx_manager,
|
|
150
|
+
self._img_sz,
|
|
151
|
+
data_frames,
|
|
152
|
+
max_val_threshold=data_config.empty_patch_max_val_threshold,
|
|
153
|
+
)
|
|
260
154
|
|
|
261
|
-
self.rm_bkground_set_max_val_and_upperclip_data(
|
|
155
|
+
self.rm_bkground_set_max_val_and_upperclip_data(
|
|
156
|
+
data_config.max_val, data_config.datasplit_type
|
|
157
|
+
)
|
|
262
158
|
|
|
263
159
|
# For overlapping dloader, image_size and repeat_factors are not related. hence a different function.
|
|
264
160
|
|
|
265
161
|
self._mean = None
|
|
266
162
|
self._std = None
|
|
267
|
-
self._use_one_mu_std = use_one_mu_std
|
|
163
|
+
self._use_one_mu_std = data_config.use_one_mu_std
|
|
268
164
|
# Hardcoded
|
|
269
165
|
self._target_separate_normalization = True
|
|
270
166
|
|
|
271
|
-
self._enable_rotation = enable_rotation_aug
|
|
272
|
-
self._enable_random_cropping = enable_random_cropping
|
|
167
|
+
self._enable_rotation = data_config.enable_rotation_aug
|
|
168
|
+
self._enable_random_cropping = data_config.enable_random_cropping
|
|
273
169
|
self._uncorrelated_channels = (
|
|
274
|
-
data_config.
|
|
170
|
+
data_config.uncorrelated_channels and self._is_train
|
|
275
171
|
)
|
|
276
172
|
assert self._is_train or self._uncorrelated_channels is False
|
|
277
173
|
assert (
|
|
@@ -286,9 +182,10 @@ class MultiChDloader:
|
|
|
286
182
|
)
|
|
287
183
|
self._rotation_transform = A.Compose([A.Flip(), A.RandomRotate90()])
|
|
288
184
|
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
185
|
+
# TODO: remove print log messages
|
|
186
|
+
# if print_vars:
|
|
187
|
+
# msg = self._init_msg()
|
|
188
|
+
# print(msg)
|
|
292
189
|
|
|
293
190
|
def disable_noise(self):
|
|
294
191
|
assert (
|
|
@@ -339,7 +236,7 @@ class MultiChDloader:
|
|
|
339
236
|
)
|
|
340
237
|
|
|
341
238
|
msg = ""
|
|
342
|
-
if data_config.
|
|
239
|
+
if data_config.poisson_noise_factor > 0:
|
|
343
240
|
self._poisson_noise_factor = data_config.poisson_noise_factor
|
|
344
241
|
msg += f"Adding Poisson noise with factor {self._poisson_noise_factor}.\t"
|
|
345
242
|
self._data = (
|
|
@@ -347,20 +244,26 @@ class MultiChDloader:
|
|
|
347
244
|
* self._poisson_noise_factor
|
|
348
245
|
)
|
|
349
246
|
|
|
350
|
-
if data_config.
|
|
351
|
-
synthetic_scale = data_config.
|
|
247
|
+
if data_config.enable_gaussian_noise:
|
|
248
|
+
synthetic_scale = data_config.synthetic_gaussian_scale
|
|
352
249
|
msg += f"Adding Gaussian noise with scale {synthetic_scale}"
|
|
353
250
|
# 0 => noise for input. 1: => noise for all targets.
|
|
354
251
|
shape = self._data.shape
|
|
355
252
|
self._noise_data = np.random.normal(
|
|
356
253
|
0, synthetic_scale, (*shape[:-1], shape[-1] + 1)
|
|
357
254
|
)
|
|
358
|
-
if data_config.
|
|
255
|
+
if data_config.input_has_dependant_noise:
|
|
359
256
|
msg += ". Moreover, input has dependent noise"
|
|
360
257
|
self._noise_data[..., 0] = np.mean(self._noise_data[..., 1:], axis=-1)
|
|
361
258
|
print(msg)
|
|
362
259
|
|
|
363
|
-
self.
|
|
260
|
+
self._5Ddata = len(self._data.shape) == 5
|
|
261
|
+
if self._5Ddata:
|
|
262
|
+
self.Z = self._data.shape[1]
|
|
263
|
+
|
|
264
|
+
if self._depth3D > 1:
|
|
265
|
+
assert self._5Ddata, "Data must be 5D:NxZxHxWxC for 3D data"
|
|
266
|
+
|
|
364
267
|
assert (
|
|
365
268
|
self._data.shape[-1] == self._num_channels
|
|
366
269
|
), "Number of channels in data and config do not match."
|
|
@@ -441,9 +344,13 @@ class MultiChDloader:
|
|
|
441
344
|
def get_img_sz(self):
|
|
442
345
|
return self._img_sz
|
|
443
346
|
|
|
347
|
+
def get_num_frames(self):
|
|
348
|
+
return self._data.shape[0]
|
|
349
|
+
|
|
444
350
|
def reduce_data(
|
|
445
351
|
self, t_list=None, h_start=None, h_end=None, w_start=None, w_end=None
|
|
446
352
|
):
|
|
353
|
+
assert not self._5Ddata, "This function is not supported for 3D data."
|
|
447
354
|
if t_list is None:
|
|
448
355
|
t_list = list(range(self._data.shape[0]))
|
|
449
356
|
if h_start is None:
|
|
@@ -461,12 +368,22 @@ class MultiChDloader:
|
|
|
461
368
|
t_list, h_start:h_end, w_start:w_end, :
|
|
462
369
|
].copy()
|
|
463
370
|
|
|
464
|
-
self.N = len(t_list)
|
|
465
371
|
self.set_img_sz(self._img_sz, self._grid_sz)
|
|
466
372
|
print(
|
|
467
373
|
f"[{self.__class__.__name__}] Data reduced. New data shape: {self._data.shape}"
|
|
468
374
|
)
|
|
469
375
|
|
|
376
|
+
def get_idx_manager_shapes(self, patch_size: int, grid_size: int):
|
|
377
|
+
numC = self._data.shape[-1]
|
|
378
|
+
if self._5Ddata:
|
|
379
|
+
grid_shape = (1, 1, grid_size, grid_size, numC)
|
|
380
|
+
patch_shape = (1, self._depth3D, patch_size, patch_size, numC)
|
|
381
|
+
else:
|
|
382
|
+
grid_shape = (1, grid_size, grid_size, numC)
|
|
383
|
+
patch_shape = (1, patch_size, patch_size, numC)
|
|
384
|
+
|
|
385
|
+
return patch_shape, grid_shape
|
|
386
|
+
|
|
470
387
|
def set_img_sz(self, image_size, grid_size):
|
|
471
388
|
"""
|
|
472
389
|
If one wants to change the image size on the go, then this can be used.
|
|
@@ -474,12 +391,23 @@ class MultiChDloader:
|
|
|
474
391
|
image_size: size of one patch
|
|
475
392
|
grid_size: frame is divided into square grids of this size. A patch centered on a grid having size `image_size` is returned.
|
|
476
393
|
"""
|
|
394
|
+
|
|
477
395
|
self._img_sz = image_size
|
|
478
396
|
self._grid_sz = grid_size
|
|
397
|
+
shape = self._data.shape
|
|
398
|
+
|
|
399
|
+
patch_shape, grid_shape = self.get_idx_manager_shapes(
|
|
400
|
+
self._img_sz, self._grid_sz
|
|
401
|
+
)
|
|
479
402
|
self.idx_manager = GridIndexManager(
|
|
480
|
-
|
|
403
|
+
shape, grid_shape, patch_shape, self._trim_boundary
|
|
481
404
|
)
|
|
482
|
-
self.set_repeat_factor()
|
|
405
|
+
# self.set_repeat_factor()
|
|
406
|
+
|
|
407
|
+
def __len__(self):
|
|
408
|
+
# Vera: N is the number of frames in Z stack
|
|
409
|
+
# Repeat factor is n_rows * n_cols
|
|
410
|
+
return self.idx_manager.total_grid_count()
|
|
483
411
|
|
|
484
412
|
def set_repeat_factor(self):
|
|
485
413
|
if self._grid_sz > 1:
|
|
@@ -497,7 +425,14 @@ class MultiChDloader:
|
|
|
497
425
|
msg = (
|
|
498
426
|
f"[{self.__class__.__name__}] Train:{int(self._is_train)} Sz:{self._img_sz}"
|
|
499
427
|
)
|
|
428
|
+
dim_sizes = [
|
|
429
|
+
self.idx_manager.get_individual_dim_grid_count(dim)
|
|
430
|
+
for dim in range(len(self._data.shape))
|
|
431
|
+
]
|
|
432
|
+
dim_sizes = ",".join([str(x) for x in dim_sizes])
|
|
500
433
|
msg += f" N:{self.N} NumPatchPerN:{self._repeat_factor}"
|
|
434
|
+
msg += f"{self.idx_manager.total_grid_count()} DimSz:({dim_sizes})"
|
|
435
|
+
msg += f" TrimB:{self._trim_boundary}"
|
|
501
436
|
# msg += f' NormInp:{self._normalized_input}'
|
|
502
437
|
# msg += f' SingleNorm:{self._use_one_mu_std}'
|
|
503
438
|
msg += f" Rot:{self._enable_rotation}"
|
|
@@ -529,40 +464,52 @@ class MultiChDloader:
|
|
|
529
464
|
)
|
|
530
465
|
|
|
531
466
|
if self._enable_random_cropping:
|
|
532
|
-
|
|
467
|
+
patch_start_loc = self._get_random_hw(h, w)
|
|
468
|
+
if self._5Ddata:
|
|
469
|
+
patch_start_loc = (
|
|
470
|
+
np.random.choice(img_tuples[0].shape[-3] - self._depth3D),
|
|
471
|
+
) + patch_start_loc
|
|
533
472
|
else:
|
|
534
|
-
|
|
473
|
+
patch_start_loc = self._get_deterministic_loc(index)
|
|
535
474
|
|
|
536
475
|
cropped_imgs = []
|
|
537
476
|
for img in img_tuples:
|
|
538
|
-
img = self._crop_flip_img(img,
|
|
477
|
+
img = self._crop_flip_img(img, patch_start_loc, False, False)
|
|
539
478
|
cropped_imgs.append(img)
|
|
540
479
|
|
|
541
480
|
return (
|
|
542
481
|
*tuple(cropped_imgs),
|
|
543
482
|
{
|
|
544
|
-
"h": [h_start, h_start + self._img_sz],
|
|
545
|
-
"w": [w_start, w_start + self._img_sz],
|
|
546
483
|
"hflip": False,
|
|
547
484
|
"wflip": False,
|
|
548
485
|
},
|
|
549
486
|
)
|
|
550
487
|
|
|
551
|
-
def _crop_img(self, img: np.ndarray,
|
|
552
|
-
if self.
|
|
488
|
+
def _crop_img(self, img: np.ndarray, patch_start_loc: Tuple):
|
|
489
|
+
if self._trim_boundary:
|
|
553
490
|
# In training, this is used.
|
|
554
491
|
# NOTE: It is my opinion that if I just use self._crop_img_with_padding, it will work perfectly fine.
|
|
555
492
|
# The only benefit this if else loop provides is that it makes it easier to see what happens during training.
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
493
|
+
patch_end_loc = (
|
|
494
|
+
np.array(patch_start_loc, dtype=np.int32)
|
|
495
|
+
+ self.idx_manager.patch_shape[1:-1]
|
|
496
|
+
)
|
|
497
|
+
if self._5Ddata:
|
|
498
|
+
z_start, h_start, w_start = patch_start_loc
|
|
499
|
+
z_end, h_end, w_end = patch_end_loc
|
|
500
|
+
new_img = img[..., z_start:z_end, h_start:h_end, w_start:w_end]
|
|
501
|
+
else:
|
|
502
|
+
h_start, w_start = patch_start_loc
|
|
503
|
+
h_end, w_end = patch_end_loc
|
|
504
|
+
new_img = img[..., h_start:h_end, w_start:w_end]
|
|
505
|
+
|
|
559
506
|
return new_img
|
|
560
|
-
|
|
507
|
+
else:
|
|
561
508
|
# During evaluation, this is used. In this situation, we can have negative h_start, w_start. Or h_start +self._img_sz can be larger than frame
|
|
562
509
|
# In these situations, we need some sort of padding. This is not needed in the LeftTop alignement.
|
|
563
|
-
return self._crop_img_with_padding(img,
|
|
510
|
+
return self._crop_img_with_padding(img, patch_start_loc)
|
|
564
511
|
|
|
565
|
-
def get_begin_end_padding(self, start_pos, max_len):
|
|
512
|
+
def get_begin_end_padding(self, start_pos, end_pos, max_len):
|
|
566
513
|
"""
|
|
567
514
|
The effect is that the image with size self._grid_sz is in the center of the patch with sufficient
|
|
568
515
|
padding on all four sides so that the final patch size is self._img_sz.
|
|
@@ -572,44 +519,56 @@ class MultiChDloader:
|
|
|
572
519
|
if start_pos < 0:
|
|
573
520
|
pad_start = -1 * start_pos
|
|
574
521
|
|
|
575
|
-
pad_end = max(0,
|
|
522
|
+
pad_end = max(0, end_pos - max_len)
|
|
576
523
|
|
|
577
524
|
return pad_start, pad_end
|
|
578
525
|
|
|
579
|
-
def _crop_img_with_padding(
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
526
|
+
def _crop_img_with_padding(
|
|
527
|
+
self, img: np.ndarray, patch_start_loc, max_len_vals=None
|
|
528
|
+
):
|
|
529
|
+
if max_len_vals is None:
|
|
530
|
+
max_len_vals = self.idx_manager.data_shape[1:-1]
|
|
531
|
+
patch_end_loc = np.array(patch_start_loc, dtype=int) + np.array(
|
|
532
|
+
self.idx_manager.patch_shape[1:-1], dtype=int
|
|
533
|
+
)
|
|
534
|
+
boundary_crossed = []
|
|
535
|
+
valid_slice = []
|
|
536
|
+
padding = [[0, 0]]
|
|
537
|
+
for start_idx, end_idx, max_len in zip(
|
|
538
|
+
patch_start_loc, patch_end_loc, max_len_vals
|
|
539
|
+
):
|
|
540
|
+
boundary_crossed.append(end_idx > max_len or start_idx < 0)
|
|
541
|
+
valid_slice.append((max(0, start_idx), min(max_len, end_idx)))
|
|
542
|
+
pad = [0, 0]
|
|
543
|
+
if boundary_crossed[-1]:
|
|
544
|
+
pad = self.get_begin_end_padding(start_idx, end_idx, max_len)
|
|
545
|
+
padding.append(pad)
|
|
589
546
|
# max() is needed since h_start could be negative.
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
547
|
+
if self._5Ddata:
|
|
548
|
+
new_img = img[
|
|
549
|
+
...,
|
|
550
|
+
valid_slice[0][0] : valid_slice[0][1],
|
|
551
|
+
valid_slice[1][0] : valid_slice[1][1],
|
|
552
|
+
valid_slice[2][0] : valid_slice[2][1],
|
|
553
|
+
]
|
|
554
|
+
else:
|
|
555
|
+
new_img = img[
|
|
556
|
+
...,
|
|
557
|
+
valid_slice[0][0] : valid_slice[0][1],
|
|
558
|
+
valid_slice[1][0] : valid_slice[1][1],
|
|
559
|
+
]
|
|
603
560
|
|
|
561
|
+
# print(np.array(padding).shape, img.shape, new_img.shape)
|
|
562
|
+
# print(padding)
|
|
604
563
|
if not np.all(padding == 0):
|
|
605
564
|
new_img = np.pad(new_img, padding, **self._overlapping_padding_kwargs)
|
|
606
565
|
|
|
607
566
|
return new_img
|
|
608
567
|
|
|
609
568
|
def _crop_flip_img(
|
|
610
|
-
self, img: np.ndarray,
|
|
569
|
+
self, img: np.ndarray, patch_start_loc: Tuple, h_flip: bool, w_flip: bool
|
|
611
570
|
):
|
|
612
|
-
new_img = self._crop_img(img,
|
|
571
|
+
new_img = self._crop_img(img, patch_start_loc)
|
|
613
572
|
if h_flip:
|
|
614
573
|
new_img = new_img[..., ::-1, :]
|
|
615
574
|
if w_flip:
|
|
@@ -617,9 +576,6 @@ class MultiChDloader:
|
|
|
617
576
|
|
|
618
577
|
return new_img.astype(np.float32)
|
|
619
578
|
|
|
620
|
-
def __len__(self):
|
|
621
|
-
return self.N * self._repeat_factor
|
|
622
|
-
|
|
623
579
|
def _load_img(
|
|
624
580
|
self, index: Union[int, Tuple[int, int]]
|
|
625
581
|
) -> Tuple[np.ndarray, np.ndarray]:
|
|
@@ -631,12 +587,21 @@ class MultiChDloader:
|
|
|
631
587
|
else:
|
|
632
588
|
idx = index[0]
|
|
633
589
|
|
|
634
|
-
|
|
590
|
+
patch_loc_list = self.idx_manager.get_patch_location_from_dataset_idx(idx)
|
|
591
|
+
imgs = self._data[patch_loc_list[0]]
|
|
592
|
+
# if self._5Ddata:
|
|
593
|
+
# assert self._noise_data is None, 'Noise is not supported for 5D data'
|
|
594
|
+
# n_loc, z_loc = patch_loc_list[:2]
|
|
595
|
+
# z_loc_interval = range(z_loc, z_loc + self._depth3D)
|
|
596
|
+
# imgs = self._data[n_loc, z_loc_interval]
|
|
597
|
+
# else:
|
|
598
|
+
# imgs = self._data[patch_loc_list[0]]
|
|
599
|
+
|
|
635
600
|
loaded_imgs = [imgs[None, ..., i] for i in range(imgs.shape[-1])]
|
|
636
601
|
noise = []
|
|
637
602
|
if self._noise_data is not None and not self._disable_noise:
|
|
638
603
|
noise = [
|
|
639
|
-
self._noise_data[
|
|
604
|
+
self._noise_data[patch_loc_list[0]][None, ..., i]
|
|
640
605
|
for i in range(self._noise_data.shape[-1])
|
|
641
606
|
]
|
|
642
607
|
return tuple(loaded_imgs), tuple(noise)
|
|
@@ -669,27 +634,16 @@ class MultiChDloader:
|
|
|
669
634
|
def per_side_overlap_pixelcount(self):
|
|
670
635
|
return (self._img_sz - self._grid_sz) // 2
|
|
671
636
|
|
|
672
|
-
def on_boundary(self, cur_loc, frame_size):
|
|
673
|
-
|
|
637
|
+
# def on_boundary(self, cur_loc, frame_size):
|
|
638
|
+
# return cur_loc + self._img_sz > frame_size or cur_loc < 0
|
|
674
639
|
|
|
675
|
-
def
|
|
640
|
+
def _get_deterministic_loc(self, index: int):
|
|
676
641
|
"""
|
|
677
642
|
It returns the top-left corner of the patch corresponding to index.
|
|
678
643
|
"""
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
else:
|
|
683
|
-
idx, grid_size = index
|
|
684
|
-
|
|
685
|
-
h_start, w_start = self.idx_manager.get_deterministic_hw(
|
|
686
|
-
idx, grid_size=grid_size
|
|
687
|
-
)
|
|
688
|
-
if self._grid_alignment == GridAlignement.LeftTop:
|
|
689
|
-
return h_start, w_start
|
|
690
|
-
elif self._grid_alignment == GridAlignement.Center:
|
|
691
|
-
pad = self.per_side_overlap_pixelcount()
|
|
692
|
-
return h_start - pad, w_start - pad
|
|
644
|
+
loc_list = self.idx_manager.get_patch_location_from_dataset_idx(index)
|
|
645
|
+
# last dim is channel. we need to take the third and the second last element.
|
|
646
|
+
return loc_list[1:-1]
|
|
693
647
|
|
|
694
648
|
def compute_individual_mean_std(self):
|
|
695
649
|
# numpy 1.19.2 has issues in computing for large arrays. https://github.com/numpy/numpy/issues/8869
|
|
@@ -715,6 +669,10 @@ class MultiChDloader:
|
|
|
715
669
|
|
|
716
670
|
mean = np.array(mean_arr)
|
|
717
671
|
std = np.array(std_arr)
|
|
672
|
+
if (
|
|
673
|
+
self._5Ddata
|
|
674
|
+
): # NOTE: IDEALLY this should be only when the model expects 3D data.
|
|
675
|
+
return mean[None, :, None, None, None], std[None, :, None, None, None]
|
|
718
676
|
|
|
719
677
|
return mean[None, :, None, None], std[None, :, None, None]
|
|
720
678
|
|
|
@@ -776,6 +734,10 @@ class MultiChDloader:
|
|
|
776
734
|
if self._skip_normalization_using_mean:
|
|
777
735
|
mean = np.zeros_like(mean)
|
|
778
736
|
|
|
737
|
+
if self._5Ddata:
|
|
738
|
+
mean = mean[:, :, None]
|
|
739
|
+
std = std[:, :, None]
|
|
740
|
+
|
|
779
741
|
mean_dict = {"input": mean} # , 'target':mean}
|
|
780
742
|
std_dict = {"input": std} # , 'target':std}
|
|
781
743
|
|
|
@@ -810,8 +772,14 @@ class MultiChDloader:
|
|
|
810
772
|
return cropped_img_tuples, cropped_noise_tuples
|
|
811
773
|
|
|
812
774
|
def replace_with_empty_patch(self, img_tuples):
|
|
775
|
+
"""
|
|
776
|
+
Replaces the content of one of the channels with background
|
|
777
|
+
"""
|
|
813
778
|
empty_index = self._empty_patch_fetcher.sample()
|
|
814
|
-
empty_img_tuples = self._get_img(empty_index)
|
|
779
|
+
empty_img_tuples, empty_img_noise_tuples = self._get_img(empty_index)
|
|
780
|
+
assert (
|
|
781
|
+
len(empty_img_noise_tuples) == 0
|
|
782
|
+
), "Noise is not supported with empty patch replacement"
|
|
815
783
|
final_img_tuples = []
|
|
816
784
|
for tuple_idx in range(len(img_tuples)):
|
|
817
785
|
if tuple_idx == self._empty_patch_replacement_channel_idx:
|
|
@@ -834,14 +802,7 @@ class MultiChDloader:
|
|
|
834
802
|
)
|
|
835
803
|
img_tuples = [img_tuples[i] for i in self._tar_idx_list]
|
|
836
804
|
|
|
837
|
-
|
|
838
|
-
assert self._input_is_sum is False
|
|
839
|
-
target = []
|
|
840
|
-
for i in range(len(img_tuples)):
|
|
841
|
-
target.append(img_tuples[i] * alpha[i])
|
|
842
|
-
target = np.concatenate(target, axis=0)
|
|
843
|
-
else:
|
|
844
|
-
target = np.concatenate(img_tuples, axis=0)
|
|
805
|
+
target = np.concatenate(img_tuples, axis=0)
|
|
845
806
|
return target
|
|
846
807
|
|
|
847
808
|
def _compute_input_with_alpha(self, img_tuples, alpha_list):
|
|
@@ -902,9 +863,6 @@ class MultiChDloader:
|
|
|
902
863
|
index = self._train_index_switcher.get_invalid_target_index()
|
|
903
864
|
return index
|
|
904
865
|
|
|
905
|
-
def _rotate(self, img_tuples, noise_tuples):
|
|
906
|
-
return self._rotate2D(img_tuples, noise_tuples)
|
|
907
|
-
|
|
908
866
|
def _rotate2D(self, img_tuples, noise_tuples):
|
|
909
867
|
img_kwargs = {}
|
|
910
868
|
for i, img in enumerate(img_tuples):
|
|
@@ -921,6 +879,7 @@ class MultiChDloader:
|
|
|
921
879
|
rot_dic = self._rotation_transform(
|
|
922
880
|
image=img_tuples[0][0], **img_kwargs, **noise_kwargs
|
|
923
881
|
)
|
|
882
|
+
|
|
924
883
|
rotated_img_tuples = []
|
|
925
884
|
for i, img in enumerate(img_tuples):
|
|
926
885
|
if len(img) == 1:
|
|
@@ -946,7 +905,90 @@ class MultiChDloader:
|
|
|
946
905
|
|
|
947
906
|
return rotated_img_tuples, rotated_noise_tuples
|
|
948
907
|
|
|
908
|
+
def _rotate(self, img_tuples, noise_tuples):
|
|
909
|
+
if self._depth3D > 1:
|
|
910
|
+
return self._rotate3D(img_tuples, noise_tuples)
|
|
911
|
+
else:
|
|
912
|
+
return self._rotate2D(img_tuples, noise_tuples)
|
|
913
|
+
|
|
914
|
+
def _rotate3D(self, img_tuples, noise_tuples):
|
|
915
|
+
img_kwargs = {}
|
|
916
|
+
for i, img in enumerate(img_tuples):
|
|
917
|
+
for j in range(self._depth3D):
|
|
918
|
+
for k in range(len(img)):
|
|
919
|
+
img_kwargs[f"img{i}_{j}_{k}"] = img[k, j]
|
|
920
|
+
|
|
921
|
+
noise_kwargs = {}
|
|
922
|
+
for i, nimg in enumerate(noise_tuples):
|
|
923
|
+
for j in range(self._depth3D):
|
|
924
|
+
for k in range(len(nimg)):
|
|
925
|
+
noise_kwargs[f"noise{i}_{j}_{k}"] = nimg[k, j]
|
|
926
|
+
|
|
927
|
+
keys = list(img_kwargs.keys()) + list(noise_kwargs.keys())
|
|
928
|
+
self._rotation_transform.add_targets({k: "image" for k in keys})
|
|
929
|
+
rot_dic = self._rotation_transform(
|
|
930
|
+
image=img_tuples[0][0], **img_kwargs, **noise_kwargs
|
|
931
|
+
)
|
|
932
|
+
rotated_img_tuples = []
|
|
933
|
+
for i, img in enumerate(img_tuples):
|
|
934
|
+
if len(img) == 1:
|
|
935
|
+
rotated_img_tuples.append(
|
|
936
|
+
np.concatenate(
|
|
937
|
+
[
|
|
938
|
+
rot_dic[f"img{i}_{j}_0"][None, None]
|
|
939
|
+
for j in range(self._depth3D)
|
|
940
|
+
],
|
|
941
|
+
axis=1,
|
|
942
|
+
)
|
|
943
|
+
)
|
|
944
|
+
else:
|
|
945
|
+
temp_arr = []
|
|
946
|
+
for k in range(len(img)):
|
|
947
|
+
temp_arr.append(
|
|
948
|
+
np.concatenate(
|
|
949
|
+
[
|
|
950
|
+
rot_dic[f"img{i}_{j}_{k}"][None, None]
|
|
951
|
+
for j in range(self._depth3D)
|
|
952
|
+
],
|
|
953
|
+
axis=1,
|
|
954
|
+
)
|
|
955
|
+
)
|
|
956
|
+
rotated_img_tuples.append(np.concatenate(temp_arr, axis=0))
|
|
957
|
+
|
|
958
|
+
rotated_noise_tuples = []
|
|
959
|
+
for i, nimg in enumerate(noise_tuples):
|
|
960
|
+
if len(nimg) == 1:
|
|
961
|
+
rotated_noise_tuples.append(
|
|
962
|
+
np.concatenate(
|
|
963
|
+
[
|
|
964
|
+
rot_dic[f"noise{i}_{j}_0"][None, None]
|
|
965
|
+
for j in range(self._depth3D)
|
|
966
|
+
],
|
|
967
|
+
axis=1,
|
|
968
|
+
)
|
|
969
|
+
)
|
|
970
|
+
else:
|
|
971
|
+
temp_arr = []
|
|
972
|
+
for k in range(len(nimg)):
|
|
973
|
+
temp_arr.append(
|
|
974
|
+
np.concatenate(
|
|
975
|
+
[
|
|
976
|
+
rot_dic[f"noise{i}_{j}_{k}"][None, None]
|
|
977
|
+
for j in range(self._depth3D)
|
|
978
|
+
],
|
|
979
|
+
axis=1,
|
|
980
|
+
)
|
|
981
|
+
)
|
|
982
|
+
rotated_noise_tuples.append(np.concatenate(temp_arr, axis=0))
|
|
983
|
+
|
|
984
|
+
return rotated_img_tuples, rotated_noise_tuples
|
|
985
|
+
|
|
949
986
|
def get_uncorrelated_img_tuples(self, index):
|
|
987
|
+
"""
|
|
988
|
+
Content of channels like actin and nuclei is "correlated" in its
|
|
989
|
+
respective location, this function allows to pick channels' content
|
|
990
|
+
from different patches of the image to make it "uncorrelated".
|
|
991
|
+
"""
|
|
950
992
|
img_tuples, noise_tuples = self._get_img(index)
|
|
951
993
|
assert len(noise_tuples) == 0
|
|
952
994
|
img_tuples = [img_tuples[0]]
|
|
@@ -959,6 +1001,8 @@ class MultiChDloader:
|
|
|
959
1001
|
def __getitem__(
|
|
960
1002
|
self, index: Union[int, Tuple[int, int]]
|
|
961
1003
|
) -> Tuple[np.ndarray, np.ndarray]:
|
|
1004
|
+
# Vera: input can be both real microscopic image and two separate channels that are summed in the code
|
|
1005
|
+
|
|
962
1006
|
if self._train_index_switcher is not None:
|
|
963
1007
|
index = self._get_index_from_valid_target_logic(index)
|
|
964
1008
|
|
|
@@ -971,22 +1015,29 @@ class MultiChDloader:
|
|
|
971
1015
|
self._empty_patch_replacement_enabled != True
|
|
972
1016
|
), "This is not supported with noise"
|
|
973
1017
|
|
|
1018
|
+
# Replace the content of one of the channels
|
|
1019
|
+
# with background with given probability
|
|
974
1020
|
if self._empty_patch_replacement_enabled:
|
|
975
1021
|
if np.random.rand() < self._empty_patch_replacement_probab:
|
|
976
1022
|
img_tuples = self.replace_with_empty_patch(img_tuples)
|
|
977
1023
|
|
|
1024
|
+
# Noise tuples are not needed for the paper
|
|
1025
|
+
# the image tuples are noisy by default
|
|
1026
|
+
# TODO: remove noise tuples completely?
|
|
978
1027
|
if self._enable_rotation:
|
|
979
1028
|
img_tuples, noise_tuples = self._rotate(img_tuples, noise_tuples)
|
|
980
1029
|
|
|
981
|
-
#
|
|
1030
|
+
# Add noise tuples with image tuples to create the input
|
|
982
1031
|
if len(noise_tuples) > 0:
|
|
983
1032
|
factor = np.sqrt(2) if self._input_is_sum else 1.0
|
|
984
1033
|
input_tuples = [x + noise_tuples[0] * factor for x in img_tuples]
|
|
985
1034
|
else:
|
|
986
1035
|
input_tuples = img_tuples
|
|
1036
|
+
|
|
1037
|
+
# Weight the individual channels, typically alpha is fixed
|
|
987
1038
|
inp, alpha = self._compute_input(input_tuples)
|
|
988
1039
|
|
|
989
|
-
#
|
|
1040
|
+
# Add noise tuples to the image tuples to create the target
|
|
990
1041
|
if len(noise_tuples) >= 1:
|
|
991
1042
|
img_tuples = [x + noise for x, noise in zip(img_tuples, noise_tuples[1:])]
|
|
992
1043
|
|
|
@@ -1000,221 +1051,4 @@ class MultiChDloader:
|
|
|
1000
1051
|
if self._return_index:
|
|
1001
1052
|
output.append(index)
|
|
1002
1053
|
|
|
1003
|
-
if isinstance(index, int) or isinstance(index, np.int64):
|
|
1004
|
-
return tuple(output)
|
|
1005
|
-
|
|
1006
|
-
_, grid_size = index
|
|
1007
|
-
output.append(grid_size)
|
|
1008
|
-
return tuple(output)
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
class LCMultiChDloader(MultiChDloader):
|
|
1012
|
-
|
|
1013
|
-
def __init__(
|
|
1014
|
-
self,
|
|
1015
|
-
data_config,
|
|
1016
|
-
fpath: str,
|
|
1017
|
-
datasplit_type: DataSplitType = None,
|
|
1018
|
-
val_fraction=None,
|
|
1019
|
-
test_fraction=None,
|
|
1020
|
-
normalized_input=None,
|
|
1021
|
-
enable_rotation_aug: bool = False,
|
|
1022
|
-
use_one_mu_std=None,
|
|
1023
|
-
num_scales: int = None,
|
|
1024
|
-
enable_random_cropping=False,
|
|
1025
|
-
padding_kwargs: dict = None,
|
|
1026
|
-
allow_generation: bool = False,
|
|
1027
|
-
lowres_supervision=None,
|
|
1028
|
-
max_val=None,
|
|
1029
|
-
grid_alignment=GridAlignement.LeftTop,
|
|
1030
|
-
overlapping_padding_kwargs=None,
|
|
1031
|
-
print_vars=True,
|
|
1032
|
-
):
|
|
1033
|
-
"""
|
|
1034
|
-
Args:
|
|
1035
|
-
num_scales: The number of resolutions at which we want the input. Note that the target is formed at the
|
|
1036
|
-
highest resolution.
|
|
1037
|
-
"""
|
|
1038
|
-
self._padding_kwargs = (
|
|
1039
|
-
padding_kwargs # mode=padding_mode, constant_values=constant_value
|
|
1040
|
-
)
|
|
1041
|
-
if overlapping_padding_kwargs is not None:
|
|
1042
|
-
assert (
|
|
1043
|
-
self._padding_kwargs == overlapping_padding_kwargs
|
|
1044
|
-
), "During evaluation, overlapping_padding_kwargs should be same as padding_args. \
|
|
1045
|
-
It should be so since we just use overlapping_padding_kwargs when it is not None"
|
|
1046
|
-
|
|
1047
|
-
else:
|
|
1048
|
-
overlapping_padding_kwargs = padding_kwargs
|
|
1049
|
-
|
|
1050
|
-
super().__init__(
|
|
1051
|
-
data_config,
|
|
1052
|
-
fpath,
|
|
1053
|
-
datasplit_type=datasplit_type,
|
|
1054
|
-
val_fraction=val_fraction,
|
|
1055
|
-
test_fraction=test_fraction,
|
|
1056
|
-
normalized_input=normalized_input,
|
|
1057
|
-
enable_rotation_aug=enable_rotation_aug,
|
|
1058
|
-
enable_random_cropping=enable_random_cropping,
|
|
1059
|
-
use_one_mu_std=use_one_mu_std,
|
|
1060
|
-
allow_generation=allow_generation,
|
|
1061
|
-
max_val=max_val,
|
|
1062
|
-
grid_alignment=grid_alignment,
|
|
1063
|
-
overlapping_padding_kwargs=overlapping_padding_kwargs,
|
|
1064
|
-
print_vars=print_vars,
|
|
1065
|
-
)
|
|
1066
|
-
self.num_scales = num_scales
|
|
1067
|
-
assert self.num_scales is not None
|
|
1068
|
-
self._scaled_data = [self._data]
|
|
1069
|
-
self._scaled_noise_data = [self._noise_data]
|
|
1070
|
-
|
|
1071
|
-
assert isinstance(self.num_scales, int) and self.num_scales >= 1
|
|
1072
|
-
self._lowres_supervision = lowres_supervision
|
|
1073
|
-
assert isinstance(self._padding_kwargs, dict)
|
|
1074
|
-
assert "mode" in self._padding_kwargs
|
|
1075
|
-
|
|
1076
|
-
for _ in range(1, self.num_scales):
|
|
1077
|
-
shape = self._scaled_data[-1].shape
|
|
1078
|
-
assert len(shape) == 4
|
|
1079
|
-
new_shape = (shape[0], shape[1] // 2, shape[2] // 2, shape[3])
|
|
1080
|
-
ds_data = resize(
|
|
1081
|
-
self._scaled_data[-1].astype(np.float32), new_shape
|
|
1082
|
-
).astype(self._scaled_data[-1].dtype)
|
|
1083
|
-
# NOTE: These asserts are important. the resize method expects np.float32. otherwise, one gets weird results.
|
|
1084
|
-
assert (
|
|
1085
|
-
ds_data.max() / self._scaled_data[-1].max() < 5
|
|
1086
|
-
), "Downsampled image should not have very different values"
|
|
1087
|
-
assert (
|
|
1088
|
-
ds_data.max() / self._scaled_data[-1].max() > 0.2
|
|
1089
|
-
), "Downsampled image should not have very different values"
|
|
1090
|
-
|
|
1091
|
-
self._scaled_data.append(ds_data)
|
|
1092
|
-
# do the same for noise
|
|
1093
|
-
if self._noise_data is not None:
|
|
1094
|
-
noise_data = resize(self._scaled_noise_data[-1], new_shape)
|
|
1095
|
-
self._scaled_noise_data.append(noise_data)
|
|
1096
|
-
|
|
1097
|
-
def _init_msg(self):
|
|
1098
|
-
msg = super()._init_msg()
|
|
1099
|
-
msg += f" Pad:{self._padding_kwargs}"
|
|
1100
|
-
return msg
|
|
1101
|
-
|
|
1102
|
-
def _load_scaled_img(
|
|
1103
|
-
self, scaled_index, index: Union[int, Tuple[int, int]]
|
|
1104
|
-
) -> Tuple[np.ndarray, np.ndarray]:
|
|
1105
|
-
if isinstance(index, int):
|
|
1106
|
-
idx = index
|
|
1107
|
-
else:
|
|
1108
|
-
idx, _ = index
|
|
1109
|
-
imgs = self._scaled_data[scaled_index][idx % self.N]
|
|
1110
|
-
imgs = tuple([imgs[None, :, :, i] for i in range(imgs.shape[-1])])
|
|
1111
|
-
if self._noise_data is not None:
|
|
1112
|
-
noisedata = self._scaled_noise_data[scaled_index][idx % self.N]
|
|
1113
|
-
noise = tuple(
|
|
1114
|
-
[noisedata[None, :, :, i] for i in range(noisedata.shape[-1])]
|
|
1115
|
-
)
|
|
1116
|
-
factor = np.sqrt(2) if self._input_is_sum else 1.0
|
|
1117
|
-
# since we are using this lowres images for just the input, we need to add the noise of the input.
|
|
1118
|
-
assert self._lowres_supervision is None or self._lowres_supervision is False
|
|
1119
|
-
imgs = tuple([img + noise[0] * factor for img in imgs])
|
|
1120
|
-
return imgs
|
|
1121
|
-
|
|
1122
|
-
def _crop_img(self, img: np.ndarray, h_start: int, w_start: int):
|
|
1123
|
-
"""
|
|
1124
|
-
Here, h_start, w_start could be negative. That simply means we need to pick the content from 0. So,
|
|
1125
|
-
the cropped image will be smaller than self._img_sz * self._img_sz
|
|
1126
|
-
"""
|
|
1127
|
-
return self._crop_img_with_padding(img, h_start, w_start)
|
|
1128
|
-
|
|
1129
|
-
def _get_img(self, index: int):
|
|
1130
|
-
"""
|
|
1131
|
-
Returns the primary patch along with low resolution patches centered on the primary patch.
|
|
1132
|
-
"""
|
|
1133
|
-
img_tuples, noise_tuples = self._load_img(index)
|
|
1134
|
-
assert self._img_sz is not None
|
|
1135
|
-
h, w = img_tuples[0].shape[-2:]
|
|
1136
|
-
if self._enable_random_cropping:
|
|
1137
|
-
h_start, w_start = self._get_random_hw(h, w)
|
|
1138
|
-
else:
|
|
1139
|
-
h_start, w_start = self._get_deterministic_hw(index)
|
|
1140
|
-
|
|
1141
|
-
cropped_img_tuples = [
|
|
1142
|
-
self._crop_flip_img(img, h_start, w_start, False, False)
|
|
1143
|
-
for img in img_tuples
|
|
1144
|
-
]
|
|
1145
|
-
cropped_noise_tuples = [
|
|
1146
|
-
self._crop_flip_img(noise, h_start, w_start, False, False)
|
|
1147
|
-
for noise in noise_tuples
|
|
1148
|
-
]
|
|
1149
|
-
h_center = h_start + self._img_sz // 2
|
|
1150
|
-
w_center = w_start + self._img_sz // 2
|
|
1151
|
-
allres_versions = {
|
|
1152
|
-
i: [cropped_img_tuples[i]] for i in range(len(cropped_img_tuples))
|
|
1153
|
-
}
|
|
1154
|
-
for scale_idx in range(1, self.num_scales):
|
|
1155
|
-
scaled_img_tuples = self._load_scaled_img(scale_idx, index)
|
|
1156
|
-
|
|
1157
|
-
h_center = h_center // 2
|
|
1158
|
-
w_center = w_center // 2
|
|
1159
|
-
|
|
1160
|
-
h_start = h_center - self._img_sz // 2
|
|
1161
|
-
w_start = w_center - self._img_sz // 2
|
|
1162
|
-
|
|
1163
|
-
scaled_cropped_img_tuples = [
|
|
1164
|
-
self._crop_flip_img(img, h_start, w_start, False, False)
|
|
1165
|
-
for img in scaled_img_tuples
|
|
1166
|
-
]
|
|
1167
|
-
for ch_idx in range(len(img_tuples)):
|
|
1168
|
-
allres_versions[ch_idx].append(scaled_cropped_img_tuples[ch_idx])
|
|
1169
|
-
|
|
1170
|
-
output_img_tuples = tuple(
|
|
1171
|
-
[
|
|
1172
|
-
np.concatenate(allres_versions[ch_idx])
|
|
1173
|
-
for ch_idx in range(len(img_tuples))
|
|
1174
|
-
]
|
|
1175
|
-
)
|
|
1176
|
-
return output_img_tuples, cropped_noise_tuples
|
|
1177
|
-
|
|
1178
|
-
def __getitem__(self, index: Union[int, Tuple[int, int]]):
|
|
1179
|
-
if self._uncorrelated_channels:
|
|
1180
|
-
img_tuples, noise_tuples = self.get_uncorrelated_img_tuples(index)
|
|
1181
|
-
else:
|
|
1182
|
-
img_tuples, noise_tuples = self._get_img(index)
|
|
1183
|
-
|
|
1184
|
-
if self._enable_rotation:
|
|
1185
|
-
img_tuples, noise_tuples = self._rotate(img_tuples, noise_tuples)
|
|
1186
|
-
|
|
1187
|
-
assert self._lowres_supervision != True
|
|
1188
|
-
# add noise to input
|
|
1189
|
-
if len(noise_tuples) > 0:
|
|
1190
|
-
factor = np.sqrt(2) if self._input_is_sum else 1.0
|
|
1191
|
-
input_tuples = []
|
|
1192
|
-
for x in img_tuples:
|
|
1193
|
-
# NOTE: other LC levels already have noise added. So, we just need to add noise to the highest resolution.
|
|
1194
|
-
x[0] = x[0] + noise_tuples[0] * factor
|
|
1195
|
-
input_tuples.append(x)
|
|
1196
|
-
else:
|
|
1197
|
-
input_tuples = img_tuples
|
|
1198
|
-
|
|
1199
|
-
inp, alpha = self._compute_input(input_tuples)
|
|
1200
|
-
# assert self._alpha_weighted_target in [False, None]
|
|
1201
|
-
target_tuples = [img[:1] for img in img_tuples]
|
|
1202
|
-
# add noise to target.
|
|
1203
|
-
if len(noise_tuples) >= 1:
|
|
1204
|
-
target_tuples = [
|
|
1205
|
-
x + noise for x, noise in zip(target_tuples, noise_tuples[1:])
|
|
1206
|
-
]
|
|
1207
|
-
|
|
1208
|
-
target = self._compute_target(target_tuples, alpha)
|
|
1209
|
-
|
|
1210
|
-
output = [inp, target]
|
|
1211
|
-
|
|
1212
|
-
if self._return_alpha:
|
|
1213
|
-
output.append(alpha)
|
|
1214
|
-
|
|
1215
|
-
if isinstance(index, int):
|
|
1216
|
-
return tuple(output)
|
|
1217
|
-
|
|
1218
|
-
_, grid_size = index
|
|
1219
|
-
output.append(grid_size)
|
|
1220
1054
|
return tuple(output)
|