careamics 0.0.15__py3-none-any.whl → 0.0.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (59) hide show
  1. careamics/careamist.py +6 -12
  2. careamics/cli/conf.py +18 -3
  3. careamics/config/__init__.py +8 -0
  4. careamics/config/algorithms/__init__.py +4 -0
  5. careamics/config/algorithms/hdn_algorithm_model.py +103 -0
  6. careamics/config/algorithms/microsplit_algorithm_model.py +103 -0
  7. careamics/config/algorithms/n2v_algorithm_model.py +1 -2
  8. careamics/config/algorithms/vae_algorithm_model.py +51 -16
  9. careamics/config/architectures/lvae_model.py +12 -8
  10. careamics/config/callback_model.py +7 -3
  11. careamics/config/configuration.py +9 -8
  12. careamics/config/configuration_factories.py +843 -29
  13. careamics/config/data/data_model.py +1 -2
  14. careamics/config/data/ng_data_model.py +1 -2
  15. careamics/config/inference_model.py +1 -2
  16. careamics/config/likelihood_model.py +2 -2
  17. careamics/config/loss_model.py +6 -2
  18. careamics/config/nm_model.py +26 -1
  19. careamics/config/optimizer_models.py +1 -2
  20. careamics/config/support/supported_algorithms.py +5 -3
  21. careamics/config/support/supported_losses.py +5 -2
  22. careamics/config/training_model.py +6 -36
  23. careamics/config/transformations/normalize_model.py +1 -2
  24. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +4 -4
  25. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -2
  26. careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +33 -7
  27. careamics/dataset_ng/patch_extractor/image_stack_loader.py +2 -2
  28. careamics/file_io/read/__init__.py +0 -1
  29. careamics/lightning/__init__.py +16 -2
  30. careamics/lightning/callbacks/__init__.py +2 -0
  31. careamics/lightning/callbacks/data_stats_callback.py +23 -0
  32. careamics/lightning/lightning_module.py +161 -61
  33. careamics/lightning/microsplit_data_module.py +631 -0
  34. careamics/lightning/predict_data_module.py +8 -1
  35. careamics/lightning/train_data_module.py +19 -8
  36. careamics/losses/__init__.py +7 -1
  37. careamics/losses/loss_factory.py +9 -1
  38. careamics/losses/lvae/losses.py +85 -0
  39. careamics/lvae_training/dataset/__init__.py +8 -8
  40. careamics/lvae_training/dataset/config.py +56 -44
  41. careamics/lvae_training/dataset/lc_dataset.py +18 -12
  42. careamics/lvae_training/dataset/ms_dataset_ref.py +5 -5
  43. careamics/lvae_training/dataset/multich_dataset.py +24 -18
  44. careamics/lvae_training/dataset/multifile_dataset.py +6 -6
  45. careamics/model_io/bmz_io.py +9 -5
  46. careamics/models/lvae/likelihoods.py +30 -14
  47. careamics/models/lvae/lvae.py +2 -2
  48. careamics/models/lvae/noise_models.py +20 -14
  49. careamics/prediction_utils/__init__.py +8 -2
  50. careamics/prediction_utils/prediction_outputs.py +48 -3
  51. careamics/prediction_utils/stitch_prediction.py +71 -0
  52. careamics/transforms/xy_random_rotate90.py +1 -1
  53. {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/METADATA +18 -15
  54. {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/RECORD +57 -55
  55. careamics/dataset/zarr_dataset.py +0 -151
  56. careamics/file_io/read/zarr.py +0 -60
  57. {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/WHEEL +0 -0
  58. {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/entry_points.txt +0 -0
  59. {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/licenses/LICENSE +0 -0
@@ -6,7 +6,7 @@ import os
6
6
  import sys
7
7
  from collections.abc import Sequence
8
8
  from pprint import pformat
9
- from typing import Annotated, Any, Literal, Union
9
+ from typing import Annotated, Any, Literal, Self, Union
10
10
  from warnings import warn
11
11
 
12
12
  import numpy as np
@@ -19,7 +19,6 @@ from pydantic import (
19
19
  field_validator,
20
20
  model_validator,
21
21
  )
22
- from typing_extensions import Self
23
22
 
24
23
  from ..transformations import XYFlipModel, XYRandomRotate90Model
25
24
  from ..validators import check_axes_validity, patch_size_ge_than_8_power_of_2
@@ -4,7 +4,7 @@ from __future__ import annotations
4
4
 
5
5
  from collections.abc import Sequence
6
6
  from pprint import pformat
7
- from typing import Annotated, Any, Literal, Union
7
+ from typing import Annotated, Any, Literal, Self, Union
8
8
  from warnings import warn
9
9
 
10
10
  import numpy as np
@@ -17,7 +17,6 @@ from pydantic import (
17
17
  field_validator,
18
18
  model_validator,
19
19
  )
20
- from typing_extensions import Self
21
20
 
22
21
  from ..transformations import XYFlipModel, XYRandomRotate90Model
23
22
  from ..validators import check_axes_validity
@@ -2,10 +2,9 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from typing import Any, Literal, Union
5
+ from typing import Any, Literal, Self, Union
6
6
 
7
7
  from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
8
- from typing_extensions import Self
9
8
 
10
9
  from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
11
10
 
@@ -50,11 +50,11 @@ class NMLikelihoodConfig(BaseModel):
50
50
  model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
51
51
 
52
52
  # TODO remove and use as parameters to the likelihood functions?
53
- data_mean: Tensor = torch.zeros(1)
53
+ data_mean: Tensor | None = None
54
54
  """The mean of the data, used to unnormalize data for noise model evaluation.
55
55
  Shape is (target_ch,) (or (1, target_ch, [1], 1, 1))."""
56
56
 
57
57
  # TODO remove and use as parameters to the likelihood functions?
58
- data_std: Tensor = torch.ones(1)
58
+ data_std: Tensor | None = None
59
59
  """The standard deviation of the data, used to unnormalize data for noise
60
60
  model evaluation. Shape is (target_ch,) (or (1, target_ch, [1], 1, 1))."""
@@ -35,7 +35,9 @@ class LVAELossConfig(BaseModel):
35
35
  validate_assignment=True, validate_default=True, arbitrary_types_allowed=True
36
36
  )
37
37
 
38
- loss_type: Literal["musplit", "denoisplit", "denoisplit_musplit"]
38
+ loss_type: Literal[
39
+ "hdn", "microsplit", "musplit", "denoisplit", "denoisplit_musplit"
40
+ ]
39
41
  """Type of loss to use for LVAE."""
40
42
 
41
43
  reconstruction_weight: float = 1.0
@@ -50,7 +52,9 @@ class LVAELossConfig(BaseModel):
50
52
  """Weight for the denoiSplit loss (used in the muSplit-deonoiSplit loss)."""
51
53
  kl_params: KLLossConfig = KLLossConfig()
52
54
  """KL loss configuration."""
53
-
55
+ # TODO revisit weights for the losses
54
56
  # TODO: remove?
55
57
  non_stochastic: bool = False
56
58
  """Whether to sample latents and compute KL."""
59
+
60
+ # TODO what are the correct parameters for HDN ?
@@ -1,7 +1,7 @@
1
1
  """Noise models config."""
2
2
 
3
3
  from pathlib import Path
4
- from typing import Annotated, Literal, Union
4
+ from typing import Annotated, Literal, Self, Union
5
5
 
6
6
  import numpy as np
7
7
  import torch
@@ -11,6 +11,7 @@ from pydantic import (
11
11
  Field,
12
12
  PlainSerializer,
13
13
  PlainValidator,
14
+ model_validator,
14
15
  )
15
16
 
16
17
  from careamics.utils.serializers import _array_to_json, _to_numpy
@@ -86,6 +87,30 @@ class GaussianMixtureNMConfig(BaseModel):
86
87
  tol: float = Field(default=1e-10)
87
88
  """Tolerance used in the computation of the noise model likelihood."""
88
89
 
90
+ @model_validator(mode="after")
91
+ def validate_path(self: Self) -> Self:
92
+ """Validate that the path points to a valid .npz file if provided.
93
+
94
+ Returns
95
+ -------
96
+ Self
97
+ Returns itself.
98
+
99
+ Raises
100
+ ------
101
+ ValueError
102
+ If the path is provided but does not point to a valid .npz file.
103
+ """
104
+ if self.path is not None:
105
+ path = Path(self.path)
106
+ if not path.exists():
107
+ raise ValueError(f"Path {path} does not exist.")
108
+ if path.suffix != ".npz":
109
+ raise ValueError(f"Path {path} must point to a .npz file.")
110
+ if not path.is_file():
111
+ raise ValueError(f"Path {path} must point to a file.")
112
+ return self
113
+
89
114
  # @model_validator(mode="after")
90
115
  # def validate_path_to_pretrained_vs_training_data(self: Self) -> Self:
91
116
  # """Validate paths provided in the config.
@@ -2,7 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from typing import Literal
5
+ from typing import Literal, Self
6
6
 
7
7
  from pydantic import (
8
8
  BaseModel,
@@ -13,7 +13,6 @@ from pydantic import (
13
13
  model_validator,
14
14
  )
15
15
  from torch import optim
16
- from typing_extensions import Self
17
16
 
18
17
  from careamics.utils.torch_utils import filter_parameters
19
18
 
@@ -26,9 +26,11 @@ class SupportedAlgorithm(str, BaseEnum):
26
26
  MUSPLIT = "musplit"
27
27
  """An image splitting approach based on ladder VAE architectures."""
28
28
 
29
+ MICROSPLIT = "microsplit"
30
+ """A micro-level image splitting approach based on ladder VAE architectures."""
31
+
29
32
  DENOISPLIT = "denoisplit"
30
33
  """An image splitting and denoising approach based on ladder VAE architectures."""
31
34
 
32
- # PN2V = "pn2v"
33
- # HDN = "hdn"
34
- # SEG = "segmentation"
35
+ HDN = "hdn"
36
+ """Hierarchical Denoising Network, an unsupervised denoising algorithm"""
@@ -21,9 +21,12 @@ class SupportedLoss(str, BaseEnum):
21
21
  MAE = "mae"
22
22
  N2V = "n2v"
23
23
  # PN2V = "pn2v"
24
- # HDN = "hdn"
24
+ HDN = "hdn"
25
25
  MUSPLIT = "musplit"
26
+ MICROSPLIT = "microsplit"
26
27
  DENOISPLIT = "denoisplit"
27
- DENOISPLIT_MUSPLIT = "denoisplit_musplit"
28
+ DENOISPLIT_MUSPLIT = (
29
+ "denoisplit_musplit" # TODO refac losses, leave only microsplit
30
+ )
28
31
  # CE = "ce"
29
32
  # DICE = "dice"
@@ -3,9 +3,9 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  from pprint import pformat
6
- from typing import Literal, Union
6
+ from typing import Literal
7
7
 
8
- from pydantic import BaseModel, ConfigDict, Field, field_validator
8
+ from pydantic import BaseModel, ConfigDict, Field
9
9
 
10
10
  from .callback_model import CheckpointModel, EarlyStoppingModel
11
11
 
@@ -29,26 +29,15 @@ class TrainingConfig(BaseModel):
29
29
  model_config = ConfigDict(
30
30
  validate_assignment=True,
31
31
  )
32
+ lightning_trainer_config: dict | None = None
33
+ """Configuration for the PyTorch Lightning Trainer, following PyTorch Lightning
34
+ Trainer class"""
32
35
 
33
- num_epochs: int = Field(default=20, ge=1)
34
- """Number of epochs, greater than 0."""
35
-
36
- precision: Literal["64", "32", "16-mixed", "bf16-mixed"] = Field(default="32")
37
- """Numerical precision"""
38
- max_steps: int = Field(default=-1, ge=-1)
39
- """Maximum number of steps to train for. -1 means no limit."""
40
- check_val_every_n_epoch: int = Field(default=1, ge=1)
41
- """Validation step frequency."""
42
- accumulate_grad_batches: int = Field(default=1, ge=1)
43
- """Number of batches to accumulate gradients over before stepping the optimizer."""
44
- gradient_clip_val: Union[int, float] | None = None
45
- """The value to which to clip the gradient"""
46
- gradient_clip_algorithm: Literal["value", "norm"] = "norm"
47
- """The algorithm to use for gradient clipping (see lightning `Trainer`)."""
48
36
  logger: Literal["wandb", "tensorboard"] | None = None
49
37
  """Logger to use during training. If None, no logger will be used. Available
50
38
  loggers are defined in SupportedLogger."""
51
39
 
40
+ # Only basic callbacks
52
41
  checkpoint_callback: CheckpointModel = CheckpointModel()
53
42
  """Checkpoint callback configuration, following PyTorch Lightning Checkpoint
54
43
  callback."""
@@ -78,22 +67,3 @@ class TrainingConfig(BaseModel):
78
67
  Whether the logger is defined or not.
79
68
  """
80
69
  return self.logger is not None
81
-
82
- @field_validator("max_steps")
83
- @classmethod
84
- def validate_max_steps(cls, max_steps: int) -> int:
85
- """Validate the max_steps parameter.
86
-
87
- Parameters
88
- ----------
89
- max_steps : int
90
- Maximum number of steps to train for. -1 means no limit.
91
-
92
- Returns
93
- -------
94
- int
95
- Validated max_steps.
96
- """
97
- if max_steps == 0:
98
- raise ValueError("max_steps must be greater than 0. Use -1 for no limit.")
99
- return max_steps
@@ -1,9 +1,8 @@
1
1
  """Pydantic model for the Normalize transform."""
2
2
 
3
- from typing import Literal
3
+ from typing import Literal, Self
4
4
 
5
5
  from pydantic import ConfigDict, Field, model_validator
6
- from typing_extensions import Self
7
6
 
8
7
  from .transform_model import TransformModel
9
8
 
@@ -7,7 +7,7 @@ import matplotlib.pyplot as plt
7
7
  import numpy as np
8
8
  import zarr
9
9
  from numpy.typing import NDArray
10
- from zarr.storage import FSStore
10
+ from zarr.storage import FsspecStore
11
11
 
12
12
  from careamics.config import DataConfig
13
13
  from careamics.config.support import SupportedData
@@ -20,7 +20,7 @@ from careamics.dataset_ng.patch_extractor.patch_extractor_factory import (
20
20
 
21
21
  # %%
22
22
  def create_zarr_array(file_path: Path, data_path: str, data: NDArray):
23
- store = FSStore(url=file_path.resolve())
23
+ store = FsspecStore.from_url(url=file_path.resolve())
24
24
  # create array
25
25
  array = zarr.create(
26
26
  store=store,
@@ -61,7 +61,7 @@ if not file_path.is_file() and not file_path.is_dir():
61
61
  # ### Make sure file exists
62
62
 
63
63
  # %%
64
- store = FSStore(url=file_path.resolve(), mode="r")
64
+ store = FsspecStore.from_url(url=file_path.resolve(), mode="r")
65
65
 
66
66
  # %%
67
67
  list(store.keys())
@@ -72,7 +72,7 @@ list(store.keys())
72
72
 
73
73
  # %%
74
74
  class ZarrSource(TypedDict):
75
- store: FSStore
75
+ store: FsspecStore
76
76
  data_paths: Sequence[str]
77
77
 
78
78
 
@@ -1,9 +1,8 @@
1
1
  from collections.abc import Sequence
2
2
  from pathlib import Path
3
- from typing import Any, Literal, Union
3
+ from typing import Any, Literal, Self, Union
4
4
 
5
5
  from numpy.typing import DTypeLike, NDArray
6
- from typing_extensions import Self
7
6
 
8
7
  from careamics.dataset.dataset_utils import reshape_array
9
8
  from careamics.file_io.read import ReadFunc, read_tiff
@@ -1,11 +1,11 @@
1
1
  from collections.abc import Sequence
2
2
  from pathlib import Path
3
- from typing import Union
3
+ from typing import Self, Union
4
4
 
5
+ import validators
5
6
  import zarr
6
- import zarr.storage
7
7
  from numpy.typing import NDArray
8
- from typing_extensions import Self
8
+ from zarr.storage import FsspecStore, LocalStore
9
9
 
10
10
  from careamics.dataset.dataset_utils import reshape_array
11
11
 
@@ -15,9 +15,10 @@ class ZarrImageStack:
15
15
  A class for extracting patches from an image stack that is stored as a zarr array.
16
16
  """
17
17
 
18
- # TODO: keeping store type narrow so that it has the path attribute
19
- # base zarr store is zarr.storage.Store, includes MemoryStore
20
- def __init__(self, store: zarr.storage.FSStore, data_path: str, axes: str):
18
+ # TODO: We should keep store type narrow
19
+ # - in zarr v3, does zarr.storage.Store exists and has the path attribute?
20
+ # - can we declare a narrow type rather than a union?
21
+ def __init__(self, store: LocalStore | FsspecStore, data_path: str, axes: str):
21
22
  self._store = store
22
23
  self._array = zarr.open_array(store=self._store, path=data_path, mode="r")
23
24
  # TODO: validate axes
@@ -46,8 +47,33 @@ class ZarrImageStack:
46
47
  Assumes the path only contains 1 image.
47
48
 
48
49
  Path can be to a local file, or it can be a URL to a zarr stored in the cloud.
50
+
51
+ Parameters
52
+ ----------
53
+ path : Union[Path, str]
54
+ Path to the root of the OME-Zarr, local file or url.
55
+
56
+ Returns
57
+ -------
58
+ ZarrImageStack
59
+ Initialised ZarrImageStack.
60
+
61
+ Raises
62
+ ------
63
+ ValueError
64
+ If the path does not exist or is not a valid URL.
65
+ ValueError
66
+ If the OME-Zarr at the path does not contain the attribute 'multiscales'.
49
67
  """
50
- store = zarr.storage.FSStore(url=path)
68
+ if Path(path).is_file():
69
+ store = zarr.storage.LocalStore(root=Path(path).resolve())
70
+ elif validators.url(path):
71
+ store = zarr.storage.FsspecStore.from_url(url=path)
72
+ else:
73
+ raise ValueError(
74
+ f"Path '{path}' is neither an existing file nor a valid URL."
75
+ )
76
+
51
77
  group = zarr.open_group(store=store, mode="r")
52
78
  if "multiscales" not in group.attrs:
53
79
  raise ValueError(
@@ -38,7 +38,7 @@ class ImageStackLoader(Protocol[P, GenericImageStack]):
38
38
 
39
39
  >>> from typing import TypedDict
40
40
 
41
- >>> from zarr.storage import FSStore
41
+ >>> from zarr.storage import FsspecStore
42
42
 
43
43
  >>> from careamics.config import DataConfig
44
44
  >>> from careamics.dataset_ng.patch_extractor.image_stack import ZarrImageStack
@@ -46,7 +46,7 @@ class ImageStackLoader(Protocol[P, GenericImageStack]):
46
46
  >>> # Define a zarr source
47
47
  >>> # It encompasses multiple arguments that determine what data will be loaded
48
48
  >>> class ZarrSource(TypedDict):
49
- ... store: FSStore
49
+ ... store: FsspecStore
50
50
  ... data_paths: Sequence[str]
51
51
 
52
52
  >>> def custom_image_stack_loader(
@@ -9,4 +9,3 @@ __all__ = [
9
9
 
10
10
  from .get_func import ReadFunc, get_read_func
11
11
  from .tiff import read_tiff
12
- from .zarr import read_zarr
@@ -1,18 +1,32 @@
1
1
  """CAREamics PyTorch Lightning modules."""
2
2
 
3
3
  __all__ = [
4
+ "DataStatsCallback",
4
5
  "FCNModule",
5
6
  "HyperParametersCallback",
7
+ "MicroSplitDataModule",
6
8
  "PredictDataModule",
7
9
  "ProgressBarCallback",
8
10
  "TrainDataModule",
9
11
  "VAEModule",
10
12
  "create_careamics_module",
13
+ "create_microsplit_predict_datamodule",
14
+ "create_microsplit_train_datamodule",
11
15
  "create_predict_datamodule",
12
16
  "create_train_datamodule",
17
+ "create_unet_based_module",
18
+ "create_vae_based_module",
13
19
  ]
14
20
 
15
- from .callbacks import HyperParametersCallback, ProgressBarCallback
21
+ from .callbacks import DataStatsCallback, HyperParametersCallback, ProgressBarCallback
16
22
  from .lightning_module import FCNModule, VAEModule, create_careamics_module
23
+ from .microsplit_data_module import (
24
+ MicroSplitDataModule,
25
+ create_microsplit_predict_datamodule,
26
+ create_microsplit_train_datamodule,
27
+ )
17
28
  from .predict_data_module import PredictDataModule, create_predict_datamodule
18
- from .train_data_module import TrainDataModule, create_train_datamodule
29
+ from .train_data_module import (
30
+ TrainDataModule,
31
+ create_train_datamodule,
32
+ )
@@ -1,11 +1,13 @@
1
1
  """Callbacks module."""
2
2
 
3
3
  __all__ = [
4
+ "DataStatsCallback",
4
5
  "HyperParametersCallback",
5
6
  "PredictionWriterCallback",
6
7
  "ProgressBarCallback",
7
8
  ]
8
9
 
10
+ from .data_stats_callback import DataStatsCallback
9
11
  from .hyperparameters_callback import HyperParametersCallback
10
12
  from .prediction_writer_callback import PredictionWriterCallback
11
13
  from .progress_bar_callback import ProgressBarCallback
@@ -0,0 +1,23 @@
1
+ """Data statistics callback."""
2
+
3
+ import pytorch_lightning as L
4
+ from pytorch_lightning.callbacks import Callback
5
+
6
+
7
+ class DataStatsCallback(Callback):
8
+ """Callback to update model's data statistics from datamodule.
9
+
10
+ This callback ensures that the model has access to the data statistics (mean and std)
11
+ calculated by the datamodule before training starts.
12
+ """
13
+
14
+ def setup(self, trainer: L.Trainer, module: L.LightningModule, stage: str) -> None:
15
+ """Called when trainer is setting up."""
16
+ if stage == "fit":
17
+ # Get data statistics from datamodule
18
+ (data_mean, data_std), _ = trainer.datamodule.get_data_stats()
19
+
20
+ # Set data statistics in the model's likelihood module
21
+ module.noise_model_likelihood.set_data_stats(
22
+ data_mean=data_mean["target"], data_std=data_std["target"]
23
+ )