careamics 0.0.9__py3-none-any.whl → 0.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (63) hide show
  1. careamics/__init__.py +0 -4
  2. careamics/careamist.py +0 -1
  3. careamics/config/__init__.py +1 -13
  4. careamics/config/algorithms/care_algorithm_model.py +84 -0
  5. careamics/config/algorithms/n2n_algorithm_model.py +85 -0
  6. careamics/config/algorithms/n2v_algorithm_model.py +269 -1
  7. careamics/config/configuration.py +21 -13
  8. careamics/config/configuration_factories.py +179 -187
  9. careamics/config/configuration_io.py +2 -2
  10. careamics/config/data/__init__.py +1 -4
  11. careamics/config/data/data_model.py +46 -62
  12. careamics/config/support/supported_transforms.py +1 -1
  13. careamics/config/transformations/__init__.py +0 -2
  14. careamics/config/transformations/n2v_manipulate_model.py +15 -0
  15. careamics/config/transformations/transform_unions.py +0 -13
  16. careamics/dataset/dataset_utils/iterate_over_files.py +2 -2
  17. careamics/dataset/dataset_utils/running_stats.py +7 -3
  18. careamics/dataset/in_memory_dataset.py +3 -10
  19. careamics/dataset/in_memory_pred_dataset.py +3 -5
  20. careamics/dataset/in_memory_tiled_pred_dataset.py +2 -2
  21. careamics/dataset/iterable_dataset.py +2 -2
  22. careamics/dataset/iterable_pred_dataset.py +3 -5
  23. careamics/dataset/iterable_tiled_pred_dataset.py +3 -3
  24. careamics/dataset_ng/dataset/__init__.py +3 -0
  25. careamics/dataset_ng/dataset/dataset.py +184 -0
  26. careamics/dataset_ng/demo_dataset.ipynb +271 -0
  27. careamics/dataset_ng/demo_patch_extractor.py +53 -0
  28. careamics/dataset_ng/demo_patch_extractor_factory.py +37 -0
  29. careamics/dataset_ng/patch_extractor/__init__.py +10 -0
  30. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +111 -0
  31. careamics/dataset_ng/patch_extractor/image_stack/__init__.py +9 -0
  32. careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +53 -0
  33. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +55 -0
  34. careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +163 -0
  35. careamics/dataset_ng/patch_extractor/image_stack_loader.py +140 -0
  36. careamics/dataset_ng/patch_extractor/patch_extractor.py +29 -0
  37. careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +208 -0
  38. careamics/dataset_ng/patching_strategies/__init__.py +11 -0
  39. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +82 -0
  40. careamics/dataset_ng/patching_strategies/random_patching.py +338 -0
  41. careamics/dataset_ng/patching_strategies/sequential_patching.py +75 -0
  42. careamics/lightning/lightning_module.py +78 -27
  43. careamics/lightning/train_data_module.py +8 -39
  44. careamics/losses/fcn/losses.py +17 -10
  45. careamics/model_io/bioimage/bioimage_utils.py +5 -3
  46. careamics/model_io/bioimage/model_description.py +3 -3
  47. careamics/model_io/bmz_io.py +2 -2
  48. careamics/model_io/model_io_utils.py +2 -2
  49. careamics/transforms/__init__.py +2 -1
  50. careamics/transforms/compose.py +5 -15
  51. careamics/transforms/n2v_manipulate_torch.py +143 -0
  52. careamics/transforms/pixel_manipulation.py +1 -0
  53. careamics/transforms/pixel_manipulation_torch.py +418 -0
  54. careamics/utils/version.py +38 -0
  55. {careamics-0.0.9.dist-info → careamics-0.0.11.dist-info}/METADATA +7 -8
  56. {careamics-0.0.9.dist-info → careamics-0.0.11.dist-info}/RECORD +59 -42
  57. careamics/config/care_configuration.py +0 -100
  58. careamics/config/data/n2v_data_model.py +0 -193
  59. careamics/config/n2n_configuration.py +0 -101
  60. careamics/config/n2v_configuration.py +0 -266
  61. {careamics-0.0.9.dist-info → careamics-0.0.11.dist-info}/WHEEL +0 -0
  62. {careamics-0.0.9.dist-info → careamics-0.0.11.dist-info}/entry_points.txt +0 -0
  63. {careamics-0.0.9.dist-info → careamics-0.0.11.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,271 @@
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "0",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from pathlib import Path\n",
11
+ "\n",
12
+ "import matplotlib.pyplot as plt\n",
13
+ "import numpy as np\n",
14
+ "import skimage\n",
15
+ "import tifffile\n",
16
+ "\n",
17
+ "from careamics.config import create_n2n_configuration\n",
18
+ "from careamics.dataset_ng.dataset.dataset import CareamicsDataset, Mode"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "id": "1",
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "example_data = skimage.data.human_mitosis()\n",
29
+ "\n",
30
+ "markers = np.zeros_like(example_data)\n",
31
+ "markers[example_data < 25] = 1\n",
32
+ "markers[example_data > 50] = 2\n",
33
+ "\n",
34
+ "elevation_map = skimage.filters.sobel(example_data)\n",
35
+ "segmentation = skimage.segmentation.watershed(elevation_map, markers)\n",
36
+ "\n",
37
+ "fig, ax = plt.subplots(1, 2)\n",
38
+ "ax[0].imshow(example_data)\n",
39
+ "ax[1].imshow(segmentation)\n",
40
+ "plt.show()"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "markdown",
45
+ "id": "2",
46
+ "metadata": {},
47
+ "source": [
48
+ "### 1. From an array "
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": null,
54
+ "id": "3",
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "# 1. Train val from an array\n",
59
+ "\n",
60
+ "train_data_config = create_n2n_configuration(\n",
61
+ " \"test_exp\",\n",
62
+ " data_type=\"array\",\n",
63
+ " axes=\"YX\",\n",
64
+ " patch_size=(32, 32),\n",
65
+ " batch_size=1,\n",
66
+ " num_epochs=1,\n",
67
+ ").data_config\n",
68
+ "\n",
69
+ "val_data_config = create_n2n_configuration(\n",
70
+ " \"test_exp\",\n",
71
+ " data_type=\"array\",\n",
72
+ " axes=\"YX\",\n",
73
+ " patch_size=(32, 32),\n",
74
+ " batch_size=1,\n",
75
+ " num_epochs=1,\n",
76
+ " augmentations=[],\n",
77
+ ").data_config\n",
78
+ "\n",
79
+ "\n",
80
+ "train_dataset = CareamicsDataset(\n",
81
+ " data_config=train_data_config,\n",
82
+ " mode=Mode.TRAINING,\n",
83
+ " inputs=[example_data],\n",
84
+ " targets=[segmentation],\n",
85
+ ")\n",
86
+ "val_dataset = CareamicsDataset(\n",
87
+ " data_config=val_data_config,\n",
88
+ " mode=Mode.VALIDATING,\n",
89
+ " inputs=[example_data],\n",
90
+ " targets=[segmentation],\n",
91
+ ")\n",
92
+ "\n",
93
+ "fig, ax = plt.subplots(2, 5, figsize=(10, 5))\n",
94
+ "ax[0, 0].set_title(\"Train input\")\n",
95
+ "ax[1, 0].set_title(\"Train target\")\n",
96
+ "for i in range(5):\n",
97
+ " sample, target = train_dataset[i]\n",
98
+ " ax[0, i].imshow(sample.data[0])\n",
99
+ " ax[1, i].imshow(target.data[0])"
100
+ ]
101
+ },
102
+ {
103
+ "cell_type": "markdown",
104
+ "id": "4",
105
+ "metadata": {},
106
+ "source": [
107
+ "### 2. From tiff "
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "code",
112
+ "execution_count": null,
113
+ "id": "5",
114
+ "metadata": {},
115
+ "outputs": [],
116
+ "source": [
117
+ "tifffile.imwrite(\"example_data1.tiff\", example_data)\n",
118
+ "tifffile.imwrite(\"example_target1.tiff\", segmentation)\n",
119
+ "tifffile.imwrite(\"example_data2.tiff\", example_data[:256, :256])\n",
120
+ "tifffile.imwrite(\"example_target2.tiff\", segmentation[:256, :256])\n",
121
+ "\n",
122
+ "train_data_config = create_n2n_configuration(\n",
123
+ " \"test_exp\",\n",
124
+ " data_type=\"tiff\",\n",
125
+ " axes=\"YX\",\n",
126
+ " patch_size=(32, 32),\n",
127
+ " batch_size=1,\n",
128
+ " num_epochs=1,\n",
129
+ ").data_config\n",
130
+ "\n",
131
+ "val_data_config = create_n2n_configuration(\n",
132
+ " \"test_exp\",\n",
133
+ " data_type=\"tiff\",\n",
134
+ " axes=\"YX\",\n",
135
+ " patch_size=(32, 32),\n",
136
+ " batch_size=1,\n",
137
+ " num_epochs=1,\n",
138
+ " augmentations=[],\n",
139
+ ").data_config\n",
140
+ "\n",
141
+ "data = sorted(Path(\"./\").glob(\"example_data*.tiff\"))\n",
142
+ "targets = sorted(Path(\"./\").glob(\"example_target*.tiff\"))\n",
143
+ "train_dataset = CareamicsDataset(\n",
144
+ " data_config=train_data_config, inputs=data, targets=targets\n",
145
+ ")\n",
146
+ "val_dataset = CareamicsDataset(\n",
147
+ " data_config=val_data_config, inputs=data, targets=targets\n",
148
+ ")\n",
149
+ "\n",
150
+ "fig, ax = plt.subplots(2, 5, figsize=(10, 5))\n",
151
+ "ax[0, 0].set_title(\"Train input\")\n",
152
+ "ax[1, 0].set_title(\"Train target\")\n",
153
+ "for i in range(5):\n",
154
+ " sample, target = train_dataset[i]\n",
155
+ " ax[0, i].imshow(sample.data[0])\n",
156
+ " ax[1, i].imshow(target.data[0])"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "markdown",
161
+ "id": "6",
162
+ "metadata": {},
163
+ "source": [
164
+ "### 3. Prediction from array"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "code",
169
+ "execution_count": null,
170
+ "id": "7",
171
+ "metadata": {},
172
+ "outputs": [],
173
+ "source": [
174
+ "from careamics.config import InferenceConfig\n",
175
+ "\n",
176
+ "prediction_config = InferenceConfig(\n",
177
+ " data_type=\"array\",\n",
178
+ " tile_size=(32, 32),\n",
179
+ " tile_overlap=(16, 16),\n",
180
+ " axes=\"YX\",\n",
181
+ " image_means=(example_data.mean(),),\n",
182
+ " image_stds=(example_data.std(),),\n",
183
+ " tta_transforms=False,\n",
184
+ " batch_size=1,\n",
185
+ ")\n",
186
+ "prediction_dataset = CareamicsDataset(\n",
187
+ " data_config=prediction_config, mode=Mode.PREDICTING, inputs=[example_data]\n",
188
+ ")\n",
189
+ "\n",
190
+ "fig, ax = plt.subplots(1, 5, figsize=(10, 5))\n",
191
+ "ax[0].set_title(\"Prediction input\")\n",
192
+ "for i in range(5):\n",
193
+ " sample, _ = prediction_dataset[i]\n",
194
+ " ax[i].imshow(sample.data[0])"
195
+ ]
196
+ },
197
+ {
198
+ "cell_type": "markdown",
199
+ "id": "8",
200
+ "metadata": {},
201
+ "source": [
202
+ "### 4. From custom data type "
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "code",
207
+ "execution_count": null,
208
+ "id": "9",
209
+ "metadata": {},
210
+ "outputs": [],
211
+ "source": [
212
+ "train_data_config = create_n2n_configuration(\n",
213
+ " \"test_exp\",\n",
214
+ " data_type=\"custom\",\n",
215
+ " axes=\"YX\",\n",
216
+ " patch_size=(32, 32),\n",
217
+ " batch_size=1,\n",
218
+ " num_epochs=1,\n",
219
+ ").data_config\n",
220
+ "\n",
221
+ "\n",
222
+ "def read_data_func_test(data):\n",
223
+ " return 255 - example_data\n",
224
+ "\n",
225
+ "\n",
226
+ "fig, ax = plt.subplots(1, 5, figsize=(10, 5))\n",
227
+ "train_dataset = CareamicsDataset(\n",
228
+ " data_config=train_data_config,\n",
229
+ " mode=Mode.TRAINING,\n",
230
+ " inputs=[example_data],\n",
231
+ " targets=[segmentation],\n",
232
+ " read_func=read_data_func_test,\n",
233
+ " read_kwargs={}\n",
234
+ ")\n",
235
+ "\n",
236
+ "for i in range(5):\n",
237
+ " sample, _ = train_dataset[i]\n",
238
+ " ax[i].imshow(sample.data[0])"
239
+ ]
240
+ },
241
+ {
242
+ "cell_type": "code",
243
+ "execution_count": null,
244
+ "id": "10",
245
+ "metadata": {},
246
+ "outputs": [],
247
+ "source": []
248
+ }
249
+ ],
250
+ "metadata": {
251
+ "kernelspec": {
252
+ "display_name": "Python 3",
253
+ "language": "python",
254
+ "name": "python3"
255
+ },
256
+ "language_info": {
257
+ "codemirror_mode": {
258
+ "name": "ipython",
259
+ "version": 3
260
+ },
261
+ "file_extension": ".py",
262
+ "mimetype": "text/x-python",
263
+ "name": "python",
264
+ "nbconvert_exporter": "python",
265
+ "pygments_lexer": "ipython3",
266
+ "version": "3.11.11"
267
+ }
268
+ },
269
+ "nbformat": 4,
270
+ "nbformat_minor": 5
271
+ }
@@ -0,0 +1,53 @@
1
+ # %%
2
+ import numpy as np
3
+
4
+ # %%
5
+ from careamics.config.support import SupportedData
6
+ from careamics.dataset_ng.patch_extractor import create_patch_extractor
7
+ from careamics.dataset_ng.patch_extractor.image_stack import InMemoryImageStack
8
+ from careamics.dataset_ng.patching_strategies import RandomPatchingStrategy
9
+
10
+ # %%
11
+ array = np.arange(36).reshape(6, 6)
12
+ image_stack = InMemoryImageStack.from_array(data=array, axes="YX")
13
+ image_stack.extract_patch(sample_idx=0, coords=(2, 2), patch_size=(3, 3))
14
+
15
+ # %%
16
+ rng = np.random.default_rng()
17
+
18
+ # %%
19
+ # define example data
20
+ array1 = np.arange(36).reshape(1, 6, 6)
21
+ array2 = np.arange(50).reshape(2, 5, 5)
22
+ target1 = rng.integers(0, 1, size=array1.shape, endpoint=True)
23
+ target2 = rng.integers(0, 1, size=array2.shape, endpoint=True)
24
+
25
+ # %%
26
+ print(array1)
27
+ print(array2)
28
+ print(target1)
29
+ print(target2)
30
+
31
+ # %%
32
+ # define example readers
33
+ input_patch_extractor = create_patch_extractor(
34
+ [array1, array2], axes="SYX", data_type=SupportedData.ARRAY
35
+ )
36
+ target_patch_extractor = create_patch_extractor(
37
+ [target1, target2], axes="SYX", data_type=SupportedData.ARRAY
38
+ )
39
+
40
+ # %%
41
+ # generate random patch specification
42
+ data_shapes = [
43
+ image_stack.data_shape for image_stack in input_patch_extractor.image_stacks
44
+ ]
45
+ patch_specs_generator = RandomPatchingStrategy(data_shapes, patch_size=(2, 2))
46
+ patch_specs = patch_specs_generator.get_patch_spec(18)
47
+
48
+ # %%
49
+ # extract a subset of patches
50
+ input_patch_extractor.extract_patch(**patch_specs)
51
+
52
+ # %%
53
+ target_patch_extractor.extract_patch(**patch_specs)
@@ -0,0 +1,37 @@
1
+ # %%
2
+ import numpy as np
3
+
4
+ from careamics.config import create_n2n_configuration
5
+ from careamics.config.support import SupportedData
6
+ from careamics.dataset_ng.patch_extractor.patch_extractor_factory import (
7
+ create_patch_extractors,
8
+ )
9
+
10
+ rng = np.random.default_rng()
11
+
12
+ # %%
13
+ # define example data
14
+ array1 = np.arange(36).reshape(1, 6, 6)
15
+ array2 = np.arange(50).reshape(2, 5, 5)
16
+ target1 = rng.integers(0, 1, size=array1.shape, endpoint=True)
17
+ target2 = rng.integers(0, 1, size=array2.shape, endpoint=True)
18
+
19
+ # %%
20
+ config = create_n2n_configuration(
21
+ "test_exp",
22
+ data_type="array",
23
+ axes="SYX",
24
+ patch_size=[8, 8],
25
+ batch_size=1,
26
+ num_epochs=1,
27
+ )
28
+ data_config = config.data_config
29
+
30
+ # %%
31
+ data_type = SupportedData(data_config.data_type)
32
+ train_inputs, train_targets = create_patch_extractors(
33
+ [array1, array2], [target1, target2], axes=data_config.axes, data_type=data_type
34
+ )
35
+
36
+ # %%
37
+ train_inputs.extract_patch(data_idx=0, sample_idx=0, coords=(2, 2), patch_size=(3, 3))
@@ -0,0 +1,10 @@
1
+ __all__ = [
2
+ "ImageStackLoader",
3
+ "PatchExtractor",
4
+ "create_patch_extractor",
5
+ "get_image_stack_loader",
6
+ ]
7
+
8
+ from .image_stack_loader import ImageStackLoader, get_image_stack_loader
9
+ from .patch_extractor import PatchExtractor
10
+ from .patch_extractor_factory import create_patch_extractor
@@ -0,0 +1,111 @@
1
+ # %%
2
+ from collections.abc import Sequence
3
+ from pathlib import Path
4
+ from typing import TypedDict
5
+
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import zarr
9
+ from numpy.typing import NDArray
10
+ from zarr.storage import FSStore
11
+
12
+ from careamics.config import DataConfig
13
+ from careamics.config.support import SupportedData
14
+ from careamics.dataset_ng.patch_extractor import create_patch_extractor
15
+ from careamics.dataset_ng.patch_extractor.image_stack import ZarrImageStack
16
+ from careamics.dataset_ng.patch_extractor.image_stack_loader import ImageStackLoader
17
+
18
+
19
+ # %%
20
+ def create_zarr_array(file_path: Path, data_path: str, data: NDArray):
21
+ store = FSStore(url=file_path.resolve())
22
+ # create array
23
+ array = zarr.create(
24
+ store=store,
25
+ shape=data.shape,
26
+ chunks=data.shape, # only 1 chunk
27
+ dtype=np.uint16,
28
+ path=data_path,
29
+ )
30
+ # write data
31
+ array[...] = data
32
+ store.close()
33
+
34
+
35
+ def create_zarr(file_path: Path, data_paths: Sequence[str], data: Sequence[NDArray]):
36
+ for data_path, array in zip(data_paths, data):
37
+ create_zarr_array(file_path=file_path, data_path=data_path, data=array)
38
+
39
+
40
+ # %% [markdown]
41
+ # ### Create example ZARR file
42
+
43
+ # %%
44
+ dir_path = Path("/home/melisande.croft/Documents/Data")
45
+ file_name = "test_ngff_image.zarr"
46
+ file_path = dir_path / file_name
47
+
48
+ data_paths = [
49
+ "image_1",
50
+ "group_1/image_1.1",
51
+ "group_1/image_1.2",
52
+ ]
53
+ data_shapes = [(1, 3, 64, 64), (1, 3, 32, 48), (1, 3, 32, 32)]
54
+ data = [np.random.randint(1, 255, size=shape, dtype=np.uint8) for shape in data_shapes]
55
+ if not file_path.is_file() and not file_path.is_dir():
56
+ create_zarr(file_path, data_paths, data)
57
+
58
+ # %% [markdown]
59
+ # ### Make sure file exists
60
+
61
+ # %%
62
+ store = FSStore(url=file_path.resolve(), mode="r")
63
+
64
+ # %%
65
+ list(store.keys())
66
+
67
+ # %% [markdown]
68
+ # ### Define custom loading function
69
+
70
+
71
+ # %%
72
+ class ZarrSource(TypedDict):
73
+ store: FSStore
74
+ data_paths: Sequence[str]
75
+
76
+
77
+ def custom_image_stack_loader(source: ZarrSource, axes: str, *args, **kwargs):
78
+ image_stacks = [
79
+ ZarrImageStack(store=source["store"], data_path=data_path, axes=axes)
80
+ for data_path in source["data_paths"]
81
+ ]
82
+ return image_stacks
83
+
84
+
85
+ # %% [markdown]
86
+ # ### Test custom loading func
87
+
88
+ # %%
89
+ # dummy data config
90
+ data_config = DataConfig(data_type="custom", patch_size=[64, 64], axes="SCYX")
91
+
92
+ # %%
93
+ image_stack_loader: ImageStackLoader = custom_image_stack_loader
94
+
95
+ # %%
96
+ # So pylance knows that datatype is custom to match function overloads
97
+ assert data_config.data_type is SupportedData.CUSTOM
98
+
99
+ patch_extractor = create_patch_extractor(
100
+ source={"store": store, "data_paths": data_paths},
101
+ axes=data_config.axes,
102
+ data_type=data_config.data_type,
103
+ image_stack_loader=custom_image_stack_loader,
104
+ )
105
+
106
+ # %%
107
+ # extract patch and display
108
+ patch = patch_extractor.extract_patch(2, 0, (8, 16), (16, 16))
109
+ plt.imshow(np.moveaxis(patch, 0, -1))
110
+
111
+ # %%
@@ -0,0 +1,9 @@
1
+ __all__ = [
2
+ "ImageStack",
3
+ "InMemoryImageStack",
4
+ "ZarrImageStack",
5
+ ]
6
+
7
+ from .image_stack_protocol import ImageStack
8
+ from .in_memory_image_stack import InMemoryImageStack
9
+ from .zarr_image_stack import ZarrImageStack
@@ -0,0 +1,53 @@
1
+ from collections.abc import Sequence
2
+ from pathlib import Path
3
+ from typing import Literal, Protocol, Union
4
+
5
+ from numpy.typing import DTypeLike, NDArray
6
+
7
+
8
+ class ImageStack(Protocol):
9
+ """
10
+ An interface for extracting patches from an image stack.
11
+
12
+ Attributes
13
+ ----------
14
+ source: Path or "array"
15
+ Origin of the image data.
16
+ data_shape: Sequence[int]
17
+ The shape of the data, it is expected to be in the order (SC(Z)YX).
18
+
19
+ """
20
+
21
+ # TODO: not sure how compatible using Path will be for a zarr array
22
+ # (for a zarr array need to specify file path and internal zarr path)
23
+ @property
24
+ def source(self) -> Union[Path, Literal["array"]]: ...
25
+
26
+ @property
27
+ def data_shape(self) -> Sequence[int]: ...
28
+
29
+ @property
30
+ def data_dtype(self) -> DTypeLike: ...
31
+
32
+ def extract_patch(
33
+ self, sample_idx: int, coords: Sequence[int], patch_size: Sequence[int]
34
+ ) -> NDArray:
35
+ """
36
+ Extracts a patch for a given sample within the image stack.
37
+
38
+ Parameters
39
+ ----------
40
+ sample_idx: int
41
+ Sample index. The first dimension of the image data will be indexed at this
42
+ value.
43
+ coords: Sequence of int
44
+ The coordinates that define the start of a patch.
45
+ patch_size: Sequence of int
46
+ The size of the patch in each spatial dimension.
47
+
48
+ Returns
49
+ -------
50
+ numpy.ndarray
51
+ A patch of the image data from a particlular sample. It will have the
52
+ dimensions C(Z)YX.
53
+ """
@@ -0,0 +1,55 @@
1
+ from collections.abc import Sequence
2
+ from pathlib import Path
3
+ from typing import Any, Literal, Union
4
+
5
+ from numpy.typing import DTypeLike, NDArray
6
+ from typing_extensions import Self
7
+
8
+ from careamics.dataset.dataset_utils import reshape_array
9
+ from careamics.file_io.read import ReadFunc, read_tiff
10
+
11
+
12
+ class InMemoryImageStack:
13
+ """
14
+ A class for extracting patches from an image stack that has been loaded into memory.
15
+ """
16
+
17
+ def __init__(self, source: Union[Path, Literal["array"]], data: NDArray):
18
+ self.source: Union[Path, Literal["array"]] = source
19
+ # data expected to be in SC(Z)YX shape, reason to use from_array constructor
20
+ self._data: NDArray = data
21
+ self.data_shape: Sequence[int] = self._data.shape
22
+ self.data_dtype: DTypeLike = self._data.dtype
23
+
24
+ def extract_patch(
25
+ self, sample_idx: int, coords: Sequence[int], patch_size: Sequence[int]
26
+ ) -> NDArray:
27
+ if len(coords) != len(patch_size):
28
+ raise ValueError("Length of coords and extent must match.")
29
+ # TODO: test for 2D or 3D?
30
+ return self._data[
31
+ (
32
+ sample_idx, # type: ignore
33
+ ..., # type: ignore
34
+ *[slice(c, c + e) for c, e in zip(coords, patch_size)], # type: ignore
35
+ )
36
+ ]
37
+
38
+ @classmethod
39
+ def from_array(cls, data: NDArray, axes: str) -> Self:
40
+ data = reshape_array(data, axes)
41
+ return cls(source="array", data=data)
42
+
43
+ @classmethod
44
+ def from_tiff(cls, path: Path, axes: str) -> Self:
45
+ data = read_tiff(path)
46
+ data = reshape_array(data, axes)
47
+ return cls(source=path, data=data)
48
+
49
+ @classmethod
50
+ def from_custom_file_type(
51
+ cls, path: Path, axes: str, read_func: ReadFunc, **read_kwargs: Any
52
+ ) -> Self:
53
+ data = read_func(path, **read_kwargs)
54
+ data = reshape_array(data, axes)
55
+ return cls(source=path, data=data)