careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +16 -4
- careamics/callbacks/__init__.py +6 -0
- careamics/callbacks/hyperparameters_callback.py +42 -0
- careamics/callbacks/progress_bar_callback.py +57 -0
- careamics/careamist.py +761 -0
- careamics/config/__init__.py +31 -3
- careamics/config/algorithm_model.py +167 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +29 -0
- careamics/config/architectures/custom_model.py +150 -0
- careamics/config/architectures/register_model.py +101 -0
- careamics/config/architectures/unet_model.py +96 -0
- careamics/config/architectures/vae_model.py +39 -0
- careamics/config/callback_model.py +92 -0
- careamics/config/configuration_example.py +89 -0
- careamics/config/configuration_factory.py +597 -0
- careamics/config/configuration_model.py +597 -0
- careamics/config/data_model.py +555 -0
- careamics/config/inference_model.py +283 -0
- careamics/config/noise_models.py +162 -0
- careamics/config/optimizer_models.py +181 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +131 -0
- careamics/config/references/references.py +38 -0
- careamics/config/support/__init__.py +33 -0
- careamics/config/support/supported_activations.py +24 -0
- careamics/config/support/supported_algorithms.py +18 -0
- careamics/config/support/supported_architectures.py +18 -0
- careamics/config/support/supported_data.py +82 -0
- careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
- careamics/config/support/supported_loggers.py +8 -0
- careamics/config/support/supported_losses.py +25 -0
- careamics/config/support/supported_optimizers.py +55 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +19 -0
- careamics/config/support/supported_transforms.py +23 -0
- careamics/config/tile_information.py +104 -0
- careamics/config/training_model.py +65 -0
- careamics/config/transformations/__init__.py +14 -0
- careamics/config/transformations/n2v_manipulate_model.py +63 -0
- careamics/config/transformations/nd_flip_model.py +32 -0
- careamics/config/transformations/normalize_model.py +31 -0
- careamics/config/transformations/transform_model.py +44 -0
- careamics/config/transformations/xy_random_rotate90_model.py +29 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +100 -0
- careamics/conftest.py +26 -0
- careamics/dataset/__init__.py +5 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +100 -0
- careamics/dataset/dataset_utils/file_utils.py +140 -0
- careamics/dataset/dataset_utils/read_tiff.py +61 -0
- careamics/dataset/dataset_utils/read_utils.py +25 -0
- careamics/dataset/dataset_utils/read_zarr.py +56 -0
- careamics/dataset/in_memory_dataset.py +323 -134
- careamics/dataset/iterable_dataset.py +416 -0
- careamics/dataset/patching/__init__.py +8 -0
- careamics/dataset/patching/patch_transform.py +44 -0
- careamics/dataset/patching/patching.py +212 -0
- careamics/dataset/patching/random_patching.py +190 -0
- careamics/dataset/patching/sequential_patching.py +206 -0
- careamics/dataset/patching/tiled_patching.py +158 -0
- careamics/dataset/patching/validate_patch_dimension.py +60 -0
- careamics/dataset/zarr_dataset.py +149 -0
- careamics/lightning_datamodule.py +743 -0
- careamics/lightning_module.py +292 -0
- careamics/lightning_prediction_datamodule.py +396 -0
- careamics/lightning_prediction_loop.py +116 -0
- careamics/losses/__init__.py +4 -1
- careamics/losses/loss_factory.py +24 -14
- careamics/losses/losses.py +65 -5
- careamics/losses/noise_model_factory.py +40 -0
- careamics/losses/noise_models.py +524 -0
- careamics/model_io/__init__.py +8 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +120 -0
- careamics/model_io/bioimage/bioimage_utils.py +48 -0
- careamics/model_io/bioimage/model_description.py +318 -0
- careamics/model_io/bmz_io.py +231 -0
- careamics/model_io/model_io_utils.py +80 -0
- careamics/models/__init__.py +4 -1
- careamics/models/activation.py +35 -0
- careamics/models/layers.py +244 -0
- careamics/models/model_factory.py +21 -221
- careamics/models/unet.py +46 -20
- careamics/prediction/__init__.py +1 -3
- careamics/prediction/stitch_prediction.py +73 -0
- careamics/transforms/__init__.py +41 -0
- careamics/transforms/n2v_manipulate.py +113 -0
- careamics/transforms/nd_flip.py +93 -0
- careamics/transforms/normalize.py +109 -0
- careamics/transforms/pixel_manipulation.py +383 -0
- careamics/transforms/struct_mask_parameters.py +18 -0
- careamics/transforms/tta.py +74 -0
- careamics/transforms/xy_random_rotate90.py +95 -0
- careamics/utils/__init__.py +10 -12
- careamics/utils/base_enum.py +32 -0
- careamics/utils/context.py +22 -2
- careamics/utils/metrics.py +0 -46
- careamics/utils/path_utils.py +24 -0
- careamics/utils/ram.py +13 -0
- careamics/utils/receptive_field.py +102 -0
- careamics/utils/running_stats.py +43 -0
- careamics/utils/torch_utils.py +112 -75
- careamics-0.1.0rc4.dist-info/METADATA +122 -0
- careamics-0.1.0rc4.dist-info/RECORD +110 -0
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/WHEEL +1 -1
- careamics/bioimage/__init__.py +0 -15
- careamics/bioimage/docs/Noise2Void.md +0 -5
- careamics/bioimage/docs/__init__.py +0 -1
- careamics/bioimage/io.py +0 -182
- careamics/bioimage/rdf.py +0 -105
- careamics/config/algorithm.py +0 -231
- careamics/config/config.py +0 -297
- careamics/config/config_filter.py +0 -44
- careamics/config/data.py +0 -194
- careamics/config/torch_optim.py +0 -118
- careamics/config/training.py +0 -534
- careamics/dataset/dataset_utils.py +0 -111
- careamics/dataset/patching.py +0 -492
- careamics/dataset/prepare_dataset.py +0 -175
- careamics/dataset/tiff_dataset.py +0 -212
- careamics/engine.py +0 -1014
- careamics/manipulation/__init__.py +0 -4
- careamics/manipulation/pixel_manipulation.py +0 -158
- careamics/prediction/prediction_utils.py +0 -106
- careamics/utils/ascii_logo.txt +0 -9
- careamics/utils/augment.py +0 -65
- careamics/utils/normalization.py +0 -55
- careamics/utils/validators.py +0 -170
- careamics/utils/wandb.py +0 -121
- careamics-0.1.0rc2.dist-info/METADATA +0 -81
- careamics-0.1.0rc2.dist-info/RECORD +0 -47
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
"""Training configuration."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from pprint import pformat
|
|
5
|
+
from typing import Literal, Optional
|
|
6
|
+
|
|
7
|
+
from pydantic import (
|
|
8
|
+
BaseModel,
|
|
9
|
+
ConfigDict,
|
|
10
|
+
Field,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from .callback_model import CheckpointModel, EarlyStoppingModel
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TrainingConfig(BaseModel):
|
|
17
|
+
"""
|
|
18
|
+
Parameters related to the training.
|
|
19
|
+
|
|
20
|
+
Mandatory parameters are:
|
|
21
|
+
- num_epochs: number of epochs, greater than 0.
|
|
22
|
+
- batch_size: batch size, greater than 0.
|
|
23
|
+
- augmentation: whether to use data augmentation or not (True or False).
|
|
24
|
+
|
|
25
|
+
Attributes
|
|
26
|
+
----------
|
|
27
|
+
num_epochs : int
|
|
28
|
+
Number of epochs, greater than 0.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
# Pydantic class configuration
|
|
32
|
+
model_config = ConfigDict(
|
|
33
|
+
validate_assignment=True,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
num_epochs: int = Field(default=20, ge=1)
|
|
37
|
+
|
|
38
|
+
logger: Optional[Literal["wandb", "tensorboard"]] = None
|
|
39
|
+
|
|
40
|
+
checkpoint_callback: CheckpointModel = CheckpointModel()
|
|
41
|
+
|
|
42
|
+
early_stopping_callback: Optional[EarlyStoppingModel] = Field(
|
|
43
|
+
default=None, validate_default=True
|
|
44
|
+
)
|
|
45
|
+
# precision: Literal["64", "32", "16", "bf16"] = 32
|
|
46
|
+
|
|
47
|
+
def __str__(self) -> str:
|
|
48
|
+
"""Pretty string reprensenting the configuration.
|
|
49
|
+
|
|
50
|
+
Returns
|
|
51
|
+
-------
|
|
52
|
+
str
|
|
53
|
+
Pretty string.
|
|
54
|
+
"""
|
|
55
|
+
return pformat(self.model_dump())
|
|
56
|
+
|
|
57
|
+
def has_logger(self) -> bool:
|
|
58
|
+
"""Check if the logger is defined.
|
|
59
|
+
|
|
60
|
+
Returns
|
|
61
|
+
-------
|
|
62
|
+
bool
|
|
63
|
+
Whether the logger is defined or not.
|
|
64
|
+
"""
|
|
65
|
+
return self.logger is not None
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""CAREamics transformation Pydantic models."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"N2VManipulateModel",
|
|
5
|
+
"NDFlipModel",
|
|
6
|
+
"NormalizeModel",
|
|
7
|
+
"XYRandomRotate90Model",
|
|
8
|
+
]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
from .n2v_manipulate_model import N2VManipulateModel
|
|
12
|
+
from .nd_flip_model import NDFlipModel
|
|
13
|
+
from .normalize_model import NormalizeModel
|
|
14
|
+
from .xy_random_rotate90_model import XYRandomRotate90Model
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""Pydantic model for the N2VManipulate transform."""
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
from pydantic import ConfigDict, Field, field_validator
|
|
5
|
+
|
|
6
|
+
from .transform_model import TransformModel
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class N2VManipulateModel(TransformModel):
|
|
10
|
+
"""
|
|
11
|
+
Pydantic model used to represent N2V manipulation.
|
|
12
|
+
|
|
13
|
+
Attributes
|
|
14
|
+
----------
|
|
15
|
+
name : Literal["N2VManipulate"]
|
|
16
|
+
Name of the transformation.
|
|
17
|
+
roi_size : int
|
|
18
|
+
Size of the masking region, by default 11.
|
|
19
|
+
masked_pixel_percentage : float
|
|
20
|
+
Percentage of masked pixels, by default 0.2.
|
|
21
|
+
strategy : Literal["uniform", "median"]
|
|
22
|
+
Strategy pixel value replacement, by default "uniform".
|
|
23
|
+
struct_mask_axis : Literal["horizontal", "vertical", "none"]
|
|
24
|
+
Axis of the structN2V mask, by default "none".
|
|
25
|
+
struct_mask_span : int
|
|
26
|
+
Span of the structN2V mask, by default 5.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
model_config = ConfigDict(
|
|
30
|
+
validate_assignment=True,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
name: Literal["N2VManipulate"] = "N2VManipulate"
|
|
34
|
+
roi_size: int = Field(default=11, ge=3, le=21)
|
|
35
|
+
masked_pixel_percentage: float = Field(default=0.2, ge=0.05, le=1.0)
|
|
36
|
+
strategy: Literal["uniform", "median"] = Field(default="uniform")
|
|
37
|
+
struct_mask_axis: Literal["horizontal", "vertical", "none"] = Field(default="none")
|
|
38
|
+
struct_mask_span: int = Field(default=5, ge=3, le=15)
|
|
39
|
+
|
|
40
|
+
@field_validator("roi_size", "struct_mask_span")
|
|
41
|
+
@classmethod
|
|
42
|
+
def odd_value(cls, v: int) -> int:
|
|
43
|
+
"""
|
|
44
|
+
Validate that the value is odd.
|
|
45
|
+
|
|
46
|
+
Parameters
|
|
47
|
+
----------
|
|
48
|
+
v : int
|
|
49
|
+
Value to validate.
|
|
50
|
+
|
|
51
|
+
Returns
|
|
52
|
+
-------
|
|
53
|
+
int
|
|
54
|
+
The validated value.
|
|
55
|
+
|
|
56
|
+
Raises
|
|
57
|
+
------
|
|
58
|
+
ValueError
|
|
59
|
+
If the value is even.
|
|
60
|
+
"""
|
|
61
|
+
if v % 2 == 0:
|
|
62
|
+
raise ValueError("Size must be an odd number.")
|
|
63
|
+
return v
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
"""Pydantic model for the NDFlip transform."""
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
from pydantic import ConfigDict, Field
|
|
5
|
+
|
|
6
|
+
from .transform_model import TransformModel
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class NDFlipModel(TransformModel):
|
|
10
|
+
"""
|
|
11
|
+
Pydantic model used to represent NDFlip transformation.
|
|
12
|
+
|
|
13
|
+
Attributes
|
|
14
|
+
----------
|
|
15
|
+
name : Literal["NDFlip"]
|
|
16
|
+
Name of the transformation.
|
|
17
|
+
p : float
|
|
18
|
+
Probability of applying the transformation, by default 0.5.
|
|
19
|
+
is_3D : bool
|
|
20
|
+
Whether the transformation should be applied in 3D, by default False.
|
|
21
|
+
flip_z : bool
|
|
22
|
+
Whether to flip the z axis, by default True.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
model_config = ConfigDict(
|
|
26
|
+
validate_assignment=True,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
name: Literal["NDFlip"] = "NDFlip"
|
|
30
|
+
p: float = Field(default=0.5, ge=0.0, le=1.0)
|
|
31
|
+
is_3D: bool = Field(default=False)
|
|
32
|
+
flip_z: bool = Field(default=True)
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""Pydantic model for the Normalize transform."""
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
from pydantic import ConfigDict, Field
|
|
5
|
+
|
|
6
|
+
from .transform_model import TransformModel
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class NormalizeModel(TransformModel):
|
|
10
|
+
"""
|
|
11
|
+
Pydantic model used to represent Normalize transformation.
|
|
12
|
+
|
|
13
|
+
The Normalize transform is a zero mean and unit variance transformation.
|
|
14
|
+
|
|
15
|
+
Attributes
|
|
16
|
+
----------
|
|
17
|
+
name : Literal["Normalize"]
|
|
18
|
+
Name of the transformation.
|
|
19
|
+
mean : float
|
|
20
|
+
Mean value for normalization.
|
|
21
|
+
std : float
|
|
22
|
+
Standard deviation value for normalization.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
model_config = ConfigDict(
|
|
26
|
+
validate_assignment=True,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
name: Literal["Normalize"] = "Normalize"
|
|
30
|
+
mean: float = Field(default=0.485) # albumentations defaults
|
|
31
|
+
std: float = Field(default=0.229)
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
"""Parent model for the transforms."""
|
|
2
|
+
from typing import Any, Dict
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel, ConfigDict
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class TransformModel(BaseModel):
|
|
8
|
+
"""
|
|
9
|
+
Pydantic model used to represent a transformation.
|
|
10
|
+
|
|
11
|
+
The `model_dump` method is overwritten to exclude the name field.
|
|
12
|
+
|
|
13
|
+
Attributes
|
|
14
|
+
----------
|
|
15
|
+
name : str
|
|
16
|
+
Name of the transformation.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
model_config = ConfigDict(
|
|
20
|
+
extra="forbid", # throw errors if the parameters are not properly passed
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
name: str
|
|
24
|
+
|
|
25
|
+
def model_dump(self, **kwargs) -> Dict[str, Any]:
|
|
26
|
+
"""
|
|
27
|
+
Return the model as a dictionary.
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
**kwargs
|
|
32
|
+
Pydantic BaseMode model_dump method keyword arguments.
|
|
33
|
+
|
|
34
|
+
Returns
|
|
35
|
+
-------
|
|
36
|
+
Dict[str, Any]
|
|
37
|
+
Dictionary representation of the model.
|
|
38
|
+
"""
|
|
39
|
+
model_dict = super().model_dump(**kwargs)
|
|
40
|
+
|
|
41
|
+
# remove the name field
|
|
42
|
+
model_dict.pop("name")
|
|
43
|
+
|
|
44
|
+
return model_dict
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
"""Pydantic model for the XYRandomRotate90 transform."""
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
from pydantic import ConfigDict, Field
|
|
5
|
+
|
|
6
|
+
from .transform_model import TransformModel
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class XYRandomRotate90Model(TransformModel):
|
|
10
|
+
"""
|
|
11
|
+
Pydantic model used to represent NDFlip transformation.
|
|
12
|
+
|
|
13
|
+
Attributes
|
|
14
|
+
----------
|
|
15
|
+
name : Literal["XYRandomRotate90"]
|
|
16
|
+
Name of the transformation.
|
|
17
|
+
p : float
|
|
18
|
+
Probability of applying the transformation, by default 0.5.
|
|
19
|
+
is_3D : bool
|
|
20
|
+
Whether the transformation should be applied in 3D, by default False.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
model_config = ConfigDict(
|
|
24
|
+
validate_assignment=True,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
name: Literal["XYRandomRotate90"] = "XYRandomRotate90"
|
|
28
|
+
p: float = Field(default=0.5, ge=0.0, le=1.0)
|
|
29
|
+
is_3D: bool = Field(default=False)
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Validator functions.
|
|
3
|
+
|
|
4
|
+
These functions are used to validate dimensions and axes of inputs.
|
|
5
|
+
"""
|
|
6
|
+
from typing import List, Optional, Tuple, Union
|
|
7
|
+
|
|
8
|
+
_AXES = "STCZYX"
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def check_axes_validity(axes: str) -> None:
|
|
12
|
+
"""
|
|
13
|
+
Sanity check on axes.
|
|
14
|
+
|
|
15
|
+
The constraints on the axes are the following:
|
|
16
|
+
- must be a combination of 'STCZYX'
|
|
17
|
+
- must not contain duplicates
|
|
18
|
+
- must contain at least 2 contiguous axes: X and Y
|
|
19
|
+
- must contain at most 4 axes
|
|
20
|
+
- cannot contain both S and T axes
|
|
21
|
+
|
|
22
|
+
Axes do not need to be in the order 'STCZYX', as this depends on the user data.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
axes : str
|
|
27
|
+
Axes to validate.
|
|
28
|
+
"""
|
|
29
|
+
_axes = axes.upper()
|
|
30
|
+
|
|
31
|
+
# Minimum is 2 (XY) and maximum is 4 (TZYX)
|
|
32
|
+
if len(_axes) < 2 or len(_axes) > 6:
|
|
33
|
+
raise ValueError(
|
|
34
|
+
f"Invalid axes {axes}. Must contain at least 2 and at most 6 axes."
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
if "YX" not in _axes and "XY" not in _axes:
|
|
38
|
+
raise ValueError(
|
|
39
|
+
f"Invalid axes {axes}. Must contain at least X and Y axes consecutively."
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
# all characters must be in REF_AXES = 'STCZYX'
|
|
43
|
+
if not all(s in _AXES for s in _axes):
|
|
44
|
+
raise ValueError(f"Invalid axes {axes}. Must be a combination of {_AXES}.")
|
|
45
|
+
|
|
46
|
+
# check for repeating characters
|
|
47
|
+
for i, s in enumerate(_axes):
|
|
48
|
+
if i != _axes.rfind(s):
|
|
49
|
+
raise ValueError(
|
|
50
|
+
f"Invalid axes {axes}. Cannot contain duplicate axes"
|
|
51
|
+
f" (got multiple {axes[i]})."
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def value_ge_than_8_power_of_2(
|
|
56
|
+
value: int,
|
|
57
|
+
) -> None:
|
|
58
|
+
"""
|
|
59
|
+
Validate that the value is greater or equal than 8 and a power of 2.
|
|
60
|
+
|
|
61
|
+
Parameters
|
|
62
|
+
----------
|
|
63
|
+
value : int
|
|
64
|
+
Value to validate.
|
|
65
|
+
|
|
66
|
+
Raises
|
|
67
|
+
------
|
|
68
|
+
ValueError
|
|
69
|
+
If the value is smaller than 8.
|
|
70
|
+
ValueError
|
|
71
|
+
If the value is not a power of 2.
|
|
72
|
+
"""
|
|
73
|
+
if value < 8:
|
|
74
|
+
raise ValueError(f"Value must be non-zero positive (got {value}).")
|
|
75
|
+
|
|
76
|
+
if (value & (value - 1)) != 0:
|
|
77
|
+
raise ValueError(f"Value must be a power of 2 (got {value}).")
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def patch_size_ge_than_8_power_of_2(
|
|
81
|
+
patch_list: Optional[Union[List[int], Union[Tuple[int, ...]]]],
|
|
82
|
+
) -> None:
|
|
83
|
+
"""
|
|
84
|
+
Validate that each entry is greater or equal than 8 and a power of 2.
|
|
85
|
+
|
|
86
|
+
Parameters
|
|
87
|
+
----------
|
|
88
|
+
patch_list : Optional[Union[List[int]]]
|
|
89
|
+
Patch size.
|
|
90
|
+
|
|
91
|
+
Raises
|
|
92
|
+
------
|
|
93
|
+
ValueError
|
|
94
|
+
If the patch size if smaller than 8.
|
|
95
|
+
ValueError
|
|
96
|
+
If the patch size is not a power of 2.
|
|
97
|
+
"""
|
|
98
|
+
if patch_list is not None:
|
|
99
|
+
for dim in patch_list:
|
|
100
|
+
value_ge_than_8_power_of_2(dim)
|
careamics/conftest.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""File used to discover python modules and run doctest.
|
|
2
|
+
|
|
3
|
+
See https://sybil.readthedocs.io/en/latest/use.html#pytest
|
|
4
|
+
"""
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import pytest
|
|
8
|
+
from pytest import TempPathFactory
|
|
9
|
+
from sybil import Sybil
|
|
10
|
+
from sybil.parsers.codeblock import PythonCodeBlockParser
|
|
11
|
+
from sybil.parsers.doctest import DocTestParser
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@pytest.fixture(scope="module")
|
|
15
|
+
def my_path(tmpdir_factory: TempPathFactory) -> Path:
|
|
16
|
+
return tmpdir_factory.mktemp("my_path")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
pytest_collect_file = Sybil(
|
|
20
|
+
parsers=[
|
|
21
|
+
DocTestParser(),
|
|
22
|
+
PythonCodeBlockParser(future_imports=["print_function"]),
|
|
23
|
+
],
|
|
24
|
+
pattern="*.py",
|
|
25
|
+
fixtures=["my_path"],
|
|
26
|
+
).pytest()
|
careamics/dataset/__init__.py
CHANGED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""Files and arrays utils used in the datasets."""
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"reshape_array",
|
|
6
|
+
"get_files_size",
|
|
7
|
+
"list_files",
|
|
8
|
+
"validate_source_target_files",
|
|
9
|
+
"read_tiff",
|
|
10
|
+
"get_read_func",
|
|
11
|
+
"read_zarr",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
from .dataset_utils import reshape_array
|
|
16
|
+
from .file_utils import get_files_size, list_files, validate_source_target_files
|
|
17
|
+
from .read_tiff import read_tiff
|
|
18
|
+
from .read_utils import get_read_func
|
|
19
|
+
from .read_zarr import read_zarr
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
"""Convenience methods for datasets."""
|
|
2
|
+
from typing import List, Tuple
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from careamics.utils.logging import get_logger
|
|
7
|
+
|
|
8
|
+
logger = get_logger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _get_shape_order(
|
|
12
|
+
shape_in: Tuple[int, ...], axes_in: str, ref_axes: str = "STCZYX"
|
|
13
|
+
) -> Tuple[Tuple[int, ...], str, List[int]]:
|
|
14
|
+
"""
|
|
15
|
+
Compute a new shape for the array based on the reference axes.
|
|
16
|
+
|
|
17
|
+
Parameters
|
|
18
|
+
----------
|
|
19
|
+
shape_in : Tuple
|
|
20
|
+
Input shape.
|
|
21
|
+
ref_axes : str
|
|
22
|
+
Reference axes.
|
|
23
|
+
axes_in : str
|
|
24
|
+
Input axes.
|
|
25
|
+
|
|
26
|
+
Returns
|
|
27
|
+
-------
|
|
28
|
+
Tuple[Tuple[int, ...], str, List[int]]
|
|
29
|
+
New shape, new axes, indices of axes in the new axes order.
|
|
30
|
+
"""
|
|
31
|
+
indices = [axes_in.find(k) for k in ref_axes]
|
|
32
|
+
|
|
33
|
+
# remove all non-existing axes (index == -1)
|
|
34
|
+
new_indices = list(filter(lambda k: k != -1, indices))
|
|
35
|
+
|
|
36
|
+
# find axes order and get new shape
|
|
37
|
+
new_axes = [axes_in[ind] for ind in new_indices]
|
|
38
|
+
new_shape = tuple([shape_in[ind] for ind in new_indices])
|
|
39
|
+
|
|
40
|
+
return new_shape, "".join(new_axes), new_indices
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def reshape_array(x: np.ndarray, axes: str) -> np.ndarray:
|
|
44
|
+
"""Reshape the data to (S, C, (Z), Y, X) by moving axes.
|
|
45
|
+
|
|
46
|
+
If the data has both S and T axes, the two axes will be merged. A singleton
|
|
47
|
+
dimension is added if there are no C axis.
|
|
48
|
+
|
|
49
|
+
Parameters
|
|
50
|
+
----------
|
|
51
|
+
x : np.ndarray
|
|
52
|
+
Input array.
|
|
53
|
+
axes : str
|
|
54
|
+
Description of axes in format `STCZYX`.
|
|
55
|
+
|
|
56
|
+
Returns
|
|
57
|
+
-------
|
|
58
|
+
np.ndarray
|
|
59
|
+
Reshaped array with shape (S, C, (Z), Y, X).
|
|
60
|
+
"""
|
|
61
|
+
_x = x
|
|
62
|
+
_axes = axes
|
|
63
|
+
|
|
64
|
+
# sanity checks
|
|
65
|
+
if len(_axes) != len(_x.shape):
|
|
66
|
+
raise ValueError(
|
|
67
|
+
f"Incompatible data shape ({_x.shape}) and axes ({_axes}). Are the axes "
|
|
68
|
+
f"correct?"
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# get new x shape
|
|
72
|
+
new_x_shape, new_axes, indices = _get_shape_order(_x.shape, _axes)
|
|
73
|
+
|
|
74
|
+
# if S is not in the list of axes, then add a singleton S
|
|
75
|
+
if "S" not in new_axes:
|
|
76
|
+
new_axes = "S" + new_axes
|
|
77
|
+
_x = _x[np.newaxis, ...]
|
|
78
|
+
new_x_shape = (1,) + new_x_shape
|
|
79
|
+
|
|
80
|
+
# need to change the array of indices
|
|
81
|
+
indices = [0] + [1 + i for i in indices]
|
|
82
|
+
|
|
83
|
+
# reshape by moving axes
|
|
84
|
+
destination = list(range(len(indices)))
|
|
85
|
+
_x = np.moveaxis(_x, indices, destination)
|
|
86
|
+
|
|
87
|
+
# remove T if necessary
|
|
88
|
+
if "T" in new_axes:
|
|
89
|
+
new_x_shape = (-1,) + new_x_shape[2:] # remove T and S
|
|
90
|
+
new_axes = new_axes.replace("T", "")
|
|
91
|
+
|
|
92
|
+
# reshape S and T together
|
|
93
|
+
_x = _x.reshape(new_x_shape)
|
|
94
|
+
|
|
95
|
+
# add channel
|
|
96
|
+
if "C" not in new_axes:
|
|
97
|
+
# Add channel axis after S
|
|
98
|
+
_x = np.expand_dims(_x, new_axes.index("S") + 1)
|
|
99
|
+
|
|
100
|
+
return _x
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
from fnmatch import fnmatch
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import List, Union
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from careamics.config.support import SupportedData
|
|
8
|
+
from careamics.utils.logging import get_logger
|
|
9
|
+
|
|
10
|
+
logger = get_logger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def get_files_size(files: List[Path]) -> float:
|
|
14
|
+
"""
|
|
15
|
+
Get files size in MB.
|
|
16
|
+
|
|
17
|
+
Parameters
|
|
18
|
+
----------
|
|
19
|
+
files : List[Path]
|
|
20
|
+
List of files.
|
|
21
|
+
|
|
22
|
+
Returns
|
|
23
|
+
-------
|
|
24
|
+
float
|
|
25
|
+
Total size of the files in MB.
|
|
26
|
+
"""
|
|
27
|
+
return np.sum([f.stat().st_size / 1024**2 for f in files])
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def list_files(
|
|
31
|
+
data_path: Union[str, Path],
|
|
32
|
+
data_type: Union[str, SupportedData],
|
|
33
|
+
extension_filter: str = "",
|
|
34
|
+
) -> List[Path]:
|
|
35
|
+
"""Creates a recursive list of files in `data_path`.
|
|
36
|
+
|
|
37
|
+
If `data_path` is a file, its name is validated against the `data_type` using
|
|
38
|
+
`fnmatch`, and the method returns `data_path` itself.
|
|
39
|
+
|
|
40
|
+
By default, if `data_type` is equal to `custom`, all files will be listed. To
|
|
41
|
+
further filter the files, use `extension_filter`.
|
|
42
|
+
|
|
43
|
+
`extension_filter` must be compatible with `fnmatch` and `Path.rglob`, e.g. "*.npy"
|
|
44
|
+
or "*.czi".
|
|
45
|
+
|
|
46
|
+
Parameters
|
|
47
|
+
----------
|
|
48
|
+
data_path : Union[str, Path]
|
|
49
|
+
Path to the folder containing the data.
|
|
50
|
+
data_type : Union[str, SupportedData]
|
|
51
|
+
One of the supported data type (e.g. tif, custom).
|
|
52
|
+
extension_filter : str, optional
|
|
53
|
+
Extension filter, by default "".
|
|
54
|
+
|
|
55
|
+
Returns
|
|
56
|
+
-------
|
|
57
|
+
List[Path]
|
|
58
|
+
List of pathlib.Path objects.
|
|
59
|
+
|
|
60
|
+
Raises
|
|
61
|
+
------
|
|
62
|
+
FileNotFoundError
|
|
63
|
+
If the data path does not exist.
|
|
64
|
+
ValueError
|
|
65
|
+
If the data path is empty or no files with the extension were found.
|
|
66
|
+
ValueError
|
|
67
|
+
If the file does not match the requested extension.
|
|
68
|
+
"""
|
|
69
|
+
# convert to Path
|
|
70
|
+
data_path = Path(data_path)
|
|
71
|
+
|
|
72
|
+
# raise error if does not exists
|
|
73
|
+
if not data_path.exists():
|
|
74
|
+
raise FileNotFoundError(f"Data path {data_path} does not exist.")
|
|
75
|
+
|
|
76
|
+
# get extension compatible with fnmatch and rglob search
|
|
77
|
+
extension = SupportedData.get_extension(data_type)
|
|
78
|
+
|
|
79
|
+
if data_type == SupportedData.CUSTOM and extension_filter != "":
|
|
80
|
+
extension = extension_filter
|
|
81
|
+
|
|
82
|
+
# search recurively
|
|
83
|
+
if data_path.is_dir():
|
|
84
|
+
# search recursively the path for files with the extension
|
|
85
|
+
files = sorted(data_path.rglob(extension))
|
|
86
|
+
else:
|
|
87
|
+
# raise error if it has the wrong extension
|
|
88
|
+
if not fnmatch(str(data_path.absolute()), extension):
|
|
89
|
+
raise ValueError(
|
|
90
|
+
f"File {data_path} does not match the requested extension "
|
|
91
|
+
f'"{extension}".'
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
# save in list
|
|
95
|
+
files = [data_path]
|
|
96
|
+
|
|
97
|
+
# raise error if no files were found
|
|
98
|
+
if len(files) == 0:
|
|
99
|
+
raise ValueError(
|
|
100
|
+
f'Data path {data_path} is empty or files with extension "{extension}" '
|
|
101
|
+
f"were not found."
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
return files
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def validate_source_target_files(src_files: List[Path], tar_files: List[Path]) -> None:
|
|
108
|
+
"""
|
|
109
|
+
Validate source and target path lists.
|
|
110
|
+
|
|
111
|
+
The two lists should have the same number of files, and the filenames should match.
|
|
112
|
+
|
|
113
|
+
Parameters
|
|
114
|
+
----------
|
|
115
|
+
src_files : List[Path]
|
|
116
|
+
List of source files.
|
|
117
|
+
tar_files : List[Path]
|
|
118
|
+
List of target files.
|
|
119
|
+
|
|
120
|
+
Raises
|
|
121
|
+
------
|
|
122
|
+
ValueError
|
|
123
|
+
If the number of files in source and target folders is not the same.
|
|
124
|
+
ValueError
|
|
125
|
+
If some filenames in Train and target folders are not the same.
|
|
126
|
+
"""
|
|
127
|
+
# check equal length
|
|
128
|
+
if len(src_files) != len(tar_files):
|
|
129
|
+
raise ValueError(
|
|
130
|
+
f"The number of source files ({len(src_files)}) is not equal to the number "
|
|
131
|
+
f"of target files ({len(tar_files)})."
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# check identical names
|
|
135
|
+
src_names = {f.name for f in src_files}
|
|
136
|
+
tar_names = {f.name for f in tar_files}
|
|
137
|
+
difference = src_names.symmetric_difference(tar_names)
|
|
138
|
+
|
|
139
|
+
if len(difference) > 0:
|
|
140
|
+
raise ValueError(f"Source and target files have different names: {difference}.")
|