careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +1 -14
- careamics/careamist.py +212 -294
- careamics/config/__init__.py +0 -3
- careamics/config/algorithm_model.py +8 -15
- careamics/config/architectures/architecture_model.py +1 -0
- careamics/config/architectures/custom_model.py +5 -3
- careamics/config/architectures/unet_model.py +19 -0
- careamics/config/architectures/vae_model.py +1 -0
- careamics/config/callback_model.py +76 -34
- careamics/config/configuration_factory.py +18 -98
- careamics/config/configuration_model.py +23 -18
- careamics/config/data_model.py +103 -54
- careamics/config/inference_model.py +41 -19
- careamics/config/optimizer_models.py +13 -7
- careamics/config/support/supported_data.py +29 -4
- careamics/config/support/supported_transforms.py +0 -1
- careamics/config/tile_information.py +36 -58
- careamics/config/training_model.py +5 -1
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/validators/validator_utils.py +1 -1
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -7
- careamics/dataset/dataset_utils/file_utils.py +2 -2
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +84 -173
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +97 -250
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/patching.py +97 -52
- careamics/dataset/patching/random_patching.py +9 -4
- careamics/dataset/patching/validate_patch_dimension.py +5 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
- careamics/file_io/__init__.py +7 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -10
- careamics/file_io/write/__init__.py +9 -0
- careamics/file_io/write/get_func.py +59 -0
- careamics/file_io/write/tiff.py +39 -0
- careamics/lightning/__init__.py +17 -0
- careamics/{lightning_module.py → lightning/lightning_module.py} +69 -92
- careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +120 -178
- careamics/{lightning_datamodule.py → lightning/train_data_module.py} +135 -220
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +2 -2
- careamics/model_io/model_io_utils.py +6 -3
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/prediction_outputs.py +137 -0
- careamics/prediction_utils/stitch_prediction.py +103 -0
- careamics/transforms/n2v_manipulate.py +3 -1
- careamics/transforms/normalize.py +139 -68
- careamics/transforms/pixel_manipulation.py +33 -9
- careamics/transforms/tta.py +43 -29
- careamics/utils/__init__.py +2 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/ram.py +2 -2
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +7 -6
- careamics-0.1.0rc8.dist-info/RECORD +135 -0
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +1 -1
- careamics/config/configuration_example.py +0 -89
- careamics/dataset/dataset_utils/read_utils.py +0 -27
- careamics/lightning_prediction_loop.py +0 -118
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -70
- careamics/utils/running_stats.py +0 -43
- careamics-0.1.0rc6.dist-info/RECORD +0 -107
- /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
- /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1220 @@
|
|
|
1
|
+
"""
|
|
2
|
+
A place for Datasets and Dataloaders.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from typing import Tuple, Union
|
|
7
|
+
|
|
8
|
+
# import albumentations as A
|
|
9
|
+
import ml_collections
|
|
10
|
+
import numpy as np
|
|
11
|
+
from skimage.transform import resize
|
|
12
|
+
|
|
13
|
+
from .data_utils import (
|
|
14
|
+
DataSplitType,
|
|
15
|
+
DataType,
|
|
16
|
+
GridAlignement,
|
|
17
|
+
GridIndexManager,
|
|
18
|
+
IndexSwitcher,
|
|
19
|
+
get_datasplit_tuples,
|
|
20
|
+
get_mrc_data,
|
|
21
|
+
load_tiff,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_train_val_data(
|
|
26
|
+
data_config,
|
|
27
|
+
fpath,
|
|
28
|
+
datasplit_type: DataSplitType,
|
|
29
|
+
val_fraction=None,
|
|
30
|
+
test_fraction=None,
|
|
31
|
+
allow_generation=None,
|
|
32
|
+
ignore_specific_datapoints=None,
|
|
33
|
+
):
|
|
34
|
+
"""
|
|
35
|
+
Load the data from the given path and split them in training, validation and test sets.
|
|
36
|
+
|
|
37
|
+
Ensure that the shape of data should be N*H*W*C: N is number of data points. H,W are the image dimensions.
|
|
38
|
+
C is the number of channels.
|
|
39
|
+
"""
|
|
40
|
+
if data_config.data_type == DataType.SeparateTiffData:
|
|
41
|
+
fpath1 = os.path.join(fpath, data_config.ch1_fname)
|
|
42
|
+
fpath2 = os.path.join(fpath, data_config.ch2_fname)
|
|
43
|
+
fpaths = [fpath1, fpath2]
|
|
44
|
+
fpath0 = ""
|
|
45
|
+
if "ch_input_fname" in data_config:
|
|
46
|
+
fpath0 = os.path.join(fpath, data_config.ch_input_fname)
|
|
47
|
+
fpaths = [fpath0] + fpaths
|
|
48
|
+
|
|
49
|
+
print(
|
|
50
|
+
f"Loading from {fpath} Channels: "
|
|
51
|
+
f"{fpath1},{fpath2}, inp:{fpath0} Mode:{DataSplitType.name(datasplit_type)}"
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
data = np.concatenate([load_tiff(fpath)[..., None] for fpath in fpaths], axis=3)
|
|
55
|
+
if data_config.data_type == DataType.PredictedTiffData:
|
|
56
|
+
assert len(data.shape) == 5 and data.shape[-1] == 1
|
|
57
|
+
data = data[..., 0].copy()
|
|
58
|
+
# data = data[::3].copy()
|
|
59
|
+
# NOTE: This was not the correct way to do it. It is so because the noise present in the input was directly related
|
|
60
|
+
# to the noise present in the channels and so this is not the way we would get the data.
|
|
61
|
+
# We need to add the noise independently to the input and the target.
|
|
62
|
+
|
|
63
|
+
# if data_config.get('poisson_noise_factor', False):
|
|
64
|
+
# data = np.random.poisson(data)
|
|
65
|
+
# if data_config.get('enable_gaussian_noise', False):
|
|
66
|
+
# synthetic_scale = data_config.get('synthetic_gaussian_scale', 0.1)
|
|
67
|
+
# print('Adding Gaussian noise with scale', synthetic_scale)
|
|
68
|
+
# noise = np.random.normal(0, synthetic_scale, data.shape)
|
|
69
|
+
# data = data + noise
|
|
70
|
+
|
|
71
|
+
if datasplit_type == DataSplitType.All:
|
|
72
|
+
return data.astype(np.float32)
|
|
73
|
+
|
|
74
|
+
train_idx, val_idx, test_idx = get_datasplit_tuples(
|
|
75
|
+
val_fraction, test_fraction, len(data), starting_test=True
|
|
76
|
+
)
|
|
77
|
+
if datasplit_type == DataSplitType.Train:
|
|
78
|
+
return data[train_idx].astype(np.float32)
|
|
79
|
+
elif datasplit_type == DataSplitType.Val:
|
|
80
|
+
return data[val_idx].astype(np.float32)
|
|
81
|
+
elif datasplit_type == DataSplitType.Test:
|
|
82
|
+
return data[test_idx].astype(np.float32)
|
|
83
|
+
|
|
84
|
+
elif data_config.data_type == DataType.BioSR_MRC:
|
|
85
|
+
num_channels = data_config.get("num_channels", 2)
|
|
86
|
+
fpaths = []
|
|
87
|
+
data_list = []
|
|
88
|
+
for i in range(num_channels):
|
|
89
|
+
fpath1 = os.path.join(fpath, data_config.get(f"ch{i + 1}_fname"))
|
|
90
|
+
fpaths.append(fpath1)
|
|
91
|
+
data = get_mrc_data(fpath1)[..., None]
|
|
92
|
+
data_list.append(data)
|
|
93
|
+
|
|
94
|
+
dirname = os.path.dirname(os.path.dirname(fpaths[0])) + "/"
|
|
95
|
+
|
|
96
|
+
msg = ",".join([x[len(dirname) :] for x in fpaths])
|
|
97
|
+
print(
|
|
98
|
+
f"Loaded from {dirname} Channels:{len(fpaths)} {msg} Mode:{DataSplitType.name(datasplit_type)}"
|
|
99
|
+
)
|
|
100
|
+
N = data_list[0].shape[0]
|
|
101
|
+
for data in data_list:
|
|
102
|
+
N = min(N, data.shape[0])
|
|
103
|
+
|
|
104
|
+
cropped_data = []
|
|
105
|
+
for data in data_list:
|
|
106
|
+
cropped_data.append(data[:N])
|
|
107
|
+
|
|
108
|
+
data = np.concatenate(cropped_data, axis=3)
|
|
109
|
+
|
|
110
|
+
if datasplit_type == DataSplitType.All:
|
|
111
|
+
return data.astype(np.float32)
|
|
112
|
+
|
|
113
|
+
train_idx, val_idx, test_idx = get_datasplit_tuples(
|
|
114
|
+
val_fraction, test_fraction, len(data), starting_test=True
|
|
115
|
+
)
|
|
116
|
+
if datasplit_type == DataSplitType.Train:
|
|
117
|
+
return data[train_idx].astype(np.float32)
|
|
118
|
+
elif datasplit_type == DataSplitType.Val:
|
|
119
|
+
return data[val_idx].astype(np.float32)
|
|
120
|
+
elif datasplit_type == DataSplitType.Test:
|
|
121
|
+
return data[test_idx].astype(np.float32)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class MultiChDloader:
|
|
125
|
+
|
|
126
|
+
def __init__(
|
|
127
|
+
self,
|
|
128
|
+
data_config: ml_collections.ConfigDict,
|
|
129
|
+
fpath: str,
|
|
130
|
+
datasplit_type: DataSplitType = None,
|
|
131
|
+
val_fraction: float = None,
|
|
132
|
+
test_fraction: float = None,
|
|
133
|
+
normalized_input=None,
|
|
134
|
+
enable_rotation_aug: bool = False,
|
|
135
|
+
enable_random_cropping: bool = False,
|
|
136
|
+
use_one_mu_std=None,
|
|
137
|
+
allow_generation: bool = False,
|
|
138
|
+
max_val: float = None,
|
|
139
|
+
grid_alignment=GridAlignement.LeftTop,
|
|
140
|
+
overlapping_padding_kwargs=None,
|
|
141
|
+
print_vars: bool = True,
|
|
142
|
+
):
|
|
143
|
+
"""
|
|
144
|
+
Here, an image is split into grids of size img_sz.
|
|
145
|
+
Args:
|
|
146
|
+
repeat_factor: Since we are doing a random crop, repeat_factor is
|
|
147
|
+
given which can repeatedly sample from the same image. If self.N=12
|
|
148
|
+
and repeat_factor is 5, then index upto 12*5 = 60 is allowed.
|
|
149
|
+
use_one_mu_std: If this is set to true, then one mean and stdev is used
|
|
150
|
+
for both channels. Otherwise, two different meean and stdev are used.
|
|
151
|
+
|
|
152
|
+
"""
|
|
153
|
+
self._data_type = data_config.data_type
|
|
154
|
+
self._fpath = fpath
|
|
155
|
+
self._data = self.N = self._noise_data = None
|
|
156
|
+
|
|
157
|
+
# Hardcoded params, not included in the config file.
|
|
158
|
+
|
|
159
|
+
# by default, if the noise is present, add it to the input and target.
|
|
160
|
+
self._disable_noise = False # to add synthetic noise
|
|
161
|
+
self._train_index_switcher = None
|
|
162
|
+
# NOTE: Input is the sum of the different channels. It is not the average of the different channels.
|
|
163
|
+
self._input_is_sum = data_config.get("input_is_sum", False)
|
|
164
|
+
self._num_channels = data_config.get("num_channels", 2)
|
|
165
|
+
self._input_idx = data_config.get("input_idx", None)
|
|
166
|
+
self._tar_idx_list = data_config.get("target_idx_list", None)
|
|
167
|
+
|
|
168
|
+
if datasplit_type == DataSplitType.Train:
|
|
169
|
+
self._datausage_fraction = 1.0
|
|
170
|
+
# assert self._datausage_fraction == 1.0, 'Not supported. Use validtarget_random_fraction and training_validtarget_fraction to get the same effect'
|
|
171
|
+
self._validtarget_rand_fract = None
|
|
172
|
+
# self._validtarget_random_fraction_final = data_config.get('validtarget_random_fraction_final', None)
|
|
173
|
+
# self._validtarget_random_fraction_stepepoch = data_config.get('validtarget_random_fraction_stepepoch', None)
|
|
174
|
+
# self._idx_count = 0
|
|
175
|
+
elif datasplit_type == DataSplitType.Val:
|
|
176
|
+
self._datausage_fraction = 1.0
|
|
177
|
+
else:
|
|
178
|
+
self._datausage_fraction = 1.0
|
|
179
|
+
|
|
180
|
+
self.load_data(
|
|
181
|
+
data_config,
|
|
182
|
+
datasplit_type,
|
|
183
|
+
val_fraction=val_fraction,
|
|
184
|
+
test_fraction=test_fraction,
|
|
185
|
+
allow_generation=allow_generation,
|
|
186
|
+
)
|
|
187
|
+
self._normalized_input = normalized_input
|
|
188
|
+
self._quantile = 1.0
|
|
189
|
+
self._channelwise_quantile = False
|
|
190
|
+
self._background_quantile = 0.0
|
|
191
|
+
self._clip_background_noise_to_zero = False
|
|
192
|
+
self._skip_normalization_using_mean = False
|
|
193
|
+
self._empty_patch_replacement_enabled = False
|
|
194
|
+
|
|
195
|
+
self._background_values = None
|
|
196
|
+
|
|
197
|
+
self._grid_alignment = grid_alignment
|
|
198
|
+
self._overlapping_padding_kwargs = overlapping_padding_kwargs
|
|
199
|
+
if self._grid_alignment == GridAlignement.LeftTop:
|
|
200
|
+
assert (
|
|
201
|
+
self._overlapping_padding_kwargs is None
|
|
202
|
+
or data_config.multiscale_lowres_count is not None
|
|
203
|
+
), "Padding is not used with this alignement style"
|
|
204
|
+
elif self._grid_alignment == GridAlignement.Center:
|
|
205
|
+
assert (
|
|
206
|
+
self._overlapping_padding_kwargs is not None
|
|
207
|
+
), "With Center grid alignment, padding is needed."
|
|
208
|
+
|
|
209
|
+
self._is_train = datasplit_type == DataSplitType.Train
|
|
210
|
+
|
|
211
|
+
# input = alpha * ch1 + (1-alpha)*ch2.
|
|
212
|
+
# alpha is sampled randomly between these two extremes
|
|
213
|
+
self._start_alpha_arr = self._end_alpha_arr = self._return_alpha = (
|
|
214
|
+
self._alpha_weighted_target
|
|
215
|
+
) = None
|
|
216
|
+
|
|
217
|
+
self._img_sz = self._grid_sz = self._repeat_factor = self.idx_manager = None
|
|
218
|
+
if self._is_train:
|
|
219
|
+
self._start_alpha_arr = None
|
|
220
|
+
self._end_alpha_arr = None
|
|
221
|
+
self._alpha_weighted_target = False
|
|
222
|
+
|
|
223
|
+
self.set_img_sz(
|
|
224
|
+
data_config.image_size,
|
|
225
|
+
(
|
|
226
|
+
data_config.grid_size
|
|
227
|
+
if "grid_size" in data_config
|
|
228
|
+
else data_config.image_size
|
|
229
|
+
),
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
# if self._validtarget_rand_fract is not None:
|
|
233
|
+
# self._train_index_switcher = IndexSwitcher(self.idx_manager, data_config, self._img_sz)
|
|
234
|
+
# self._std_background_arr = None
|
|
235
|
+
|
|
236
|
+
else:
|
|
237
|
+
self.set_img_sz(
|
|
238
|
+
data_config.image_size,
|
|
239
|
+
(
|
|
240
|
+
data_config.grid_size
|
|
241
|
+
if "grid_size" in data_config
|
|
242
|
+
else data_config.image_size
|
|
243
|
+
),
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
self._return_alpha = False
|
|
247
|
+
self._return_index = False
|
|
248
|
+
|
|
249
|
+
# self._empty_patch_replacement_enabled = data_config.get("empty_patch_replacement_enabled",
|
|
250
|
+
# False) and self._is_train
|
|
251
|
+
# if self._empty_patch_replacement_enabled:
|
|
252
|
+
# self._empty_patch_replacement_channel_idx = data_config.empty_patch_replacement_channel_idx
|
|
253
|
+
# self._empty_patch_replacement_probab = data_config.empty_patch_replacement_probab
|
|
254
|
+
# data_frames = self._data[..., self._empty_patch_replacement_channel_idx]
|
|
255
|
+
# # NOTE: This is on the raw data. So, it must be called before removing the background.
|
|
256
|
+
# self._empty_patch_fetcher = EmptyPatchFetcher(self.idx_manager,
|
|
257
|
+
# self._img_sz,
|
|
258
|
+
# data_frames,
|
|
259
|
+
# max_val_threshold=data_config.empty_patch_max_val_threshold)
|
|
260
|
+
|
|
261
|
+
self.rm_bkground_set_max_val_and_upperclip_data(max_val, datasplit_type)
|
|
262
|
+
|
|
263
|
+
# For overlapping dloader, image_size and repeat_factors are not related. hence a different function.
|
|
264
|
+
|
|
265
|
+
self._mean = None
|
|
266
|
+
self._std = None
|
|
267
|
+
self._use_one_mu_std = use_one_mu_std
|
|
268
|
+
# Hardcoded
|
|
269
|
+
self._target_separate_normalization = True
|
|
270
|
+
|
|
271
|
+
self._enable_rotation = enable_rotation_aug
|
|
272
|
+
self._enable_random_cropping = enable_random_cropping
|
|
273
|
+
self._uncorrelated_channels = (
|
|
274
|
+
data_config.get("uncorrelated_channels", False) and self._is_train
|
|
275
|
+
)
|
|
276
|
+
assert self._is_train or self._uncorrelated_channels is False
|
|
277
|
+
assert (
|
|
278
|
+
self._enable_random_cropping is True or self._uncorrelated_channels is False
|
|
279
|
+
)
|
|
280
|
+
# Randomly rotate [-90,90]
|
|
281
|
+
|
|
282
|
+
self._rotation_transform = None
|
|
283
|
+
if self._enable_rotation:
|
|
284
|
+
raise NotImplementedError(
|
|
285
|
+
"Augmentation by means of rotation is not supported yet."
|
|
286
|
+
)
|
|
287
|
+
self._rotation_transform = A.Compose([A.Flip(), A.RandomRotate90()])
|
|
288
|
+
|
|
289
|
+
if print_vars:
|
|
290
|
+
msg = self._init_msg()
|
|
291
|
+
print(msg)
|
|
292
|
+
|
|
293
|
+
def disable_noise(self):
|
|
294
|
+
assert (
|
|
295
|
+
self._poisson_noise_factor is None
|
|
296
|
+
), "This is not supported. Poisson noise is added to the data itself and so the noise cannot be disabled."
|
|
297
|
+
self._disable_noise = True
|
|
298
|
+
|
|
299
|
+
def enable_noise(self):
|
|
300
|
+
self._disable_noise = False
|
|
301
|
+
|
|
302
|
+
def get_data_shape(self):
|
|
303
|
+
return self._data.shape
|
|
304
|
+
|
|
305
|
+
def load_data(
|
|
306
|
+
self,
|
|
307
|
+
data_config,
|
|
308
|
+
datasplit_type,
|
|
309
|
+
val_fraction=None,
|
|
310
|
+
test_fraction=None,
|
|
311
|
+
allow_generation=None,
|
|
312
|
+
):
|
|
313
|
+
self._data = get_train_val_data(
|
|
314
|
+
data_config,
|
|
315
|
+
self._fpath,
|
|
316
|
+
datasplit_type,
|
|
317
|
+
val_fraction=val_fraction,
|
|
318
|
+
test_fraction=test_fraction,
|
|
319
|
+
allow_generation=allow_generation,
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
old_shape = self._data.shape
|
|
323
|
+
if self._datausage_fraction < 1.0:
|
|
324
|
+
framepixelcount = np.prod(self._data.shape[1:3])
|
|
325
|
+
pixelcount = int(
|
|
326
|
+
len(self._data) * framepixelcount * self._datausage_fraction
|
|
327
|
+
)
|
|
328
|
+
frame_count = int(np.ceil(pixelcount / framepixelcount))
|
|
329
|
+
last_frame_reduced_size, _ = IndexSwitcher.get_reduced_frame_size(
|
|
330
|
+
self._data.shape[:3], self._datausage_fraction
|
|
331
|
+
)
|
|
332
|
+
self._data = self._data[:frame_count].copy()
|
|
333
|
+
if frame_count == 1:
|
|
334
|
+
self._data = self._data[
|
|
335
|
+
:, :last_frame_reduced_size, :last_frame_reduced_size
|
|
336
|
+
].copy()
|
|
337
|
+
print(
|
|
338
|
+
f"[{self.__class__.__name__}] New data shape: {self._data.shape} Old: {old_shape}"
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
msg = ""
|
|
342
|
+
if data_config.get("poisson_noise_factor", -1) > 0:
|
|
343
|
+
self._poisson_noise_factor = data_config.poisson_noise_factor
|
|
344
|
+
msg += f"Adding Poisson noise with factor {self._poisson_noise_factor}.\t"
|
|
345
|
+
self._data = (
|
|
346
|
+
np.random.poisson(self._data / self._poisson_noise_factor)
|
|
347
|
+
* self._poisson_noise_factor
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
if data_config.get("enable_gaussian_noise", False):
|
|
351
|
+
synthetic_scale = data_config.get("synthetic_gaussian_scale", 0.1)
|
|
352
|
+
msg += f"Adding Gaussian noise with scale {synthetic_scale}"
|
|
353
|
+
# 0 => noise for input. 1: => noise for all targets.
|
|
354
|
+
shape = self._data.shape
|
|
355
|
+
self._noise_data = np.random.normal(
|
|
356
|
+
0, synthetic_scale, (*shape[:-1], shape[-1] + 1)
|
|
357
|
+
)
|
|
358
|
+
if data_config.get("input_has_dependant_noise", False):
|
|
359
|
+
msg += ". Moreover, input has dependent noise"
|
|
360
|
+
self._noise_data[..., 0] = np.mean(self._noise_data[..., 1:], axis=-1)
|
|
361
|
+
print(msg)
|
|
362
|
+
|
|
363
|
+
self.N = len(self._data)
|
|
364
|
+
assert (
|
|
365
|
+
self._data.shape[-1] == self._num_channels
|
|
366
|
+
), "Number of channels in data and config do not match."
|
|
367
|
+
|
|
368
|
+
def save_background(self, channel_idx, frame_idx, background_value):
|
|
369
|
+
self._background_values[frame_idx, channel_idx] = background_value
|
|
370
|
+
|
|
371
|
+
def get_background(self, channel_idx, frame_idx):
|
|
372
|
+
return self._background_values[frame_idx, channel_idx]
|
|
373
|
+
|
|
374
|
+
def remove_background(self):
|
|
375
|
+
|
|
376
|
+
self._background_values = np.zeros((self._data.shape[0], self._data.shape[-1]))
|
|
377
|
+
|
|
378
|
+
if self._background_quantile == 0.0:
|
|
379
|
+
assert (
|
|
380
|
+
self._clip_background_noise_to_zero is False
|
|
381
|
+
), "This operation currently happens later in this function."
|
|
382
|
+
return
|
|
383
|
+
|
|
384
|
+
if self._data.dtype in [np.uint16]:
|
|
385
|
+
# unsigned integer creates havoc
|
|
386
|
+
self._data = self._data.astype(np.int32)
|
|
387
|
+
|
|
388
|
+
for ch in range(self._data.shape[-1]):
|
|
389
|
+
for idx in range(self._data.shape[0]):
|
|
390
|
+
qval = np.quantile(self._data[idx, ..., ch], self._background_quantile)
|
|
391
|
+
assert (
|
|
392
|
+
np.abs(qval) > 20
|
|
393
|
+
), "We are truncating the qval to an integer which will only make sense if it is large enough"
|
|
394
|
+
# NOTE: Here, there can be an issue if you work with normalized data
|
|
395
|
+
qval = int(qval)
|
|
396
|
+
self.save_background(ch, idx, qval)
|
|
397
|
+
self._data[idx, ..., ch] -= qval
|
|
398
|
+
|
|
399
|
+
if self._clip_background_noise_to_zero:
|
|
400
|
+
self._data[self._data < 0] = 0
|
|
401
|
+
|
|
402
|
+
def rm_bkground_set_max_val_and_upperclip_data(self, max_val, datasplit_type):
|
|
403
|
+
self.remove_background()
|
|
404
|
+
self.set_max_val(max_val, datasplit_type)
|
|
405
|
+
self.upperclip_data()
|
|
406
|
+
|
|
407
|
+
def upperclip_data(self):
|
|
408
|
+
if isinstance(self.max_val, list):
|
|
409
|
+
chN = self._data.shape[-1]
|
|
410
|
+
assert chN == len(self.max_val)
|
|
411
|
+
for ch in range(chN):
|
|
412
|
+
ch_data = self._data[..., ch]
|
|
413
|
+
ch_q = self.max_val[ch]
|
|
414
|
+
ch_data[ch_data > ch_q] = ch_q
|
|
415
|
+
self._data[..., ch] = ch_data
|
|
416
|
+
else:
|
|
417
|
+
self._data[self._data > self.max_val] = self.max_val
|
|
418
|
+
|
|
419
|
+
def compute_max_val(self):
|
|
420
|
+
if self._channelwise_quantile:
|
|
421
|
+
max_val_arr = [
|
|
422
|
+
np.quantile(self._data[..., i], self._quantile)
|
|
423
|
+
for i in range(self._data.shape[-1])
|
|
424
|
+
]
|
|
425
|
+
return max_val_arr
|
|
426
|
+
else:
|
|
427
|
+
return np.quantile(self._data, self._quantile)
|
|
428
|
+
|
|
429
|
+
def set_max_val(self, max_val, datasplit_type):
|
|
430
|
+
|
|
431
|
+
if max_val is None:
|
|
432
|
+
assert datasplit_type == DataSplitType.Train
|
|
433
|
+
self.max_val = self.compute_max_val()
|
|
434
|
+
else:
|
|
435
|
+
assert max_val is not None
|
|
436
|
+
self.max_val = max_val
|
|
437
|
+
|
|
438
|
+
def get_max_val(self):
|
|
439
|
+
return self.max_val
|
|
440
|
+
|
|
441
|
+
def get_img_sz(self):
|
|
442
|
+
return self._img_sz
|
|
443
|
+
|
|
444
|
+
def reduce_data(
|
|
445
|
+
self, t_list=None, h_start=None, h_end=None, w_start=None, w_end=None
|
|
446
|
+
):
|
|
447
|
+
if t_list is None:
|
|
448
|
+
t_list = list(range(self._data.shape[0]))
|
|
449
|
+
if h_start is None:
|
|
450
|
+
h_start = 0
|
|
451
|
+
if h_end is None:
|
|
452
|
+
h_end = self._data.shape[1]
|
|
453
|
+
if w_start is None:
|
|
454
|
+
w_start = 0
|
|
455
|
+
if w_end is None:
|
|
456
|
+
w_end = self._data.shape[2]
|
|
457
|
+
|
|
458
|
+
self._data = self._data[t_list, h_start:h_end, w_start:w_end, :].copy()
|
|
459
|
+
if self._noise_data is not None:
|
|
460
|
+
self._noise_data = self._noise_data[
|
|
461
|
+
t_list, h_start:h_end, w_start:w_end, :
|
|
462
|
+
].copy()
|
|
463
|
+
|
|
464
|
+
self.N = len(t_list)
|
|
465
|
+
self.set_img_sz(self._img_sz, self._grid_sz)
|
|
466
|
+
print(
|
|
467
|
+
f"[{self.__class__.__name__}] Data reduced. New data shape: {self._data.shape}"
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
def set_img_sz(self, image_size, grid_size):
|
|
471
|
+
"""
|
|
472
|
+
If one wants to change the image size on the go, then this can be used.
|
|
473
|
+
Args:
|
|
474
|
+
image_size: size of one patch
|
|
475
|
+
grid_size: frame is divided into square grids of this size. A patch centered on a grid having size `image_size` is returned.
|
|
476
|
+
"""
|
|
477
|
+
self._img_sz = image_size
|
|
478
|
+
self._grid_sz = grid_size
|
|
479
|
+
self.idx_manager = GridIndexManager(
|
|
480
|
+
self._data.shape, self._grid_sz, self._img_sz, self._grid_alignment
|
|
481
|
+
)
|
|
482
|
+
self.set_repeat_factor()
|
|
483
|
+
|
|
484
|
+
def set_repeat_factor(self):
|
|
485
|
+
if self._grid_sz > 1:
|
|
486
|
+
self._repeat_factor = self.idx_manager.grid_rows(
|
|
487
|
+
self._grid_sz
|
|
488
|
+
) * self.idx_manager.grid_cols(self._grid_sz)
|
|
489
|
+
else:
|
|
490
|
+
self._repeat_factor = self.idx_manager.grid_rows(
|
|
491
|
+
self._img_sz
|
|
492
|
+
) * self.idx_manager.grid_cols(self._img_sz)
|
|
493
|
+
|
|
494
|
+
def _init_msg(
|
|
495
|
+
self,
|
|
496
|
+
):
|
|
497
|
+
msg = (
|
|
498
|
+
f"[{self.__class__.__name__}] Train:{int(self._is_train)} Sz:{self._img_sz}"
|
|
499
|
+
)
|
|
500
|
+
msg += f" N:{self.N} NumPatchPerN:{self._repeat_factor}"
|
|
501
|
+
# msg += f' NormInp:{self._normalized_input}'
|
|
502
|
+
# msg += f' SingleNorm:{self._use_one_mu_std}'
|
|
503
|
+
msg += f" Rot:{self._enable_rotation}"
|
|
504
|
+
msg += f" RandCrop:{self._enable_random_cropping}"
|
|
505
|
+
msg += f" Channel:{self._num_channels}"
|
|
506
|
+
# msg += f' Q:{self._quantile}'
|
|
507
|
+
if self._input_is_sum:
|
|
508
|
+
msg += f" SummedInput:{self._input_is_sum}"
|
|
509
|
+
|
|
510
|
+
if self._empty_patch_replacement_enabled:
|
|
511
|
+
msg += f" ReplaceWithRandSample:{self._empty_patch_replacement_enabled}"
|
|
512
|
+
if self._uncorrelated_channels:
|
|
513
|
+
msg += f" Uncorr:{self._uncorrelated_channels}"
|
|
514
|
+
if self._empty_patch_replacement_enabled:
|
|
515
|
+
msg += f"-{self._empty_patch_replacement_channel_idx}-{self._empty_patch_replacement_probab}"
|
|
516
|
+
if self._background_quantile > 0.0:
|
|
517
|
+
msg += f" BckQ:{self._background_quantile}"
|
|
518
|
+
|
|
519
|
+
if self._start_alpha_arr is not None:
|
|
520
|
+
msg += f" Alpha:[{self._start_alpha_arr},{self._end_alpha_arr}]"
|
|
521
|
+
return msg
|
|
522
|
+
|
|
523
|
+
def _crop_imgs(self, index, *img_tuples: np.ndarray):
|
|
524
|
+
h, w = img_tuples[0].shape[-2:]
|
|
525
|
+
if self._img_sz is None:
|
|
526
|
+
return (
|
|
527
|
+
*img_tuples,
|
|
528
|
+
{"h": [0, h], "w": [0, w], "hflip": False, "wflip": False},
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
if self._enable_random_cropping:
|
|
532
|
+
h_start, w_start = self._get_random_hw(h, w)
|
|
533
|
+
else:
|
|
534
|
+
h_start, w_start = self._get_deterministic_hw(index)
|
|
535
|
+
|
|
536
|
+
cropped_imgs = []
|
|
537
|
+
for img in img_tuples:
|
|
538
|
+
img = self._crop_flip_img(img, h_start, w_start, False, False)
|
|
539
|
+
cropped_imgs.append(img)
|
|
540
|
+
|
|
541
|
+
return (
|
|
542
|
+
*tuple(cropped_imgs),
|
|
543
|
+
{
|
|
544
|
+
"h": [h_start, h_start + self._img_sz],
|
|
545
|
+
"w": [w_start, w_start + self._img_sz],
|
|
546
|
+
"hflip": False,
|
|
547
|
+
"wflip": False,
|
|
548
|
+
},
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
def _crop_img(self, img: np.ndarray, h_start: int, w_start: int):
|
|
552
|
+
if self._grid_alignment == GridAlignement.LeftTop:
|
|
553
|
+
# In training, this is used.
|
|
554
|
+
# NOTE: It is my opinion that if I just use self._crop_img_with_padding, it will work perfectly fine.
|
|
555
|
+
# The only benefit this if else loop provides is that it makes it easier to see what happens during training.
|
|
556
|
+
new_img = img[
|
|
557
|
+
..., h_start : h_start + self._img_sz, w_start : w_start + self._img_sz
|
|
558
|
+
]
|
|
559
|
+
return new_img
|
|
560
|
+
elif self._grid_alignment == GridAlignement.Center:
|
|
561
|
+
# During evaluation, this is used. In this situation, we can have negative h_start, w_start. Or h_start +self._img_sz can be larger than frame
|
|
562
|
+
# In these situations, we need some sort of padding. This is not needed in the LeftTop alignement.
|
|
563
|
+
return self._crop_img_with_padding(img, h_start, w_start)
|
|
564
|
+
|
|
565
|
+
def get_begin_end_padding(self, start_pos, max_len):
|
|
566
|
+
"""
|
|
567
|
+
The effect is that the image with size self._grid_sz is in the center of the patch with sufficient
|
|
568
|
+
padding on all four sides so that the final patch size is self._img_sz.
|
|
569
|
+
"""
|
|
570
|
+
pad_start = 0
|
|
571
|
+
pad_end = 0
|
|
572
|
+
if start_pos < 0:
|
|
573
|
+
pad_start = -1 * start_pos
|
|
574
|
+
|
|
575
|
+
pad_end = max(0, start_pos + self._img_sz - max_len)
|
|
576
|
+
|
|
577
|
+
return pad_start, pad_end
|
|
578
|
+
|
|
579
|
+
def _crop_img_with_padding(self, img: np.ndarray, h_start: int, w_start: int):
|
|
580
|
+
_, H, W = img.shape
|
|
581
|
+
h_on_boundary = self.on_boundary(h_start, H)
|
|
582
|
+
w_on_boundary = self.on_boundary(w_start, W)
|
|
583
|
+
|
|
584
|
+
assert h_start < H
|
|
585
|
+
assert w_start < W
|
|
586
|
+
|
|
587
|
+
assert h_start + self._img_sz <= H or h_on_boundary
|
|
588
|
+
assert w_start + self._img_sz <= W or w_on_boundary
|
|
589
|
+
# max() is needed since h_start could be negative.
|
|
590
|
+
new_img = img[
|
|
591
|
+
...,
|
|
592
|
+
max(0, h_start) : h_start + self._img_sz,
|
|
593
|
+
max(0, w_start) : w_start + self._img_sz,
|
|
594
|
+
]
|
|
595
|
+
padding = np.array([[0, 0], [0, 0], [0, 0]])
|
|
596
|
+
|
|
597
|
+
if h_on_boundary:
|
|
598
|
+
pad = self.get_begin_end_padding(h_start, H)
|
|
599
|
+
padding[1] = pad
|
|
600
|
+
if w_on_boundary:
|
|
601
|
+
pad = self.get_begin_end_padding(w_start, W)
|
|
602
|
+
padding[2] = pad
|
|
603
|
+
|
|
604
|
+
if not np.all(padding == 0):
|
|
605
|
+
new_img = np.pad(new_img, padding, **self._overlapping_padding_kwargs)
|
|
606
|
+
|
|
607
|
+
return new_img
|
|
608
|
+
|
|
609
|
+
def _crop_flip_img(
|
|
610
|
+
self, img: np.ndarray, h_start: int, w_start: int, h_flip: bool, w_flip: bool
|
|
611
|
+
):
|
|
612
|
+
new_img = self._crop_img(img, h_start, w_start)
|
|
613
|
+
if h_flip:
|
|
614
|
+
new_img = new_img[..., ::-1, :]
|
|
615
|
+
if w_flip:
|
|
616
|
+
new_img = new_img[..., :, ::-1]
|
|
617
|
+
|
|
618
|
+
return new_img.astype(np.float32)
|
|
619
|
+
|
|
620
|
+
def __len__(self):
|
|
621
|
+
return self.N * self._repeat_factor
|
|
622
|
+
|
|
623
|
+
def _load_img(
|
|
624
|
+
self, index: Union[int, Tuple[int, int]]
|
|
625
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
626
|
+
"""
|
|
627
|
+
Returns the channels and also the respective noise channels.
|
|
628
|
+
"""
|
|
629
|
+
if isinstance(index, int) or isinstance(index, np.int64):
|
|
630
|
+
idx = index
|
|
631
|
+
else:
|
|
632
|
+
idx = index[0]
|
|
633
|
+
|
|
634
|
+
imgs = self._data[self.idx_manager.get_t(idx)]
|
|
635
|
+
loaded_imgs = [imgs[None, ..., i] for i in range(imgs.shape[-1])]
|
|
636
|
+
noise = []
|
|
637
|
+
if self._noise_data is not None and not self._disable_noise:
|
|
638
|
+
noise = [
|
|
639
|
+
self._noise_data[self.idx_manager.get_t(idx)][None, ..., i]
|
|
640
|
+
for i in range(self._noise_data.shape[-1])
|
|
641
|
+
]
|
|
642
|
+
return tuple(loaded_imgs), tuple(noise)
|
|
643
|
+
|
|
644
|
+
def get_mean_std(self):
|
|
645
|
+
return self._mean, self._std
|
|
646
|
+
|
|
647
|
+
def set_mean_std(self, mean_val, std_val):
|
|
648
|
+
self._mean = mean_val
|
|
649
|
+
self._std = std_val
|
|
650
|
+
|
|
651
|
+
def normalize_img(self, *img_tuples):
|
|
652
|
+
mean, std = self.get_mean_std()
|
|
653
|
+
mean = mean["target"]
|
|
654
|
+
std = std["target"]
|
|
655
|
+
mean = mean.squeeze()
|
|
656
|
+
std = std.squeeze()
|
|
657
|
+
normalized_imgs = []
|
|
658
|
+
for i, img in enumerate(img_tuples):
|
|
659
|
+
img = (img - mean[i]) / std[i]
|
|
660
|
+
normalized_imgs.append(img)
|
|
661
|
+
return tuple(normalized_imgs)
|
|
662
|
+
|
|
663
|
+
def get_grid_size(self):
|
|
664
|
+
return self._grid_sz
|
|
665
|
+
|
|
666
|
+
def get_idx_manager(self):
|
|
667
|
+
return self.idx_manager
|
|
668
|
+
|
|
669
|
+
def per_side_overlap_pixelcount(self):
|
|
670
|
+
return (self._img_sz - self._grid_sz) // 2
|
|
671
|
+
|
|
672
|
+
def on_boundary(self, cur_loc, frame_size):
|
|
673
|
+
return cur_loc + self._img_sz > frame_size or cur_loc < 0
|
|
674
|
+
|
|
675
|
+
def _get_deterministic_hw(self, index: Union[int, Tuple[int, int]]):
|
|
676
|
+
"""
|
|
677
|
+
It returns the top-left corner of the patch corresponding to index.
|
|
678
|
+
"""
|
|
679
|
+
if isinstance(index, int) or isinstance(index, np.int64):
|
|
680
|
+
idx = index
|
|
681
|
+
grid_size = self._grid_sz
|
|
682
|
+
else:
|
|
683
|
+
idx, grid_size = index
|
|
684
|
+
|
|
685
|
+
h_start, w_start = self.idx_manager.get_deterministic_hw(
|
|
686
|
+
idx, grid_size=grid_size
|
|
687
|
+
)
|
|
688
|
+
if self._grid_alignment == GridAlignement.LeftTop:
|
|
689
|
+
return h_start, w_start
|
|
690
|
+
elif self._grid_alignment == GridAlignement.Center:
|
|
691
|
+
pad = self.per_side_overlap_pixelcount()
|
|
692
|
+
return h_start - pad, w_start - pad
|
|
693
|
+
|
|
694
|
+
def compute_individual_mean_std(self):
|
|
695
|
+
# numpy 1.19.2 has issues in computing for large arrays. https://github.com/numpy/numpy/issues/8869
|
|
696
|
+
# mean = np.mean(self._data, axis=(0, 1, 2))
|
|
697
|
+
# std = np.std(self._data, axis=(0, 1, 2))
|
|
698
|
+
mean_arr = []
|
|
699
|
+
std_arr = []
|
|
700
|
+
for ch_idx in range(self._data.shape[-1]):
|
|
701
|
+
mean_ = (
|
|
702
|
+
0.0
|
|
703
|
+
if self._skip_normalization_using_mean
|
|
704
|
+
else self._data[..., ch_idx].mean()
|
|
705
|
+
)
|
|
706
|
+
if self._noise_data is not None:
|
|
707
|
+
std_ = (
|
|
708
|
+
self._data[..., ch_idx] + self._noise_data[..., ch_idx + 1]
|
|
709
|
+
).std()
|
|
710
|
+
else:
|
|
711
|
+
std_ = self._data[..., ch_idx].std()
|
|
712
|
+
|
|
713
|
+
mean_arr.append(mean_)
|
|
714
|
+
std_arr.append(std_)
|
|
715
|
+
|
|
716
|
+
mean = np.array(mean_arr)
|
|
717
|
+
std = np.array(std_arr)
|
|
718
|
+
|
|
719
|
+
return mean[None, :, None, None], std[None, :, None, None]
|
|
720
|
+
|
|
721
|
+
def compute_mean_std(self, allow_for_validation_data=False):
|
|
722
|
+
"""
|
|
723
|
+
Note that we must compute this only for training data.
|
|
724
|
+
"""
|
|
725
|
+
assert (
|
|
726
|
+
self._is_train is True or allow_for_validation_data
|
|
727
|
+
), "This is just allowed for training data"
|
|
728
|
+
assert self._use_one_mu_std is True, "This is the only supported case"
|
|
729
|
+
|
|
730
|
+
if self._input_idx is not None:
|
|
731
|
+
assert (
|
|
732
|
+
self._tar_idx_list is not None
|
|
733
|
+
), "tar_idx_list must be set if input_idx is set."
|
|
734
|
+
assert self._noise_data is None, "This is not supported with noise"
|
|
735
|
+
assert (
|
|
736
|
+
self._target_separate_normalization is True
|
|
737
|
+
), "This is not supported with target_separate_normalization=False"
|
|
738
|
+
|
|
739
|
+
mean, std = self.compute_individual_mean_std()
|
|
740
|
+
mean_dict = {
|
|
741
|
+
"input": mean[:, self._input_idx : self._input_idx + 1],
|
|
742
|
+
"target": mean[:, self._tar_idx_list],
|
|
743
|
+
}
|
|
744
|
+
std_dict = {
|
|
745
|
+
"input": std[:, self._input_idx : self._input_idx + 1],
|
|
746
|
+
"target": std[:, self._tar_idx_list],
|
|
747
|
+
}
|
|
748
|
+
return mean_dict, std_dict
|
|
749
|
+
|
|
750
|
+
if self._input_is_sum:
|
|
751
|
+
assert self._noise_data is None, "This is not supported with noise"
|
|
752
|
+
mean = [
|
|
753
|
+
np.mean(self._data[..., k : k + 1], keepdims=True)
|
|
754
|
+
for k in range(self._num_channels)
|
|
755
|
+
]
|
|
756
|
+
mean = np.sum(mean, keepdims=True)[0]
|
|
757
|
+
std = np.linalg.norm(
|
|
758
|
+
[
|
|
759
|
+
np.std(self._data[..., k : k + 1], keepdims=True)
|
|
760
|
+
for k in range(self._num_channels)
|
|
761
|
+
],
|
|
762
|
+
keepdims=True,
|
|
763
|
+
)[0]
|
|
764
|
+
else:
|
|
765
|
+
mean = np.mean(self._data, keepdims=True).reshape(1, 1, 1, 1)
|
|
766
|
+
if self._noise_data is not None:
|
|
767
|
+
std = np.std(
|
|
768
|
+
self._data + self._noise_data[..., 1:], keepdims=True
|
|
769
|
+
).reshape(1, 1, 1, 1)
|
|
770
|
+
else:
|
|
771
|
+
std = np.std(self._data, keepdims=True).reshape(1, 1, 1, 1)
|
|
772
|
+
|
|
773
|
+
mean = np.repeat(mean, self._num_channels, axis=1)
|
|
774
|
+
std = np.repeat(std, self._num_channels, axis=1)
|
|
775
|
+
|
|
776
|
+
if self._skip_normalization_using_mean:
|
|
777
|
+
mean = np.zeros_like(mean)
|
|
778
|
+
|
|
779
|
+
mean_dict = {"input": mean} # , 'target':mean}
|
|
780
|
+
std_dict = {"input": std} # , 'target':std}
|
|
781
|
+
|
|
782
|
+
if self._target_separate_normalization:
|
|
783
|
+
mean, std = self.compute_individual_mean_std()
|
|
784
|
+
|
|
785
|
+
mean_dict["target"] = mean
|
|
786
|
+
std_dict["target"] = std
|
|
787
|
+
return mean_dict, std_dict
|
|
788
|
+
|
|
789
|
+
def _get_random_hw(self, h: int, w: int):
|
|
790
|
+
"""
|
|
791
|
+
Random starting position for the crop for the img with index `index`.
|
|
792
|
+
"""
|
|
793
|
+
if h != self._img_sz:
|
|
794
|
+
h_start = np.random.choice(h - self._img_sz)
|
|
795
|
+
w_start = np.random.choice(w - self._img_sz)
|
|
796
|
+
else:
|
|
797
|
+
h_start = 0
|
|
798
|
+
w_start = 0
|
|
799
|
+
return h_start, w_start
|
|
800
|
+
|
|
801
|
+
def _get_img(self, index: Union[int, Tuple[int, int]]):
|
|
802
|
+
"""
|
|
803
|
+
Loads an image.
|
|
804
|
+
Crops the image such that cropped image has content.
|
|
805
|
+
"""
|
|
806
|
+
img_tuples, noise_tuples = self._load_img(index)
|
|
807
|
+
cropped_img_tuples = self._crop_imgs(index, *img_tuples, *noise_tuples)[:-1]
|
|
808
|
+
cropped_noise_tuples = cropped_img_tuples[len(img_tuples) :]
|
|
809
|
+
cropped_img_tuples = cropped_img_tuples[: len(img_tuples)]
|
|
810
|
+
return cropped_img_tuples, cropped_noise_tuples
|
|
811
|
+
|
|
812
|
+
def replace_with_empty_patch(self, img_tuples):
|
|
813
|
+
empty_index = self._empty_patch_fetcher.sample()
|
|
814
|
+
empty_img_tuples = self._get_img(empty_index)
|
|
815
|
+
final_img_tuples = []
|
|
816
|
+
for tuple_idx in range(len(img_tuples)):
|
|
817
|
+
if tuple_idx == self._empty_patch_replacement_channel_idx:
|
|
818
|
+
final_img_tuples.append(empty_img_tuples[tuple_idx])
|
|
819
|
+
else:
|
|
820
|
+
final_img_tuples.append(img_tuples[tuple_idx])
|
|
821
|
+
return tuple(final_img_tuples)
|
|
822
|
+
|
|
823
|
+
def get_mean_std_for_input(self):
|
|
824
|
+
mean, std = self.get_mean_std()
|
|
825
|
+
return mean["input"], std["input"]
|
|
826
|
+
|
|
827
|
+
def _compute_target(self, img_tuples, alpha):
|
|
828
|
+
if self._tar_idx_list is not None and isinstance(self._tar_idx_list, int):
|
|
829
|
+
target = img_tuples[self._tar_idx_list]
|
|
830
|
+
else:
|
|
831
|
+
if self._tar_idx_list is not None:
|
|
832
|
+
assert isinstance(self._tar_idx_list, list) or isinstance(
|
|
833
|
+
self._tar_idx_list, tuple
|
|
834
|
+
)
|
|
835
|
+
img_tuples = [img_tuples[i] for i in self._tar_idx_list]
|
|
836
|
+
|
|
837
|
+
if self._alpha_weighted_target:
|
|
838
|
+
assert self._input_is_sum is False
|
|
839
|
+
target = []
|
|
840
|
+
for i in range(len(img_tuples)):
|
|
841
|
+
target.append(img_tuples[i] * alpha[i])
|
|
842
|
+
target = np.concatenate(target, axis=0)
|
|
843
|
+
else:
|
|
844
|
+
target = np.concatenate(img_tuples, axis=0)
|
|
845
|
+
return target
|
|
846
|
+
|
|
847
|
+
def _compute_input_with_alpha(self, img_tuples, alpha_list):
|
|
848
|
+
# assert self._normalized_input is True, "normalization should happen here"
|
|
849
|
+
if self._input_idx is not None:
|
|
850
|
+
inp = img_tuples[self._input_idx]
|
|
851
|
+
else:
|
|
852
|
+
inp = 0
|
|
853
|
+
for alpha, img in zip(alpha_list, img_tuples):
|
|
854
|
+
inp += img * alpha
|
|
855
|
+
|
|
856
|
+
if self._normalized_input is False:
|
|
857
|
+
return inp.astype(np.float32)
|
|
858
|
+
|
|
859
|
+
mean, std = self.get_mean_std_for_input()
|
|
860
|
+
mean = mean.squeeze()
|
|
861
|
+
std = std.squeeze()
|
|
862
|
+
if mean.size == 1:
|
|
863
|
+
mean = mean.reshape(
|
|
864
|
+
1,
|
|
865
|
+
)
|
|
866
|
+
std = std.reshape(
|
|
867
|
+
1,
|
|
868
|
+
)
|
|
869
|
+
|
|
870
|
+
for i in range(len(mean)):
|
|
871
|
+
assert mean[0] == mean[i]
|
|
872
|
+
assert std[0] == std[i]
|
|
873
|
+
|
|
874
|
+
inp = (inp - mean[0]) / std[0]
|
|
875
|
+
return inp.astype(np.float32)
|
|
876
|
+
|
|
877
|
+
def _sample_alpha(self):
|
|
878
|
+
alpha_arr = []
|
|
879
|
+
for i in range(self._num_channels):
|
|
880
|
+
alpha_pos = np.random.rand()
|
|
881
|
+
alpha = self._start_alpha_arr[i] + alpha_pos * (
|
|
882
|
+
self._end_alpha_arr[i] - self._start_alpha_arr[i]
|
|
883
|
+
)
|
|
884
|
+
alpha_arr.append(alpha)
|
|
885
|
+
return alpha_arr
|
|
886
|
+
|
|
887
|
+
def _compute_input(self, img_tuples):
|
|
888
|
+
alpha = [1 / len(img_tuples) for _ in range(len(img_tuples))]
|
|
889
|
+
if self._start_alpha_arr is not None:
|
|
890
|
+
alpha = self._sample_alpha()
|
|
891
|
+
|
|
892
|
+
inp = self._compute_input_with_alpha(img_tuples, alpha)
|
|
893
|
+
if self._input_is_sum:
|
|
894
|
+
inp = len(img_tuples) * inp
|
|
895
|
+
return inp, alpha
|
|
896
|
+
|
|
897
|
+
def _get_index_from_valid_target_logic(self, index):
|
|
898
|
+
if self._validtarget_rand_fract is not None:
|
|
899
|
+
if np.random.rand() < self._validtarget_rand_fract:
|
|
900
|
+
index = self._train_index_switcher.get_valid_target_index()
|
|
901
|
+
else:
|
|
902
|
+
index = self._train_index_switcher.get_invalid_target_index()
|
|
903
|
+
return index
|
|
904
|
+
|
|
905
|
+
def _rotate(self, img_tuples, noise_tuples):
|
|
906
|
+
return self._rotate2D(img_tuples, noise_tuples)
|
|
907
|
+
|
|
908
|
+
def _rotate2D(self, img_tuples, noise_tuples):
|
|
909
|
+
img_kwargs = {}
|
|
910
|
+
for i, img in enumerate(img_tuples):
|
|
911
|
+
for k in range(len(img)):
|
|
912
|
+
img_kwargs[f"img{i}_{k}"] = img[k]
|
|
913
|
+
|
|
914
|
+
noise_kwargs = {}
|
|
915
|
+
for i, nimg in enumerate(noise_tuples):
|
|
916
|
+
for k in range(len(nimg)):
|
|
917
|
+
noise_kwargs[f"noise{i}_{k}"] = nimg[k]
|
|
918
|
+
|
|
919
|
+
keys = list(img_kwargs.keys()) + list(noise_kwargs.keys())
|
|
920
|
+
self._rotation_transform.add_targets({k: "image" for k in keys})
|
|
921
|
+
rot_dic = self._rotation_transform(
|
|
922
|
+
image=img_tuples[0][0], **img_kwargs, **noise_kwargs
|
|
923
|
+
)
|
|
924
|
+
rotated_img_tuples = []
|
|
925
|
+
for i, img in enumerate(img_tuples):
|
|
926
|
+
if len(img) == 1:
|
|
927
|
+
rotated_img_tuples.append(rot_dic[f"img{i}_0"][None])
|
|
928
|
+
else:
|
|
929
|
+
rotated_img_tuples.append(
|
|
930
|
+
np.concatenate(
|
|
931
|
+
[rot_dic[f"img{i}_{k}"][None] for k in range(len(img))], axis=0
|
|
932
|
+
)
|
|
933
|
+
)
|
|
934
|
+
|
|
935
|
+
rotated_noise_tuples = []
|
|
936
|
+
for i, nimg in enumerate(noise_tuples):
|
|
937
|
+
if len(nimg) == 1:
|
|
938
|
+
rotated_noise_tuples.append(rot_dic[f"noise{i}_0"][None])
|
|
939
|
+
else:
|
|
940
|
+
rotated_noise_tuples.append(
|
|
941
|
+
np.concatenate(
|
|
942
|
+
[rot_dic[f"noise{i}_{k}"][None] for k in range(len(nimg))],
|
|
943
|
+
axis=0,
|
|
944
|
+
)
|
|
945
|
+
)
|
|
946
|
+
|
|
947
|
+
return rotated_img_tuples, rotated_noise_tuples
|
|
948
|
+
|
|
949
|
+
def get_uncorrelated_img_tuples(self, index):
|
|
950
|
+
img_tuples, noise_tuples = self._get_img(index)
|
|
951
|
+
assert len(noise_tuples) == 0
|
|
952
|
+
img_tuples = [img_tuples[0]]
|
|
953
|
+
for ch_idx in range(1, len(img_tuples)):
|
|
954
|
+
new_index = np.random.randint(len(self))
|
|
955
|
+
other_img_tuples, _ = self._get_img(new_index)
|
|
956
|
+
img_tuples.append(other_img_tuples[ch_idx])
|
|
957
|
+
return img_tuples, noise_tuples
|
|
958
|
+
|
|
959
|
+
def __getitem__(
|
|
960
|
+
self, index: Union[int, Tuple[int, int]]
|
|
961
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
962
|
+
if self._train_index_switcher is not None:
|
|
963
|
+
index = self._get_index_from_valid_target_logic(index)
|
|
964
|
+
|
|
965
|
+
if self._uncorrelated_channels:
|
|
966
|
+
img_tuples, noise_tuples = self.get_uncorrelated_img_tuples(index)
|
|
967
|
+
else:
|
|
968
|
+
img_tuples, noise_tuples = self._get_img(index)
|
|
969
|
+
|
|
970
|
+
assert (
|
|
971
|
+
self._empty_patch_replacement_enabled != True
|
|
972
|
+
), "This is not supported with noise"
|
|
973
|
+
|
|
974
|
+
if self._empty_patch_replacement_enabled:
|
|
975
|
+
if np.random.rand() < self._empty_patch_replacement_probab:
|
|
976
|
+
img_tuples = self.replace_with_empty_patch(img_tuples)
|
|
977
|
+
|
|
978
|
+
if self._enable_rotation:
|
|
979
|
+
img_tuples, noise_tuples = self._rotate(img_tuples, noise_tuples)
|
|
980
|
+
|
|
981
|
+
# add noise to input
|
|
982
|
+
if len(noise_tuples) > 0:
|
|
983
|
+
factor = np.sqrt(2) if self._input_is_sum else 1.0
|
|
984
|
+
input_tuples = [x + noise_tuples[0] * factor for x in img_tuples]
|
|
985
|
+
else:
|
|
986
|
+
input_tuples = img_tuples
|
|
987
|
+
inp, alpha = self._compute_input(input_tuples)
|
|
988
|
+
|
|
989
|
+
# add noise to target.
|
|
990
|
+
if len(noise_tuples) >= 1:
|
|
991
|
+
img_tuples = [x + noise for x, noise in zip(img_tuples, noise_tuples[1:])]
|
|
992
|
+
|
|
993
|
+
target = self._compute_target(img_tuples, alpha)
|
|
994
|
+
|
|
995
|
+
output = [inp, target]
|
|
996
|
+
|
|
997
|
+
if self._return_alpha:
|
|
998
|
+
output.append(alpha)
|
|
999
|
+
|
|
1000
|
+
if self._return_index:
|
|
1001
|
+
output.append(index)
|
|
1002
|
+
|
|
1003
|
+
if isinstance(index, int) or isinstance(index, np.int64):
|
|
1004
|
+
return tuple(output)
|
|
1005
|
+
|
|
1006
|
+
_, grid_size = index
|
|
1007
|
+
output.append(grid_size)
|
|
1008
|
+
return tuple(output)
|
|
1009
|
+
|
|
1010
|
+
|
|
1011
|
+
class LCMultiChDloader(MultiChDloader):
|
|
1012
|
+
|
|
1013
|
+
def __init__(
|
|
1014
|
+
self,
|
|
1015
|
+
data_config,
|
|
1016
|
+
fpath: str,
|
|
1017
|
+
datasplit_type: DataSplitType = None,
|
|
1018
|
+
val_fraction=None,
|
|
1019
|
+
test_fraction=None,
|
|
1020
|
+
normalized_input=None,
|
|
1021
|
+
enable_rotation_aug: bool = False,
|
|
1022
|
+
use_one_mu_std=None,
|
|
1023
|
+
num_scales: int = None,
|
|
1024
|
+
enable_random_cropping=False,
|
|
1025
|
+
padding_kwargs: dict = None,
|
|
1026
|
+
allow_generation: bool = False,
|
|
1027
|
+
lowres_supervision=None,
|
|
1028
|
+
max_val=None,
|
|
1029
|
+
grid_alignment=GridAlignement.LeftTop,
|
|
1030
|
+
overlapping_padding_kwargs=None,
|
|
1031
|
+
print_vars=True,
|
|
1032
|
+
):
|
|
1033
|
+
"""
|
|
1034
|
+
Args:
|
|
1035
|
+
num_scales: The number of resolutions at which we want the input. Note that the target is formed at the
|
|
1036
|
+
highest resolution.
|
|
1037
|
+
"""
|
|
1038
|
+
self._padding_kwargs = (
|
|
1039
|
+
padding_kwargs # mode=padding_mode, constant_values=constant_value
|
|
1040
|
+
)
|
|
1041
|
+
if overlapping_padding_kwargs is not None:
|
|
1042
|
+
assert (
|
|
1043
|
+
self._padding_kwargs == overlapping_padding_kwargs
|
|
1044
|
+
), "During evaluation, overlapping_padding_kwargs should be same as padding_args. \
|
|
1045
|
+
It should be so since we just use overlapping_padding_kwargs when it is not None"
|
|
1046
|
+
|
|
1047
|
+
else:
|
|
1048
|
+
overlapping_padding_kwargs = padding_kwargs
|
|
1049
|
+
|
|
1050
|
+
super().__init__(
|
|
1051
|
+
data_config,
|
|
1052
|
+
fpath,
|
|
1053
|
+
datasplit_type=datasplit_type,
|
|
1054
|
+
val_fraction=val_fraction,
|
|
1055
|
+
test_fraction=test_fraction,
|
|
1056
|
+
normalized_input=normalized_input,
|
|
1057
|
+
enable_rotation_aug=enable_rotation_aug,
|
|
1058
|
+
enable_random_cropping=enable_random_cropping,
|
|
1059
|
+
use_one_mu_std=use_one_mu_std,
|
|
1060
|
+
allow_generation=allow_generation,
|
|
1061
|
+
max_val=max_val,
|
|
1062
|
+
grid_alignment=grid_alignment,
|
|
1063
|
+
overlapping_padding_kwargs=overlapping_padding_kwargs,
|
|
1064
|
+
print_vars=print_vars,
|
|
1065
|
+
)
|
|
1066
|
+
self.num_scales = num_scales
|
|
1067
|
+
assert self.num_scales is not None
|
|
1068
|
+
self._scaled_data = [self._data]
|
|
1069
|
+
self._scaled_noise_data = [self._noise_data]
|
|
1070
|
+
|
|
1071
|
+
assert isinstance(self.num_scales, int) and self.num_scales >= 1
|
|
1072
|
+
self._lowres_supervision = lowres_supervision
|
|
1073
|
+
assert isinstance(self._padding_kwargs, dict)
|
|
1074
|
+
assert "mode" in self._padding_kwargs
|
|
1075
|
+
|
|
1076
|
+
for _ in range(1, self.num_scales):
|
|
1077
|
+
shape = self._scaled_data[-1].shape
|
|
1078
|
+
assert len(shape) == 4
|
|
1079
|
+
new_shape = (shape[0], shape[1] // 2, shape[2] // 2, shape[3])
|
|
1080
|
+
ds_data = resize(
|
|
1081
|
+
self._scaled_data[-1].astype(np.float32), new_shape
|
|
1082
|
+
).astype(self._scaled_data[-1].dtype)
|
|
1083
|
+
# NOTE: These asserts are important. the resize method expects np.float32. otherwise, one gets weird results.
|
|
1084
|
+
assert (
|
|
1085
|
+
ds_data.max() / self._scaled_data[-1].max() < 5
|
|
1086
|
+
), "Downsampled image should not have very different values"
|
|
1087
|
+
assert (
|
|
1088
|
+
ds_data.max() / self._scaled_data[-1].max() > 0.2
|
|
1089
|
+
), "Downsampled image should not have very different values"
|
|
1090
|
+
|
|
1091
|
+
self._scaled_data.append(ds_data)
|
|
1092
|
+
# do the same for noise
|
|
1093
|
+
if self._noise_data is not None:
|
|
1094
|
+
noise_data = resize(self._scaled_noise_data[-1], new_shape)
|
|
1095
|
+
self._scaled_noise_data.append(noise_data)
|
|
1096
|
+
|
|
1097
|
+
def _init_msg(self):
|
|
1098
|
+
msg = super()._init_msg()
|
|
1099
|
+
msg += f" Pad:{self._padding_kwargs}"
|
|
1100
|
+
return msg
|
|
1101
|
+
|
|
1102
|
+
def _load_scaled_img(
|
|
1103
|
+
self, scaled_index, index: Union[int, Tuple[int, int]]
|
|
1104
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
1105
|
+
if isinstance(index, int):
|
|
1106
|
+
idx = index
|
|
1107
|
+
else:
|
|
1108
|
+
idx, _ = index
|
|
1109
|
+
imgs = self._scaled_data[scaled_index][idx % self.N]
|
|
1110
|
+
imgs = tuple([imgs[None, :, :, i] for i in range(imgs.shape[-1])])
|
|
1111
|
+
if self._noise_data is not None:
|
|
1112
|
+
noisedata = self._scaled_noise_data[scaled_index][idx % self.N]
|
|
1113
|
+
noise = tuple(
|
|
1114
|
+
[noisedata[None, :, :, i] for i in range(noisedata.shape[-1])]
|
|
1115
|
+
)
|
|
1116
|
+
factor = np.sqrt(2) if self._input_is_sum else 1.0
|
|
1117
|
+
# since we are using this lowres images for just the input, we need to add the noise of the input.
|
|
1118
|
+
assert self._lowres_supervision is None or self._lowres_supervision is False
|
|
1119
|
+
imgs = tuple([img + noise[0] * factor for img in imgs])
|
|
1120
|
+
return imgs
|
|
1121
|
+
|
|
1122
|
+
def _crop_img(self, img: np.ndarray, h_start: int, w_start: int):
|
|
1123
|
+
"""
|
|
1124
|
+
Here, h_start, w_start could be negative. That simply means we need to pick the content from 0. So,
|
|
1125
|
+
the cropped image will be smaller than self._img_sz * self._img_sz
|
|
1126
|
+
"""
|
|
1127
|
+
return self._crop_img_with_padding(img, h_start, w_start)
|
|
1128
|
+
|
|
1129
|
+
def _get_img(self, index: int):
|
|
1130
|
+
"""
|
|
1131
|
+
Returns the primary patch along with low resolution patches centered on the primary patch.
|
|
1132
|
+
"""
|
|
1133
|
+
img_tuples, noise_tuples = self._load_img(index)
|
|
1134
|
+
assert self._img_sz is not None
|
|
1135
|
+
h, w = img_tuples[0].shape[-2:]
|
|
1136
|
+
if self._enable_random_cropping:
|
|
1137
|
+
h_start, w_start = self._get_random_hw(h, w)
|
|
1138
|
+
else:
|
|
1139
|
+
h_start, w_start = self._get_deterministic_hw(index)
|
|
1140
|
+
|
|
1141
|
+
cropped_img_tuples = [
|
|
1142
|
+
self._crop_flip_img(img, h_start, w_start, False, False)
|
|
1143
|
+
for img in img_tuples
|
|
1144
|
+
]
|
|
1145
|
+
cropped_noise_tuples = [
|
|
1146
|
+
self._crop_flip_img(noise, h_start, w_start, False, False)
|
|
1147
|
+
for noise in noise_tuples
|
|
1148
|
+
]
|
|
1149
|
+
h_center = h_start + self._img_sz // 2
|
|
1150
|
+
w_center = w_start + self._img_sz // 2
|
|
1151
|
+
allres_versions = {
|
|
1152
|
+
i: [cropped_img_tuples[i]] for i in range(len(cropped_img_tuples))
|
|
1153
|
+
}
|
|
1154
|
+
for scale_idx in range(1, self.num_scales):
|
|
1155
|
+
scaled_img_tuples = self._load_scaled_img(scale_idx, index)
|
|
1156
|
+
|
|
1157
|
+
h_center = h_center // 2
|
|
1158
|
+
w_center = w_center // 2
|
|
1159
|
+
|
|
1160
|
+
h_start = h_center - self._img_sz // 2
|
|
1161
|
+
w_start = w_center - self._img_sz // 2
|
|
1162
|
+
|
|
1163
|
+
scaled_cropped_img_tuples = [
|
|
1164
|
+
self._crop_flip_img(img, h_start, w_start, False, False)
|
|
1165
|
+
for img in scaled_img_tuples
|
|
1166
|
+
]
|
|
1167
|
+
for ch_idx in range(len(img_tuples)):
|
|
1168
|
+
allres_versions[ch_idx].append(scaled_cropped_img_tuples[ch_idx])
|
|
1169
|
+
|
|
1170
|
+
output_img_tuples = tuple(
|
|
1171
|
+
[
|
|
1172
|
+
np.concatenate(allres_versions[ch_idx])
|
|
1173
|
+
for ch_idx in range(len(img_tuples))
|
|
1174
|
+
]
|
|
1175
|
+
)
|
|
1176
|
+
return output_img_tuples, cropped_noise_tuples
|
|
1177
|
+
|
|
1178
|
+
def __getitem__(self, index: Union[int, Tuple[int, int]]):
|
|
1179
|
+
if self._uncorrelated_channels:
|
|
1180
|
+
img_tuples, noise_tuples = self.get_uncorrelated_img_tuples(index)
|
|
1181
|
+
else:
|
|
1182
|
+
img_tuples, noise_tuples = self._get_img(index)
|
|
1183
|
+
|
|
1184
|
+
if self._enable_rotation:
|
|
1185
|
+
img_tuples, noise_tuples = self._rotate(img_tuples, noise_tuples)
|
|
1186
|
+
|
|
1187
|
+
assert self._lowres_supervision != True
|
|
1188
|
+
# add noise to input
|
|
1189
|
+
if len(noise_tuples) > 0:
|
|
1190
|
+
factor = np.sqrt(2) if self._input_is_sum else 1.0
|
|
1191
|
+
input_tuples = []
|
|
1192
|
+
for x in img_tuples:
|
|
1193
|
+
# NOTE: other LC levels already have noise added. So, we just need to add noise to the highest resolution.
|
|
1194
|
+
x[0] = x[0] + noise_tuples[0] * factor
|
|
1195
|
+
input_tuples.append(x)
|
|
1196
|
+
else:
|
|
1197
|
+
input_tuples = img_tuples
|
|
1198
|
+
|
|
1199
|
+
inp, alpha = self._compute_input(input_tuples)
|
|
1200
|
+
# assert self._alpha_weighted_target in [False, None]
|
|
1201
|
+
target_tuples = [img[:1] for img in img_tuples]
|
|
1202
|
+
# add noise to target.
|
|
1203
|
+
if len(noise_tuples) >= 1:
|
|
1204
|
+
target_tuples = [
|
|
1205
|
+
x + noise for x, noise in zip(target_tuples, noise_tuples[1:])
|
|
1206
|
+
]
|
|
1207
|
+
|
|
1208
|
+
target = self._compute_target(target_tuples, alpha)
|
|
1209
|
+
|
|
1210
|
+
output = [inp, target]
|
|
1211
|
+
|
|
1212
|
+
if self._return_alpha:
|
|
1213
|
+
output.append(alpha)
|
|
1214
|
+
|
|
1215
|
+
if isinstance(index, int):
|
|
1216
|
+
return tuple(output)
|
|
1217
|
+
|
|
1218
|
+
_, grid_size = index
|
|
1219
|
+
output.append(grid_size)
|
|
1220
|
+
return tuple(output)
|