careamics 0.0.8__py3-none-any.whl → 0.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (63) hide show
  1. careamics/__init__.py +0 -4
  2. careamics/careamist.py +0 -1
  3. careamics/config/__init__.py +1 -13
  4. careamics/config/algorithms/care_algorithm_model.py +84 -0
  5. careamics/config/algorithms/n2n_algorithm_model.py +85 -0
  6. careamics/config/algorithms/n2v_algorithm_model.py +269 -1
  7. careamics/config/configuration.py +21 -13
  8. careamics/config/configuration_factories.py +179 -187
  9. careamics/config/configuration_io.py +2 -2
  10. careamics/config/data/__init__.py +1 -4
  11. careamics/config/data/data_model.py +46 -62
  12. careamics/config/support/supported_transforms.py +1 -1
  13. careamics/config/transformations/__init__.py +0 -2
  14. careamics/config/transformations/n2v_manipulate_model.py +15 -0
  15. careamics/config/transformations/transform_unions.py +0 -13
  16. careamics/dataset/dataset_utils/iterate_over_files.py +2 -2
  17. careamics/dataset/in_memory_dataset.py +3 -10
  18. careamics/dataset/in_memory_pred_dataset.py +3 -5
  19. careamics/dataset/in_memory_tiled_pred_dataset.py +2 -2
  20. careamics/dataset/iterable_dataset.py +2 -2
  21. careamics/dataset/iterable_pred_dataset.py +3 -5
  22. careamics/dataset/iterable_tiled_pred_dataset.py +3 -3
  23. careamics/dataset_ng/dataset/__init__.py +3 -0
  24. careamics/dataset_ng/dataset/dataset.py +184 -0
  25. careamics/dataset_ng/demo_dataset.ipynb +271 -0
  26. careamics/dataset_ng/demo_patch_extractor.py +53 -0
  27. careamics/dataset_ng/demo_patch_extractor_factory.py +37 -0
  28. careamics/dataset_ng/patch_extractor/__init__.py +10 -0
  29. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +111 -0
  30. careamics/dataset_ng/patch_extractor/image_stack/__init__.py +9 -0
  31. careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +53 -0
  32. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +55 -0
  33. careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +163 -0
  34. careamics/dataset_ng/patch_extractor/image_stack_loader.py +140 -0
  35. careamics/dataset_ng/patch_extractor/patch_extractor.py +29 -0
  36. careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +208 -0
  37. careamics/dataset_ng/patching_strategies/__init__.py +11 -0
  38. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +82 -0
  39. careamics/dataset_ng/patching_strategies/random_patching.py +338 -0
  40. careamics/dataset_ng/patching_strategies/sequential_patching.py +75 -0
  41. careamics/lightning/lightning_module.py +78 -27
  42. careamics/lightning/train_data_module.py +8 -39
  43. careamics/losses/fcn/losses.py +17 -10
  44. careamics/lvae_training/eval_utils.py +21 -8
  45. careamics/model_io/bioimage/bioimage_utils.py +5 -3
  46. careamics/model_io/bioimage/model_description.py +3 -3
  47. careamics/model_io/bmz_io.py +2 -2
  48. careamics/model_io/model_io_utils.py +2 -2
  49. careamics/transforms/__init__.py +2 -1
  50. careamics/transforms/compose.py +5 -15
  51. careamics/transforms/n2v_manipulate_torch.py +143 -0
  52. careamics/transforms/pixel_manipulation.py +1 -0
  53. careamics/transforms/pixel_manipulation_torch.py +418 -0
  54. careamics/utils/version.py +38 -0
  55. {careamics-0.0.8.dist-info → careamics-0.0.10.dist-info}/METADATA +7 -8
  56. {careamics-0.0.8.dist-info → careamics-0.0.10.dist-info}/RECORD +59 -42
  57. careamics/config/care_configuration.py +0 -100
  58. careamics/config/data/n2v_data_model.py +0 -193
  59. careamics/config/n2n_configuration.py +0 -101
  60. careamics/config/n2v_configuration.py +0 -266
  61. {careamics-0.0.8.dist-info → careamics-0.0.10.dist-info}/WHEEL +0 -0
  62. {careamics-0.0.8.dist-info → careamics-0.0.10.dist-info}/entry_points.txt +0 -0
  63. {careamics-0.0.8.dist-info → careamics-0.0.10.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,163 @@
1
+ from collections.abc import Sequence
2
+ from pathlib import Path
3
+ from typing import Union
4
+
5
+ import zarr
6
+ import zarr.storage
7
+ from numpy.typing import NDArray
8
+ from typing_extensions import Self
9
+
10
+ from careamics.dataset.dataset_utils import reshape_array
11
+
12
+
13
+ class ZarrImageStack:
14
+ """
15
+ A class for extracting patches from an image stack that is stored as a zarr array.
16
+ """
17
+
18
+ # TODO: keeping store type narrow so that it has the path attribute
19
+ # base zarr store is zarr.storage.Store, includes MemoryStore
20
+ def __init__(self, store: zarr.storage.FSStore, data_path: str, axes: str):
21
+ self._store = store
22
+ self._array = zarr.open_array(store=self._store, path=data_path, mode="r")
23
+ # TODO: validate axes
24
+ # - must contain XY
25
+ # - must be subset of STCZYX
26
+ self._original_axes = axes
27
+ self._original_data_shape: tuple[int, ...] = self._array.shape
28
+ self.data_shape = _reshaped_array_shape(axes, self._original_data_shape)
29
+ self.data_dtype = self._array.dtype
30
+
31
+ # TODO: not sure if this is useful
32
+ # TODO: potential solution using different metadata class for each ImageStack type
33
+ # - see #399
34
+ @property
35
+ def source(self) -> Path:
36
+ return Path(self._store.path) / self._array.path
37
+
38
+ # automatically finds axes from metadata
39
+ # based on implementation in ome-zarr python package
40
+ # https://github.com/ome/ome-zarr-py/blob/f7096b0f2c1fc8edf4d7304e33caf8d279d99dbb/ome_zarr/reader.py#L294-L316
41
+ @classmethod
42
+ def from_ome_zarr(cls, path: Union[Path, str]) -> Self:
43
+ """
44
+ Will only use the first resolution in the hierarchy.
45
+
46
+ Assumes the path only contains 1 image.
47
+
48
+ Path can be to a local file, or it can be a URL to a zarr stored in the cloud.
49
+ """
50
+ store = zarr.storage.FSStore(url=path)
51
+ group = zarr.open_group(store=store, mode="r")
52
+ if "multiscales" not in group.attrs:
53
+ raise ValueError(
54
+ f"Zarr at path '{path}' cannot be loaded as an OME-Zarr because it "
55
+ "does not contain the attribute 'multiscales'."
56
+ )
57
+ # TODO: why is this a list of length 1? 0 index also in ome-zarr-python
58
+ # https://github.com/ome/ome-zarr-py/blob/f7096b0f2c1fc8edf4d7304e33caf8d279d99dbb/ome_zarr/reader.py#L286
59
+ multiscales_metadata = group.attrs["multiscales"][0]
60
+
61
+ # get axes
62
+ axes_list = [axes_data["name"] for axes_data in multiscales_metadata["axes"]]
63
+ axes = "".join(axes_list).upper()
64
+
65
+ first_multiscale_path = multiscales_metadata["datasets"][0]["path"]
66
+
67
+ return cls(store=store, data_path=first_multiscale_path, axes=axes)
68
+
69
+ def extract_patch(
70
+ self, sample_idx: int, coords: Sequence[int], patch_size: Sequence[int]
71
+ ) -> NDArray:
72
+ # original axes assumed to be any subset of STCZYX (containing YX), in any order
73
+ # arguments must be transformed to index data in original axes order
74
+ # to do this: loop through original axes and append correct index/slice
75
+ # for each case: STCZYX
76
+ # Note: if any axis is not present in original_axes it is skipped.
77
+
78
+ # guard for no S and T in original axes
79
+ if ("S" not in self._original_axes) and ("T" not in self._original_axes):
80
+ if sample_idx not in [0, -1]:
81
+ raise IndexError(
82
+ f"Sample index {sample_idx} out of bounds for S axes with size "
83
+ f"{self.data_shape[0]}"
84
+ )
85
+
86
+ patch_slice: list[Union[int, slice]] = []
87
+ for d in self._original_axes:
88
+ if d == "S":
89
+ patch_slice.append(self._get_S_index(sample_idx))
90
+ elif d == "T":
91
+ patch_slice.append(self._get_T_index(sample_idx))
92
+ elif d == "C":
93
+ patch_slice.append(slice(None, None))
94
+ elif d == "Z":
95
+ patch_slice.append(slice(coords[0], coords[0] + patch_size[0]))
96
+ elif d == "Y":
97
+ y_idx = 0 if "Z" not in self._original_axes else 1
98
+ patch_slice.append(
99
+ slice(coords[y_idx], coords[y_idx] + patch_size[y_idx])
100
+ )
101
+ elif d == "X":
102
+ x_idx = 1 if "Z" not in self._original_axes else 2
103
+ patch_slice.append(
104
+ slice(coords[x_idx], coords[x_idx] + patch_size[x_idx])
105
+ )
106
+ else:
107
+ raise ValueError(f"Unrecognised axis '{d}', axes should be in STCZYX.")
108
+
109
+ patch = self._array[tuple(patch_slice)]
110
+ patch_axes = self._original_axes.replace("S", "").replace("T", "")
111
+ return reshape_array(patch, patch_axes)[0] # remove first sample dim
112
+
113
+ def _get_T_index(self, sample_idx: int) -> int:
114
+ """Get T index given `sample_idx`."""
115
+ if "T" not in self._original_axes:
116
+ raise ValueError("No 'T' axis specified in original data axes.")
117
+ axis_idx = self._original_axes.index("T")
118
+ dim = self._original_data_shape[axis_idx]
119
+
120
+ # new S' = S*T
121
+ # T_idx = S_idx' // T_size
122
+ # S_idx = S_idx' % T_size
123
+ # - floor divide finds the row
124
+ # - modulus finds how far along the row i.e. the column
125
+ return sample_idx % dim
126
+
127
+ def _get_S_index(self, sample_idx: int) -> int:
128
+ """Get S index given `sample_idx`."""
129
+ if "S" not in self._original_axes:
130
+ raise ValueError("No 'S' axis specified in original data axes.")
131
+ if "T" in self._original_axes:
132
+ T_axis_idx = self._original_axes.index("T")
133
+ T_dim = self._original_data_shape[T_axis_idx]
134
+
135
+ # new S' = S*T
136
+ # T_idx = S_idx' // T_size
137
+ # S_idx = S_idx' % T_size
138
+ # - floor divide finds the row
139
+ # - modulus finds how far along the row i.e. the column
140
+ return sample_idx // T_dim
141
+ else:
142
+ return sample_idx
143
+
144
+
145
+ # TODO: move to dataset_utils, better name?
146
+ def _reshaped_array_shape(axes: str, shape: Sequence[int]) -> tuple[int, ...]:
147
+ """Find resulting shape if reshaping array with given `axes` and `shape`."""
148
+ target_axes = "SCZYX"
149
+ target_shape = []
150
+ for d in target_axes:
151
+ if d in axes:
152
+ idx = axes.index(d)
153
+ target_shape.append(shape[idx])
154
+ elif (d != axes) and (d != "Z"):
155
+ target_shape.append(1)
156
+ else:
157
+ pass
158
+
159
+ if "T" in axes:
160
+ idx = axes.index("T")
161
+ target_shape[0] = target_shape[0] * shape[idx]
162
+
163
+ return tuple(target_shape)
@@ -0,0 +1,140 @@
1
+ from collections.abc import Sequence
2
+ from pathlib import Path
3
+ from typing import (
4
+ Any,
5
+ Optional,
6
+ Protocol,
7
+ Union,
8
+ )
9
+
10
+ from numpy.typing import NDArray
11
+ from typing_extensions import ParamSpec
12
+
13
+ from careamics.config.support import SupportedData
14
+ from careamics.file_io.read import ReadFunc
15
+ from careamics.utils import BaseEnum
16
+
17
+ from .image_stack import ImageStack, InMemoryImageStack, ZarrImageStack
18
+
19
+ P = ParamSpec("P")
20
+
21
+
22
+ class SupportedDataDev(str, BaseEnum):
23
+ ZARR = "zarr"
24
+
25
+
26
+ class ImageStackLoader(Protocol[P]):
27
+ """
28
+ Protocol to define how `ImageStacks` should be loaded.
29
+
30
+ An `ImageStackLoader` is a callable that must take the `source` of the data as the
31
+ first argument, and the data `axes` as the second argument.
32
+
33
+ Additional `*args` and `**kwargs` are allowed, but they should only be used to
34
+ determine _how_ the data is loaded, not _what_ data is loaded. The `source`
35
+ argument has to wholly determine _what_ data is loaded, this is because,
36
+ downstream, both an input-source and a target-source have to be specified but they
37
+ will share `*args` and `**kwargs`.
38
+
39
+ An `ImageStackLoader` must return a sequence of the `ImageStack` class. This could
40
+ be a sequence of one of the existing concrete implementations, such as
41
+ `ZarrImageStack`, or a custom user defined `ImageStack`.
42
+
43
+ Example
44
+ -------
45
+ The following example demonstrates how an `ImageStackLoader` could be defined
46
+ for loading non-OME Zarr images. Returning a list of `ZarrImageStack` instances.
47
+
48
+ >>> from typing import TypedDict
49
+
50
+ >>> from zarr.storage import FSStore
51
+
52
+ >>> from careamics.config import DataConfig
53
+ >>> from careamics.dataset_ng.patch_extractor.image_stack import ZarrImageStack
54
+
55
+ >>> # Define a zarr source
56
+ >>> # It encompasses multiple arguments that determine what data will be loaded
57
+ >>> class ZarrSource(TypedDict):
58
+ ... store: FSStore
59
+ ... data_paths: Sequence[str]
60
+
61
+ >>> def custom_image_stack_loader(
62
+ ... source: ZarrSource, axes: str, *args, **kwargs
63
+ ... ) -> list[ZarrImageStack]:
64
+ ... image_stacks = [
65
+ ... ZarrImageStack(store=source["store"], data_path=data_path, axes=axes)
66
+ ... for data_path in source["data_paths"]
67
+ ... ]
68
+ ... return image_stacks
69
+
70
+ TODO: show example use in the `CAREamicsDataset`
71
+
72
+ The example above defines a `ZarrSource` dict because to determine _which_ ZARR
73
+ images will be loaded both a ZARR store and the internal data paths need to be
74
+ specified.
75
+ """
76
+
77
+ def __call__(
78
+ self, source: Any, axes: str, *args: P.args, **kwargs: P.kwargs
79
+ ) -> Sequence[ImageStack]: ...
80
+
81
+
82
+ def from_arrays(
83
+ source: Sequence[NDArray], axes: str, *args, **kwargs
84
+ ) -> list[InMemoryImageStack]:
85
+ return [InMemoryImageStack.from_array(data=array, axes=axes) for array in source]
86
+
87
+
88
+ # TODO: change source to directory path? Like in current implementation
89
+ # Advantage of having a list is the user can match input and target order themselves
90
+ def from_tiff_files(
91
+ source: Sequence[Path], axes: str, *args, **kwargs
92
+ ) -> list[InMemoryImageStack]:
93
+ return [InMemoryImageStack.from_tiff(path=path, axes=axes) for path in source]
94
+
95
+
96
+ # TODO: change source to directory path? Like in current implementation
97
+ # Advantage of having a list is the user can match input and target order themselves
98
+ def from_custom_file_type(
99
+ source: Sequence[Path],
100
+ axes: str,
101
+ read_func: ReadFunc,
102
+ read_kwargs: dict[str, Any],
103
+ *args,
104
+ **kwargs,
105
+ ) -> list[InMemoryImageStack]:
106
+ return [
107
+ InMemoryImageStack.from_custom_file_type(
108
+ path=path,
109
+ axes=axes,
110
+ read_func=read_func,
111
+ **read_kwargs,
112
+ )
113
+ for path in source
114
+ ]
115
+
116
+
117
+ def from_ome_zarr_files(
118
+ source: Sequence[Path], axes: str, *args, **kwargs
119
+ ) -> list[ZarrImageStack]:
120
+ # NOTE: axes is unused here, in from_ome_zarr the axes are automatically retrieved
121
+ return [ZarrImageStack.from_ome_zarr(path) for path in source]
122
+
123
+
124
+ def get_image_stack_loader(
125
+ data_type: Union[SupportedData, SupportedDataDev],
126
+ image_stack_loader: Optional[ImageStackLoader] = None,
127
+ ) -> ImageStackLoader:
128
+ if data_type == SupportedData.ARRAY:
129
+ return from_arrays
130
+ elif data_type == SupportedData.TIFF:
131
+ return from_tiff_files
132
+ elif data_type == "zarr": # temp for testing until zarr is added to SupportedData
133
+ return from_ome_zarr_files
134
+ elif data_type == SupportedData.CUSTOM:
135
+ if image_stack_loader is None:
136
+ return from_custom_file_type
137
+ else:
138
+ return image_stack_loader
139
+ else:
140
+ raise ValueError
@@ -0,0 +1,29 @@
1
+ from collections.abc import Sequence
2
+
3
+ from numpy.typing import NDArray
4
+
5
+ from .image_stack import ImageStack
6
+
7
+
8
+ class PatchExtractor:
9
+ """
10
+ A class for extracting patches from multiple image stacks.
11
+ """
12
+
13
+ def __init__(self, image_stacks: Sequence[ImageStack]):
14
+ self.image_stacks: list[ImageStack] = list(image_stacks)
15
+
16
+ def extract_patch(
17
+ self,
18
+ data_idx: int,
19
+ sample_idx: int,
20
+ coords: Sequence[int],
21
+ patch_size: Sequence[int],
22
+ ) -> NDArray:
23
+ return self.image_stacks[data_idx].extract_patch(
24
+ sample_idx=sample_idx, coords=coords, patch_size=patch_size
25
+ )
26
+
27
+ @property
28
+ def shape(self):
29
+ return [stack.data_shape for stack in self.image_stacks]
@@ -0,0 +1,208 @@
1
+ from collections.abc import Sequence
2
+ from pathlib import Path
3
+ from typing import Any, Literal, Optional, Union, overload
4
+
5
+ from numpy.typing import NDArray
6
+ from typing_extensions import ParamSpec
7
+
8
+ from careamics.config.support import SupportedData
9
+ from careamics.dataset_ng.patch_extractor import PatchExtractor
10
+ from careamics.file_io.read import ReadFunc
11
+
12
+ from .image_stack_loader import (
13
+ ImageStackLoader,
14
+ SupportedDataDev,
15
+ get_image_stack_loader,
16
+ )
17
+
18
+ P = ParamSpec("P")
19
+
20
+
21
+ # Define overloads for each implemented ImageStackLoader case
22
+ # Array case
23
+ @overload
24
+ def create_patch_extractor(
25
+ source: Sequence[NDArray], axes: str, data_type: Literal[SupportedData.ARRAY]
26
+ ) -> PatchExtractor:
27
+ """
28
+ Create a patch extractor from a sequence of numpy arrays.
29
+
30
+ Parameters
31
+ ----------
32
+ source: sequence of numpy.ndarray
33
+ The source arrays of the data.
34
+ data_config: DataConfig
35
+ The data configuration, `data_config.data_type` should have the value "array",
36
+ and `data_config.axes` should describe the axes of every array in the `source`.
37
+
38
+ Returns
39
+ -------
40
+ PatchExtractor
41
+ """
42
+
43
+
44
+ # TIFF and ZARR case
45
+ @overload
46
+ def create_patch_extractor(
47
+ source: Sequence[Path],
48
+ axes: str,
49
+ data_type: Literal[SupportedData.TIFF, SupportedDataDev.ZARR],
50
+ ) -> PatchExtractor:
51
+ """
52
+ Create a patch extractor from a sequence of files that match our supported types.
53
+
54
+ Supported file types include TIFF and ZARR.
55
+
56
+ If the files are ZARR files they must follow the OME standard. If you have ZARR
57
+ files that do not follow the OME standard, see documentation on how to create
58
+ a custom `image_stack_loader`. (TODO: Add link).
59
+
60
+ Parameters
61
+ ----------
62
+ source: sequence of Path
63
+ The source files for the data.
64
+ data_config: DataConfig
65
+ The data configuration, `data_config.data_type` should have the value "tiff" or
66
+ "zarr", and `data_config.axes` should describe the axes of every image in the
67
+ `source`.
68
+
69
+ Returns
70
+ -------
71
+ PatchExtractor
72
+ """
73
+
74
+
75
+ # Custom file type case (loaded into memory)
76
+ @overload
77
+ def create_patch_extractor(
78
+ source: Any,
79
+ axes: str,
80
+ data_type: Literal[SupportedData.CUSTOM],
81
+ *,
82
+ read_func: ReadFunc,
83
+ read_kwargs: dict[str, Any],
84
+ ) -> PatchExtractor:
85
+ """
86
+ Create a patch extractor from a sequence of files of a custom type.
87
+
88
+ Parameters
89
+ ----------
90
+ source: sequence of Path
91
+ The source files for the data.
92
+ data_config: DataConfig
93
+ The data configuration, `data_config.data_type` should have the value "custom".
94
+ read_func : ReadFunc
95
+ A function to read the custom file type, see the `ReadFunc` protocol.
96
+ read_kwargs : dict of {str: Any}
97
+ Kwargs that will be passed to the custom `read_func`.
98
+
99
+ Returns
100
+ -------
101
+ PatchExtractor
102
+ """
103
+
104
+
105
+ # Custom ImageStackLoader case
106
+ @overload
107
+ def create_patch_extractor(
108
+ source: Any,
109
+ axes: str,
110
+ data_type: Literal[SupportedData.CUSTOM],
111
+ image_stack_loader: ImageStackLoader[P],
112
+ *args: P.args,
113
+ **kwargs: P.kwargs,
114
+ ) -> PatchExtractor:
115
+ """
116
+ Create a patch extractor using a custom `ImageStackLoader`.
117
+
118
+ The custom image stack loader must follow the `ImageStackLoader` protocol, i.e.
119
+ it must have the following function signature:
120
+ ```
121
+ def image_loader_example(
122
+ source: Any, data_config: DataConfig, *args, **kwargs
123
+ ) -> Sequence[ImageStack]:
124
+ ```
125
+
126
+ Parameters
127
+ ----------
128
+ source: sequence of Path
129
+ The source files for the data.
130
+ data_config: DataConfig
131
+ The data configuration, `data_config.data_type` should have the value "custom".
132
+ image_stack_loader: ImageStackLoader
133
+ A custom image stack loader callable.
134
+ *args: Any
135
+ Positional arguments that will be passed to the custom image stack loader.
136
+ **kwargs: Any
137
+ Keyword arguments that will be passed to the custom image stack loader.
138
+
139
+ Returns
140
+ -------
141
+ PatchExtractor
142
+ """
143
+
144
+
145
+ # final overload to match the implentation function signature
146
+ # Need this so it works later in the code
147
+ # (bec there aren't created overloads for create_patch_extractors below)
148
+ @overload
149
+ def create_patch_extractor(
150
+ source: Any,
151
+ axes: str,
152
+ data_type: Union[SupportedData, SupportedDataDev],
153
+ image_stack_loader: Optional[ImageStackLoader[P]] = None,
154
+ *args: P.args,
155
+ **kwargs: P.kwargs,
156
+ ) -> PatchExtractor: ...
157
+
158
+
159
+ def create_patch_extractor(
160
+ source: Any,
161
+ axes: str,
162
+ data_type: Union[SupportedData, SupportedDataDev],
163
+ image_stack_loader: Optional[ImageStackLoader[P]] = None,
164
+ *args: P.args,
165
+ **kwargs: P.kwargs,
166
+ ) -> PatchExtractor:
167
+ # TODO: Do we need to catch data_config.data_type and source mismatches?
168
+ # e.g. data_config.data_type is "array" but source is not Sequence[NDArray]
169
+ loader: ImageStackLoader[P] = get_image_stack_loader(data_type, image_stack_loader)
170
+ image_stacks = loader(source, axes, *args, **kwargs)
171
+ return PatchExtractor(image_stacks)
172
+
173
+
174
+ # TODO: Remove this and just call `create_patch_extractor` within the Dataset class
175
+ # Keeping for consistency for now
176
+ def create_patch_extractors(
177
+ source: Any,
178
+ target_source: Optional[Any],
179
+ axes: str,
180
+ data_type: Union[SupportedData, SupportedDataDev],
181
+ image_stack_loader: Optional[ImageStackLoader] = None,
182
+ *args,
183
+ **kwargs,
184
+ ) -> tuple[PatchExtractor, Optional[PatchExtractor]]:
185
+
186
+ # --- data extractor
187
+ patch_extractor: PatchExtractor = create_patch_extractor(
188
+ source,
189
+ axes,
190
+ data_type,
191
+ image_stack_loader,
192
+ *args,
193
+ **kwargs,
194
+ )
195
+ # --- optional target extractor
196
+ if target_source is not None:
197
+ target_patch_extractor = create_patch_extractor(
198
+ target_source,
199
+ axes,
200
+ data_type,
201
+ image_stack_loader,
202
+ *args,
203
+ **kwargs,
204
+ )
205
+
206
+ return patch_extractor, target_patch_extractor
207
+
208
+ return patch_extractor, None
@@ -0,0 +1,11 @@
1
+ __all__ = [
2
+ "FixedRandomPatchingStrategy",
3
+ "PatchSpecs",
4
+ "PatchingStrategy",
5
+ "RandomPatchingStrategy",
6
+ "SequentialPatchingStrategy",
7
+ ]
8
+
9
+ from .patching_strategy_protocol import PatchingStrategy, PatchSpecs
10
+ from .random_patching import FixedRandomPatchingStrategy, RandomPatchingStrategy
11
+ from .sequential_patching import SequentialPatchingStrategy
@@ -0,0 +1,82 @@
1
+ """A module to contain type definitions relating to patching strategies."""
2
+
3
+ from collections.abc import Sequence
4
+ from typing import Protocol, TypedDict
5
+
6
+
7
+ class PatchSpecs(TypedDict):
8
+ """A dictionary that specifies a single patch in a series of `ImageStacks`.
9
+
10
+ Attributes
11
+ ----------
12
+ data_idx: int
13
+ Determines which `ImageStack` a patch belongs to, within a series of
14
+ `ImageStack`s.
15
+ sample_idx: int
16
+ Determines which sample a patch belongs to, within an `ImageStack`.
17
+ coords: sequence of int
18
+ The top-left (and first z-slice for 3D data) of a patch. The sequence will have
19
+ length 2 or 3, for 2D and 3D data respectively.
20
+ patch_size: sequence of int
21
+ The size of the patch. The sequence will have length 2 or 3, for 2D and 3D data
22
+ respectively.
23
+ """
24
+
25
+ data_idx: int
26
+ sample_idx: int
27
+ coords: Sequence[int]
28
+ patch_size: Sequence[int]
29
+
30
+
31
+ class PatchingStrategy(Protocol):
32
+ """
33
+ An interface for patching strategies.
34
+
35
+ Patching strategies are a component of the `CAREamicsDataset`; they determine
36
+ how patches are extracted from the underlying data.
37
+
38
+ Attributes
39
+ ----------
40
+ n_patches: int
41
+ The number of patches that the patching strategy will return.
42
+
43
+ Methods
44
+ -------
45
+ get_patch_spec(index: int) -> PatchSpecs
46
+ Get a patch specification for a given patch index.
47
+ """
48
+
49
+ @property
50
+ def n_patches(self) -> int:
51
+ """
52
+ The number of patches that the patching strategy will return.
53
+
54
+ It also determines the maximum index that can be given to `get_patch_spec`,
55
+ and the length of the `CAREamicsDataset`.
56
+
57
+ Returns
58
+ -------
59
+ int
60
+ Number of patches.
61
+ """
62
+ ...
63
+
64
+ def get_patch_spec(self, index: int) -> PatchSpecs:
65
+ """
66
+ Get a patch specification for a given patch index.
67
+
68
+ This method is intended to be called from within the
69
+ `CAREamicsDataset.__getitem__`. The index will be passed through from this
70
+ method.
71
+
72
+ Parameters
73
+ ----------
74
+ index : int
75
+ A patch index.
76
+
77
+ Returns
78
+ -------
79
+ PatchSpecs
80
+ A dictionary that specifies a single patch in a series of `ImageStacks`.
81
+ """
82
+ ...