careamics 0.1.0rc5__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (118) hide show
  1. careamics/callbacks/hyperparameters_callback.py +10 -3
  2. careamics/callbacks/progress_bar_callback.py +37 -4
  3. careamics/careamist.py +164 -231
  4. careamics/config/algorithm_model.py +5 -18
  5. careamics/config/architectures/architecture_model.py +7 -0
  6. careamics/config/architectures/custom_model.py +11 -4
  7. careamics/config/architectures/register_model.py +3 -1
  8. careamics/config/architectures/unet_model.py +2 -0
  9. careamics/config/architectures/vae_model.py +2 -0
  10. careamics/config/callback_model.py +3 -15
  11. careamics/config/configuration_example.py +4 -5
  12. careamics/config/configuration_factory.py +27 -41
  13. careamics/config/configuration_model.py +11 -11
  14. careamics/config/data_model.py +89 -63
  15. careamics/config/inference_model.py +28 -81
  16. careamics/config/optimizer_models.py +11 -11
  17. careamics/config/support/__init__.py +0 -2
  18. careamics/config/support/supported_activations.py +2 -0
  19. careamics/config/support/supported_algorithms.py +3 -1
  20. careamics/config/support/supported_architectures.py +2 -0
  21. careamics/config/support/supported_data.py +2 -0
  22. careamics/config/support/supported_loggers.py +2 -0
  23. careamics/config/support/supported_losses.py +2 -0
  24. careamics/config/support/supported_optimizers.py +2 -0
  25. careamics/config/support/supported_pixel_manipulations.py +3 -3
  26. careamics/config/support/supported_struct_axis.py +2 -0
  27. careamics/config/support/supported_transforms.py +4 -16
  28. careamics/config/tile_information.py +28 -58
  29. careamics/config/transformations/__init__.py +3 -2
  30. careamics/config/transformations/normalize_model.py +32 -4
  31. careamics/config/transformations/xy_flip_model.py +43 -0
  32. careamics/config/transformations/xy_random_rotate90_model.py +11 -3
  33. careamics/config/validators/validator_utils.py +1 -1
  34. careamics/conftest.py +12 -0
  35. careamics/dataset/__init__.py +12 -1
  36. careamics/dataset/dataset_utils/__init__.py +8 -1
  37. careamics/dataset/dataset_utils/dataset_utils.py +4 -4
  38. careamics/dataset/dataset_utils/file_utils.py +4 -3
  39. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  40. careamics/dataset/dataset_utils/read_tiff.py +6 -11
  41. careamics/dataset/dataset_utils/read_utils.py +2 -0
  42. careamics/dataset/dataset_utils/read_zarr.py +11 -7
  43. careamics/dataset/dataset_utils/running_stats.py +186 -0
  44. careamics/dataset/in_memory_dataset.py +88 -154
  45. careamics/dataset/in_memory_pred_dataset.py +88 -0
  46. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  47. careamics/dataset/iterable_dataset.py +121 -191
  48. careamics/dataset/iterable_pred_dataset.py +121 -0
  49. careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
  50. careamics/dataset/patching/patching.py +109 -39
  51. careamics/dataset/patching/random_patching.py +17 -6
  52. careamics/dataset/patching/sequential_patching.py +14 -8
  53. careamics/dataset/patching/validate_patch_dimension.py +7 -3
  54. careamics/dataset/tiling/__init__.py +10 -0
  55. careamics/dataset/tiling/collate_tiles.py +33 -0
  56. careamics/dataset/{patching → tiling}/tiled_patching.py +7 -5
  57. careamics/dataset/zarr_dataset.py +2 -0
  58. careamics/lightning_datamodule.py +46 -25
  59. careamics/lightning_module.py +19 -9
  60. careamics/lightning_prediction_datamodule.py +54 -84
  61. careamics/losses/__init__.py +2 -3
  62. careamics/losses/loss_factory.py +1 -1
  63. careamics/losses/losses.py +11 -7
  64. careamics/lvae_training/__init__.py +0 -0
  65. careamics/lvae_training/data_modules.py +1220 -0
  66. careamics/lvae_training/data_utils.py +618 -0
  67. careamics/lvae_training/eval_utils.py +905 -0
  68. careamics/lvae_training/get_config.py +84 -0
  69. careamics/lvae_training/lightning_module.py +701 -0
  70. careamics/lvae_training/metrics.py +214 -0
  71. careamics/lvae_training/train_lvae.py +339 -0
  72. careamics/lvae_training/train_utils.py +121 -0
  73. careamics/model_io/bioimage/model_description.py +40 -32
  74. careamics/model_io/bmz_io.py +3 -3
  75. careamics/model_io/model_io_utils.py +5 -2
  76. careamics/models/activation.py +2 -0
  77. careamics/models/layers.py +121 -25
  78. careamics/models/lvae/__init__.py +0 -0
  79. careamics/models/lvae/layers.py +1998 -0
  80. careamics/models/lvae/likelihoods.py +312 -0
  81. careamics/models/lvae/lvae.py +985 -0
  82. careamics/models/lvae/noise_models.py +409 -0
  83. careamics/models/lvae/utils.py +395 -0
  84. careamics/models/model_factory.py +1 -1
  85. careamics/models/unet.py +35 -14
  86. careamics/prediction_utils/__init__.py +12 -0
  87. careamics/prediction_utils/create_pred_datamodule.py +185 -0
  88. careamics/prediction_utils/prediction_outputs.py +165 -0
  89. careamics/prediction_utils/stitch_prediction.py +100 -0
  90. careamics/transforms/__init__.py +2 -2
  91. careamics/transforms/compose.py +33 -7
  92. careamics/transforms/n2v_manipulate.py +52 -14
  93. careamics/transforms/normalize.py +171 -48
  94. careamics/transforms/pixel_manipulation.py +35 -11
  95. careamics/transforms/struct_mask_parameters.py +3 -1
  96. careamics/transforms/transform.py +10 -19
  97. careamics/transforms/tta.py +43 -29
  98. careamics/transforms/xy_flip.py +123 -0
  99. careamics/transforms/xy_random_rotate90.py +38 -5
  100. careamics/utils/base_enum.py +28 -0
  101. careamics/utils/path_utils.py +2 -0
  102. careamics/utils/ram.py +4 -2
  103. careamics/utils/receptive_field.py +93 -87
  104. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +8 -6
  105. careamics-0.1.0rc7.dist-info/RECORD +130 -0
  106. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
  107. careamics/config/noise_models.py +0 -162
  108. careamics/config/support/supported_extraction_strategies.py +0 -25
  109. careamics/config/transformations/nd_flip_model.py +0 -27
  110. careamics/lightning_prediction_loop.py +0 -116
  111. careamics/losses/noise_model_factory.py +0 -40
  112. careamics/losses/noise_models.py +0 -524
  113. careamics/prediction/__init__.py +0 -7
  114. careamics/prediction/stitch_prediction.py +0 -74
  115. careamics/transforms/nd_flip.py +0 -67
  116. careamics/utils/running_stats.py +0 -43
  117. careamics-0.1.0rc5.dist-info/RECORD +0 -111
  118. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
@@ -1,23 +1,11 @@
1
+ """Transforms supported by CAREamics."""
2
+
1
3
  from careamics.utils import BaseEnum
2
4
 
3
5
 
4
6
  class SupportedTransform(str, BaseEnum):
5
- """Transforms officially supported by CAREamics.
6
-
7
- - Flip: from Albumentations, randomly flip the input horizontally, vertically or
8
- both, parameter `p` can be used to set the probability to apply the transform.
9
- - XYRandomRotate90: #TODO
10
- - Normalize # TODO add details, in particular about the parameters
11
- - ManipulateN2V # TODO add details, in particular about the parameters
12
- - NDFlip
13
-
14
- Note that while any Albumentations (see https://albumentations.ai/) transform can be
15
- used in CAREamics, no check are implemented to verify the compatibility of any other
16
- transforms than the ones officially supported.
17
- """
7
+ """Transforms officially supported by CAREamics."""
18
8
 
19
- NDFLIP = "NDFlip"
9
+ XY_FLIP = "XYFlip"
20
10
  XY_RANDOM_ROTATE90 = "XYRandomRotate90"
21
- NORMALIZE = "Normalize"
22
11
  N2V_MANIPULATE = "N2VManipulate"
23
- # CUSTOM = "Custom"
@@ -1,8 +1,8 @@
1
- from __future__ import annotations
1
+ """Pydantic model representing the metadata of a prediction tile."""
2
2
 
3
- from typing import Optional, Tuple
3
+ from __future__ import annotations
4
4
 
5
- from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
5
+ from pydantic import BaseModel, ConfigDict, field_validator
6
6
 
7
7
 
8
8
  class TileInformation(BaseModel):
@@ -11,30 +11,33 @@ class TileInformation(BaseModel):
11
11
 
12
12
  This model is used to represent the information required to stitch back a tile into
13
13
  a larger image. It is used throughout the prediction pipeline of CAREamics.
14
+
15
+ Array shape should be (C)(Z)YX, where C and Z are optional dimensions, and must not
16
+ contain singleton dimensions.
14
17
  """
15
18
 
16
19
  model_config = ConfigDict(validate_default=True)
17
20
 
18
- array_shape: Tuple[int, ...]
19
- tiled: bool = False
21
+ array_shape: tuple[int, ...]
20
22
  last_tile: bool = False
21
- overlap_crop_coords: Optional[Tuple[Tuple[int, ...], ...]] = Field(default=None)
22
- stitch_coords: Optional[Tuple[Tuple[int, ...], ...]] = Field(default=None)
23
+ overlap_crop_coords: tuple[tuple[int, ...], ...]
24
+ stitch_coords: tuple[tuple[int, ...], ...]
25
+ sample_id: int
23
26
 
24
27
  @field_validator("array_shape")
25
28
  @classmethod
26
- def no_singleton_dimensions(cls, v: Tuple[int, ...]):
29
+ def no_singleton_dimensions(cls, v: tuple[int, ...]):
27
30
  """
28
31
  Check that the array shape does not have any singleton dimensions.
29
32
 
30
33
  Parameters
31
34
  ----------
32
- v : Tuple[int, ...]
35
+ v : tuple of int
33
36
  Array shape to check.
34
37
 
35
38
  Returns
36
39
  -------
37
- Tuple[int, ...]
40
+ tuple of int
38
41
  The array shape if it does not contain singleton dimensions.
39
42
 
40
43
  Raises
@@ -46,59 +49,26 @@ class TileInformation(BaseModel):
46
49
  raise ValueError("Array shape must not contain singleton dimensions.")
47
50
  return v
48
51
 
49
- @field_validator("last_tile")
50
- @classmethod
51
- def only_if_tiled(cls, v: bool, values: ValidationInfo):
52
- """
53
- Check that the last tile flag is only set if tiling is enabled.
52
+ def __eq__(self, other_tile: object):
53
+ """Check if two tile information objects are equal.
54
54
 
55
55
  Parameters
56
56
  ----------
57
- v : bool
58
- Last tile flag.
59
- values : ValidationInfo
60
- Validation information.
57
+ other_tile : object
58
+ Tile information object to compare with.
61
59
 
62
60
  Returns
63
61
  -------
64
62
  bool
65
- The last tile flag.
66
- """
67
- if not values.data["tiled"]:
68
- return False
69
- return v
70
-
71
- @field_validator("overlap_crop_coords", "stitch_coords")
72
- @classmethod
73
- def mandatory_if_tiled(
74
- cls, v: Optional[Tuple[int, ...]], values: ValidationInfo
75
- ) -> Optional[Tuple[int, ...]]:
76
- """
77
- Check that the coordinates are not `None` if tiling is enabled.
78
-
79
- The method also return `None` if tiling is not enabled.
80
-
81
- Parameters
82
- ----------
83
- v : Optional[Tuple[int, ...]]
84
- Coordinates to check.
85
- values : ValidationInfo
86
- Validation information.
87
-
88
- Returns
89
- -------
90
- Optional[Tuple[int, ...]]
91
- The coordinates if tiling is enabled, otherwise `None`.
92
-
93
- Raises
94
- ------
95
- ValueError
96
- If the coordinates are `None` and tiling is enabled.
63
+ Whether the two tile information objects are equal.
97
64
  """
98
- if values.data["tiled"]:
99
- if v is None:
100
- raise ValueError("Value must be specified if tiling is enabled.")
101
-
102
- return v
103
- else:
104
- return None
65
+ if not isinstance(other_tile, TileInformation):
66
+ return NotImplemented
67
+
68
+ return (
69
+ self.array_shape == other_tile.array_shape
70
+ and self.last_tile == other_tile.last_tile
71
+ and self.overlap_crop_coords == other_tile.overlap_crop_coords
72
+ and self.stitch_coords == other_tile.stitch_coords
73
+ and self.sample_id == other_tile.sample_id
74
+ )
@@ -2,13 +2,14 @@
2
2
 
3
3
  __all__ = [
4
4
  "N2VManipulateModel",
5
- "NDFlipModel",
5
+ "XYFlipModel",
6
6
  "NormalizeModel",
7
7
  "XYRandomRotate90Model",
8
+ "XorYFlipModel",
8
9
  ]
9
10
 
10
11
 
11
12
  from .n2v_manipulate_model import N2VManipulateModel
12
- from .nd_flip_model import NDFlipModel
13
13
  from .normalize_model import NormalizeModel
14
+ from .xy_flip_model import XYFlipModel
14
15
  from .xy_random_rotate90_model import XYRandomRotate90Model
@@ -1,8 +1,9 @@
1
1
  """Pydantic model for the Normalize transform."""
2
2
 
3
- from typing import Literal
3
+ from typing import Literal, Optional
4
4
 
5
- from pydantic import ConfigDict, Field
5
+ from pydantic import ConfigDict, Field, model_validator
6
+ from typing_extensions import Self
6
7
 
7
8
  from .transform_model import TransformModel
8
9
 
@@ -28,5 +29,32 @@ class NormalizeModel(TransformModel):
28
29
  )
29
30
 
30
31
  name: Literal["Normalize"] = "Normalize"
31
- mean: float = Field(default=0.485) # albumentations defaults
32
- std: float = Field(default=0.229)
32
+ image_means: list = Field(..., min_length=0, max_length=32)
33
+ image_stds: list = Field(..., min_length=0, max_length=32)
34
+ target_means: Optional[list] = Field(default=None, min_length=0, max_length=32)
35
+ target_stds: Optional[list] = Field(default=None, min_length=0, max_length=32)
36
+
37
+ @model_validator(mode="after")
38
+ def validate_means_stds(self: Self) -> Self:
39
+ """Validate that the means and stds have the same length.
40
+
41
+ Returns
42
+ -------
43
+ Self
44
+ The instance of the model.
45
+ """
46
+ if len(self.image_means) != len(self.image_stds):
47
+ raise ValueError("The number of image means and stds must be the same.")
48
+
49
+ if (self.target_means is None) != (self.target_stds is None):
50
+ raise ValueError(
51
+ "Both target means and stds must be provided together, or bot None."
52
+ )
53
+
54
+ if self.target_means is not None and self.target_stds is not None:
55
+ if len(self.target_means) != len(self.target_stds):
56
+ raise ValueError(
57
+ "The number of target means and stds must be the same."
58
+ )
59
+
60
+ return self
@@ -0,0 +1,43 @@
1
+ """Pydantic model for the XYFlip transform."""
2
+
3
+ from typing import Literal, Optional
4
+
5
+ from pydantic import ConfigDict, Field
6
+
7
+ from .transform_model import TransformModel
8
+
9
+
10
+ class XYFlipModel(TransformModel):
11
+ """
12
+ Pydantic model used to represent XYFlip transformation.
13
+
14
+ Attributes
15
+ ----------
16
+ name : Literal["XYFlip"]
17
+ Name of the transformation.
18
+ p : float
19
+ Probability of applying the transform, by default 0.5.
20
+ seed : Optional[int]
21
+ Seed for the random number generator, by default None.
22
+ """
23
+
24
+ model_config = ConfigDict(
25
+ validate_assignment=True,
26
+ )
27
+
28
+ name: Literal["XYFlip"] = "XYFlip"
29
+ flip_x: bool = Field(
30
+ True,
31
+ description="Whether to flip along the X axis.",
32
+ )
33
+ flip_y: bool = Field(
34
+ True,
35
+ description="Whether to flip along the Y axis.",
36
+ )
37
+ p: float = Field(
38
+ 0.5,
39
+ description="Probability of applying the transform.",
40
+ ge=0,
41
+ le=1,
42
+ )
43
+ seed: Optional[int] = None
@@ -2,21 +2,23 @@
2
2
 
3
3
  from typing import Literal, Optional
4
4
 
5
- from pydantic import ConfigDict
5
+ from pydantic import ConfigDict, Field
6
6
 
7
7
  from .transform_model import TransformModel
8
8
 
9
9
 
10
10
  class XYRandomRotate90Model(TransformModel):
11
11
  """
12
- Pydantic model used to represent NDFlip transformation.
12
+ Pydantic model used to represent the XY random 90 degree rotation transformation.
13
13
 
14
14
  Attributes
15
15
  ----------
16
16
  name : Literal["XYRandomRotate90"]
17
17
  Name of the transformation.
18
+ p : float
19
+ Probability of applying the transform, by default 0.5.
18
20
  seed : Optional[int]
19
- Seed for the random number generator.
21
+ Seed for the random number generator, by default None.
20
22
  """
21
23
 
22
24
  model_config = ConfigDict(
@@ -24,4 +26,10 @@ class XYRandomRotate90Model(TransformModel):
24
26
  )
25
27
 
26
28
  name: Literal["XYRandomRotate90"] = "XYRandomRotate90"
29
+ p: float = Field(
30
+ 0.5,
31
+ description="Probability of applying the transform.",
32
+ ge=0,
33
+ le=1,
34
+ )
27
35
  seed: Optional[int] = None
@@ -72,7 +72,7 @@ def value_ge_than_8_power_of_2(
72
72
  If the value is not a power of 2.
73
73
  """
74
74
  if value < 8:
75
- raise ValueError(f"Value must be non-zero positive (got {value}).")
75
+ raise ValueError(f"Value must be greater than 8 (got {value}).")
76
76
 
77
77
  if (value & (value - 1)) != 0:
78
78
  raise ValueError(f"Value must be a power of 2 (got {value}).")
careamics/conftest.py CHANGED
@@ -14,6 +14,18 @@ from sybil.parsers.doctest import DocTestParser
14
14
 
15
15
  @pytest.fixture(scope="module")
16
16
  def my_path(tmpdir_factory: TempPathFactory) -> Path:
17
+ """Fixture used in doctest to create a temporary directory.
18
+
19
+ Parameters
20
+ ----------
21
+ tmpdir_factory : TempPathFactory
22
+ Temporary path factory from pytest.
23
+
24
+ Returns
25
+ -------
26
+ Path
27
+ Temporary directory path.
28
+ """
17
29
  return tmpdir_factory.mktemp("my_path")
18
30
 
19
31
 
@@ -1,6 +1,17 @@
1
1
  """Dataset module."""
2
2
 
3
- __all__ = ["InMemoryDataset", "PathIterableDataset"]
3
+ __all__ = [
4
+ "InMemoryDataset",
5
+ "InMemoryPredDataset",
6
+ "InMemoryTiledPredDataset",
7
+ "PathIterableDataset",
8
+ "IterableTiledPredDataset",
9
+ "IterablePredDataset",
10
+ ]
4
11
 
5
12
  from .in_memory_dataset import InMemoryDataset
13
+ from .in_memory_pred_dataset import InMemoryPredDataset
14
+ from .in_memory_tiled_pred_dataset import InMemoryTiledPredDataset
6
15
  from .iterable_dataset import PathIterableDataset
16
+ from .iterable_pred_dataset import IterablePredDataset
17
+ from .iterable_tiled_pred_dataset import IterableTiledPredDataset
@@ -2,17 +2,24 @@
2
2
 
3
3
  __all__ = [
4
4
  "reshape_array",
5
+ "compute_normalization_stats",
5
6
  "get_files_size",
6
7
  "list_files",
7
8
  "validate_source_target_files",
8
9
  "read_tiff",
9
10
  "get_read_func",
10
11
  "read_zarr",
12
+ "iterate_over_files",
13
+ "WelfordStatistics",
11
14
  ]
12
15
 
13
16
 
14
- from .dataset_utils import reshape_array
17
+ from .dataset_utils import (
18
+ reshape_array,
19
+ )
15
20
  from .file_utils import get_files_size, list_files, validate_source_target_files
21
+ from .iterate_over_files import iterate_over_files
16
22
  from .read_tiff import read_tiff
17
23
  from .read_utils import get_read_func
18
24
  from .read_zarr import read_zarr
25
+ from .running_stats import WelfordStatistics, compute_normalization_stats
@@ -1,4 +1,4 @@
1
- """Convenience methods for datasets."""
1
+ """Dataset utilities."""
2
2
 
3
3
  from typing import List, Tuple
4
4
 
@@ -17,12 +17,12 @@ def _get_shape_order(
17
17
 
18
18
  Parameters
19
19
  ----------
20
- shape_in : Tuple
20
+ shape_in : Tuple[int, ...]
21
21
  Input shape.
22
- ref_axes : str
23
- Reference axes.
24
22
  axes_in : str
25
23
  Input axes.
24
+ ref_axes : str
25
+ Reference axes.
26
26
 
27
27
  Returns
28
28
  -------
@@ -1,3 +1,5 @@
1
+ """File utilities."""
2
+
1
3
  from fnmatch import fnmatch
2
4
  from pathlib import Path
3
5
  from typing import List, Union
@@ -11,8 +13,7 @@ logger = get_logger(__name__)
11
13
 
12
14
 
13
15
  def get_files_size(files: List[Path]) -> float:
14
- """
15
- Get files size in MB.
16
+ """Get files size in MB.
16
17
 
17
18
  Parameters
18
19
  ----------
@@ -32,7 +33,7 @@ def list_files(
32
33
  data_type: Union[str, SupportedData],
33
34
  extension_filter: str = "",
34
35
  ) -> List[Path]:
35
- """Creates a recursive list of files in `data_path`.
36
+ """List recursively files in `data_path` and return a sorted list.
36
37
 
37
38
  If `data_path` is a file, its name is validated against the `data_type` using
38
39
  `fnmatch`, and the method returns `data_path` itself.
@@ -0,0 +1,83 @@
1
+ """Function to iterate over files."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Callable, Generator, Optional, Union
7
+
8
+ from numpy.typing import NDArray
9
+ from torch.utils.data import get_worker_info
10
+
11
+ from careamics.config import DataConfig, InferenceConfig
12
+ from careamics.utils.logging import get_logger
13
+
14
+ from .dataset_utils import reshape_array
15
+ from .read_tiff import read_tiff
16
+
17
+ logger = get_logger(__name__)
18
+
19
+
20
+ def iterate_over_files(
21
+ data_config: Union[DataConfig, InferenceConfig],
22
+ data_files: list[Path],
23
+ target_files: Optional[list[Path]] = None,
24
+ read_source_func: Callable = read_tiff,
25
+ ) -> Generator[tuple[NDArray, Optional[NDArray]], None, None]:
26
+ """Iterate over data source and yield whole reshaped images.
27
+
28
+ Parameters
29
+ ----------
30
+ data_config : CAREamics DataConfig or InferenceConfig
31
+ Configuration.
32
+ data_files : list of pathlib.Path
33
+ List of data files.
34
+ target_files : list of pathlib.Path, optional
35
+ List of target files, by default None.
36
+ read_source_func : Callable, optional
37
+ Function to read the source, by default read_tiff.
38
+
39
+ Yields
40
+ ------
41
+ NDArray
42
+ Image.
43
+ """
44
+ # When num_workers > 0, each worker process will have a different copy of the
45
+ # dataset object
46
+ # Configuring each copy independently to avoid having duplicate data returned
47
+ # from the workers
48
+ worker_info = get_worker_info()
49
+ worker_id = worker_info.id if worker_info is not None else 0
50
+ num_workers = worker_info.num_workers if worker_info is not None else 1
51
+
52
+ # iterate over the files
53
+ for i, filename in enumerate(data_files):
54
+ # retrieve file corresponding to the worker id
55
+ if i % num_workers == worker_id:
56
+ try:
57
+ # read data
58
+ sample = read_source_func(filename, data_config.axes)
59
+
60
+ # reshape array
61
+ reshaped_sample = reshape_array(sample, data_config.axes)
62
+
63
+ # read target, if available
64
+ if target_files is not None:
65
+ if filename.name != target_files[i].name:
66
+ raise ValueError(
67
+ f"File {filename} does not match target file "
68
+ f"{target_files[i]}. Have you passed sorted "
69
+ f"arrays?"
70
+ )
71
+
72
+ # read target
73
+ target = read_source_func(target_files[i], data_config.axes)
74
+
75
+ # reshape target
76
+ reshaped_target = reshape_array(target, data_config.axes)
77
+
78
+ yield reshaped_sample, reshaped_target
79
+ else:
80
+ yield reshaped_sample, None
81
+
82
+ except Exception as e:
83
+ logger.error(f"Error reading file {filename}: {e}")
@@ -1,3 +1,5 @@
1
+ """Funtions to read tiff images."""
2
+
1
3
  import logging
2
4
  from fnmatch import fnmatch
3
5
  from pathlib import Path
@@ -19,8 +21,10 @@ def read_tiff(file_path: Path, *args: list, **kwargs: dict) -> np.ndarray:
19
21
  ----------
20
22
  file_path : Path
21
23
  Path to a file.
22
- axes : str
23
- Description of axes in format STCZYX.
24
+ *args : list
25
+ Additional arguments.
26
+ **kwargs : dict
27
+ Additional keyword arguments.
24
28
 
25
29
  Returns
26
30
  -------
@@ -49,13 +53,4 @@ def read_tiff(file_path: Path, *args: list, **kwargs: dict) -> np.ndarray:
49
53
  else:
50
54
  raise ValueError(f"File {file_path} is not a valid tiff.")
51
55
 
52
- # check dimensions
53
- # TODO or should this really be done here? probably in the LightningDataModule
54
- # TODO this should also be centralized somewhere else (validate_dimensions)
55
- if len(array.shape) < 2 or len(array.shape) > 6:
56
- raise ValueError(
57
- f"Incorrect data dimensions. Must be 2, 3 or 4 (got {array.shape} for"
58
- f"file {file_path})."
59
- )
60
-
61
56
  return array
@@ -1,3 +1,5 @@
1
+ """Read function utilities."""
2
+
1
3
  from typing import Callable, Union
2
4
 
3
5
  from careamics.config.support import SupportedData
@@ -1,3 +1,5 @@
1
+ """Function to read zarr images."""
2
+
1
3
  from typing import Union
2
4
 
3
5
  from zarr import Group, core, hierarchy, storage
@@ -6,26 +8,28 @@ from zarr import Group, core, hierarchy, storage
6
8
  def read_zarr(
7
9
  zarr_source: Group, axes: str
8
10
  ) -> Union[core.Array, storage.DirectoryStore, hierarchy.Group]:
9
- """Reads a file and returns a pointer.
11
+ """Read a file and returns a pointer.
10
12
 
11
13
  Parameters
12
14
  ----------
13
- file_path : Path
14
- pathlib.Path object containing a path to a file
15
+ zarr_source : Group
16
+ Zarr storage.
17
+ axes : str
18
+ Axes of the data.
15
19
 
16
20
  Returns
17
21
  -------
18
22
  np.ndarray
19
- Pointer to zarr storage
23
+ Pointer to zarr storage.
20
24
 
21
25
  Raises
22
26
  ------
23
27
  ValueError, OSError
24
- if a file is not a valid tiff or damaged
28
+ if a file is not a valid tiff or damaged.
25
29
  ValueError
26
- if data dimensions are not 2, 3 or 4
30
+ if data dimensions are not 2, 3 or 4.
27
31
  ValueError
28
- if axes parameter from config is not consistent with data dimensions
32
+ if axes parameter from config is not consistent with data dimensions.
29
33
  """
30
34
  if isinstance(zarr_source, hierarchy.Group):
31
35
  array = zarr_source[0]