careamics 0.0.11__py3-none-any.whl → 0.0.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (53) hide show
  1. careamics/careamist.py +20 -4
  2. careamics/config/configuration.py +10 -5
  3. careamics/config/data/data_model.py +38 -1
  4. careamics/config/optimizer_models.py +1 -3
  5. careamics/config/training_model.py +0 -2
  6. careamics/dataset_ng/README.md +212 -0
  7. careamics/dataset_ng/dataset.py +233 -0
  8. careamics/dataset_ng/demos/bsd68_demo.ipynb +356 -0
  9. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  10. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +734 -0
  11. careamics/dataset_ng/demos/demo_datamodule.ipynb +443 -0
  12. careamics/dataset_ng/{demo_dataset.ipynb → demos/demo_dataset.ipynb} +39 -15
  13. careamics/dataset_ng/{demo_patch_extractor.py → demos/demo_patch_extractor.py} +7 -9
  14. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +292 -0
  15. careamics/dataset_ng/factory.py +408 -0
  16. careamics/dataset_ng/legacy_interoperability.py +168 -0
  17. careamics/dataset_ng/patch_extractor/__init__.py +3 -8
  18. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +6 -4
  19. careamics/dataset_ng/patch_extractor/image_stack/__init__.py +2 -1
  20. careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +5 -1
  21. careamics/dataset_ng/patch_extractor/image_stack_loader.py +5 -75
  22. careamics/dataset_ng/patch_extractor/patch_extractor.py +5 -4
  23. careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +73 -106
  24. careamics/dataset_ng/patching_strategies/__init__.py +6 -1
  25. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +31 -0
  26. careamics/dataset_ng/patching_strategies/random_patching.py +3 -1
  27. careamics/dataset_ng/patching_strategies/tiling_strategy.py +171 -0
  28. careamics/dataset_ng/patching_strategies/whole_sample.py +36 -0
  29. careamics/lightning/dataset_ng/data_module.py +488 -0
  30. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  31. careamics/lightning/dataset_ng/lightning_modules/care_module.py +58 -0
  32. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +67 -0
  33. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +143 -0
  34. careamics/lightning/lightning_module.py +3 -0
  35. careamics/lvae_training/dataset/__init__.py +8 -3
  36. careamics/lvae_training/dataset/config.py +3 -3
  37. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  38. careamics/lvae_training/dataset/multich_dataset.py +46 -17
  39. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  40. careamics/lvae_training/dataset/types.py +3 -3
  41. careamics/lvae_training/dataset/utils/index_manager.py +259 -0
  42. careamics/lvae_training/eval_utils.py +93 -3
  43. careamics/transforms/compose.py +1 -0
  44. careamics/transforms/normalize.py +18 -7
  45. careamics/utils/lightning_utils.py +25 -11
  46. {careamics-0.0.11.dist-info → careamics-0.0.12.dist-info}/METADATA +3 -3
  47. {careamics-0.0.11.dist-info → careamics-0.0.12.dist-info}/RECORD +50 -35
  48. careamics/dataset_ng/dataset/__init__.py +0 -3
  49. careamics/dataset_ng/dataset/dataset.py +0 -184
  50. careamics/dataset_ng/demo_patch_extractor_factory.py +0 -37
  51. {careamics-0.0.11.dist-info → careamics-0.0.12.dist-info}/WHEEL +0 -0
  52. {careamics-0.0.11.dist-info → careamics-0.0.12.dist-info}/entry_points.txt +0 -0
  53. {careamics-0.0.11.dist-info → careamics-0.0.12.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,168 @@
1
+ """
2
+ A module for utility functions that adapts the new dataset outputs to work with previous
3
+ code until it is updated.
4
+ """
5
+
6
+ from collections.abc import Sequence
7
+ from typing import cast
8
+
9
+ import numpy as np
10
+ from numpy.typing import NDArray
11
+
12
+ from careamics.config.tile_information import TileInformation
13
+
14
+ from .dataset import ImageRegionData
15
+ from .patching_strategies import TileSpecs
16
+
17
+
18
+ def imageregions_to_tileinfos(
19
+ image_regions: Sequence[ImageRegionData],
20
+ ) -> list[tuple[NDArray, list[TileInformation]]]:
21
+ """
22
+ Converts a series of `TileSpecs` dictionaries to `TileInformation` pydantic class.
23
+
24
+ Parameters
25
+ ----------
26
+ image_regions : sequence of ImageRegionData
27
+ A list of ImageRegionData, it must have an instance of `TileSpecs` as it's
28
+ `region_data` field.
29
+
30
+ Returns
31
+ -------
32
+ list of TileInformation
33
+ The converted tile information.
34
+ """
35
+
36
+ tile_infos: list[TileInformation] = []
37
+
38
+ data = [image_region.data for image_region in image_regions]
39
+ tile_specs = [image_region.region_spec for image_region in image_regions]
40
+
41
+ data_indices: NDArray[np.int_] = np.array(
42
+ [tile_spec["data_idx"] for tile_spec in tile_specs], dtype=int
43
+ )
44
+ unique_data_indices = np.unique(data_indices)
45
+ # data_idx denotes which image stack a patch belongs to
46
+ # separate TileSpecs by image_stack
47
+ for data_idx in unique_data_indices:
48
+ # collect all ImageRegions
49
+ data_image_regions: list[ImageRegionData] = [
50
+ image_region
51
+ for image_region in image_regions
52
+ if image_region.region_spec["data_idx"] == data_idx
53
+ ]
54
+
55
+ # --- find last indices
56
+ # make sure tiles belonging to the same sample are together
57
+ data_image_regions.sort(
58
+ key=lambda image_region: image_region.region_spec["sample_idx"]
59
+ )
60
+ sample_indices = np.array(
61
+ [
62
+ image_region.region_spec["sample_idx"]
63
+ for image_region in data_image_regions
64
+ ]
65
+ )
66
+ # reverse array so indices returned are at far edge
67
+ _, unique_indices = np.unique(sample_indices[::-1], return_index=True)
68
+ # un reverse indices
69
+ last_indices = len(sample_indices) - 1 - unique_indices
70
+
71
+ # convert each ImageRegionData to tile_info
72
+ for i, image_region in enumerate(data_image_regions):
73
+ last_tile = i in last_indices
74
+ tile_info = _imageregion_to_tileinfo(image_region, last_tile)
75
+ tile_infos.append(tile_info)
76
+
77
+ return [(data, [tile_info]) for data, tile_info in zip(data, tile_infos)]
78
+
79
+
80
+ def _imageregion_to_tileinfo(
81
+ image_region: ImageRegionData, last_tile: bool
82
+ ) -> TileInformation:
83
+ """
84
+ Convert a single `ImageRegionData` instance to a `TileInformation` instance. Whether
85
+ it is the last tile in a sequence needs to be supplied.
86
+
87
+ Parameters
88
+ ----------
89
+ image_region : ImageRegionData
90
+ An instance of `ImageRegionData`, it must have an instance of `TileSpecs` as
91
+ it's `region_data` field.
92
+ last_tile : bool
93
+ Whether a tile is the last tile in a sequence, for stitching.
94
+
95
+ Returns
96
+ -------
97
+ TileInformation
98
+ A tile information object.
99
+
100
+ Raises
101
+ ------
102
+ KeyError
103
+ If `image_region.region_spec` does not contain the keys: {'crop_coords',
104
+ 'crop_size', 'stitch_coords'}.
105
+ """
106
+ patch_spec = image_region.region_spec
107
+ data_shape = image_region.data_shape
108
+
109
+ # TODO: In python 3.11 and greater, NamedTuples can inherit from Generic
110
+ # so we could do image_region: ImageRegionData[TileSpecs]
111
+ # and not have to do this check here + cast
112
+ # make sure image_region.region_spec is TileSpec
113
+ if (
114
+ ("crop_coords" not in patch_spec)
115
+ or ("crop_size" not in patch_spec)
116
+ or ("stitch_coords" not in patch_spec)
117
+ ):
118
+ raise KeyError(
119
+ "Could not find all keys: {'crop_coords', 'crop_size', 'stitch_coords'} in "
120
+ "`image_region.region_spec`."
121
+ )
122
+ tile_spec = cast(TileSpecs, patch_spec) # ugly cast for mypy
123
+ return _tilespec_to_tileinfo(tile_spec, data_shape, last_tile)
124
+
125
+
126
+ def _tilespec_to_tileinfo(
127
+ tile_spec: TileSpecs, data_shape: Sequence[int], last_tile: bool
128
+ ) -> TileInformation:
129
+ """
130
+ Convert a single `TileSpec` to a `TileInformation`. Whether it is the last tile
131
+ needs to be supplied.
132
+
133
+ Parameters
134
+ ----------
135
+ tile_spec : TileSpecs
136
+ A tile spec dictionary.
137
+ data_shape : sequence of int
138
+ The original shape of the data the tile came from, labeling the dimensions of
139
+ axes SC(Z)YX.
140
+ last_tile : bool
141
+ Whether a tile is the last tile in a sequence, for stitching.
142
+
143
+ Returns
144
+ -------
145
+ TileInformation
146
+ A tile information object.
147
+ """
148
+ overlap_crop_coords = tuple(
149
+ (
150
+ tile_spec["crop_coords"][i],
151
+ tile_spec["crop_coords"][i] + tile_spec["crop_size"][i],
152
+ )
153
+ for i in range(len(tile_spec["crop_coords"]))
154
+ )
155
+ stitch_coords = tuple(
156
+ (
157
+ tile_spec["stitch_coords"][i],
158
+ tile_spec["stitch_coords"][i] + tile_spec["crop_size"][i],
159
+ )
160
+ for i in range(len(tile_spec["crop_coords"]))
161
+ )
162
+ return TileInformation(
163
+ array_shape=tuple(data_shape[1:]), # remove sample dimension
164
+ last_tile=last_tile,
165
+ overlap_crop_coords=overlap_crop_coords,
166
+ stitch_coords=stitch_coords,
167
+ sample_id=tile_spec["sample_idx"],
168
+ )
@@ -1,10 +1,5 @@
1
- __all__ = [
2
- "ImageStackLoader",
3
- "PatchExtractor",
4
- "create_patch_extractor",
5
- "get_image_stack_loader",
6
- ]
1
+ __all__ = ["GenericImageStack", "ImageStackLoader", "PatchExtractor"]
7
2
 
8
- from .image_stack_loader import ImageStackLoader, get_image_stack_loader
3
+ from .image_stack import GenericImageStack
4
+ from .image_stack_loader import ImageStackLoader
9
5
  from .patch_extractor import PatchExtractor
10
- from .patch_extractor_factory import create_patch_extractor
@@ -11,9 +11,11 @@ from zarr.storage import FSStore
11
11
 
12
12
  from careamics.config import DataConfig
13
13
  from careamics.config.support import SupportedData
14
- from careamics.dataset_ng.patch_extractor import create_patch_extractor
15
14
  from careamics.dataset_ng.patch_extractor.image_stack import ZarrImageStack
16
15
  from careamics.dataset_ng.patch_extractor.image_stack_loader import ImageStackLoader
16
+ from careamics.dataset_ng.patch_extractor.patch_extractor_factory import (
17
+ create_custom_image_stack_extractor,
18
+ )
17
19
 
18
20
 
19
21
  # %%
@@ -94,12 +96,12 @@ image_stack_loader: ImageStackLoader = custom_image_stack_loader
94
96
 
95
97
  # %%
96
98
  # So pylance knows that datatype is custom to match function overloads
97
- assert data_config.data_type is SupportedData.CUSTOM
99
+ assert SupportedData(data_config.data_type) is SupportedData.CUSTOM
98
100
 
99
- patch_extractor = create_patch_extractor(
101
+ patch_extractor = create_custom_image_stack_extractor(
100
102
  source={"store": store, "data_paths": data_paths},
101
103
  axes=data_config.axes,
102
- data_type=data_config.data_type,
104
+ data_type=SupportedData(data_config.data_type),
103
105
  image_stack_loader=custom_image_stack_loader,
104
106
  )
105
107
 
@@ -1,9 +1,10 @@
1
1
  __all__ = [
2
+ "GenericImageStack",
2
3
  "ImageStack",
3
4
  "InMemoryImageStack",
4
5
  "ZarrImageStack",
5
6
  ]
6
7
 
7
- from .image_stack_protocol import ImageStack
8
+ from .image_stack_protocol import GenericImageStack, ImageStack
8
9
  from .in_memory_image_stack import InMemoryImageStack
9
10
  from .zarr_image_stack import ZarrImageStack
@@ -1,6 +1,6 @@
1
1
  from collections.abc import Sequence
2
2
  from pathlib import Path
3
- from typing import Literal, Protocol, Union
3
+ from typing import Literal, Protocol, TypeVar, Union
4
4
 
5
5
  from numpy.typing import DTypeLike, NDArray
6
6
 
@@ -51,3 +51,7 @@ class ImageStack(Protocol):
51
51
  A patch of the image data from a particlular sample. It will have the
52
52
  dimensions C(Z)YX.
53
53
  """
54
+ ...
55
+
56
+
57
+ GenericImageStack = TypeVar("GenericImageStack", bound=ImageStack, covariant=True)
@@ -1,20 +1,11 @@
1
1
  from collections.abc import Sequence
2
- from pathlib import Path
3
- from typing import (
4
- Any,
5
- Optional,
6
- Protocol,
7
- Union,
8
- )
9
-
10
- from numpy.typing import NDArray
2
+ from typing import Any, Protocol
3
+
11
4
  from typing_extensions import ParamSpec
12
5
 
13
- from careamics.config.support import SupportedData
14
- from careamics.file_io.read import ReadFunc
15
6
  from careamics.utils import BaseEnum
16
7
 
17
- from .image_stack import ImageStack, InMemoryImageStack, ZarrImageStack
8
+ from .image_stack import GenericImageStack
18
9
 
19
10
  P = ParamSpec("P")
20
11
 
@@ -23,7 +14,7 @@ class SupportedDataDev(str, BaseEnum):
23
14
  ZARR = "zarr"
24
15
 
25
16
 
26
- class ImageStackLoader(Protocol[P]):
17
+ class ImageStackLoader(Protocol[P, GenericImageStack]):
27
18
  """
28
19
  Protocol to define how `ImageStacks` should be loaded.
29
20
 
@@ -76,65 +67,4 @@ class ImageStackLoader(Protocol[P]):
76
67
 
77
68
  def __call__(
78
69
  self, source: Any, axes: str, *args: P.args, **kwargs: P.kwargs
79
- ) -> Sequence[ImageStack]: ...
80
-
81
-
82
- def from_arrays(
83
- source: Sequence[NDArray], axes: str, *args, **kwargs
84
- ) -> list[InMemoryImageStack]:
85
- return [InMemoryImageStack.from_array(data=array, axes=axes) for array in source]
86
-
87
-
88
- # TODO: change source to directory path? Like in current implementation
89
- # Advantage of having a list is the user can match input and target order themselves
90
- def from_tiff_files(
91
- source: Sequence[Path], axes: str, *args, **kwargs
92
- ) -> list[InMemoryImageStack]:
93
- return [InMemoryImageStack.from_tiff(path=path, axes=axes) for path in source]
94
-
95
-
96
- # TODO: change source to directory path? Like in current implementation
97
- # Advantage of having a list is the user can match input and target order themselves
98
- def from_custom_file_type(
99
- source: Sequence[Path],
100
- axes: str,
101
- read_func: ReadFunc,
102
- read_kwargs: dict[str, Any],
103
- *args,
104
- **kwargs,
105
- ) -> list[InMemoryImageStack]:
106
- return [
107
- InMemoryImageStack.from_custom_file_type(
108
- path=path,
109
- axes=axes,
110
- read_func=read_func,
111
- **read_kwargs,
112
- )
113
- for path in source
114
- ]
115
-
116
-
117
- def from_ome_zarr_files(
118
- source: Sequence[Path], axes: str, *args, **kwargs
119
- ) -> list[ZarrImageStack]:
120
- # NOTE: axes is unused here, in from_ome_zarr the axes are automatically retrieved
121
- return [ZarrImageStack.from_ome_zarr(path) for path in source]
122
-
123
-
124
- def get_image_stack_loader(
125
- data_type: Union[SupportedData, SupportedDataDev],
126
- image_stack_loader: Optional[ImageStackLoader] = None,
127
- ) -> ImageStackLoader:
128
- if data_type == SupportedData.ARRAY:
129
- return from_arrays
130
- elif data_type == SupportedData.TIFF:
131
- return from_tiff_files
132
- elif data_type == "zarr": # temp for testing until zarr is added to SupportedData
133
- return from_ome_zarr_files
134
- elif data_type == SupportedData.CUSTOM:
135
- if image_stack_loader is None:
136
- return from_custom_file_type
137
- else:
138
- return image_stack_loader
139
- else:
140
- raise ValueError
70
+ ) -> Sequence[GenericImageStack]: ...
@@ -1,17 +1,18 @@
1
1
  from collections.abc import Sequence
2
+ from typing import Generic
2
3
 
3
4
  from numpy.typing import NDArray
4
5
 
5
- from .image_stack import ImageStack
6
+ from .image_stack import GenericImageStack
6
7
 
7
8
 
8
- class PatchExtractor:
9
+ class PatchExtractor(Generic[GenericImageStack]):
9
10
  """
10
11
  A class for extracting patches from multiple image stacks.
11
12
  """
12
13
 
13
- def __init__(self, image_stacks: Sequence[ImageStack]):
14
- self.image_stacks: list[ImageStack] = list(image_stacks)
14
+ def __init__(self, image_stacks: Sequence[GenericImageStack]):
15
+ self.image_stacks: list[GenericImageStack] = list(image_stacks)
15
16
 
16
17
  def extract_patch(
17
18
  self,
@@ -1,29 +1,29 @@
1
1
  from collections.abc import Sequence
2
2
  from pathlib import Path
3
- from typing import Any, Literal, Optional, Union, overload
3
+ from typing import Any
4
4
 
5
5
  from numpy.typing import NDArray
6
6
  from typing_extensions import ParamSpec
7
7
 
8
- from careamics.config.support import SupportedData
9
8
  from careamics.dataset_ng.patch_extractor import PatchExtractor
10
9
  from careamics.file_io.read import ReadFunc
11
10
 
11
+ from .image_stack import (
12
+ GenericImageStack,
13
+ InMemoryImageStack,
14
+ ZarrImageStack,
15
+ )
12
16
  from .image_stack_loader import (
13
17
  ImageStackLoader,
14
- SupportedDataDev,
15
- get_image_stack_loader,
16
18
  )
17
19
 
18
20
  P = ParamSpec("P")
19
21
 
20
22
 
21
- # Define overloads for each implemented ImageStackLoader case
22
23
  # Array case
23
- @overload
24
- def create_patch_extractor(
25
- source: Sequence[NDArray], axes: str, data_type: Literal[SupportedData.ARRAY]
26
- ) -> PatchExtractor:
24
+ def create_array_extractor(
25
+ source: Sequence[NDArray[Any]], axes: str
26
+ ) -> PatchExtractor[InMemoryImageStack]:
27
27
  """
28
28
  Create a patch extractor from a sequence of numpy arrays.
29
29
 
@@ -31,57 +31,78 @@ def create_patch_extractor(
31
31
  ----------
32
32
  source: sequence of numpy.ndarray
33
33
  The source arrays of the data.
34
- data_config: DataConfig
35
- The data configuration, `data_config.data_type` should have the value "array",
36
- and `data_config.axes` should describe the axes of every array in the `source`.
34
+ axes: str
35
+ The original axes of the data, must be a subset of "STCZYX".
37
36
 
38
37
  Returns
39
38
  -------
40
39
  PatchExtractor
41
40
  """
41
+ image_stacks = [
42
+ InMemoryImageStack.from_array(data=array, axes=axes) for array in source
43
+ ]
44
+ return PatchExtractor(image_stacks)
45
+
42
46
 
47
+ # TIFF case
48
+ def create_tiff_extractor(
49
+ source: Sequence[Path], axes: str
50
+ ) -> PatchExtractor[InMemoryImageStack]:
51
+ """
52
+ Create a patch extractor from a sequence of TIFF files.
43
53
 
44
- # TIFF and ZARR case
45
- @overload
46
- def create_patch_extractor(
54
+ Parameters
55
+ ----------
56
+ source: sequence of Path
57
+ The source files for the data.
58
+ axes: str
59
+ The original axes of the data, must be a subset of "STCZYX".
60
+
61
+ Returns
62
+ -------
63
+ PatchExtractor
64
+ """
65
+ image_stacks = [
66
+ InMemoryImageStack.from_tiff(path=path, axes=axes) for path in source
67
+ ]
68
+ return PatchExtractor(image_stacks)
69
+
70
+
71
+ # ZARR case
72
+ def create_ome_zarr_extractor(
47
73
  source: Sequence[Path],
48
74
  axes: str,
49
- data_type: Literal[SupportedData.TIFF, SupportedDataDev.ZARR],
50
- ) -> PatchExtractor:
75
+ ) -> PatchExtractor[ZarrImageStack]:
51
76
  """
52
- Create a patch extractor from a sequence of files that match our supported types.
77
+ Create a patch extractor from a sequence of OME ZARR files.
53
78
 
54
- Supported file types include TIFF and ZARR.
55
-
56
- If the files are ZARR files they must follow the OME standard. If you have ZARR
57
- files that do not follow the OME standard, see documentation on how to create
58
- a custom `image_stack_loader`. (TODO: Add link).
79
+ If you have ZARR files that do not follow the OME standard, see documentation on
80
+ how to create a custom `image_stack_loader`. (TODO: Add link).
59
81
 
60
82
  Parameters
61
83
  ----------
62
84
  source: sequence of Path
63
85
  The source files for the data.
64
- data_config: DataConfig
65
- The data configuration, `data_config.data_type` should have the value "tiff" or
66
- "zarr", and `data_config.axes` should describe the axes of every image in the
67
- `source`.
86
+ axes: str
87
+ The original axes of the data, must be a subset of "STCZYX".
68
88
 
69
89
  Returns
70
90
  -------
71
91
  PatchExtractor
72
92
  """
93
+ # NOTE: axes is unused here, in from_ome_zarr the axes are automatically retrieved
94
+ image_stacks = [ZarrImageStack.from_ome_zarr(path) for path in source]
95
+ return PatchExtractor(image_stacks)
73
96
 
74
97
 
75
98
  # Custom file type case (loaded into memory)
76
- @overload
77
- def create_patch_extractor(
78
- source: Any,
99
+ def create_custom_file_extractor(
100
+ source: Sequence[Path],
79
101
  axes: str,
80
- data_type: Literal[SupportedData.CUSTOM],
81
102
  *,
82
103
  read_func: ReadFunc,
83
104
  read_kwargs: dict[str, Any],
84
- ) -> PatchExtractor:
105
+ ) -> PatchExtractor[InMemoryImageStack]:
85
106
  """
86
107
  Create a patch extractor from a sequence of files of a custom type.
87
108
 
@@ -89,8 +110,8 @@ def create_patch_extractor(
89
110
  ----------
90
111
  source: sequence of Path
91
112
  The source files for the data.
92
- data_config: DataConfig
93
- The data configuration, `data_config.data_type` should have the value "custom".
113
+ axes: str
114
+ The original axes of the data, must be a subset of "STCZYX".
94
115
  read_func : ReadFunc
95
116
  A function to read the custom file type, see the `ReadFunc` protocol.
96
117
  read_kwargs : dict of {str: Any}
@@ -100,18 +121,28 @@ def create_patch_extractor(
100
121
  -------
101
122
  PatchExtractor
102
123
  """
124
+ # TODO: lazy loading custom files
125
+ image_stacks = [
126
+ InMemoryImageStack.from_custom_file_type(
127
+ path=path,
128
+ axes=axes,
129
+ read_func=read_func,
130
+ **read_kwargs,
131
+ )
132
+ for path in source
133
+ ]
134
+
135
+ return PatchExtractor(image_stacks)
103
136
 
104
137
 
105
138
  # Custom ImageStackLoader case
106
- @overload
107
- def create_patch_extractor(
139
+ def create_custom_image_stack_extractor(
108
140
  source: Any,
109
141
  axes: str,
110
- data_type: Literal[SupportedData.CUSTOM],
111
- image_stack_loader: ImageStackLoader[P],
142
+ image_stack_loader: ImageStackLoader[P, GenericImageStack],
112
143
  *args: P.args,
113
144
  **kwargs: P.kwargs,
114
- ) -> PatchExtractor:
145
+ ) -> PatchExtractor[GenericImageStack]:
115
146
  """
116
147
  Create a patch extractor using a custom `ImageStackLoader`.
117
148
 
@@ -127,8 +158,8 @@ def create_patch_extractor(
127
158
  ----------
128
159
  source: sequence of Path
129
160
  The source files for the data.
130
- data_config: DataConfig
131
- The data configuration, `data_config.data_type` should have the value "custom".
161
+ axes: str
162
+ The original axes of the data, must be a subset of "STCZYX".
132
163
  image_stack_loader: ImageStackLoader
133
164
  A custom image stack loader callable.
134
165
  *args: Any
@@ -140,69 +171,5 @@ def create_patch_extractor(
140
171
  -------
141
172
  PatchExtractor
142
173
  """
143
-
144
-
145
- # final overload to match the implentation function signature
146
- # Need this so it works later in the code
147
- # (bec there aren't created overloads for create_patch_extractors below)
148
- @overload
149
- def create_patch_extractor(
150
- source: Any,
151
- axes: str,
152
- data_type: Union[SupportedData, SupportedDataDev],
153
- image_stack_loader: Optional[ImageStackLoader[P]] = None,
154
- *args: P.args,
155
- **kwargs: P.kwargs,
156
- ) -> PatchExtractor: ...
157
-
158
-
159
- def create_patch_extractor(
160
- source: Any,
161
- axes: str,
162
- data_type: Union[SupportedData, SupportedDataDev],
163
- image_stack_loader: Optional[ImageStackLoader[P]] = None,
164
- *args: P.args,
165
- **kwargs: P.kwargs,
166
- ) -> PatchExtractor:
167
- # TODO: Do we need to catch data_config.data_type and source mismatches?
168
- # e.g. data_config.data_type is "array" but source is not Sequence[NDArray]
169
- loader: ImageStackLoader[P] = get_image_stack_loader(data_type, image_stack_loader)
170
- image_stacks = loader(source, axes, *args, **kwargs)
174
+ image_stacks = image_stack_loader(source, axes, *args, **kwargs)
171
175
  return PatchExtractor(image_stacks)
172
-
173
-
174
- # TODO: Remove this and just call `create_patch_extractor` within the Dataset class
175
- # Keeping for consistency for now
176
- def create_patch_extractors(
177
- source: Any,
178
- target_source: Optional[Any],
179
- axes: str,
180
- data_type: Union[SupportedData, SupportedDataDev],
181
- image_stack_loader: Optional[ImageStackLoader] = None,
182
- *args,
183
- **kwargs,
184
- ) -> tuple[PatchExtractor, Optional[PatchExtractor]]:
185
-
186
- # --- data extractor
187
- patch_extractor: PatchExtractor = create_patch_extractor(
188
- source,
189
- axes,
190
- data_type,
191
- image_stack_loader,
192
- *args,
193
- **kwargs,
194
- )
195
- # --- optional target extractor
196
- if target_source is not None:
197
- target_patch_extractor = create_patch_extractor(
198
- target_source,
199
- axes,
200
- data_type,
201
- image_stack_loader,
202
- *args,
203
- **kwargs,
204
- )
205
-
206
- return patch_extractor, target_patch_extractor
207
-
208
- return patch_extractor, None
@@ -4,8 +4,13 @@ __all__ = [
4
4
  "PatchingStrategy",
5
5
  "RandomPatchingStrategy",
6
6
  "SequentialPatchingStrategy",
7
+ "TileSpecs",
8
+ "TilingStrategy",
9
+ "WholeSamplePatchingStrategy",
7
10
  ]
8
11
 
9
- from .patching_strategy_protocol import PatchingStrategy, PatchSpecs
12
+ from .patching_strategy_protocol import PatchingStrategy, PatchSpecs, TileSpecs
10
13
  from .random_patching import FixedRandomPatchingStrategy, RandomPatchingStrategy
11
14
  from .sequential_patching import SequentialPatchingStrategy
15
+ from .tiling_strategy import TilingStrategy
16
+ from .whole_sample import WholeSamplePatchingStrategy
@@ -28,6 +28,37 @@ class PatchSpecs(TypedDict):
28
28
  patch_size: Sequence[int]
29
29
 
30
30
 
31
+ class TileSpecs(PatchSpecs):
32
+ """A dictionary that specifies a single patch in a series of `ImageStacks`.
33
+
34
+ Attributes
35
+ ----------
36
+ data_idx: int
37
+ Determines which `ImageStack` a patch belongs to, within a series of
38
+ `ImageStack`s.
39
+ sample_idx: int
40
+ Determines which sample a patch belongs to, within an `ImageStack`.
41
+ coords: sequence of int
42
+ The top-left (and first z-slice for 3D data) of a patch. The sequence will have
43
+ length 2 or 3, for 2D and 3D data respectively.
44
+ patch_size: sequence of int
45
+ The size of the patch. The sequence will have length 2 or 3, for 2D and 3D data
46
+ respectively.
47
+ crop_coords: sequence of int
48
+ The top-left side of where the tile will be cropped, in coordinates relative
49
+ to the tile.
50
+ crop_size: sequence of int
51
+ The size of the cropped tile.
52
+ stitch_coords: sequence of int
53
+ Where the tile will be stitched back into an image, taking into account
54
+ that the tile will be cropped, in coords relative to the image.
55
+ """
56
+
57
+ crop_coords: Sequence[int]
58
+ crop_size: Sequence[int]
59
+ stitch_coords: Sequence[int]
60
+
61
+
31
62
  class PatchingStrategy(Protocol):
32
63
  """
33
64
  An interface for patching strategies.