careamics 0.0.3__py3-none-any.whl → 0.0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +25 -17
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +391 -0
- careamics/cli/main.py +134 -0
- careamics/config/architectures/lvae_model.py +0 -4
- careamics/config/configuration_factory.py +480 -177
- careamics/config/configuration_model.py +1 -2
- careamics/config/data_model.py +1 -15
- careamics/config/fcn_algorithm_model.py +14 -9
- careamics/config/likelihood_model.py +21 -4
- careamics/config/nm_model.py +31 -5
- careamics/config/optimizer_models.py +3 -1
- careamics/config/support/supported_optimizers.py +1 -1
- careamics/config/support/supported_transforms.py +1 -0
- careamics/config/training_model.py +35 -6
- careamics/config/transformations/__init__.py +4 -1
- careamics/config/transformations/transform_union.py +20 -0
- careamics/config/vae_algorithm_model.py +2 -36
- careamics/dataset/tiling/lvae_tiled_patching.py +90 -8
- careamics/lightning/lightning_module.py +10 -8
- careamics/lightning/train_data_module.py +2 -2
- careamics/losses/loss_factory.py +3 -3
- careamics/losses/lvae/losses.py +2 -2
- careamics/lvae_training/dataset/__init__.py +15 -0
- careamics/lvae_training/dataset/{vae_data_config.py → config.py} +25 -81
- careamics/lvae_training/dataset/lc_dataset.py +28 -20
- careamics/lvae_training/dataset/{vae_dataset.py → multich_dataset.py} +91 -51
- careamics/lvae_training/dataset/multifile_dataset.py +334 -0
- careamics/lvae_training/dataset/types.py +43 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +232 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +109 -64
- careamics/lvae_training/get_config.py +1 -1
- careamics/lvae_training/train_lvae.py +1 -1
- careamics/model_io/bioimage/bioimage_utils.py +4 -2
- careamics/model_io/bmz_io.py +6 -5
- careamics/models/lvae/likelihoods.py +18 -9
- careamics/models/lvae/lvae.py +12 -16
- careamics/models/lvae/noise_models.py +1 -1
- careamics/transforms/compose.py +90 -15
- careamics/transforms/n2v_manipulate.py +6 -2
- careamics/transforms/normalize.py +14 -3
- careamics/transforms/xy_flip.py +16 -6
- careamics/transforms/xy_random_rotate90.py +16 -7
- careamics/utils/metrics.py +204 -24
- careamics/utils/serializers.py +60 -0
- {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/METADATA +4 -3
- {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/RECORD +54 -43
- careamics-0.0.4.1.dist-info/entry_points.txt +2 -0
- careamics/lvae_training/dataset/data_utils.py +0 -701
- careamics/lvae_training/dataset/lc_dataset_config.py +0 -13
- {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/WHEEL +0 -0
- {careamics-0.0.3.dist-info → careamics-0.0.4.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from careamics.lvae_training.dataset.types import TilingMode
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class GridIndexManager:
|
|
10
|
+
data_shape: tuple
|
|
11
|
+
grid_shape: tuple
|
|
12
|
+
patch_shape: tuple
|
|
13
|
+
tiling_mode: TilingMode
|
|
14
|
+
|
|
15
|
+
# Patch is centered on index in the grid, grid size not used in training,
|
|
16
|
+
# used only during val / test, grid size controls the overlap of the patches
|
|
17
|
+
# in training you only get random patches every time
|
|
18
|
+
# For borders - just cropped the data, so it perfectly divisible
|
|
19
|
+
|
|
20
|
+
def __post_init__(self):
|
|
21
|
+
assert len(self.data_shape) == len(
|
|
22
|
+
self.grid_shape
|
|
23
|
+
), f"Data shape:{self.data_shape} and grid size:{self.grid_shape} must have the same dimension"
|
|
24
|
+
assert len(self.data_shape) == len(
|
|
25
|
+
self.patch_shape
|
|
26
|
+
), f"Data shape:{self.data_shape} and patch shape:{self.patch_shape} must have the same dimension"
|
|
27
|
+
innerpad = np.array(self.patch_shape) - np.array(self.grid_shape)
|
|
28
|
+
for dim, pad in enumerate(innerpad):
|
|
29
|
+
if pad < 0:
|
|
30
|
+
raise ValueError(
|
|
31
|
+
f"Patch shape:{self.patch_shape} must be greater than or equal to grid shape:{self.grid_shape} in dimension {dim}"
|
|
32
|
+
)
|
|
33
|
+
if pad % 2 != 0:
|
|
34
|
+
raise ValueError(
|
|
35
|
+
f"Patch shape:{self.patch_shape} must have even padding in dimension {dim}"
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
def patch_offset(self):
|
|
39
|
+
return (np.array(self.patch_shape) - np.array(self.grid_shape)) // 2
|
|
40
|
+
|
|
41
|
+
def get_individual_dim_grid_count(self, dim: int):
|
|
42
|
+
"""
|
|
43
|
+
Returns the number of the grid in the specified dimension, ignoring all other dimensions.
|
|
44
|
+
"""
|
|
45
|
+
assert dim < len(
|
|
46
|
+
self.data_shape
|
|
47
|
+
), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
48
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
49
|
+
|
|
50
|
+
if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
|
|
51
|
+
return self.data_shape[dim]
|
|
52
|
+
elif self.tiling_mode == TilingMode.PadBoundary:
|
|
53
|
+
return int(np.ceil(self.data_shape[dim] / self.grid_shape[dim]))
|
|
54
|
+
elif self.tiling_mode == TilingMode.ShiftBoundary:
|
|
55
|
+
excess_size = self.patch_shape[dim] - self.grid_shape[dim]
|
|
56
|
+
return int(
|
|
57
|
+
np.ceil((self.data_shape[dim] - excess_size) / self.grid_shape[dim])
|
|
58
|
+
)
|
|
59
|
+
else:
|
|
60
|
+
excess_size = self.patch_shape[dim] - self.grid_shape[dim]
|
|
61
|
+
return int(
|
|
62
|
+
np.floor((self.data_shape[dim] - excess_size) / self.grid_shape[dim])
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
def total_grid_count(self):
|
|
66
|
+
"""
|
|
67
|
+
Returns the total number of grids in the dataset.
|
|
68
|
+
"""
|
|
69
|
+
return self.grid_count(0) * self.get_individual_dim_grid_count(0)
|
|
70
|
+
|
|
71
|
+
def grid_count(self, dim: int):
|
|
72
|
+
"""
|
|
73
|
+
Returns the total number of grids for one value in the specified dimension.
|
|
74
|
+
"""
|
|
75
|
+
assert dim < len(
|
|
76
|
+
self.data_shape
|
|
77
|
+
), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
78
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
79
|
+
if dim == len(self.data_shape) - 1:
|
|
80
|
+
return 1
|
|
81
|
+
|
|
82
|
+
return self.get_individual_dim_grid_count(dim + 1) * self.grid_count(dim + 1)
|
|
83
|
+
|
|
84
|
+
def get_grid_index(self, dim: int, coordinate: int):
|
|
85
|
+
"""
|
|
86
|
+
Returns the index of the grid in the specified dimension.
|
|
87
|
+
"""
|
|
88
|
+
assert dim < len(
|
|
89
|
+
self.data_shape
|
|
90
|
+
), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
91
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
92
|
+
assert (
|
|
93
|
+
coordinate < self.data_shape[dim]
|
|
94
|
+
), f"Coordinate {coordinate} is out of bounds for data shape {self.data_shape}"
|
|
95
|
+
|
|
96
|
+
if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
|
|
97
|
+
return coordinate
|
|
98
|
+
elif self.tiling_mode == TilingMode.PadBoundary: # self.trim_boundary is False:
|
|
99
|
+
return np.floor(coordinate / self.grid_shape[dim])
|
|
100
|
+
elif self.tiling_mode == TilingMode.TrimBoundary:
|
|
101
|
+
excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
|
|
102
|
+
# can be <0 if coordinate is in [0,grid_shape[dim]]
|
|
103
|
+
return max(0, np.floor((coordinate - excess_size) / self.grid_shape[dim]))
|
|
104
|
+
elif self.tiling_mode == TilingMode.ShiftBoundary:
|
|
105
|
+
excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
|
|
106
|
+
if coordinate + self.grid_shape[dim] + excess_size == self.data_shape[dim]:
|
|
107
|
+
return self.get_individual_dim_grid_count(dim) - 1
|
|
108
|
+
else:
|
|
109
|
+
# can be <0 if coordinate is in [0,grid_shape[dim]]
|
|
110
|
+
return max(
|
|
111
|
+
0, np.floor((coordinate - excess_size) / self.grid_shape[dim])
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
else:
|
|
115
|
+
raise ValueError(f"Unsupported tiling mode {self.tiling_mode}")
|
|
116
|
+
|
|
117
|
+
def dataset_idx_from_grid_idx(self, grid_idx: tuple):
|
|
118
|
+
"""
|
|
119
|
+
Returns the index of the grid in the dataset.
|
|
120
|
+
"""
|
|
121
|
+
assert len(grid_idx) == len(
|
|
122
|
+
self.data_shape
|
|
123
|
+
), f"Dimension indices {grid_idx} must have the same dimension as data shape {self.data_shape}"
|
|
124
|
+
index = 0
|
|
125
|
+
for dim in range(len(grid_idx)):
|
|
126
|
+
index += grid_idx[dim] * self.grid_count(dim)
|
|
127
|
+
return index
|
|
128
|
+
|
|
129
|
+
def get_patch_location_from_dataset_idx(self, dataset_idx: int):
|
|
130
|
+
"""
|
|
131
|
+
Returns the patch location of the grid in the dataset.
|
|
132
|
+
"""
|
|
133
|
+
grid_location = self.get_location_from_dataset_idx(dataset_idx)
|
|
134
|
+
offset = self.patch_offset()
|
|
135
|
+
return tuple(np.array(grid_location) - np.array(offset))
|
|
136
|
+
|
|
137
|
+
def get_dataset_idx_from_grid_location(self, location: tuple):
|
|
138
|
+
assert len(location) == len(
|
|
139
|
+
self.data_shape
|
|
140
|
+
), f"Location {location} must have the same dimension as data shape {self.data_shape}"
|
|
141
|
+
grid_idx = [
|
|
142
|
+
self.get_grid_index(dim, location[dim]) for dim in range(len(location))
|
|
143
|
+
]
|
|
144
|
+
return self.dataset_idx_from_grid_idx(tuple(grid_idx))
|
|
145
|
+
|
|
146
|
+
def get_gridstart_location_from_dim_index(self, dim: int, dim_index: int):
|
|
147
|
+
"""
|
|
148
|
+
Returns the grid-start coordinate of the grid in the specified dimension.
|
|
149
|
+
"""
|
|
150
|
+
assert dim < len(
|
|
151
|
+
self.data_shape
|
|
152
|
+
), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
153
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
154
|
+
assert dim_index < self.get_individual_dim_grid_count(
|
|
155
|
+
dim
|
|
156
|
+
), f"Dimension index {dim_index} is out of bounds for data shape {self.data_shape}"
|
|
157
|
+
|
|
158
|
+
if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
|
|
159
|
+
return dim_index
|
|
160
|
+
elif self.tiling_mode == TilingMode.PadBoundary:
|
|
161
|
+
return dim_index * self.grid_shape[dim]
|
|
162
|
+
elif self.tiling_mode == TilingMode.TrimBoundary:
|
|
163
|
+
excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
|
|
164
|
+
return dim_index * self.grid_shape[dim] + excess_size
|
|
165
|
+
elif self.tiling_mode == TilingMode.ShiftBoundary:
|
|
166
|
+
excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
|
|
167
|
+
if dim_index < self.get_individual_dim_grid_count(dim) - 1:
|
|
168
|
+
return dim_index * self.grid_shape[dim] + excess_size
|
|
169
|
+
else:
|
|
170
|
+
# on boundary. grid should be placed such that the patch covers the entire data.
|
|
171
|
+
return self.data_shape[dim] - self.grid_shape[dim] - excess_size
|
|
172
|
+
else:
|
|
173
|
+
raise ValueError(f"Unsupported tiling mode {self.tiling_mode}")
|
|
174
|
+
|
|
175
|
+
def get_location_from_dataset_idx(self, dataset_idx: int):
|
|
176
|
+
"""
|
|
177
|
+
Returns the start location of the grid in the dataset.
|
|
178
|
+
"""
|
|
179
|
+
grid_idx = []
|
|
180
|
+
for dim in range(len(self.data_shape)):
|
|
181
|
+
grid_idx.append(dataset_idx // self.grid_count(dim))
|
|
182
|
+
dataset_idx = dataset_idx % self.grid_count(dim)
|
|
183
|
+
location = [
|
|
184
|
+
self.get_gridstart_location_from_dim_index(dim, grid_idx[dim])
|
|
185
|
+
for dim in range(len(self.data_shape))
|
|
186
|
+
]
|
|
187
|
+
return tuple(location)
|
|
188
|
+
|
|
189
|
+
def on_boundary(self, dataset_idx: int, dim: int, only_end: bool = False):
|
|
190
|
+
"""
|
|
191
|
+
Returns True if the grid is on the boundary in the specified dimension.
|
|
192
|
+
"""
|
|
193
|
+
assert dim < len(
|
|
194
|
+
self.data_shape
|
|
195
|
+
), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
196
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
197
|
+
|
|
198
|
+
if dim > 0:
|
|
199
|
+
dataset_idx = dataset_idx % self.grid_count(dim - 1)
|
|
200
|
+
|
|
201
|
+
dim_index = dataset_idx // self.grid_count(dim)
|
|
202
|
+
if only_end:
|
|
203
|
+
return dim_index == self.get_individual_dim_grid_count(dim) - 1
|
|
204
|
+
|
|
205
|
+
return (
|
|
206
|
+
dim_index == 0 or dim_index == self.get_individual_dim_grid_count(dim) - 1
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
def next_grid_along_dim(self, dataset_idx: int, dim: int):
|
|
210
|
+
"""
|
|
211
|
+
Returns the index of the grid in the specified dimension in the specified direction.
|
|
212
|
+
"""
|
|
213
|
+
assert dim < len(
|
|
214
|
+
self.data_shape
|
|
215
|
+
), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
216
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
217
|
+
new_idx = dataset_idx + self.grid_count(dim)
|
|
218
|
+
if new_idx >= self.total_grid_count():
|
|
219
|
+
return None
|
|
220
|
+
return new_idx
|
|
221
|
+
|
|
222
|
+
def prev_grid_along_dim(self, dataset_idx: int, dim: int):
|
|
223
|
+
"""
|
|
224
|
+
Returns the index of the grid in the specified dimension in the specified direction.
|
|
225
|
+
"""
|
|
226
|
+
assert dim < len(
|
|
227
|
+
self.data_shape
|
|
228
|
+
), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
229
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
230
|
+
new_idx = dataset_idx - self.grid_count(dim)
|
|
231
|
+
if new_idx < 0:
|
|
232
|
+
return None
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class IndexSwitcher:
|
|
5
|
+
"""
|
|
6
|
+
The idea is to switch from valid indices for target to invalid indices for target.
|
|
7
|
+
If index in invalid for the target, then we return all zero vector as target.
|
|
8
|
+
This combines both logic:
|
|
9
|
+
1. Using less amount of total data.
|
|
10
|
+
2. Using less amount of target data but using full data.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def __init__(self, idx_manager, data_config, patch_size) -> None:
|
|
14
|
+
self.idx_manager = idx_manager
|
|
15
|
+
self._data_shape = self.idx_manager.get_data_shape()
|
|
16
|
+
self._training_validtarget_fraction = data_config.get(
|
|
17
|
+
"training_validtarget_fraction", 1.0
|
|
18
|
+
)
|
|
19
|
+
self._validtarget_ceilT = int(
|
|
20
|
+
np.ceil(self._data_shape[0] * self._training_validtarget_fraction)
|
|
21
|
+
)
|
|
22
|
+
self._patch_size = patch_size
|
|
23
|
+
assert (
|
|
24
|
+
data_config.deterministic_grid is True
|
|
25
|
+
), "This only works when the dataset has deterministic grid. Needed randomness comes from this class."
|
|
26
|
+
assert (
|
|
27
|
+
"grid_size" in data_config and data_config.grid_size == 1
|
|
28
|
+
), "We need a one to one mapping between index and h, w, t"
|
|
29
|
+
|
|
30
|
+
self._h_validmax, self._w_validmax = self.get_reduced_frame_size(
|
|
31
|
+
self._data_shape[:3], self._training_validtarget_fraction
|
|
32
|
+
)
|
|
33
|
+
if self._h_validmax < self._patch_size or self._w_validmax < self._patch_size:
|
|
34
|
+
print(
|
|
35
|
+
"WARNING: The valid target size is smaller than the patch size. This will result in all zero target. so, we are ignoring this frame for target."
|
|
36
|
+
)
|
|
37
|
+
self._h_validmax = 0
|
|
38
|
+
self._w_validmax = 0
|
|
39
|
+
|
|
40
|
+
print(
|
|
41
|
+
f"[{self.__class__.__name__}] Target Indices: [0,{self._validtarget_ceilT - 1}]. Index={self._validtarget_ceilT - 1} has shape [:{self._h_validmax},:{self._w_validmax}]. Available data: {self._data_shape[0]}"
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
def get_valid_target_index(self):
|
|
45
|
+
"""
|
|
46
|
+
Returns an index which corresponds to a frame which is expected to have a target.
|
|
47
|
+
"""
|
|
48
|
+
_, h, w, _ = self._data_shape
|
|
49
|
+
framepixelcount = h * w
|
|
50
|
+
targetpixels = np.array(
|
|
51
|
+
[framepixelcount] * (self._validtarget_ceilT - 1)
|
|
52
|
+
+ [self._h_validmax * self._w_validmax]
|
|
53
|
+
)
|
|
54
|
+
targetpixels = targetpixels / np.sum(targetpixels)
|
|
55
|
+
t = np.random.choice(self._validtarget_ceilT, p=targetpixels)
|
|
56
|
+
# t = np.random.randint(0, self._validtarget_ceilT) if self._validtarget_ceilT >= 1 else 0
|
|
57
|
+
h, w = self.get_valid_target_hw(t)
|
|
58
|
+
index = self.idx_manager.idx_from_hwt(h, w, t)
|
|
59
|
+
# print('Valid', index, h,w,t)
|
|
60
|
+
return index
|
|
61
|
+
|
|
62
|
+
def get_invalid_target_index(self):
|
|
63
|
+
# if self._validtarget_ceilT == 0:
|
|
64
|
+
# TODO: There may not be enough data for this to work. The better way is to skip using 0 for invalid target.
|
|
65
|
+
# t = np.random.randint(1, self._data_shape[0])
|
|
66
|
+
# elif self._validtarget_ceilT < self._data_shape[0]:
|
|
67
|
+
# t = np.random.randint(self._validtarget_ceilT, self._data_shape[0])
|
|
68
|
+
# else:
|
|
69
|
+
# t = self._validtarget_ceilT - 1
|
|
70
|
+
# 5
|
|
71
|
+
# 1.2 => 2
|
|
72
|
+
total_t, h, w, _ = self._data_shape
|
|
73
|
+
framepixelcount = h * w
|
|
74
|
+
available_h = h - self._h_validmax
|
|
75
|
+
if available_h < self._patch_size:
|
|
76
|
+
available_h = 0
|
|
77
|
+
available_w = w - self._w_validmax
|
|
78
|
+
if available_w < self._patch_size:
|
|
79
|
+
available_w = 0
|
|
80
|
+
|
|
81
|
+
targetpixels = np.array(
|
|
82
|
+
[available_h * available_w]
|
|
83
|
+
+ [framepixelcount] * (total_t - self._validtarget_ceilT)
|
|
84
|
+
)
|
|
85
|
+
t_probab = targetpixels / np.sum(targetpixels)
|
|
86
|
+
t = np.random.choice(
|
|
87
|
+
np.arange(self._validtarget_ceilT - 1, total_t), p=t_probab
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
h, w = self.get_invalid_target_hw(t)
|
|
91
|
+
index = self.idx_manager.idx_from_hwt(h, w, t)
|
|
92
|
+
# print('Invalid', index, h,w,t)
|
|
93
|
+
return index
|
|
94
|
+
|
|
95
|
+
def get_valid_target_hw(self, t):
|
|
96
|
+
"""
|
|
97
|
+
This is the opposite of get_invalid_target_hw. It returns a h,w which is valid for target.
|
|
98
|
+
This is only valid for single frame setup.
|
|
99
|
+
"""
|
|
100
|
+
if t == self._validtarget_ceilT - 1:
|
|
101
|
+
h = np.random.randint(0, self._h_validmax - self._patch_size)
|
|
102
|
+
w = np.random.randint(0, self._w_validmax - self._patch_size)
|
|
103
|
+
else:
|
|
104
|
+
h = np.random.randint(0, self._data_shape[1] - self._patch_size)
|
|
105
|
+
w = np.random.randint(0, self._data_shape[2] - self._patch_size)
|
|
106
|
+
return h, w
|
|
107
|
+
|
|
108
|
+
def get_invalid_target_hw(self, t):
|
|
109
|
+
"""
|
|
110
|
+
This is the opposite of get_valid_target_hw. It returns a h,w which is not valid for target.
|
|
111
|
+
This is only valid for single frame setup.
|
|
112
|
+
"""
|
|
113
|
+
if t == self._validtarget_ceilT - 1:
|
|
114
|
+
h = np.random.randint(
|
|
115
|
+
self._h_validmax, self._data_shape[1] - self._patch_size
|
|
116
|
+
)
|
|
117
|
+
w = np.random.randint(
|
|
118
|
+
self._w_validmax, self._data_shape[2] - self._patch_size
|
|
119
|
+
)
|
|
120
|
+
else:
|
|
121
|
+
h = np.random.randint(0, self._data_shape[1] - self._patch_size)
|
|
122
|
+
w = np.random.randint(0, self._data_shape[2] - self._patch_size)
|
|
123
|
+
return h, w
|
|
124
|
+
|
|
125
|
+
def _get_tidx(self, index):
|
|
126
|
+
if isinstance(index, int) or isinstance(index, np.int64):
|
|
127
|
+
idx = index
|
|
128
|
+
else:
|
|
129
|
+
idx = index[0]
|
|
130
|
+
return self.idx_manager.get_t(idx)
|
|
131
|
+
|
|
132
|
+
def index_should_have_target(self, index):
|
|
133
|
+
tidx = self._get_tidx(index)
|
|
134
|
+
if tidx < self._validtarget_ceilT - 1:
|
|
135
|
+
return True
|
|
136
|
+
elif tidx > self._validtarget_ceilT - 1:
|
|
137
|
+
return False
|
|
138
|
+
else:
|
|
139
|
+
h, w, _ = self.idx_manager.hwt_from_idx(index)
|
|
140
|
+
return (
|
|
141
|
+
h + self._patch_size < self._h_validmax
|
|
142
|
+
and w + self._patch_size < self._w_validmax
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
@staticmethod
|
|
146
|
+
def get_reduced_frame_size(data_shape_nhw, fraction):
|
|
147
|
+
n, h, w = data_shape_nhw
|
|
148
|
+
|
|
149
|
+
framepixelcount = h * w
|
|
150
|
+
targetpixelcount = int(n * framepixelcount * fraction)
|
|
151
|
+
|
|
152
|
+
# We are currently supporting this only when there is just one frame.
|
|
153
|
+
# if np.ceil(pixelcount / framepixelcount) > 1:
|
|
154
|
+
# return None, None
|
|
155
|
+
|
|
156
|
+
lastframepixelcount = targetpixelcount % framepixelcount
|
|
157
|
+
assert data_shape_nhw[1] == data_shape_nhw[2]
|
|
158
|
+
if lastframepixelcount > 0:
|
|
159
|
+
new_size = int(np.sqrt(lastframepixelcount))
|
|
160
|
+
return new_size, new_size
|
|
161
|
+
else:
|
|
162
|
+
assert (
|
|
163
|
+
targetpixelcount / framepixelcount >= 1
|
|
164
|
+
), "This is not possible in euclidean space :D (so this is a bug)"
|
|
165
|
+
return h, w
|
|
@@ -14,13 +14,19 @@ import matplotlib
|
|
|
14
14
|
import matplotlib.pyplot as plt
|
|
15
15
|
import numpy as np
|
|
16
16
|
import torch
|
|
17
|
+
from torch import nn
|
|
18
|
+
from torch.utils.data import Dataset
|
|
17
19
|
from matplotlib.gridspec import GridSpec
|
|
18
20
|
from torch.utils.data import DataLoader
|
|
19
21
|
from tqdm import tqdm
|
|
20
22
|
|
|
23
|
+
from careamics.lightning import VAEModule
|
|
24
|
+
from careamics.losses.lvae.losses import (
|
|
25
|
+
get_reconstruction_loss,
|
|
26
|
+
reconstruction_loss_musplit_denoisplit,
|
|
27
|
+
)
|
|
21
28
|
from careamics.models.lvae.utils import ModelType
|
|
22
|
-
|
|
23
|
-
from .metrics import RangeInvariantPsnr, RunningPSNR
|
|
29
|
+
from careamics.utils.metrics import scale_invariant_psnr, RunningPSNR
|
|
24
30
|
|
|
25
31
|
|
|
26
32
|
# ------------------------------------------------------------------------------------------------
|
|
@@ -40,28 +46,26 @@ def clean_ax(ax):
|
|
|
40
46
|
ax.tick_params(left=False, right=False, top=False, bottom=False)
|
|
41
47
|
|
|
42
48
|
|
|
43
|
-
def
|
|
49
|
+
def get_eval_output_dir(
|
|
44
50
|
saveplotsdir: str, patch_size: int, mmse_count: int = 50
|
|
45
51
|
) -> str:
|
|
46
52
|
"""
|
|
47
53
|
Given the path to a root directory to save plots, patch size, and mmse count,
|
|
48
54
|
it returns the specific directory to save the plots.
|
|
49
55
|
"""
|
|
50
|
-
|
|
51
|
-
saveplotsdir, f"
|
|
56
|
+
eval_out_dir = os.path.join(
|
|
57
|
+
saveplotsdir, f"eval_outputs/patch_{patch_size}_mmse_{mmse_count}"
|
|
52
58
|
)
|
|
53
|
-
os.makedirs(
|
|
54
|
-
print(
|
|
55
|
-
return
|
|
59
|
+
os.makedirs(eval_out_dir, exist_ok=True)
|
|
60
|
+
print(eval_out_dir)
|
|
61
|
+
return eval_out_dir
|
|
56
62
|
|
|
57
63
|
|
|
58
64
|
def get_psnr_str(tar_hsnr, pred, col_idx):
|
|
59
65
|
"""
|
|
60
66
|
Compute PSNR between the ground truth (`tar_hsnr`) and the predicted image (`pred`).
|
|
61
67
|
"""
|
|
62
|
-
return (
|
|
63
|
-
f"{RangeInvariantPsnr(tar_hsnr[col_idx][None], pred[col_idx][None]).item():.1f}"
|
|
64
|
-
)
|
|
68
|
+
return f"{scale_invariant_psnr(tar_hsnr[col_idx][None], pred[col_idx][None]).item():.1f}"
|
|
65
69
|
|
|
66
70
|
|
|
67
71
|
def add_psnr_str(ax_, psnr):
|
|
@@ -499,20 +503,40 @@ def get_predictions(idx, val_dset, model, mmse_count=50, patch_size=256):
|
|
|
499
503
|
|
|
500
504
|
|
|
501
505
|
def get_dset_predictions(
|
|
502
|
-
model,
|
|
503
|
-
dset,
|
|
506
|
+
model: VAEModule,
|
|
507
|
+
dset: Dataset,
|
|
504
508
|
batch_size: int,
|
|
505
|
-
|
|
509
|
+
loss_type: Literal["musplit", "denoisplit", "denoisplit_musplit"],
|
|
506
510
|
mmse_count: int = 1,
|
|
507
511
|
num_workers: int = 4,
|
|
508
|
-
):
|
|
509
|
-
"""
|
|
510
|
-
Get predictions from a model for the entire dataset.
|
|
512
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[float]]:
|
|
513
|
+
"""Get patch-wise predictions from a model for the entire dataset.
|
|
511
514
|
|
|
512
515
|
Parameters
|
|
513
516
|
----------
|
|
514
|
-
|
|
515
|
-
|
|
517
|
+
model : VAEModule
|
|
518
|
+
Lightning model used for prediction.
|
|
519
|
+
dset : Dataset
|
|
520
|
+
Dataset to predict on.
|
|
521
|
+
batch_size : int
|
|
522
|
+
Batch size to use for prediction.
|
|
523
|
+
loss_type :
|
|
524
|
+
Type of reconstruction loss used by the model, by default `None`.
|
|
525
|
+
mmse_count : int, optional
|
|
526
|
+
Number of samples to generate for each input and then to average over for
|
|
527
|
+
MMSE estimation, by default 1.
|
|
528
|
+
num_workers : int, optional
|
|
529
|
+
Number of workers to use for DataLoader, by default 4.
|
|
530
|
+
|
|
531
|
+
Returns
|
|
532
|
+
-------
|
|
533
|
+
tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[float]]
|
|
534
|
+
Tuple containing:
|
|
535
|
+
- predictions: Predicted images for the dataset.
|
|
536
|
+
- predictions_std: Standard deviation of the predicted images.
|
|
537
|
+
- logvar_arr: Log variance of the predicted images.
|
|
538
|
+
- losses: Reconstruction losses for the predictions.
|
|
539
|
+
- psnr: PSNR values for the predictions.
|
|
516
540
|
"""
|
|
517
541
|
dloader = DataLoader(
|
|
518
542
|
dset,
|
|
@@ -521,69 +545,90 @@ def get_dset_predictions(
|
|
|
521
545
|
shuffle=False,
|
|
522
546
|
batch_size=batch_size,
|
|
523
547
|
)
|
|
524
|
-
|
|
548
|
+
|
|
549
|
+
gauss_likelihood = model.gaussian_likelihood
|
|
550
|
+
nm_likelihood = model.noise_model_likelihood
|
|
551
|
+
|
|
525
552
|
predictions = []
|
|
526
553
|
predictions_std = []
|
|
527
554
|
losses = []
|
|
528
555
|
logvar_arr = []
|
|
529
|
-
|
|
556
|
+
num_channels = dset[0][1].shape[0]
|
|
557
|
+
patch_psnr_channels = [RunningPSNR() for _ in range(num_channels)]
|
|
530
558
|
with torch.no_grad():
|
|
531
|
-
for batch in tqdm(dloader):
|
|
532
|
-
inp, tar = batch
|
|
559
|
+
for batch in tqdm(dloader, desc="Predicting patches"):
|
|
560
|
+
inp, tar = batch
|
|
533
561
|
inp = inp.cuda()
|
|
534
562
|
tar = tar.cuda()
|
|
535
563
|
|
|
536
|
-
|
|
564
|
+
rec_img_list = []
|
|
537
565
|
for mmse_idx in range(mmse_count):
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
566
|
+
|
|
567
|
+
# TODO: case of HDN left for future refactoring
|
|
568
|
+
# if model_type == ModelType.Denoiser:
|
|
569
|
+
# assert model.denoise_channel in [
|
|
570
|
+
# "Ch1",
|
|
571
|
+
# "Ch2",
|
|
572
|
+
# "input",
|
|
573
|
+
# ], '"all" denoise channel not supported for evaluation. Pick one of "Ch1", "Ch2", "input"'
|
|
574
|
+
|
|
575
|
+
# x_normalized_new, tar_new = model.get_new_input_target(
|
|
576
|
+
# (inp, tar, *batch[2:])
|
|
577
|
+
# )
|
|
578
|
+
# rec, _ = model(x_normalized_new)
|
|
579
|
+
# rec_loss, imgs = model.get_reconstruction_loss(
|
|
580
|
+
# rec,
|
|
581
|
+
# tar,
|
|
582
|
+
# x_normalized_new,
|
|
583
|
+
# return_predicted_img=True,
|
|
584
|
+
# )
|
|
585
|
+
|
|
586
|
+
# get model output
|
|
587
|
+
rec, _ = model(inp)
|
|
588
|
+
|
|
589
|
+
# get reconstructed img
|
|
590
|
+
if model.model.predict_logvar is None:
|
|
591
|
+
rec_img = rec
|
|
592
|
+
logvar = torch.tensor([-1])
|
|
593
|
+
else:
|
|
594
|
+
rec_img, logvar = torch.chunk(rec, chunks=2, dim=1)
|
|
595
|
+
rec_img_list.append(rec_img.cpu().unsqueeze(0)) # add MMSE dim
|
|
596
|
+
logvar_arr.append(logvar.cpu().numpy())
|
|
597
|
+
|
|
598
|
+
# compute reconstruction loss
|
|
599
|
+
if loss_type == "musplit":
|
|
600
|
+
rec_loss = get_reconstruction_loss(
|
|
601
|
+
reconstruction=rec, target=tar, likelihood_obj=gauss_likelihood
|
|
547
602
|
)
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
recon_normalized,
|
|
552
|
-
tar_normalized,
|
|
553
|
-
x_normalized_new,
|
|
554
|
-
return_predicted_img=True,
|
|
603
|
+
elif loss_type == "denoisplit":
|
|
604
|
+
rec_loss = get_reconstruction_loss(
|
|
605
|
+
reconstruction=rec, target=tar, likelihood_obj=nm_likelihood
|
|
555
606
|
)
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
607
|
+
elif loss_type == "denoisplit_musplit":
|
|
608
|
+
rec_loss = reconstruction_loss_musplit_denoisplit(
|
|
609
|
+
predictions=rec,
|
|
610
|
+
targets=tar,
|
|
611
|
+
gaussian_likelihood=gauss_likelihood,
|
|
612
|
+
nm_likelihood=nm_likelihood,
|
|
613
|
+
nm_weight=model.loss_parameters.denoisplit_weight,
|
|
614
|
+
gaussian_weight=model.loss_parameters.musplit_weight,
|
|
562
615
|
)
|
|
616
|
+
rec_loss = {"loss": rec_loss} # hacky, but ok for now
|
|
563
617
|
|
|
618
|
+
# store rec loss values for first pred
|
|
564
619
|
if mmse_idx == 0:
|
|
565
|
-
q_dic = (
|
|
566
|
-
likelihood.distr_params(recon_normalized)
|
|
567
|
-
if likelihood is not None
|
|
568
|
-
else {"logvar": None}
|
|
569
|
-
)
|
|
570
|
-
if q_dic["logvar"] is not None:
|
|
571
|
-
logvar_arr.append(q_dic["logvar"].cpu().numpy())
|
|
572
|
-
else:
|
|
573
|
-
logvar_arr.append(np.array([-1]))
|
|
574
|
-
|
|
575
620
|
try:
|
|
576
621
|
losses.append(rec_loss["loss"].cpu().numpy())
|
|
577
622
|
except:
|
|
578
623
|
losses.append(rec_loss["loss"])
|
|
579
624
|
|
|
580
|
-
|
|
581
|
-
|
|
625
|
+
# update running PSNR
|
|
626
|
+
for i in range(num_channels):
|
|
627
|
+
patch_psnr_channels[i].update(rec_img[:, i], tar[:, i])
|
|
582
628
|
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
mmse_imgs = torch.mean(samples, dim=0)
|
|
629
|
+
# aggregate results
|
|
630
|
+
samples = torch.cat(rec_img_list, dim=0)
|
|
631
|
+
mmse_imgs = torch.mean(samples, dim=0) # avg over MMSE dim
|
|
587
632
|
mmse_std = torch.std(samples, dim=0)
|
|
588
633
|
predictions.append(mmse_imgs.cpu().numpy())
|
|
589
634
|
predictions_std.append(mmse_std.cpu().numpy())
|
|
@@ -591,10 +636,10 @@ def get_dset_predictions(
|
|
|
591
636
|
psnr = [x.get() for x in patch_psnr_channels]
|
|
592
637
|
return (
|
|
593
638
|
np.concatenate(predictions, axis=0),
|
|
594
|
-
np.
|
|
639
|
+
np.concatenate(predictions_std, axis=0),
|
|
595
640
|
np.concatenate(logvar_arr),
|
|
641
|
+
np.array(losses),
|
|
596
642
|
psnr,
|
|
597
|
-
np.concatenate(predictions_std, axis=0),
|
|
598
643
|
)
|
|
599
644
|
|
|
600
645
|
|
|
@@ -24,7 +24,7 @@ from careamics.lvae_training.dataset.data_modules import (
|
|
|
24
24
|
LCMultiChDloader,
|
|
25
25
|
MultiChDloader,
|
|
26
26
|
)
|
|
27
|
-
from careamics.lvae_training.dataset.data_utils import DataSplitType
|
|
27
|
+
from careamics.lvae_training.dataset.utils.data_utils import DataSplitType
|
|
28
28
|
from careamics.lvae_training.lightning_module import LadderVAELight
|
|
29
29
|
from careamics.lvae_training.train_utils import *
|
|
30
30
|
|