careamics 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (98) hide show
  1. careamics/careamist.py +24 -7
  2. careamics/cli/utils.py +1 -1
  3. careamics/config/algorithms/n2v_algorithm_model.py +1 -1
  4. careamics/config/architectures/unet_model.py +3 -0
  5. careamics/config/callback_model.py +23 -34
  6. careamics/config/configuration.py +55 -4
  7. careamics/config/configuration_factories.py +288 -23
  8. careamics/config/data/__init__.py +2 -0
  9. careamics/config/data/data_model.py +41 -4
  10. careamics/config/data/ng_data_model.py +381 -0
  11. careamics/config/data/patching_strategies/__init__.py +14 -0
  12. careamics/config/data/patching_strategies/_overlapping_patched_model.py +103 -0
  13. careamics/config/data/patching_strategies/_patched_model.py +56 -0
  14. careamics/config/data/patching_strategies/random_patching_model.py +21 -0
  15. careamics/config/data/patching_strategies/sequential_patching_model.py +25 -0
  16. careamics/config/data/patching_strategies/tiled_patching_model.py +40 -0
  17. careamics/config/data/patching_strategies/whole_patching_model.py +12 -0
  18. careamics/config/inference_model.py +6 -3
  19. careamics/config/optimizer_models.py +1 -3
  20. careamics/config/support/supported_data.py +7 -0
  21. careamics/config/support/supported_patching_strategies.py +22 -0
  22. careamics/config/training_model.py +0 -2
  23. careamics/config/validators/validator_utils.py +4 -3
  24. careamics/dataset/dataset_utils/iterate_over_files.py +2 -2
  25. careamics/dataset/in_memory_dataset.py +2 -1
  26. careamics/dataset/iterable_dataset.py +2 -2
  27. careamics/dataset/iterable_pred_dataset.py +2 -2
  28. careamics/dataset/iterable_tiled_pred_dataset.py +2 -2
  29. careamics/dataset/patching/patching.py +3 -2
  30. careamics/dataset/tiling/lvae_tiled_patching.py +16 -6
  31. careamics/dataset/tiling/tiled_patching.py +2 -1
  32. careamics/dataset_ng/README.md +212 -0
  33. careamics/dataset_ng/dataset.py +229 -0
  34. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  35. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  36. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +734 -0
  37. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  38. careamics/dataset_ng/{demo_dataset.ipynb → demos/demo_dataset.ipynb} +60 -53
  39. careamics/dataset_ng/{demo_patch_extractor.py → demos/demo_patch_extractor.py} +7 -9
  40. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +292 -0
  41. careamics/dataset_ng/factory.py +451 -0
  42. careamics/dataset_ng/legacy_interoperability.py +170 -0
  43. careamics/dataset_ng/patch_extractor/__init__.py +3 -8
  44. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +7 -5
  45. careamics/dataset_ng/patch_extractor/image_stack/__init__.py +4 -1
  46. careamics/dataset_ng/patch_extractor/image_stack/czi_image_stack.py +360 -0
  47. careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +5 -1
  48. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -1
  49. careamics/dataset_ng/patch_extractor/image_stack_loader.py +5 -75
  50. careamics/dataset_ng/patch_extractor/patch_extractor.py +5 -4
  51. careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +114 -105
  52. careamics/dataset_ng/patching_strategies/__init__.py +6 -1
  53. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +31 -0
  54. careamics/dataset_ng/patching_strategies/random_patching.py +5 -1
  55. careamics/dataset_ng/patching_strategies/sequential_patching.py +5 -5
  56. careamics/dataset_ng/patching_strategies/tiling_strategy.py +172 -0
  57. careamics/dataset_ng/patching_strategies/whole_sample.py +36 -0
  58. careamics/file_io/read/get_func.py +2 -1
  59. careamics/lightning/dataset_ng/__init__.py +1 -0
  60. careamics/lightning/dataset_ng/data_module.py +678 -0
  61. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  62. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  63. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  64. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +212 -0
  65. careamics/lightning/lightning_module.py +5 -1
  66. careamics/lightning/predict_data_module.py +2 -1
  67. careamics/lightning/train_data_module.py +2 -1
  68. careamics/losses/loss_factory.py +2 -1
  69. careamics/lvae_training/dataset/__init__.py +8 -3
  70. careamics/lvae_training/dataset/config.py +3 -3
  71. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  72. careamics/lvae_training/dataset/multich_dataset.py +46 -17
  73. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  74. careamics/lvae_training/dataset/types.py +3 -3
  75. careamics/lvae_training/dataset/utils/index_manager.py +259 -0
  76. careamics/lvae_training/eval_utils.py +93 -3
  77. careamics/model_io/bioimage/bioimage_utils.py +1 -1
  78. careamics/model_io/bioimage/model_description.py +1 -1
  79. careamics/model_io/bmz_io.py +1 -1
  80. careamics/model_io/model_io_utils.py +2 -2
  81. careamics/models/activation.py +2 -1
  82. careamics/prediction_utils/prediction_outputs.py +1 -1
  83. careamics/prediction_utils/stitch_prediction.py +1 -1
  84. careamics/transforms/compose.py +1 -0
  85. careamics/transforms/n2v_manipulate_torch.py +15 -9
  86. careamics/transforms/normalize.py +18 -7
  87. careamics/transforms/pixel_manipulation_torch.py +59 -92
  88. careamics/utils/lightning_utils.py +25 -11
  89. careamics/utils/metrics.py +2 -1
  90. careamics/utils/torch_utils.py +23 -0
  91. {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/METADATA +12 -11
  92. {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/RECORD +95 -69
  93. careamics/dataset_ng/dataset/__init__.py +0 -3
  94. careamics/dataset_ng/dataset/dataset.py +0 -184
  95. careamics/dataset_ng/demo_patch_extractor_factory.py +0 -37
  96. {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/WHEEL +0 -0
  97. {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/entry_points.txt +0 -0
  98. {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,360 @@
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ from collections.abc import Sequence
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING, Any, Literal
7
+
8
+ import numpy as np
9
+ from numpy.typing import NDArray
10
+
11
+ try:
12
+ from pylibCZIrw.czi import CziReader, Rectangle, open_czi
13
+
14
+ pyczi_available = True
15
+ except ImportError:
16
+ pyczi_available = False
17
+
18
+ if TYPE_CHECKING:
19
+ try:
20
+ from pylibCZIrw.czi import CziReader, Rectangle, open_czi
21
+ except ImportError:
22
+ CziReader = Rectangle = open_czi = None # type: ignore
23
+
24
+
25
+ class CziImageStack:
26
+ """
27
+ A class for extracting patches from an image stack that is stored as a CZI file.
28
+
29
+ Parameters
30
+ ----------
31
+ data_path : str or Path
32
+ Path to the CZI file.
33
+
34
+ scene : int, optional
35
+ Index of the scene to extract.
36
+
37
+ A single CZI file can contain multiple "scenes", which are stored alongside each
38
+ other at different coordinates in the image plane, often separated by empty
39
+ space. Specifying this argument will read only the single scene with that index
40
+ from the file. Think of it as cropping the CZI file to the region where that
41
+ scene is located.
42
+
43
+ If no scene index is specified, the entire image will be read. In case it
44
+ contains multiple scenes, they will all be present in the resulting image.
45
+ This is usually not desirable due to the empty space between them.
46
+ In general, only omit this argument or set it to `None` if you know that
47
+ your CZI file does not contain any scenes.
48
+
49
+ The static function :py:meth:`get_bounding_rectangles` can be used to find out
50
+ how many scenes a given file contains and what their bounding rectangles are.
51
+
52
+ The scene can also be provided as part of `data_path` by appending an `"@"`
53
+ followed by the scene index to the filename.
54
+
55
+ depth_axis : {"none", "Z", "T"}, default: "none"
56
+ Which axis to use as depth-axis for providing 3-D patches.
57
+
58
+ - `"none"`: Only provide 2-D patches. If a Z or T dimension is present in the
59
+ data, they will be combined into the sample dimension `S`.
60
+ - `"Z"`: Use the Z-axis as depth-axis. If a T axis is present as well, it will
61
+ be merged into the sample dimensions `S`.
62
+ - `"T"`: Use the T-axis as depth-axis. If a Z axis is present as well, it will
63
+ be merged into the sample dimensions `S`.
64
+
65
+ Attributes
66
+ ----------
67
+ source : Path
68
+ Path to the CZI file, including the scene index if specified.
69
+ data_path : Path
70
+ Path to the CZI file without scene index.
71
+ scene : int or None
72
+ Index of the scene to extract, or None if not specified.
73
+ data_shape : Sequence[int]
74
+ The shape of the data in the order `(SC(Z)YX)`.
75
+ axes : str
76
+ The axes in the CZI file corresponding to the dimensions in `data_shape`.
77
+ The following values can occur:
78
+
79
+ - "SCZYX" for 3-D volumes if `depth_axis` is `"Z"`.
80
+ - "SCTYX" for time-series if `depth_axis` is `"T"`.
81
+ - "SCYX" if `depth_axis` is `"none"`.
82
+
83
+ The axis `S` (sample) is the only one not mapping one-to-one to an axis in the
84
+ CZI file but combines all remaining axes present in the file into one.
85
+
86
+ Examples
87
+ --------
88
+ Create an image stack for the first scene in a CZI file:
89
+ >>> stack = CziImageStack("path/to/file.czi", scene=0) # doctest: +SKIP
90
+
91
+ Alternatively, the scene index can also be provided as part of the filename.
92
+ This is mainly intended for re-creating an image stack from the `source` property:
93
+ >>> stack = CziImageStack("path/to/file.czi@0") # doctest: +SKIP
94
+ >>> stack2 = CziImageStack(stack.source) # doctest: +SKIP
95
+
96
+ If the CZI file contains a third dimension (Z or T) and you want to perform 3-D
97
+ denoising, you need to explicitly set `depth_axis` to `"Z"` or `"T"`:
98
+ >>> stack_2d = CziImageStack("path/to/file.czi", scene=0) # doctest: +SKIP
99
+ >>> stack_2d.axes, stack_2d.data_shape # doctest: +SKIP
100
+ ('SCYX', [40, 1, 512, 512])
101
+ >>> stack_3d = CziImageStack( # doctest: +SKIP
102
+ ... "path/to/file.czi", scene=0, depth_axis="Z"
103
+ ... )
104
+ >>> stack_3d.axes, stack_3d.data_shape # doctest: +SKIP
105
+ ('SCZYX', [4, 1, 10, 512, 512])
106
+ """
107
+
108
+ def __init__(
109
+ self,
110
+ data_path: str | Path,
111
+ scene: int | None = None,
112
+ depth_axis: Literal["none", "Z", "T"] = "none",
113
+ ) -> None:
114
+ if not pyczi_available:
115
+ raise ImportError(
116
+ "The CZI image stack requires the `pylibCZIrw` package to be installed."
117
+ " Please install it with `pip install careamics[czi]`."
118
+ )
119
+
120
+ _data_path = Path(data_path)
121
+
122
+ # Check for scene encoded in filename.
123
+ # Normally, file path and scene should be provided as separate arguments but
124
+ # we would also like to support using the `source` property to re-create the
125
+ # CZI image stack. In this case, the scene index is encoded in the file path.
126
+ scene_matches = re.match(r"^(.*)@(\d+)$", _data_path.name)
127
+ if scene_matches:
128
+ if scene is not None:
129
+ raise ValueError(
130
+ f"Scene index is specified in the filename ({_data_path.name}) and "
131
+ f"as an argument ({scene}). Please specify only one."
132
+ )
133
+ _data_path = _data_path.parent / scene_matches.group(1)
134
+ scene = int(scene_matches.group(2))
135
+
136
+ # Set variables
137
+ self.data_path = _data_path
138
+ self.scene = scene
139
+ self._depth_axis = depth_axis
140
+
141
+ # Open CZI file
142
+ self._czi = CziReader(str(self.data_path))
143
+
144
+ # Determine metadata
145
+ self.axes, self.data_shape, self._bounding_rectangle, self._sample_axes = (
146
+ self._get_shape()
147
+ )
148
+ self.data_dtype = np.float32
149
+
150
+ def __del__(self):
151
+ if hasattr(self, "_czi"):
152
+ # Close CZI file
153
+ self._czi.close()
154
+
155
+ def __getstate__(self) -> dict[str, Any]:
156
+ # Remove CziReader object from state to avoid pickling issues
157
+ state = self.__dict__.copy()
158
+ del state["_czi"]
159
+ return state
160
+
161
+ def __setstate__(self, state: dict[str, Any]) -> None:
162
+ # Reopen CZI file after unpickling
163
+ self.__dict__.update(state)
164
+ self._czi = CziReader(str(self.data_path))
165
+
166
+ # TODO: we append the scene index to the file name
167
+ # - not sure if this is a good approach
168
+ @property
169
+ def source(self) -> Path:
170
+ filename = self.data_path.name
171
+ if self.scene is not None:
172
+ filename = f"{filename}@{self.scene}"
173
+ return self.data_path.parent / filename
174
+
175
+ def extract_patch(
176
+ self, sample_idx: int, coords: Sequence[int], patch_size: Sequence[int]
177
+ ) -> NDArray:
178
+ # Determine 3rd dimension (T, Z or none)
179
+ if len(coords) == 3:
180
+ if len(self.axes) != 5:
181
+ raise ValueError(
182
+ f"Requested a 3D patch from a 2D image stack with axes {self.axes}."
183
+ )
184
+ third_dim = self.axes[2]
185
+ third_dim_offset, third_dim_size = coords[0], patch_size[0]
186
+ else:
187
+ if len(self.axes) != 4:
188
+ raise ValueError(
189
+ f"Requested a 2D patch from a 3D image stack with axes {self.axes}."
190
+ )
191
+ third_dim = None
192
+ third_dim_offset, third_dim_size = 0, 1
193
+
194
+ # Set up ROI to extract from each plane as (x, y, w, h)
195
+ roi = (
196
+ self._bounding_rectangle.x + coords[-1],
197
+ self._bounding_rectangle.y + coords[-2],
198
+ patch_size[-1],
199
+ patch_size[-2],
200
+ )
201
+
202
+ # Create output array of shape (C, Z, Y, X)
203
+ patch = np.empty(
204
+ (self.data_shape[1], third_dim_size, *patch_size[-2:]), dtype=np.float32
205
+ )
206
+
207
+ # Set up plane to index `sample_idx`
208
+ sample_shape = list(self._sample_axes.values())
209
+ sample_indices = np.unravel_index(sample_idx, sample_shape)
210
+ plane = {
211
+ dimension: int(index)
212
+ for dimension, index in zip(
213
+ self._sample_axes.keys(), sample_indices, strict=False
214
+ )
215
+ }
216
+
217
+ # Read XY planes sequentially
218
+ for channel in range(self.data_shape[1]):
219
+ for third_dim_index in range(third_dim_size):
220
+ plane["C"] = channel
221
+ if third_dim is not None:
222
+ plane[third_dim] = third_dim_offset + third_dim_index
223
+ extracted_roi = self._czi.read(roi=roi, plane=plane, scene=self.scene)
224
+ if extracted_roi.ndim == 3:
225
+ if extracted_roi.shape[-1] > 1:
226
+ raise ValueError(
227
+ "CZI files with RGB channels are currently not supported."
228
+ )
229
+ extracted_roi = extracted_roi.squeeze(-1)
230
+ patch[channel, third_dim_index] = extracted_roi
231
+
232
+ # Remove dummy 3rd dimension for 2-D data
233
+ if third_dim is None:
234
+ patch = patch.squeeze(1)
235
+
236
+ return patch
237
+
238
+ def _get_shape(self) -> tuple[str, list[int], Rectangle, dict[str, int]]:
239
+ """Determines the shape of the selected scene.
240
+
241
+ Returns
242
+ -------
243
+ axes : str
244
+ String specifying the axis order. Examples:
245
+
246
+ - "SCZYX" for 3-D volumes if `depth_axis` is `"Z"`.
247
+ - "SCTYX" for time-series if `depth_axis` is `"T"`.
248
+ - "SCYX" if `depth_axis` is `"none"`.
249
+
250
+ The axis `S` is the sample dimension and combines all remaining axes
251
+ present in the data.
252
+
253
+ shape : list[int]
254
+ The size of each axis, in the order listed in `axes`.
255
+
256
+ bounding_rectangle : Rectangle
257
+ The bounding rectangle of the scene in pixels. The rectangle is
258
+ defined by its top-left corner (x, y) and its width and height (w, h).
259
+
260
+ sample_axes : dict[str, int]
261
+ A dictionary with information about the remaining axes used for the
262
+ sample dimension.
263
+ The keys are the axis names (e.g., "T", "Z") and the values are their
264
+ respective sizes.
265
+ """
266
+ # Get CZI dimensions
267
+ total_bbox = self._czi.total_bounding_box_no_pyramid
268
+ if self.scene is None:
269
+ bounding_rectangle = self._czi.total_bounding_rectangle_no_pyramid
270
+ else:
271
+ bounding_rectangle = self._czi.scenes_bounding_rectangle_no_pyramid[
272
+ self.scene
273
+ ]
274
+
275
+ # Determine if T and Z axis are present
276
+ # Note: An axis of size 1 is as good as no axis since we cannot use it for 3-D
277
+ # denoising.
278
+ has_time = "T" in total_bbox and (total_bbox["T"][1] - total_bbox["T"][0]) > 1
279
+ has_depth = "Z" in total_bbox and (total_bbox["Z"][1] - total_bbox["Z"][0]) > 1
280
+
281
+ # Determine axis order depending on `depth_axis`
282
+ if self._depth_axis == "Z":
283
+ axes = "SCZYX"
284
+ if not has_depth:
285
+ raise RuntimeError(
286
+ f"The CZI file {self.data_path} does not contain a Z axis to use "
287
+ 'for 3-D denoising. Consider setting `axes="YX"` or '
288
+ '`depth_axis="none"` to perform 2-D denoising instead.'
289
+ )
290
+ elif self._depth_axis == "T":
291
+ axes = "SCTYX"
292
+ if not has_time:
293
+ raise RuntimeError(
294
+ f"The CZI file {self.data_path} does not contain a T axis to use "
295
+ 'for 3-D denoising. Consider setting `axes="YX"` or '
296
+ '`depth_axis="none"` to perform 2-D denoising instead.'
297
+ )
298
+ else:
299
+ axes = "SCYX"
300
+
301
+ # Calculcate size of sample dimension S, combining all axes not used elsewhere.
302
+ # This could, for example, be a time axis. If we only perform 2-D denoising, a
303
+ # potentially present Z axis would also be used as sample dimension. If both,
304
+ # T and Z, are present, both need to be combined into the sample dimension.
305
+ # The same needs to be done to any other potentially present axis in the CZI
306
+ # file which is not a spatial or channel axis.
307
+ # The following code calculates the size of the combined sample axis.
308
+ sample_axes = {}
309
+ sample_size = 1
310
+ for dimension, (start, end) in total_bbox.items():
311
+ if dimension not in axes:
312
+ sample_axes[dimension] = end - start
313
+ sample_size *= end - start
314
+
315
+ # Determine data shape
316
+ shape = []
317
+ for dimension in axes:
318
+ if dimension == "S":
319
+ shape.append(sample_size)
320
+ elif dimension == "Y":
321
+ shape.append(bounding_rectangle.h)
322
+ elif dimension == "X":
323
+ shape.append(bounding_rectangle.w)
324
+ elif dimension in total_bbox:
325
+ shape.append(total_bbox[dimension][1] - total_bbox[dimension][0])
326
+ else:
327
+ shape.append(1)
328
+
329
+ return axes, shape, bounding_rectangle, sample_axes
330
+
331
+ @classmethod
332
+ def get_bounding_rectangles(
333
+ cls, czi: Path | str | CziReader
334
+ ) -> dict[int | None, Rectangle]:
335
+ """Gets the bounding rectangles of all scenes in a CZI file.
336
+
337
+ Parameters
338
+ ----------
339
+ czi : Path or str or pyczi.CziReader
340
+ Path to the CZI file or an already opened file as CziReader object.
341
+
342
+ Returns
343
+ -------
344
+ dict[int | None, Rectangle]
345
+ A dictionary mapping scene indices to their bounding rectangles in the
346
+ format `(x, y, w, h)`.
347
+ If no scenes are present in the CZI file, the returned dictionary will
348
+ have only one entry with key `None`, whose bounding rectangle covers the
349
+ entire image.
350
+ """
351
+ if not isinstance(czi, CziReader):
352
+ with open_czi(str(czi)) as czi_reader:
353
+ return cls.get_bounding_rectangles(czi_reader)
354
+
355
+ scenes_bounding_rectangle = czi.scenes_bounding_rectangle_no_pyramid
356
+ if len(scenes_bounding_rectangle) >= 1:
357
+ # Ensure keys are int | None for type compatibility
358
+ return {int(k): v for k, v in scenes_bounding_rectangle.items()}
359
+ else:
360
+ return {None: czi.total_bounding_rectangle_no_pyramid}
@@ -1,6 +1,6 @@
1
1
  from collections.abc import Sequence
2
2
  from pathlib import Path
3
- from typing import Literal, Protocol, Union
3
+ from typing import Literal, Protocol, TypeVar, Union
4
4
 
5
5
  from numpy.typing import DTypeLike, NDArray
6
6
 
@@ -51,3 +51,7 @@ class ImageStack(Protocol):
51
51
  A patch of the image data from a particlular sample. It will have the
52
52
  dimensions C(Z)YX.
53
53
  """
54
+ ...
55
+
56
+
57
+ GenericImageStack = TypeVar("GenericImageStack", bound=ImageStack, covariant=True)
@@ -31,7 +31,7 @@ class InMemoryImageStack:
31
31
  (
32
32
  sample_idx, # type: ignore
33
33
  ..., # type: ignore
34
- *[slice(c, c + e) for c, e in zip(coords, patch_size)], # type: ignore
34
+ *[slice(c, c + e) for c, e in zip(coords, patch_size, strict=False)], # type: ignore
35
35
  )
36
36
  ]
37
37
 
@@ -1,20 +1,11 @@
1
1
  from collections.abc import Sequence
2
- from pathlib import Path
3
- from typing import (
4
- Any,
5
- Optional,
6
- Protocol,
7
- Union,
8
- )
9
-
10
- from numpy.typing import NDArray
2
+ from typing import Any, Protocol
3
+
11
4
  from typing_extensions import ParamSpec
12
5
 
13
- from careamics.config.support import SupportedData
14
- from careamics.file_io.read import ReadFunc
15
6
  from careamics.utils import BaseEnum
16
7
 
17
- from .image_stack import ImageStack, InMemoryImageStack, ZarrImageStack
8
+ from .image_stack import GenericImageStack
18
9
 
19
10
  P = ParamSpec("P")
20
11
 
@@ -23,7 +14,7 @@ class SupportedDataDev(str, BaseEnum):
23
14
  ZARR = "zarr"
24
15
 
25
16
 
26
- class ImageStackLoader(Protocol[P]):
17
+ class ImageStackLoader(Protocol[P, GenericImageStack]):
27
18
  """
28
19
  Protocol to define how `ImageStacks` should be loaded.
29
20
 
@@ -76,65 +67,4 @@ class ImageStackLoader(Protocol[P]):
76
67
 
77
68
  def __call__(
78
69
  self, source: Any, axes: str, *args: P.args, **kwargs: P.kwargs
79
- ) -> Sequence[ImageStack]: ...
80
-
81
-
82
- def from_arrays(
83
- source: Sequence[NDArray], axes: str, *args, **kwargs
84
- ) -> list[InMemoryImageStack]:
85
- return [InMemoryImageStack.from_array(data=array, axes=axes) for array in source]
86
-
87
-
88
- # TODO: change source to directory path? Like in current implementation
89
- # Advantage of having a list is the user can match input and target order themselves
90
- def from_tiff_files(
91
- source: Sequence[Path], axes: str, *args, **kwargs
92
- ) -> list[InMemoryImageStack]:
93
- return [InMemoryImageStack.from_tiff(path=path, axes=axes) for path in source]
94
-
95
-
96
- # TODO: change source to directory path? Like in current implementation
97
- # Advantage of having a list is the user can match input and target order themselves
98
- def from_custom_file_type(
99
- source: Sequence[Path],
100
- axes: str,
101
- read_func: ReadFunc,
102
- read_kwargs: dict[str, Any],
103
- *args,
104
- **kwargs,
105
- ) -> list[InMemoryImageStack]:
106
- return [
107
- InMemoryImageStack.from_custom_file_type(
108
- path=path,
109
- axes=axes,
110
- read_func=read_func,
111
- **read_kwargs,
112
- )
113
- for path in source
114
- ]
115
-
116
-
117
- def from_ome_zarr_files(
118
- source: Sequence[Path], axes: str, *args, **kwargs
119
- ) -> list[ZarrImageStack]:
120
- # NOTE: axes is unused here, in from_ome_zarr the axes are automatically retrieved
121
- return [ZarrImageStack.from_ome_zarr(path) for path in source]
122
-
123
-
124
- def get_image_stack_loader(
125
- data_type: Union[SupportedData, SupportedDataDev],
126
- image_stack_loader: Optional[ImageStackLoader] = None,
127
- ) -> ImageStackLoader:
128
- if data_type == SupportedData.ARRAY:
129
- return from_arrays
130
- elif data_type == SupportedData.TIFF:
131
- return from_tiff_files
132
- elif data_type == "zarr": # temp for testing until zarr is added to SupportedData
133
- return from_ome_zarr_files
134
- elif data_type == SupportedData.CUSTOM:
135
- if image_stack_loader is None:
136
- return from_custom_file_type
137
- else:
138
- return image_stack_loader
139
- else:
140
- raise ValueError
70
+ ) -> Sequence[GenericImageStack]: ...
@@ -1,17 +1,18 @@
1
1
  from collections.abc import Sequence
2
+ from typing import Generic
2
3
 
3
4
  from numpy.typing import NDArray
4
5
 
5
- from .image_stack import ImageStack
6
+ from .image_stack import GenericImageStack
6
7
 
7
8
 
8
- class PatchExtractor:
9
+ class PatchExtractor(Generic[GenericImageStack]):
9
10
  """
10
11
  A class for extracting patches from multiple image stacks.
11
12
  """
12
13
 
13
- def __init__(self, image_stacks: Sequence[ImageStack]):
14
- self.image_stacks: list[ImageStack] = list(image_stacks)
14
+ def __init__(self, image_stacks: Sequence[GenericImageStack]):
15
+ self.image_stacks: list[GenericImageStack] = list(image_stacks)
15
16
 
16
17
  def extract_patch(
17
18
  self,