careamics 0.0.4.1__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +235 -25
- careamics/cli/conf.py +19 -30
- careamics/cli/main.py +111 -10
- careamics/cli/utils.py +29 -0
- careamics/config/__init__.py +2 -0
- careamics/config/architectures/lvae_model.py +104 -21
- careamics/config/configuration_factory.py +49 -45
- careamics/config/configuration_model.py +2 -2
- careamics/config/likelihood_model.py +7 -6
- careamics/config/loss_model.py +56 -0
- careamics/config/nm_model.py +24 -24
- careamics/config/vae_algorithm_model.py +14 -13
- careamics/dataset/dataset_utils/running_stats.py +22 -23
- careamics/lightning/lightning_module.py +58 -27
- careamics/lightning/train_data_module.py +15 -1
- careamics/losses/loss_factory.py +1 -85
- careamics/losses/lvae/losses.py +223 -164
- careamics/lvae_training/calibration.py +184 -0
- careamics/lvae_training/dataset/config.py +2 -2
- careamics/lvae_training/dataset/multich_dataset.py +11 -19
- careamics/lvae_training/dataset/multifile_dataset.py +3 -2
- careamics/lvae_training/dataset/types.py +15 -26
- careamics/lvae_training/dataset/utils/index_manager.py +4 -4
- careamics/lvae_training/eval_utils.py +125 -213
- careamics/model_io/bioimage/_readme_factory.py +25 -33
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +35 -22
- careamics/model_io/bmz_io.py +36 -25
- careamics/models/layers.py +6 -4
- careamics/models/lvae/layers.py +348 -975
- careamics/models/lvae/likelihoods.py +10 -8
- careamics/models/lvae/lvae.py +214 -272
- careamics/models/lvae/noise_models.py +179 -112
- careamics/models/lvae/stochastic.py +393 -0
- careamics/models/lvae/utils.py +82 -73
- careamics/utils/lightning_utils.py +57 -0
- careamics/utils/serializers.py +2 -0
- careamics/utils/torch_utils.py +1 -1
- {careamics-0.0.4.1.dist-info → careamics-0.0.5.dist-info}/METADATA +12 -9
- {careamics-0.0.4.1.dist-info → careamics-0.0.5.dist-info}/RECORD +43 -37
- {careamics-0.0.4.1.dist-info → careamics-0.0.5.dist-info}/WHEEL +1 -1
- {careamics-0.0.4.1.dist-info → careamics-0.0.5.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.4.1.dist-info → careamics-0.0.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -34,36 +34,35 @@ def update_iterative_stats(
|
|
|
34
34
|
Parameters
|
|
35
35
|
----------
|
|
36
36
|
count : NDArray
|
|
37
|
-
Number of elements in the array.
|
|
37
|
+
Number of elements in the array. Shape: (C,).
|
|
38
38
|
mean : NDArray
|
|
39
|
-
Mean of the array.
|
|
39
|
+
Mean of the array. Shape: (C,).
|
|
40
40
|
m2 : NDArray
|
|
41
|
-
Variance of the array.
|
|
41
|
+
Variance of the array. Shape: (C,).
|
|
42
42
|
new_values : NDArray
|
|
43
|
-
New values to add to the mean and variance.
|
|
43
|
+
New values to add to the mean and variance. Shape: (C, 1, 1, Z, Y, X).
|
|
44
44
|
|
|
45
45
|
Returns
|
|
46
46
|
-------
|
|
47
47
|
tuple[NDArray, NDArray, NDArray]
|
|
48
48
|
Updated count, mean, and variance.
|
|
49
49
|
"""
|
|
50
|
-
|
|
51
|
-
# newvalues - oldMean
|
|
52
|
-
delta = [
|
|
53
|
-
np.subtract(v.flatten(), [m] * len(v.flatten()))
|
|
54
|
-
for v, m in zip(new_values, mean)
|
|
55
|
-
]
|
|
50
|
+
num_channels = len(new_values)
|
|
56
51
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
delta2 = [
|
|
60
|
-
np.subtract(v.flatten(), [m] * len(v.flatten()))
|
|
61
|
-
for v, m in zip(new_values, mean)
|
|
62
|
-
]
|
|
52
|
+
# --- update channel-wise counts ---
|
|
53
|
+
count += np.ones_like(count) * np.prod(new_values.shape[1:])
|
|
63
54
|
|
|
64
|
-
|
|
55
|
+
# --- update channel-wise mean ---
|
|
56
|
+
# compute (new_values - old_mean) -> shape: (C, Z*Y*X)
|
|
57
|
+
delta = new_values.reshape(num_channels, -1) - mean.reshape(num_channels, 1)
|
|
58
|
+
mean += np.sum(delta / count.reshape(num_channels, 1), axis=1)
|
|
65
59
|
|
|
66
|
-
|
|
60
|
+
# --- update channel-wise SoS ---
|
|
61
|
+
# compute (new_values - new_mean) -> shape: (C, Z*Y*X)
|
|
62
|
+
delta2 = new_values.reshape(num_channels, -1) - mean.reshape(num_channels, 1)
|
|
63
|
+
m2 += np.sum(delta * delta2, axis=1)
|
|
64
|
+
|
|
65
|
+
return count, mean, m2
|
|
67
66
|
|
|
68
67
|
|
|
69
68
|
def finalize_iterative_stats(
|
|
@@ -74,18 +73,18 @@ def finalize_iterative_stats(
|
|
|
74
73
|
Parameters
|
|
75
74
|
----------
|
|
76
75
|
count : NDArray
|
|
77
|
-
Number of elements in the array.
|
|
76
|
+
Number of elements in the array. Shape: (C,).
|
|
78
77
|
mean : NDArray
|
|
79
|
-
Mean of the array.
|
|
78
|
+
Mean of the array. Shape: (C,).
|
|
80
79
|
m2 : NDArray
|
|
81
|
-
Variance of the array.
|
|
80
|
+
Variance of the array. Shape: (C,).
|
|
82
81
|
|
|
83
82
|
Returns
|
|
84
83
|
-------
|
|
85
84
|
tuple[NDArray, NDArray]
|
|
86
|
-
Final mean and standard deviation.
|
|
85
|
+
Final channel-wise mean and standard deviation.
|
|
87
86
|
"""
|
|
88
|
-
std = np.
|
|
87
|
+
std = np.sqrt(m2 / count)
|
|
89
88
|
if any(c < 2 for c in count):
|
|
90
89
|
return np.full(mean.shape, np.nan), np.full(std.shape, np.nan)
|
|
91
90
|
else:
|
|
@@ -14,8 +14,8 @@ from careamics.config.support import (
|
|
|
14
14
|
SupportedOptimizer,
|
|
15
15
|
SupportedScheduler,
|
|
16
16
|
)
|
|
17
|
+
from careamics.config.tile_information import TileInformation
|
|
17
18
|
from careamics.losses import loss_factory
|
|
18
|
-
from careamics.losses.loss_factory import LVAELossParameters
|
|
19
19
|
from careamics.models.lvae.likelihoods import (
|
|
20
20
|
GaussianLikelihood,
|
|
21
21
|
NoiseModelLikelihood,
|
|
@@ -164,7 +164,17 @@ class FCNModule(L.LightningModule):
|
|
|
164
164
|
Any
|
|
165
165
|
Model output.
|
|
166
166
|
"""
|
|
167
|
-
|
|
167
|
+
# TODO refactor when redoing datasets
|
|
168
|
+
# hacky way to determine if it is PredictDataModule, otherwise there is a
|
|
169
|
+
# circular import to solve with isinstance
|
|
170
|
+
from_prediction = hasattr(self._trainer.datamodule, "tiled")
|
|
171
|
+
is_tiled = (
|
|
172
|
+
len(batch) > 1
|
|
173
|
+
and isinstance(batch[1], list)
|
|
174
|
+
and isinstance(batch[1][0], TileInformation)
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
if is_tiled:
|
|
168
178
|
x, *aux = batch
|
|
169
179
|
else:
|
|
170
180
|
x = batch
|
|
@@ -172,7 +182,10 @@ class FCNModule(L.LightningModule):
|
|
|
172
182
|
|
|
173
183
|
# apply test-time augmentation if available
|
|
174
184
|
# TODO: probably wont work with batch size > 1
|
|
175
|
-
if
|
|
185
|
+
if (
|
|
186
|
+
from_prediction
|
|
187
|
+
and self._trainer.datamodule.prediction_config.tta_transforms
|
|
188
|
+
):
|
|
176
189
|
tta = ImageRestorationTTA()
|
|
177
190
|
augmented_batch = tta.forward(x) # list of augmented tensors
|
|
178
191
|
augmented_output = []
|
|
@@ -184,9 +197,18 @@ class FCNModule(L.LightningModule):
|
|
|
184
197
|
output = self.model(x)
|
|
185
198
|
|
|
186
199
|
# Denormalize the output
|
|
200
|
+
# TODO incompatible API between predict and train datasets
|
|
187
201
|
denorm = Denormalize(
|
|
188
|
-
image_means=
|
|
189
|
-
|
|
202
|
+
image_means=(
|
|
203
|
+
self._trainer.datamodule.predict_dataset.image_means
|
|
204
|
+
if from_prediction
|
|
205
|
+
else self._trainer.datamodule.train_dataset.image_stats.means
|
|
206
|
+
),
|
|
207
|
+
image_stds=(
|
|
208
|
+
self._trainer.datamodule.predict_dataset.image_stds
|
|
209
|
+
if from_prediction
|
|
210
|
+
else self._trainer.datamodule.train_dataset.image_stats.stds
|
|
211
|
+
),
|
|
190
212
|
)
|
|
191
213
|
denormalized_output = denorm(patch=output.cpu().numpy())
|
|
192
214
|
|
|
@@ -266,29 +288,27 @@ class VAEModule(L.LightningModule):
|
|
|
266
288
|
# TODO: log algorithm config
|
|
267
289
|
# self.save_hyperparameters(self.algorithm_config.model_dump())
|
|
268
290
|
|
|
269
|
-
# create model
|
|
291
|
+
# create model
|
|
270
292
|
self.model: nn.Module = model_factory(self.algorithm_config.model)
|
|
271
|
-
|
|
293
|
+
|
|
294
|
+
# create loss function
|
|
295
|
+
self.noise_model: Optional[NoiseModel] = noise_model_factory(
|
|
272
296
|
self.algorithm_config.noise_model
|
|
273
297
|
)
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
self.noise_model_likelihood: NoiseModelLikelihood = likelihood_factory(
|
|
281
|
-
self.algorithm_config.noise_model_likelihood_model
|
|
298
|
+
|
|
299
|
+
self.noise_model_likelihood: Optional[NoiseModelLikelihood] = (
|
|
300
|
+
likelihood_factory(
|
|
301
|
+
config=self.algorithm_config.noise_model_likelihood,
|
|
302
|
+
noise_model=self.noise_model,
|
|
303
|
+
)
|
|
282
304
|
)
|
|
283
|
-
|
|
284
|
-
|
|
305
|
+
|
|
306
|
+
self.gaussian_likelihood: Optional[GaussianLikelihood] = likelihood_factory(
|
|
307
|
+
self.algorithm_config.gaussian_likelihood
|
|
285
308
|
)
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
# TODO: musplit/denoisplit weights ?
|
|
290
|
-
) # type: ignore
|
|
291
|
-
self.loss_func = loss_factory(self.algorithm_config.loss)
|
|
309
|
+
|
|
310
|
+
self.loss_parameters = self.algorithm_config.loss
|
|
311
|
+
self.loss_func = loss_factory(self.algorithm_config.loss.loss_type)
|
|
292
312
|
|
|
293
313
|
# save optimizer and lr_scheduler names and parameters
|
|
294
314
|
self.optimizer_name = self.algorithm_config.optimizer.name
|
|
@@ -344,11 +364,16 @@ class VAEModule(L.LightningModule):
|
|
|
344
364
|
out = self.model(x)
|
|
345
365
|
|
|
346
366
|
# Update loss parameters
|
|
347
|
-
|
|
348
|
-
self.loss_parameters.current_epoch = self.current_epoch
|
|
367
|
+
self.loss_parameters.kl_params.current_epoch = self.current_epoch
|
|
349
368
|
|
|
350
369
|
# Compute loss
|
|
351
|
-
loss = self.loss_func(
|
|
370
|
+
loss = self.loss_func(
|
|
371
|
+
model_outputs=out,
|
|
372
|
+
targets=target,
|
|
373
|
+
config=self.loss_parameters,
|
|
374
|
+
gaussian_likelihood=self.gaussian_likelihood,
|
|
375
|
+
noise_model_likelihood=self.noise_model_likelihood,
|
|
376
|
+
)
|
|
352
377
|
|
|
353
378
|
# Logging
|
|
354
379
|
# TODO: implement a separate logging method?
|
|
@@ -376,7 +401,13 @@ class VAEModule(L.LightningModule):
|
|
|
376
401
|
out = self.model(x)
|
|
377
402
|
|
|
378
403
|
# Compute loss
|
|
379
|
-
loss = self.loss_func(
|
|
404
|
+
loss = self.loss_func(
|
|
405
|
+
model_outputs=out,
|
|
406
|
+
targets=target,
|
|
407
|
+
config=self.loss_parameters,
|
|
408
|
+
gaussian_likelihood=self.gaussian_likelihood,
|
|
409
|
+
noise_model_likelihood=self.noise_model_likelihood,
|
|
410
|
+
)
|
|
380
411
|
|
|
381
412
|
# Logging
|
|
382
413
|
# Rename val_loss dict
|
|
@@ -2,11 +2,12 @@
|
|
|
2
2
|
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
from typing import Any, Callable, Literal, Optional, Union
|
|
5
|
+
from warnings import warn
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
7
8
|
import pytorch_lightning as L
|
|
8
9
|
from numpy.typing import NDArray
|
|
9
|
-
from torch.utils.data import DataLoader
|
|
10
|
+
from torch.utils.data import DataLoader, IterableDataset
|
|
10
11
|
|
|
11
12
|
from careamics.config import DataConfig
|
|
12
13
|
from careamics.config.support import SupportedData
|
|
@@ -446,6 +447,19 @@ class TrainDataModule(L.LightningDataModule):
|
|
|
446
447
|
Any
|
|
447
448
|
Training dataloader.
|
|
448
449
|
"""
|
|
450
|
+
# check because iterable dataset cannot be shuffled
|
|
451
|
+
if not isinstance(self.train_dataset, IterableDataset):
|
|
452
|
+
if ("shuffle" in self.dataloader_params) and (
|
|
453
|
+
not self.dataloader_params["shuffle"]
|
|
454
|
+
):
|
|
455
|
+
warn(
|
|
456
|
+
"Dataloader parameters include `shuffle=False`, this will be "
|
|
457
|
+
"passed to the training dataloader and may result in bad results.",
|
|
458
|
+
stacklevel=1,
|
|
459
|
+
)
|
|
460
|
+
else:
|
|
461
|
+
self.dataloader_params["shuffle"] = True
|
|
462
|
+
|
|
449
463
|
return DataLoader(
|
|
450
464
|
self.train_dataset, batch_size=self.batch_size, **self.dataloader_params
|
|
451
465
|
)
|
careamics/losses/loss_factory.py
CHANGED
|
@@ -7,7 +7,7 @@ This module contains a factory function for creating loss functions.
|
|
|
7
7
|
from __future__ import annotations
|
|
8
8
|
|
|
9
9
|
from dataclasses import dataclass
|
|
10
|
-
from typing import
|
|
10
|
+
from typing import Callable, Union
|
|
11
11
|
|
|
12
12
|
from torch import Tensor as tensor
|
|
13
13
|
|
|
@@ -15,18 +15,6 @@ from ..config.support import SupportedLoss
|
|
|
15
15
|
from .fcn.losses import mae_loss, mse_loss, n2v_loss
|
|
16
16
|
from .lvae.losses import denoisplit_loss, denoisplit_musplit_loss, musplit_loss
|
|
17
17
|
|
|
18
|
-
if TYPE_CHECKING:
|
|
19
|
-
from careamics.models.lvae.likelihoods import (
|
|
20
|
-
GaussianLikelihood,
|
|
21
|
-
NoiseModelLikelihood,
|
|
22
|
-
)
|
|
23
|
-
from careamics.models.lvae.noise_models import (
|
|
24
|
-
GaussianMixtureNoiseModel,
|
|
25
|
-
MultiChannelNoiseModel,
|
|
26
|
-
)
|
|
27
|
-
|
|
28
|
-
NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
|
|
29
|
-
|
|
30
18
|
|
|
31
19
|
@dataclass
|
|
32
20
|
class FCNLossParameters:
|
|
@@ -40,78 +28,6 @@ class FCNLossParameters:
|
|
|
40
28
|
loss_weight: float
|
|
41
29
|
|
|
42
30
|
|
|
43
|
-
@dataclass # TODO why not pydantic?
|
|
44
|
-
class LVAELossParameters:
|
|
45
|
-
"""Dataclass for LVAE loss."""
|
|
46
|
-
|
|
47
|
-
# TODO: refactor in more modular blocks (otherwise it gets messy very easily)
|
|
48
|
-
# e.g., - weights, - kl_params, ...
|
|
49
|
-
|
|
50
|
-
noise_model_likelihood: Optional[NoiseModelLikelihood] = None
|
|
51
|
-
"""Noise model likelihood instance."""
|
|
52
|
-
gaussian_likelihood: Optional[GaussianLikelihood] = None
|
|
53
|
-
"""Gaussian likelihood instance."""
|
|
54
|
-
current_epoch: int = 0
|
|
55
|
-
"""Current epoch in the training loop."""
|
|
56
|
-
reconstruction_weight: float = 1.0
|
|
57
|
-
"""Weight for the reconstruction loss in the total net loss
|
|
58
|
-
(i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`)."""
|
|
59
|
-
musplit_weight: float = 0.1
|
|
60
|
-
"""Weight for the muSplit loss (used in the muSplit-denoiSplit loss)."""
|
|
61
|
-
denoisplit_weight: float = 0.9
|
|
62
|
-
"""Weight for the denoiSplit loss (used in the muSplit-deonoiSplit loss)."""
|
|
63
|
-
kl_type: Literal["kl", "kl_restricted", "kl_spatial", "kl_channelwise"] = "kl"
|
|
64
|
-
"""Type of KL divergence used as KL loss."""
|
|
65
|
-
kl_weight: float = 1.0
|
|
66
|
-
"""Weight for the KL loss in the total net loss.
|
|
67
|
-
(i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`)."""
|
|
68
|
-
kl_annealing: bool = False
|
|
69
|
-
"""Whether to apply KL loss annealing."""
|
|
70
|
-
kl_start: int = -1
|
|
71
|
-
"""Epoch at which KL loss annealing starts."""
|
|
72
|
-
kl_annealtime: int = 10
|
|
73
|
-
"""Number of epochs for which KL loss annealing is applied."""
|
|
74
|
-
non_stochastic: bool = False
|
|
75
|
-
"""Whether to sample latents and compute KL."""
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
# TODO: really needed?
|
|
79
|
-
# like it is now, it is difficult to use, we need a way to specify the
|
|
80
|
-
# loss parameters in a more user-friendly way.
|
|
81
|
-
def loss_parameters_factory(
|
|
82
|
-
type: SupportedLoss,
|
|
83
|
-
) -> Union[FCNLossParameters, LVAELossParameters]:
|
|
84
|
-
"""Return loss parameters.
|
|
85
|
-
|
|
86
|
-
Parameters
|
|
87
|
-
----------
|
|
88
|
-
type : SupportedLoss
|
|
89
|
-
Requested loss.
|
|
90
|
-
|
|
91
|
-
Returns
|
|
92
|
-
-------
|
|
93
|
-
Union[FCNLossParameters, LVAELossParameters]
|
|
94
|
-
Loss parameters.
|
|
95
|
-
|
|
96
|
-
Raises
|
|
97
|
-
------
|
|
98
|
-
NotImplementedError
|
|
99
|
-
If the loss is unknown.
|
|
100
|
-
"""
|
|
101
|
-
if type in [SupportedLoss.N2V, SupportedLoss.MSE, SupportedLoss.MAE]:
|
|
102
|
-
return FCNLossParameters
|
|
103
|
-
|
|
104
|
-
elif type in [
|
|
105
|
-
SupportedLoss.MUSPLIT,
|
|
106
|
-
SupportedLoss.DENOISPLIT,
|
|
107
|
-
SupportedLoss.DENOISPLIT_MUSPLIT,
|
|
108
|
-
]:
|
|
109
|
-
return LVAELossParameters # it returns the class, not an instance
|
|
110
|
-
|
|
111
|
-
else:
|
|
112
|
-
raise NotImplementedError(f"Loss {type} is not yet supported.")
|
|
113
|
-
|
|
114
|
-
|
|
115
31
|
def loss_factory(loss: Union[SupportedLoss, str]) -> Callable:
|
|
116
32
|
"""Return loss function.
|
|
117
33
|
|