careamics 0.0.10__py3-none-any.whl → 0.0.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +20 -4
- careamics/config/configuration.py +10 -5
- careamics/config/data/data_model.py +38 -1
- careamics/config/optimizer_models.py +1 -3
- careamics/config/training_model.py +0 -2
- careamics/dataset/dataset_utils/running_stats.py +7 -3
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/dataset.py +233 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +356 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +734 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +443 -0
- careamics/dataset_ng/{demo_dataset.ipynb → demos/demo_dataset.ipynb} +39 -15
- careamics/dataset_ng/{demo_patch_extractor.py → demos/demo_patch_extractor.py} +7 -9
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +292 -0
- careamics/dataset_ng/factory.py +408 -0
- careamics/dataset_ng/legacy_interoperability.py +168 -0
- careamics/dataset_ng/patch_extractor/__init__.py +3 -8
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +6 -4
- careamics/dataset_ng/patch_extractor/image_stack/__init__.py +2 -1
- careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +5 -1
- careamics/dataset_ng/patch_extractor/image_stack_loader.py +5 -75
- careamics/dataset_ng/patch_extractor/patch_extractor.py +5 -4
- careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +73 -106
- careamics/dataset_ng/patching_strategies/__init__.py +6 -1
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +31 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +3 -1
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +171 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +36 -0
- careamics/lightning/dataset_ng/data_module.py +488 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +58 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +67 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +143 -0
- careamics/lightning/lightning_module.py +3 -0
- careamics/lvae_training/dataset/__init__.py +8 -3
- careamics/lvae_training/dataset/config.py +3 -3
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +46 -17
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/types.py +3 -3
- careamics/lvae_training/dataset/utils/index_manager.py +259 -0
- careamics/lvae_training/eval_utils.py +93 -3
- careamics/transforms/compose.py +1 -0
- careamics/transforms/normalize.py +18 -7
- careamics/utils/lightning_utils.py +25 -11
- {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/METADATA +3 -3
- {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/RECORD +51 -36
- careamics/dataset_ng/dataset/__init__.py +0 -3
- careamics/dataset_ng/dataset/dataset.py +0 -184
- careamics/dataset_ng/demo_patch_extractor_factory.py +0 -37
- {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/WHEEL +0 -0
- {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/licenses/LICENSE +0 -0
|
@@ -340,25 +340,54 @@ class MultiChDloader:
|
|
|
340
340
|
return self._data.shape[0]
|
|
341
341
|
|
|
342
342
|
def reduce_data(
|
|
343
|
-
self,
|
|
343
|
+
self,
|
|
344
|
+
t_list=None,
|
|
345
|
+
z_start=None,
|
|
346
|
+
z_end=None,
|
|
347
|
+
h_start=None,
|
|
348
|
+
h_end=None,
|
|
349
|
+
w_start=None,
|
|
350
|
+
w_end=None,
|
|
344
351
|
):
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
352
|
+
if self._5Ddata:
|
|
353
|
+
if t_list is None:
|
|
354
|
+
t_list = list(range(self._data.shape[0]))
|
|
355
|
+
if z_start is None:
|
|
356
|
+
z_start = 0
|
|
357
|
+
if z_end is None:
|
|
358
|
+
z_end = self._data.shape[1]
|
|
359
|
+
if h_start is None:
|
|
360
|
+
h_start = 0
|
|
361
|
+
if h_end is None:
|
|
362
|
+
h_end = self._data.shape[2]
|
|
363
|
+
if w_start is None:
|
|
364
|
+
w_start = 0
|
|
365
|
+
if w_end is None:
|
|
366
|
+
w_end = self._data.shape[3]
|
|
367
|
+
self._data = self._data[
|
|
368
|
+
t_list, z_start:z_end, h_start:h_end, w_start:w_end, :
|
|
361
369
|
].copy()
|
|
370
|
+
if self._noise_data is not None:
|
|
371
|
+
self._noise_data = self._noise_data[
|
|
372
|
+
t_list, z_start:z_end, h_start:h_end, w_start:w_end, :
|
|
373
|
+
].copy()
|
|
374
|
+
else:
|
|
375
|
+
if t_list is None:
|
|
376
|
+
t_list = list(range(self._data.shape[0]))
|
|
377
|
+
if h_start is None:
|
|
378
|
+
h_start = 0
|
|
379
|
+
if h_end is None:
|
|
380
|
+
h_end = self._data.shape[1]
|
|
381
|
+
if w_start is None:
|
|
382
|
+
w_start = 0
|
|
383
|
+
if w_end is None:
|
|
384
|
+
w_end = self._data.shape[2]
|
|
385
|
+
|
|
386
|
+
self._data = self._data[t_list, h_start:h_end, w_start:w_end, :].copy()
|
|
387
|
+
if self._noise_data is not None:
|
|
388
|
+
self._noise_data = self._noise_data[
|
|
389
|
+
t_list, h_start:h_end, w_start:w_end, :
|
|
390
|
+
].copy()
|
|
362
391
|
# TODO where tf is self._img_sz defined?
|
|
363
392
|
self.set_img_sz([self._img_sz, self._img_sz], self._grid_sz)
|
|
364
393
|
print(
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Here, we have multiple folders, each containing images of a single channel.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from collections import defaultdict
|
|
6
|
+
from functools import cache
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
from .types import DataSplitType
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def l2(x):
|
|
14
|
+
return np.sqrt(np.mean(np.array(x) ** 2))
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class MultiCropDset:
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
data_config,
|
|
21
|
+
fpath: str,
|
|
22
|
+
load_data_fn=None,
|
|
23
|
+
val_fraction=None,
|
|
24
|
+
test_fraction=None,
|
|
25
|
+
):
|
|
26
|
+
|
|
27
|
+
assert (
|
|
28
|
+
data_config.input_is_sum == True
|
|
29
|
+
), "This dataset is designed for sum of images"
|
|
30
|
+
|
|
31
|
+
self._img_sz = data_config.image_size
|
|
32
|
+
self._enable_rotation = data_config.enable_rotation_aug
|
|
33
|
+
|
|
34
|
+
self._background_values = data_config.background_values
|
|
35
|
+
self._data = load_data_fn(
|
|
36
|
+
data_config, fpath, data_config.datasplit_type, val_fraction, test_fraction
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
# remove upper quantiles, crucial for removing puncta
|
|
40
|
+
self.max_val = data_config.max_val
|
|
41
|
+
if self.max_val is not None:
|
|
42
|
+
for ch_idx, data in enumerate(self._data):
|
|
43
|
+
if self.max_val[ch_idx] is not None:
|
|
44
|
+
for idx in range(len(data)):
|
|
45
|
+
data[idx][data[idx] > self.max_val[ch_idx]] = self.max_val[
|
|
46
|
+
ch_idx
|
|
47
|
+
]
|
|
48
|
+
|
|
49
|
+
# remove background values
|
|
50
|
+
if self._background_values is not None:
|
|
51
|
+
final_data_arr = []
|
|
52
|
+
for ch_idx, data in enumerate(self._data):
|
|
53
|
+
data_float = [x.astype(np.float32) for x in data]
|
|
54
|
+
final_data_arr.append(
|
|
55
|
+
[x - self._background_values[ch_idx] for x in data_float]
|
|
56
|
+
)
|
|
57
|
+
self._data = final_data_arr
|
|
58
|
+
|
|
59
|
+
print(
|
|
60
|
+
f"{self.__class__.__name__} N:{len(self)} Rot:{self._enable_rotation} Ch:{len(self._data)} MaxVal:{self.max_val} Bg:{self._background_values}"
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
def get_max_val(self):
|
|
64
|
+
return self.max_val
|
|
65
|
+
|
|
66
|
+
def compute_mean_std(self):
|
|
67
|
+
mean_tar_dict = defaultdict(list)
|
|
68
|
+
std_tar_dict = defaultdict(list)
|
|
69
|
+
mean_inp = []
|
|
70
|
+
std_inp = []
|
|
71
|
+
for _ in range(30000):
|
|
72
|
+
crops = []
|
|
73
|
+
for ch_idx in range(len(self._data)):
|
|
74
|
+
crop = self.sample_crop(ch_idx)
|
|
75
|
+
mean_tar_dict[ch_idx].append(np.mean(crop))
|
|
76
|
+
std_tar_dict[ch_idx].append(np.std(crop))
|
|
77
|
+
crops.append(crop)
|
|
78
|
+
|
|
79
|
+
inp = 0
|
|
80
|
+
for img in crops:
|
|
81
|
+
inp += img
|
|
82
|
+
|
|
83
|
+
mean_inp.append(np.mean(inp))
|
|
84
|
+
std_inp.append(np.std(inp))
|
|
85
|
+
|
|
86
|
+
output_mean = defaultdict(list)
|
|
87
|
+
output_std = defaultdict(list)
|
|
88
|
+
NC = len(self._data)
|
|
89
|
+
for ch_idx in range(NC):
|
|
90
|
+
output_mean["target"].append(np.mean(mean_tar_dict[ch_idx]))
|
|
91
|
+
output_std["target"].append(l2(std_tar_dict[ch_idx]))
|
|
92
|
+
|
|
93
|
+
output_mean["target"] = np.array(output_mean["target"]).reshape(NC, 1, 1)
|
|
94
|
+
output_std["target"] = np.array(output_std["target"]).reshape(NC, 1, 1)
|
|
95
|
+
|
|
96
|
+
output_mean["input"] = np.array([np.mean(mean_inp)]).reshape(1, 1, 1)
|
|
97
|
+
output_std["input"] = np.array([l2(std_inp)]).reshape(1, 1, 1)
|
|
98
|
+
return dict(output_mean), dict(output_std)
|
|
99
|
+
|
|
100
|
+
def set_mean_std(self, mean_dict, std_dict):
|
|
101
|
+
self._data_mean = mean_dict
|
|
102
|
+
self._data_std = std_dict
|
|
103
|
+
|
|
104
|
+
def get_mean_std(self):
|
|
105
|
+
return self._data_mean, self._data_std
|
|
106
|
+
|
|
107
|
+
def get_num_frames(self):
|
|
108
|
+
return len(self._data)
|
|
109
|
+
|
|
110
|
+
@cache
|
|
111
|
+
def crop_probablities(self, ch_idx):
|
|
112
|
+
sizes = np.array([np.prod(x.shape) for x in self._data[ch_idx]])
|
|
113
|
+
return sizes / sizes.sum()
|
|
114
|
+
|
|
115
|
+
def sample_crop(self, ch_idx):
|
|
116
|
+
idx = None
|
|
117
|
+
count = 0
|
|
118
|
+
while idx is None:
|
|
119
|
+
count += 1
|
|
120
|
+
idx = np.random.choice(
|
|
121
|
+
len(self._data[ch_idx]), p=self.crop_probablities(ch_idx)
|
|
122
|
+
)
|
|
123
|
+
data = self._data[ch_idx][idx]
|
|
124
|
+
if data.shape[0] >= self._img_sz[0] and data.shape[1] >= self._img_sz[1]:
|
|
125
|
+
h = np.random.randint(0, data.shape[0] - self._img_sz[0])
|
|
126
|
+
w = np.random.randint(0, data.shape[1] - self._img_sz[1])
|
|
127
|
+
return data[h : h + self._img_sz[0], w : w + self._img_sz[1]]
|
|
128
|
+
elif count > 100:
|
|
129
|
+
raise ValueError("Cannot find a valid crop")
|
|
130
|
+
else:
|
|
131
|
+
idx = None
|
|
132
|
+
|
|
133
|
+
return None
|
|
134
|
+
|
|
135
|
+
def len_per_channel(self, ch_idx):
|
|
136
|
+
return np.sum([np.prod(x.shape) for x in self._data[ch_idx]]) / np.prod(
|
|
137
|
+
self._img_sz
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
def imgs_for_patch(self):
|
|
141
|
+
return [self.sample_crop(ch_idx) for ch_idx in range(len(self._data))]
|
|
142
|
+
|
|
143
|
+
def __len__(self):
|
|
144
|
+
len_per_channel = [
|
|
145
|
+
self.len_per_channel(ch_idx) for ch_idx in range(len(self._data))
|
|
146
|
+
]
|
|
147
|
+
return int(np.max(len_per_channel))
|
|
148
|
+
|
|
149
|
+
def _rotate(self, img_tuples):
|
|
150
|
+
return self._rotate2D(img_tuples)
|
|
151
|
+
|
|
152
|
+
def _rotate2D(self, img_tuples):
|
|
153
|
+
img_kwargs = {}
|
|
154
|
+
for i, img in enumerate(img_tuples):
|
|
155
|
+
for k in range(len(img)):
|
|
156
|
+
img_kwargs[f"img{i}_{k}"] = img[k]
|
|
157
|
+
|
|
158
|
+
keys = list(img_kwargs.keys())
|
|
159
|
+
self._rotation_transform.add_targets({k: "image" for k in keys})
|
|
160
|
+
rot_dic = self._rotation_transform(image=img_tuples[0][0], **img_kwargs)
|
|
161
|
+
|
|
162
|
+
rotated_img_tuples = []
|
|
163
|
+
for i, img in enumerate(img_tuples):
|
|
164
|
+
if len(img) == 1:
|
|
165
|
+
rotated_img_tuples.append(rot_dic[f"img{i}_0"][None])
|
|
166
|
+
else:
|
|
167
|
+
rotated_img_tuples.append(
|
|
168
|
+
np.concatenate(
|
|
169
|
+
[rot_dic[f"img{i}_{k}"][None] for k in range(len(img))], axis=0
|
|
170
|
+
)
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
return rotated_img_tuples
|
|
174
|
+
|
|
175
|
+
def _compute_input(self, imgs):
|
|
176
|
+
inp = 0
|
|
177
|
+
for img in imgs:
|
|
178
|
+
inp += img
|
|
179
|
+
|
|
180
|
+
inp = inp[None]
|
|
181
|
+
inp = (inp - self._data_mean["input"]) / (self._data_std["input"])
|
|
182
|
+
return inp
|
|
183
|
+
|
|
184
|
+
def _compute_target(self, imgs):
|
|
185
|
+
imgs = np.stack(imgs)
|
|
186
|
+
target = (imgs - self._data_mean["target"]) / (self._data_std["target"])
|
|
187
|
+
return target
|
|
188
|
+
|
|
189
|
+
def __getitem__(self, idx):
|
|
190
|
+
imgs = self.imgs_for_patch()
|
|
191
|
+
if self._enable_rotation:
|
|
192
|
+
imgs = self._rotate(imgs)
|
|
193
|
+
|
|
194
|
+
inp = self._compute_input(imgs)
|
|
195
|
+
target = self._compute_target(imgs)
|
|
196
|
+
return inp, target
|
|
@@ -2,9 +2,9 @@ from enum import Enum
|
|
|
2
2
|
|
|
3
3
|
|
|
4
4
|
class DataType(Enum):
|
|
5
|
-
|
|
5
|
+
HTH24Data = 0
|
|
6
6
|
HTLIF24Data = 1
|
|
7
|
-
|
|
7
|
+
PaviaP24Data = 2
|
|
8
8
|
TavernaSox2GolgiV2 = 3
|
|
9
9
|
Dao3ChannelWithInput = 4
|
|
10
10
|
ExpMicroscopyV1 = 5
|
|
@@ -15,7 +15,7 @@ class DataType(Enum):
|
|
|
15
15
|
OptiMEM100_014 = 10
|
|
16
16
|
SeparateTiffData = 11
|
|
17
17
|
BioSR_MRC = 12
|
|
18
|
-
|
|
18
|
+
HTH23BData = 13 # puncta, in case we have differently sized crops for each channel.
|
|
19
19
|
Care3D = 14
|
|
20
20
|
|
|
21
21
|
|
|
@@ -230,3 +230,262 @@ class GridIndexManager:
|
|
|
230
230
|
new_idx = dataset_idx - self.grid_count(dim)
|
|
231
231
|
if new_idx < 0:
|
|
232
232
|
return None
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
@dataclass
|
|
236
|
+
class GridIndexManagerRef:
|
|
237
|
+
data_shapes: tuple
|
|
238
|
+
grid_shape: tuple
|
|
239
|
+
patch_shape: tuple
|
|
240
|
+
tiling_mode: TilingMode
|
|
241
|
+
|
|
242
|
+
# This class is used to calculate and store information about patches, and calculate
|
|
243
|
+
# the total length of the dataset in patches.
|
|
244
|
+
# It introduces a concept of a grid, to which input images are split.
|
|
245
|
+
# The grid is defined by the grid_shape and patch_shape, with former controlling the
|
|
246
|
+
# overlap.
|
|
247
|
+
# In this reimplementation it can accept multiple channels with different lengths,
|
|
248
|
+
# and every image can have different shape.
|
|
249
|
+
|
|
250
|
+
def __post_init__(self):
|
|
251
|
+
if len(self.data_shapes) > 1:
|
|
252
|
+
assert {len(ds) for ds in self.data_shapes[0]}.pop() == {
|
|
253
|
+
len(ds) for ds in self.data_shapes[1]
|
|
254
|
+
}.pop(), "Data shape for all channels must be the same" # TODO better way to assert this
|
|
255
|
+
assert {len(ds) for ds in self.data_shapes[0]}.pop() == len(
|
|
256
|
+
self.grid_shape
|
|
257
|
+
), "Data shape and grid size must have the same dimension"
|
|
258
|
+
assert {len(ds) for ds in self.data_shapes[0]}.pop() == len(
|
|
259
|
+
self.patch_shape
|
|
260
|
+
), "Data shape and patch shape must have the same dimension"
|
|
261
|
+
innerpad = np.array(self.patch_shape) - np.array(self.grid_shape)
|
|
262
|
+
for dim, pad in enumerate(innerpad):
|
|
263
|
+
if pad < 0:
|
|
264
|
+
raise ValueError(
|
|
265
|
+
f"Patch shape must be greater than or equal to grid shape in dimension {dim}"
|
|
266
|
+
)
|
|
267
|
+
if pad % 2 != 0:
|
|
268
|
+
raise ValueError(
|
|
269
|
+
f"Patch shape must have even padding in dimension {dim}"
|
|
270
|
+
)
|
|
271
|
+
self.num_patches_per_channel = self.total_grid_count()[1]
|
|
272
|
+
|
|
273
|
+
def patch_offset(self):
|
|
274
|
+
return (np.array(self.patch_shape) - np.array(self.grid_shape)) // 2
|
|
275
|
+
|
|
276
|
+
def get_individual_dim_grid_count(self, shape: tuple, dim: int):
|
|
277
|
+
"""
|
|
278
|
+
Returns the number of the grid in the specified dimension, ignoring all other dimensions.
|
|
279
|
+
"""
|
|
280
|
+
# assert that dim is less than the number of dimensions in data shape
|
|
281
|
+
|
|
282
|
+
# if dim > len()
|
|
283
|
+
if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
|
|
284
|
+
return shape[dim]
|
|
285
|
+
elif self.tiling_mode == TilingMode.PadBoundary:
|
|
286
|
+
return int(np.ceil(shape[dim] / self.grid_shape[dim]))
|
|
287
|
+
elif self.tiling_mode == TilingMode.ShiftBoundary:
|
|
288
|
+
excess_size = self.patch_shape[dim] - self.grid_shape[dim]
|
|
289
|
+
return int(np.ceil((shape[dim] - excess_size) / self.grid_shape[dim]))
|
|
290
|
+
# if dim_index < self.get_individual_dim_grid_count(dim) - 1:
|
|
291
|
+
# return dim_index * self.grid_shape[dim] + excess_size
|
|
292
|
+
# on boundary. grid should be placed such that the patch covers the entire data.
|
|
293
|
+
# return self.data_shape[dim] - self.grid_shape[dim] - excess_size
|
|
294
|
+
else:
|
|
295
|
+
excess_size = self.patch_shape[dim] - self.grid_shape[dim]
|
|
296
|
+
return int(np.floor((shape[dim] - excess_size) / self.grid_shape[dim]))
|
|
297
|
+
|
|
298
|
+
def total_grid_count(self):
|
|
299
|
+
"""Returns the total number of patches in the dataset."""
|
|
300
|
+
len_per_channel = []
|
|
301
|
+
num_patches_per_sample = []
|
|
302
|
+
for channel_data in self.data_shapes:
|
|
303
|
+
num_patches = []
|
|
304
|
+
for file_shape in channel_data:
|
|
305
|
+
num_patches.append(np.prod(self.grid_count_per_sample(file_shape)))
|
|
306
|
+
len_per_channel.append(np.sum(num_patches))
|
|
307
|
+
num_patches_per_sample.append(num_patches)
|
|
308
|
+
|
|
309
|
+
return len_per_channel, num_patches_per_sample
|
|
310
|
+
|
|
311
|
+
def grid_count_per_sample(self, shape: tuple):
|
|
312
|
+
"""Returns the total number of patches for one dimension."""
|
|
313
|
+
grid_count = []
|
|
314
|
+
for dim in range(len(shape)):
|
|
315
|
+
grid_count.append(self.get_individual_dim_grid_count(shape, dim))
|
|
316
|
+
return grid_count
|
|
317
|
+
|
|
318
|
+
def get_grid_index(self, shape, dim: int, coordinate: int):
|
|
319
|
+
"""Returns the index of the patch in the specified dimension."""
|
|
320
|
+
assert dim < len(
|
|
321
|
+
shape
|
|
322
|
+
), f"Dimension {dim} is out of bounds for data shape {shape}"
|
|
323
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
324
|
+
assert (
|
|
325
|
+
coordinate < shape[dim]
|
|
326
|
+
), f"Coordinate {coordinate} is out of bounds for data shape {shape}"
|
|
327
|
+
|
|
328
|
+
if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
|
|
329
|
+
return coordinate
|
|
330
|
+
elif self.tiling_mode == TilingMode.PadBoundary: # self.trim_boundary is False:
|
|
331
|
+
return np.floor(coordinate / self.grid_shape[dim])
|
|
332
|
+
elif self.tiling_mode == TilingMode.TrimBoundary:
|
|
333
|
+
excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
|
|
334
|
+
# can be <0 if coordinate is in [0,grid_shape[dim]]
|
|
335
|
+
return max(0, np.floor((coordinate - excess_size) / self.grid_shape[dim]))
|
|
336
|
+
elif self.tiling_mode == TilingMode.ShiftBoundary:
|
|
337
|
+
excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
|
|
338
|
+
if coordinate + self.grid_shape[dim] + excess_size == self.data_shapes[dim]:
|
|
339
|
+
return self.get_individual_dim_grid_count(shape, dim) - 1
|
|
340
|
+
else:
|
|
341
|
+
# can be <0 if coordinate is in [0,grid_shape[dim]]
|
|
342
|
+
return max(
|
|
343
|
+
0, np.floor((coordinate - excess_size) / self.grid_shape[dim])
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
else:
|
|
347
|
+
raise ValueError(f"Unsupported tiling mode {self.tiling_mode}")
|
|
348
|
+
|
|
349
|
+
def patch_idx_from_grid_idx(self, shape: tuple, grid_idx: tuple):
|
|
350
|
+
"""Returns the index of the patch in the dataset."""
|
|
351
|
+
assert len(grid_idx) == len(
|
|
352
|
+
shape
|
|
353
|
+
), f"Dimension indices {grid_idx} must have the same dimension as data shape {shape}"
|
|
354
|
+
index = 0
|
|
355
|
+
for dim in range(len(grid_idx)):
|
|
356
|
+
index += grid_idx[dim] * self.grid_count(shape, dim)
|
|
357
|
+
return index
|
|
358
|
+
|
|
359
|
+
def get_patch_location_from_patch_idx(self, ch_idx: int, patch_idx: int):
|
|
360
|
+
"""Returns the patch location of the grid in the dataset."""
|
|
361
|
+
grid_location = self.get_location_from_patch_idx(ch_idx, patch_idx)
|
|
362
|
+
offset = self.patch_offset()
|
|
363
|
+
return tuple(np.array(grid_location) - np.concatenate((np.array((0,)), offset)))
|
|
364
|
+
|
|
365
|
+
def get_patch_idx_from_grid_location(self, shape, location: tuple):
|
|
366
|
+
assert len(location) == len(
|
|
367
|
+
shape
|
|
368
|
+
), f"Location {location} must have the same dimension as data shape {shape}"
|
|
369
|
+
grid_idx = [
|
|
370
|
+
self.get_grid_index(dim, location[dim]) for dim in range(len(location))
|
|
371
|
+
]
|
|
372
|
+
return self.patch_idx_from_grid_idx(tuple(grid_idx))
|
|
373
|
+
|
|
374
|
+
def get_gridstart_location_from_dim_index(
|
|
375
|
+
self, shape: tuple, dim_idx: int, dim: int
|
|
376
|
+
):
|
|
377
|
+
"""Returns the grid-start coordinate of the grid in the specified dimension.
|
|
378
|
+
|
|
379
|
+
dim_idx: int
|
|
380
|
+
Index of the dimension in the data shape.
|
|
381
|
+
dim: int
|
|
382
|
+
Value of the dimension in the grid (relative to num patches in dimension).
|
|
383
|
+
"""
|
|
384
|
+
if self.grid_shape[dim_idx] == 1 and self.patch_shape[dim_idx] == 1:
|
|
385
|
+
return dim_idx
|
|
386
|
+
elif self.tiling_mode == TilingMode.ShiftBoundary:
|
|
387
|
+
excess_size = (self.patch_shape[dim_idx] - self.grid_shape[dim_idx]) // 2
|
|
388
|
+
if dim < self.get_individual_dim_grid_count(shape, dim_idx) - 1:
|
|
389
|
+
return dim * self.grid_shape[dim_idx] + excess_size
|
|
390
|
+
else:
|
|
391
|
+
# on boundary. grid should be placed such that the patch covers the entire data.
|
|
392
|
+
return shape[dim_idx] - self.grid_shape[dim_idx] - excess_size
|
|
393
|
+
else:
|
|
394
|
+
raise ValueError(f"Unsupported tiling mode {self.tiling_mode}")
|
|
395
|
+
|
|
396
|
+
def get_location_from_patch_idx(self, channel_idx: int, patch_idx: int):
|
|
397
|
+
"""
|
|
398
|
+
Returns the start location of the grid in the dataset. Per channel!.
|
|
399
|
+
|
|
400
|
+
Parameters
|
|
401
|
+
----------
|
|
402
|
+
patch_idx : int
|
|
403
|
+
The index of the patch in a list of samples within a channel. Channels can
|
|
404
|
+
be different in length.
|
|
405
|
+
"""
|
|
406
|
+
# TODO assert patch_idx <= num of patches in the channel
|
|
407
|
+
# create cumulative sum of the grid counts for each channel
|
|
408
|
+
cumulative_indices = np.cumsum(self.total_grid_count()[1][channel_idx])
|
|
409
|
+
# find the channel index
|
|
410
|
+
sample_idx = np.searchsorted(cumulative_indices, patch_idx, side="right")
|
|
411
|
+
sample_shape = self.data_shapes[channel_idx][sample_idx]
|
|
412
|
+
# TODO duplicated runs, revisit
|
|
413
|
+
# ingoring the channel dimension because we index it explicitly
|
|
414
|
+
grid_count = self.grid_count_per_sample(sample_shape)[1:]
|
|
415
|
+
|
|
416
|
+
grid_idx = []
|
|
417
|
+
for i in range(len(grid_count) - 1, -1, -1):
|
|
418
|
+
stride = np.prod(grid_count[:i]) if i > 0 else 1
|
|
419
|
+
grid_idx.insert(0, patch_idx // stride)
|
|
420
|
+
patch_idx %= stride
|
|
421
|
+
# TODO check for 3D !
|
|
422
|
+
# adding channel index
|
|
423
|
+
grid_idx = [channel_idx] + grid_idx
|
|
424
|
+
location = [
|
|
425
|
+
sample_idx,
|
|
426
|
+
] + [
|
|
427
|
+
self.get_gridstart_location_from_dim_index(
|
|
428
|
+
shape=sample_shape, dim_idx=dim_idx, dim=grid_idx[dim_idx]
|
|
429
|
+
)
|
|
430
|
+
for dim_idx in range(len(grid_idx))
|
|
431
|
+
]
|
|
432
|
+
return tuple(location)
|
|
433
|
+
|
|
434
|
+
def get_location_from_patch_idx_o(self, dataset_idx: int):
|
|
435
|
+
"""
|
|
436
|
+
Returns the start location of the grid in the dataset.
|
|
437
|
+
"""
|
|
438
|
+
grid_idx = []
|
|
439
|
+
for dim in range(len(self.data_shape)):
|
|
440
|
+
grid_idx.append(dataset_idx // self.grid_count(dim))
|
|
441
|
+
dataset_idx = dataset_idx % self.grid_count(dim)
|
|
442
|
+
location = [
|
|
443
|
+
self.get_gridstart_location_from_dim_index(dim, grid_idx[dim])
|
|
444
|
+
for dim in range(len(self.data_shape))
|
|
445
|
+
]
|
|
446
|
+
return tuple(location)
|
|
447
|
+
|
|
448
|
+
def on_boundary(self, dataset_idx: int, dim: int, only_end: bool = False):
|
|
449
|
+
"""
|
|
450
|
+
Returns True if the grid is on the boundary in the specified dimension.
|
|
451
|
+
"""
|
|
452
|
+
assert dim < len(
|
|
453
|
+
self.data_shapes
|
|
454
|
+
), f"Dimension {dim} is out of bounds for data shape {self.data_shapes}"
|
|
455
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
456
|
+
|
|
457
|
+
if dim > 0:
|
|
458
|
+
dataset_idx = dataset_idx % self.grid_count(dim - 1)
|
|
459
|
+
|
|
460
|
+
dim_index = dataset_idx // self.grid_count(dim)
|
|
461
|
+
if only_end:
|
|
462
|
+
return dim_index == self.get_individual_dim_grid_count(dim) - 1
|
|
463
|
+
|
|
464
|
+
return (
|
|
465
|
+
dim_index == 0 or dim_index == self.get_individual_dim_grid_count(dim) - 1
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
def next_grid_along_dim(self, dataset_idx: int, dim: int):
|
|
469
|
+
"""
|
|
470
|
+
Returns the index of the grid in the specified dimension in the specified direction.
|
|
471
|
+
"""
|
|
472
|
+
assert dim < len(
|
|
473
|
+
self.data_shapes
|
|
474
|
+
), f"Dimension {dim} is out of bounds for data shape {self.data_shapes}"
|
|
475
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
476
|
+
new_idx = dataset_idx + self.grid_count(dim)
|
|
477
|
+
if new_idx >= self.total_grid_count():
|
|
478
|
+
return None
|
|
479
|
+
return new_idx
|
|
480
|
+
|
|
481
|
+
def prev_grid_along_dim(self, dataset_idx: int, dim: int):
|
|
482
|
+
"""
|
|
483
|
+
Returns the index of the grid in the specified dimension in the specified direction.
|
|
484
|
+
"""
|
|
485
|
+
assert dim < len(
|
|
486
|
+
self.data_shapes
|
|
487
|
+
), f"Dimension {dim} is out of bounds for data shape {self.data_shapes}"
|
|
488
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
489
|
+
new_idx = dataset_idx - self.grid_count(dim)
|
|
490
|
+
if new_idx < 0:
|
|
491
|
+
return None
|
|
@@ -14,10 +14,11 @@ import matplotlib.pyplot as plt
|
|
|
14
14
|
import numpy as np
|
|
15
15
|
import torch
|
|
16
16
|
from matplotlib.gridspec import GridSpec
|
|
17
|
-
from torch.utils.data import DataLoader, Dataset
|
|
17
|
+
from torch.utils.data import DataLoader, Dataset
|
|
18
18
|
from tqdm import tqdm
|
|
19
19
|
|
|
20
20
|
from careamics.lightning import VAEModule
|
|
21
|
+
from careamics.lvae_training.dataset import MultiChDloaderRef
|
|
21
22
|
from careamics.utils.metrics import scale_invariant_psnr
|
|
22
23
|
|
|
23
24
|
|
|
@@ -542,7 +543,9 @@ def get_predictions(
|
|
|
542
543
|
mmse_count=mmse_count,
|
|
543
544
|
num_workers=num_workers,
|
|
544
545
|
)
|
|
546
|
+
# TODO stitching still not working properly for weirdly shaped images
|
|
545
547
|
# get filename without extension and path
|
|
548
|
+
# TODO in the ref ds this is the name of a folder not file :(
|
|
546
549
|
filename = dset._fpath.name
|
|
547
550
|
return (
|
|
548
551
|
{filename: stitched_predictions},
|
|
@@ -656,8 +659,14 @@ def get_single_file_mmse(
|
|
|
656
659
|
|
|
657
660
|
tiles_arr = np.concatenate(tile_mmse, axis=0)
|
|
658
661
|
tile_stds = np.concatenate(tile_stds, axis=0)
|
|
659
|
-
|
|
660
|
-
|
|
662
|
+
# TODO temporary hack, because of the stupid jupyter!
|
|
663
|
+
# If a user reruns a cell with class definition, isinstance will return False
|
|
664
|
+
if str(MultiChDloaderRef).split(".")[-1] == str(dset.__class__).split(".")[-1]:
|
|
665
|
+
stitch_func = stitch_predictions_general
|
|
666
|
+
else:
|
|
667
|
+
stitch_func = stitch_predictions_new
|
|
668
|
+
stitched_predictions = stitch_func(tiles_arr, dset)
|
|
669
|
+
stitched_stds = stitch_func(tile_stds, dset)
|
|
661
670
|
return stitched_predictions, stitched_stds
|
|
662
671
|
|
|
663
672
|
|
|
@@ -873,3 +882,84 @@ def stitch_predictions_new(predictions, dset):
|
|
|
873
882
|
raise ValueError(f"Unsupported shape {output.shape}")
|
|
874
883
|
|
|
875
884
|
return output
|
|
885
|
+
|
|
886
|
+
|
|
887
|
+
def stitch_predictions_general(predictions, dset):
|
|
888
|
+
"""Stitching for the dataset with multiple files of different shape."""
|
|
889
|
+
mng = dset.idx_manager
|
|
890
|
+
|
|
891
|
+
# TODO assert all shapes are equal len
|
|
892
|
+
# adjust number of channels to match with prediction shape #TODO ugly, refac!
|
|
893
|
+
shapes = []
|
|
894
|
+
for shape in dset.get_data_shapes()[0]:
|
|
895
|
+
shapes.append((predictions.shape[1],) + shape[1:])
|
|
896
|
+
|
|
897
|
+
output = [np.zeros(shape, dtype=predictions.dtype) for shape in shapes]
|
|
898
|
+
# frame_shape = dset.get_data_shape()[:-1]
|
|
899
|
+
for patch_idx in range(predictions.shape[0]):
|
|
900
|
+
# grid start, grid end
|
|
901
|
+
# channel_idx is 0 because during prediction we're only use one channel. # TODO revisit this
|
|
902
|
+
# 0th dimension is sample index in the output list
|
|
903
|
+
grid_coords = np.array(
|
|
904
|
+
mng.get_location_from_patch_idx(channel_idx=0, patch_idx=patch_idx),
|
|
905
|
+
dtype=int,
|
|
906
|
+
)
|
|
907
|
+
sample_idx = grid_coords[0]
|
|
908
|
+
grid_start = grid_coords[1:]
|
|
909
|
+
# from here on, coordinates are relative to the sample(file in the list of inputs)
|
|
910
|
+
grid_end = grid_start + mng.grid_shape
|
|
911
|
+
|
|
912
|
+
# patch start, patch end
|
|
913
|
+
patch_start = grid_start - mng.patch_offset()
|
|
914
|
+
patch_end = patch_start + mng.patch_shape
|
|
915
|
+
|
|
916
|
+
# valid grid start, valid grid end
|
|
917
|
+
valid_grid_start = np.array([max(0, x) for x in grid_start], dtype=int)
|
|
918
|
+
valid_grid_end = np.array(
|
|
919
|
+
[min(x, y) for x, y in zip(grid_end, shapes[sample_idx])], dtype=int
|
|
920
|
+
)
|
|
921
|
+
|
|
922
|
+
if mng.tiling_mode == TilingMode.ShiftBoundary:
|
|
923
|
+
for dim in range(len(valid_grid_start)):
|
|
924
|
+
if patch_start[dim] == 0:
|
|
925
|
+
valid_grid_start[dim] = 0
|
|
926
|
+
if patch_end[dim] == mng.data_shape[dim]:
|
|
927
|
+
valid_grid_end[dim] = mng.data_shape[dim]
|
|
928
|
+
|
|
929
|
+
# relative start, relative end. This will be used on pred_tiled
|
|
930
|
+
relative_start = valid_grid_start - patch_start
|
|
931
|
+
relative_end = relative_start + (valid_grid_end - valid_grid_start)
|
|
932
|
+
|
|
933
|
+
for ch_idx in range(predictions.shape[1]):
|
|
934
|
+
if len(output[sample_idx].shape) == 3:
|
|
935
|
+
# starting from 1 because 0th dimension is channel relative to input
|
|
936
|
+
# channel dimension for stitched output is relative to model output
|
|
937
|
+
output[sample_idx][
|
|
938
|
+
ch_idx,
|
|
939
|
+
valid_grid_start[1] : valid_grid_end[1],
|
|
940
|
+
valid_grid_start[2] : valid_grid_end[2],
|
|
941
|
+
] = predictions[patch_idx][
|
|
942
|
+
ch_idx,
|
|
943
|
+
relative_start[1] : relative_end[1],
|
|
944
|
+
relative_start[2] : relative_end[2],
|
|
945
|
+
]
|
|
946
|
+
elif len(output[sample_idx].shape) == 4:
|
|
947
|
+
assert (
|
|
948
|
+
valid_grid_end[0] - valid_grid_start[0] == 1
|
|
949
|
+
), "Only one frame is supported"
|
|
950
|
+
output[
|
|
951
|
+
ch_idx,
|
|
952
|
+
valid_grid_start[0],
|
|
953
|
+
valid_grid_end[1] : valid_grid_end[1],
|
|
954
|
+
valid_grid_start[2] : valid_grid_end[2],
|
|
955
|
+
valid_grid_start[3] : valid_grid_end[3],
|
|
956
|
+
] = predictions[patch_idx][
|
|
957
|
+
ch_idx,
|
|
958
|
+
relative_start[1] : relative_end[1],
|
|
959
|
+
relative_start[2] : relative_end[2],
|
|
960
|
+
relative_start[3] : relative_end[3],
|
|
961
|
+
]
|
|
962
|
+
else:
|
|
963
|
+
raise ValueError(f"Unsupported shape {output.shape}")
|
|
964
|
+
|
|
965
|
+
return output
|
careamics/transforms/compose.py
CHANGED
|
@@ -86,6 +86,7 @@ class Compose:
|
|
|
86
86
|
*params, _ = t(*params) # ignore additional_arrays dict
|
|
87
87
|
|
|
88
88
|
# avoid None values that create problems for collating
|
|
89
|
+
# TODO: removing None should be handled in dataset, not here
|
|
89
90
|
return tuple(p for p in params if p is not None)
|
|
90
91
|
|
|
91
92
|
def _chain_transforms_additional_arrays(
|