careamics 0.1.0rc5__py3-none-any.whl → 0.1.0rc6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (81) hide show
  1. careamics/callbacks/hyperparameters_callback.py +10 -3
  2. careamics/callbacks/progress_bar_callback.py +37 -4
  3. careamics/careamist.py +80 -44
  4. careamics/config/algorithm_model.py +5 -3
  5. careamics/config/architectures/architecture_model.py +7 -0
  6. careamics/config/architectures/custom_model.py +8 -1
  7. careamics/config/architectures/register_model.py +3 -1
  8. careamics/config/architectures/unet_model.py +2 -0
  9. careamics/config/architectures/vae_model.py +2 -0
  10. careamics/config/callback_model.py +3 -15
  11. careamics/config/configuration_example.py +4 -2
  12. careamics/config/configuration_factory.py +4 -16
  13. careamics/config/data_model.py +10 -14
  14. careamics/config/inference_model.py +0 -65
  15. careamics/config/optimizer_models.py +4 -4
  16. careamics/config/support/__init__.py +0 -2
  17. careamics/config/support/supported_activations.py +2 -0
  18. careamics/config/support/supported_algorithms.py +3 -1
  19. careamics/config/support/supported_architectures.py +2 -0
  20. careamics/config/support/supported_data.py +2 -0
  21. careamics/config/support/supported_loggers.py +2 -0
  22. careamics/config/support/supported_losses.py +2 -0
  23. careamics/config/support/supported_optimizers.py +2 -0
  24. careamics/config/support/supported_pixel_manipulations.py +3 -3
  25. careamics/config/support/supported_struct_axis.py +2 -0
  26. careamics/config/support/supported_transforms.py +4 -15
  27. careamics/config/tile_information.py +2 -0
  28. careamics/config/transformations/__init__.py +3 -2
  29. careamics/config/transformations/xy_flip_model.py +43 -0
  30. careamics/config/transformations/xy_random_rotate90_model.py +11 -3
  31. careamics/conftest.py +12 -0
  32. careamics/dataset/dataset_utils/dataset_utils.py +4 -4
  33. careamics/dataset/dataset_utils/file_utils.py +4 -3
  34. careamics/dataset/dataset_utils/read_tiff.py +6 -2
  35. careamics/dataset/dataset_utils/read_utils.py +2 -0
  36. careamics/dataset/dataset_utils/read_zarr.py +11 -7
  37. careamics/dataset/in_memory_dataset.py +71 -32
  38. careamics/dataset/iterable_dataset.py +155 -68
  39. careamics/dataset/patching/patching.py +56 -15
  40. careamics/dataset/patching/random_patching.py +8 -2
  41. careamics/dataset/patching/sequential_patching.py +14 -8
  42. careamics/dataset/patching/tiled_patching.py +3 -1
  43. careamics/dataset/patching/validate_patch_dimension.py +2 -0
  44. careamics/dataset/zarr_dataset.py +2 -0
  45. careamics/lightning_datamodule.py +45 -19
  46. careamics/lightning_module.py +8 -2
  47. careamics/lightning_prediction_datamodule.py +3 -13
  48. careamics/lightning_prediction_loop.py +8 -6
  49. careamics/losses/__init__.py +2 -3
  50. careamics/losses/loss_factory.py +1 -1
  51. careamics/losses/losses.py +11 -7
  52. careamics/model_io/bmz_io.py +3 -3
  53. careamics/models/activation.py +2 -0
  54. careamics/models/layers.py +121 -25
  55. careamics/models/model_factory.py +1 -1
  56. careamics/models/unet.py +35 -14
  57. careamics/prediction/stitch_prediction.py +2 -6
  58. careamics/transforms/__init__.py +2 -2
  59. careamics/transforms/compose.py +33 -7
  60. careamics/transforms/n2v_manipulate.py +49 -13
  61. careamics/transforms/normalize.py +55 -3
  62. careamics/transforms/pixel_manipulation.py +5 -5
  63. careamics/transforms/struct_mask_parameters.py +3 -1
  64. careamics/transforms/transform.py +10 -19
  65. careamics/transforms/xy_flip.py +123 -0
  66. careamics/transforms/xy_random_rotate90.py +38 -5
  67. careamics/utils/base_enum.py +28 -0
  68. careamics/utils/path_utils.py +2 -0
  69. careamics/utils/ram.py +2 -0
  70. careamics/utils/receptive_field.py +93 -87
  71. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc6.dist-info}/METADATA +2 -1
  72. careamics-0.1.0rc6.dist-info/RECORD +107 -0
  73. careamics/config/noise_models.py +0 -162
  74. careamics/config/support/supported_extraction_strategies.py +0 -25
  75. careamics/config/transformations/nd_flip_model.py +0 -27
  76. careamics/losses/noise_model_factory.py +0 -40
  77. careamics/losses/noise_models.py +0 -524
  78. careamics/transforms/nd_flip.py +0 -67
  79. careamics-0.1.0rc5.dist-info/RECORD +0 -111
  80. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc6.dist-info}/WHEEL +0 -0
  81. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc6.dist-info}/licenses/LICENSE +0 -0
@@ -7,12 +7,8 @@ from typing import Any, List, Literal, Optional, Union
7
7
  from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
8
8
  from typing_extensions import Self
9
9
 
10
- from .support import SupportedTransform
11
- from .transformations.normalize_model import NormalizeModel
12
10
  from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
13
11
 
14
- TRANSFORMS_UNION = Union[NormalizeModel]
15
-
16
12
 
17
13
  class InferenceConfig(BaseModel):
18
14
  """Configuration class for the prediction model."""
@@ -33,15 +29,6 @@ class InferenceConfig(BaseModel):
33
29
  mean: float
34
30
  std: float = Field(..., ge=0.0)
35
31
 
36
- transforms: List[TRANSFORMS_UNION] = Field(
37
- default=[
38
- {
39
- "name": SupportedTransform.NORMALIZE.value,
40
- },
41
- ],
42
- validate_default=True,
43
- )
44
-
45
32
  # only default TTAs are supported for now
46
33
  tta_transforms: bool = Field(default=True)
47
34
 
@@ -149,39 +136,6 @@ class InferenceConfig(BaseModel):
149
136
 
150
137
  return axes
151
138
 
152
- @field_validator("transforms")
153
- @classmethod
154
- def validate_transforms(
155
- cls, transforms: List[TRANSFORMS_UNION]
156
- ) -> List[TRANSFORMS_UNION]:
157
- """
158
- Validate that transforms do not have N2V pixel manipulate transforms.
159
-
160
- Parameters
161
- ----------
162
- transforms : List[TRANSFORMS_UNION]
163
- Transforms.
164
-
165
- Returns
166
- -------
167
- List[TRANSFORMS_UNION]
168
- Validated transforms.
169
-
170
- Raises
171
- ------
172
- ValueError
173
- If transforms contain N2V pixel manipulate transforms.
174
- """
175
- if transforms is not None:
176
- for transform in transforms:
177
- if transform.name == SupportedTransform.N2V_MANIPULATE.value:
178
- raise ValueError(
179
- "N2V_Manipulate transform is not allowed in "
180
- "prediction transforms."
181
- )
182
-
183
- return transforms
184
-
185
139
  @model_validator(mode="after")
186
140
  def validate_dimensions(self: Self) -> Self:
187
141
  """
@@ -235,25 +189,6 @@ class InferenceConfig(BaseModel):
235
189
 
236
190
  return self
237
191
 
238
- @model_validator(mode="after")
239
- def add_std_and_mean_to_normalize(self: Self) -> Self:
240
- """
241
- Add mean and std to the Normalize transform if it is present.
242
-
243
- Returns
244
- -------
245
- Self
246
- Inference model with mean and std added to the Normalize transform.
247
- """
248
- if self.mean is not None or self.std is not None:
249
- # search in the transforms for Normalize and update parameters
250
- for transform in self.transforms:
251
- if transform.name == SupportedTransform.NORMALIZE.value:
252
- transform.mean = self.mean
253
- transform.std = self.std
254
-
255
- return self
256
-
257
192
  def _update(self, **kwargs: Any) -> None:
258
193
  """
259
194
  Update multiple arguments at once.
@@ -1,3 +1,5 @@
1
+ """Optimizers and schedulers Pydantic models."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
5
  from typing import Dict, Literal
@@ -19,8 +21,7 @@ from .support import SupportedOptimizer
19
21
 
20
22
 
21
23
  class OptimizerModel(BaseModel):
22
- """
23
- Torch optimizer.
24
+ """Torch optimizer Pydantic model.
24
25
 
25
26
  Only parameters supported by the corresponding torch optimizer will be taken
26
27
  into account. For more details, check:
@@ -115,8 +116,7 @@ class OptimizerModel(BaseModel):
115
116
 
116
117
 
117
118
  class LrSchedulerModel(BaseModel):
118
- """
119
- Torch learning rate scheduler.
119
+ """Torch learning rate scheduler Pydantic model.
120
120
 
121
121
  Only parameters supported by the corresponding torch lr scheduler will be taken
122
122
  into account. For more details, check:
@@ -14,7 +14,6 @@ __all__ = [
14
14
  "SupportedPixelManipulation",
15
15
  "SupportedTransform",
16
16
  "SupportedData",
17
- "SupportedExtractionStrategy",
18
17
  "SupportedStructAxis",
19
18
  "SupportedLogger",
20
19
  ]
@@ -24,7 +23,6 @@ from .supported_activations import SupportedActivation
24
23
  from .supported_algorithms import SupportedAlgorithm
25
24
  from .supported_architectures import SupportedArchitecture
26
25
  from .supported_data import SupportedData
27
- from .supported_extraction_strategies import SupportedExtractionStrategy
28
26
  from .supported_loggers import SupportedLogger
29
27
  from .supported_losses import SupportedLoss
30
28
  from .supported_optimizers import SupportedOptimizer, SupportedScheduler
@@ -1,3 +1,5 @@
1
+ """Activations supported by CAREamics."""
2
+
1
3
  from careamics.utils import BaseEnum
2
4
 
3
5
 
@@ -1,3 +1,5 @@
1
+ """Algorithms supported by CAREamics."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
5
  from careamics.utils import BaseEnum
@@ -10,9 +12,9 @@ class SupportedAlgorithm(str, BaseEnum):
10
12
  """
11
13
 
12
14
  N2V = "n2v"
13
- CUSTOM = "custom"
14
15
  CARE = "care"
15
16
  N2N = "n2n"
17
+ CUSTOM = "custom"
16
18
  # PN2V = "pn2v"
17
19
  # HDN = "hdn"
18
20
  # SEG = "segmentation"
@@ -1,3 +1,5 @@
1
+ """Architectures supported by CAREamics."""
2
+
1
3
  from careamics.utils import BaseEnum
2
4
 
3
5
 
@@ -1,3 +1,5 @@
1
+ """Data supported by CAREamics."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
5
  from typing import Union
@@ -1,3 +1,5 @@
1
+ """Logger supported by CAREamics."""
2
+
1
3
  from careamics.utils import BaseEnum
2
4
 
3
5
 
@@ -1,3 +1,5 @@
1
+ """Losses supported by CAREamics."""
2
+
1
3
  from careamics.utils import BaseEnum
2
4
 
3
5
 
@@ -1,3 +1,5 @@
1
+ """Optimizers and schedulers supported by CAREamics."""
2
+
1
3
  from careamics.utils import BaseEnum
2
4
 
3
5
 
@@ -1,15 +1,15 @@
1
+ """Pixel manipulation methods supported by CAREamics."""
2
+
1
3
  from careamics.utils import BaseEnum
2
4
 
3
5
 
4
6
  class SupportedPixelManipulation(str, BaseEnum):
5
- """_summary_.
7
+ """Supported Noise2Void pixel manipulations.
6
8
 
7
9
  - Uniform: Replace masked pixel value by a (uniformly) randomly selected neighbor
8
10
  pixel value.
9
11
  - Median: Replace masked pixel value by the mean of the neighborhood.
10
12
  """
11
13
 
12
- # TODO docs
13
-
14
14
  UNIFORM = "uniform"
15
15
  MEDIAN = "median"
@@ -1,3 +1,5 @@
1
+ """StructN2V axes supported by CAREamics."""
2
+
1
3
  from careamics.utils import BaseEnum
2
4
 
3
5
 
@@ -1,23 +1,12 @@
1
+ """Transforms supported by CAREamics."""
2
+
1
3
  from careamics.utils import BaseEnum
2
4
 
3
5
 
4
6
  class SupportedTransform(str, BaseEnum):
5
- """Transforms officially supported by CAREamics.
6
-
7
- - Flip: from Albumentations, randomly flip the input horizontally, vertically or
8
- both, parameter `p` can be used to set the probability to apply the transform.
9
- - XYRandomRotate90: #TODO
10
- - Normalize # TODO add details, in particular about the parameters
11
- - ManipulateN2V # TODO add details, in particular about the parameters
12
- - NDFlip
13
-
14
- Note that while any Albumentations (see https://albumentations.ai/) transform can be
15
- used in CAREamics, no check are implemented to verify the compatibility of any other
16
- transforms than the ones officially supported.
17
- """
7
+ """Transforms officially supported by CAREamics."""
18
8
 
19
- NDFLIP = "NDFlip"
9
+ XY_FLIP = "XYFlip"
20
10
  XY_RANDOM_ROTATE90 = "XYRandomRotate90"
21
11
  NORMALIZE = "Normalize"
22
12
  N2V_MANIPULATE = "N2VManipulate"
23
- # CUSTOM = "Custom"
@@ -1,3 +1,5 @@
1
+ """Pydantic model representing the metadata of a prediction tile."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
5
  from typing import Optional, Tuple
@@ -2,13 +2,14 @@
2
2
 
3
3
  __all__ = [
4
4
  "N2VManipulateModel",
5
- "NDFlipModel",
5
+ "XYFlipModel",
6
6
  "NormalizeModel",
7
7
  "XYRandomRotate90Model",
8
+ "XorYFlipModel",
8
9
  ]
9
10
 
10
11
 
11
12
  from .n2v_manipulate_model import N2VManipulateModel
12
- from .nd_flip_model import NDFlipModel
13
13
  from .normalize_model import NormalizeModel
14
+ from .xy_flip_model import XYFlipModel
14
15
  from .xy_random_rotate90_model import XYRandomRotate90Model
@@ -0,0 +1,43 @@
1
+ """Pydantic model for the XYFlip transform."""
2
+
3
+ from typing import Literal, Optional
4
+
5
+ from pydantic import ConfigDict, Field
6
+
7
+ from .transform_model import TransformModel
8
+
9
+
10
+ class XYFlipModel(TransformModel):
11
+ """
12
+ Pydantic model used to represent XYFlip transformation.
13
+
14
+ Attributes
15
+ ----------
16
+ name : Literal["XYFlip"]
17
+ Name of the transformation.
18
+ p : float
19
+ Probability of applying the transform, by default 0.5.
20
+ seed : Optional[int]
21
+ Seed for the random number generator, by default None.
22
+ """
23
+
24
+ model_config = ConfigDict(
25
+ validate_assignment=True,
26
+ )
27
+
28
+ name: Literal["XYFlip"] = "XYFlip"
29
+ flip_x: bool = Field(
30
+ True,
31
+ description="Whether to flip along the X axis.",
32
+ )
33
+ flip_y: bool = Field(
34
+ True,
35
+ description="Whether to flip along the Y axis.",
36
+ )
37
+ p: float = Field(
38
+ 0.5,
39
+ description="Probability of applying the transform.",
40
+ ge=0,
41
+ le=1,
42
+ )
43
+ seed: Optional[int] = None
@@ -2,21 +2,23 @@
2
2
 
3
3
  from typing import Literal, Optional
4
4
 
5
- from pydantic import ConfigDict
5
+ from pydantic import ConfigDict, Field
6
6
 
7
7
  from .transform_model import TransformModel
8
8
 
9
9
 
10
10
  class XYRandomRotate90Model(TransformModel):
11
11
  """
12
- Pydantic model used to represent NDFlip transformation.
12
+ Pydantic model used to represent the XY random 90 degree rotation transformation.
13
13
 
14
14
  Attributes
15
15
  ----------
16
16
  name : Literal["XYRandomRotate90"]
17
17
  Name of the transformation.
18
+ p : float
19
+ Probability of applying the transform, by default 0.5.
18
20
  seed : Optional[int]
19
- Seed for the random number generator.
21
+ Seed for the random number generator, by default None.
20
22
  """
21
23
 
22
24
  model_config = ConfigDict(
@@ -24,4 +26,10 @@ class XYRandomRotate90Model(TransformModel):
24
26
  )
25
27
 
26
28
  name: Literal["XYRandomRotate90"] = "XYRandomRotate90"
29
+ p: float = Field(
30
+ 0.5,
31
+ description="Probability of applying the transform.",
32
+ ge=0,
33
+ le=1,
34
+ )
27
35
  seed: Optional[int] = None
careamics/conftest.py CHANGED
@@ -14,6 +14,18 @@ from sybil.parsers.doctest import DocTestParser
14
14
 
15
15
  @pytest.fixture(scope="module")
16
16
  def my_path(tmpdir_factory: TempPathFactory) -> Path:
17
+ """Fixture used in doctest to create a temporary directory.
18
+
19
+ Parameters
20
+ ----------
21
+ tmpdir_factory : TempPathFactory
22
+ Temporary path factory from pytest.
23
+
24
+ Returns
25
+ -------
26
+ Path
27
+ Temporary directory path.
28
+ """
17
29
  return tmpdir_factory.mktemp("my_path")
18
30
 
19
31
 
@@ -1,4 +1,4 @@
1
- """Convenience methods for datasets."""
1
+ """Dataset utilities."""
2
2
 
3
3
  from typing import List, Tuple
4
4
 
@@ -17,12 +17,12 @@ def _get_shape_order(
17
17
 
18
18
  Parameters
19
19
  ----------
20
- shape_in : Tuple
20
+ shape_in : Tuple[int, ...]
21
21
  Input shape.
22
- ref_axes : str
23
- Reference axes.
24
22
  axes_in : str
25
23
  Input axes.
24
+ ref_axes : str
25
+ Reference axes.
26
26
 
27
27
  Returns
28
28
  -------
@@ -1,3 +1,5 @@
1
+ """File utilities."""
2
+
1
3
  from fnmatch import fnmatch
2
4
  from pathlib import Path
3
5
  from typing import List, Union
@@ -11,8 +13,7 @@ logger = get_logger(__name__)
11
13
 
12
14
 
13
15
  def get_files_size(files: List[Path]) -> float:
14
- """
15
- Get files size in MB.
16
+ """Get files size in MB.
16
17
 
17
18
  Parameters
18
19
  ----------
@@ -32,7 +33,7 @@ def list_files(
32
33
  data_type: Union[str, SupportedData],
33
34
  extension_filter: str = "",
34
35
  ) -> List[Path]:
35
- """Creates a recursive list of files in `data_path`.
36
+ """Create a recursive list of files in `data_path`.
36
37
 
37
38
  If `data_path` is a file, its name is validated against the `data_type` using
38
39
  `fnmatch`, and the method returns `data_path` itself.
@@ -1,3 +1,5 @@
1
+ """Funtions to read tiff images."""
2
+
1
3
  import logging
2
4
  from fnmatch import fnmatch
3
5
  from pathlib import Path
@@ -19,8 +21,10 @@ def read_tiff(file_path: Path, *args: list, **kwargs: dict) -> np.ndarray:
19
21
  ----------
20
22
  file_path : Path
21
23
  Path to a file.
22
- axes : str
23
- Description of axes in format STCZYX.
24
+ *args : list
25
+ Additional arguments.
26
+ **kwargs : dict
27
+ Additional keyword arguments.
24
28
 
25
29
  Returns
26
30
  -------
@@ -1,3 +1,5 @@
1
+ """Read function utilities."""
2
+
1
3
  from typing import Callable, Union
2
4
 
3
5
  from careamics.config.support import SupportedData
@@ -1,3 +1,5 @@
1
+ """Function to read zarr images."""
2
+
1
3
  from typing import Union
2
4
 
3
5
  from zarr import Group, core, hierarchy, storage
@@ -6,26 +8,28 @@ from zarr import Group, core, hierarchy, storage
6
8
  def read_zarr(
7
9
  zarr_source: Group, axes: str
8
10
  ) -> Union[core.Array, storage.DirectoryStore, hierarchy.Group]:
9
- """Reads a file and returns a pointer.
11
+ """Read a file and returns a pointer.
10
12
 
11
13
  Parameters
12
14
  ----------
13
- file_path : Path
14
- pathlib.Path object containing a path to a file
15
+ zarr_source : Group
16
+ Zarr storage.
17
+ axes : str
18
+ Axes of the data.
15
19
 
16
20
  Returns
17
21
  -------
18
22
  np.ndarray
19
- Pointer to zarr storage
23
+ Pointer to zarr storage.
20
24
 
21
25
  Raises
22
26
  ------
23
27
  ValueError, OSError
24
- if a file is not a valid tiff or damaged
28
+ if a file is not a valid tiff or damaged.
25
29
  ValueError
26
- if data dimensions are not 2, 3 or 4
30
+ if data dimensions are not 2, 3 or 4.
27
31
  ValueError
28
- if axes parameter from config is not consistent with data dimensions
32
+ if axes parameter from config is not consistent with data dimensions.
29
33
  """
30
34
  if isinstance(zarr_source, hierarchy.Group):
31
35
  array = zarr_source[0]