careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (133) hide show
  1. careamics/__init__.py +14 -4
  2. careamics/callbacks/__init__.py +6 -0
  3. careamics/callbacks/hyperparameters_callback.py +42 -0
  4. careamics/callbacks/progress_bar_callback.py +57 -0
  5. careamics/careamist.py +761 -0
  6. careamics/config/__init__.py +27 -3
  7. careamics/config/algorithm_model.py +167 -0
  8. careamics/config/architectures/__init__.py +17 -0
  9. careamics/config/architectures/architecture_model.py +29 -0
  10. careamics/config/architectures/custom_model.py +150 -0
  11. careamics/config/architectures/register_model.py +101 -0
  12. careamics/config/architectures/unet_model.py +96 -0
  13. careamics/config/architectures/vae_model.py +39 -0
  14. careamics/config/callback_model.py +92 -0
  15. careamics/config/configuration_factory.py +460 -0
  16. careamics/config/configuration_model.py +596 -0
  17. careamics/config/data_model.py +555 -0
  18. careamics/config/inference_model.py +283 -0
  19. careamics/config/noise_models.py +162 -0
  20. careamics/config/optimizer_models.py +181 -0
  21. careamics/config/references/__init__.py +45 -0
  22. careamics/config/references/algorithm_descriptions.py +131 -0
  23. careamics/config/references/references.py +38 -0
  24. careamics/config/support/__init__.py +33 -0
  25. careamics/config/support/supported_activations.py +24 -0
  26. careamics/config/support/supported_algorithms.py +18 -0
  27. careamics/config/support/supported_architectures.py +18 -0
  28. careamics/config/support/supported_data.py +82 -0
  29. careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
  30. careamics/config/support/supported_loggers.py +8 -0
  31. careamics/config/support/supported_losses.py +25 -0
  32. careamics/config/support/supported_optimizers.py +55 -0
  33. careamics/config/support/supported_pixel_manipulations.py +15 -0
  34. careamics/config/support/supported_struct_axis.py +19 -0
  35. careamics/config/support/supported_transforms.py +23 -0
  36. careamics/config/tile_information.py +104 -0
  37. careamics/config/training_model.py +65 -0
  38. careamics/config/transformations/__init__.py +14 -0
  39. careamics/config/transformations/n2v_manipulate_model.py +63 -0
  40. careamics/config/transformations/nd_flip_model.py +32 -0
  41. careamics/config/transformations/normalize_model.py +31 -0
  42. careamics/config/transformations/transform_model.py +44 -0
  43. careamics/config/transformations/xy_random_rotate90_model.py +29 -0
  44. careamics/config/validators/__init__.py +5 -0
  45. careamics/config/validators/validator_utils.py +100 -0
  46. careamics/conftest.py +26 -0
  47. careamics/dataset/__init__.py +5 -0
  48. careamics/dataset/dataset_utils/__init__.py +19 -0
  49. careamics/dataset/dataset_utils/dataset_utils.py +100 -0
  50. careamics/dataset/dataset_utils/file_utils.py +140 -0
  51. careamics/dataset/dataset_utils/read_tiff.py +61 -0
  52. careamics/dataset/dataset_utils/read_utils.py +25 -0
  53. careamics/dataset/dataset_utils/read_zarr.py +56 -0
  54. careamics/dataset/in_memory_dataset.py +323 -134
  55. careamics/dataset/iterable_dataset.py +416 -0
  56. careamics/dataset/patching/__init__.py +8 -0
  57. careamics/dataset/patching/patch_transform.py +44 -0
  58. careamics/dataset/patching/patching.py +212 -0
  59. careamics/dataset/patching/random_patching.py +190 -0
  60. careamics/dataset/patching/sequential_patching.py +206 -0
  61. careamics/dataset/patching/tiled_patching.py +158 -0
  62. careamics/dataset/patching/validate_patch_dimension.py +60 -0
  63. careamics/dataset/zarr_dataset.py +149 -0
  64. careamics/lightning_datamodule.py +665 -0
  65. careamics/lightning_module.py +292 -0
  66. careamics/lightning_prediction_datamodule.py +390 -0
  67. careamics/lightning_prediction_loop.py +116 -0
  68. careamics/losses/__init__.py +4 -1
  69. careamics/losses/loss_factory.py +24 -14
  70. careamics/losses/losses.py +65 -5
  71. careamics/losses/noise_model_factory.py +40 -0
  72. careamics/losses/noise_models.py +524 -0
  73. careamics/model_io/__init__.py +8 -0
  74. careamics/model_io/bioimage/__init__.py +11 -0
  75. careamics/model_io/bioimage/_readme_factory.py +120 -0
  76. careamics/model_io/bioimage/bioimage_utils.py +48 -0
  77. careamics/model_io/bioimage/model_description.py +318 -0
  78. careamics/model_io/bmz_io.py +231 -0
  79. careamics/model_io/model_io_utils.py +80 -0
  80. careamics/models/__init__.py +4 -1
  81. careamics/models/activation.py +35 -0
  82. careamics/models/layers.py +244 -0
  83. careamics/models/model_factory.py +21 -221
  84. careamics/models/unet.py +46 -20
  85. careamics/prediction/__init__.py +1 -3
  86. careamics/prediction/stitch_prediction.py +73 -0
  87. careamics/transforms/__init__.py +41 -0
  88. careamics/transforms/n2v_manipulate.py +113 -0
  89. careamics/transforms/nd_flip.py +93 -0
  90. careamics/transforms/normalize.py +109 -0
  91. careamics/transforms/pixel_manipulation.py +383 -0
  92. careamics/transforms/struct_mask_parameters.py +18 -0
  93. careamics/transforms/tta.py +74 -0
  94. careamics/transforms/xy_random_rotate90.py +95 -0
  95. careamics/utils/__init__.py +10 -12
  96. careamics/utils/base_enum.py +32 -0
  97. careamics/utils/context.py +22 -2
  98. careamics/utils/metrics.py +0 -46
  99. careamics/utils/path_utils.py +24 -0
  100. careamics/utils/ram.py +13 -0
  101. careamics/utils/receptive_field.py +102 -0
  102. careamics/utils/running_stats.py +43 -0
  103. careamics/utils/torch_utils.py +112 -75
  104. careamics-0.1.0rc3.dist-info/METADATA +122 -0
  105. careamics-0.1.0rc3.dist-info/RECORD +109 -0
  106. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc3.dist-info}/WHEEL +1 -1
  107. careamics/bioimage/__init__.py +0 -15
  108. careamics/bioimage/docs/Noise2Void.md +0 -5
  109. careamics/bioimage/docs/__init__.py +0 -1
  110. careamics/bioimage/io.py +0 -182
  111. careamics/bioimage/rdf.py +0 -105
  112. careamics/config/algorithm.py +0 -231
  113. careamics/config/config.py +0 -297
  114. careamics/config/config_filter.py +0 -44
  115. careamics/config/data.py +0 -194
  116. careamics/config/torch_optim.py +0 -118
  117. careamics/config/training.py +0 -534
  118. careamics/dataset/dataset_utils.py +0 -111
  119. careamics/dataset/patching.py +0 -492
  120. careamics/dataset/prepare_dataset.py +0 -175
  121. careamics/dataset/tiff_dataset.py +0 -212
  122. careamics/engine.py +0 -1014
  123. careamics/manipulation/__init__.py +0 -4
  124. careamics/manipulation/pixel_manipulation.py +0 -158
  125. careamics/prediction/prediction_utils.py +0 -106
  126. careamics/utils/ascii_logo.txt +0 -9
  127. careamics/utils/augment.py +0 -65
  128. careamics/utils/normalization.py +0 -55
  129. careamics/utils/validators.py +0 -170
  130. careamics/utils/wandb.py +0 -121
  131. careamics-0.1.0rc2.dist-info/METADATA +0 -81
  132. careamics-0.1.0rc2.dist-info/RECORD +0 -47
  133. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,65 @@
1
+ """Training configuration."""
2
+ from __future__ import annotations
3
+
4
+ from pprint import pformat
5
+ from typing import Literal, Optional
6
+
7
+ from pydantic import (
8
+ BaseModel,
9
+ ConfigDict,
10
+ Field,
11
+ )
12
+
13
+ from .callback_model import CheckpointModel, EarlyStoppingModel
14
+
15
+
16
+ class TrainingModel(BaseModel):
17
+ """
18
+ Parameters related to the training.
19
+
20
+ Mandatory parameters are:
21
+ - num_epochs: number of epochs, greater than 0.
22
+ - batch_size: batch size, greater than 0.
23
+ - augmentation: whether to use data augmentation or not (True or False).
24
+
25
+ Attributes
26
+ ----------
27
+ num_epochs : int
28
+ Number of epochs, greater than 0.
29
+ """
30
+
31
+ # Pydantic class configuration
32
+ model_config = ConfigDict(
33
+ validate_assignment=True,
34
+ )
35
+
36
+ num_epochs: int = Field(default=20, ge=1)
37
+
38
+ logger: Optional[Literal["wandb", "tensorboard"]] = None
39
+
40
+ checkpoint_callback: CheckpointModel = CheckpointModel()
41
+
42
+ early_stopping_callback: Optional[EarlyStoppingModel] = Field(
43
+ default=None, validate_default=True
44
+ )
45
+ # precision: Literal["64", "32", "16", "bf16"] = 32
46
+
47
+ def __str__(self) -> str:
48
+ """Pretty string reprensenting the configuration.
49
+
50
+ Returns
51
+ -------
52
+ str
53
+ Pretty string.
54
+ """
55
+ return pformat(self.model_dump())
56
+
57
+ def has_logger(self) -> bool:
58
+ """Check if the logger is defined.
59
+
60
+ Returns
61
+ -------
62
+ bool
63
+ Whether the logger is defined or not.
64
+ """
65
+ return self.logger is not None
@@ -0,0 +1,14 @@
1
+ """CAREamics transformation Pydantic models."""
2
+
3
+ __all__ = [
4
+ "N2VManipulateModel",
5
+ "NDFlipModel",
6
+ "NormalizeModel",
7
+ "XYRandomRotate90Model",
8
+ ]
9
+
10
+
11
+ from .n2v_manipulate_model import N2VManipulateModel
12
+ from .nd_flip_model import NDFlipModel
13
+ from .normalize_model import NormalizeModel
14
+ from .xy_random_rotate90_model import XYRandomRotate90Model
@@ -0,0 +1,63 @@
1
+ """Pydantic model for the N2VManipulate transform."""
2
+ from typing import Literal
3
+
4
+ from pydantic import ConfigDict, Field, field_validator
5
+
6
+ from .transform_model import TransformModel
7
+
8
+
9
+ class N2VManipulateModel(TransformModel):
10
+ """
11
+ Pydantic model used to represent N2V manipulation.
12
+
13
+ Attributes
14
+ ----------
15
+ name : Literal["N2VManipulate"]
16
+ Name of the transformation.
17
+ roi_size : int
18
+ Size of the masking region, by default 11.
19
+ masked_pixel_percentage : float
20
+ Percentage of masked pixels, by default 0.2.
21
+ strategy : Literal["uniform", "median"]
22
+ Strategy pixel value replacement, by default "uniform".
23
+ struct_mask_axis : Literal["horizontal", "vertical", "none"]
24
+ Axis of the structN2V mask, by default "none".
25
+ struct_mask_span : int
26
+ Span of the structN2V mask, by default 5.
27
+ """
28
+
29
+ model_config = ConfigDict(
30
+ validate_assignment=True,
31
+ )
32
+
33
+ name: Literal["N2VManipulate"]
34
+ roi_size: int = Field(default=11, ge=3, le=21)
35
+ masked_pixel_percentage: float = Field(default=0.2, ge=0.05, le=1.0)
36
+ strategy: Literal["uniform", "median"] = Field(default="uniform")
37
+ struct_mask_axis: Literal["horizontal", "vertical", "none"] = Field(default="none")
38
+ struct_mask_span: int = Field(default=5, ge=3, le=15)
39
+
40
+ @field_validator("roi_size", "struct_mask_span")
41
+ @classmethod
42
+ def odd_value(cls, v: int) -> int:
43
+ """
44
+ Validate that the value is odd.
45
+
46
+ Parameters
47
+ ----------
48
+ v : int
49
+ Value to validate.
50
+
51
+ Returns
52
+ -------
53
+ int
54
+ The validated value.
55
+
56
+ Raises
57
+ ------
58
+ ValueError
59
+ If the value is even.
60
+ """
61
+ if v % 2 == 0:
62
+ raise ValueError("Size must be an odd number.")
63
+ return v
@@ -0,0 +1,32 @@
1
+ """Pydantic model for the NDFlip transform."""
2
+ from typing import Literal
3
+
4
+ from pydantic import ConfigDict, Field
5
+
6
+ from .transform_model import TransformModel
7
+
8
+
9
+ class NDFlipModel(TransformModel):
10
+ """
11
+ Pydantic model used to represent NDFlip transformation.
12
+
13
+ Attributes
14
+ ----------
15
+ name : Literal["NDFlip"]
16
+ Name of the transformation.
17
+ p : float
18
+ Probability of applying the transformation, by default 0.5.
19
+ is_3D : bool
20
+ Whether the transformation should be applied in 3D, by default False.
21
+ flip_z : bool
22
+ Whether to flip the z axis, by default True.
23
+ """
24
+
25
+ model_config = ConfigDict(
26
+ validate_assignment=True,
27
+ )
28
+
29
+ name: Literal["NDFlip"]
30
+ p: float = Field(default=0.5, ge=0.0, le=1.0)
31
+ is_3D: bool = Field(default=False)
32
+ flip_z: bool = Field(default=True)
@@ -0,0 +1,31 @@
1
+ """Pydantic model for the Normalize transform."""
2
+ from typing import Literal
3
+
4
+ from pydantic import ConfigDict, Field
5
+
6
+ from .transform_model import TransformModel
7
+
8
+
9
+ class NormalizeModel(TransformModel):
10
+ """
11
+ Pydantic model used to represent Normalize transformation.
12
+
13
+ The Normalize transform is a zero mean and unit variance transformation.
14
+
15
+ Attributes
16
+ ----------
17
+ name : Literal["Normalize"]
18
+ Name of the transformation.
19
+ mean : float
20
+ Mean value for normalization.
21
+ std : float
22
+ Standard deviation value for normalization.
23
+ """
24
+
25
+ model_config = ConfigDict(
26
+ validate_assignment=True,
27
+ )
28
+
29
+ name: Literal["Normalize"]
30
+ mean: float = Field(default=0.485) # albumentations defaults
31
+ std: float = Field(default=0.229)
@@ -0,0 +1,44 @@
1
+ """Parent model for the transforms."""
2
+ from typing import Any, Dict
3
+
4
+ from pydantic import BaseModel, ConfigDict
5
+
6
+
7
+ class TransformModel(BaseModel):
8
+ """
9
+ Pydantic model used to represent a transformation.
10
+
11
+ The `model_dump` method is overwritten to exclude the name field.
12
+
13
+ Attributes
14
+ ----------
15
+ name : str
16
+ Name of the transformation.
17
+ """
18
+
19
+ model_config = ConfigDict(
20
+ extra="forbid", # throw errors if the parameters are not properly passed
21
+ )
22
+
23
+ name: str
24
+
25
+ def model_dump(self, **kwargs) -> Dict[str, Any]:
26
+ """
27
+ Return the model as a dictionary.
28
+
29
+ Parameters
30
+ ----------
31
+ **kwargs
32
+ Pydantic BaseMode model_dump method keyword arguments.
33
+
34
+ Returns
35
+ -------
36
+ Dict[str, Any]
37
+ Dictionary representation of the model.
38
+ """
39
+ model_dict = super().model_dump(**kwargs)
40
+
41
+ # remove the name field
42
+ model_dict.pop("name")
43
+
44
+ return model_dict
@@ -0,0 +1,29 @@
1
+ """Pydantic model for the XYRandomRotate90 transform."""
2
+ from typing import Literal
3
+
4
+ from pydantic import ConfigDict, Field
5
+
6
+ from .transform_model import TransformModel
7
+
8
+
9
+ class XYRandomRotate90Model(TransformModel):
10
+ """
11
+ Pydantic model used to represent NDFlip transformation.
12
+
13
+ Attributes
14
+ ----------
15
+ name : Literal["XYRandomRotate90"]
16
+ Name of the transformation.
17
+ p : float
18
+ Probability of applying the transformation, by default 0.5.
19
+ is_3D : bool
20
+ Whether the transformation should be applied in 3D, by default False.
21
+ """
22
+
23
+ model_config = ConfigDict(
24
+ validate_assignment=True,
25
+ )
26
+
27
+ name: Literal["XYRandomRotate90"]
28
+ p: float = Field(default=0.5, ge=0.0, le=1.0)
29
+ is_3D: bool = Field(default=False)
@@ -0,0 +1,5 @@
1
+ """Validator utilities."""
2
+
3
+ __all__ = ["check_axes_validity", "patch_size_ge_than_8_power_of_2"]
4
+
5
+ from .validator_utils import check_axes_validity, patch_size_ge_than_8_power_of_2
@@ -0,0 +1,100 @@
1
+ """
2
+ Validator functions.
3
+
4
+ These functions are used to validate dimensions and axes of inputs.
5
+ """
6
+ from typing import List, Optional, Tuple, Union
7
+
8
+ _AXES = "STCZYX"
9
+
10
+
11
+ def check_axes_validity(axes: str) -> None:
12
+ """
13
+ Sanity check on axes.
14
+
15
+ The constraints on the axes are the following:
16
+ - must be a combination of 'STCZYX'
17
+ - must not contain duplicates
18
+ - must contain at least 2 contiguous axes: X and Y
19
+ - must contain at most 4 axes
20
+ - cannot contain both S and T axes
21
+
22
+ Axes do not need to be in the order 'STCZYX', as this depends on the user data.
23
+
24
+ Parameters
25
+ ----------
26
+ axes : str
27
+ Axes to validate.
28
+ """
29
+ _axes = axes.upper()
30
+
31
+ # Minimum is 2 (XY) and maximum is 4 (TZYX)
32
+ if len(_axes) < 2 or len(_axes) > 6:
33
+ raise ValueError(
34
+ f"Invalid axes {axes}. Must contain at least 2 and at most 6 axes."
35
+ )
36
+
37
+ if "YX" not in _axes and "XY" not in _axes:
38
+ raise ValueError(
39
+ f"Invalid axes {axes}. Must contain at least X and Y axes consecutively."
40
+ )
41
+
42
+ # all characters must be in REF_AXES = 'STCZYX'
43
+ if not all(s in _AXES for s in _axes):
44
+ raise ValueError(f"Invalid axes {axes}. Must be a combination of {_AXES}.")
45
+
46
+ # check for repeating characters
47
+ for i, s in enumerate(_axes):
48
+ if i != _axes.rfind(s):
49
+ raise ValueError(
50
+ f"Invalid axes {axes}. Cannot contain duplicate axes"
51
+ f" (got multiple {axes[i]})."
52
+ )
53
+
54
+
55
+ def value_ge_than_8_power_of_2(
56
+ value: int,
57
+ ) -> None:
58
+ """
59
+ Validate that the value is greater or equal than 8 and a power of 2.
60
+
61
+ Parameters
62
+ ----------
63
+ value : int
64
+ Value to validate.
65
+
66
+ Raises
67
+ ------
68
+ ValueError
69
+ If the value is smaller than 8.
70
+ ValueError
71
+ If the value is not a power of 2.
72
+ """
73
+ if value < 8:
74
+ raise ValueError(f"Value must be non-zero positive (got {value}).")
75
+
76
+ if (value & (value - 1)) != 0:
77
+ raise ValueError(f"Value must be a power of 2 (got {value}).")
78
+
79
+
80
+ def patch_size_ge_than_8_power_of_2(
81
+ patch_list: Optional[Union[List[int], Union[Tuple[int, ...]]]],
82
+ ) -> None:
83
+ """
84
+ Validate that each entry is greater or equal than 8 and a power of 2.
85
+
86
+ Parameters
87
+ ----------
88
+ patch_list : Optional[Union[List[int]]]
89
+ Patch size.
90
+
91
+ Raises
92
+ ------
93
+ ValueError
94
+ If the patch size if smaller than 8.
95
+ ValueError
96
+ If the patch size is not a power of 2.
97
+ """
98
+ if patch_list is not None:
99
+ for dim in patch_list:
100
+ value_ge_than_8_power_of_2(dim)
careamics/conftest.py ADDED
@@ -0,0 +1,26 @@
1
+ """File used to discover python modules and run doctest.
2
+
3
+ See https://sybil.readthedocs.io/en/latest/use.html#pytest
4
+ """
5
+ from pathlib import Path
6
+
7
+ import pytest
8
+ from pytest import TempPathFactory
9
+ from sybil import Sybil
10
+ from sybil.parsers.codeblock import PythonCodeBlockParser
11
+ from sybil.parsers.doctest import DocTestParser
12
+
13
+
14
+ @pytest.fixture(scope="module")
15
+ def my_path(tmpdir_factory: TempPathFactory) -> Path:
16
+ return tmpdir_factory.mktemp("my_path")
17
+
18
+
19
+ pytest_collect_file = Sybil(
20
+ parsers=[
21
+ DocTestParser(),
22
+ PythonCodeBlockParser(future_imports=["print_function"]),
23
+ ],
24
+ pattern="*.py",
25
+ fixtures=["my_path"],
26
+ ).pytest()
@@ -1 +1,6 @@
1
1
  """Dataset module."""
2
+
3
+ __all__ = ["InMemoryDataset", "PathIterableDataset"]
4
+
5
+ from .in_memory_dataset import InMemoryDataset
6
+ from .iterable_dataset import PathIterableDataset
@@ -0,0 +1,19 @@
1
+ """Files and arrays utils used in the datasets."""
2
+
3
+
4
+ __all__ = [
5
+ "reshape_array",
6
+ "get_files_size",
7
+ "list_files",
8
+ "validate_source_target_files",
9
+ "read_tiff",
10
+ "get_read_func",
11
+ "read_zarr",
12
+ ]
13
+
14
+
15
+ from .dataset_utils import reshape_array
16
+ from .file_utils import get_files_size, list_files, validate_source_target_files
17
+ from .read_tiff import read_tiff
18
+ from .read_utils import get_read_func
19
+ from .read_zarr import read_zarr
@@ -0,0 +1,100 @@
1
+ """Convenience methods for datasets."""
2
+ from typing import List, Tuple
3
+
4
+ import numpy as np
5
+
6
+ from careamics.utils.logging import get_logger
7
+
8
+ logger = get_logger(__name__)
9
+
10
+
11
+ def _get_shape_order(
12
+ shape_in: Tuple[int, ...], axes_in: str, ref_axes: str = "STCZYX"
13
+ ) -> Tuple[Tuple[int, ...], str, List[int]]:
14
+ """
15
+ Compute a new shape for the array based on the reference axes.
16
+
17
+ Parameters
18
+ ----------
19
+ shape_in : Tuple
20
+ Input shape.
21
+ ref_axes : str
22
+ Reference axes.
23
+ axes_in : str
24
+ Input axes.
25
+
26
+ Returns
27
+ -------
28
+ Tuple[Tuple[int, ...], str, List[int]]
29
+ New shape, new axes, indices of axes in the new axes order.
30
+ """
31
+ indices = [axes_in.find(k) for k in ref_axes]
32
+
33
+ # remove all non-existing axes (index == -1)
34
+ new_indices = list(filter(lambda k: k != -1, indices))
35
+
36
+ # find axes order and get new shape
37
+ new_axes = [axes_in[ind] for ind in new_indices]
38
+ new_shape = tuple([shape_in[ind] for ind in new_indices])
39
+
40
+ return new_shape, "".join(new_axes), new_indices
41
+
42
+
43
+ def reshape_array(x: np.ndarray, axes: str) -> np.ndarray:
44
+ """Reshape the data to (S, C, (Z), Y, X) by moving axes.
45
+
46
+ If the data has both S and T axes, the two axes will be merged. A singleton
47
+ dimension is added if there are no C axis.
48
+
49
+ Parameters
50
+ ----------
51
+ x : np.ndarray
52
+ Input array.
53
+ axes : str
54
+ Description of axes in format `STCZYX`.
55
+
56
+ Returns
57
+ -------
58
+ np.ndarray
59
+ Reshaped array with shape (S, C, (Z), Y, X).
60
+ """
61
+ _x = x
62
+ _axes = axes
63
+
64
+ # sanity checks
65
+ if len(_axes) != len(_x.shape):
66
+ raise ValueError(
67
+ f"Incompatible data shape ({_x.shape}) and axes ({_axes}). Are the axes "
68
+ f"correct?"
69
+ )
70
+
71
+ # get new x shape
72
+ new_x_shape, new_axes, indices = _get_shape_order(_x.shape, _axes)
73
+
74
+ # if S is not in the list of axes, then add a singleton S
75
+ if "S" not in new_axes:
76
+ new_axes = "S" + new_axes
77
+ _x = _x[np.newaxis, ...]
78
+ new_x_shape = (1,) + new_x_shape
79
+
80
+ # need to change the array of indices
81
+ indices = [0] + [1 + i for i in indices]
82
+
83
+ # reshape by moving axes
84
+ destination = list(range(len(indices)))
85
+ _x = np.moveaxis(_x, indices, destination)
86
+
87
+ # remove T if necessary
88
+ if "T" in new_axes:
89
+ new_x_shape = (-1,) + new_x_shape[2:] # remove T and S
90
+ new_axes = new_axes.replace("T", "")
91
+
92
+ # reshape S and T together
93
+ _x = _x.reshape(new_x_shape)
94
+
95
+ # add channel
96
+ if "C" not in new_axes:
97
+ # Add channel axis after S
98
+ _x = np.expand_dims(_x, new_axes.index("S") + 1)
99
+
100
+ return _x
@@ -0,0 +1,140 @@
1
+ from fnmatch import fnmatch
2
+ from pathlib import Path
3
+ from typing import List, Union
4
+
5
+ import numpy as np
6
+
7
+ from careamics.config.support import SupportedData
8
+ from careamics.utils.logging import get_logger
9
+
10
+ logger = get_logger(__name__)
11
+
12
+
13
+ def get_files_size(files: List[Path]) -> float:
14
+ """
15
+ Get files size in MB.
16
+
17
+ Parameters
18
+ ----------
19
+ files : List[Path]
20
+ List of files.
21
+
22
+ Returns
23
+ -------
24
+ float
25
+ Total size of the files in MB.
26
+ """
27
+ return np.sum([f.stat().st_size / 1024**2 for f in files])
28
+
29
+
30
+ def list_files(
31
+ data_path: Union[str, Path],
32
+ data_type: Union[str, SupportedData],
33
+ extension_filter: str = "",
34
+ ) -> List[Path]:
35
+ """Creates a recursive list of files in `data_path`.
36
+
37
+ If `data_path` is a file, its name is validated against the `data_type` using
38
+ `fnmatch`, and the method returns `data_path` itself.
39
+
40
+ By default, if `data_type` is equal to `custom`, all files will be listed. To
41
+ further filter the files, use `extension_filter`.
42
+
43
+ `extension_filter` must be compatible with `fnmatch` and `Path.rglob`, e.g. "*.npy"
44
+ or "*.czi".
45
+
46
+ Parameters
47
+ ----------
48
+ data_path : Union[str, Path]
49
+ Path to the folder containing the data.
50
+ data_type : Union[str, SupportedData]
51
+ One of the supported data type (e.g. tif, custom).
52
+ extension_filter : str, optional
53
+ Extension filter, by default "".
54
+
55
+ Returns
56
+ -------
57
+ List[Path]
58
+ List of pathlib.Path objects.
59
+
60
+ Raises
61
+ ------
62
+ FileNotFoundError
63
+ If the data path does not exist.
64
+ ValueError
65
+ If the data path is empty or no files with the extension were found.
66
+ ValueError
67
+ If the file does not match the requested extension.
68
+ """
69
+ # convert to Path
70
+ data_path = Path(data_path)
71
+
72
+ # raise error if does not exists
73
+ if not data_path.exists():
74
+ raise FileNotFoundError(f"Data path {data_path} does not exist.")
75
+
76
+ # get extension compatible with fnmatch and rglob search
77
+ extension = SupportedData.get_extension(data_type)
78
+
79
+ if data_type == SupportedData.CUSTOM and extension_filter != "":
80
+ extension = extension_filter
81
+
82
+ # search recurively
83
+ if data_path.is_dir():
84
+ # search recursively the path for files with the extension
85
+ files = sorted(data_path.rglob(extension))
86
+ else:
87
+ # raise error if it has the wrong extension
88
+ if not fnmatch(str(data_path.absolute()), extension):
89
+ raise ValueError(
90
+ f"File {data_path} does not match the requested extension "
91
+ f'"{extension}".'
92
+ )
93
+
94
+ # save in list
95
+ files = [data_path]
96
+
97
+ # raise error if no files were found
98
+ if len(files) == 0:
99
+ raise ValueError(
100
+ f'Data path {data_path} is empty or files with extension "{extension}" '
101
+ f"were not found."
102
+ )
103
+
104
+ return files
105
+
106
+
107
+ def validate_source_target_files(src_files: List[Path], tar_files: List[Path]) -> None:
108
+ """
109
+ Validate source and target path lists.
110
+
111
+ The two lists should have the same number of files, and the filenames should match.
112
+
113
+ Parameters
114
+ ----------
115
+ src_files : List[Path]
116
+ List of source files.
117
+ tar_files : List[Path]
118
+ List of target files.
119
+
120
+ Raises
121
+ ------
122
+ ValueError
123
+ If the number of files in source and target folders is not the same.
124
+ ValueError
125
+ If some filenames in Train and target folders are not the same.
126
+ """
127
+ # check equal length
128
+ if len(src_files) != len(tar_files):
129
+ raise ValueError(
130
+ f"The number of source files ({len(src_files)}) is not equal to the number "
131
+ f"of target files ({len(tar_files)})."
132
+ )
133
+
134
+ # check identical names
135
+ src_names = {f.name for f in src_files}
136
+ tar_names = {f.name for f in tar_files}
137
+ difference = src_names.symmetric_difference(tar_names)
138
+
139
+ if len(difference) > 0:
140
+ raise ValueError(f"Source and target files have different names: {difference}.")