careamics 0.0.3__py3-none-any.whl → 0.0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +25 -17
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +391 -0
- careamics/cli/main.py +134 -0
- careamics/config/architectures/lvae_model.py +0 -4
- careamics/config/configuration_factory.py +480 -177
- careamics/config/configuration_model.py +1 -2
- careamics/config/data_model.py +1 -15
- careamics/config/fcn_algorithm_model.py +14 -9
- careamics/config/likelihood_model.py +21 -4
- careamics/config/nm_model.py +31 -5
- careamics/config/optimizer_models.py +3 -1
- careamics/config/support/supported_optimizers.py +1 -1
- careamics/config/support/supported_transforms.py +1 -0
- careamics/config/training_model.py +35 -6
- careamics/config/transformations/__init__.py +4 -1
- careamics/config/transformations/transform_union.py +20 -0
- careamics/config/vae_algorithm_model.py +2 -36
- careamics/dataset/tiling/lvae_tiled_patching.py +90 -8
- careamics/lightning/lightning_module.py +10 -8
- careamics/lightning/train_data_module.py +2 -2
- careamics/losses/loss_factory.py +3 -3
- careamics/losses/lvae/losses.py +2 -2
- careamics/lvae_training/dataset/__init__.py +15 -0
- careamics/lvae_training/dataset/{vae_data_config.py → config.py} +25 -81
- careamics/lvae_training/dataset/lc_dataset.py +28 -20
- careamics/lvae_training/dataset/{vae_dataset.py → multich_dataset.py} +91 -51
- careamics/lvae_training/dataset/multifile_dataset.py +334 -0
- careamics/lvae_training/dataset/types.py +43 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +232 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +109 -64
- careamics/lvae_training/get_config.py +1 -1
- careamics/lvae_training/train_lvae.py +1 -1
- careamics/model_io/bioimage/bioimage_utils.py +4 -2
- careamics/model_io/bmz_io.py +6 -5
- careamics/models/lvae/likelihoods.py +18 -9
- careamics/models/lvae/lvae.py +12 -16
- careamics/models/lvae/noise_models.py +1 -1
- careamics/transforms/compose.py +90 -15
- careamics/transforms/n2v_manipulate.py +6 -2
- careamics/transforms/normalize.py +14 -3
- careamics/transforms/xy_flip.py +16 -6
- careamics/transforms/xy_random_rotate90.py +16 -7
- careamics/utils/metrics.py +204 -24
- careamics/utils/serializers.py +60 -0
- {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/METADATA +4 -3
- {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/RECORD +54 -43
- careamics-0.0.4.1.dist-info/entry_points.txt +2 -0
- careamics/lvae_training/dataset/data_utils.py +0 -701
- careamics/lvae_training/dataset/lc_dataset_config.py +0 -13
- {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/WHEEL +0 -0
- {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,334 @@
|
|
|
1
|
+
from typing import Union, Callable, Sequence
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from numpy.typing import NDArray
|
|
5
|
+
|
|
6
|
+
from .config import DatasetConfig
|
|
7
|
+
from .multich_dataset import MultiChDloader
|
|
8
|
+
from .types import DataSplitType
|
|
9
|
+
from .lc_dataset import LCMultiChDloader
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TwoChannelData(Sequence):
|
|
13
|
+
"""
|
|
14
|
+
each element in data_arr should be a N*H*W array
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, data_arr1, data_arr2, paths_data1=None, paths_data2=None):
|
|
18
|
+
assert len(data_arr1) == len(data_arr2)
|
|
19
|
+
self.paths1 = paths_data1
|
|
20
|
+
self.paths2 = paths_data2
|
|
21
|
+
|
|
22
|
+
self._data = []
|
|
23
|
+
for i in range(len(data_arr1)):
|
|
24
|
+
assert data_arr1[i].shape == data_arr2[i].shape
|
|
25
|
+
assert (
|
|
26
|
+
len(data_arr1[i].shape) == 3
|
|
27
|
+
), f"Each element in data arrays should be a N*H*W, but {data_arr1[i].shape}"
|
|
28
|
+
self._data.append(
|
|
29
|
+
np.concatenate(
|
|
30
|
+
[data_arr1[i][..., None], data_arr2[i][..., None]], axis=-1
|
|
31
|
+
)
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
def __len__(self):
|
|
35
|
+
n = 0
|
|
36
|
+
for x in self._data:
|
|
37
|
+
n += x.shape[0]
|
|
38
|
+
return n
|
|
39
|
+
|
|
40
|
+
def __getitem__(self, idx):
|
|
41
|
+
n = 0
|
|
42
|
+
for dataidx, x in enumerate(self._data):
|
|
43
|
+
if idx < n + x.shape[0]:
|
|
44
|
+
if self.paths1 is None:
|
|
45
|
+
return x[idx - n], None
|
|
46
|
+
else:
|
|
47
|
+
return x[idx - n], (self.paths1[dataidx], self.paths2[dataidx])
|
|
48
|
+
n += x.shape[0]
|
|
49
|
+
raise IndexError("Index out of range")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class MultiChannelData(Sequence):
|
|
53
|
+
"""
|
|
54
|
+
each element in data_arr should be a N*H*W array
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def __init__(self, data_arr, paths=None):
|
|
58
|
+
self.paths = paths
|
|
59
|
+
|
|
60
|
+
self._data = data_arr
|
|
61
|
+
|
|
62
|
+
def __len__(self):
|
|
63
|
+
n = 0
|
|
64
|
+
for x in self._data:
|
|
65
|
+
n += x.shape[0]
|
|
66
|
+
return n
|
|
67
|
+
|
|
68
|
+
def __getitem__(self, idx):
|
|
69
|
+
n = 0
|
|
70
|
+
for dataidx, x in enumerate(self._data):
|
|
71
|
+
if idx < n + x.shape[0]:
|
|
72
|
+
if self.paths is None:
|
|
73
|
+
return x[idx - n], None
|
|
74
|
+
else:
|
|
75
|
+
return x[idx - n], (self.paths[dataidx])
|
|
76
|
+
n += x.shape[0]
|
|
77
|
+
raise IndexError("Index out of range")
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class SingleFileLCDset(LCMultiChDloader):
|
|
81
|
+
def __init__(
|
|
82
|
+
self,
|
|
83
|
+
preloaded_data: NDArray,
|
|
84
|
+
data_config: DatasetConfig,
|
|
85
|
+
fpath: str,
|
|
86
|
+
load_data_fn: Callable,
|
|
87
|
+
val_fraction=None,
|
|
88
|
+
test_fraction=None,
|
|
89
|
+
):
|
|
90
|
+
self._preloaded_data = preloaded_data
|
|
91
|
+
super().__init__(
|
|
92
|
+
data_config,
|
|
93
|
+
fpath,
|
|
94
|
+
load_data_fn=load_data_fn,
|
|
95
|
+
val_fraction=val_fraction,
|
|
96
|
+
test_fraction=test_fraction,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def data_path(self):
|
|
101
|
+
return self._fpath
|
|
102
|
+
|
|
103
|
+
def rm_bkground_set_max_val_and_upperclip_data(self, max_val, datasplit_type):
|
|
104
|
+
pass
|
|
105
|
+
|
|
106
|
+
def load_data(
|
|
107
|
+
self,
|
|
108
|
+
data_config: DatasetConfig,
|
|
109
|
+
datasplit_type: DataSplitType,
|
|
110
|
+
load_data_fn: Callable,
|
|
111
|
+
val_fraction=None,
|
|
112
|
+
test_fraction=None,
|
|
113
|
+
allow_generation=None,
|
|
114
|
+
):
|
|
115
|
+
self._data = self._preloaded_data
|
|
116
|
+
assert "channel_1" not in data_config or isinstance(data_config.channel_1, str)
|
|
117
|
+
assert "channel_2" not in data_config or isinstance(data_config.channel_2, str)
|
|
118
|
+
assert "channel_3" not in data_config or isinstance(data_config.channel_3, str)
|
|
119
|
+
self._loaded_data_preprocessing(data_config)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class SingleFileDset(MultiChDloader):
|
|
123
|
+
def __init__(
|
|
124
|
+
self,
|
|
125
|
+
preloaded_data: NDArray,
|
|
126
|
+
data_config: DatasetConfig,
|
|
127
|
+
fpath: str,
|
|
128
|
+
load_data_fn: Callable,
|
|
129
|
+
val_fraction=None,
|
|
130
|
+
test_fraction=None,
|
|
131
|
+
):
|
|
132
|
+
self._preloaded_data = preloaded_data
|
|
133
|
+
super().__init__(
|
|
134
|
+
data_config,
|
|
135
|
+
fpath,
|
|
136
|
+
load_data_fn=load_data_fn,
|
|
137
|
+
val_fraction=val_fraction,
|
|
138
|
+
test_fraction=test_fraction,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
def rm_bkground_set_max_val_and_upperclip_data(self, max_val, datasplit_type):
|
|
142
|
+
pass
|
|
143
|
+
|
|
144
|
+
@property
|
|
145
|
+
def data_path(self):
|
|
146
|
+
return self._fpath
|
|
147
|
+
|
|
148
|
+
def load_data(
|
|
149
|
+
self,
|
|
150
|
+
data_config: DatasetConfig,
|
|
151
|
+
datasplit_type: DataSplitType,
|
|
152
|
+
load_data_fn: Callable[..., NDArray],
|
|
153
|
+
val_fraction=None,
|
|
154
|
+
test_fraction=None,
|
|
155
|
+
allow_generation=None,
|
|
156
|
+
):
|
|
157
|
+
self._data = self._preloaded_data
|
|
158
|
+
assert (
|
|
159
|
+
"channel_1" not in data_config
|
|
160
|
+
), "Outdated config file. Please remove channel_1, channel_2, channel_3 from the config file."
|
|
161
|
+
assert (
|
|
162
|
+
"channel_2" not in data_config
|
|
163
|
+
), "Outdated config file. Please remove channel_1, channel_2, channel_3 from the config file."
|
|
164
|
+
assert (
|
|
165
|
+
"channel_3" not in data_config
|
|
166
|
+
), "Outdated config file. Please remove channel_1, channel_2, channel_3 from the config file."
|
|
167
|
+
self._loaded_data_preprocessing(data_config)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class MultiFileDset:
|
|
171
|
+
"""
|
|
172
|
+
Here, we handle dataset having multiple files. Each file can have a different spatial dimension and number of frames (Z stack).
|
|
173
|
+
"""
|
|
174
|
+
|
|
175
|
+
def __init__(
|
|
176
|
+
self,
|
|
177
|
+
data_config: DatasetConfig,
|
|
178
|
+
fpath: str,
|
|
179
|
+
load_data_fn: Callable[..., Union[TwoChannelData, MultiChannelData]],
|
|
180
|
+
val_fraction=None,
|
|
181
|
+
test_fraction=None,
|
|
182
|
+
):
|
|
183
|
+
self._fpath = fpath
|
|
184
|
+
data: Union[TwoChannelData, MultiChannelData] = load_data_fn(
|
|
185
|
+
data_config,
|
|
186
|
+
self._fpath,
|
|
187
|
+
data_config.datasplit_type,
|
|
188
|
+
val_fraction=val_fraction,
|
|
189
|
+
test_fraction=test_fraction,
|
|
190
|
+
)
|
|
191
|
+
self.dsets = []
|
|
192
|
+
|
|
193
|
+
for i in range(len(data)):
|
|
194
|
+
prefetched_data, fpath_tuple = data[i]
|
|
195
|
+
if (
|
|
196
|
+
data_config.multiscale_lowres_count is not None
|
|
197
|
+
and data_config.multiscale_lowres_count > 1
|
|
198
|
+
):
|
|
199
|
+
|
|
200
|
+
self.dsets.append(
|
|
201
|
+
SingleFileLCDset(
|
|
202
|
+
prefetched_data[None],
|
|
203
|
+
data_config,
|
|
204
|
+
fpath_tuple,
|
|
205
|
+
load_data_fn,
|
|
206
|
+
val_fraction=val_fraction,
|
|
207
|
+
test_fraction=test_fraction,
|
|
208
|
+
)
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
else:
|
|
212
|
+
self.dsets.append(
|
|
213
|
+
SingleFileDset(
|
|
214
|
+
prefetched_data[None],
|
|
215
|
+
data_config,
|
|
216
|
+
fpath_tuple,
|
|
217
|
+
load_data_fn,
|
|
218
|
+
val_fraction=val_fraction,
|
|
219
|
+
test_fraction=test_fraction,
|
|
220
|
+
)
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
self.rm_bkground_set_max_val_and_upperclip_data(
|
|
224
|
+
data_config.max_val, data_config.datasplit_type
|
|
225
|
+
)
|
|
226
|
+
count = 0
|
|
227
|
+
avg_height = 0
|
|
228
|
+
avg_width = 0
|
|
229
|
+
for dset in self.dsets:
|
|
230
|
+
shape = dset.get_data_shape()
|
|
231
|
+
avg_height += shape[1]
|
|
232
|
+
avg_width += shape[2]
|
|
233
|
+
count += shape[0]
|
|
234
|
+
|
|
235
|
+
avg_height = int(avg_height / len(self.dsets))
|
|
236
|
+
avg_width = int(avg_width / len(self.dsets))
|
|
237
|
+
print(
|
|
238
|
+
f"{self.__class__.__name__} avg height: {avg_height}, avg width: {avg_width}, count: {count}"
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
def rm_bkground_set_max_val_and_upperclip_data(self, max_val, datasplit_type):
|
|
242
|
+
self.set_max_val(max_val, datasplit_type)
|
|
243
|
+
self.upperclip_data()
|
|
244
|
+
|
|
245
|
+
def set_mean_std(self, mean_val, std_val):
|
|
246
|
+
for dset in self.dsets:
|
|
247
|
+
dset.set_mean_std(mean_val, std_val)
|
|
248
|
+
|
|
249
|
+
def get_mean_std(self):
|
|
250
|
+
return self.dsets[0].get_mean_std()
|
|
251
|
+
|
|
252
|
+
def compute_max_val(self):
|
|
253
|
+
max_val_arr = []
|
|
254
|
+
for dset in self.dsets:
|
|
255
|
+
max_val_arr.append(dset.compute_max_val())
|
|
256
|
+
return np.max(max_val_arr)
|
|
257
|
+
|
|
258
|
+
def set_max_val(self, max_val, datasplit_type):
|
|
259
|
+
if datasplit_type == DataSplitType.Train:
|
|
260
|
+
assert max_val is None
|
|
261
|
+
max_val = self.compute_max_val()
|
|
262
|
+
for dset in self.dsets:
|
|
263
|
+
dset.set_max_val(max_val, datasplit_type)
|
|
264
|
+
|
|
265
|
+
def upperclip_data(self):
|
|
266
|
+
for dset in self.dsets:
|
|
267
|
+
dset.upperclip_data()
|
|
268
|
+
|
|
269
|
+
def get_max_val(self):
|
|
270
|
+
return self.dsets[0].get_max_val()
|
|
271
|
+
|
|
272
|
+
def get_img_sz(self):
|
|
273
|
+
return self.dsets[0].get_img_sz()
|
|
274
|
+
|
|
275
|
+
def set_img_sz(self, image_size, grid_size):
|
|
276
|
+
for dset in self.dsets:
|
|
277
|
+
dset.set_img_sz(image_size, grid_size)
|
|
278
|
+
|
|
279
|
+
def compute_mean_std(self):
|
|
280
|
+
cur_mean = {"target": 0, "input": 0}
|
|
281
|
+
cur_std = {"target": 0, "input": 0}
|
|
282
|
+
for dset in self.dsets:
|
|
283
|
+
mean, std = dset.compute_mean_std()
|
|
284
|
+
cur_mean["target"] += mean["target"]
|
|
285
|
+
cur_mean["input"] += mean["input"]
|
|
286
|
+
|
|
287
|
+
cur_std["target"] += std["target"]
|
|
288
|
+
cur_std["input"] += std["input"]
|
|
289
|
+
|
|
290
|
+
cur_mean["target"] /= len(self.dsets)
|
|
291
|
+
cur_mean["input"] /= len(self.dsets)
|
|
292
|
+
cur_std["target"] /= len(self.dsets)
|
|
293
|
+
cur_std["input"] /= len(self.dsets)
|
|
294
|
+
return cur_mean, cur_std
|
|
295
|
+
|
|
296
|
+
def compute_individual_mean_std(self):
|
|
297
|
+
cum_mean = 0
|
|
298
|
+
cum_std = 0
|
|
299
|
+
for dset in self.dsets:
|
|
300
|
+
mean, std = dset.compute_individual_mean_std()
|
|
301
|
+
cum_mean += mean
|
|
302
|
+
cum_std += std
|
|
303
|
+
return cum_mean / len(self.dsets), cum_std / len(self.dsets)
|
|
304
|
+
|
|
305
|
+
def get_num_frames(self):
|
|
306
|
+
return len(self.dsets)
|
|
307
|
+
|
|
308
|
+
def reduce_data(
|
|
309
|
+
self, t_list=None, h_start=None, h_end=None, w_start=None, w_end=None
|
|
310
|
+
):
|
|
311
|
+
assert h_start is None
|
|
312
|
+
assert h_end is None
|
|
313
|
+
assert w_start is None
|
|
314
|
+
assert w_end is None
|
|
315
|
+
self.dsets = [self.dsets[t] for t in t_list]
|
|
316
|
+
print(
|
|
317
|
+
f"[{self.__class__.__name__}] Data reduced. New data count: {len(self.dsets)}"
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
def __len__(self):
|
|
321
|
+
out = 0
|
|
322
|
+
for dset in self.dsets:
|
|
323
|
+
out += len(dset)
|
|
324
|
+
return out
|
|
325
|
+
|
|
326
|
+
def __getitem__(self, idx):
|
|
327
|
+
cum_len = 0
|
|
328
|
+
for dset in self.dsets:
|
|
329
|
+
cum_len += len(dset)
|
|
330
|
+
if idx < cum_len:
|
|
331
|
+
rel_idx = idx - (cum_len - len(dset))
|
|
332
|
+
return dset[rel_idx]
|
|
333
|
+
|
|
334
|
+
raise IndexError("Index out of range")
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class DataType(Enum):
|
|
5
|
+
MNIST = 0
|
|
6
|
+
Places365 = 1
|
|
7
|
+
NotMNIST = 2
|
|
8
|
+
OptiMEM100_014 = 3
|
|
9
|
+
CustomSinosoid = 4
|
|
10
|
+
Prevedel_EMBL = 5
|
|
11
|
+
AllenCellMito = 6
|
|
12
|
+
SeparateTiffData = 7
|
|
13
|
+
CustomSinosoidThreeCurve = 8
|
|
14
|
+
SemiSupBloodVesselsEMBL = 9
|
|
15
|
+
Pavia2 = 10
|
|
16
|
+
Pavia2VanillaSplitting = 11
|
|
17
|
+
ExpansionMicroscopyMitoTub = 12
|
|
18
|
+
ShroffMitoEr = 13
|
|
19
|
+
HTIba1Ki67 = 14
|
|
20
|
+
BSD68 = 15
|
|
21
|
+
BioSR_MRC = 16
|
|
22
|
+
TavernaSox2Golgi = 17
|
|
23
|
+
Dao3Channel = 18
|
|
24
|
+
ExpMicroscopyV2 = 19
|
|
25
|
+
Dao3ChannelWithInput = 20
|
|
26
|
+
TavernaSox2GolgiV2 = 21
|
|
27
|
+
TwoDset = 22
|
|
28
|
+
PredictedTiffData = 23
|
|
29
|
+
Pavia3SeqData = 24
|
|
30
|
+
NicolaData = 25
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class DataSplitType(Enum):
|
|
34
|
+
All = 0
|
|
35
|
+
Train = 1
|
|
36
|
+
Val = 2
|
|
37
|
+
Test = 3
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class TilingMode(Enum):
|
|
41
|
+
TrimBoundary = 0
|
|
42
|
+
PadBoundary = 1
|
|
43
|
+
ShiftBoundary = 2
|
|
File without changes
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utility functions needed by dataloader & co.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import List
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from skimage.io import imread, imsave
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def load_tiff(path):
|
|
12
|
+
"""
|
|
13
|
+
Returns a 4d numpy array: num_imgs*h*w*num_channels
|
|
14
|
+
"""
|
|
15
|
+
data = imread(path, plugin="tifffile")
|
|
16
|
+
return data
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def save_tiff(path, data):
|
|
20
|
+
imsave(path, data, plugin="tifffile")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def load_tiffs(paths):
|
|
24
|
+
data = [load_tiff(path) for path in paths]
|
|
25
|
+
return np.concatenate(data, axis=0)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def split_in_half(s, e):
|
|
29
|
+
n = e - s
|
|
30
|
+
s1 = list(np.arange(n // 2))
|
|
31
|
+
s2 = list(np.arange(n // 2, n))
|
|
32
|
+
return [x + s for x in s1], [x + s for x in s2]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def adjust_for_imbalance_in_fraction_value(
|
|
36
|
+
val: List[int],
|
|
37
|
+
test: List[int],
|
|
38
|
+
val_fraction: float,
|
|
39
|
+
test_fraction: float,
|
|
40
|
+
total_size: int,
|
|
41
|
+
):
|
|
42
|
+
"""
|
|
43
|
+
here, val and test are divided almost equally. Here, we need to take into account their respective fractions
|
|
44
|
+
and pick elements rendomly from one array and put in the other array.
|
|
45
|
+
"""
|
|
46
|
+
if val_fraction == 0:
|
|
47
|
+
test += val
|
|
48
|
+
val = []
|
|
49
|
+
elif test_fraction == 0:
|
|
50
|
+
val += test
|
|
51
|
+
test = []
|
|
52
|
+
else:
|
|
53
|
+
diff_fraction = test_fraction - val_fraction
|
|
54
|
+
if diff_fraction > 0:
|
|
55
|
+
imb_count = int(diff_fraction * total_size / 2)
|
|
56
|
+
val = list(np.random.RandomState(seed=955).permutation(val))
|
|
57
|
+
test += val[:imb_count]
|
|
58
|
+
val = val[imb_count:]
|
|
59
|
+
elif diff_fraction < 0:
|
|
60
|
+
imb_count = int(-1 * diff_fraction * total_size / 2)
|
|
61
|
+
test = list(np.random.RandomState(seed=955).permutation(test))
|
|
62
|
+
val += test[:imb_count]
|
|
63
|
+
test = test[imb_count:]
|
|
64
|
+
return val, test
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def get_datasplit_tuples(
|
|
68
|
+
val_fraction: float,
|
|
69
|
+
test_fraction: float,
|
|
70
|
+
total_size: int,
|
|
71
|
+
starting_test: bool = False,
|
|
72
|
+
):
|
|
73
|
+
if starting_test:
|
|
74
|
+
# test => val => train
|
|
75
|
+
test = list(range(0, int(total_size * test_fraction)))
|
|
76
|
+
val = list(range(test[-1] + 1, test[-1] + 1 + int(total_size * val_fraction)))
|
|
77
|
+
train = list(range(val[-1] + 1, total_size))
|
|
78
|
+
else:
|
|
79
|
+
# {test,val}=> train
|
|
80
|
+
test_val_size = int((val_fraction + test_fraction) * total_size)
|
|
81
|
+
train = list(range(test_val_size, total_size))
|
|
82
|
+
|
|
83
|
+
if test_val_size == 0:
|
|
84
|
+
test = []
|
|
85
|
+
val = []
|
|
86
|
+
return train, val, test
|
|
87
|
+
|
|
88
|
+
# Split the test and validation in chunks.
|
|
89
|
+
chunksize = max(1, min(3, test_val_size // 2))
|
|
90
|
+
|
|
91
|
+
nchunks = test_val_size // chunksize
|
|
92
|
+
|
|
93
|
+
test = []
|
|
94
|
+
val = []
|
|
95
|
+
s = 0
|
|
96
|
+
for i in range(nchunks):
|
|
97
|
+
if i % 2 == 0:
|
|
98
|
+
val += list(np.arange(s, s + chunksize))
|
|
99
|
+
else:
|
|
100
|
+
test += list(np.arange(s, s + chunksize))
|
|
101
|
+
s += chunksize
|
|
102
|
+
|
|
103
|
+
if i % 2 == 0:
|
|
104
|
+
test += list(np.arange(s, test_val_size))
|
|
105
|
+
else:
|
|
106
|
+
p1, p2 = split_in_half(s, test_val_size)
|
|
107
|
+
test += p1
|
|
108
|
+
val += p2
|
|
109
|
+
|
|
110
|
+
val, test = adjust_for_imbalance_in_fraction_value(
|
|
111
|
+
val, test, val_fraction, test_fraction, total_size
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
return train, val, test
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from tqdm import tqdm
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class EmptyPatchFetcher:
|
|
6
|
+
"""
|
|
7
|
+
The idea is to fetch empty patches so that real content can be replaced with this.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
def __init__(self, idx_manager, patch_size, data_frames, max_val_threshold=None):
|
|
11
|
+
self._frames = data_frames
|
|
12
|
+
self._idx_manager = idx_manager
|
|
13
|
+
self._max_val_threshold = max_val_threshold
|
|
14
|
+
self._idx_list = []
|
|
15
|
+
self._patch_size = patch_size
|
|
16
|
+
self._grid_size = 1
|
|
17
|
+
self.set_empty_idx()
|
|
18
|
+
|
|
19
|
+
print(f"[{self.__class__.__name__}] MaxVal:{self._max_val_threshold}")
|
|
20
|
+
|
|
21
|
+
def compute_max(self, window):
|
|
22
|
+
"""
|
|
23
|
+
Rolling compute.
|
|
24
|
+
"""
|
|
25
|
+
N, H, W = self._frames.shape
|
|
26
|
+
randnum = -954321
|
|
27
|
+
assert self._grid_size == 1
|
|
28
|
+
max_data = np.zeros((N, H - window, W - window)) * randnum
|
|
29
|
+
|
|
30
|
+
for h in tqdm(range(H - window)):
|
|
31
|
+
for w in range(W - window):
|
|
32
|
+
max_data[:, h, w] = self._frames[:, h : h + window, w : w + window].max(
|
|
33
|
+
axis=(1, 2)
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
assert (max_data != 954321).any()
|
|
37
|
+
return max_data
|
|
38
|
+
|
|
39
|
+
def set_empty_idx(self):
|
|
40
|
+
max_data = self.compute_max(self._patch_size)
|
|
41
|
+
empty_loc = np.where(
|
|
42
|
+
np.logical_and(max_data >= 0, max_data < self._max_val_threshold)
|
|
43
|
+
)
|
|
44
|
+
# print(max_data.shape, len(empty_loc))
|
|
45
|
+
self._idx_list = []
|
|
46
|
+
for idx in range(len(empty_loc[0])):
|
|
47
|
+
n_idx = empty_loc[0][idx]
|
|
48
|
+
h_start = empty_loc[1][idx]
|
|
49
|
+
w_start = empty_loc[2][idx]
|
|
50
|
+
# print(n_idx,h_start,w_start)
|
|
51
|
+
# channel_idx = self._idx_manager.get_location_from_dataset_idx(0)[-1]
|
|
52
|
+
loc = (n_idx, h_start, w_start, 0)
|
|
53
|
+
idx = self._idx_manager.get_dataset_idx_from_location(loc)
|
|
54
|
+
t, h, w, _ = self._idx_manager.get_location_from_dataset_idx(idx)
|
|
55
|
+
assert h == h_start, f"{h} != {h_start}"
|
|
56
|
+
assert w == w_start, f"{w} != {w_start}"
|
|
57
|
+
assert t == n_idx, f"{t} != {n_idx}"
|
|
58
|
+
self._idx_list.append(idx)
|
|
59
|
+
|
|
60
|
+
self._idx_list = np.array(self._idx_list)
|
|
61
|
+
|
|
62
|
+
assert len(self._idx_list) > 0
|
|
63
|
+
|
|
64
|
+
def sample(self):
|
|
65
|
+
return (np.random.choice(self._idx_list), self._grid_size)
|