careamics 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (141) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +726 -0
  3. careamics/config/__init__.py +35 -0
  4. careamics/config/algorithm_model.py +162 -0
  5. careamics/config/architectures/__init__.py +17 -0
  6. careamics/config/architectures/architecture_model.py +37 -0
  7. careamics/config/architectures/custom_model.py +159 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/architectures/vae_model.py +42 -0
  11. careamics/config/callback_model.py +123 -0
  12. careamics/config/configuration_factory.py +575 -0
  13. careamics/config/configuration_model.py +600 -0
  14. careamics/config/data_model.py +502 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/optimizer_models.py +187 -0
  17. careamics/config/references/__init__.py +45 -0
  18. careamics/config/references/algorithm_descriptions.py +132 -0
  19. careamics/config/references/references.py +39 -0
  20. careamics/config/support/__init__.py +31 -0
  21. careamics/config/support/supported_activations.py +26 -0
  22. careamics/config/support/supported_algorithms.py +20 -0
  23. careamics/config/support/supported_architectures.py +20 -0
  24. careamics/config/support/supported_data.py +109 -0
  25. careamics/config/support/supported_loggers.py +10 -0
  26. careamics/config/support/supported_losses.py +27 -0
  27. careamics/config/support/supported_optimizers.py +57 -0
  28. careamics/config/support/supported_pixel_manipulations.py +15 -0
  29. careamics/config/support/supported_struct_axis.py +21 -0
  30. careamics/config/support/supported_transforms.py +11 -0
  31. careamics/config/tile_information.py +65 -0
  32. careamics/config/training_model.py +72 -0
  33. careamics/config/transformations/__init__.py +15 -0
  34. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  35. careamics/config/transformations/normalize_model.py +60 -0
  36. careamics/config/transformations/transform_model.py +45 -0
  37. careamics/config/transformations/xy_flip_model.py +43 -0
  38. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  39. careamics/config/validators/__init__.py +5 -0
  40. careamics/config/validators/validator_utils.py +101 -0
  41. careamics/conftest.py +39 -0
  42. careamics/dataset/__init__.py +17 -0
  43. careamics/dataset/dataset_utils/__init__.py +19 -0
  44. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  45. careamics/dataset/dataset_utils/file_utils.py +141 -0
  46. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  47. careamics/dataset/dataset_utils/running_stats.py +186 -0
  48. careamics/dataset/in_memory_dataset.py +310 -0
  49. careamics/dataset/in_memory_pred_dataset.py +88 -0
  50. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  51. careamics/dataset/iterable_dataset.py +295 -0
  52. careamics/dataset/iterable_pred_dataset.py +122 -0
  53. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  54. careamics/dataset/patching/__init__.py +1 -0
  55. careamics/dataset/patching/patching.py +299 -0
  56. careamics/dataset/patching/random_patching.py +201 -0
  57. careamics/dataset/patching/sequential_patching.py +212 -0
  58. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  59. careamics/dataset/tiling/__init__.py +10 -0
  60. careamics/dataset/tiling/collate_tiles.py +33 -0
  61. careamics/dataset/tiling/tiled_patching.py +164 -0
  62. careamics/dataset/zarr_dataset.py +151 -0
  63. careamics/file_io/__init__.py +15 -0
  64. careamics/file_io/read/__init__.py +12 -0
  65. careamics/file_io/read/get_func.py +56 -0
  66. careamics/file_io/read/tiff.py +58 -0
  67. careamics/file_io/read/zarr.py +60 -0
  68. careamics/file_io/write/__init__.py +15 -0
  69. careamics/file_io/write/get_func.py +63 -0
  70. careamics/file_io/write/tiff.py +40 -0
  71. careamics/lightning/__init__.py +17 -0
  72. careamics/lightning/callbacks/__init__.py +11 -0
  73. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  74. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  75. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  76. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  77. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  79. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  80. careamics/lightning/lightning_module.py +276 -0
  81. careamics/lightning/predict_data_module.py +333 -0
  82. careamics/lightning/train_data_module.py +680 -0
  83. careamics/losses/__init__.py +5 -0
  84. careamics/losses/loss_factory.py +49 -0
  85. careamics/losses/losses.py +98 -0
  86. careamics/lvae_training/__init__.py +0 -0
  87. careamics/lvae_training/data_modules.py +1220 -0
  88. careamics/lvae_training/data_utils.py +618 -0
  89. careamics/lvae_training/eval_utils.py +905 -0
  90. careamics/lvae_training/get_config.py +84 -0
  91. careamics/lvae_training/lightning_module.py +701 -0
  92. careamics/lvae_training/metrics.py +214 -0
  93. careamics/lvae_training/train_lvae.py +339 -0
  94. careamics/lvae_training/train_utils.py +121 -0
  95. careamics/model_io/__init__.py +7 -0
  96. careamics/model_io/bioimage/__init__.py +11 -0
  97. careamics/model_io/bioimage/_readme_factory.py +121 -0
  98. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  99. careamics/model_io/bioimage/model_description.py +327 -0
  100. careamics/model_io/bmz_io.py +233 -0
  101. careamics/model_io/model_io_utils.py +83 -0
  102. careamics/models/__init__.py +7 -0
  103. careamics/models/activation.py +37 -0
  104. careamics/models/layers.py +493 -0
  105. careamics/models/lvae/__init__.py +0 -0
  106. careamics/models/lvae/layers.py +1998 -0
  107. careamics/models/lvae/likelihoods.py +312 -0
  108. careamics/models/lvae/lvae.py +985 -0
  109. careamics/models/lvae/noise_models.py +409 -0
  110. careamics/models/lvae/utils.py +395 -0
  111. careamics/models/model_factory.py +52 -0
  112. careamics/models/unet.py +443 -0
  113. careamics/prediction_utils/__init__.py +10 -0
  114. careamics/prediction_utils/prediction_outputs.py +135 -0
  115. careamics/prediction_utils/stitch_prediction.py +98 -0
  116. careamics/transforms/__init__.py +20 -0
  117. careamics/transforms/compose.py +107 -0
  118. careamics/transforms/n2v_manipulate.py +146 -0
  119. careamics/transforms/normalize.py +243 -0
  120. careamics/transforms/pixel_manipulation.py +407 -0
  121. careamics/transforms/struct_mask_parameters.py +20 -0
  122. careamics/transforms/transform.py +24 -0
  123. careamics/transforms/tta.py +88 -0
  124. careamics/transforms/xy_flip.py +123 -0
  125. careamics/transforms/xy_random_rotate90.py +101 -0
  126. careamics/utils/__init__.py +19 -0
  127. careamics/utils/autocorrelation.py +40 -0
  128. careamics/utils/base_enum.py +60 -0
  129. careamics/utils/context.py +66 -0
  130. careamics/utils/logging.py +322 -0
  131. careamics/utils/metrics.py +115 -0
  132. careamics/utils/path_utils.py +26 -0
  133. careamics/utils/ram.py +15 -0
  134. careamics/utils/receptive_field.py +108 -0
  135. careamics/utils/torch_utils.py +127 -0
  136. careamics-0.0.2.dist-info/METADATA +78 -0
  137. careamics-0.0.2.dist-info/RECORD +140 -0
  138. {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/WHEEL +1 -1
  139. {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/licenses/LICENSE +1 -1
  140. careamics-0.0.1.dist-info/METADATA +0 -46
  141. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,35 @@
1
+ """Pydantic model for the XYRandomRotate90 transform."""
2
+
3
+ from typing import Literal, Optional
4
+
5
+ from pydantic import ConfigDict, Field
6
+
7
+ from .transform_model import TransformModel
8
+
9
+
10
+ class XYRandomRotate90Model(TransformModel):
11
+ """
12
+ Pydantic model used to represent the XY random 90 degree rotation transformation.
13
+
14
+ Attributes
15
+ ----------
16
+ name : Literal["XYRandomRotate90"]
17
+ Name of the transformation.
18
+ p : float
19
+ Probability of applying the transform, by default 0.5.
20
+ seed : Optional[int]
21
+ Seed for the random number generator, by default None.
22
+ """
23
+
24
+ model_config = ConfigDict(
25
+ validate_assignment=True,
26
+ )
27
+
28
+ name: Literal["XYRandomRotate90"] = "XYRandomRotate90"
29
+ p: float = Field(
30
+ 0.5,
31
+ description="Probability of applying the transform.",
32
+ ge=0,
33
+ le=1,
34
+ )
35
+ seed: Optional[int] = None
@@ -0,0 +1,5 @@
1
+ """Validator utilities."""
2
+
3
+ __all__ = ["check_axes_validity", "patch_size_ge_than_8_power_of_2"]
4
+
5
+ from .validator_utils import check_axes_validity, patch_size_ge_than_8_power_of_2
@@ -0,0 +1,101 @@
1
+ """
2
+ Validator functions.
3
+
4
+ These functions are used to validate dimensions and axes of inputs.
5
+ """
6
+
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ _AXES = "STCZYX"
10
+
11
+
12
+ def check_axes_validity(axes: str) -> None:
13
+ """
14
+ Sanity check on axes.
15
+
16
+ The constraints on the axes are the following:
17
+ - must be a combination of 'STCZYX'
18
+ - must not contain duplicates
19
+ - must contain at least 2 contiguous axes: X and Y
20
+ - must contain at most 4 axes
21
+ - cannot contain both S and T axes
22
+
23
+ Axes do not need to be in the order 'STCZYX', as this depends on the user data.
24
+
25
+ Parameters
26
+ ----------
27
+ axes : str
28
+ Axes to validate.
29
+ """
30
+ _axes = axes.upper()
31
+
32
+ # Minimum is 2 (XY) and maximum is 4 (TZYX)
33
+ if len(_axes) < 2 or len(_axes) > 6:
34
+ raise ValueError(
35
+ f"Invalid axes {axes}. Must contain at least 2 and at most 6 axes."
36
+ )
37
+
38
+ if "YX" not in _axes and "XY" not in _axes:
39
+ raise ValueError(
40
+ f"Invalid axes {axes}. Must contain at least X and Y axes consecutively."
41
+ )
42
+
43
+ # all characters must be in REF_AXES = 'STCZYX'
44
+ if not all(s in _AXES for s in _axes):
45
+ raise ValueError(f"Invalid axes {axes}. Must be a combination of {_AXES}.")
46
+
47
+ # check for repeating characters
48
+ for i, s in enumerate(_axes):
49
+ if i != _axes.rfind(s):
50
+ raise ValueError(
51
+ f"Invalid axes {axes}. Cannot contain duplicate axes"
52
+ f" (got multiple {axes[i]})."
53
+ )
54
+
55
+
56
+ def value_ge_than_8_power_of_2(
57
+ value: int,
58
+ ) -> None:
59
+ """
60
+ Validate that the value is greater or equal than 8 and a power of 2.
61
+
62
+ Parameters
63
+ ----------
64
+ value : int
65
+ Value to validate.
66
+
67
+ Raises
68
+ ------
69
+ ValueError
70
+ If the value is smaller than 8.
71
+ ValueError
72
+ If the value is not a power of 2.
73
+ """
74
+ if value < 8:
75
+ raise ValueError(f"Value must be greater than 8 (got {value}).")
76
+
77
+ if (value & (value - 1)) != 0:
78
+ raise ValueError(f"Value must be a power of 2 (got {value}).")
79
+
80
+
81
+ def patch_size_ge_than_8_power_of_2(
82
+ patch_list: Optional[Union[List[int], Union[Tuple[int, ...]]]],
83
+ ) -> None:
84
+ """
85
+ Validate that each entry is greater or equal than 8 and a power of 2.
86
+
87
+ Parameters
88
+ ----------
89
+ patch_list : Optional[Union[List[int]]]
90
+ Patch size.
91
+
92
+ Raises
93
+ ------
94
+ ValueError
95
+ If the patch size if smaller than 8.
96
+ ValueError
97
+ If the patch size is not a power of 2.
98
+ """
99
+ if patch_list is not None:
100
+ for dim in patch_list:
101
+ value_ge_than_8_power_of_2(dim)
careamics/conftest.py ADDED
@@ -0,0 +1,39 @@
1
+ """File used to discover python modules and run doctest.
2
+
3
+ See https://sybil.readthedocs.io/en/latest/use.html#pytest
4
+ """
5
+
6
+ from pathlib import Path
7
+
8
+ import pytest
9
+ from pytest import TempPathFactory
10
+ from sybil import Sybil
11
+ from sybil.parsers.codeblock import PythonCodeBlockParser
12
+ from sybil.parsers.doctest import DocTestParser
13
+
14
+
15
+ @pytest.fixture(scope="module")
16
+ def my_path(tmpdir_factory: TempPathFactory) -> Path:
17
+ """Fixture used in doctest to create a temporary directory.
18
+
19
+ Parameters
20
+ ----------
21
+ tmpdir_factory : TempPathFactory
22
+ Temporary path factory from pytest.
23
+
24
+ Returns
25
+ -------
26
+ Path
27
+ Temporary directory path.
28
+ """
29
+ return tmpdir_factory.mktemp("my_path")
30
+
31
+
32
+ pytest_collect_file = Sybil(
33
+ parsers=[
34
+ DocTestParser(),
35
+ PythonCodeBlockParser(future_imports=["print_function"]),
36
+ ],
37
+ pattern="*.py",
38
+ fixtures=["my_path"],
39
+ ).pytest()
@@ -0,0 +1,17 @@
1
+ """Dataset module."""
2
+
3
+ __all__ = [
4
+ "InMemoryDataset",
5
+ "InMemoryPredDataset",
6
+ "InMemoryTiledPredDataset",
7
+ "PathIterableDataset",
8
+ "IterableTiledPredDataset",
9
+ "IterablePredDataset",
10
+ ]
11
+
12
+ from .in_memory_dataset import InMemoryDataset
13
+ from .in_memory_pred_dataset import InMemoryPredDataset
14
+ from .in_memory_tiled_pred_dataset import InMemoryTiledPredDataset
15
+ from .iterable_dataset import PathIterableDataset
16
+ from .iterable_pred_dataset import IterablePredDataset
17
+ from .iterable_tiled_pred_dataset import IterableTiledPredDataset
@@ -0,0 +1,19 @@
1
+ """Files and arrays utils used in the datasets."""
2
+
3
+ __all__ = [
4
+ "reshape_array",
5
+ "compute_normalization_stats",
6
+ "get_files_size",
7
+ "list_files",
8
+ "validate_source_target_files",
9
+ "iterate_over_files",
10
+ "WelfordStatistics",
11
+ ]
12
+
13
+
14
+ from .dataset_utils import (
15
+ reshape_array,
16
+ )
17
+ from .file_utils import get_files_size, list_files, validate_source_target_files
18
+ from .iterate_over_files import iterate_over_files
19
+ from .running_stats import WelfordStatistics, compute_normalization_stats
@@ -0,0 +1,101 @@
1
+ """Dataset utilities."""
2
+
3
+ from typing import List, Tuple
4
+
5
+ import numpy as np
6
+
7
+ from careamics.utils.logging import get_logger
8
+
9
+ logger = get_logger(__name__)
10
+
11
+
12
+ def _get_shape_order(
13
+ shape_in: Tuple[int, ...], axes_in: str, ref_axes: str = "STCZYX"
14
+ ) -> Tuple[Tuple[int, ...], str, List[int]]:
15
+ """
16
+ Compute a new shape for the array based on the reference axes.
17
+
18
+ Parameters
19
+ ----------
20
+ shape_in : Tuple[int, ...]
21
+ Input shape.
22
+ axes_in : str
23
+ Input axes.
24
+ ref_axes : str
25
+ Reference axes.
26
+
27
+ Returns
28
+ -------
29
+ Tuple[Tuple[int, ...], str, List[int]]
30
+ New shape, new axes, indices of axes in the new axes order.
31
+ """
32
+ indices = [axes_in.find(k) for k in ref_axes]
33
+
34
+ # remove all non-existing axes (index == -1)
35
+ new_indices = list(filter(lambda k: k != -1, indices))
36
+
37
+ # find axes order and get new shape
38
+ new_axes = [axes_in[ind] for ind in new_indices]
39
+ new_shape = tuple([shape_in[ind] for ind in new_indices])
40
+
41
+ return new_shape, "".join(new_axes), new_indices
42
+
43
+
44
+ def reshape_array(x: np.ndarray, axes: str) -> np.ndarray:
45
+ """Reshape the data to (S, C, (Z), Y, X) by moving axes.
46
+
47
+ If the data has both S and T axes, the two axes will be merged. A singleton
48
+ dimension is added if there are no C axis.
49
+
50
+ Parameters
51
+ ----------
52
+ x : np.ndarray
53
+ Input array.
54
+ axes : str
55
+ Description of axes in format `STCZYX`.
56
+
57
+ Returns
58
+ -------
59
+ np.ndarray
60
+ Reshaped array with shape (S, C, (Z), Y, X).
61
+ """
62
+ _x = x
63
+ _axes = axes
64
+
65
+ # sanity checks
66
+ if len(_axes) != len(_x.shape):
67
+ raise ValueError(
68
+ f"Incompatible data shape ({_x.shape}) and axes ({_axes}). Are the axes "
69
+ f"correct?"
70
+ )
71
+
72
+ # get new x shape
73
+ new_x_shape, new_axes, indices = _get_shape_order(_x.shape, _axes)
74
+
75
+ # if S is not in the list of axes, then add a singleton S
76
+ if "S" not in new_axes:
77
+ new_axes = "S" + new_axes
78
+ _x = _x[np.newaxis, ...]
79
+ new_x_shape = (1,) + new_x_shape
80
+
81
+ # need to change the array of indices
82
+ indices = [0] + [1 + i for i in indices]
83
+
84
+ # reshape by moving axes
85
+ destination = list(range(len(indices)))
86
+ _x = np.moveaxis(_x, indices, destination)
87
+
88
+ # remove T if necessary
89
+ if "T" in new_axes:
90
+ new_x_shape = (-1,) + new_x_shape[2:] # remove T and S
91
+ new_axes = new_axes.replace("T", "")
92
+
93
+ # reshape S and T together
94
+ _x = _x.reshape(new_x_shape)
95
+
96
+ # add channel
97
+ if "C" not in new_axes:
98
+ # Add channel axis after S
99
+ _x = np.expand_dims(_x, new_axes.index("S") + 1)
100
+
101
+ return _x
@@ -0,0 +1,141 @@
1
+ """File utilities."""
2
+
3
+ from fnmatch import fnmatch
4
+ from pathlib import Path
5
+ from typing import List, Union
6
+
7
+ import numpy as np
8
+
9
+ from careamics.config.support import SupportedData
10
+ from careamics.utils.logging import get_logger
11
+
12
+ logger = get_logger(__name__)
13
+
14
+
15
+ def get_files_size(files: List[Path]) -> float:
16
+ """Get files size in MB.
17
+
18
+ Parameters
19
+ ----------
20
+ files : List[Path]
21
+ List of files.
22
+
23
+ Returns
24
+ -------
25
+ float
26
+ Total size of the files in MB.
27
+ """
28
+ return np.sum([f.stat().st_size / 1024**2 for f in files])
29
+
30
+
31
+ def list_files(
32
+ data_path: Union[str, Path],
33
+ data_type: Union[str, SupportedData],
34
+ extension_filter: str = "",
35
+ ) -> List[Path]:
36
+ """List recursively files in `data_path` and return a sorted list.
37
+
38
+ If `data_path` is a file, its name is validated against the `data_type` using
39
+ `fnmatch`, and the method returns `data_path` itself.
40
+
41
+ By default, if `data_type` is equal to `custom`, all files will be listed. To
42
+ further filter the files, use `extension_filter`.
43
+
44
+ `extension_filter` must be compatible with `fnmatch` and `Path.rglob`, e.g. "*.npy"
45
+ or "*.czi".
46
+
47
+ Parameters
48
+ ----------
49
+ data_path : Union[str, Path]
50
+ Path to the folder containing the data.
51
+ data_type : Union[str, SupportedData]
52
+ One of the supported data type (e.g. tif, custom).
53
+ extension_filter : str, optional
54
+ Extension filter, by default "".
55
+
56
+ Returns
57
+ -------
58
+ List[Path]
59
+ List of pathlib.Path objects.
60
+
61
+ Raises
62
+ ------
63
+ FileNotFoundError
64
+ If the data path does not exist.
65
+ ValueError
66
+ If the data path is empty or no files with the extension were found.
67
+ ValueError
68
+ If the file does not match the requested extension.
69
+ """
70
+ # convert to Path
71
+ data_path = Path(data_path)
72
+
73
+ # raise error if does not exists
74
+ if not data_path.exists():
75
+ raise FileNotFoundError(f"Data path {data_path} does not exist.")
76
+
77
+ # get extension compatible with fnmatch and rglob search
78
+ extension = SupportedData.get_extension_pattern(data_type)
79
+
80
+ if data_type == SupportedData.CUSTOM and extension_filter != "":
81
+ extension = extension_filter
82
+
83
+ # search recurively
84
+ if data_path.is_dir():
85
+ # search recursively the path for files with the extension
86
+ files = sorted(data_path.rglob(extension))
87
+ else:
88
+ # raise error if it has the wrong extension
89
+ if not fnmatch(str(data_path.absolute()), extension):
90
+ raise ValueError(
91
+ f"File {data_path} does not match the requested extension "
92
+ f'"{extension}".'
93
+ )
94
+
95
+ # save in list
96
+ files = [data_path]
97
+
98
+ # raise error if no files were found
99
+ if len(files) == 0:
100
+ raise ValueError(
101
+ f'Data path {data_path} is empty or files with extension "{extension}" '
102
+ f"were not found."
103
+ )
104
+
105
+ return files
106
+
107
+
108
+ def validate_source_target_files(src_files: List[Path], tar_files: List[Path]) -> None:
109
+ """
110
+ Validate source and target path lists.
111
+
112
+ The two lists should have the same number of files, and the filenames should match.
113
+
114
+ Parameters
115
+ ----------
116
+ src_files : List[Path]
117
+ List of source files.
118
+ tar_files : List[Path]
119
+ List of target files.
120
+
121
+ Raises
122
+ ------
123
+ ValueError
124
+ If the number of files in source and target folders is not the same.
125
+ ValueError
126
+ If some filenames in Train and target folders are not the same.
127
+ """
128
+ # check equal length
129
+ if len(src_files) != len(tar_files):
130
+ raise ValueError(
131
+ f"The number of source files ({len(src_files)}) is not equal to the number "
132
+ f"of target files ({len(tar_files)})."
133
+ )
134
+
135
+ # check identical names
136
+ src_names = {f.name for f in src_files}
137
+ tar_names = {f.name for f in tar_files}
138
+ difference = src_names.symmetric_difference(tar_names)
139
+
140
+ if len(difference) > 0:
141
+ raise ValueError(f"Source and target files have different names: {difference}.")
@@ -0,0 +1,83 @@
1
+ """Function to iterate over files."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Callable, Generator, Optional, Union
7
+
8
+ from numpy.typing import NDArray
9
+ from torch.utils.data import get_worker_info
10
+
11
+ from careamics.config import DataConfig, InferenceConfig
12
+ from careamics.file_io.read import read_tiff
13
+ from careamics.utils.logging import get_logger
14
+
15
+ from .dataset_utils import reshape_array
16
+
17
+ logger = get_logger(__name__)
18
+
19
+
20
+ def iterate_over_files(
21
+ data_config: Union[DataConfig, InferenceConfig],
22
+ data_files: list[Path],
23
+ target_files: Optional[list[Path]] = None,
24
+ read_source_func: Callable = read_tiff,
25
+ ) -> Generator[tuple[NDArray, Optional[NDArray]], None, None]:
26
+ """Iterate over data source and yield whole reshaped images.
27
+
28
+ Parameters
29
+ ----------
30
+ data_config : CAREamics DataConfig or InferenceConfig
31
+ Configuration.
32
+ data_files : list of pathlib.Path
33
+ List of data files.
34
+ target_files : list of pathlib.Path, optional
35
+ List of target files, by default None.
36
+ read_source_func : Callable, optional
37
+ Function to read the source, by default read_tiff.
38
+
39
+ Yields
40
+ ------
41
+ NDArray
42
+ Image.
43
+ """
44
+ # When num_workers > 0, each worker process will have a different copy of the
45
+ # dataset object
46
+ # Configuring each copy independently to avoid having duplicate data returned
47
+ # from the workers
48
+ worker_info = get_worker_info()
49
+ worker_id = worker_info.id if worker_info is not None else 0
50
+ num_workers = worker_info.num_workers if worker_info is not None else 1
51
+
52
+ # iterate over the files
53
+ for i, filename in enumerate(data_files):
54
+ # retrieve file corresponding to the worker id
55
+ if i % num_workers == worker_id:
56
+ try:
57
+ # read data
58
+ sample = read_source_func(filename, data_config.axes)
59
+
60
+ # reshape array
61
+ reshaped_sample = reshape_array(sample, data_config.axes)
62
+
63
+ # read target, if available
64
+ if target_files is not None:
65
+ if filename.name != target_files[i].name:
66
+ raise ValueError(
67
+ f"File {filename} does not match target file "
68
+ f"{target_files[i]}. Have you passed sorted "
69
+ f"arrays?"
70
+ )
71
+
72
+ # read target
73
+ target = read_source_func(target_files[i], data_config.axes)
74
+
75
+ # reshape target
76
+ reshaped_target = reshape_array(target, data_config.axes)
77
+
78
+ yield reshaped_sample, reshaped_target
79
+ else:
80
+ yield reshaped_sample, None
81
+
82
+ except Exception as e:
83
+ logger.error(f"Error reading file {filename}: {e}")