careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +1 -14
- careamics/careamist.py +212 -294
- careamics/config/__init__.py +0 -3
- careamics/config/algorithm_model.py +8 -15
- careamics/config/architectures/architecture_model.py +1 -0
- careamics/config/architectures/custom_model.py +5 -3
- careamics/config/architectures/unet_model.py +19 -0
- careamics/config/architectures/vae_model.py +1 -0
- careamics/config/callback_model.py +76 -34
- careamics/config/configuration_factory.py +18 -98
- careamics/config/configuration_model.py +23 -18
- careamics/config/data_model.py +103 -54
- careamics/config/inference_model.py +41 -19
- careamics/config/optimizer_models.py +13 -7
- careamics/config/support/supported_data.py +29 -4
- careamics/config/support/supported_transforms.py +0 -1
- careamics/config/tile_information.py +36 -58
- careamics/config/training_model.py +5 -1
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/validators/validator_utils.py +1 -1
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -7
- careamics/dataset/dataset_utils/file_utils.py +2 -2
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +84 -173
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +97 -250
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/patching.py +97 -52
- careamics/dataset/patching/random_patching.py +9 -4
- careamics/dataset/patching/validate_patch_dimension.py +5 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
- careamics/file_io/__init__.py +7 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -10
- careamics/file_io/write/__init__.py +9 -0
- careamics/file_io/write/get_func.py +59 -0
- careamics/file_io/write/tiff.py +39 -0
- careamics/lightning/__init__.py +17 -0
- careamics/{lightning_module.py → lightning/lightning_module.py} +69 -92
- careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +120 -178
- careamics/{lightning_datamodule.py → lightning/train_data_module.py} +135 -220
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +2 -2
- careamics/model_io/model_io_utils.py +6 -3
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/prediction_outputs.py +137 -0
- careamics/prediction_utils/stitch_prediction.py +103 -0
- careamics/transforms/n2v_manipulate.py +3 -1
- careamics/transforms/normalize.py +139 -68
- careamics/transforms/pixel_manipulation.py +33 -9
- careamics/transforms/tta.py +43 -29
- careamics/utils/__init__.py +2 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/ram.py +2 -2
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +7 -6
- careamics-0.1.0rc8.dist-info/RECORD +135 -0
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +1 -1
- careamics/config/configuration_example.py +0 -89
- careamics/dataset/dataset_utils/read_utils.py +0 -27
- careamics/lightning_prediction_loop.py +0 -118
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -70
- careamics/utils/running_stats.py +0 -43
- careamics-0.1.0rc6.dist-info/RECORD +0 -107
- /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
- /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,395 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Script for utility functions needed by the LVAE model.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Iterable
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn as nn
|
|
10
|
+
import torchvision.transforms.functional as F
|
|
11
|
+
from torch.distributions.normal import Normal
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def torch_nanmean(inp):
|
|
15
|
+
return torch.mean(inp[~inp.isnan()])
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def compute_batch_mean(x):
|
|
19
|
+
N = len(x)
|
|
20
|
+
return x.view(N, -1).mean(dim=1)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def power_of_2(self, x):
|
|
24
|
+
assert isinstance(x, int)
|
|
25
|
+
if x == 1:
|
|
26
|
+
return True
|
|
27
|
+
if x == 0:
|
|
28
|
+
# happens with validation
|
|
29
|
+
return False
|
|
30
|
+
if x % 2 == 1:
|
|
31
|
+
return False
|
|
32
|
+
return self.power_of_2(x // 2)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class Enum:
|
|
36
|
+
@classmethod
|
|
37
|
+
def name(cls, enum_type):
|
|
38
|
+
for key, value in cls.__dict__.items():
|
|
39
|
+
if enum_type == value:
|
|
40
|
+
return key
|
|
41
|
+
|
|
42
|
+
@classmethod
|
|
43
|
+
def contains(cls, enum_type):
|
|
44
|
+
for key, value in cls.__dict__.items():
|
|
45
|
+
if enum_type == value:
|
|
46
|
+
return True
|
|
47
|
+
return False
|
|
48
|
+
|
|
49
|
+
@classmethod
|
|
50
|
+
def from_name(cls, enum_type_str):
|
|
51
|
+
for key, value in cls.__dict__.items():
|
|
52
|
+
if key == enum_type_str:
|
|
53
|
+
return value
|
|
54
|
+
assert f"{cls.__name__}:{enum_type_str} doesnot exist."
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class LossType(Enum):
|
|
58
|
+
Elbo = 0
|
|
59
|
+
ElboWithCritic = 1
|
|
60
|
+
ElboMixedReconstruction = 2
|
|
61
|
+
MSE = 3
|
|
62
|
+
ElboWithNbrConsistency = 4
|
|
63
|
+
ElboSemiSupMixedReconstruction = 5
|
|
64
|
+
ElboCL = 6
|
|
65
|
+
ElboRestrictedReconstruction = 7
|
|
66
|
+
DenoiSplitMuSplit = 8
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class ModelType(Enum):
|
|
70
|
+
LadderVae = 3
|
|
71
|
+
LadderVaeTwinDecoder = 4
|
|
72
|
+
LadderVAECritic = 5
|
|
73
|
+
# Separate vampprior: two optimizers
|
|
74
|
+
LadderVaeSepVampprior = 6
|
|
75
|
+
# one encoder for mixed input, two for separate inputs.
|
|
76
|
+
LadderVaeSepEncoder = 7
|
|
77
|
+
LadderVAEMultiTarget = 8
|
|
78
|
+
LadderVaeSepEncoderSingleOptim = 9
|
|
79
|
+
UNet = 10
|
|
80
|
+
BraveNet = 11
|
|
81
|
+
LadderVaeStitch = 12
|
|
82
|
+
LadderVaeSemiSupervised = 13
|
|
83
|
+
LadderVaeStitch2Stage = 14 # Note that previously trained models will have issue.
|
|
84
|
+
# since earlier, LadderVaeStitch2Stage = 13, LadderVaeSemiSupervised = 14
|
|
85
|
+
LadderVaeMixedRecons = 15
|
|
86
|
+
LadderVaeCL = 16
|
|
87
|
+
LadderVaeTwoDataSet = (
|
|
88
|
+
17 # on one subdset, apply disentanglement, on other apply reconstruction
|
|
89
|
+
)
|
|
90
|
+
LadderVaeTwoDatasetMultiBranch = 18
|
|
91
|
+
LadderVaeTwoDatasetMultiOptim = 19
|
|
92
|
+
LVaeDeepEncoderIntensityAug = 20
|
|
93
|
+
AutoRegresiveLadderVAE = 21
|
|
94
|
+
LadderVAEInterleavedOptimization = 22
|
|
95
|
+
Denoiser = 23
|
|
96
|
+
DenoiserSplitter = 24
|
|
97
|
+
SplitterDenoiser = 25
|
|
98
|
+
LadderVAERestrictedReconstruction = 26
|
|
99
|
+
LadderVAETwoDataSetRestRecon = 27
|
|
100
|
+
LadderVAETwoDataSetFinetuning = 28
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _pad_crop_img(x, size, mode) -> torch.Tensor:
|
|
104
|
+
"""Pads or crops a tensor.
|
|
105
|
+
Pads or crops a tensor of shape (batch, channels, h, w) to new height
|
|
106
|
+
and width given by a tuple.
|
|
107
|
+
Args:
|
|
108
|
+
x (torch.Tensor): Input image
|
|
109
|
+
size (list or tuple): Desired size (height, width)
|
|
110
|
+
mode (str): Mode, either 'pad' or 'crop'
|
|
111
|
+
Returns:
|
|
112
|
+
The padded or cropped tensor
|
|
113
|
+
"""
|
|
114
|
+
assert x.dim() == 4 and len(size) == 2
|
|
115
|
+
size = tuple(size)
|
|
116
|
+
x_size = x.size()[2:4]
|
|
117
|
+
if mode == "pad":
|
|
118
|
+
cond = x_size[0] > size[0] or x_size[1] > size[1]
|
|
119
|
+
elif mode == "crop":
|
|
120
|
+
cond = x_size[0] < size[0] or x_size[1] < size[1]
|
|
121
|
+
else:
|
|
122
|
+
raise ValueError(f"invalid mode '{mode}'")
|
|
123
|
+
if cond:
|
|
124
|
+
raise ValueError(f"trying to {mode} from size {x_size} to size {size}")
|
|
125
|
+
dr, dc = (abs(x_size[0] - size[0]), abs(x_size[1] - size[1]))
|
|
126
|
+
dr1, dr2 = dr // 2, dr - (dr // 2)
|
|
127
|
+
dc1, dc2 = dc // 2, dc - (dc // 2)
|
|
128
|
+
if mode == "pad":
|
|
129
|
+
return nn.functional.pad(x, [dc1, dc2, dr1, dr2, 0, 0, 0, 0])
|
|
130
|
+
elif mode == "crop":
|
|
131
|
+
return x[:, :, dr1 : x_size[0] - dr2, dc1 : x_size[1] - dc2]
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def pad_img_tensor(x, size) -> torch.Tensor:
|
|
135
|
+
"""Pads a tensor.
|
|
136
|
+
Pads a tensor of shape (batch, channels, h, w) to a desired height and width.
|
|
137
|
+
Args:
|
|
138
|
+
x (torch.Tensor): Input image
|
|
139
|
+
size (list or tuple): Desired size (height, width)
|
|
140
|
+
|
|
141
|
+
Returns
|
|
142
|
+
-------
|
|
143
|
+
The padded tensor
|
|
144
|
+
"""
|
|
145
|
+
return _pad_crop_img(x, size, "pad")
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def crop_img_tensor(x, size) -> torch.Tensor:
|
|
149
|
+
"""Crops a tensor.
|
|
150
|
+
Crops a tensor of shape (batch, channels, h, w) to a desired height and width
|
|
151
|
+
given by a tuple.
|
|
152
|
+
Args:
|
|
153
|
+
x (torch.Tensor): Input image
|
|
154
|
+
size (list or tuple): Desired size (height, width)
|
|
155
|
+
|
|
156
|
+
Returns
|
|
157
|
+
-------
|
|
158
|
+
The cropped tensor
|
|
159
|
+
"""
|
|
160
|
+
return _pad_crop_img(x, size, "crop")
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class StableExponential:
|
|
164
|
+
"""
|
|
165
|
+
Class that redefines the definition of exp() to increase numerical stability.
|
|
166
|
+
Naturally, also the definition of log() must change accordingly.
|
|
167
|
+
However, it is worth noting that the two operations remain one the inverse of the other,
|
|
168
|
+
meaning that x = log(exp(x)) and x = exp(log(x)) are always true.
|
|
169
|
+
|
|
170
|
+
Definition:
|
|
171
|
+
exp(x) = {
|
|
172
|
+
exp(x) if x<=0
|
|
173
|
+
x+1 if x>0
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
log(x) = {
|
|
177
|
+
x if x<=0
|
|
178
|
+
log(1+x) if x>0
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
NOTE 1:
|
|
182
|
+
Within the class everything is done on the tensor given as input to the constructor.
|
|
183
|
+
Therefore, when exp() is called, self._tensor.exp() is computed.
|
|
184
|
+
When log() is called, torch.log(self._tensor.exp()) is computed instead.
|
|
185
|
+
|
|
186
|
+
NOTE 2:
|
|
187
|
+
Given the output from exp(), torch.log() or the log() method of the class give identical results.
|
|
188
|
+
"""
|
|
189
|
+
|
|
190
|
+
def __init__(self, tensor):
|
|
191
|
+
self._raw_tensor = tensor
|
|
192
|
+
posneg_dic = self.posneg_separation(self._raw_tensor)
|
|
193
|
+
self.pos_f, self.neg_f = posneg_dic["filter"]
|
|
194
|
+
self.pos_data, self.neg_data = posneg_dic["value"]
|
|
195
|
+
|
|
196
|
+
def posneg_separation(self, tensor):
|
|
197
|
+
pos = tensor > 0
|
|
198
|
+
pos_tensor = torch.clip(tensor, min=0)
|
|
199
|
+
|
|
200
|
+
neg = tensor <= 0
|
|
201
|
+
neg_tensor = torch.clip(tensor, max=0)
|
|
202
|
+
|
|
203
|
+
return {"filter": [pos, neg], "value": [pos_tensor, neg_tensor]}
|
|
204
|
+
|
|
205
|
+
def exp(self):
|
|
206
|
+
return torch.exp(self.neg_data) * self.neg_f + (1 + self.pos_data) * self.pos_f
|
|
207
|
+
|
|
208
|
+
def log(self):
|
|
209
|
+
return self.neg_data * self.neg_f + torch.log(1 + self.pos_data) * self.pos_f
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
class StableLogVar:
|
|
213
|
+
"""
|
|
214
|
+
Class that provides a numerically stable implementation of Log-Variance.
|
|
215
|
+
Specifically, it uses the exp() and log() formulas defined in `StableExponential` class.
|
|
216
|
+
"""
|
|
217
|
+
|
|
218
|
+
def __init__(
|
|
219
|
+
self, logvar: torch.Tensor, enable_stable: bool = True, var_eps: float = 1e-6
|
|
220
|
+
):
|
|
221
|
+
"""
|
|
222
|
+
Contructor.
|
|
223
|
+
|
|
224
|
+
Parameters
|
|
225
|
+
----------
|
|
226
|
+
logvar: torch.Tensor
|
|
227
|
+
The input (true) logvar vector, to be converted in the Stable version.
|
|
228
|
+
enable_stable: bool, optional
|
|
229
|
+
Whether to compute the stable version of log-variance. Default is `True`.
|
|
230
|
+
var_eps: float, optional
|
|
231
|
+
The minimum value attainable by the variance. Default is `1e-6`.
|
|
232
|
+
"""
|
|
233
|
+
self._lv = logvar
|
|
234
|
+
self._enable_stable = enable_stable
|
|
235
|
+
self._eps = var_eps
|
|
236
|
+
|
|
237
|
+
def get(self) -> torch.Tensor:
|
|
238
|
+
if self._enable_stable is False:
|
|
239
|
+
return self._lv
|
|
240
|
+
|
|
241
|
+
return torch.log(self.get_var())
|
|
242
|
+
|
|
243
|
+
def get_var(self) -> torch.Tensor:
|
|
244
|
+
"""
|
|
245
|
+
Get Variance from Log-Variance.
|
|
246
|
+
"""
|
|
247
|
+
if self._enable_stable is False:
|
|
248
|
+
return torch.exp(self._lv)
|
|
249
|
+
return StableExponential(self._lv).exp() + self._eps
|
|
250
|
+
|
|
251
|
+
def get_std(self) -> torch.Tensor:
|
|
252
|
+
return torch.sqrt(self.get_var())
|
|
253
|
+
|
|
254
|
+
def centercrop_to_size(self, size: Iterable[int]) -> None:
|
|
255
|
+
"""
|
|
256
|
+
Centercrop the log-variance tensor to the desired size.
|
|
257
|
+
|
|
258
|
+
Parameters
|
|
259
|
+
----------
|
|
260
|
+
size: torch.Tensor
|
|
261
|
+
The desired size of the log-variance tensor.
|
|
262
|
+
"""
|
|
263
|
+
if self._lv.shape[-1] == size:
|
|
264
|
+
return
|
|
265
|
+
|
|
266
|
+
diff = self._lv.shape[-1] - size
|
|
267
|
+
assert diff > 0 and diff % 2 == 0
|
|
268
|
+
self._lv = F.center_crop(self._lv, (size, size))
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
class StableMean:
|
|
272
|
+
|
|
273
|
+
def __init__(self, mean):
|
|
274
|
+
self._mean = mean
|
|
275
|
+
|
|
276
|
+
def get(self) -> torch.Tensor:
|
|
277
|
+
return self._mean
|
|
278
|
+
|
|
279
|
+
def centercrop_to_size(self, size: Iterable[int]) -> None:
|
|
280
|
+
"""
|
|
281
|
+
Centercrop the mean tensor to the desired size.
|
|
282
|
+
|
|
283
|
+
Parameters
|
|
284
|
+
----------
|
|
285
|
+
size: torch.Tensor
|
|
286
|
+
The desired size of the log-variance tensor.
|
|
287
|
+
"""
|
|
288
|
+
if self._mean.shape[-1] == size:
|
|
289
|
+
return
|
|
290
|
+
|
|
291
|
+
diff = self._mean.shape[-1] - size
|
|
292
|
+
assert diff > 0 and diff % 2 == 0
|
|
293
|
+
self._mean = F.center_crop(self._mean, (size, size))
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def allow_numpy(func):
|
|
297
|
+
"""
|
|
298
|
+
All optional arguements are passed as is. positional arguments are checked. if they are numpy array,
|
|
299
|
+
they are converted to torch Tensor.
|
|
300
|
+
"""
|
|
301
|
+
|
|
302
|
+
def numpy_wrapper(*args, **kwargs):
|
|
303
|
+
new_args = []
|
|
304
|
+
for arg in args:
|
|
305
|
+
if isinstance(arg, np.ndarray):
|
|
306
|
+
arg = torch.Tensor(arg)
|
|
307
|
+
new_args.append(arg)
|
|
308
|
+
new_args = tuple(new_args)
|
|
309
|
+
|
|
310
|
+
output = func(*new_args, **kwargs)
|
|
311
|
+
return output
|
|
312
|
+
|
|
313
|
+
return numpy_wrapper
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
class Interpolate(nn.Module):
|
|
317
|
+
"""Wrapper for torch.nn.functional.interpolate."""
|
|
318
|
+
|
|
319
|
+
def __init__(self, size=None, scale=None, mode="bilinear", align_corners=False):
|
|
320
|
+
super().__init__()
|
|
321
|
+
assert (size is None) == (scale is not None)
|
|
322
|
+
self.size = size
|
|
323
|
+
self.scale = scale
|
|
324
|
+
self.mode = mode
|
|
325
|
+
self.align_corners = align_corners
|
|
326
|
+
|
|
327
|
+
def forward(self, x):
|
|
328
|
+
out = F.interpolate(
|
|
329
|
+
x,
|
|
330
|
+
size=self.size,
|
|
331
|
+
scale_factor=self.scale,
|
|
332
|
+
mode=self.mode,
|
|
333
|
+
align_corners=self.align_corners,
|
|
334
|
+
)
|
|
335
|
+
return out
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def kl_normal_mc(z, p_mulv, q_mulv):
|
|
339
|
+
"""
|
|
340
|
+
One-sample estimation of element-wise KL between two diagonal
|
|
341
|
+
multivariate normal distributions. Any number of dimensions,
|
|
342
|
+
broadcasting supported (be careful).
|
|
343
|
+
:param z:
|
|
344
|
+
:param p_mulv:
|
|
345
|
+
:param q_mulv:
|
|
346
|
+
:return:
|
|
347
|
+
"""
|
|
348
|
+
assert isinstance(p_mulv, tuple)
|
|
349
|
+
assert isinstance(q_mulv, tuple)
|
|
350
|
+
p_mu, p_lv = p_mulv
|
|
351
|
+
q_mu, q_lv = q_mulv
|
|
352
|
+
|
|
353
|
+
p_std = p_lv.get_std()
|
|
354
|
+
q_std = q_lv.get_std()
|
|
355
|
+
|
|
356
|
+
p_distrib = Normal(p_mu.get(), p_std)
|
|
357
|
+
q_distrib = Normal(q_mu.get(), q_std)
|
|
358
|
+
return q_distrib.log_prob(z) - p_distrib.log_prob(z)
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
def free_bits_kl(
|
|
362
|
+
kl: torch.Tensor, free_bits: float, batch_average: bool = False, eps: float = 1e-6
|
|
363
|
+
) -> torch.Tensor:
|
|
364
|
+
"""
|
|
365
|
+
Computes free-bits version of KL divergence.
|
|
366
|
+
Ensures that the KL doesn't go to zero for any latent dimension.
|
|
367
|
+
Hence, it contributes to use latent variables more efficiently,
|
|
368
|
+
leading to better representation learning.
|
|
369
|
+
|
|
370
|
+
NOTE:
|
|
371
|
+
Takes in the KL with shape (batch size, layers), returns the KL with
|
|
372
|
+
free bits (for optimization) with shape (layers,), which is the average
|
|
373
|
+
free-bits KL per layer in the current batch.
|
|
374
|
+
If batch_average is False (default), the free bits are per layer and
|
|
375
|
+
per batch element. Otherwise, the free bits are still per layer, but
|
|
376
|
+
are assigned on average to the whole batch. In both cases, the batch
|
|
377
|
+
average is returned, so it's simply a matter of doing mean(clamp(KL))
|
|
378
|
+
or clamp(mean(KL)).
|
|
379
|
+
|
|
380
|
+
Args:
|
|
381
|
+
kl (torch.Tensor)
|
|
382
|
+
free_bits (float)
|
|
383
|
+
batch_average (bool, optional))
|
|
384
|
+
eps (float, optional)
|
|
385
|
+
|
|
386
|
+
Returns
|
|
387
|
+
-------
|
|
388
|
+
The KL with free bits
|
|
389
|
+
"""
|
|
390
|
+
assert kl.dim() == 2
|
|
391
|
+
if free_bits < eps:
|
|
392
|
+
return kl.mean(0)
|
|
393
|
+
if batch_average:
|
|
394
|
+
return kl.mean(0).clamp(min=free_bits)
|
|
395
|
+
return kl.clamp(min=free_bits).mean(0)
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
"""Package to house various prediction utilies."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"stitch_prediction",
|
|
5
|
+
"stitch_prediction_single",
|
|
6
|
+
"convert_outputs",
|
|
7
|
+
]
|
|
8
|
+
|
|
9
|
+
from .prediction_outputs import convert_outputs
|
|
10
|
+
from .stitch_prediction import stitch_prediction, stitch_prediction_single
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
"""Module containing functions to convert prediction outputs to desired form."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, List, Literal, Tuple, Union, overload
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from numpy.typing import NDArray
|
|
7
|
+
|
|
8
|
+
from ..config.tile_information import TileInformation
|
|
9
|
+
from .stitch_prediction import stitch_prediction
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def convert_outputs(
|
|
13
|
+
predictions: List[Any], tiled: bool
|
|
14
|
+
) -> Union[List[NDArray], NDArray]:
|
|
15
|
+
"""
|
|
16
|
+
Convert the Lightning trainer outputs to the desired form.
|
|
17
|
+
|
|
18
|
+
This method allows stitching back together tiled predictions.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
predictions : list
|
|
23
|
+
Predictions that are output from `Trainer.predict`.
|
|
24
|
+
tiled : bool
|
|
25
|
+
Whether the predictions are tiled.
|
|
26
|
+
|
|
27
|
+
Returns
|
|
28
|
+
-------
|
|
29
|
+
list of numpy.ndarray or numpy.ndarray
|
|
30
|
+
List of arrays with the axes SC(Z)YX. If there is only 1 output it will not
|
|
31
|
+
be in a list.
|
|
32
|
+
"""
|
|
33
|
+
if len(predictions) == 0:
|
|
34
|
+
return predictions
|
|
35
|
+
|
|
36
|
+
# this layout is to stop mypy complaining
|
|
37
|
+
if tiled:
|
|
38
|
+
predictions_comb = combine_batches(predictions, tiled)
|
|
39
|
+
predictions_output = stitch_prediction(*predictions_comb)
|
|
40
|
+
else:
|
|
41
|
+
predictions_output = combine_batches(predictions, tiled)
|
|
42
|
+
|
|
43
|
+
return predictions_output
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# for mypy
|
|
47
|
+
@overload
|
|
48
|
+
def combine_batches( # numpydoc ignore=GL08
|
|
49
|
+
predictions: List[Any], tiled: Literal[True]
|
|
50
|
+
) -> Tuple[List[NDArray], List[TileInformation]]: ...
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
# for mypy
|
|
54
|
+
@overload
|
|
55
|
+
def combine_batches( # numpydoc ignore=GL08
|
|
56
|
+
predictions: List[Any], tiled: Literal[False]
|
|
57
|
+
) -> List[NDArray]: ...
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
# for mypy
|
|
61
|
+
@overload
|
|
62
|
+
def combine_batches( # numpydoc ignore=GL08
|
|
63
|
+
predictions: List[Any], tiled: Union[bool, Literal[True], Literal[False]]
|
|
64
|
+
) -> Union[List[NDArray], Tuple[List[NDArray], List[TileInformation]]]: ...
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def combine_batches(
|
|
68
|
+
predictions: List[Any], tiled: bool
|
|
69
|
+
) -> Union[List[NDArray], Tuple[List[NDArray], List[TileInformation]]]:
|
|
70
|
+
"""
|
|
71
|
+
If predictions are in batches, they will be combined.
|
|
72
|
+
|
|
73
|
+
Parameters
|
|
74
|
+
----------
|
|
75
|
+
predictions : list
|
|
76
|
+
Predictions that are output from `Trainer.predict`.
|
|
77
|
+
tiled : bool
|
|
78
|
+
Whether the predictions are tiled.
|
|
79
|
+
|
|
80
|
+
Returns
|
|
81
|
+
-------
|
|
82
|
+
(list of numpy.ndarray) or tuple of (list of numpy.ndarray, list of TileInformation)
|
|
83
|
+
Combined batches.
|
|
84
|
+
"""
|
|
85
|
+
if tiled:
|
|
86
|
+
return _combine_tiled_batches(predictions)
|
|
87
|
+
else:
|
|
88
|
+
return _combine_array_batches(predictions)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _combine_tiled_batches(
|
|
92
|
+
predictions: List[Tuple[NDArray, List[TileInformation]]]
|
|
93
|
+
) -> Tuple[List[NDArray], List[TileInformation]]:
|
|
94
|
+
"""
|
|
95
|
+
Combine batches from tiled output.
|
|
96
|
+
|
|
97
|
+
Parameters
|
|
98
|
+
----------
|
|
99
|
+
predictions : list of (numpy.ndarray, list of TileInformation)
|
|
100
|
+
Predictions that are output from `Trainer.predict`. For tiled batches, this is
|
|
101
|
+
a list of tuples. The first element of the tuples is the prediction output of
|
|
102
|
+
tiles with dimension (B, C, (Z), Y, X), where B is batch size. The second
|
|
103
|
+
element of the tuples is a list of TileInformation objects of length B.
|
|
104
|
+
|
|
105
|
+
Returns
|
|
106
|
+
-------
|
|
107
|
+
tuple of (list of numpy.ndarray, list of TileInformation)
|
|
108
|
+
Combined batches.
|
|
109
|
+
"""
|
|
110
|
+
# turn list of lists into single list
|
|
111
|
+
tile_infos = [
|
|
112
|
+
tile_info for _, tile_info_list in predictions for tile_info in tile_info_list
|
|
113
|
+
]
|
|
114
|
+
prediction_tiles: List[NDArray] = _combine_array_batches(
|
|
115
|
+
[preds for preds, _ in predictions]
|
|
116
|
+
)
|
|
117
|
+
return prediction_tiles, tile_infos
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _combine_array_batches(predictions: List[NDArray]) -> List[NDArray]:
|
|
121
|
+
"""
|
|
122
|
+
Combine batches of arrays.
|
|
123
|
+
|
|
124
|
+
Parameters
|
|
125
|
+
----------
|
|
126
|
+
predictions : list
|
|
127
|
+
Prediction arrays that are output from `Trainer.predict`. A list of arrays that
|
|
128
|
+
have dimensions (B, C, (Z), Y, X), where B is batch size.
|
|
129
|
+
|
|
130
|
+
Returns
|
|
131
|
+
-------
|
|
132
|
+
list of numpy.ndarray
|
|
133
|
+
A list of arrays with dimensions (1, C, (Z), Y, X).
|
|
134
|
+
"""
|
|
135
|
+
prediction_concat: NDArray = np.concatenate(predictions, axis=0)
|
|
136
|
+
prediction_split = np.split(prediction_concat, prediction_concat.shape[0], axis=0)
|
|
137
|
+
return prediction_split
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
"""Prediction utility functions."""
|
|
2
|
+
|
|
3
|
+
import builtins
|
|
4
|
+
from typing import List, Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from numpy.typing import NDArray
|
|
8
|
+
|
|
9
|
+
from careamics.config.tile_information import TileInformation
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# TODO: why not allow input and output of torch.tensor ?
|
|
13
|
+
def stitch_prediction(
|
|
14
|
+
tiles: List[np.ndarray],
|
|
15
|
+
tile_infos: List[TileInformation],
|
|
16
|
+
) -> List[np.ndarray]:
|
|
17
|
+
"""
|
|
18
|
+
Stitch tiles back together to form a full image(s).
|
|
19
|
+
|
|
20
|
+
Tiles are of dimensions SC(Z)YX, where C is the number of channels and can be a
|
|
21
|
+
singleton dimension.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
tiles : list of numpy.ndarray
|
|
26
|
+
Cropped tiles and their respective stitching coordinates. Can contain tiles
|
|
27
|
+
from multiple images.
|
|
28
|
+
tile_infos : list of TileInformation
|
|
29
|
+
List of information and coordinates obtained from
|
|
30
|
+
`dataset.tiled_patching.extract_tiles`.
|
|
31
|
+
|
|
32
|
+
Returns
|
|
33
|
+
-------
|
|
34
|
+
list of numpy.ndarray
|
|
35
|
+
Full image(s).
|
|
36
|
+
"""
|
|
37
|
+
# Find where to split the lists so that only info from one image is contained.
|
|
38
|
+
# Do this by locating the last tiles of each image.
|
|
39
|
+
last_tiles = [tile_info.last_tile for tile_info in tile_infos]
|
|
40
|
+
last_tile_position = np.where(last_tiles)[0]
|
|
41
|
+
image_slices = [
|
|
42
|
+
slice(
|
|
43
|
+
None if i == 0 else last_tile_position[i - 1] + 1, last_tile_position[i] + 1
|
|
44
|
+
)
|
|
45
|
+
for i in range(len(last_tile_position))
|
|
46
|
+
]
|
|
47
|
+
image_predictions = []
|
|
48
|
+
# slice the lists and apply stitch_prediction_single to each in turn.
|
|
49
|
+
for image_slice in image_slices:
|
|
50
|
+
image_predictions.append(
|
|
51
|
+
stitch_prediction_single(tiles[image_slice], tile_infos[image_slice])
|
|
52
|
+
)
|
|
53
|
+
return image_predictions
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def stitch_prediction_single(
|
|
57
|
+
tiles: List[NDArray],
|
|
58
|
+
tile_infos: List[TileInformation],
|
|
59
|
+
) -> NDArray:
|
|
60
|
+
"""
|
|
61
|
+
Stitch tiles back together to form a full image.
|
|
62
|
+
|
|
63
|
+
Tiles are of dimensions SC(Z)YX, where C is the number of channels and can be a
|
|
64
|
+
singleton dimension.
|
|
65
|
+
|
|
66
|
+
Parameters
|
|
67
|
+
----------
|
|
68
|
+
tiles : list of numpy.ndarray
|
|
69
|
+
Cropped tiles and their respective stitching coordinates.
|
|
70
|
+
tile_infos : list of TileInformation
|
|
71
|
+
List of information and coordinates obtained from
|
|
72
|
+
`dataset.tiled_patching.extract_tiles`.
|
|
73
|
+
|
|
74
|
+
Returns
|
|
75
|
+
-------
|
|
76
|
+
numpy.ndarray
|
|
77
|
+
Full image, with dimensions SC(Z)YX.
|
|
78
|
+
"""
|
|
79
|
+
# retrieve whole array size
|
|
80
|
+
input_shape = tile_infos[0].array_shape
|
|
81
|
+
predicted_image = np.zeros(input_shape, dtype=np.float32)
|
|
82
|
+
|
|
83
|
+
# reshape
|
|
84
|
+
# TODO: can be more elegantly solved if TileInformation allows singleton dims
|
|
85
|
+
singleton_dims = tuple(np.where(np.array(tiles[0].shape) == 1)[0])
|
|
86
|
+
predicted_image = np.expand_dims(predicted_image, singleton_dims)
|
|
87
|
+
|
|
88
|
+
for tile, tile_info in zip(tiles, tile_infos):
|
|
89
|
+
|
|
90
|
+
# Compute coordinates for cropping predicted tile
|
|
91
|
+
crop_slices: tuple[Union[builtins.ellipsis, slice], ...] = (
|
|
92
|
+
...,
|
|
93
|
+
*[slice(c[0], c[1]) for c in tile_info.overlap_crop_coords],
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Crop predited tile according to overlap coordinates
|
|
97
|
+
cropped_tile = tile[crop_slices]
|
|
98
|
+
|
|
99
|
+
# Insert cropped tile into predicted image using stitch coordinates
|
|
100
|
+
image_slices = (..., *[slice(c[0], c[1]) for c in tile_info.stitch_coords])
|
|
101
|
+
predicted_image[image_slices] = cropped_tile.astype(np.float32)
|
|
102
|
+
|
|
103
|
+
return predicted_image
|
|
@@ -60,7 +60,7 @@ class N2VManipulate(Transform):
|
|
|
60
60
|
remove_center: bool = True,
|
|
61
61
|
struct_mask_axis: Literal["horizontal", "vertical", "none"] = "none",
|
|
62
62
|
struct_mask_span: int = 5,
|
|
63
|
-
seed: Optional[int] = None,
|
|
63
|
+
seed: Optional[int] = None,
|
|
64
64
|
):
|
|
65
65
|
"""Constructor.
|
|
66
66
|
|
|
@@ -127,6 +127,7 @@ class N2VManipulate(Transform):
|
|
|
127
127
|
subpatch_size=self.roi_size,
|
|
128
128
|
remove_center=self.remove_center,
|
|
129
129
|
struct_params=self.struct_mask,
|
|
130
|
+
rng=self.rng,
|
|
130
131
|
)
|
|
131
132
|
elif self.strategy == SupportedPixelManipulation.MEDIAN:
|
|
132
133
|
# Iterate over the channels to apply manipulation separately
|
|
@@ -136,6 +137,7 @@ class N2VManipulate(Transform):
|
|
|
136
137
|
mask_pixel_percentage=self.masked_pixel_percentage,
|
|
137
138
|
subpatch_size=self.roi_size,
|
|
138
139
|
struct_params=self.struct_mask,
|
|
140
|
+
rng=self.rng,
|
|
139
141
|
)
|
|
140
142
|
else:
|
|
141
143
|
raise ValueError(f"Unknown masking strategy ({self.strategy}).")
|