careamics 0.0.11__py3-none-any.whl → 0.0.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +20 -4
- careamics/config/configuration.py +10 -5
- careamics/config/data/data_model.py +38 -1
- careamics/config/optimizer_models.py +1 -3
- careamics/config/training_model.py +0 -2
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/dataset.py +233 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +356 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +734 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +443 -0
- careamics/dataset_ng/{demo_dataset.ipynb → demos/demo_dataset.ipynb} +39 -15
- careamics/dataset_ng/{demo_patch_extractor.py → demos/demo_patch_extractor.py} +7 -9
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +292 -0
- careamics/dataset_ng/factory.py +408 -0
- careamics/dataset_ng/legacy_interoperability.py +168 -0
- careamics/dataset_ng/patch_extractor/__init__.py +3 -8
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +6 -4
- careamics/dataset_ng/patch_extractor/image_stack/__init__.py +2 -1
- careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +5 -1
- careamics/dataset_ng/patch_extractor/image_stack_loader.py +5 -75
- careamics/dataset_ng/patch_extractor/patch_extractor.py +5 -4
- careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +73 -106
- careamics/dataset_ng/patching_strategies/__init__.py +6 -1
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +31 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +3 -1
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +171 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +36 -0
- careamics/lightning/dataset_ng/data_module.py +488 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +58 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +67 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +143 -0
- careamics/lightning/lightning_module.py +3 -0
- careamics/lvae_training/dataset/__init__.py +8 -3
- careamics/lvae_training/dataset/config.py +3 -3
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +46 -17
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/types.py +3 -3
- careamics/lvae_training/dataset/utils/index_manager.py +259 -0
- careamics/lvae_training/eval_utils.py +93 -3
- careamics/transforms/compose.py +1 -0
- careamics/transforms/normalize.py +18 -7
- careamics/utils/lightning_utils.py +25 -11
- {careamics-0.0.11.dist-info → careamics-0.0.12.dist-info}/METADATA +3 -3
- {careamics-0.0.11.dist-info → careamics-0.0.12.dist-info}/RECORD +50 -35
- careamics/dataset_ng/dataset/__init__.py +0 -3
- careamics/dataset_ng/dataset/dataset.py +0 -184
- careamics/dataset_ng/demo_patch_extractor_factory.py +0 -37
- {careamics-0.0.11.dist-info → careamics-0.0.12.dist-info}/WHEEL +0 -0
- {careamics-0.0.11.dist-info → careamics-0.0.12.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.11.dist-info → careamics-0.0.12.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1067 @@
|
|
|
1
|
+
"""
|
|
2
|
+
A place for Datasets and Dataloaders.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from collections import defaultdict
|
|
6
|
+
from functools import cache
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Callable, Union
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
from skimage.transform import resize
|
|
12
|
+
|
|
13
|
+
from .config import DatasetConfig
|
|
14
|
+
from .types import DataSplitType, TilingMode
|
|
15
|
+
from .utils.empty_patch_fetcher import EmptyPatchFetcher
|
|
16
|
+
from .utils.index_manager import GridIndexManagerRef
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class MultiChDloaderRef:
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
data_config: DatasetConfig,
|
|
23
|
+
fpath: str,
|
|
24
|
+
load_data_fn: Callable,
|
|
25
|
+
val_fraction: float = None,
|
|
26
|
+
test_fraction: float = None,
|
|
27
|
+
):
|
|
28
|
+
""" """
|
|
29
|
+
self._data_type = data_config.data_type
|
|
30
|
+
self._fpath = Path(fpath)
|
|
31
|
+
self._data = None
|
|
32
|
+
self._3Ddata = False # TODO wtf it was 5D
|
|
33
|
+
self._tiling_mode = data_config.tiling_mode
|
|
34
|
+
# by default, if the noise is present, add it to the input and target.
|
|
35
|
+
self._depth3D = data_config.depth3D
|
|
36
|
+
self._mode_3D = data_config.mode_3D
|
|
37
|
+
# NOTE: Input is the sum of the different channels. It is not the average of the different channels.
|
|
38
|
+
self._input_is_sum = data_config.input_is_sum
|
|
39
|
+
self._num_channels = data_config.num_channels
|
|
40
|
+
self._input_idx = data_config.input_idx
|
|
41
|
+
self._tar_idx_list = data_config.target_idx_list
|
|
42
|
+
|
|
43
|
+
self.load_data(
|
|
44
|
+
data_config,
|
|
45
|
+
data_config.datasplit_type,
|
|
46
|
+
load_data_fn=load_data_fn,
|
|
47
|
+
val_fraction=val_fraction,
|
|
48
|
+
test_fraction=test_fraction,
|
|
49
|
+
allow_generation=data_config.allow_generation,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
self._data_shapes = self.get_data_shapes()
|
|
53
|
+
self._normalized_input = data_config.normalized_input
|
|
54
|
+
self._quantile = 1.0
|
|
55
|
+
self._channelwise_quantile = False
|
|
56
|
+
self._background_quantile = 0.0
|
|
57
|
+
self._clip_background_noise_to_zero = False
|
|
58
|
+
self._skip_normalization_using_mean = False
|
|
59
|
+
self._empty_patch_replacement_enabled = False
|
|
60
|
+
|
|
61
|
+
self._background_values = None
|
|
62
|
+
|
|
63
|
+
self._overlapping_padding_kwargs = data_config.overlapping_padding_kwargs
|
|
64
|
+
if self._tiling_mode in [TilingMode.TrimBoundary, TilingMode.ShiftBoundary]:
|
|
65
|
+
if (
|
|
66
|
+
self._overlapping_padding_kwargs is None
|
|
67
|
+
or data_config.multiscale_lowres_count is not None
|
|
68
|
+
):
|
|
69
|
+
# raise warning
|
|
70
|
+
print("Padding is not used with this alignement style")
|
|
71
|
+
else:
|
|
72
|
+
assert (
|
|
73
|
+
self._overlapping_padding_kwargs is not None
|
|
74
|
+
), "When not trimming boudnary, padding is needed."
|
|
75
|
+
|
|
76
|
+
self._is_train = data_config.datasplit_type == DataSplitType.Train
|
|
77
|
+
|
|
78
|
+
# input = alpha * ch1 + (1-alpha)*ch2.
|
|
79
|
+
# alpha is sampled randomly between these two extremes
|
|
80
|
+
self._start_alpha_arr = self._end_alpha_arr = self._return_alpha = None
|
|
81
|
+
|
|
82
|
+
self._img_sz = self._grid_sz = self._repeat_factor = self.idx_manager = None
|
|
83
|
+
|
|
84
|
+
# changed set_img_sz because "grid_size" in data_config returns false
|
|
85
|
+
try:
|
|
86
|
+
grid_size = data_config.grid_size
|
|
87
|
+
except AttributeError:
|
|
88
|
+
grid_size = data_config.image_size
|
|
89
|
+
|
|
90
|
+
if self._is_train:
|
|
91
|
+
self._start_alpha_arr = data_config.start_alpha # TODO why only for train?
|
|
92
|
+
self._end_alpha_arr = data_config.end_alpha
|
|
93
|
+
|
|
94
|
+
self.set_img_sz(data_config.image_size, grid_size)
|
|
95
|
+
|
|
96
|
+
self._empty_patch_replacement_enabled = (
|
|
97
|
+
data_config.empty_patch_replacement_enabled and self._is_train
|
|
98
|
+
)
|
|
99
|
+
if self._empty_patch_replacement_enabled:
|
|
100
|
+
self._empty_patch_replacement_channel_idx = (
|
|
101
|
+
data_config.empty_patch_replacement_channel_idx
|
|
102
|
+
)
|
|
103
|
+
self._empty_patch_replacement_probab = (
|
|
104
|
+
data_config.empty_patch_replacement_probab
|
|
105
|
+
)
|
|
106
|
+
data_frames = self._data[..., self._empty_patch_replacement_channel_idx]
|
|
107
|
+
# NOTE: This is on the raw data. So, it must be called before removing the background.
|
|
108
|
+
self._empty_patch_fetcher = EmptyPatchFetcher(
|
|
109
|
+
self.idx_manager,
|
|
110
|
+
self._img_sz,
|
|
111
|
+
data_frames,
|
|
112
|
+
max_val_threshold=data_config.empty_patch_max_val_threshold,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
self.rm_bkground_set_max_val_and_upperclip_data(
|
|
116
|
+
data_config.max_val, data_config.datasplit_type
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# For overlapping dloader, image_size and repeat_factors are not related. hence a different function.
|
|
120
|
+
|
|
121
|
+
self._mean = None
|
|
122
|
+
self._std = None
|
|
123
|
+
self._use_one_mu_std = data_config.use_one_mu_std
|
|
124
|
+
|
|
125
|
+
self._target_separate_normalization = data_config.target_separate_normalization
|
|
126
|
+
|
|
127
|
+
self._enable_rotation = data_config.enable_rotation_aug
|
|
128
|
+
flipz_3D = data_config.random_flip_z_3D
|
|
129
|
+
self._flipz_3D = flipz_3D and self._enable_rotation
|
|
130
|
+
|
|
131
|
+
self._enable_random_cropping = data_config.enable_random_cropping
|
|
132
|
+
self._uncorrelated_channels = (
|
|
133
|
+
data_config.uncorrelated_channels and self._is_train
|
|
134
|
+
)
|
|
135
|
+
self._uncorrelated_channel_probab = data_config.uncorrelated_channel_probab
|
|
136
|
+
assert self._is_train or self._uncorrelated_channels is False
|
|
137
|
+
assert (
|
|
138
|
+
self._enable_random_cropping is True or self._uncorrelated_channels is False
|
|
139
|
+
)
|
|
140
|
+
# Randomly rotate [-90,90]
|
|
141
|
+
|
|
142
|
+
self._rotation_transform = None
|
|
143
|
+
if self._enable_rotation:
|
|
144
|
+
# TODO: fix this import
|
|
145
|
+
import albumentations as A
|
|
146
|
+
|
|
147
|
+
self._rotation_transform = A.Compose([A.Flip(), A.RandomRotate90()])
|
|
148
|
+
|
|
149
|
+
# TODO: remove print log messages
|
|
150
|
+
# if print_vars:
|
|
151
|
+
# msg = self._init_msg()
|
|
152
|
+
# print(msg)
|
|
153
|
+
|
|
154
|
+
def get_data_shapes(self):
|
|
155
|
+
if self._3Ddata: # TODO we assume images don't have a channel dimension
|
|
156
|
+
[
|
|
157
|
+
[
|
|
158
|
+
im.shape if len(im.shape) == 4 else (1, *im.shape)
|
|
159
|
+
for im in self._data[ch]
|
|
160
|
+
]
|
|
161
|
+
for ch in range(len(self._data))
|
|
162
|
+
]
|
|
163
|
+
else:
|
|
164
|
+
return [
|
|
165
|
+
[
|
|
166
|
+
im.shape if len(im.shape) == 3 else (1, *im.shape)
|
|
167
|
+
for im in self._data[ch]
|
|
168
|
+
]
|
|
169
|
+
for ch in range(len(self._data))
|
|
170
|
+
]
|
|
171
|
+
|
|
172
|
+
def load_data(
|
|
173
|
+
self,
|
|
174
|
+
data_config,
|
|
175
|
+
datasplit_type,
|
|
176
|
+
load_data_fn: Callable,
|
|
177
|
+
val_fraction=None,
|
|
178
|
+
test_fraction=None,
|
|
179
|
+
allow_generation=None,
|
|
180
|
+
):
|
|
181
|
+
self._data = load_data_fn(
|
|
182
|
+
data_config,
|
|
183
|
+
self._fpath,
|
|
184
|
+
datasplit_type,
|
|
185
|
+
val_fraction=val_fraction,
|
|
186
|
+
test_fraction=test_fraction,
|
|
187
|
+
allow_generation=allow_generation,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# TODO check for 2D/3D data consistency with config
|
|
191
|
+
# TODO check number of channels consistency with config
|
|
192
|
+
|
|
193
|
+
def save_background(self, channel_idx, frame_idx, background_value):
|
|
194
|
+
self._background_values[frame_idx, channel_idx] = background_value
|
|
195
|
+
|
|
196
|
+
def get_background(self, channel_idx, frame_idx):
|
|
197
|
+
return self._background_values[frame_idx, channel_idx]
|
|
198
|
+
|
|
199
|
+
def rm_bkground_set_max_val_and_upperclip_data(self, max_val, datasplit_type):
|
|
200
|
+
# self.remove_background() # TODO revisit
|
|
201
|
+
self.set_max_val(max_val, datasplit_type)
|
|
202
|
+
self.upperclip_data()
|
|
203
|
+
|
|
204
|
+
def upperclip_data(self):
|
|
205
|
+
for ch_idx, data in enumerate(self._data):
|
|
206
|
+
if self.max_val[ch_idx] is not None:
|
|
207
|
+
for idx in range(len(data)):
|
|
208
|
+
data[idx][data[idx] > self.max_val[ch_idx]] = self.max_val[ch_idx]
|
|
209
|
+
|
|
210
|
+
def compute_max_val(self):
|
|
211
|
+
# TODO add channelwise quantile ?
|
|
212
|
+
return [
|
|
213
|
+
max([np.quantile(im, self._quantile) for im in ch]) for ch in self._data
|
|
214
|
+
]
|
|
215
|
+
|
|
216
|
+
def set_max_val(self, max_val, datasplit_type):
|
|
217
|
+
if max_val is None:
|
|
218
|
+
assert datasplit_type in [DataSplitType.Train, DataSplitType.All]
|
|
219
|
+
self.max_val = self.compute_max_val()
|
|
220
|
+
else:
|
|
221
|
+
assert max_val is not None
|
|
222
|
+
self.max_val = max_val
|
|
223
|
+
|
|
224
|
+
def get_max_val(self):
|
|
225
|
+
return self.max_val
|
|
226
|
+
|
|
227
|
+
def get_img_sz(self):
|
|
228
|
+
return self._img_sz
|
|
229
|
+
|
|
230
|
+
def get_num_frames(self):
|
|
231
|
+
"""Returns the number of the longest channel."""
|
|
232
|
+
return max(self.idx_manager.total_grid_count()[0])
|
|
233
|
+
|
|
234
|
+
def reduce_data(
|
|
235
|
+
self,
|
|
236
|
+
t_list=None,
|
|
237
|
+
z_start=None,
|
|
238
|
+
z_end=None,
|
|
239
|
+
h_start=None,
|
|
240
|
+
h_end=None,
|
|
241
|
+
w_start=None,
|
|
242
|
+
w_end=None,
|
|
243
|
+
):
|
|
244
|
+
raise NotImplementedError("Not implemented")
|
|
245
|
+
|
|
246
|
+
def get_idx_manager_shapes(
|
|
247
|
+
self, patch_size: int, grid_size: Union[int, tuple[int, int, int]]
|
|
248
|
+
):
|
|
249
|
+
numC = len(self._data_shapes)
|
|
250
|
+
if self._3Ddata:
|
|
251
|
+
patch_shape = (1, self._depth3D, patch_size, patch_size)
|
|
252
|
+
if isinstance(grid_size, int):
|
|
253
|
+
grid_shape = (1, 1, grid_size, grid_size)
|
|
254
|
+
else:
|
|
255
|
+
assert len(grid_size) == 3
|
|
256
|
+
assert all(
|
|
257
|
+
[g <= p for g, p in zip(grid_size, patch_shape[1:-1])]
|
|
258
|
+
), f"Grid size {grid_size} must be less than patch size {patch_shape[1:-1]}"
|
|
259
|
+
grid_shape = (1, grid_size[0], grid_size[1], grid_size[2])
|
|
260
|
+
else:
|
|
261
|
+
assert isinstance(grid_size, int)
|
|
262
|
+
grid_shape = (1, grid_size, grid_size)
|
|
263
|
+
patch_shape = (1, patch_size, patch_size)
|
|
264
|
+
|
|
265
|
+
return patch_shape, grid_shape
|
|
266
|
+
|
|
267
|
+
def set_img_sz(self, image_size, grid_size: Union[int, tuple[int, int, int]]):
|
|
268
|
+
"""
|
|
269
|
+
If one wants to change the image size on the go, then this can be used.
|
|
270
|
+
Args:
|
|
271
|
+
image_size: size of one patch
|
|
272
|
+
grid_size: frame is divided into square grids of this size. A patch centered on a grid having size `image_size` is returned.
|
|
273
|
+
"""
|
|
274
|
+
# hacky way to deal with image shape from new conf
|
|
275
|
+
self._img_sz = image_size[-1] # TODO revisit!
|
|
276
|
+
self._grid_sz = grid_size
|
|
277
|
+
shapes = self._data_shapes
|
|
278
|
+
|
|
279
|
+
patch_shape, grid_shape = self.get_idx_manager_shapes(
|
|
280
|
+
self._img_sz, self._grid_sz
|
|
281
|
+
)
|
|
282
|
+
self.idx_manager = GridIndexManagerRef(
|
|
283
|
+
shapes, grid_shape, patch_shape, self._tiling_mode
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
def __len__(self):
|
|
287
|
+
# If channel length is not equal, return the longest
|
|
288
|
+
return max(self.idx_manager.total_grid_count()[0])
|
|
289
|
+
|
|
290
|
+
def _init_msg(
|
|
291
|
+
self,
|
|
292
|
+
):
|
|
293
|
+
msg = (
|
|
294
|
+
f"[{self.__class__.__name__}] Train:{int(self._is_train)} Sz:{self._img_sz}"
|
|
295
|
+
)
|
|
296
|
+
dim_sizes = [
|
|
297
|
+
self.idx_manager.get_individual_dim_grid_count(dim)
|
|
298
|
+
for dim in range(len(self._data.shape))
|
|
299
|
+
]
|
|
300
|
+
dim_sizes = ",".join([str(x) for x in dim_sizes])
|
|
301
|
+
msg += f" N:{self.N} NumPatchPerN:{self._repeat_factor}"
|
|
302
|
+
msg += f"{self.idx_manager.total_grid_count()} DimSz:({dim_sizes})"
|
|
303
|
+
msg += f" TrimB:{self._tiling_mode}"
|
|
304
|
+
# msg += f' NormInp:{self._normalized_input}'
|
|
305
|
+
# msg += f' SingleNorm:{self._use_one_mu_std}'
|
|
306
|
+
msg += f" Rot:{self._enable_rotation}"
|
|
307
|
+
if self._flipz_3D:
|
|
308
|
+
msg += f" FlipZ:{self._flipz_3D}"
|
|
309
|
+
|
|
310
|
+
msg += f" RandCrop:{self._enable_random_cropping}"
|
|
311
|
+
msg += f" Channel:{self._num_channels}"
|
|
312
|
+
# msg += f' Q:{self._quantile}'
|
|
313
|
+
if self._input_is_sum:
|
|
314
|
+
msg += f" SummedInput:{self._input_is_sum}"
|
|
315
|
+
|
|
316
|
+
if self._empty_patch_replacement_enabled:
|
|
317
|
+
msg += f" ReplaceWithRandSample:{self._empty_patch_replacement_enabled}"
|
|
318
|
+
if self._uncorrelated_channels:
|
|
319
|
+
msg += f" Uncorr:{self._uncorrelated_channels}"
|
|
320
|
+
if self._empty_patch_replacement_enabled:
|
|
321
|
+
msg += f"-{self._empty_patch_replacement_channel_idx}-{self._empty_patch_replacement_probab}"
|
|
322
|
+
if self._background_quantile > 0.0:
|
|
323
|
+
msg += f" BckQ:{self._background_quantile}"
|
|
324
|
+
|
|
325
|
+
if self._start_alpha_arr is not None:
|
|
326
|
+
msg += f" Alpha:[{self._start_alpha_arr},{self._end_alpha_arr}]"
|
|
327
|
+
return msg
|
|
328
|
+
|
|
329
|
+
def _crop_imgs(self, ch_idx: int, patch_idx: int, img: np.ndarray):
|
|
330
|
+
h, w = img.shape[-2:]
|
|
331
|
+
if self._img_sz is None:
|
|
332
|
+
return (
|
|
333
|
+
img,
|
|
334
|
+
{"h": [0, h], "w": [0, w], "hflip": False, "wflip": False},
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
if self._enable_random_cropping:
|
|
338
|
+
# this parameter is ambiguous. It toggles between random/deterministic patching
|
|
339
|
+
patch_start_loc = self._get_random_hw(h, w)
|
|
340
|
+
if self._3Ddata:
|
|
341
|
+
patch_start_loc = (
|
|
342
|
+
np.random.choice(1 + img.shape[-3] - self._depth3D),
|
|
343
|
+
) + patch_start_loc
|
|
344
|
+
else:
|
|
345
|
+
# Patch coordinates are calculated by the index manager.
|
|
346
|
+
patch_start_loc = self._get_deterministic_loc(ch_idx, patch_idx)
|
|
347
|
+
cropped_img = self._crop_flip_img(img, patch_start_loc, False, False)
|
|
348
|
+
|
|
349
|
+
return cropped_img
|
|
350
|
+
|
|
351
|
+
def _crop_img(self, img: np.ndarray, patch_start_loc: tuple):
|
|
352
|
+
if self._tiling_mode in [TilingMode.TrimBoundary, TilingMode.ShiftBoundary]:
|
|
353
|
+
# In training, this is used.
|
|
354
|
+
# NOTE: It is my opinion that if I just use self._crop_img_with_padding, it will work perfectly fine.
|
|
355
|
+
# The only benefit this if else loop provides is that it makes it easier to see what happens during training.
|
|
356
|
+
patch_end_loc = (
|
|
357
|
+
np.array(patch_start_loc, dtype=np.int32)
|
|
358
|
+
+ self.idx_manager.patch_shape[1:-1]
|
|
359
|
+
)
|
|
360
|
+
if self._3Ddata:
|
|
361
|
+
z_start, h_start, w_start = patch_start_loc
|
|
362
|
+
z_end, h_end, w_end = patch_end_loc
|
|
363
|
+
new_img = img[..., z_start:z_end, h_start:h_end, w_start:w_end]
|
|
364
|
+
else:
|
|
365
|
+
h_start, w_start = patch_start_loc
|
|
366
|
+
h_end, w_end = patch_end_loc
|
|
367
|
+
new_img = img[..., h_start:h_end, w_start:w_end]
|
|
368
|
+
|
|
369
|
+
return new_img
|
|
370
|
+
else:
|
|
371
|
+
# During evaluation, this is used. In this situation, we can have negative h_start, w_start. Or h_start +self._img_sz can be larger than frame
|
|
372
|
+
# In these situations, we need some sort of padding. This is not needed in the LeftTop alignement.
|
|
373
|
+
return self._crop_img_with_padding(img, patch_start_loc)
|
|
374
|
+
|
|
375
|
+
def get_begin_end_padding(self, start_pos, end_pos, max_len):
|
|
376
|
+
"""
|
|
377
|
+
The effect is that the image with size self._grid_sz is in the center of the patch with sufficient
|
|
378
|
+
padding on all four sides so that the final patch size is self._img_sz.
|
|
379
|
+
"""
|
|
380
|
+
pad_start = 0
|
|
381
|
+
pad_end = 0
|
|
382
|
+
if start_pos < 0:
|
|
383
|
+
pad_start = -1 * start_pos
|
|
384
|
+
|
|
385
|
+
pad_end = max(0, end_pos - max_len)
|
|
386
|
+
|
|
387
|
+
return pad_start, pad_end
|
|
388
|
+
|
|
389
|
+
def _crop_img_with_padding(
|
|
390
|
+
self, img: np.ndarray, patch_start_loc, max_len_vals=None
|
|
391
|
+
):
|
|
392
|
+
if max_len_vals is None:
|
|
393
|
+
max_len_vals = self.idx_manager.data_shape[1:-1]
|
|
394
|
+
patch_end_loc = np.array(patch_start_loc, dtype=int) + np.array(
|
|
395
|
+
self.idx_manager.patch_shape[1:-1], dtype=int
|
|
396
|
+
)
|
|
397
|
+
boundary_crossed = []
|
|
398
|
+
valid_slice = []
|
|
399
|
+
padding = [[0, 0]]
|
|
400
|
+
for start_idx, end_idx, max_len in zip(
|
|
401
|
+
patch_start_loc, patch_end_loc, max_len_vals
|
|
402
|
+
):
|
|
403
|
+
boundary_crossed.append(end_idx > max_len or start_idx < 0)
|
|
404
|
+
valid_slice.append((max(0, start_idx), min(max_len, end_idx)))
|
|
405
|
+
pad = [0, 0]
|
|
406
|
+
if boundary_crossed[-1]:
|
|
407
|
+
pad = self.get_begin_end_padding(start_idx, end_idx, max_len)
|
|
408
|
+
padding.append(pad)
|
|
409
|
+
# max() is needed since h_start could be negative.
|
|
410
|
+
if self._3Ddata:
|
|
411
|
+
new_img = img[
|
|
412
|
+
...,
|
|
413
|
+
valid_slice[0][0] : valid_slice[0][1],
|
|
414
|
+
valid_slice[1][0] : valid_slice[1][1],
|
|
415
|
+
valid_slice[2][0] : valid_slice[2][1],
|
|
416
|
+
]
|
|
417
|
+
else:
|
|
418
|
+
new_img = img[
|
|
419
|
+
...,
|
|
420
|
+
valid_slice[0][0] : valid_slice[0][1],
|
|
421
|
+
valid_slice[1][0] : valid_slice[1][1],
|
|
422
|
+
]
|
|
423
|
+
|
|
424
|
+
# print(np.array(padding).shape, img.shape, new_img.shape)
|
|
425
|
+
# print(padding)
|
|
426
|
+
if not np.all(padding == 0):
|
|
427
|
+
new_img = np.pad(new_img, padding, **self._overlapping_padding_kwargs)
|
|
428
|
+
|
|
429
|
+
return new_img
|
|
430
|
+
|
|
431
|
+
def _crop_flip_img(
|
|
432
|
+
self, img: np.ndarray, patch_start_loc: tuple, h_flip: bool, w_flip: bool
|
|
433
|
+
):
|
|
434
|
+
new_img = self._crop_img(img, patch_start_loc)
|
|
435
|
+
if h_flip:
|
|
436
|
+
new_img = new_img[..., ::-1, :]
|
|
437
|
+
if w_flip:
|
|
438
|
+
new_img = new_img[..., :, ::-1]
|
|
439
|
+
|
|
440
|
+
return new_img.astype(np.float32)
|
|
441
|
+
|
|
442
|
+
def _load_img(self, ch_idx: int, patch_idx: int) -> tuple[np.ndarray, np.ndarray]:
|
|
443
|
+
"""
|
|
444
|
+
Returns the channels and also the respective noise channels.
|
|
445
|
+
"""
|
|
446
|
+
patch_loc_list = self.idx_manager.get_patch_location_from_patch_idx(
|
|
447
|
+
ch_idx, patch_idx
|
|
448
|
+
)
|
|
449
|
+
# TODO we should be adding channel dim here probably
|
|
450
|
+
img = self._data[ch_idx][patch_loc_list[0]]
|
|
451
|
+
return img
|
|
452
|
+
|
|
453
|
+
def get_mean_std(self):
|
|
454
|
+
return self._mean, self._std
|
|
455
|
+
|
|
456
|
+
def set_mean_std(self, mean_val, std_val):
|
|
457
|
+
self._mean = mean_val
|
|
458
|
+
self._std = std_val
|
|
459
|
+
|
|
460
|
+
def normalize_target(self, target):
|
|
461
|
+
mean_dict, std_dict = self.get_mean_std()
|
|
462
|
+
mean_ = mean_dict["target"] # .squeeze(0)
|
|
463
|
+
std_ = std_dict["target"] # .squeeze(0)
|
|
464
|
+
return (target - mean_) / std_
|
|
465
|
+
|
|
466
|
+
def get_grid_size(self):
|
|
467
|
+
return self._grid_sz
|
|
468
|
+
|
|
469
|
+
def get_idx_manager(self):
|
|
470
|
+
return self.idx_manager
|
|
471
|
+
|
|
472
|
+
def per_side_overlap_pixelcount(self):
|
|
473
|
+
return (self._img_sz - self._grid_sz) // 2
|
|
474
|
+
|
|
475
|
+
def _get_deterministic_loc(self, ch_idx: int, patch_idx: int):
|
|
476
|
+
"""
|
|
477
|
+
It returns the top-left corner of the patch corresponding to index.
|
|
478
|
+
"""
|
|
479
|
+
loc_list = self.idx_manager.get_patch_location_from_patch_idx(ch_idx, patch_idx)
|
|
480
|
+
# last dim is channel. we need to take the third and the second last element.
|
|
481
|
+
return loc_list[2:]
|
|
482
|
+
|
|
483
|
+
@cache
|
|
484
|
+
def crop_probablities(self, ch_idx):
|
|
485
|
+
sizes = np.array([np.prod(x.shape) for x in self._data[ch_idx]])
|
|
486
|
+
return sizes / sizes.sum()
|
|
487
|
+
|
|
488
|
+
def sample_crop(self, ch_idx):
|
|
489
|
+
idx = None
|
|
490
|
+
count = 0
|
|
491
|
+
while idx is None:
|
|
492
|
+
count += 1
|
|
493
|
+
idx = np.random.choice(
|
|
494
|
+
len(self._data[ch_idx]), p=self.crop_probablities(ch_idx)
|
|
495
|
+
)
|
|
496
|
+
data = self._data[ch_idx][idx] # TODO no channel and S dim ?
|
|
497
|
+
# changed for ndim
|
|
498
|
+
if all(
|
|
499
|
+
d >= self._img_sz for d in data.shape[-2:]
|
|
500
|
+
): # TODO dims were hardcoded
|
|
501
|
+
h = np.random.randint(0, data.shape[-2] - self._img_sz)
|
|
502
|
+
w = np.random.randint(0, data.shape[-1] - self._img_sz)
|
|
503
|
+
|
|
504
|
+
if len(data.shape) > 2 and not self._3Ddata:
|
|
505
|
+
s = np.random.randint(0, data.shape[0] - 1)
|
|
506
|
+
return data[s, h : h + self._img_sz, w : w + self._img_sz]
|
|
507
|
+
else:
|
|
508
|
+
return data[h : h + self._img_sz, w : w + self._img_sz]
|
|
509
|
+
|
|
510
|
+
elif count > 100:
|
|
511
|
+
raise ValueError("Cannot find a valid crop")
|
|
512
|
+
else:
|
|
513
|
+
idx = None
|
|
514
|
+
|
|
515
|
+
return None
|
|
516
|
+
|
|
517
|
+
def _l2(self, x):
|
|
518
|
+
return np.sqrt(np.mean(np.array(x) ** 2))
|
|
519
|
+
|
|
520
|
+
def compute_mean_std(self, allow_for_validation_data=False):
|
|
521
|
+
"""
|
|
522
|
+
Note that we must compute this only for training data.
|
|
523
|
+
"""
|
|
524
|
+
if self._3Ddata:
|
|
525
|
+
raise NotImplementedError("Not implemented for 3D data")
|
|
526
|
+
|
|
527
|
+
if self._input_is_sum:
|
|
528
|
+
mean_tar_dict = defaultdict(list)
|
|
529
|
+
std_tar_dict = defaultdict(list)
|
|
530
|
+
mean_inp = []
|
|
531
|
+
std_inp = []
|
|
532
|
+
for _ in range(30000):
|
|
533
|
+
crops = []
|
|
534
|
+
for ch_idx in range(len(self._data)):
|
|
535
|
+
crop = self.sample_crop(ch_idx)
|
|
536
|
+
mean_tar_dict[ch_idx].append(np.mean(crop))
|
|
537
|
+
std_tar_dict[ch_idx].append(np.std(crop))
|
|
538
|
+
crops.append(crop)
|
|
539
|
+
|
|
540
|
+
inp = 0
|
|
541
|
+
for img in crops:
|
|
542
|
+
inp += img
|
|
543
|
+
|
|
544
|
+
mean_inp.append(np.mean(inp))
|
|
545
|
+
std_inp.append(np.std(inp))
|
|
546
|
+
|
|
547
|
+
output_mean = defaultdict(list)
|
|
548
|
+
output_std = defaultdict(list)
|
|
549
|
+
|
|
550
|
+
NC = len(self._data)
|
|
551
|
+
for ch_idx in range(NC):
|
|
552
|
+
output_mean["target"].append(np.mean(mean_tar_dict[ch_idx]))
|
|
553
|
+
output_std["target"].append(self._l2(std_tar_dict[ch_idx]))
|
|
554
|
+
|
|
555
|
+
output_mean["target"] = np.array(output_mean["target"]).reshape(NC, 1, 1)
|
|
556
|
+
output_std["target"] = np.array(output_std["target"]).reshape(NC, 1, 1)
|
|
557
|
+
|
|
558
|
+
output_mean["input"] = np.array([np.mean(mean_inp)]).reshape(1, 1, 1)
|
|
559
|
+
output_std["input"] = np.array([self._l2(std_inp)]).reshape(1, 1, 1)
|
|
560
|
+
else:
|
|
561
|
+
raise NotImplementedError("Not implemented for non-summed input")
|
|
562
|
+
|
|
563
|
+
return dict(output_mean), dict(output_std)
|
|
564
|
+
|
|
565
|
+
def set_mean_std(self, mean_dict, std_dict):
|
|
566
|
+
self._data_mean = mean_dict
|
|
567
|
+
self._data_std = std_dict
|
|
568
|
+
|
|
569
|
+
def get_mean_std(self):
|
|
570
|
+
return self._data_mean, self._data_std
|
|
571
|
+
|
|
572
|
+
def _get_random_hw(self, h: int, w: int):
|
|
573
|
+
"""
|
|
574
|
+
Random starting position for the crop for the img with index `index`.
|
|
575
|
+
"""
|
|
576
|
+
if h != self._img_sz:
|
|
577
|
+
h_start = np.random.choice(h - self._img_sz)
|
|
578
|
+
w_start = np.random.choice(w - self._img_sz)
|
|
579
|
+
else:
|
|
580
|
+
h_start = 0
|
|
581
|
+
w_start = 0
|
|
582
|
+
return h_start, w_start
|
|
583
|
+
|
|
584
|
+
def replace_with_empty_patch(self, img_tuples):
|
|
585
|
+
"""
|
|
586
|
+
Replaces the content of one of the channels with background
|
|
587
|
+
"""
|
|
588
|
+
empty_index = self._empty_patch_fetcher.sample()
|
|
589
|
+
empty_img_tuples, empty_img_noise_tuples = self._get_img(empty_index)
|
|
590
|
+
assert (
|
|
591
|
+
len(empty_img_noise_tuples) == 0
|
|
592
|
+
), "Noise is not supported with empty patch replacement"
|
|
593
|
+
final_img_tuples = []
|
|
594
|
+
for tuple_idx in range(len(img_tuples)):
|
|
595
|
+
if tuple_idx == self._empty_patch_replacement_channel_idx:
|
|
596
|
+
final_img_tuples.append(empty_img_tuples[tuple_idx])
|
|
597
|
+
else:
|
|
598
|
+
final_img_tuples.append(img_tuples[tuple_idx])
|
|
599
|
+
return tuple(final_img_tuples)
|
|
600
|
+
|
|
601
|
+
def get_mean_std_for_input(self):
|
|
602
|
+
mean, std = self.get_mean_std()
|
|
603
|
+
return mean["input"], std["input"]
|
|
604
|
+
|
|
605
|
+
def _compute_target(self, img_tuples, alpha):
|
|
606
|
+
if self._tar_idx_list is not None and isinstance(self._tar_idx_list, int):
|
|
607
|
+
target = img_tuples[self._tar_idx_list]
|
|
608
|
+
else:
|
|
609
|
+
if self._tar_idx_list is not None:
|
|
610
|
+
assert isinstance(self._tar_idx_list, list) or isinstance(
|
|
611
|
+
self._tar_idx_list, tuple
|
|
612
|
+
)
|
|
613
|
+
img_tuples = [img_tuples[i] for i in self._tar_idx_list]
|
|
614
|
+
|
|
615
|
+
target = np.stack(img_tuples, axis=0)
|
|
616
|
+
return target
|
|
617
|
+
|
|
618
|
+
def _compute_input_with_alpha(self, img_tuples, alpha_list):
|
|
619
|
+
# assert self._normalized_input is True, "normalization should happen here"
|
|
620
|
+
if self._input_idx is not None:
|
|
621
|
+
inp = img_tuples[self._input_idx]
|
|
622
|
+
else:
|
|
623
|
+
inp = 0
|
|
624
|
+
for alpha, img in zip(alpha_list, img_tuples):
|
|
625
|
+
inp += img * alpha
|
|
626
|
+
|
|
627
|
+
if self._normalized_input is False:
|
|
628
|
+
return inp.astype(np.float32)
|
|
629
|
+
|
|
630
|
+
mean, std = self.get_mean_std_for_input()
|
|
631
|
+
mean = mean.squeeze()
|
|
632
|
+
std = std.squeeze()
|
|
633
|
+
if mean.size == 1:
|
|
634
|
+
mean = mean.reshape(
|
|
635
|
+
1,
|
|
636
|
+
)
|
|
637
|
+
std = std.reshape(
|
|
638
|
+
1,
|
|
639
|
+
)
|
|
640
|
+
|
|
641
|
+
for i in range(len(mean)):
|
|
642
|
+
assert mean[0] == mean[i]
|
|
643
|
+
assert std[0] == std[i]
|
|
644
|
+
|
|
645
|
+
inp = (inp - mean[0]) / std[0]
|
|
646
|
+
return inp.astype(np.float32)
|
|
647
|
+
|
|
648
|
+
def _sample_alpha(self):
|
|
649
|
+
alpha_arr = []
|
|
650
|
+
for i in range(self._num_channels):
|
|
651
|
+
alpha_pos = np.random.rand()
|
|
652
|
+
alpha = self._start_alpha_arr[i] + alpha_pos * (
|
|
653
|
+
self._end_alpha_arr[i] - self._start_alpha_arr[i]
|
|
654
|
+
)
|
|
655
|
+
alpha_arr.append(alpha)
|
|
656
|
+
return alpha_arr
|
|
657
|
+
|
|
658
|
+
def _compute_input(self, img_tuples):
|
|
659
|
+
alpha = [1 / len(img_tuples) for _ in range(len(img_tuples))]
|
|
660
|
+
if self._start_alpha_arr is not None:
|
|
661
|
+
alpha = self._sample_alpha()
|
|
662
|
+
|
|
663
|
+
inp = self._compute_input_with_alpha(img_tuples, alpha)
|
|
664
|
+
if self._input_is_sum:
|
|
665
|
+
inp = len(img_tuples) * inp
|
|
666
|
+
|
|
667
|
+
# TODO instead we add channel here
|
|
668
|
+
if len(inp.shape) == 2 or (len(inp.shape) == 3 and self._3Ddata):
|
|
669
|
+
inp = inp[None, ...]
|
|
670
|
+
|
|
671
|
+
return inp, alpha
|
|
672
|
+
|
|
673
|
+
def _get_index_from_valid_target_logic(self, index):
|
|
674
|
+
if self._validtarget_rand_fract is not None:
|
|
675
|
+
if np.random.rand() < self._validtarget_rand_fract:
|
|
676
|
+
index = self._train_index_switcher.get_valid_target_index()
|
|
677
|
+
else:
|
|
678
|
+
index = self._train_index_switcher.get_invalid_target_index()
|
|
679
|
+
return index
|
|
680
|
+
|
|
681
|
+
def _rotate2D(self, img_tuples):
|
|
682
|
+
img_kwargs = {}
|
|
683
|
+
for i, img in enumerate(img_tuples):
|
|
684
|
+
for k in range(len(img)):
|
|
685
|
+
img_kwargs[f"img{i}_{k}"] = img[k]
|
|
686
|
+
|
|
687
|
+
keys = list(img_kwargs.keys())
|
|
688
|
+
self._rotation_transform.add_targets({k: "image" for k in keys})
|
|
689
|
+
rot_dic = self._rotation_transform(image=img_tuples[0][0], **img_kwargs)
|
|
690
|
+
|
|
691
|
+
rotated_img_tuples = []
|
|
692
|
+
for i, img in enumerate(img_tuples):
|
|
693
|
+
if len(img) == 1:
|
|
694
|
+
rotated_img_tuples.append(rot_dic[f"img{i}_0"][None])
|
|
695
|
+
else:
|
|
696
|
+
rotated_img_tuples.append(
|
|
697
|
+
np.concatenate(
|
|
698
|
+
[rot_dic[f"img{i}_{k}"][None] for k in range(len(img))], axis=0
|
|
699
|
+
)
|
|
700
|
+
)
|
|
701
|
+
|
|
702
|
+
return rotated_img_tuples
|
|
703
|
+
|
|
704
|
+
def _rotate3D(self, img_tuples):
|
|
705
|
+
img_kwargs = {}
|
|
706
|
+
# random flip in z direction
|
|
707
|
+
flip_z = self._flipz_3D and np.random.rand() < 0.5
|
|
708
|
+
for i, img in enumerate(img_tuples):
|
|
709
|
+
for j in range(self._depth3D):
|
|
710
|
+
for k in range(len(img)):
|
|
711
|
+
if flip_z:
|
|
712
|
+
z_idx = self._depth3D - 1 - j
|
|
713
|
+
else:
|
|
714
|
+
z_idx = j
|
|
715
|
+
img_kwargs[f"img{i}_{z_idx}_{k}"] = img[k, j]
|
|
716
|
+
|
|
717
|
+
keys = list(img_kwargs.keys())
|
|
718
|
+
self._rotation_transform.add_targets({k: "image" for k in keys})
|
|
719
|
+
rot_dic = self._rotation_transform(image=img_tuples[0][0][0], **img_kwargs)
|
|
720
|
+
rotated_img_tuples = []
|
|
721
|
+
for i, img in enumerate(img_tuples):
|
|
722
|
+
if len(img) == 1:
|
|
723
|
+
rotated_img_tuples.append(
|
|
724
|
+
np.concatenate(
|
|
725
|
+
[
|
|
726
|
+
rot_dic[f"img{i}_{j}_0"][None, None]
|
|
727
|
+
for j in range(self._depth3D)
|
|
728
|
+
],
|
|
729
|
+
axis=1,
|
|
730
|
+
)
|
|
731
|
+
)
|
|
732
|
+
else:
|
|
733
|
+
temp_arr = []
|
|
734
|
+
for k in range(len(img)):
|
|
735
|
+
temp_arr.append(
|
|
736
|
+
np.concatenate(
|
|
737
|
+
[
|
|
738
|
+
rot_dic[f"img{i}_{j}_{k}"][None, None]
|
|
739
|
+
for j in range(self._depth3D)
|
|
740
|
+
],
|
|
741
|
+
axis=1,
|
|
742
|
+
)
|
|
743
|
+
)
|
|
744
|
+
rotated_img_tuples.append(np.concatenate(temp_arr, axis=0))
|
|
745
|
+
|
|
746
|
+
return rotated_img_tuples
|
|
747
|
+
|
|
748
|
+
def _rotate(self, img_tuples, noise_tuples):
|
|
749
|
+
|
|
750
|
+
if self._3Ddata:
|
|
751
|
+
return self._rotate3D(img_tuples, noise_tuples)
|
|
752
|
+
else:
|
|
753
|
+
return self._rotate2D(img_tuples, noise_tuples)
|
|
754
|
+
|
|
755
|
+
def _get_img(self, ch_idx: int, patch_idx: int):
|
|
756
|
+
"""
|
|
757
|
+
Loads an image.
|
|
758
|
+
Crops the image such that cropped image has content.
|
|
759
|
+
"""
|
|
760
|
+
img = self._load_img(ch_idx, patch_idx)
|
|
761
|
+
cropped_img = self._crop_imgs(ch_idx, patch_idx, img)
|
|
762
|
+
return cropped_img
|
|
763
|
+
|
|
764
|
+
def get_uncorrelated_img_tuples(self, index):
|
|
765
|
+
"""
|
|
766
|
+
Content of channels like actin and nuclei is "correlated" in its
|
|
767
|
+
respective location, this function allows to pick channels' content
|
|
768
|
+
from different patches of the image to make it "uncorrelated".
|
|
769
|
+
"""
|
|
770
|
+
img_tuples = []
|
|
771
|
+
for ch_idx in range(len(self._data)):
|
|
772
|
+
if ch_idx == 0:
|
|
773
|
+
# dataset index becomes sample index because all channels have the same
|
|
774
|
+
# length
|
|
775
|
+
img_tuples.append(self._get_img(0, index))
|
|
776
|
+
else:
|
|
777
|
+
# get a random index from corresponding channel
|
|
778
|
+
sample_index = np.random.randint(
|
|
779
|
+
self.idx_manager.total_grid_count()[0][ch_idx]
|
|
780
|
+
)
|
|
781
|
+
img_tuples.append(self._get_img(ch_idx, sample_index))
|
|
782
|
+
return img_tuples
|
|
783
|
+
|
|
784
|
+
def __getitem__(
|
|
785
|
+
self, index: Union[int, tuple[int, int]]
|
|
786
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
787
|
+
|
|
788
|
+
# Uncorrelated channels means crops to create the input are taken from different
|
|
789
|
+
# spatial locations of the image.
|
|
790
|
+
if (
|
|
791
|
+
self._uncorrelated_channels
|
|
792
|
+
and np.random.rand() < self._uncorrelated_channel_probab
|
|
793
|
+
):
|
|
794
|
+
input_tuples = self.get_uncorrelated_img_tuples(index)
|
|
795
|
+
else:
|
|
796
|
+
# 0 is the channel index, because in this case locations are the same for
|
|
797
|
+
# all channels
|
|
798
|
+
# tuple for compatibility with _compute_input. #TODO check
|
|
799
|
+
input_tuples = (self._get_img(0, index),)
|
|
800
|
+
|
|
801
|
+
if self._enable_rotation:
|
|
802
|
+
input_tuples = self._rotate(input_tuples)
|
|
803
|
+
|
|
804
|
+
# Weight the individual channels, typically alpha is fixed
|
|
805
|
+
inp, alpha = self._compute_input(input_tuples)
|
|
806
|
+
|
|
807
|
+
target = self._compute_target(input_tuples, alpha)
|
|
808
|
+
norm_target = self.normalize_target(target)
|
|
809
|
+
|
|
810
|
+
return inp, norm_target
|
|
811
|
+
|
|
812
|
+
|
|
813
|
+
class LCMultiChDloaderRef(MultiChDloaderRef):
|
|
814
|
+
def __init__(
|
|
815
|
+
self,
|
|
816
|
+
data_config: DatasetConfig,
|
|
817
|
+
fpath: str,
|
|
818
|
+
load_data_fn: Callable,
|
|
819
|
+
val_fraction=None,
|
|
820
|
+
test_fraction=None,
|
|
821
|
+
):
|
|
822
|
+
self._padding_kwargs = (
|
|
823
|
+
data_config.padding_kwargs # mode=padding_mode, constant_values=constant_value
|
|
824
|
+
)
|
|
825
|
+
self._uncorrelated_channel_probab = data_config.uncorrelated_channel_probab
|
|
826
|
+
|
|
827
|
+
super().__init__(
|
|
828
|
+
data_config,
|
|
829
|
+
fpath,
|
|
830
|
+
load_data_fn=load_data_fn,
|
|
831
|
+
val_fraction=val_fraction,
|
|
832
|
+
test_fraction=test_fraction,
|
|
833
|
+
)
|
|
834
|
+
|
|
835
|
+
if data_config.overlapping_padding_kwargs is not None:
|
|
836
|
+
assert (
|
|
837
|
+
self._padding_kwargs == data_config.overlapping_padding_kwargs
|
|
838
|
+
), "During evaluation, overlapping_padding_kwargs should be same as padding_args. \
|
|
839
|
+
It should be so since we just use overlapping_padding_kwargs when it is not None"
|
|
840
|
+
|
|
841
|
+
else:
|
|
842
|
+
self._overlapping_padding_kwargs = data_config.padding_kwargs
|
|
843
|
+
|
|
844
|
+
self.multiscale_lowres_count = data_config.multiscale_lowres_count
|
|
845
|
+
assert self.multiscale_lowres_count is not None
|
|
846
|
+
self._scaled_data = [self._data]
|
|
847
|
+
self._scaled_noise_data = [self._noise_data]
|
|
848
|
+
|
|
849
|
+
assert (
|
|
850
|
+
isinstance(self.multiscale_lowres_count, int)
|
|
851
|
+
and self.multiscale_lowres_count >= 1
|
|
852
|
+
)
|
|
853
|
+
assert isinstance(self._padding_kwargs, dict)
|
|
854
|
+
assert "mode" in self._padding_kwargs
|
|
855
|
+
|
|
856
|
+
for _ in range(1, self.multiscale_lowres_count):
|
|
857
|
+
shape = self._scaled_data[-1].shape
|
|
858
|
+
assert len(shape) == 4
|
|
859
|
+
new_shape = (shape[0], shape[1] // 2, shape[2] // 2, shape[3])
|
|
860
|
+
ds_data = resize(
|
|
861
|
+
self._scaled_data[-1].astype(np.float32), new_shape
|
|
862
|
+
).astype(self._scaled_data[-1].dtype)
|
|
863
|
+
# NOTE: These asserts are important. the resize method expects np.float32. otherwise, one gets weird results.
|
|
864
|
+
assert (
|
|
865
|
+
ds_data.max() / self._scaled_data[-1].max() < 5
|
|
866
|
+
), "Downsampled image should not have very different values"
|
|
867
|
+
assert (
|
|
868
|
+
ds_data.max() / self._scaled_data[-1].max() > 0.2
|
|
869
|
+
), "Downsampled image should not have very different values"
|
|
870
|
+
|
|
871
|
+
self._scaled_data.append(ds_data)
|
|
872
|
+
# do the same for noise
|
|
873
|
+
if self._noise_data is not None:
|
|
874
|
+
noise_data = resize(self._scaled_noise_data[-1], new_shape)
|
|
875
|
+
self._scaled_noise_data.append(noise_data)
|
|
876
|
+
|
|
877
|
+
def reduce_data(
|
|
878
|
+
self, t_list=None, h_start=None, h_end=None, w_start=None, w_end=None
|
|
879
|
+
):
|
|
880
|
+
assert t_list is not None
|
|
881
|
+
assert h_start is None
|
|
882
|
+
assert h_end is None
|
|
883
|
+
assert w_start is None
|
|
884
|
+
assert w_end is None
|
|
885
|
+
|
|
886
|
+
self._data = self._data[t_list].copy()
|
|
887
|
+
self._scaled_data = [
|
|
888
|
+
self._scaled_data[i][t_list].copy() for i in range(len(self._scaled_data))
|
|
889
|
+
]
|
|
890
|
+
|
|
891
|
+
if self._noise_data is not None:
|
|
892
|
+
self._noise_data = self._noise_data[t_list].copy()
|
|
893
|
+
self._scaled_noise_data = [
|
|
894
|
+
self._scaled_noise_data[i][t_list].copy()
|
|
895
|
+
for i in range(len(self._scaled_noise_data))
|
|
896
|
+
]
|
|
897
|
+
|
|
898
|
+
self.N = len(t_list)
|
|
899
|
+
# TODO where tf is self._img_sz defined?
|
|
900
|
+
self.set_img_sz([self._img_sz, self._img_sz], self._grid_sz)
|
|
901
|
+
print(
|
|
902
|
+
f"[{self.__class__.__name__}] Data reduced. New data shape: {self._data.shape}"
|
|
903
|
+
)
|
|
904
|
+
|
|
905
|
+
def _init_msg(self):
|
|
906
|
+
msg = super()._init_msg()
|
|
907
|
+
msg += f" Pad:{self._padding_kwargs}"
|
|
908
|
+
if self._uncorrelated_channels:
|
|
909
|
+
msg += f" UncorrChProbab:{self._uncorrelated_channel_probab}"
|
|
910
|
+
return msg
|
|
911
|
+
|
|
912
|
+
def _load_scaled_img(
|
|
913
|
+
self, scaled_index, index: Union[int, tuple[int, int]]
|
|
914
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
915
|
+
if isinstance(index, int):
|
|
916
|
+
idx = index
|
|
917
|
+
else:
|
|
918
|
+
idx, _ = index
|
|
919
|
+
|
|
920
|
+
# tidx = self.idx_manager.get_t(idx)
|
|
921
|
+
patch_loc_list = self.idx_manager.get_patch_location_from_dataset_idx(idx)
|
|
922
|
+
nidx = patch_loc_list[0]
|
|
923
|
+
|
|
924
|
+
imgs = self._scaled_data[scaled_index][nidx]
|
|
925
|
+
imgs = tuple([imgs[None, ..., i] for i in range(imgs.shape[-1])])
|
|
926
|
+
if self._noise_data is not None:
|
|
927
|
+
noisedata = self._scaled_noise_data[scaled_index][nidx]
|
|
928
|
+
noise = tuple([noisedata[None, ..., i] for i in range(noisedata.shape[-1])])
|
|
929
|
+
factor = np.sqrt(2) if self._input_is_sum else 1.0
|
|
930
|
+
imgs = tuple([img + noise[0] * factor for img in imgs])
|
|
931
|
+
return imgs
|
|
932
|
+
|
|
933
|
+
def _crop_img(self, img: np.ndarray, patch_start_loc: tuple):
|
|
934
|
+
"""
|
|
935
|
+
Here, h_start, w_start could be negative. That simply means we need to pick the content from 0. So,
|
|
936
|
+
the cropped image will be smaller than self._img_sz * self._img_sz
|
|
937
|
+
"""
|
|
938
|
+
max_len_vals = list(self.idx_manager.data_shape[1:-1])
|
|
939
|
+
max_len_vals[-2:] = img.shape[-2:]
|
|
940
|
+
return self._crop_img_with_padding(
|
|
941
|
+
img, patch_start_loc, max_len_vals=max_len_vals
|
|
942
|
+
)
|
|
943
|
+
|
|
944
|
+
def _get_img(self, index: int):
|
|
945
|
+
"""
|
|
946
|
+
Returns the primary patch along with low resolution patches centered on the primary patch.
|
|
947
|
+
"""
|
|
948
|
+
# Noise_tuples is populated when there is synthetic noise in training
|
|
949
|
+
# Should have similar type of noise with the noise model
|
|
950
|
+
# Starting with microsplit, dump the noise, use it instead as an augmentation if nessesary
|
|
951
|
+
img_tuples, noise_tuples = self._load_img(index)
|
|
952
|
+
assert self._img_sz is not None
|
|
953
|
+
h, w = img_tuples[0].shape[-2:]
|
|
954
|
+
if self._enable_random_cropping:
|
|
955
|
+
patch_start_loc = self._get_random_hw(h, w)
|
|
956
|
+
if self._3Ddata:
|
|
957
|
+
patch_start_loc = (
|
|
958
|
+
np.random.choice(img_tuples[0].shape[-3] - self._depth3D),
|
|
959
|
+
) + patch_start_loc
|
|
960
|
+
else:
|
|
961
|
+
patch_start_loc = self._get_deterministic_loc(index)
|
|
962
|
+
|
|
963
|
+
# LC logic is located here, the function crops the image of the highest resolution
|
|
964
|
+
cropped_img_tuples = [
|
|
965
|
+
self._crop_flip_img(img, patch_start_loc, False, False)
|
|
966
|
+
for img in img_tuples
|
|
967
|
+
]
|
|
968
|
+
cropped_noise_tuples = [
|
|
969
|
+
self._crop_flip_img(noise, patch_start_loc, False, False)
|
|
970
|
+
for noise in noise_tuples
|
|
971
|
+
]
|
|
972
|
+
patch_start_loc = list(patch_start_loc)
|
|
973
|
+
h_start, w_start = patch_start_loc[-2], patch_start_loc[-1]
|
|
974
|
+
h_center = h_start + self._img_sz // 2
|
|
975
|
+
w_center = w_start + self._img_sz // 2
|
|
976
|
+
allres_versions = {
|
|
977
|
+
i: [cropped_img_tuples[i]] for i in range(len(cropped_img_tuples))
|
|
978
|
+
}
|
|
979
|
+
for scale_idx in range(1, self.multiscale_lowres_count):
|
|
980
|
+
# Returning the image of the lower resolution
|
|
981
|
+
scaled_img_tuples = self._load_scaled_img(scale_idx, index)
|
|
982
|
+
|
|
983
|
+
h_center = h_center // 2
|
|
984
|
+
w_center = w_center // 2
|
|
985
|
+
|
|
986
|
+
h_start = h_center - self._img_sz // 2
|
|
987
|
+
w_start = w_center - self._img_sz // 2
|
|
988
|
+
patch_start_loc[-2:] = [h_start, w_start]
|
|
989
|
+
scaled_cropped_img_tuples = [
|
|
990
|
+
self._crop_flip_img(img, patch_start_loc, False, False)
|
|
991
|
+
for img in scaled_img_tuples
|
|
992
|
+
]
|
|
993
|
+
for ch_idx in range(len(img_tuples)):
|
|
994
|
+
allres_versions[ch_idx].append(scaled_cropped_img_tuples[ch_idx])
|
|
995
|
+
|
|
996
|
+
output_img_tuples = tuple(
|
|
997
|
+
[
|
|
998
|
+
np.concatenate(allres_versions[ch_idx])
|
|
999
|
+
for ch_idx in range(len(img_tuples))
|
|
1000
|
+
]
|
|
1001
|
+
)
|
|
1002
|
+
return output_img_tuples, cropped_noise_tuples
|
|
1003
|
+
|
|
1004
|
+
def __getitem__(self, index: Union[int, tuple[int, int]]):
|
|
1005
|
+
img_tuples, noise_tuples = self._get_img(index)
|
|
1006
|
+
if self._uncorrelated_channels:
|
|
1007
|
+
assert (
|
|
1008
|
+
self._input_idx is None
|
|
1009
|
+
), "Uncorrelated channels is not implemented when there is a separate input channel."
|
|
1010
|
+
if np.random.rand() < self._uncorrelated_channel_probab:
|
|
1011
|
+
img_tuples_new = [None] * len(img_tuples)
|
|
1012
|
+
img_tuples_new[0] = img_tuples[0]
|
|
1013
|
+
for i in range(1, len(img_tuples)):
|
|
1014
|
+
new_index = np.random.randint(len(self))
|
|
1015
|
+
img_tuples_tmp, _ = self._get_img(new_index)
|
|
1016
|
+
img_tuples_new[i] = img_tuples_tmp[i]
|
|
1017
|
+
img_tuples = img_tuples_new
|
|
1018
|
+
|
|
1019
|
+
if self._is_train:
|
|
1020
|
+
if self._empty_patch_replacement_enabled:
|
|
1021
|
+
if np.random.rand() < self._empty_patch_replacement_probab:
|
|
1022
|
+
img_tuples = self.replace_with_empty_patch(img_tuples)
|
|
1023
|
+
|
|
1024
|
+
if self._enable_rotation:
|
|
1025
|
+
img_tuples, noise_tuples = self._rotate(img_tuples, noise_tuples)
|
|
1026
|
+
|
|
1027
|
+
# add noise to input, if noise is present combine it with the image
|
|
1028
|
+
# factor is for the compute input not to have too much noise because the average of two gaussians
|
|
1029
|
+
if len(noise_tuples) > 0:
|
|
1030
|
+
factor = np.sqrt(2) if self._input_is_sum else 1.0
|
|
1031
|
+
input_tuples = []
|
|
1032
|
+
for x in img_tuples:
|
|
1033
|
+
x = (
|
|
1034
|
+
x.copy()
|
|
1035
|
+
) # to avoid changing the original image since it is later used for target
|
|
1036
|
+
# NOTE: other LC levels already have noise added. So, we just need to add noise to the highest resolution.
|
|
1037
|
+
x[0] = x[0] + noise_tuples[0] * factor
|
|
1038
|
+
input_tuples.append(x)
|
|
1039
|
+
else:
|
|
1040
|
+
input_tuples = img_tuples
|
|
1041
|
+
|
|
1042
|
+
# Compute the input by sum / average the channels
|
|
1043
|
+
# Alpha is an amount of weight which is applied to the channels when combining them
|
|
1044
|
+
# How to sample alpha is still under research
|
|
1045
|
+
inp, alpha = self._compute_input(input_tuples)
|
|
1046
|
+
target_tuples = [img[:1] for img in img_tuples]
|
|
1047
|
+
# add noise to target.
|
|
1048
|
+
if len(noise_tuples) >= 1:
|
|
1049
|
+
target_tuples = [
|
|
1050
|
+
x + noise for x, noise in zip(target_tuples, noise_tuples[1:])
|
|
1051
|
+
]
|
|
1052
|
+
|
|
1053
|
+
target = self._compute_target(target_tuples, alpha)
|
|
1054
|
+
|
|
1055
|
+
norm_target = self.normalize_target(target)
|
|
1056
|
+
|
|
1057
|
+
output = [inp, norm_target]
|
|
1058
|
+
|
|
1059
|
+
if self._return_alpha:
|
|
1060
|
+
output.append(alpha)
|
|
1061
|
+
|
|
1062
|
+
if isinstance(index, int):
|
|
1063
|
+
return tuple(output)
|
|
1064
|
+
|
|
1065
|
+
_, grid_size = index
|
|
1066
|
+
output.append(grid_size)
|
|
1067
|
+
return tuple(output)
|