careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (155) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +729 -0
  3. careamics/config/__init__.py +39 -0
  4. careamics/config/architectures/__init__.py +17 -0
  5. careamics/config/architectures/architecture_model.py +37 -0
  6. careamics/config/architectures/custom_model.py +162 -0
  7. careamics/config/architectures/lvae_model.py +174 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/callback_model.py +123 -0
  11. careamics/config/configuration_factory.py +583 -0
  12. careamics/config/configuration_model.py +604 -0
  13. careamics/config/data_model.py +527 -0
  14. careamics/config/fcn_algorithm_model.py +147 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/likelihood_model.py +43 -0
  17. careamics/config/nm_model.py +101 -0
  18. careamics/config/optimizer_models.py +187 -0
  19. careamics/config/references/__init__.py +45 -0
  20. careamics/config/references/algorithm_descriptions.py +132 -0
  21. careamics/config/references/references.py +39 -0
  22. careamics/config/support/__init__.py +31 -0
  23. careamics/config/support/supported_activations.py +27 -0
  24. careamics/config/support/supported_algorithms.py +33 -0
  25. careamics/config/support/supported_architectures.py +17 -0
  26. careamics/config/support/supported_data.py +109 -0
  27. careamics/config/support/supported_loggers.py +10 -0
  28. careamics/config/support/supported_losses.py +29 -0
  29. careamics/config/support/supported_optimizers.py +57 -0
  30. careamics/config/support/supported_pixel_manipulations.py +15 -0
  31. careamics/config/support/supported_struct_axis.py +21 -0
  32. careamics/config/support/supported_transforms.py +11 -0
  33. careamics/config/tile_information.py +65 -0
  34. careamics/config/training_model.py +72 -0
  35. careamics/config/transformations/__init__.py +15 -0
  36. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  37. careamics/config/transformations/normalize_model.py +60 -0
  38. careamics/config/transformations/transform_model.py +45 -0
  39. careamics/config/transformations/xy_flip_model.py +43 -0
  40. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  41. careamics/config/vae_algorithm_model.py +171 -0
  42. careamics/config/validators/__init__.py +5 -0
  43. careamics/config/validators/validator_utils.py +101 -0
  44. careamics/conftest.py +39 -0
  45. careamics/dataset/__init__.py +17 -0
  46. careamics/dataset/dataset_utils/__init__.py +19 -0
  47. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  48. careamics/dataset/dataset_utils/file_utils.py +141 -0
  49. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  50. careamics/dataset/dataset_utils/running_stats.py +186 -0
  51. careamics/dataset/in_memory_dataset.py +310 -0
  52. careamics/dataset/in_memory_pred_dataset.py +88 -0
  53. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  54. careamics/dataset/iterable_dataset.py +295 -0
  55. careamics/dataset/iterable_pred_dataset.py +122 -0
  56. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  57. careamics/dataset/patching/__init__.py +1 -0
  58. careamics/dataset/patching/patching.py +299 -0
  59. careamics/dataset/patching/random_patching.py +201 -0
  60. careamics/dataset/patching/sequential_patching.py +212 -0
  61. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  62. careamics/dataset/tiling/__init__.py +10 -0
  63. careamics/dataset/tiling/collate_tiles.py +33 -0
  64. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  65. careamics/dataset/tiling/tiled_patching.py +164 -0
  66. careamics/dataset/zarr_dataset.py +151 -0
  67. careamics/file_io/__init__.py +15 -0
  68. careamics/file_io/read/__init__.py +12 -0
  69. careamics/file_io/read/get_func.py +56 -0
  70. careamics/file_io/read/tiff.py +58 -0
  71. careamics/file_io/read/zarr.py +60 -0
  72. careamics/file_io/write/__init__.py +15 -0
  73. careamics/file_io/write/get_func.py +63 -0
  74. careamics/file_io/write/tiff.py +40 -0
  75. careamics/lightning/__init__.py +18 -0
  76. careamics/lightning/callbacks/__init__.py +11 -0
  77. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  79. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  80. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  81. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  82. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  83. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  84. careamics/lightning/lightning_module.py +632 -0
  85. careamics/lightning/predict_data_module.py +333 -0
  86. careamics/lightning/train_data_module.py +680 -0
  87. careamics/losses/__init__.py +15 -0
  88. careamics/losses/fcn/__init__.py +1 -0
  89. careamics/losses/fcn/losses.py +98 -0
  90. careamics/losses/loss_factory.py +155 -0
  91. careamics/losses/lvae/__init__.py +1 -0
  92. careamics/losses/lvae/loss_utils.py +83 -0
  93. careamics/losses/lvae/losses.py +445 -0
  94. careamics/lvae_training/__init__.py +0 -0
  95. careamics/lvae_training/dataset/__init__.py +0 -0
  96. careamics/lvae_training/dataset/data_utils.py +701 -0
  97. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  98. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  99. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  100. careamics/lvae_training/dataset/vae_dataset.py +1054 -0
  101. careamics/lvae_training/eval_utils.py +905 -0
  102. careamics/lvae_training/get_config.py +84 -0
  103. careamics/lvae_training/lightning_module.py +701 -0
  104. careamics/lvae_training/metrics.py +214 -0
  105. careamics/lvae_training/train_lvae.py +342 -0
  106. careamics/lvae_training/train_utils.py +121 -0
  107. careamics/model_io/__init__.py +7 -0
  108. careamics/model_io/bioimage/__init__.py +11 -0
  109. careamics/model_io/bioimage/_readme_factory.py +121 -0
  110. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  111. careamics/model_io/bioimage/model_description.py +327 -0
  112. careamics/model_io/bmz_io.py +246 -0
  113. careamics/model_io/model_io_utils.py +95 -0
  114. careamics/models/__init__.py +5 -0
  115. careamics/models/activation.py +39 -0
  116. careamics/models/layers.py +493 -0
  117. careamics/models/lvae/__init__.py +3 -0
  118. careamics/models/lvae/layers.py +1998 -0
  119. careamics/models/lvae/likelihoods.py +364 -0
  120. careamics/models/lvae/lvae.py +901 -0
  121. careamics/models/lvae/noise_models.py +541 -0
  122. careamics/models/lvae/utils.py +395 -0
  123. careamics/models/model_factory.py +67 -0
  124. careamics/models/unet.py +443 -0
  125. careamics/prediction_utils/__init__.py +10 -0
  126. careamics/prediction_utils/lvae_prediction.py +158 -0
  127. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  128. careamics/prediction_utils/prediction_outputs.py +135 -0
  129. careamics/prediction_utils/stitch_prediction.py +112 -0
  130. careamics/transforms/__init__.py +20 -0
  131. careamics/transforms/compose.py +107 -0
  132. careamics/transforms/n2v_manipulate.py +146 -0
  133. careamics/transforms/normalize.py +243 -0
  134. careamics/transforms/pixel_manipulation.py +407 -0
  135. careamics/transforms/struct_mask_parameters.py +20 -0
  136. careamics/transforms/transform.py +24 -0
  137. careamics/transforms/tta.py +88 -0
  138. careamics/transforms/xy_flip.py +123 -0
  139. careamics/transforms/xy_random_rotate90.py +101 -0
  140. careamics/utils/__init__.py +19 -0
  141. careamics/utils/autocorrelation.py +40 -0
  142. careamics/utils/base_enum.py +60 -0
  143. careamics/utils/context.py +66 -0
  144. careamics/utils/logging.py +322 -0
  145. careamics/utils/metrics.py +188 -0
  146. careamics/utils/path_utils.py +26 -0
  147. careamics/utils/ram.py +15 -0
  148. careamics/utils/receptive_field.py +108 -0
  149. careamics/utils/torch_utils.py +127 -0
  150. careamics-0.0.3.dist-info/METADATA +78 -0
  151. careamics-0.0.3.dist-info/RECORD +154 -0
  152. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
  153. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
  154. careamics-0.0.1.dist-info/METADATA +0 -46
  155. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,15 @@
1
+ """CAREamics transformation Pydantic models."""
2
+
3
+ __all__ = [
4
+ "N2VManipulateModel",
5
+ "XYFlipModel",
6
+ "NormalizeModel",
7
+ "XYRandomRotate90Model",
8
+ "XorYFlipModel",
9
+ ]
10
+
11
+
12
+ from .n2v_manipulate_model import N2VManipulateModel
13
+ from .normalize_model import NormalizeModel
14
+ from .xy_flip_model import XYFlipModel
15
+ from .xy_random_rotate90_model import XYRandomRotate90Model
@@ -0,0 +1,64 @@
1
+ """Pydantic model for the N2VManipulate transform."""
2
+
3
+ from typing import Literal
4
+
5
+ from pydantic import ConfigDict, Field, field_validator
6
+
7
+ from .transform_model import TransformModel
8
+
9
+
10
+ class N2VManipulateModel(TransformModel):
11
+ """
12
+ Pydantic model used to represent N2V manipulation.
13
+
14
+ Attributes
15
+ ----------
16
+ name : Literal["N2VManipulate"]
17
+ Name of the transformation.
18
+ roi_size : int
19
+ Size of the masking region, by default 11.
20
+ masked_pixel_percentage : float
21
+ Percentage of masked pixels, by default 0.2.
22
+ strategy : Literal["uniform", "median"]
23
+ Strategy pixel value replacement, by default "uniform".
24
+ struct_mask_axis : Literal["horizontal", "vertical", "none"]
25
+ Axis of the structN2V mask, by default "none".
26
+ struct_mask_span : int
27
+ Span of the structN2V mask, by default 5.
28
+ """
29
+
30
+ model_config = ConfigDict(
31
+ validate_assignment=True,
32
+ )
33
+
34
+ name: Literal["N2VManipulate"] = "N2VManipulate"
35
+ roi_size: int = Field(default=11, ge=3, le=21)
36
+ masked_pixel_percentage: float = Field(default=0.2, ge=0.05, le=10.0)
37
+ strategy: Literal["uniform", "median"] = Field(default="uniform")
38
+ struct_mask_axis: Literal["horizontal", "vertical", "none"] = Field(default="none")
39
+ struct_mask_span: int = Field(default=5, ge=3, le=15)
40
+
41
+ @field_validator("roi_size", "struct_mask_span")
42
+ @classmethod
43
+ def odd_value(cls, v: int) -> int:
44
+ """
45
+ Validate that the value is odd.
46
+
47
+ Parameters
48
+ ----------
49
+ v : int
50
+ Value to validate.
51
+
52
+ Returns
53
+ -------
54
+ int
55
+ The validated value.
56
+
57
+ Raises
58
+ ------
59
+ ValueError
60
+ If the value is even.
61
+ """
62
+ if v % 2 == 0:
63
+ raise ValueError("Size must be an odd number.")
64
+ return v
@@ -0,0 +1,60 @@
1
+ """Pydantic model for the Normalize transform."""
2
+
3
+ from typing import Literal, Optional
4
+
5
+ from pydantic import ConfigDict, Field, model_validator
6
+ from typing_extensions import Self
7
+
8
+ from .transform_model import TransformModel
9
+
10
+
11
+ class NormalizeModel(TransformModel):
12
+ """
13
+ Pydantic model used to represent Normalize transformation.
14
+
15
+ The Normalize transform is a zero mean and unit variance transformation.
16
+
17
+ Attributes
18
+ ----------
19
+ name : Literal["Normalize"]
20
+ Name of the transformation.
21
+ mean : float
22
+ Mean value for normalization.
23
+ std : float
24
+ Standard deviation value for normalization.
25
+ """
26
+
27
+ model_config = ConfigDict(
28
+ validate_assignment=True,
29
+ )
30
+
31
+ name: Literal["Normalize"] = "Normalize"
32
+ image_means: list = Field(..., min_length=0, max_length=32)
33
+ image_stds: list = Field(..., min_length=0, max_length=32)
34
+ target_means: Optional[list] = Field(default=None, min_length=0, max_length=32)
35
+ target_stds: Optional[list] = Field(default=None, min_length=0, max_length=32)
36
+
37
+ @model_validator(mode="after")
38
+ def validate_means_stds(self: Self) -> Self:
39
+ """Validate that the means and stds have the same length.
40
+
41
+ Returns
42
+ -------
43
+ Self
44
+ The instance of the model.
45
+ """
46
+ if len(self.image_means) != len(self.image_stds):
47
+ raise ValueError("The number of image means and stds must be the same.")
48
+
49
+ if (self.target_means is None) != (self.target_stds is None):
50
+ raise ValueError(
51
+ "Both target means and stds must be provided together, or bot None."
52
+ )
53
+
54
+ if self.target_means is not None and self.target_stds is not None:
55
+ if len(self.target_means) != len(self.target_stds):
56
+ raise ValueError(
57
+ "The number of target means and stds must be the same."
58
+ )
59
+
60
+ return self
@@ -0,0 +1,45 @@
1
+ """Parent model for the transforms."""
2
+
3
+ from typing import Any, Dict
4
+
5
+ from pydantic import BaseModel, ConfigDict
6
+
7
+
8
+ class TransformModel(BaseModel):
9
+ """
10
+ Pydantic model used to represent a transformation.
11
+
12
+ The `model_dump` method is overwritten to exclude the name field.
13
+
14
+ Attributes
15
+ ----------
16
+ name : str
17
+ Name of the transformation.
18
+ """
19
+
20
+ model_config = ConfigDict(
21
+ extra="forbid", # throw errors if the parameters are not properly passed
22
+ )
23
+
24
+ name: str
25
+
26
+ def model_dump(self, **kwargs) -> Dict[str, Any]:
27
+ """
28
+ Return the model as a dictionary.
29
+
30
+ Parameters
31
+ ----------
32
+ **kwargs
33
+ Pydantic BaseMode model_dump method keyword arguments.
34
+
35
+ Returns
36
+ -------
37
+ Dict[str, Any]
38
+ Dictionary representation of the model.
39
+ """
40
+ model_dict = super().model_dump(**kwargs)
41
+
42
+ # remove the name field
43
+ model_dict.pop("name")
44
+
45
+ return model_dict
@@ -0,0 +1,43 @@
1
+ """Pydantic model for the XYFlip transform."""
2
+
3
+ from typing import Literal, Optional
4
+
5
+ from pydantic import ConfigDict, Field
6
+
7
+ from .transform_model import TransformModel
8
+
9
+
10
+ class XYFlipModel(TransformModel):
11
+ """
12
+ Pydantic model used to represent XYFlip transformation.
13
+
14
+ Attributes
15
+ ----------
16
+ name : Literal["XYFlip"]
17
+ Name of the transformation.
18
+ p : float
19
+ Probability of applying the transform, by default 0.5.
20
+ seed : Optional[int]
21
+ Seed for the random number generator, by default None.
22
+ """
23
+
24
+ model_config = ConfigDict(
25
+ validate_assignment=True,
26
+ )
27
+
28
+ name: Literal["XYFlip"] = "XYFlip"
29
+ flip_x: bool = Field(
30
+ True,
31
+ description="Whether to flip along the X axis.",
32
+ )
33
+ flip_y: bool = Field(
34
+ True,
35
+ description="Whether to flip along the Y axis.",
36
+ )
37
+ p: float = Field(
38
+ 0.5,
39
+ description="Probability of applying the transform.",
40
+ ge=0,
41
+ le=1,
42
+ )
43
+ seed: Optional[int] = None
@@ -0,0 +1,35 @@
1
+ """Pydantic model for the XYRandomRotate90 transform."""
2
+
3
+ from typing import Literal, Optional
4
+
5
+ from pydantic import ConfigDict, Field
6
+
7
+ from .transform_model import TransformModel
8
+
9
+
10
+ class XYRandomRotate90Model(TransformModel):
11
+ """
12
+ Pydantic model used to represent the XY random 90 degree rotation transformation.
13
+
14
+ Attributes
15
+ ----------
16
+ name : Literal["XYRandomRotate90"]
17
+ Name of the transformation.
18
+ p : float
19
+ Probability of applying the transform, by default 0.5.
20
+ seed : Optional[int]
21
+ Seed for the random number generator, by default None.
22
+ """
23
+
24
+ model_config = ConfigDict(
25
+ validate_assignment=True,
26
+ )
27
+
28
+ name: Literal["XYRandomRotate90"] = "XYRandomRotate90"
29
+ p: float = Field(
30
+ 0.5,
31
+ description="Probability of applying the transform.",
32
+ ge=0,
33
+ le=1,
34
+ )
35
+ seed: Optional[int] = None
@@ -0,0 +1,171 @@
1
+ """Algorithm configuration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pprint import pformat
6
+ from typing import Literal, Optional, Union
7
+
8
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
9
+ from typing_extensions import Self
10
+
11
+ from careamics.config.support import SupportedAlgorithm, SupportedLoss
12
+
13
+ from .architectures import CustomModel, LVAEModel
14
+ from .likelihood_model import GaussianLikelihoodConfig, NMLikelihoodConfig
15
+ from .nm_model import MultiChannelNMConfig
16
+ from .optimizer_models import LrSchedulerModel, OptimizerModel
17
+
18
+
19
+ class VAEAlgorithmConfig(BaseModel):
20
+ """Algorithm configuration.
21
+
22
+ This Pydantic model validates the parameters governing the components of the
23
+ training algorithm: which algorithm, loss function, model architecture, optimizer,
24
+ and learning rate scheduler to use.
25
+
26
+ Currently, we only support N2V, CARE, N2N and custom models. The `n2v` algorithm is
27
+ only compatible with `n2v` loss and `UNet` architecture. The `custom` algorithm
28
+ allows you to register your own architecture and select it using its name as
29
+ `name` in the custom pydantic model.
30
+
31
+ Attributes
32
+ ----------
33
+ algorithm : algorithm: Literal["musplit", "denoisplit", "custom"]
34
+ Algorithm to use.
35
+ loss : Literal["musplit", "denoisplit", "denoisplit_musplit"]
36
+ Loss function to use.
37
+ model : Union[LVAEModel, CustomModel]
38
+ Model architecture to use.
39
+ noise_model: Optional[MultiChannelNmModel]
40
+ Noise model to use.
41
+ noise_model_likelihood_model: Optional[NMLikelihoodModel]
42
+ Noise model likelihood model to use.
43
+ gaussian_likelihood_model: Optional[GaussianLikelihoodModel]
44
+ Gaussian likelihood model to use.
45
+ optimizer : OptimizerModel, optional
46
+ Optimizer to use.
47
+ lr_scheduler : LrSchedulerModel, optional
48
+ Learning rate scheduler to use.
49
+
50
+ Raises
51
+ ------
52
+ ValueError
53
+ Algorithm parameter type validation errors.
54
+ ValueError
55
+ If the algorithm, loss and model are not compatible.
56
+
57
+ Examples
58
+ --------
59
+ # TODO add once finalized
60
+ """
61
+
62
+ # Pydantic class configuration
63
+ model_config = ConfigDict(
64
+ protected_namespaces=(), # allows to use model_* as a field name
65
+ validate_assignment=True,
66
+ extra="allow",
67
+ )
68
+
69
+ # Mandatory fields
70
+ # defined in SupportedAlgorithm
71
+ # TODO: Use supported Enum classes for typing?
72
+ # - values can still be passed as strings and they will be cast to Enum
73
+ algorithm_type: Literal["vae"]
74
+ algorithm: Literal["musplit", "denoisplit", "custom"]
75
+ loss: Literal["musplit", "denoisplit", "denoisplit_musplit"]
76
+ model: Union[LVAEModel, CustomModel] = Field(discriminator="architecture")
77
+
78
+ # TODO: these are configs, change naming of attrs
79
+ noise_model: Optional[MultiChannelNMConfig] = None
80
+ noise_model_likelihood_model: Optional[NMLikelihoodConfig] = None
81
+ gaussian_likelihood_model: Optional[GaussianLikelihoodConfig] = None
82
+
83
+ # Optional fields
84
+ optimizer: OptimizerModel = OptimizerModel()
85
+ """Optimizer to use, defined in SupportedOptimizer."""
86
+
87
+ lr_scheduler: LrSchedulerModel = LrSchedulerModel()
88
+
89
+ @model_validator(mode="after")
90
+ def algorithm_cross_validation(self: Self) -> Self:
91
+ """Validate the algorithm model based on `algorithm`.
92
+
93
+ Returns
94
+ -------
95
+ Self
96
+ The validated model.
97
+ """
98
+ # musplit
99
+ if self.algorithm == SupportedAlgorithm.MUSPLIT:
100
+ if self.loss != SupportedLoss.MUSPLIT:
101
+ raise ValueError(
102
+ f"Algorithm {self.algorithm} only supports loss `musplit`."
103
+ )
104
+
105
+ if self.algorithm == SupportedAlgorithm.DENOISPLIT:
106
+ if self.loss not in [
107
+ SupportedLoss.DENOISPLIT,
108
+ SupportedLoss.DENOISPLIT_MUSPLIT,
109
+ ]:
110
+ raise ValueError(
111
+ f"Algorithm {self.algorithm} only supports loss `denoisplit` "
112
+ "or `denoisplit_musplit."
113
+ )
114
+ if (
115
+ self.loss == SupportedLoss.DENOISPLIT
116
+ and self.model.predict_logvar is not None
117
+ ):
118
+ raise ValueError(
119
+ "Algorithm `denoisplit` with loss `denoisplit` only supports "
120
+ "`predict_logvar` as `None`."
121
+ )
122
+ if self.noise_model is None:
123
+ raise ValueError("Algorithm `denoisplit` requires a noise model.")
124
+ # TODO: what if algorithm is not musplit or denoisplit (HDN?)
125
+ return self
126
+
127
+ @model_validator(mode="after")
128
+ def output_channels_validation(self: Self) -> Self:
129
+ """Validate the consistency between number of out channels and noise models.
130
+
131
+ Returns
132
+ -------
133
+ Self
134
+ The validated model.
135
+ """
136
+ if self.noise_model is not None:
137
+ assert self.model.output_channels == len(self.noise_model.noise_models), (
138
+ f"Number of output channels ({self.model.output_channels}) must match "
139
+ f"the number of noise models ({len(self.noise_model.noise_models)})."
140
+ )
141
+ return self
142
+
143
+ @model_validator(mode="after")
144
+ def predict_logvar_validation(self: Self) -> Self:
145
+ """Validate the consistency of `predict_logvar` throughout the model.
146
+
147
+ Returns
148
+ -------
149
+ Self
150
+ The validated model.
151
+ """
152
+ if self.gaussian_likelihood_model is not None:
153
+ assert (
154
+ self.model.predict_logvar
155
+ == self.gaussian_likelihood_model.predict_logvar
156
+ ), (
157
+ f"Model `predict_logvar` ({self.model.predict_logvar}) must match "
158
+ "Gaussian likelihood model `predict_logvar` "
159
+ f"({self.gaussian_likelihood_model.predict_logvar}).",
160
+ )
161
+ return self
162
+
163
+ def __str__(self) -> str:
164
+ """Pretty string representing the configuration.
165
+
166
+ Returns
167
+ -------
168
+ str
169
+ Pretty string.
170
+ """
171
+ return pformat(self.model_dump())
@@ -0,0 +1,5 @@
1
+ """Validator utilities."""
2
+
3
+ __all__ = ["check_axes_validity", "patch_size_ge_than_8_power_of_2"]
4
+
5
+ from .validator_utils import check_axes_validity, patch_size_ge_than_8_power_of_2
@@ -0,0 +1,101 @@
1
+ """
2
+ Validator functions.
3
+
4
+ These functions are used to validate dimensions and axes of inputs.
5
+ """
6
+
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ _AXES = "STCZYX"
10
+
11
+
12
+ def check_axes_validity(axes: str) -> None:
13
+ """
14
+ Sanity check on axes.
15
+
16
+ The constraints on the axes are the following:
17
+ - must be a combination of 'STCZYX'
18
+ - must not contain duplicates
19
+ - must contain at least 2 contiguous axes: X and Y
20
+ - must contain at most 4 axes
21
+ - cannot contain both S and T axes
22
+
23
+ Axes do not need to be in the order 'STCZYX', as this depends on the user data.
24
+
25
+ Parameters
26
+ ----------
27
+ axes : str
28
+ Axes to validate.
29
+ """
30
+ _axes = axes.upper()
31
+
32
+ # Minimum is 2 (XY) and maximum is 4 (TZYX)
33
+ if len(_axes) < 2 or len(_axes) > 6:
34
+ raise ValueError(
35
+ f"Invalid axes {axes}. Must contain at least 2 and at most 6 axes."
36
+ )
37
+
38
+ if "YX" not in _axes and "XY" not in _axes:
39
+ raise ValueError(
40
+ f"Invalid axes {axes}. Must contain at least X and Y axes consecutively."
41
+ )
42
+
43
+ # all characters must be in REF_AXES = 'STCZYX'
44
+ if not all(s in _AXES for s in _axes):
45
+ raise ValueError(f"Invalid axes {axes}. Must be a combination of {_AXES}.")
46
+
47
+ # check for repeating characters
48
+ for i, s in enumerate(_axes):
49
+ if i != _axes.rfind(s):
50
+ raise ValueError(
51
+ f"Invalid axes {axes}. Cannot contain duplicate axes"
52
+ f" (got multiple {axes[i]})."
53
+ )
54
+
55
+
56
+ def value_ge_than_8_power_of_2(
57
+ value: int,
58
+ ) -> None:
59
+ """
60
+ Validate that the value is greater or equal than 8 and a power of 2.
61
+
62
+ Parameters
63
+ ----------
64
+ value : int
65
+ Value to validate.
66
+
67
+ Raises
68
+ ------
69
+ ValueError
70
+ If the value is smaller than 8.
71
+ ValueError
72
+ If the value is not a power of 2.
73
+ """
74
+ if value < 8:
75
+ raise ValueError(f"Value must be greater than 8 (got {value}).")
76
+
77
+ if (value & (value - 1)) != 0:
78
+ raise ValueError(f"Value must be a power of 2 (got {value}).")
79
+
80
+
81
+ def patch_size_ge_than_8_power_of_2(
82
+ patch_list: Optional[Union[List[int], Union[Tuple[int, ...]]]],
83
+ ) -> None:
84
+ """
85
+ Validate that each entry is greater or equal than 8 and a power of 2.
86
+
87
+ Parameters
88
+ ----------
89
+ patch_list : Optional[Union[List[int]]]
90
+ Patch size.
91
+
92
+ Raises
93
+ ------
94
+ ValueError
95
+ If the patch size if smaller than 8.
96
+ ValueError
97
+ If the patch size is not a power of 2.
98
+ """
99
+ if patch_list is not None:
100
+ for dim in patch_list:
101
+ value_ge_than_8_power_of_2(dim)
careamics/conftest.py ADDED
@@ -0,0 +1,39 @@
1
+ """File used to discover python modules and run doctest.
2
+
3
+ See https://sybil.readthedocs.io/en/latest/use.html#pytest
4
+ """
5
+
6
+ from pathlib import Path
7
+
8
+ import pytest
9
+ from pytest import TempPathFactory
10
+ from sybil import Sybil
11
+ from sybil.parsers.codeblock import PythonCodeBlockParser
12
+ from sybil.parsers.doctest import DocTestParser
13
+
14
+
15
+ @pytest.fixture(scope="module")
16
+ def my_path(tmpdir_factory: TempPathFactory) -> Path:
17
+ """Fixture used in doctest to create a temporary directory.
18
+
19
+ Parameters
20
+ ----------
21
+ tmpdir_factory : TempPathFactory
22
+ Temporary path factory from pytest.
23
+
24
+ Returns
25
+ -------
26
+ Path
27
+ Temporary directory path.
28
+ """
29
+ return tmpdir_factory.mktemp("my_path")
30
+
31
+
32
+ pytest_collect_file = Sybil(
33
+ parsers=[
34
+ DocTestParser(),
35
+ PythonCodeBlockParser(future_imports=["print_function"]),
36
+ ],
37
+ pattern="*.py",
38
+ fixtures=["my_path"],
39
+ ).pytest()
@@ -0,0 +1,17 @@
1
+ """Dataset module."""
2
+
3
+ __all__ = [
4
+ "InMemoryDataset",
5
+ "InMemoryPredDataset",
6
+ "InMemoryTiledPredDataset",
7
+ "PathIterableDataset",
8
+ "IterableTiledPredDataset",
9
+ "IterablePredDataset",
10
+ ]
11
+
12
+ from .in_memory_dataset import InMemoryDataset
13
+ from .in_memory_pred_dataset import InMemoryPredDataset
14
+ from .in_memory_tiled_pred_dataset import InMemoryTiledPredDataset
15
+ from .iterable_dataset import PathIterableDataset
16
+ from .iterable_pred_dataset import IterablePredDataset
17
+ from .iterable_tiled_pred_dataset import IterableTiledPredDataset
@@ -0,0 +1,19 @@
1
+ """Files and arrays utils used in the datasets."""
2
+
3
+ __all__ = [
4
+ "reshape_array",
5
+ "compute_normalization_stats",
6
+ "get_files_size",
7
+ "list_files",
8
+ "validate_source_target_files",
9
+ "iterate_over_files",
10
+ "WelfordStatistics",
11
+ ]
12
+
13
+
14
+ from .dataset_utils import (
15
+ reshape_array,
16
+ )
17
+ from .file_utils import get_files_size, list_files, validate_source_target_files
18
+ from .iterate_over_files import iterate_over_files
19
+ from .running_stats import WelfordStatistics, compute_normalization_stats