careamics 0.0.4.2__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (43) hide show
  1. careamics/careamist.py +235 -25
  2. careamics/cli/conf.py +19 -30
  3. careamics/cli/main.py +111 -10
  4. careamics/cli/utils.py +29 -0
  5. careamics/config/__init__.py +2 -0
  6. careamics/config/architectures/lvae_model.py +104 -21
  7. careamics/config/configuration_factory.py +49 -45
  8. careamics/config/configuration_model.py +2 -2
  9. careamics/config/likelihood_model.py +7 -6
  10. careamics/config/loss_model.py +56 -0
  11. careamics/config/nm_model.py +24 -24
  12. careamics/config/vae_algorithm_model.py +14 -13
  13. careamics/dataset/dataset_utils/running_stats.py +22 -23
  14. careamics/lightning/lightning_module.py +58 -27
  15. careamics/lightning/train_data_module.py +15 -1
  16. careamics/losses/loss_factory.py +1 -85
  17. careamics/losses/lvae/losses.py +223 -164
  18. careamics/lvae_training/calibration.py +184 -0
  19. careamics/lvae_training/dataset/config.py +2 -2
  20. careamics/lvae_training/dataset/multich_dataset.py +11 -19
  21. careamics/lvae_training/dataset/multifile_dataset.py +3 -2
  22. careamics/lvae_training/dataset/types.py +15 -26
  23. careamics/lvae_training/dataset/utils/index_manager.py +4 -4
  24. careamics/lvae_training/eval_utils.py +125 -213
  25. careamics/model_io/bioimage/_readme_factory.py +25 -33
  26. careamics/model_io/bioimage/cover_factory.py +171 -0
  27. careamics/model_io/bioimage/model_description.py +39 -17
  28. careamics/model_io/bmz_io.py +36 -25
  29. careamics/models/layers.py +6 -4
  30. careamics/models/lvae/layers.py +348 -975
  31. careamics/models/lvae/likelihoods.py +10 -8
  32. careamics/models/lvae/lvae.py +214 -272
  33. careamics/models/lvae/noise_models.py +179 -112
  34. careamics/models/lvae/stochastic.py +393 -0
  35. careamics/models/lvae/utils.py +82 -73
  36. careamics/utils/lightning_utils.py +57 -0
  37. careamics/utils/serializers.py +2 -0
  38. careamics/utils/torch_utils.py +1 -1
  39. {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/METADATA +12 -9
  40. {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/RECORD +43 -37
  41. {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/WHEEL +1 -1
  42. {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/entry_points.txt +0 -0
  43. {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/licenses/LICENSE +0 -0
@@ -34,36 +34,35 @@ def update_iterative_stats(
34
34
  Parameters
35
35
  ----------
36
36
  count : NDArray
37
- Number of elements in the array.
37
+ Number of elements in the array. Shape: (C,).
38
38
  mean : NDArray
39
- Mean of the array.
39
+ Mean of the array. Shape: (C,).
40
40
  m2 : NDArray
41
- Variance of the array.
41
+ Variance of the array. Shape: (C,).
42
42
  new_values : NDArray
43
- New values to add to the mean and variance.
43
+ New values to add to the mean and variance. Shape: (C, 1, 1, Z, Y, X).
44
44
 
45
45
  Returns
46
46
  -------
47
47
  tuple[NDArray, NDArray, NDArray]
48
48
  Updated count, mean, and variance.
49
49
  """
50
- count += np.array([np.prod(channel.shape) for channel in new_values])
51
- # newvalues - oldMean
52
- delta = [
53
- np.subtract(v.flatten(), [m] * len(v.flatten()))
54
- for v, m in zip(new_values, mean)
55
- ]
50
+ num_channels = len(new_values)
56
51
 
57
- mean += np.array([np.sum(d / c) for d, c in zip(delta, count)])
58
- # newvalues - newMeant
59
- delta2 = [
60
- np.subtract(v.flatten(), [m] * len(v.flatten()))
61
- for v, m in zip(new_values, mean)
62
- ]
52
+ # --- update channel-wise counts ---
53
+ count += np.ones_like(count) * np.prod(new_values.shape[1:])
63
54
 
64
- m2 += np.array([np.sum(d * d2) for d, d2 in zip(delta, delta2)])
55
+ # --- update channel-wise mean ---
56
+ # compute (new_values - old_mean) -> shape: (C, Z*Y*X)
57
+ delta = new_values.reshape(num_channels, -1) - mean.reshape(num_channels, 1)
58
+ mean += np.sum(delta / count.reshape(num_channels, 1), axis=1)
65
59
 
66
- return (count, mean, m2)
60
+ # --- update channel-wise SoS ---
61
+ # compute (new_values - new_mean) -> shape: (C, Z*Y*X)
62
+ delta2 = new_values.reshape(num_channels, -1) - mean.reshape(num_channels, 1)
63
+ m2 += np.sum(delta * delta2, axis=1)
64
+
65
+ return count, mean, m2
67
66
 
68
67
 
69
68
  def finalize_iterative_stats(
@@ -74,18 +73,18 @@ def finalize_iterative_stats(
74
73
  Parameters
75
74
  ----------
76
75
  count : NDArray
77
- Number of elements in the array.
76
+ Number of elements in the array. Shape: (C,).
78
77
  mean : NDArray
79
- Mean of the array.
78
+ Mean of the array. Shape: (C,).
80
79
  m2 : NDArray
81
- Variance of the array.
80
+ Variance of the array. Shape: (C,).
82
81
 
83
82
  Returns
84
83
  -------
85
84
  tuple[NDArray, NDArray]
86
- Final mean and standard deviation.
85
+ Final channel-wise mean and standard deviation.
87
86
  """
88
- std = np.array([np.sqrt(m / c) for m, c in zip(m2, count)])
87
+ std = np.sqrt(m2 / count)
89
88
  if any(c < 2 for c in count):
90
89
  return np.full(mean.shape, np.nan), np.full(std.shape, np.nan)
91
90
  else:
@@ -14,8 +14,8 @@ from careamics.config.support import (
14
14
  SupportedOptimizer,
15
15
  SupportedScheduler,
16
16
  )
17
+ from careamics.config.tile_information import TileInformation
17
18
  from careamics.losses import loss_factory
18
- from careamics.losses.loss_factory import LVAELossParameters
19
19
  from careamics.models.lvae.likelihoods import (
20
20
  GaussianLikelihood,
21
21
  NoiseModelLikelihood,
@@ -164,7 +164,17 @@ class FCNModule(L.LightningModule):
164
164
  Any
165
165
  Model output.
166
166
  """
167
- if self._trainer.datamodule.tiled:
167
+ # TODO refactor when redoing datasets
168
+ # hacky way to determine if it is PredictDataModule, otherwise there is a
169
+ # circular import to solve with isinstance
170
+ from_prediction = hasattr(self._trainer.datamodule, "tiled")
171
+ is_tiled = (
172
+ len(batch) > 1
173
+ and isinstance(batch[1], list)
174
+ and isinstance(batch[1][0], TileInformation)
175
+ )
176
+
177
+ if is_tiled:
168
178
  x, *aux = batch
169
179
  else:
170
180
  x = batch
@@ -172,7 +182,10 @@ class FCNModule(L.LightningModule):
172
182
 
173
183
  # apply test-time augmentation if available
174
184
  # TODO: probably wont work with batch size > 1
175
- if self._trainer.datamodule.prediction_config.tta_transforms:
185
+ if (
186
+ from_prediction
187
+ and self._trainer.datamodule.prediction_config.tta_transforms
188
+ ):
176
189
  tta = ImageRestorationTTA()
177
190
  augmented_batch = tta.forward(x) # list of augmented tensors
178
191
  augmented_output = []
@@ -184,9 +197,18 @@ class FCNModule(L.LightningModule):
184
197
  output = self.model(x)
185
198
 
186
199
  # Denormalize the output
200
+ # TODO incompatible API between predict and train datasets
187
201
  denorm = Denormalize(
188
- image_means=self._trainer.datamodule.predict_dataset.image_means,
189
- image_stds=self._trainer.datamodule.predict_dataset.image_stds,
202
+ image_means=(
203
+ self._trainer.datamodule.predict_dataset.image_means
204
+ if from_prediction
205
+ else self._trainer.datamodule.train_dataset.image_stats.means
206
+ ),
207
+ image_stds=(
208
+ self._trainer.datamodule.predict_dataset.image_stds
209
+ if from_prediction
210
+ else self._trainer.datamodule.train_dataset.image_stats.stds
211
+ ),
190
212
  )
191
213
  denormalized_output = denorm(patch=output.cpu().numpy())
192
214
 
@@ -266,29 +288,27 @@ class VAEModule(L.LightningModule):
266
288
  # TODO: log algorithm config
267
289
  # self.save_hyperparameters(self.algorithm_config.model_dump())
268
290
 
269
- # create model and loss function
291
+ # create model
270
292
  self.model: nn.Module = model_factory(self.algorithm_config.model)
271
- self.noise_model: NoiseModel = noise_model_factory(
293
+
294
+ # create loss function
295
+ self.noise_model: Optional[NoiseModel] = noise_model_factory(
272
296
  self.algorithm_config.noise_model
273
297
  )
274
- # TODO: here we can add some code to check whether the noise model is not None
275
- # and `self.algorithm_config.noise_model_likelihood_model.noise_model` is,
276
- # instead, None. In that case we could assign the noise model to the latter.
277
- # This is particular useful when loading an algorithm config from file.
278
- # Indeed, in that case the noise model in the nm likelihood is likely
279
- # not available since excluded from serializaion.
280
- self.noise_model_likelihood: NoiseModelLikelihood = likelihood_factory(
281
- self.algorithm_config.noise_model_likelihood_model
298
+
299
+ self.noise_model_likelihood: Optional[NoiseModelLikelihood] = (
300
+ likelihood_factory(
301
+ config=self.algorithm_config.noise_model_likelihood,
302
+ noise_model=self.noise_model,
303
+ )
282
304
  )
283
- self.gaussian_likelihood: GaussianLikelihood = likelihood_factory(
284
- self.algorithm_config.gaussian_likelihood_model
305
+
306
+ self.gaussian_likelihood: Optional[GaussianLikelihood] = likelihood_factory(
307
+ self.algorithm_config.gaussian_likelihood
285
308
  )
286
- self.loss_parameters = LVAELossParameters(
287
- noise_model_likelihood=self.noise_model_likelihood,
288
- gaussian_likelihood=self.gaussian_likelihood,
289
- # TODO: musplit/denoisplit weights ?
290
- ) # type: ignore
291
- self.loss_func = loss_factory(self.algorithm_config.loss)
309
+
310
+ self.loss_parameters = self.algorithm_config.loss
311
+ self.loss_func = loss_factory(self.algorithm_config.loss.loss_type)
292
312
 
293
313
  # save optimizer and lr_scheduler names and parameters
294
314
  self.optimizer_name = self.algorithm_config.optimizer.name
@@ -344,11 +364,16 @@ class VAEModule(L.LightningModule):
344
364
  out = self.model(x)
345
365
 
346
366
  # Update loss parameters
347
- # TODO rethink loss parameters
348
- self.loss_parameters.current_epoch = self.current_epoch
367
+ self.loss_parameters.kl_params.current_epoch = self.current_epoch
349
368
 
350
369
  # Compute loss
351
- loss = self.loss_func(out, target, self.loss_parameters) # TODO ugly ?
370
+ loss = self.loss_func(
371
+ model_outputs=out,
372
+ targets=target,
373
+ config=self.loss_parameters,
374
+ gaussian_likelihood=self.gaussian_likelihood,
375
+ noise_model_likelihood=self.noise_model_likelihood,
376
+ )
352
377
 
353
378
  # Logging
354
379
  # TODO: implement a separate logging method?
@@ -376,7 +401,13 @@ class VAEModule(L.LightningModule):
376
401
  out = self.model(x)
377
402
 
378
403
  # Compute loss
379
- loss = self.loss_func(out, target, self.loss_parameters)
404
+ loss = self.loss_func(
405
+ model_outputs=out,
406
+ targets=target,
407
+ config=self.loss_parameters,
408
+ gaussian_likelihood=self.gaussian_likelihood,
409
+ noise_model_likelihood=self.noise_model_likelihood,
410
+ )
380
411
 
381
412
  # Logging
382
413
  # Rename val_loss dict
@@ -2,11 +2,12 @@
2
2
 
3
3
  from pathlib import Path
4
4
  from typing import Any, Callable, Literal, Optional, Union
5
+ from warnings import warn
5
6
 
6
7
  import numpy as np
7
8
  import pytorch_lightning as L
8
9
  from numpy.typing import NDArray
9
- from torch.utils.data import DataLoader
10
+ from torch.utils.data import DataLoader, IterableDataset
10
11
 
11
12
  from careamics.config import DataConfig
12
13
  from careamics.config.support import SupportedData
@@ -446,6 +447,19 @@ class TrainDataModule(L.LightningDataModule):
446
447
  Any
447
448
  Training dataloader.
448
449
  """
450
+ # check because iterable dataset cannot be shuffled
451
+ if not isinstance(self.train_dataset, IterableDataset):
452
+ if ("shuffle" in self.dataloader_params) and (
453
+ not self.dataloader_params["shuffle"]
454
+ ):
455
+ warn(
456
+ "Dataloader parameters include `shuffle=False`, this will be "
457
+ "passed to the training dataloader and may result in bad results.",
458
+ stacklevel=1,
459
+ )
460
+ else:
461
+ self.dataloader_params["shuffle"] = True
462
+
449
463
  return DataLoader(
450
464
  self.train_dataset, batch_size=self.batch_size, **self.dataloader_params
451
465
  )
@@ -7,7 +7,7 @@ This module contains a factory function for creating loss functions.
7
7
  from __future__ import annotations
8
8
 
9
9
  from dataclasses import dataclass
10
- from typing import TYPE_CHECKING, Callable, Literal, Optional, Union
10
+ from typing import Callable, Union
11
11
 
12
12
  from torch import Tensor as tensor
13
13
 
@@ -15,18 +15,6 @@ from ..config.support import SupportedLoss
15
15
  from .fcn.losses import mae_loss, mse_loss, n2v_loss
16
16
  from .lvae.losses import denoisplit_loss, denoisplit_musplit_loss, musplit_loss
17
17
 
18
- if TYPE_CHECKING:
19
- from careamics.models.lvae.likelihoods import (
20
- GaussianLikelihood,
21
- NoiseModelLikelihood,
22
- )
23
- from careamics.models.lvae.noise_models import (
24
- GaussianMixtureNoiseModel,
25
- MultiChannelNoiseModel,
26
- )
27
-
28
- NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
29
-
30
18
 
31
19
  @dataclass
32
20
  class FCNLossParameters:
@@ -40,78 +28,6 @@ class FCNLossParameters:
40
28
  loss_weight: float
41
29
 
42
30
 
43
- @dataclass # TODO why not pydantic?
44
- class LVAELossParameters:
45
- """Dataclass for LVAE loss."""
46
-
47
- # TODO: refactor in more modular blocks (otherwise it gets messy very easily)
48
- # e.g., - weights, - kl_params, ...
49
-
50
- noise_model_likelihood: Optional[NoiseModelLikelihood] = None
51
- """Noise model likelihood instance."""
52
- gaussian_likelihood: Optional[GaussianLikelihood] = None
53
- """Gaussian likelihood instance."""
54
- current_epoch: int = 0
55
- """Current epoch in the training loop."""
56
- reconstruction_weight: float = 1.0
57
- """Weight for the reconstruction loss in the total net loss
58
- (i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`)."""
59
- musplit_weight: float = 0.1
60
- """Weight for the muSplit loss (used in the muSplit-denoiSplit loss)."""
61
- denoisplit_weight: float = 0.9
62
- """Weight for the denoiSplit loss (used in the muSplit-deonoiSplit loss)."""
63
- kl_type: Literal["kl", "kl_restricted", "kl_spatial", "kl_channelwise"] = "kl"
64
- """Type of KL divergence used as KL loss."""
65
- kl_weight: float = 1.0
66
- """Weight for the KL loss in the total net loss.
67
- (i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`)."""
68
- kl_annealing: bool = False
69
- """Whether to apply KL loss annealing."""
70
- kl_start: int = -1
71
- """Epoch at which KL loss annealing starts."""
72
- kl_annealtime: int = 10
73
- """Number of epochs for which KL loss annealing is applied."""
74
- non_stochastic: bool = False
75
- """Whether to sample latents and compute KL."""
76
-
77
-
78
- # TODO: really needed?
79
- # like it is now, it is difficult to use, we need a way to specify the
80
- # loss parameters in a more user-friendly way.
81
- def loss_parameters_factory(
82
- type: SupportedLoss,
83
- ) -> Union[FCNLossParameters, LVAELossParameters]:
84
- """Return loss parameters.
85
-
86
- Parameters
87
- ----------
88
- type : SupportedLoss
89
- Requested loss.
90
-
91
- Returns
92
- -------
93
- Union[FCNLossParameters, LVAELossParameters]
94
- Loss parameters.
95
-
96
- Raises
97
- ------
98
- NotImplementedError
99
- If the loss is unknown.
100
- """
101
- if type in [SupportedLoss.N2V, SupportedLoss.MSE, SupportedLoss.MAE]:
102
- return FCNLossParameters
103
-
104
- elif type in [
105
- SupportedLoss.MUSPLIT,
106
- SupportedLoss.DENOISPLIT,
107
- SupportedLoss.DENOISPLIT_MUSPLIT,
108
- ]:
109
- return LVAELossParameters # it returns the class, not an instance
110
-
111
- else:
112
- raise NotImplementedError(f"Loss {type} is not yet supported.")
113
-
114
-
115
31
  def loss_factory(loss: Union[SupportedLoss, str]) -> Callable:
116
32
  """Return loss function.
117
33