careamics 0.0.4.2__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +235 -25
- careamics/cli/conf.py +19 -30
- careamics/cli/main.py +111 -10
- careamics/cli/utils.py +29 -0
- careamics/config/__init__.py +2 -0
- careamics/config/architectures/lvae_model.py +104 -21
- careamics/config/configuration_factory.py +49 -45
- careamics/config/configuration_model.py +2 -2
- careamics/config/likelihood_model.py +7 -6
- careamics/config/loss_model.py +56 -0
- careamics/config/nm_model.py +24 -24
- careamics/config/vae_algorithm_model.py +14 -13
- careamics/dataset/dataset_utils/running_stats.py +22 -23
- careamics/lightning/lightning_module.py +58 -27
- careamics/lightning/train_data_module.py +15 -1
- careamics/losses/loss_factory.py +1 -85
- careamics/losses/lvae/losses.py +223 -164
- careamics/lvae_training/calibration.py +184 -0
- careamics/lvae_training/dataset/config.py +2 -2
- careamics/lvae_training/dataset/multich_dataset.py +11 -19
- careamics/lvae_training/dataset/multifile_dataset.py +3 -2
- careamics/lvae_training/dataset/types.py +15 -26
- careamics/lvae_training/dataset/utils/index_manager.py +4 -4
- careamics/lvae_training/eval_utils.py +125 -213
- careamics/model_io/bioimage/_readme_factory.py +25 -33
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +39 -17
- careamics/model_io/bmz_io.py +36 -25
- careamics/models/layers.py +6 -4
- careamics/models/lvae/layers.py +348 -975
- careamics/models/lvae/likelihoods.py +10 -8
- careamics/models/lvae/lvae.py +214 -272
- careamics/models/lvae/noise_models.py +179 -112
- careamics/models/lvae/stochastic.py +393 -0
- careamics/models/lvae/utils.py +82 -73
- careamics/utils/lightning_utils.py +57 -0
- careamics/utils/serializers.py +2 -0
- careamics/utils/torch_utils.py +1 -1
- {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/METADATA +12 -9
- {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/RECORD +43 -37
- {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/WHEEL +1 -1
- {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,393 @@
|
|
|
1
|
+
"""Script containing the common basic blocks (nn.Module)
|
|
2
|
+
reused by the LadderVAE architecture.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Dict, Tuple, Union
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
import torchvision.transforms.functional as F
|
|
10
|
+
from torch.distributions import kl_divergence
|
|
11
|
+
from torch.distributions.normal import Normal
|
|
12
|
+
|
|
13
|
+
from .utils import (
|
|
14
|
+
StableLogVar,
|
|
15
|
+
StableMean,
|
|
16
|
+
kl_normal_mc,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
ConvType = Union[nn.Conv2d, nn.Conv3d]
|
|
20
|
+
NormType = Union[nn.BatchNorm2d, nn.BatchNorm3d]
|
|
21
|
+
DropoutType = Union[nn.Dropout2d, nn.Dropout3d]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class NormalStochasticBlock(nn.Module):
|
|
25
|
+
"""
|
|
26
|
+
Stochastic block used in the Top-Down inference pass.
|
|
27
|
+
|
|
28
|
+
Algorithm:
|
|
29
|
+
- map input parameters to q(z) and (optionally) p(z) via convolution
|
|
30
|
+
- sample a latent tensor z ~ q(z)
|
|
31
|
+
- feed z to convolution and return.
|
|
32
|
+
|
|
33
|
+
NOTE 1:
|
|
34
|
+
If parameters for q are not given, sampling is done from p(z).
|
|
35
|
+
|
|
36
|
+
NOTE 2:
|
|
37
|
+
The restricted KL divergence is obtained by first computing the element-wise KL divergence
|
|
38
|
+
(i.e., the KL computed for each element of the latent tensors). Then, the restricted version
|
|
39
|
+
is computed by summing over the channels and the spatial dimensions associated only to the
|
|
40
|
+
portion of the latent tensor that is used for prediction.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
c_in: int,
|
|
46
|
+
c_vars: int,
|
|
47
|
+
c_out: int,
|
|
48
|
+
conv_dims: int = 2,
|
|
49
|
+
kernel: int = 3,
|
|
50
|
+
transform_p_params: bool = True,
|
|
51
|
+
vanilla_latent_hw: int = None,
|
|
52
|
+
use_naive_exponential: bool = False,
|
|
53
|
+
):
|
|
54
|
+
"""
|
|
55
|
+
Parameters
|
|
56
|
+
----------
|
|
57
|
+
c_in: int
|
|
58
|
+
The number of channels of the input tensor.
|
|
59
|
+
c_vars: int
|
|
60
|
+
The number of channels of the latent space tensor.
|
|
61
|
+
c_out: int
|
|
62
|
+
The output of the stochastic layer.
|
|
63
|
+
Note that this is different from the sampled latent z.
|
|
64
|
+
conv_dims: int, optional
|
|
65
|
+
The number of dimensions of the convolutional layers (2D or 3D).
|
|
66
|
+
Default is 2.
|
|
67
|
+
kernel: int, optional
|
|
68
|
+
The size of the kernel used in convolutional layers.
|
|
69
|
+
Default is 3.
|
|
70
|
+
transform_p_params: bool, optional
|
|
71
|
+
Whether a transformation should be applied to the `p_params` tensor.
|
|
72
|
+
The transformation consists in a 2D convolution ()`conv_in_p()`) that
|
|
73
|
+
maps the input to a larger number of channels.
|
|
74
|
+
Default is `True`.
|
|
75
|
+
vanilla_latent_hw: int, optional
|
|
76
|
+
The shape of the latent tensor used for prediction (i.e., it influences the computation of restricted KL).
|
|
77
|
+
Default is `None`.
|
|
78
|
+
use_naive_exponential: bool, optional
|
|
79
|
+
If `False`, exponentials are computed according to the alternative definition
|
|
80
|
+
provided by `StableExponential` class. This should improve numerical stability
|
|
81
|
+
in the training process. Default is `False`.
|
|
82
|
+
"""
|
|
83
|
+
super().__init__()
|
|
84
|
+
assert kernel % 2 == 1
|
|
85
|
+
pad = kernel // 2
|
|
86
|
+
self.transform_p_params = transform_p_params
|
|
87
|
+
self.c_in = c_in
|
|
88
|
+
self.c_out = c_out
|
|
89
|
+
self.c_vars = c_vars
|
|
90
|
+
self.conv_dims = conv_dims
|
|
91
|
+
self._use_naive_exponential = use_naive_exponential
|
|
92
|
+
self._vanilla_latent_hw = vanilla_latent_hw
|
|
93
|
+
|
|
94
|
+
conv_layer: ConvType = getattr(nn, f"Conv{conv_dims}d")
|
|
95
|
+
|
|
96
|
+
if transform_p_params:
|
|
97
|
+
self.conv_in_p = conv_layer(c_in, 2 * c_vars, kernel, padding=pad)
|
|
98
|
+
self.conv_in_q = conv_layer(c_in, 2 * c_vars, kernel, padding=pad)
|
|
99
|
+
self.conv_out = conv_layer(c_vars, c_out, kernel, padding=pad)
|
|
100
|
+
|
|
101
|
+
def get_z(
|
|
102
|
+
self,
|
|
103
|
+
sampling_distrib: torch.distributions.normal.Normal,
|
|
104
|
+
forced_latent: Union[torch.Tensor, None],
|
|
105
|
+
mode_pred: bool,
|
|
106
|
+
use_uncond_mode: bool,
|
|
107
|
+
) -> torch.Tensor:
|
|
108
|
+
"""Sample a latent tensor from the given latent distribution.
|
|
109
|
+
|
|
110
|
+
Latent tensor can be obtained is several ways:
|
|
111
|
+
- Sampled from the (Gaussian) latent distribution.
|
|
112
|
+
- Taken as a pre-defined forced latent.
|
|
113
|
+
- Taken as the mode (mean) of the latent distribution.
|
|
114
|
+
- In prediction mode (`mode_pred==True`), can be either sample or taken as the distribution mode.
|
|
115
|
+
|
|
116
|
+
Parameters
|
|
117
|
+
----------
|
|
118
|
+
sampling_distrib: torch.distributions.normal.Normal
|
|
119
|
+
The Gaussian distribution from which latent tensor is sampled.
|
|
120
|
+
forced_latent: torch.Tensor
|
|
121
|
+
A pre-defined latent tensor. If it is not `None`, than it is used as the actual latent tensor and,
|
|
122
|
+
hence, sampling does not happen.
|
|
123
|
+
mode_pred: bool
|
|
124
|
+
Whether the model is prediction mode.
|
|
125
|
+
use_uncond_mode: bool
|
|
126
|
+
Whether to use the uncoditional distribution p(z) to sample latents in prediction mode.
|
|
127
|
+
"""
|
|
128
|
+
if forced_latent is None:
|
|
129
|
+
if mode_pred:
|
|
130
|
+
if use_uncond_mode:
|
|
131
|
+
z = sampling_distrib.mean
|
|
132
|
+
else:
|
|
133
|
+
z = sampling_distrib.rsample()
|
|
134
|
+
else:
|
|
135
|
+
z = sampling_distrib.rsample()
|
|
136
|
+
else:
|
|
137
|
+
z = forced_latent
|
|
138
|
+
return z
|
|
139
|
+
|
|
140
|
+
def sample_from_q(
|
|
141
|
+
self, q_params: torch.Tensor, var_clip_max: float
|
|
142
|
+
) -> torch.Tensor:
|
|
143
|
+
"""
|
|
144
|
+
Given an input parameter tensor defining q(z),
|
|
145
|
+
it processes it by calling `process_q_params()` method and
|
|
146
|
+
sample a latent tensor from the resulting distribution.
|
|
147
|
+
|
|
148
|
+
Parameters
|
|
149
|
+
----------
|
|
150
|
+
q_params: torch.Tensor
|
|
151
|
+
The input tensor to be processed.
|
|
152
|
+
var_clip_max: float
|
|
153
|
+
The maximum value reachable by the log-variance of the latent distribution.
|
|
154
|
+
Values exceeding this threshold are clipped.
|
|
155
|
+
"""
|
|
156
|
+
_, _, q = self.process_q_params(q_params, var_clip_max)
|
|
157
|
+
return q.rsample()
|
|
158
|
+
|
|
159
|
+
def compute_kl_metrics(
|
|
160
|
+
self,
|
|
161
|
+
p: torch.distributions.normal.Normal,
|
|
162
|
+
p_params: torch.Tensor,
|
|
163
|
+
q: torch.distributions.normal.Normal,
|
|
164
|
+
q_params: torch.Tensor,
|
|
165
|
+
mode_pred: bool,
|
|
166
|
+
analytical_kl: bool,
|
|
167
|
+
z: torch.Tensor,
|
|
168
|
+
) -> Dict[str, torch.Tensor]:
|
|
169
|
+
"""
|
|
170
|
+
Compute KL (analytical or MC estimate) and then process it, extracting composed versions of the metric.
|
|
171
|
+
Specifically, the different versions of the KL loss terms are:
|
|
172
|
+
- `kl_elementwise`: KL term for each single element of the latent tensor [Shape: (batch, ch, h, w)].
|
|
173
|
+
- `kl_samplewise`: KL term associated to each sample in the batch [Shape: (batch, )].
|
|
174
|
+
- `kl_samplewise_restricted`: KL term only associated to the portion of the latent tensor that is
|
|
175
|
+
used for prediction and summed over channel and spatial dimensions [Shape: (batch, )].
|
|
176
|
+
- `kl_channelwise`: KL term associated to each sample and each channel [Shape: (batch, ch, )].
|
|
177
|
+
- `kl_spatial`: KL term summed over the channels, i.e., retaining the spatial dimensions [Shape: (batch, h, w)]
|
|
178
|
+
|
|
179
|
+
Parameters
|
|
180
|
+
----------
|
|
181
|
+
p: torch.distributions.normal.Normal
|
|
182
|
+
The prior generative distribution p(z_i|z_{i+1}) (or p(z_L)).
|
|
183
|
+
p_params: torch.Tensor
|
|
184
|
+
The parameters of the prior generative distribution.
|
|
185
|
+
q: torch.distributions.normal.Normal
|
|
186
|
+
The inference distribution q(z_i|z_{i+1}) (or q(z_L|x)).
|
|
187
|
+
q_params: torch.Tensor
|
|
188
|
+
The parameters of the inference distribution.
|
|
189
|
+
mode_pred: bool
|
|
190
|
+
Whether the model is in prediction mode.
|
|
191
|
+
analytical_kl: bool
|
|
192
|
+
Whether to compute the KL divergence analytically or using Monte Carlo estimation.
|
|
193
|
+
z: torch.Tensor
|
|
194
|
+
The sampled latent tensor.
|
|
195
|
+
"""
|
|
196
|
+
if mode_pred is False: # if not predicting
|
|
197
|
+
if analytical_kl:
|
|
198
|
+
kl_elementwise = kl_divergence(q, p)
|
|
199
|
+
else:
|
|
200
|
+
kl_elementwise = kl_normal_mc(z, p_params, q_params)
|
|
201
|
+
|
|
202
|
+
all_dims = tuple(range(len(kl_elementwise.shape)))
|
|
203
|
+
kl_samplewise = kl_elementwise.sum(all_dims[1:])
|
|
204
|
+
kl_channelwise = kl_elementwise.sum(all_dims[2:])
|
|
205
|
+
|
|
206
|
+
# compute KL only on the portion of the latent space that is used for prediction.
|
|
207
|
+
pad = (kl_elementwise.shape[-1] - self._vanilla_latent_hw) // 2
|
|
208
|
+
if pad > 0:
|
|
209
|
+
tmp = kl_elementwise[..., pad:-pad, pad:-pad]
|
|
210
|
+
kl_samplewise_restricted = tmp.sum(all_dims[1:])
|
|
211
|
+
else:
|
|
212
|
+
kl_samplewise_restricted = kl_samplewise
|
|
213
|
+
|
|
214
|
+
# Compute spatial KL analytically (but conditioned on samples from
|
|
215
|
+
# previous layers)
|
|
216
|
+
kl_spatial = kl_elementwise.sum(1)
|
|
217
|
+
else: # if predicting, no need to compute KL
|
|
218
|
+
kl_elementwise = kl_samplewise = kl_spatial = kl_channelwise = None
|
|
219
|
+
|
|
220
|
+
kl_dict = {
|
|
221
|
+
"kl_elementwise": kl_elementwise, # (batch, ch, h, w)
|
|
222
|
+
"kl_samplewise": kl_samplewise, # (batch, )
|
|
223
|
+
"kl_samplewise_restricted": kl_samplewise_restricted, # (batch, )
|
|
224
|
+
"kl_spatial": kl_spatial, # (batch, h, w)
|
|
225
|
+
"kl_channelwise": kl_channelwise, # (batch, ch)
|
|
226
|
+
} # TODO revisit, check dims
|
|
227
|
+
return kl_dict
|
|
228
|
+
|
|
229
|
+
def process_p_params(
|
|
230
|
+
self, p_params: torch.Tensor, var_clip_max: float
|
|
231
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.distributions.normal.Normal]:
|
|
232
|
+
"""Process the input parameters to get the prior distribution p(z_i|z_{i+1}) (or p(z_L)).
|
|
233
|
+
|
|
234
|
+
Processing consists in:
|
|
235
|
+
- (optionally) 2D convolution on the input tensor to increase number of channels.
|
|
236
|
+
- split the resulting tensor into two chunks, the mean and the log-variance.
|
|
237
|
+
- (optionally) clip the log-variance to an upper threshold.
|
|
238
|
+
- define the normal distribution p(z) given the parameter tensors above.
|
|
239
|
+
|
|
240
|
+
Parameters
|
|
241
|
+
----------
|
|
242
|
+
p_params: torch.Tensor
|
|
243
|
+
The input tensor to be processed.
|
|
244
|
+
var_clip_max: float
|
|
245
|
+
The maximum value reachable by the log-variance of the latent distribution.
|
|
246
|
+
Values exceeding this threshold are clipped.
|
|
247
|
+
"""
|
|
248
|
+
if self.transform_p_params:
|
|
249
|
+
p_params = self.conv_in_p(p_params)
|
|
250
|
+
else:
|
|
251
|
+
assert p_params.size(1) == 2 * self.c_vars
|
|
252
|
+
|
|
253
|
+
# Define p(z)
|
|
254
|
+
p_mu, p_lv = p_params.chunk(2, dim=1)
|
|
255
|
+
if var_clip_max is not None:
|
|
256
|
+
p_lv = torch.clip(p_lv, max=var_clip_max)
|
|
257
|
+
|
|
258
|
+
p_mu = StableMean(p_mu)
|
|
259
|
+
p_lv = StableLogVar(p_lv, enable_stable=not self._use_naive_exponential)
|
|
260
|
+
p = Normal(p_mu.get(), p_lv.get_std())
|
|
261
|
+
return p_mu, p_lv, p
|
|
262
|
+
|
|
263
|
+
def process_q_params(
|
|
264
|
+
self, q_params: torch.Tensor, var_clip_max: float, allow_oddsizes: bool = False
|
|
265
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.distributions.normal.Normal]:
|
|
266
|
+
"""
|
|
267
|
+
Process the input parameters to get the inference distribution q(z_i|z_{i+1}) (or q(z|x)).
|
|
268
|
+
|
|
269
|
+
Processing consists in:
|
|
270
|
+
- convolution on the input tensor to double the number of channels.
|
|
271
|
+
- split the resulting tensor into 2 chunks, respectively mean and log-var.
|
|
272
|
+
- (optionally) clip the log-variance to an upper threshold.
|
|
273
|
+
- (optionally) crop the resulting tensors to ensure that the last spatial dimension is even.
|
|
274
|
+
- define the normal distribution q(z) given the parameter tensors above.
|
|
275
|
+
|
|
276
|
+
Parameters
|
|
277
|
+
----------
|
|
278
|
+
p_params: torch.Tensor
|
|
279
|
+
The input tensor to be processed.
|
|
280
|
+
var_clip_max: float
|
|
281
|
+
The maximum value reachable by the log-variance of the latent distribution.
|
|
282
|
+
Values exceeding this threshold are clipped.
|
|
283
|
+
"""
|
|
284
|
+
q_params = self.conv_in_q(q_params)
|
|
285
|
+
|
|
286
|
+
q_mu, q_lv = q_params.chunk(2, dim=1)
|
|
287
|
+
if var_clip_max is not None:
|
|
288
|
+
q_lv = torch.clip(q_lv, max=var_clip_max)
|
|
289
|
+
|
|
290
|
+
if q_mu.shape[-1] % 2 == 1 and allow_oddsizes is False:
|
|
291
|
+
q_mu = F.center_crop(q_mu, q_mu.shape[-1] - 1)
|
|
292
|
+
q_lv = F.center_crop(q_lv, q_lv.shape[-1] - 1)
|
|
293
|
+
# TODO revisit ?!
|
|
294
|
+
q_mu = StableMean(q_mu)
|
|
295
|
+
q_lv = StableLogVar(q_lv, enable_stable=not self._use_naive_exponential)
|
|
296
|
+
q = Normal(q_mu.get(), q_lv.get_std())
|
|
297
|
+
return q_mu, q_lv, q
|
|
298
|
+
|
|
299
|
+
def forward(
|
|
300
|
+
self,
|
|
301
|
+
p_params: torch.Tensor,
|
|
302
|
+
q_params: Union[torch.Tensor, None] = None,
|
|
303
|
+
forced_latent: Union[torch.Tensor, None] = None,
|
|
304
|
+
force_constant_output: bool = False,
|
|
305
|
+
analytical_kl: bool = False,
|
|
306
|
+
mode_pred: bool = False,
|
|
307
|
+
use_uncond_mode: bool = False,
|
|
308
|
+
var_clip_max: Union[float, None] = None,
|
|
309
|
+
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
|
310
|
+
"""
|
|
311
|
+
Parameters
|
|
312
|
+
----------
|
|
313
|
+
p_params: torch.Tensor
|
|
314
|
+
The output tensor of the top-down layer above (i.e., mu_{p,i+1}, sigma_{p,i+1}).
|
|
315
|
+
q_params: torch.Tensor, optional
|
|
316
|
+
The tensor resulting from merging the bu_value tensor at the same hierarchical level
|
|
317
|
+
from the bottom-up pass and the `p_params` tensor. Default is `None`.
|
|
318
|
+
forced_latent: torch.Tensor, optional
|
|
319
|
+
A pre-defined latent tensor. If it is not `None`, than it is used as the actual latent
|
|
320
|
+
tensor and, hence, sampling does not happen. Default is `None`.
|
|
321
|
+
force_constant_output: bool, optional
|
|
322
|
+
Whether to copy the first sample (and rel. distrib parameters) over the whole batch.
|
|
323
|
+
This is used when doing experiment from the prior - q is not used.
|
|
324
|
+
Default is `False`.
|
|
325
|
+
analytical_kl: bool, optional
|
|
326
|
+
Whether to compute the KL divergence analytically or using Monte Carlo estimation.
|
|
327
|
+
Default is `False`.
|
|
328
|
+
mode_pred: bool, optional
|
|
329
|
+
Whether the model is in prediction mode. Default is `False`.
|
|
330
|
+
use_uncond_mode: bool, optional
|
|
331
|
+
Whether to use the uncoditional distribution p(z) to sample latents in prediction mode.
|
|
332
|
+
Default is `False`.
|
|
333
|
+
var_clip_max: float, optional
|
|
334
|
+
The maximum value reachable by the log-variance of the latent distribution.
|
|
335
|
+
Values exceeding this threshold are clipped. Default is `None`.
|
|
336
|
+
"""
|
|
337
|
+
debug_qvar_max = 0
|
|
338
|
+
|
|
339
|
+
# Check sampling options consistency
|
|
340
|
+
assert forced_latent is None
|
|
341
|
+
|
|
342
|
+
# Get generative distribution p(z_i|z_{i+1})
|
|
343
|
+
p_mu, p_lv, p = self.process_p_params(p_params, var_clip_max)
|
|
344
|
+
p_params = (p_mu, p_lv)
|
|
345
|
+
|
|
346
|
+
if q_params is not None:
|
|
347
|
+
# Get inference distribution q(z_i|z_{i+1})
|
|
348
|
+
q_mu, q_lv, q = self.process_q_params(q_params, var_clip_max)
|
|
349
|
+
q_params = (q_mu, q_lv)
|
|
350
|
+
debug_qvar_max = torch.max(q_lv.get())
|
|
351
|
+
sampling_distrib = q
|
|
352
|
+
q_size = q_mu.get().shape[-1]
|
|
353
|
+
if p_mu.get().shape[-1] != q_size and mode_pred is False:
|
|
354
|
+
p_mu.centercrop_to_size(q_size)
|
|
355
|
+
p_lv.centercrop_to_size(q_size)
|
|
356
|
+
else:
|
|
357
|
+
sampling_distrib = p
|
|
358
|
+
|
|
359
|
+
# Sample latent variable
|
|
360
|
+
z = self.get_z(sampling_distrib, forced_latent, mode_pred, use_uncond_mode)
|
|
361
|
+
|
|
362
|
+
# TODO: not necessary, remove
|
|
363
|
+
# Copy one sample (and distrib parameters) over the whole batch.
|
|
364
|
+
# This is used when doing experiment from the prior - q is not used.
|
|
365
|
+
if force_constant_output:
|
|
366
|
+
z = z[0:1].expand_as(z).clone()
|
|
367
|
+
p_params = (
|
|
368
|
+
p_params[0][0:1].expand_as(p_params[0]).clone(),
|
|
369
|
+
p_params[1][0:1].expand_as(p_params[1]).clone(),
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
# Pass the sampled latent through the output convolution of stochastic block
|
|
373
|
+
out = self.conv_out(z)
|
|
374
|
+
|
|
375
|
+
if q_params is not None:
|
|
376
|
+
# Compute log q(z)
|
|
377
|
+
logprob_q = q.log_prob(z).sum(tuple(range(1, z.dim())))
|
|
378
|
+
# Compute KL divergence metrics
|
|
379
|
+
kl_dict = self.compute_kl_metrics(
|
|
380
|
+
p, p_params, q, q_params, mode_pred, analytical_kl, z
|
|
381
|
+
)
|
|
382
|
+
else:
|
|
383
|
+
kl_dict = {}
|
|
384
|
+
logprob_q = None
|
|
385
|
+
|
|
386
|
+
# Store meaningful quantities for later computation
|
|
387
|
+
data = kl_dict
|
|
388
|
+
data["z"] = z # sampled variable at this layer (B, C, [Z], Y, X)
|
|
389
|
+
data["p_params"] = p_params # (B, C, [Z], Y, X) where B is 1 or batch size
|
|
390
|
+
data["q_params"] = q_params # (B, C, [Z], Y, X)
|
|
391
|
+
data["logprob_q"] = logprob_q # (B, )
|
|
392
|
+
data["qvar_max"] = debug_qvar_max
|
|
393
|
+
return out, data
|
careamics/models/lvae/utils.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Script for utility functions needed by the LVAE model.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from typing import
|
|
5
|
+
from typing import Literal, Sequence
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
import torch
|
|
@@ -15,11 +15,6 @@ def torch_nanmean(inp):
|
|
|
15
15
|
return torch.mean(inp[~inp.isnan()])
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
def compute_batch_mean(x):
|
|
19
|
-
N = len(x)
|
|
20
|
-
return x.view(N, -1).mean(dim=1)
|
|
21
|
-
|
|
22
|
-
|
|
23
18
|
def power_of_2(self, x):
|
|
24
19
|
assert isinstance(x, int)
|
|
25
20
|
if x == 1:
|
|
@@ -100,46 +95,76 @@ class ModelType(Enum):
|
|
|
100
95
|
LadderVAETwoDataSetFinetuning = 28
|
|
101
96
|
|
|
102
97
|
|
|
103
|
-
def _pad_crop_img(
|
|
98
|
+
def _pad_crop_img(
|
|
99
|
+
x: torch.Tensor, size: Sequence[int], mode: Literal["crop", "pad"]
|
|
100
|
+
) -> torch.Tensor:
|
|
104
101
|
"""Pads or crops a tensor.
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
102
|
+
|
|
103
|
+
Pads or crops a tensor of shape (B, C, [Z], Y, X) to new shape.
|
|
104
|
+
|
|
105
|
+
Parameters:
|
|
106
|
+
-----------
|
|
107
|
+
x: torch.Tensor
|
|
108
|
+
Input image of shape (B, C, [Z], Y, X)
|
|
109
|
+
size: Sequence[int]
|
|
110
|
+
Desired size ([Z*], Y*, X*)
|
|
111
|
+
mode: Literal["crop", "pad"]
|
|
112
|
+
Mode, either 'pad' or 'crop'
|
|
113
|
+
|
|
111
114
|
Returns:
|
|
115
|
+
--------
|
|
116
|
+
torch.Tensor:
|
|
112
117
|
The padded or cropped tensor
|
|
113
118
|
"""
|
|
114
|
-
|
|
119
|
+
# TODO: Support cropping/padding on selected dimensions
|
|
120
|
+
assert (x.dim() == 4 and len(size) == 2) or (x.dim() == 5 and len(size) == 3)
|
|
121
|
+
|
|
115
122
|
size = tuple(size)
|
|
116
|
-
x_size = x.size()[2:
|
|
123
|
+
x_size = x.size()[2:]
|
|
124
|
+
|
|
117
125
|
if mode == "pad":
|
|
118
|
-
cond = x_size[
|
|
126
|
+
cond = any(x_size[i] > size[i] for i in range(len(size)))
|
|
119
127
|
elif mode == "crop":
|
|
120
|
-
cond = x_size[
|
|
121
|
-
|
|
122
|
-
raise ValueError(f"invalid mode '{mode}'")
|
|
128
|
+
cond = any(x_size[i] < size[i] for i in range(len(size)))
|
|
129
|
+
|
|
123
130
|
if cond:
|
|
124
|
-
raise ValueError(f"
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
131
|
+
raise ValueError(f"Trying to {mode} from size {x_size} to size {size}")
|
|
132
|
+
|
|
133
|
+
diffs = [abs(x - s) for x, s in zip(x_size, size)]
|
|
134
|
+
d1 = [d // 2 for d in diffs]
|
|
135
|
+
d2 = [d - (d // 2) for d in diffs]
|
|
136
|
+
|
|
128
137
|
if mode == "pad":
|
|
129
|
-
|
|
138
|
+
if x.dim() == 4:
|
|
139
|
+
padding = [d1[1], d2[1], d1[0], d2[0], 0, 0, 0, 0]
|
|
140
|
+
elif x.dim() == 5:
|
|
141
|
+
padding = [d1[2], d2[2], d1[1], d2[1], d1[0], d2[0], 0, 0, 0, 0]
|
|
142
|
+
return nn.functional.pad(x, padding)
|
|
130
143
|
elif mode == "crop":
|
|
131
|
-
|
|
144
|
+
if x.dim() == 4:
|
|
145
|
+
return x[:, :, d1[0] : (x_size[0] - d2[0]), d1[1] : (x_size[1] - d2[1])]
|
|
146
|
+
elif x.dim() == 5:
|
|
147
|
+
return x[
|
|
148
|
+
:,
|
|
149
|
+
:,
|
|
150
|
+
d1[0] : (x_size[0] - d2[0]),
|
|
151
|
+
d1[1] : (x_size[1] - d2[1]),
|
|
152
|
+
d1[2] : (x_size[2] - d2[2]),
|
|
153
|
+
]
|
|
132
154
|
|
|
133
155
|
|
|
134
|
-
def pad_img_tensor(x, size) -> torch.Tensor:
|
|
135
|
-
"""Pads a tensor
|
|
136
|
-
Pads a tensor of shape (batch, channels, h, w) to a desired height and width.
|
|
137
|
-
Args:
|
|
138
|
-
x (torch.Tensor): Input image
|
|
139
|
-
size (list or tuple): Desired size (height, width)
|
|
156
|
+
def pad_img_tensor(x: torch.Tensor, size: Sequence[int]) -> torch.Tensor:
|
|
157
|
+
"""Pads a tensor
|
|
140
158
|
|
|
141
|
-
|
|
142
|
-
|
|
159
|
+
Pads a tensor of shape (B, C, [Z], Y, X) to desired spatial dimensions.
|
|
160
|
+
|
|
161
|
+
Parameters:
|
|
162
|
+
-----------
|
|
163
|
+
x (torch.Tensor): Input image of shape (B, C, [Z], Y, X)
|
|
164
|
+
size (list or tuple): Desired size ([Z*], Y*, X*)
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
--------
|
|
143
168
|
The padded tensor
|
|
144
169
|
"""
|
|
145
170
|
return _pad_crop_img(x, size, "pad")
|
|
@@ -251,7 +276,15 @@ class StableLogVar:
|
|
|
251
276
|
def get_std(self) -> torch.Tensor:
|
|
252
277
|
return torch.sqrt(self.get_var())
|
|
253
278
|
|
|
254
|
-
|
|
279
|
+
@property
|
|
280
|
+
def is_3D(self) -> bool:
|
|
281
|
+
"""Check if the _lv tensor is 3D.
|
|
282
|
+
|
|
283
|
+
Recall that, in this framework, tensors have shape (B, C, [Z], Y, X).
|
|
284
|
+
"""
|
|
285
|
+
return self._lv.dim() == 5
|
|
286
|
+
|
|
287
|
+
def centercrop_to_size(self, size: Sequence[int]) -> None:
|
|
255
288
|
"""
|
|
256
289
|
Centercrop the log-variance tensor to the desired size.
|
|
257
290
|
|
|
@@ -260,6 +293,8 @@ class StableLogVar:
|
|
|
260
293
|
size: torch.Tensor
|
|
261
294
|
The desired size of the log-variance tensor.
|
|
262
295
|
"""
|
|
296
|
+
assert not self.is_3D, "Centercrop is implemented only for 2D tensors."
|
|
297
|
+
|
|
263
298
|
if self._lv.shape[-1] == size:
|
|
264
299
|
return
|
|
265
300
|
|
|
@@ -276,15 +311,26 @@ class StableMean:
|
|
|
276
311
|
def get(self) -> torch.Tensor:
|
|
277
312
|
return self._mean
|
|
278
313
|
|
|
279
|
-
|
|
314
|
+
@property
|
|
315
|
+
def is_3D(self) -> bool:
|
|
316
|
+
"""Check if the _mean tensor is 3D.
|
|
317
|
+
|
|
318
|
+
Recall that, in this framework, tensors have shape (B, C, [Z], Y, X).
|
|
280
319
|
"""
|
|
281
|
-
|
|
320
|
+
return self._mean.dim() == 5
|
|
321
|
+
|
|
322
|
+
def centercrop_to_size(self, size: Sequence[int]) -> None:
|
|
323
|
+
"""Centercrop the mean tensor to the desired size.
|
|
324
|
+
|
|
325
|
+
Implemented only in the case of 2D tensors.
|
|
282
326
|
|
|
283
327
|
Parameters
|
|
284
328
|
----------
|
|
285
329
|
size: torch.Tensor
|
|
286
330
|
The desired size of the log-variance tensor.
|
|
287
331
|
"""
|
|
332
|
+
assert not self.is_3D, "Centercrop is implemented only for 2D tensors."
|
|
333
|
+
|
|
288
334
|
if self._mean.shape[-1] == size:
|
|
289
335
|
return
|
|
290
336
|
|
|
@@ -356,40 +402,3 @@ def kl_normal_mc(z, p_mulv, q_mulv):
|
|
|
356
402
|
p_distrib = Normal(p_mu.get(), p_std)
|
|
357
403
|
q_distrib = Normal(q_mu.get(), q_std)
|
|
358
404
|
return q_distrib.log_prob(z) - p_distrib.log_prob(z)
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
def free_bits_kl(
|
|
362
|
-
kl: torch.Tensor, free_bits: float, batch_average: bool = False, eps: float = 1e-6
|
|
363
|
-
) -> torch.Tensor:
|
|
364
|
-
"""
|
|
365
|
-
Computes free-bits version of KL divergence.
|
|
366
|
-
Ensures that the KL doesn't go to zero for any latent dimension.
|
|
367
|
-
Hence, it contributes to use latent variables more efficiently,
|
|
368
|
-
leading to better representation learning.
|
|
369
|
-
|
|
370
|
-
NOTE:
|
|
371
|
-
Takes in the KL with shape (batch size, layers), returns the KL with
|
|
372
|
-
free bits (for optimization) with shape (layers,), which is the average
|
|
373
|
-
free-bits KL per layer in the current batch.
|
|
374
|
-
If batch_average is False (default), the free bits are per layer and
|
|
375
|
-
per batch element. Otherwise, the free bits are still per layer, but
|
|
376
|
-
are assigned on average to the whole batch. In both cases, the batch
|
|
377
|
-
average is returned, so it's simply a matter of doing mean(clamp(KL))
|
|
378
|
-
or clamp(mean(KL)).
|
|
379
|
-
|
|
380
|
-
Args:
|
|
381
|
-
kl (torch.Tensor)
|
|
382
|
-
free_bits (float)
|
|
383
|
-
batch_average (bool, optional))
|
|
384
|
-
eps (float, optional)
|
|
385
|
-
|
|
386
|
-
Returns
|
|
387
|
-
-------
|
|
388
|
-
The KL with free bits
|
|
389
|
-
"""
|
|
390
|
-
assert kl.dim() == 2
|
|
391
|
-
if free_bits < eps:
|
|
392
|
-
return kl.mean(0)
|
|
393
|
-
if batch_average:
|
|
394
|
-
return kl.mean(0).clamp(min=free_bits)
|
|
395
|
-
return kl.clamp(min=free_bits).mean(0)
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""PyTorch lightning utilities."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def read_csv_logger(experiment_name: str, log_folder: Union[str, Path]) -> dict:
|
|
8
|
+
"""Return the loss curves from the csv logs.
|
|
9
|
+
|
|
10
|
+
Parameters
|
|
11
|
+
----------
|
|
12
|
+
experiment_name : str
|
|
13
|
+
Name of the experiment.
|
|
14
|
+
log_folder : Path or str
|
|
15
|
+
Path to the folder containing the csv logs.
|
|
16
|
+
|
|
17
|
+
Returns
|
|
18
|
+
-------
|
|
19
|
+
dict
|
|
20
|
+
Dictionary containing the loss curves, with keys "train_epoch", "val_epoch",
|
|
21
|
+
"train_loss" and "val_loss".
|
|
22
|
+
"""
|
|
23
|
+
path = Path(log_folder) / experiment_name
|
|
24
|
+
|
|
25
|
+
# find the most recent of version_* folders
|
|
26
|
+
versions = [int(v.name.split("_")[-1]) for v in path.iterdir() if v.is_dir()]
|
|
27
|
+
version = max(versions)
|
|
28
|
+
|
|
29
|
+
path_log = path / f"version_{version}" / "metrics.csv"
|
|
30
|
+
|
|
31
|
+
epochs = []
|
|
32
|
+
train_losses_tmp = []
|
|
33
|
+
val_losses_tmp = []
|
|
34
|
+
with open(path_log) as f:
|
|
35
|
+
lines = f.readlines()
|
|
36
|
+
|
|
37
|
+
for single_line in lines[1:]:
|
|
38
|
+
epoch, _, train_loss, _, val_loss = single_line.strip().split(",")
|
|
39
|
+
|
|
40
|
+
epochs.append(epoch)
|
|
41
|
+
train_losses_tmp.append(train_loss)
|
|
42
|
+
val_losses_tmp.append(val_loss)
|
|
43
|
+
|
|
44
|
+
# train and val are not logged on the same row and can have different lengths
|
|
45
|
+
train_epoch = [
|
|
46
|
+
int(epochs[i]) for i in range(len(epochs)) if train_losses_tmp[i] != ""
|
|
47
|
+
]
|
|
48
|
+
val_epoch = [int(epochs[i]) for i in range(len(epochs)) if val_losses_tmp[i] != ""]
|
|
49
|
+
train_losses = [float(loss) for loss in train_losses_tmp if loss != ""]
|
|
50
|
+
val_losses = [float(loss) for loss in val_losses_tmp if loss != ""]
|
|
51
|
+
|
|
52
|
+
return {
|
|
53
|
+
"train_epoch": train_epoch,
|
|
54
|
+
"val_epoch": val_epoch,
|
|
55
|
+
"train_loss": train_losses,
|
|
56
|
+
"val_loss": val_losses,
|
|
57
|
+
}
|
careamics/utils/serializers.py
CHANGED
careamics/utils/torch_utils.py
CHANGED