careamics 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (98) hide show
  1. careamics/careamist.py +24 -7
  2. careamics/cli/utils.py +1 -1
  3. careamics/config/algorithms/n2v_algorithm_model.py +1 -1
  4. careamics/config/architectures/unet_model.py +3 -0
  5. careamics/config/callback_model.py +23 -34
  6. careamics/config/configuration.py +55 -4
  7. careamics/config/configuration_factories.py +288 -23
  8. careamics/config/data/__init__.py +2 -0
  9. careamics/config/data/data_model.py +41 -4
  10. careamics/config/data/ng_data_model.py +381 -0
  11. careamics/config/data/patching_strategies/__init__.py +14 -0
  12. careamics/config/data/patching_strategies/_overlapping_patched_model.py +103 -0
  13. careamics/config/data/patching_strategies/_patched_model.py +56 -0
  14. careamics/config/data/patching_strategies/random_patching_model.py +21 -0
  15. careamics/config/data/patching_strategies/sequential_patching_model.py +25 -0
  16. careamics/config/data/patching_strategies/tiled_patching_model.py +40 -0
  17. careamics/config/data/patching_strategies/whole_patching_model.py +12 -0
  18. careamics/config/inference_model.py +6 -3
  19. careamics/config/optimizer_models.py +1 -3
  20. careamics/config/support/supported_data.py +7 -0
  21. careamics/config/support/supported_patching_strategies.py +22 -0
  22. careamics/config/training_model.py +0 -2
  23. careamics/config/validators/validator_utils.py +4 -3
  24. careamics/dataset/dataset_utils/iterate_over_files.py +2 -2
  25. careamics/dataset/in_memory_dataset.py +2 -1
  26. careamics/dataset/iterable_dataset.py +2 -2
  27. careamics/dataset/iterable_pred_dataset.py +2 -2
  28. careamics/dataset/iterable_tiled_pred_dataset.py +2 -2
  29. careamics/dataset/patching/patching.py +3 -2
  30. careamics/dataset/tiling/lvae_tiled_patching.py +16 -6
  31. careamics/dataset/tiling/tiled_patching.py +2 -1
  32. careamics/dataset_ng/README.md +212 -0
  33. careamics/dataset_ng/dataset.py +229 -0
  34. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  35. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  36. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +734 -0
  37. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  38. careamics/dataset_ng/{demo_dataset.ipynb → demos/demo_dataset.ipynb} +60 -53
  39. careamics/dataset_ng/{demo_patch_extractor.py → demos/demo_patch_extractor.py} +7 -9
  40. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +292 -0
  41. careamics/dataset_ng/factory.py +451 -0
  42. careamics/dataset_ng/legacy_interoperability.py +170 -0
  43. careamics/dataset_ng/patch_extractor/__init__.py +3 -8
  44. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +7 -5
  45. careamics/dataset_ng/patch_extractor/image_stack/__init__.py +4 -1
  46. careamics/dataset_ng/patch_extractor/image_stack/czi_image_stack.py +360 -0
  47. careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +5 -1
  48. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -1
  49. careamics/dataset_ng/patch_extractor/image_stack_loader.py +5 -75
  50. careamics/dataset_ng/patch_extractor/patch_extractor.py +5 -4
  51. careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +114 -105
  52. careamics/dataset_ng/patching_strategies/__init__.py +6 -1
  53. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +31 -0
  54. careamics/dataset_ng/patching_strategies/random_patching.py +5 -1
  55. careamics/dataset_ng/patching_strategies/sequential_patching.py +5 -5
  56. careamics/dataset_ng/patching_strategies/tiling_strategy.py +172 -0
  57. careamics/dataset_ng/patching_strategies/whole_sample.py +36 -0
  58. careamics/file_io/read/get_func.py +2 -1
  59. careamics/lightning/dataset_ng/__init__.py +1 -0
  60. careamics/lightning/dataset_ng/data_module.py +678 -0
  61. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  62. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  63. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  64. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +212 -0
  65. careamics/lightning/lightning_module.py +5 -1
  66. careamics/lightning/predict_data_module.py +2 -1
  67. careamics/lightning/train_data_module.py +2 -1
  68. careamics/losses/loss_factory.py +2 -1
  69. careamics/lvae_training/dataset/__init__.py +8 -3
  70. careamics/lvae_training/dataset/config.py +3 -3
  71. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  72. careamics/lvae_training/dataset/multich_dataset.py +46 -17
  73. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  74. careamics/lvae_training/dataset/types.py +3 -3
  75. careamics/lvae_training/dataset/utils/index_manager.py +259 -0
  76. careamics/lvae_training/eval_utils.py +93 -3
  77. careamics/model_io/bioimage/bioimage_utils.py +1 -1
  78. careamics/model_io/bioimage/model_description.py +1 -1
  79. careamics/model_io/bmz_io.py +1 -1
  80. careamics/model_io/model_io_utils.py +2 -2
  81. careamics/models/activation.py +2 -1
  82. careamics/prediction_utils/prediction_outputs.py +1 -1
  83. careamics/prediction_utils/stitch_prediction.py +1 -1
  84. careamics/transforms/compose.py +1 -0
  85. careamics/transforms/n2v_manipulate_torch.py +15 -9
  86. careamics/transforms/normalize.py +18 -7
  87. careamics/transforms/pixel_manipulation_torch.py +59 -92
  88. careamics/utils/lightning_utils.py +25 -11
  89. careamics/utils/metrics.py +2 -1
  90. careamics/utils/torch_utils.py +23 -0
  91. {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/METADATA +12 -11
  92. {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/RECORD +95 -69
  93. careamics/dataset_ng/dataset/__init__.py +0 -3
  94. careamics/dataset_ng/dataset/dataset.py +0 -184
  95. careamics/dataset_ng/demo_patch_extractor_factory.py +0 -37
  96. {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/WHEEL +0 -0
  97. {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/entry_points.txt +0 -0
  98. {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/licenses/LICENSE +0 -0
@@ -1,29 +1,30 @@
1
1
  from collections.abc import Sequence
2
2
  from pathlib import Path
3
- from typing import Any, Literal, Optional, Union, overload
3
+ from typing import Any, Literal
4
4
 
5
5
  from numpy.typing import NDArray
6
6
  from typing_extensions import ParamSpec
7
7
 
8
- from careamics.config.support import SupportedData
9
8
  from careamics.dataset_ng.patch_extractor import PatchExtractor
10
9
  from careamics.file_io.read import ReadFunc
11
10
 
11
+ from .image_stack import (
12
+ CziImageStack,
13
+ GenericImageStack,
14
+ InMemoryImageStack,
15
+ ZarrImageStack,
16
+ )
12
17
  from .image_stack_loader import (
13
18
  ImageStackLoader,
14
- SupportedDataDev,
15
- get_image_stack_loader,
16
19
  )
17
20
 
18
21
  P = ParamSpec("P")
19
22
 
20
23
 
21
- # Define overloads for each implemented ImageStackLoader case
22
24
  # Array case
23
- @overload
24
- def create_patch_extractor(
25
- source: Sequence[NDArray], axes: str, data_type: Literal[SupportedData.ARRAY]
26
- ) -> PatchExtractor:
25
+ def create_array_extractor(
26
+ source: Sequence[NDArray[Any]], axes: str
27
+ ) -> PatchExtractor[InMemoryImageStack]:
27
28
  """
28
29
  Create a patch extractor from a sequence of numpy arrays.
29
30
 
@@ -31,57 +32,119 @@ def create_patch_extractor(
31
32
  ----------
32
33
  source: sequence of numpy.ndarray
33
34
  The source arrays of the data.
34
- data_config: DataConfig
35
- The data configuration, `data_config.data_type` should have the value "array",
36
- and `data_config.axes` should describe the axes of every array in the `source`.
35
+ axes: str
36
+ The original axes of the data, must be a subset of "STCZYX".
37
37
 
38
38
  Returns
39
39
  -------
40
40
  PatchExtractor
41
41
  """
42
+ image_stacks = [
43
+ InMemoryImageStack.from_array(data=array, axes=axes) for array in source
44
+ ]
45
+ return PatchExtractor(image_stacks)
46
+
47
+
48
+ # TIFF case
49
+ def create_tiff_extractor(
50
+ source: Sequence[Path], axes: str
51
+ ) -> PatchExtractor[InMemoryImageStack]:
52
+ """
53
+ Create a patch extractor from a sequence of TIFF files.
42
54
 
55
+ Parameters
56
+ ----------
57
+ source: sequence of Path
58
+ The source files for the data.
59
+ axes: str
60
+ The original axes of the data, must be a subset of "STCZYX".
61
+
62
+ Returns
63
+ -------
64
+ PatchExtractor
65
+ """
66
+ image_stacks = [
67
+ InMemoryImageStack.from_tiff(path=path, axes=axes) for path in source
68
+ ]
69
+ return PatchExtractor(image_stacks)
43
70
 
44
- # TIFF and ZARR case
45
- @overload
46
- def create_patch_extractor(
71
+
72
+ # ZARR case
73
+ def create_ome_zarr_extractor(
47
74
  source: Sequence[Path],
48
75
  axes: str,
49
- data_type: Literal[SupportedData.TIFF, SupportedDataDev.ZARR],
50
- ) -> PatchExtractor:
76
+ ) -> PatchExtractor[ZarrImageStack]:
77
+ """
78
+ Create a patch extractor from a sequence of OME ZARR files.
79
+
80
+ If you have ZARR files that do not follow the OME standard, see documentation on
81
+ how to create a custom `image_stack_loader`. (TODO: Add link).
82
+
83
+ Parameters
84
+ ----------
85
+ source: sequence of Path
86
+ The source files for the data.
87
+ axes: str
88
+ The original axes of the data, must be a subset of "STCZYX".
89
+
90
+ Returns
91
+ -------
92
+ PatchExtractor
51
93
  """
52
- Create a patch extractor from a sequence of files that match our supported types.
94
+ # NOTE: axes is unused here, in from_ome_zarr the axes are automatically retrieved
95
+ image_stacks = [ZarrImageStack.from_ome_zarr(path) for path in source]
96
+ return PatchExtractor(image_stacks)
97
+
53
98
 
54
- Supported file types include TIFF and ZARR.
99
+ # CZI case
100
+ def create_czi_extractor(
101
+ source: Sequence[Path],
102
+ axes: str,
103
+ ) -> PatchExtractor[CziImageStack]:
104
+ """
105
+ Create a patch extractor from a sequence of CZI files.
55
106
 
56
- If the files are ZARR files they must follow the OME standard. If you have ZARR
57
- files that do not follow the OME standard, see documentation on how to create
58
- a custom `image_stack_loader`. (TODO: Add link).
107
+ If the CZI files contain multiple scenes, one patch extractor will be created for
108
+ each scene.
59
109
 
60
110
  Parameters
61
111
  ----------
62
112
  source: sequence of Path
63
113
  The source files for the data.
64
- data_config: DataConfig
65
- The data configuration, `data_config.data_type` should have the value "tiff" or
66
- "zarr", and `data_config.axes` should describe the axes of every image in the
67
- `source`.
114
+ axes: str
115
+ Specifies which axes of the data to use and how.
116
+ If this string ends with `"ZYX"` or `"TYX"`, the data will consist of 3-D
117
+ patches, using `Z` or `T` as third dimension, respectively.
118
+ If the string does not end with "ZYX", the data will consist of 2-D patches.
68
119
 
69
120
  Returns
70
121
  -------
71
122
  PatchExtractor
72
123
  """
124
+ depth_axis: Literal["none", "Z", "T"] = "none"
125
+ if axes.endswith("TYX"):
126
+ depth_axis = "T"
127
+ elif axes.endswith("ZYX"):
128
+ depth_axis = "Z"
129
+
130
+ image_stacks: list[CziImageStack] = []
131
+ for path in source:
132
+ scene_rectangles = CziImageStack.get_bounding_rectangles(path)
133
+ image_stacks.extend(
134
+ CziImageStack(path, scene=scene, depth_axis=depth_axis)
135
+ for scene in scene_rectangles.keys()
136
+ )
137
+ return PatchExtractor(image_stacks)
73
138
 
74
139
 
75
140
  # Custom file type case (loaded into memory)
76
- @overload
77
- def create_patch_extractor(
78
- source: Any,
141
+ def create_custom_file_extractor(
142
+ source: Sequence[Path],
79
143
  axes: str,
80
- data_type: Literal[SupportedData.CUSTOM],
81
144
  *,
82
145
  read_func: ReadFunc,
83
146
  read_kwargs: dict[str, Any],
84
- ) -> PatchExtractor:
147
+ ) -> PatchExtractor[InMemoryImageStack]:
85
148
  """
86
149
  Create a patch extractor from a sequence of files of a custom type.
87
150
 
@@ -89,8 +152,8 @@ def create_patch_extractor(
89
152
  ----------
90
153
  source: sequence of Path
91
154
  The source files for the data.
92
- data_config: DataConfig
93
- The data configuration, `data_config.data_type` should have the value "custom".
155
+ axes: str
156
+ The original axes of the data, must be a subset of "STCZYX".
94
157
  read_func : ReadFunc
95
158
  A function to read the custom file type, see the `ReadFunc` protocol.
96
159
  read_kwargs : dict of {str: Any}
@@ -100,18 +163,28 @@ def create_patch_extractor(
100
163
  -------
101
164
  PatchExtractor
102
165
  """
166
+ # TODO: lazy loading custom files
167
+ image_stacks = [
168
+ InMemoryImageStack.from_custom_file_type(
169
+ path=path,
170
+ axes=axes,
171
+ read_func=read_func,
172
+ **read_kwargs,
173
+ )
174
+ for path in source
175
+ ]
176
+
177
+ return PatchExtractor(image_stacks)
103
178
 
104
179
 
105
180
  # Custom ImageStackLoader case
106
- @overload
107
- def create_patch_extractor(
181
+ def create_custom_image_stack_extractor(
108
182
  source: Any,
109
183
  axes: str,
110
- data_type: Literal[SupportedData.CUSTOM],
111
- image_stack_loader: ImageStackLoader[P],
184
+ image_stack_loader: ImageStackLoader[P, GenericImageStack],
112
185
  *args: P.args,
113
186
  **kwargs: P.kwargs,
114
- ) -> PatchExtractor:
187
+ ) -> PatchExtractor[GenericImageStack]:
115
188
  """
116
189
  Create a patch extractor using a custom `ImageStackLoader`.
117
190
 
@@ -127,8 +200,8 @@ def create_patch_extractor(
127
200
  ----------
128
201
  source: sequence of Path
129
202
  The source files for the data.
130
- data_config: DataConfig
131
- The data configuration, `data_config.data_type` should have the value "custom".
203
+ axes: str
204
+ The original axes of the data, must be a subset of "STCZYX".
132
205
  image_stack_loader: ImageStackLoader
133
206
  A custom image stack loader callable.
134
207
  *args: Any
@@ -140,69 +213,5 @@ def create_patch_extractor(
140
213
  -------
141
214
  PatchExtractor
142
215
  """
143
-
144
-
145
- # final overload to match the implentation function signature
146
- # Need this so it works later in the code
147
- # (bec there aren't created overloads for create_patch_extractors below)
148
- @overload
149
- def create_patch_extractor(
150
- source: Any,
151
- axes: str,
152
- data_type: Union[SupportedData, SupportedDataDev],
153
- image_stack_loader: Optional[ImageStackLoader[P]] = None,
154
- *args: P.args,
155
- **kwargs: P.kwargs,
156
- ) -> PatchExtractor: ...
157
-
158
-
159
- def create_patch_extractor(
160
- source: Any,
161
- axes: str,
162
- data_type: Union[SupportedData, SupportedDataDev],
163
- image_stack_loader: Optional[ImageStackLoader[P]] = None,
164
- *args: P.args,
165
- **kwargs: P.kwargs,
166
- ) -> PatchExtractor:
167
- # TODO: Do we need to catch data_config.data_type and source mismatches?
168
- # e.g. data_config.data_type is "array" but source is not Sequence[NDArray]
169
- loader: ImageStackLoader[P] = get_image_stack_loader(data_type, image_stack_loader)
170
- image_stacks = loader(source, axes, *args, **kwargs)
216
+ image_stacks = image_stack_loader(source, axes, *args, **kwargs)
171
217
  return PatchExtractor(image_stacks)
172
-
173
-
174
- # TODO: Remove this and just call `create_patch_extractor` within the Dataset class
175
- # Keeping for consistency for now
176
- def create_patch_extractors(
177
- source: Any,
178
- target_source: Optional[Any],
179
- axes: str,
180
- data_type: Union[SupportedData, SupportedDataDev],
181
- image_stack_loader: Optional[ImageStackLoader] = None,
182
- *args,
183
- **kwargs,
184
- ) -> tuple[PatchExtractor, Optional[PatchExtractor]]:
185
-
186
- # --- data extractor
187
- patch_extractor: PatchExtractor = create_patch_extractor(
188
- source,
189
- axes,
190
- data_type,
191
- image_stack_loader,
192
- *args,
193
- **kwargs,
194
- )
195
- # --- optional target extractor
196
- if target_source is not None:
197
- target_patch_extractor = create_patch_extractor(
198
- target_source,
199
- axes,
200
- data_type,
201
- image_stack_loader,
202
- *args,
203
- **kwargs,
204
- )
205
-
206
- return patch_extractor, target_patch_extractor
207
-
208
- return patch_extractor, None
@@ -4,8 +4,13 @@ __all__ = [
4
4
  "PatchingStrategy",
5
5
  "RandomPatchingStrategy",
6
6
  "SequentialPatchingStrategy",
7
+ "TileSpecs",
8
+ "TilingStrategy",
9
+ "WholeSamplePatchingStrategy",
7
10
  ]
8
11
 
9
- from .patching_strategy_protocol import PatchingStrategy, PatchSpecs
12
+ from .patching_strategy_protocol import PatchingStrategy, PatchSpecs, TileSpecs
10
13
  from .random_patching import FixedRandomPatchingStrategy, RandomPatchingStrategy
11
14
  from .sequential_patching import SequentialPatchingStrategy
15
+ from .tiling_strategy import TilingStrategy
16
+ from .whole_sample import WholeSamplePatchingStrategy
@@ -28,6 +28,37 @@ class PatchSpecs(TypedDict):
28
28
  patch_size: Sequence[int]
29
29
 
30
30
 
31
+ class TileSpecs(PatchSpecs):
32
+ """A dictionary that specifies a single patch in a series of `ImageStacks`.
33
+
34
+ Attributes
35
+ ----------
36
+ data_idx: int
37
+ Determines which `ImageStack` a patch belongs to, within a series of
38
+ `ImageStack`s.
39
+ sample_idx: int
40
+ Determines which sample a patch belongs to, within an `ImageStack`.
41
+ coords: sequence of int
42
+ The top-left (and first z-slice for 3D data) of a patch. The sequence will have
43
+ length 2 or 3, for 2D and 3D data respectively.
44
+ patch_size: sequence of int
45
+ The size of the patch. The sequence will have length 2 or 3, for 2D and 3D data
46
+ respectively.
47
+ crop_coords: sequence of int
48
+ The top-left side of where the tile will be cropped, in coordinates relative
49
+ to the tile.
50
+ crop_size: sequence of int
51
+ The size of the cropped tile.
52
+ stitch_coords: sequence of int
53
+ Where the tile will be stitched back into an image, taking into account
54
+ that the tile will be cropped, in coords relative to the image.
55
+ """
56
+
57
+ crop_coords: Sequence[int]
58
+ crop_size: Sequence[int]
59
+ stitch_coords: Sequence[int]
60
+
61
+
31
62
  class PatchingStrategy(Protocol):
32
63
  """
33
64
  An interface for patching strategies.
@@ -335,4 +335,8 @@ def _calc_n_patches(spatial_shape: Sequence[int], patch_size: Sequence[int]) ->
335
335
  f"spatial dimensions {len(spatial_shape)}, for `patch_size={patch_size}` "
336
336
  f"and `spatial_shape={spatial_shape}`."
337
337
  )
338
- return int(np.ceil(np.prod(spatial_shape) / np.prod(patch_size)))
338
+ patches_per_dim = [
339
+ np.ceil(s / p) for s, p in zip(spatial_shape, patch_size, strict=False)
340
+ ]
341
+ total_patches = int(np.prod(patches_per_dim))
342
+ return total_patches
@@ -18,13 +18,13 @@ class SequentialPatchingStrategy:
18
18
  self,
19
19
  data_shapes: Sequence[Sequence[int]],
20
20
  patch_size: Sequence[int],
21
- overlap: Optional[Sequence[int]] = None,
21
+ overlaps: Optional[Sequence[int]] = None,
22
22
  ):
23
23
  self.data_shapes = data_shapes
24
24
  self.patch_size = patch_size
25
- if overlap is None:
26
- overlap = [0] * len(patch_size)
27
- self.overlap = np.asarray(overlap)
25
+ if overlaps is None:
26
+ overlaps = [0] * len(patch_size)
27
+ self.overlaps = np.asarray(overlaps)
28
28
 
29
29
  self.patch_specs: list[PatchSpecs] = self._initialize_patch_specs()
30
30
 
@@ -58,7 +58,7 @@ class SequentialPatchingStrategy:
58
58
  data_spatial_shape = data_shape[-len(self.patch_size) :]
59
59
  coords_list = [
60
60
  self._compute_coords_1d(
61
- self.patch_size[i], data_spatial_shape[i], self.overlap[i]
61
+ self.patch_size[i], data_spatial_shape[i], self.overlaps[i]
62
62
  )
63
63
  for i in range(len(self.patch_size))
64
64
  ]
@@ -0,0 +1,172 @@
1
+ """Module for the `TilingStrategy` class."""
2
+
3
+ import itertools
4
+ from collections.abc import Sequence
5
+
6
+ from .patching_strategy_protocol import TileSpecs
7
+
8
+
9
+ class TilingStrategy:
10
+ """
11
+ The tiling strategy should be used for prediction. The `get_patch_specs`
12
+ method returns `TileSpec` dictionaries that contains information on how to
13
+ stitch the tiles back together to create the full image.
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ data_shapes: Sequence[Sequence[int]],
19
+ tile_size: Sequence[int],
20
+ overlaps: Sequence[int],
21
+ ):
22
+ """
23
+ The tiling strategy should be used for prediction. The `get_patch_specs`
24
+ method returns `TileSpec` dictionaries that contains information on how to
25
+ stitch the tiles back together to create the full image.
26
+
27
+ Parameters
28
+ ----------
29
+ data_shapes : sequence of (sequence of int)
30
+ The shapes of the underlying data. Each element is the dimension of the
31
+ axes SC(Z)YX.
32
+ tile_size : sequence of int
33
+ The size of the tile. The sequence will have length 2 or 3, for 2D and 3D
34
+ data respectively.
35
+ overlaps : sequence of int
36
+ How much a tile will overlap with adjacent tiles in each spatial dimension.
37
+ """
38
+ self.data_shapes = data_shapes
39
+ self.tile_size = tile_size
40
+ self.overlaps = overlaps
41
+ # tile_size and overlap should have same length validated in pydantic configs
42
+ self.tile_specs: list[TileSpecs] = self._generate_specs()
43
+
44
+ @property
45
+ def n_patches(self) -> int:
46
+ """
47
+ The number of patches that this patching strategy will return.
48
+
49
+ It also determines the maximum index that can be given to `get_patch_spec`.
50
+ """
51
+ return len(self.tile_specs)
52
+
53
+ def get_patch_spec(self, index: int) -> TileSpecs:
54
+ """Return the tile specs for a given index.
55
+
56
+ Parameters
57
+ ----------
58
+ index : int
59
+ A patch index.
60
+
61
+ Returns
62
+ -------
63
+ TileSpecs
64
+ A dictionary that specifies a single patch in a series of `ImageStacks`.
65
+ """
66
+ return self.tile_specs[index]
67
+
68
+ def _generate_specs(self) -> list[TileSpecs]:
69
+ tile_specs: list[TileSpecs] = []
70
+ for data_idx, data_shape in enumerate(self.data_shapes):
71
+ spatial_shape = data_shape[2:]
72
+
73
+ # spec info for each axis
74
+ axis_specs: list[tuple[list[int], list[int], list[int], list[int]]] = [
75
+ self._compute_1d_coords(
76
+ axis_size, self.tile_size[axis_idx], self.overlaps[axis_idx]
77
+ )
78
+ for axis_idx, axis_size in enumerate(spatial_shape)
79
+ ]
80
+
81
+ # combine by using zip
82
+ all_coords, all_stitch_coords, all_crop_coords, all_crop_size = zip(
83
+ *axis_specs, strict=False
84
+ )
85
+ # patches will be the same for each sample in a stack
86
+ for sample_idx in range(data_shape[0]):
87
+ # iterate through all combinations using itertools.product
88
+ for coords, stitch_coords, crop_coords, crop_size in zip(
89
+ itertools.product(*all_coords),
90
+ itertools.product(*all_stitch_coords),
91
+ itertools.product(*all_crop_coords),
92
+ itertools.product(*all_crop_size),
93
+ strict=False,
94
+ ):
95
+ tile_specs.append(
96
+ {
97
+ # PatchSpecs
98
+ "data_idx": data_idx,
99
+ "sample_idx": sample_idx,
100
+ "coords": coords,
101
+ "patch_size": self.tile_size,
102
+ # TileSpecs additional fields
103
+ "crop_coords": crop_coords,
104
+ "crop_size": crop_size,
105
+ "stitch_coords": stitch_coords,
106
+ }
107
+ )
108
+ return tile_specs
109
+
110
+ @staticmethod
111
+ def _compute_1d_coords(
112
+ axis_size: int, tile_size: int, overlap: int
113
+ ) -> tuple[list[int], list[int], list[int], list[int]]:
114
+ """
115
+ Computes the TileSpec information for a single axis.
116
+
117
+ Parameters
118
+ ----------
119
+ axis_size : int
120
+ The size of the axis.
121
+ tile_size : int
122
+ The tile size.
123
+ overlap : int
124
+ The tile overlap.
125
+
126
+ Returns
127
+ -------
128
+ coords: list of int
129
+ The top-left (and first z-slice for 3D data) of a tile, in coords relative
130
+ to the image.
131
+ stitch_coords: list of int
132
+ Where the tile will be stitched back into an image, taking into account
133
+ that the tile will be cropped, in coords relative to the image.
134
+ crop_coords: list of int
135
+ The top-left side of where the tile will be cropped, in coordinates relative
136
+ to the tile.
137
+ crop_size: list of int
138
+ The size of the cropped tile.
139
+ """
140
+ coords: list[int] = []
141
+ stitch_coords: list[int] = []
142
+ crop_coords: list[int] = []
143
+ crop_size: list[int] = []
144
+
145
+ step = tile_size - overlap
146
+ for i in range(0, max(1, axis_size - overlap), step):
147
+ if i == 0:
148
+ coords.append(i)
149
+ crop_coords.append(0)
150
+ stitch_coords.append(0)
151
+ crop_size.append(tile_size - overlap // 2)
152
+ elif (i > 0) and (i + tile_size < axis_size):
153
+ coords.append(i)
154
+ crop_coords.append(overlap // 2)
155
+ stitch_coords.append(coords[-1] + crop_coords[-1])
156
+ crop_size.append(tile_size - overlap)
157
+ else:
158
+ previous_crop_size = crop_size[-1] if crop_size else 1
159
+ previous_stitch_coord = stitch_coords[-1] if stitch_coords else 0
160
+ previous_tile_end = previous_stitch_coord + previous_crop_size
161
+
162
+ coords.append(max(0, axis_size - tile_size))
163
+ stitch_coords.append(previous_tile_end)
164
+ crop_coords.append(stitch_coords[-1] - coords[-1])
165
+ crop_size.append(axis_size - stitch_coords[-1])
166
+
167
+ return (
168
+ coords,
169
+ stitch_coords,
170
+ crop_coords,
171
+ crop_size,
172
+ )
@@ -0,0 +1,36 @@
1
+ from collections.abc import Sequence
2
+
3
+ from .patching_strategy_protocol import PatchSpecs
4
+
5
+
6
+ class WholeSamplePatchingStrategy:
7
+ # TODO: warn this strategy should only be used with batch size = 1
8
+ # for the case of multiple image stacks with different dimensions
9
+
10
+ # TODO: docs
11
+ def __init__(self, data_shapes: Sequence[Sequence[int]]):
12
+ self.data_shapes = data_shapes
13
+
14
+ self.patch_specs: list[PatchSpecs] = self._initialize_patch_specs()
15
+
16
+ @property
17
+ def n_patches(self) -> int:
18
+ return len(self.patch_specs)
19
+
20
+ def get_patch_spec(self, index: int) -> PatchSpecs:
21
+ return self.patch_specs[index]
22
+
23
+ def _initialize_patch_specs(self) -> list[PatchSpecs]:
24
+ patch_specs: list[PatchSpecs] = []
25
+ for data_idx, data_shape in enumerate(self.data_shapes):
26
+ spatial_shape = data_shape[2:]
27
+ for sample_idx in range(data_shape[0]):
28
+ patch_specs.append(
29
+ {
30
+ "data_idx": data_idx,
31
+ "sample_idx": sample_idx,
32
+ "coords": tuple(0 for _ in spatial_shape),
33
+ "patch_size": spatial_shape,
34
+ }
35
+ )
36
+ return patch_specs
@@ -1,7 +1,8 @@
1
1
  """Module to get read functions."""
2
2
 
3
+ from collections.abc import Callable
3
4
  from pathlib import Path
4
- from typing import Callable, Protocol, Union
5
+ from typing import Protocol, Union
5
6
 
6
7
  from numpy.typing import NDArray
7
8
 
@@ -0,0 +1 @@
1
+ """Next-Generation DataModules for Careamics."""