careamics 0.0.15__py3-none-any.whl → 0.0.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +11 -14
- careamics/cli/conf.py +18 -3
- careamics/config/__init__.py +8 -0
- careamics/config/algorithms/__init__.py +4 -0
- careamics/config/algorithms/hdn_algorithm_model.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_model.py +103 -0
- careamics/config/algorithms/n2v_algorithm_model.py +1 -2
- careamics/config/algorithms/vae_algorithm_model.py +51 -16
- careamics/config/architectures/lvae_model.py +12 -8
- careamics/config/callback_model.py +7 -3
- careamics/config/configuration.py +15 -63
- careamics/config/configuration_factories.py +853 -29
- careamics/config/data/data_model.py +50 -11
- careamics/config/data/ng_data_model.py +168 -4
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_model.py +16 -0
- careamics/config/data/patch_filter/mask_filter_model.py +17 -0
- careamics/config/data/patch_filter/max_filter_model.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_model.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_model.py +15 -0
- careamics/config/inference_model.py +1 -2
- careamics/config/likelihood_model.py +2 -2
- careamics/config/loss_model.py +6 -2
- careamics/config/nm_model.py +26 -1
- careamics/config/optimizer_models.py +1 -2
- careamics/config/support/supported_algorithms.py +5 -3
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_losses.py +5 -2
- careamics/config/training_model.py +6 -36
- careamics/config/transformations/normalize_model.py +1 -2
- careamics/dataset_ng/dataset.py +57 -5
- careamics/dataset_ng/factory.py +101 -18
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +4 -4
- careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -2
- careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +33 -7
- careamics/dataset_ng/patch_extractor/image_stack_loader.py +2 -2
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +94 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +95 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/file_io/read/__init__.py +0 -1
- careamics/lightning/__init__.py +16 -2
- careamics/lightning/callbacks/__init__.py +2 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/dataset_ng/data_module.py +79 -2
- careamics/lightning/lightning_module.py +162 -61
- careamics/lightning/microsplit_data_module.py +636 -0
- careamics/lightning/predict_data_module.py +8 -1
- careamics/lightning/train_data_module.py +19 -8
- careamics/losses/__init__.py +7 -1
- careamics/losses/loss_factory.py +9 -1
- careamics/losses/lvae/losses.py +85 -0
- careamics/lvae_training/dataset/__init__.py +8 -8
- careamics/lvae_training/dataset/config.py +56 -44
- careamics/lvae_training/dataset/lc_dataset.py +18 -12
- careamics/lvae_training/dataset/ms_dataset_ref.py +5 -5
- careamics/lvae_training/dataset/multich_dataset.py +24 -18
- careamics/lvae_training/dataset/multifile_dataset.py +6 -6
- careamics/lvae_training/eval_utils.py +46 -24
- careamics/model_io/bmz_io.py +9 -5
- careamics/models/lvae/likelihoods.py +31 -14
- careamics/models/lvae/lvae.py +2 -2
- careamics/models/lvae/noise_models.py +20 -14
- careamics/prediction_utils/__init__.py +8 -2
- careamics/prediction_utils/prediction_outputs.py +49 -3
- careamics/prediction_utils/stitch_prediction.py +83 -1
- careamics/transforms/xy_random_rotate90.py +1 -1
- careamics/utils/version.py +4 -4
- {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/METADATA +19 -22
- {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/RECORD +77 -60
- careamics/dataset/zarr_dataset.py +0 -151
- careamics/file_io/read/zarr.py +0 -60
- {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/WHEEL +0 -0
- {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/licenses/LICENSE +0 -0
careamics/dataset_ng/dataset.py
CHANGED
|
@@ -16,6 +16,7 @@ from careamics.config.transformations import NormalizeModel
|
|
|
16
16
|
from careamics.dataset.dataset_utils.running_stats import WelfordStatistics
|
|
17
17
|
from careamics.dataset.patching.patching import Stats
|
|
18
18
|
from careamics.dataset_ng.patch_extractor import GenericImageStack, PatchExtractor
|
|
19
|
+
from careamics.dataset_ng.patch_filter import create_coord_filter, create_patch_filter
|
|
19
20
|
from careamics.dataset_ng.patching_strategies import (
|
|
20
21
|
FixedRandomPatchingStrategy,
|
|
21
22
|
PatchingStrategy,
|
|
@@ -52,13 +53,26 @@ class CareamicsDataset(Dataset, Generic[GenericImageStack]):
|
|
|
52
53
|
mode: Mode,
|
|
53
54
|
input_extractor: PatchExtractor[GenericImageStack],
|
|
54
55
|
target_extractor: PatchExtractor[GenericImageStack] | None = None,
|
|
55
|
-
|
|
56
|
+
mask_extractor: PatchExtractor[GenericImageStack] | None = None,
|
|
57
|
+
) -> None:
|
|
56
58
|
self.config = data_config
|
|
57
59
|
self.mode = mode
|
|
58
60
|
|
|
59
61
|
self.input_extractor = input_extractor
|
|
60
62
|
self.target_extractor = target_extractor
|
|
61
63
|
|
|
64
|
+
self.patch_filter = (
|
|
65
|
+
create_patch_filter(self.config.patch_filter)
|
|
66
|
+
if self.config.patch_filter is not None
|
|
67
|
+
else None
|
|
68
|
+
)
|
|
69
|
+
self.coord_filter = (
|
|
70
|
+
create_coord_filter(self.config.coord_filter, mask=mask_extractor)
|
|
71
|
+
if self.config.coord_filter is not None and mask_extractor is not None
|
|
72
|
+
else None
|
|
73
|
+
)
|
|
74
|
+
self.patch_filter_patience = self.config.patch_filter_patience
|
|
75
|
+
|
|
62
76
|
self.patching_strategy = self._initialize_patching_strategy()
|
|
63
77
|
|
|
64
78
|
self.input_stats, self.target_stats = self._initialize_statistics()
|
|
@@ -183,10 +197,10 @@ class CareamicsDataset(Dataset, Generic[GenericImageStack]):
|
|
|
183
197
|
region_spec=patch_spec,
|
|
184
198
|
)
|
|
185
199
|
|
|
186
|
-
def
|
|
187
|
-
self,
|
|
188
|
-
) ->
|
|
189
|
-
|
|
200
|
+
def _extract_patches(
|
|
201
|
+
self, patch_spec: PatchSpecs
|
|
202
|
+
) -> tuple[NDArray, NDArray | None]:
|
|
203
|
+
"""Extract input and target patches based on patch specifications."""
|
|
190
204
|
input_patch = self.input_extractor.extract_patch(
|
|
191
205
|
data_idx=patch_spec["data_idx"],
|
|
192
206
|
sample_idx=patch_spec["sample_idx"],
|
|
@@ -204,7 +218,45 @@ class CareamicsDataset(Dataset, Generic[GenericImageStack]):
|
|
|
204
218
|
if self.target_extractor is not None
|
|
205
219
|
else None
|
|
206
220
|
)
|
|
221
|
+
return input_patch, target_patch
|
|
222
|
+
|
|
223
|
+
def _get_filtered_patch(
|
|
224
|
+
self, index: int
|
|
225
|
+
) -> tuple[NDArray[Any], NDArray[Any] | None, PatchSpecs]:
|
|
226
|
+
"""Extract a patch that passes filtering criteria with retry logic."""
|
|
227
|
+
should_filter = self.mode == Mode.TRAINING and (
|
|
228
|
+
self.patch_filter is not None or self.coord_filter is not None
|
|
229
|
+
)
|
|
230
|
+
empty_patch = True
|
|
231
|
+
patch_filter_patience = self.patch_filter_patience # reset patience
|
|
232
|
+
|
|
233
|
+
while empty_patch and patch_filter_patience > 0:
|
|
234
|
+
# query patches
|
|
235
|
+
patch_spec = self.patching_strategy.get_patch_spec(index)
|
|
236
|
+
|
|
237
|
+
# filter patch based on coordinates if needed
|
|
238
|
+
if should_filter and self.coord_filter is not None:
|
|
239
|
+
if self.coord_filter.filter_out(patch_spec):
|
|
240
|
+
patch_filter_patience -= 1
|
|
241
|
+
continue
|
|
242
|
+
|
|
243
|
+
input_patch, target_patch = self._extract_patches(patch_spec)
|
|
244
|
+
|
|
245
|
+
# filter patch based on values if needed
|
|
246
|
+
if should_filter and self.patch_filter is not None:
|
|
247
|
+
empty_patch = self.patch_filter.filter_out(input_patch)
|
|
248
|
+
patch_filter_patience -= 1 # decrease patience
|
|
249
|
+
else:
|
|
250
|
+
empty_patch = False
|
|
251
|
+
|
|
252
|
+
return input_patch, target_patch, patch_spec
|
|
253
|
+
|
|
254
|
+
def __getitem__(
|
|
255
|
+
self, index: int
|
|
256
|
+
) -> Union[tuple[ImageRegionData], tuple[ImageRegionData, ImageRegionData]]:
|
|
257
|
+
input_patch, target_patch, patch_spec = self._get_filtered_patch(index)
|
|
207
258
|
|
|
259
|
+
# apply transforms
|
|
208
260
|
if self.transforms is not None:
|
|
209
261
|
if self.target_extractor is not None:
|
|
210
262
|
input_patch, target_patch = self.transforms(input_patch, target_patch)
|
careamics/dataset_ng/factory.py
CHANGED
|
@@ -121,6 +121,7 @@ def create_dataset(
|
|
|
121
121
|
inputs: Any,
|
|
122
122
|
targets: Any,
|
|
123
123
|
in_memory: bool,
|
|
124
|
+
masks: Any = None,
|
|
124
125
|
read_func: ReadFunc | None = None,
|
|
125
126
|
read_kwargs: dict[str, Any] | None = None,
|
|
126
127
|
image_stack_loader: ImageStackLoader | None = None,
|
|
@@ -142,6 +143,8 @@ def create_dataset(
|
|
|
142
143
|
in_memory : bool
|
|
143
144
|
Whether all the data should be loaded into memory. This is argument is ignored
|
|
144
145
|
unless the `data_type` in `config` is "tiff" or "custom".
|
|
146
|
+
masks : Any, optional
|
|
147
|
+
The mask sources used to filter patches.
|
|
145
148
|
read_func : ReadFunc, optional
|
|
146
149
|
A function that can that can be used to load custom data. This argument is
|
|
147
150
|
ignored unless the `data_type` in the `config` is "custom".
|
|
@@ -168,18 +171,24 @@ def create_dataset(
|
|
|
168
171
|
data_type, in_memory, read_func, image_stack_loader
|
|
169
172
|
)
|
|
170
173
|
if dataset_type == DatasetType.ARRAY:
|
|
171
|
-
return create_array_dataset(config, mode, inputs, targets)
|
|
174
|
+
return create_array_dataset(config, mode, inputs, targets, masks)
|
|
172
175
|
elif dataset_type == DatasetType.IN_MEM_TIFF:
|
|
173
|
-
return create_tiff_dataset(config, mode, inputs, targets)
|
|
176
|
+
return create_tiff_dataset(config, mode, inputs, targets, masks)
|
|
174
177
|
# TODO: Lazy tiff
|
|
175
178
|
elif dataset_type == DatasetType.CZI:
|
|
176
|
-
return create_czi_dataset(config, mode, inputs, targets)
|
|
179
|
+
return create_czi_dataset(config, mode, inputs, targets, masks)
|
|
177
180
|
elif dataset_type == DatasetType.IN_MEM_CUSTOM_FILE:
|
|
178
181
|
if read_kwargs is None:
|
|
179
182
|
read_kwargs = {}
|
|
180
183
|
assert read_func is not None # should be true from `determine_dataset_type`
|
|
181
184
|
return create_custom_file_dataset(
|
|
182
|
-
config,
|
|
185
|
+
config,
|
|
186
|
+
mode,
|
|
187
|
+
inputs,
|
|
188
|
+
targets,
|
|
189
|
+
masks,
|
|
190
|
+
read_func=read_func,
|
|
191
|
+
read_kwargs=read_kwargs,
|
|
183
192
|
)
|
|
184
193
|
elif dataset_type == DatasetType.CUSTOM_IMAGE_STACK:
|
|
185
194
|
if image_stack_loader_kwargs is None:
|
|
@@ -191,6 +200,7 @@ def create_dataset(
|
|
|
191
200
|
inputs,
|
|
192
201
|
targets,
|
|
193
202
|
image_stack_loader,
|
|
203
|
+
masks,
|
|
194
204
|
**image_stack_loader_kwargs,
|
|
195
205
|
)
|
|
196
206
|
else:
|
|
@@ -202,6 +212,7 @@ def create_array_dataset(
|
|
|
202
212
|
mode: Mode,
|
|
203
213
|
inputs: Sequence[NDArray[Any]],
|
|
204
214
|
targets: Sequence[NDArray[Any]] | None,
|
|
215
|
+
masks: Sequence[NDArray[Any]] | None = None,
|
|
205
216
|
) -> CareamicsDataset[InMemoryImageStack]:
|
|
206
217
|
"""
|
|
207
218
|
Create a CAREamicsDataset from array data.
|
|
@@ -216,6 +227,8 @@ def create_array_dataset(
|
|
|
216
227
|
The input sources to the dataset.
|
|
217
228
|
targets : Any, optional
|
|
218
229
|
The target sources to the dataset.
|
|
230
|
+
masks : Any, optional
|
|
231
|
+
The mask sources used to filter patches.
|
|
219
232
|
|
|
220
233
|
Returns
|
|
221
234
|
-------
|
|
@@ -228,7 +241,14 @@ def create_array_dataset(
|
|
|
228
241
|
target_extractor = create_array_extractor(source=targets, axes=config.axes)
|
|
229
242
|
else:
|
|
230
243
|
target_extractor = None
|
|
231
|
-
|
|
244
|
+
mask_extractor: PatchExtractor[InMemoryImageStack] | None
|
|
245
|
+
if masks is not None:
|
|
246
|
+
mask_extractor = create_array_extractor(source=masks, axes=config.axes)
|
|
247
|
+
else:
|
|
248
|
+
mask_extractor = None
|
|
249
|
+
return CareamicsDataset(
|
|
250
|
+
config, mode, input_extractor, target_extractor, mask_extractor
|
|
251
|
+
)
|
|
232
252
|
|
|
233
253
|
|
|
234
254
|
def create_tiff_dataset(
|
|
@@ -236,9 +256,10 @@ def create_tiff_dataset(
|
|
|
236
256
|
mode: Mode,
|
|
237
257
|
inputs: Sequence[Path],
|
|
238
258
|
targets: Sequence[Path] | None,
|
|
259
|
+
masks: Sequence[Path] | None = None,
|
|
239
260
|
) -> CareamicsDataset[InMemoryImageStack]:
|
|
240
261
|
"""
|
|
241
|
-
Create a CAREamicsDataset from tiff files that will be
|
|
262
|
+
Create a CAREamicsDataset from tiff files that will be loaded into memory.
|
|
242
263
|
|
|
243
264
|
Parameters
|
|
244
265
|
----------
|
|
@@ -246,10 +267,12 @@ def create_tiff_dataset(
|
|
|
246
267
|
The data configuration.
|
|
247
268
|
mode : Mode
|
|
248
269
|
Whether to create the dataset in "training", "validation" or "predicting" mode.
|
|
249
|
-
inputs :
|
|
270
|
+
inputs : Sequence[Path]
|
|
250
271
|
The input sources to the dataset.
|
|
251
|
-
targets :
|
|
272
|
+
targets : Sequence[Path], optional
|
|
252
273
|
The target sources to the dataset.
|
|
274
|
+
masks : Sequence[Path], optional
|
|
275
|
+
The mask sources used to filter patches.
|
|
253
276
|
|
|
254
277
|
Returns
|
|
255
278
|
-------
|
|
@@ -265,8 +288,15 @@ def create_tiff_dataset(
|
|
|
265
288
|
target_extractor = create_tiff_extractor(source=targets, axes=config.axes)
|
|
266
289
|
else:
|
|
267
290
|
target_extractor = None
|
|
268
|
-
|
|
269
|
-
|
|
291
|
+
mask_extractor: PatchExtractor[InMemoryImageStack] | None
|
|
292
|
+
if masks is not None:
|
|
293
|
+
mask_extractor = create_tiff_extractor(source=masks, axes=config.axes)
|
|
294
|
+
else:
|
|
295
|
+
mask_extractor = None
|
|
296
|
+
|
|
297
|
+
return CareamicsDataset(
|
|
298
|
+
config, mode, input_extractor, target_extractor, mask_extractor
|
|
299
|
+
)
|
|
270
300
|
|
|
271
301
|
|
|
272
302
|
def create_czi_dataset(
|
|
@@ -274,6 +304,7 @@ def create_czi_dataset(
|
|
|
274
304
|
mode: Mode,
|
|
275
305
|
inputs: Sequence[Path],
|
|
276
306
|
targets: Sequence[Path] | None,
|
|
307
|
+
masks: Sequence[Path] | None = None,
|
|
277
308
|
) -> CareamicsDataset[CziImageStack]:
|
|
278
309
|
"""
|
|
279
310
|
Create a dataset from CZI files.
|
|
@@ -288,6 +319,8 @@ def create_czi_dataset(
|
|
|
288
319
|
The input sources to the dataset.
|
|
289
320
|
targets : Any, optional
|
|
290
321
|
The target sources to the dataset.
|
|
322
|
+
masks : Any, optional
|
|
323
|
+
The mask sources used to filter patches.
|
|
291
324
|
|
|
292
325
|
Returns
|
|
293
326
|
-------
|
|
@@ -301,8 +334,15 @@ def create_czi_dataset(
|
|
|
301
334
|
target_extractor = create_czi_extractor(source=targets, axes=config.axes)
|
|
302
335
|
else:
|
|
303
336
|
target_extractor = None
|
|
304
|
-
|
|
305
|
-
|
|
337
|
+
mask_extractor: PatchExtractor[CziImageStack] | None
|
|
338
|
+
if masks is not None:
|
|
339
|
+
mask_extractor = create_czi_extractor(source=masks, axes=config.axes)
|
|
340
|
+
else:
|
|
341
|
+
mask_extractor = None
|
|
342
|
+
|
|
343
|
+
return CareamicsDataset(
|
|
344
|
+
config, mode, input_extractor, target_extractor, mask_extractor
|
|
345
|
+
)
|
|
306
346
|
|
|
307
347
|
|
|
308
348
|
def create_ome_zarr_dataset(
|
|
@@ -310,6 +350,7 @@ def create_ome_zarr_dataset(
|
|
|
310
350
|
mode: Mode,
|
|
311
351
|
inputs: Sequence[Path],
|
|
312
352
|
targets: Sequence[Path] | None,
|
|
353
|
+
masks: Sequence[Path] | None = None,
|
|
313
354
|
) -> CareamicsDataset[ZarrImageStack]:
|
|
314
355
|
"""
|
|
315
356
|
Create a dataset from OME ZARR files.
|
|
@@ -324,6 +365,8 @@ def create_ome_zarr_dataset(
|
|
|
324
365
|
The input sources to the dataset.
|
|
325
366
|
targets : Any, optional
|
|
326
367
|
The target sources to the dataset.
|
|
368
|
+
masks : Any, optional
|
|
369
|
+
The mask sources used to filter patches.
|
|
327
370
|
|
|
328
371
|
Returns
|
|
329
372
|
-------
|
|
@@ -337,8 +380,15 @@ def create_ome_zarr_dataset(
|
|
|
337
380
|
target_extractor = create_ome_zarr_extractor(source=targets, axes=config.axes)
|
|
338
381
|
else:
|
|
339
382
|
target_extractor = None
|
|
340
|
-
|
|
341
|
-
|
|
383
|
+
mask_extractor: PatchExtractor[ZarrImageStack] | None
|
|
384
|
+
if masks is not None:
|
|
385
|
+
mask_extractor = create_ome_zarr_extractor(source=masks, axes=config.axes)
|
|
386
|
+
else:
|
|
387
|
+
mask_extractor = None
|
|
388
|
+
|
|
389
|
+
return CareamicsDataset(
|
|
390
|
+
config, mode, input_extractor, target_extractor, mask_extractor
|
|
391
|
+
)
|
|
342
392
|
|
|
343
393
|
|
|
344
394
|
def create_custom_file_dataset(
|
|
@@ -346,6 +396,7 @@ def create_custom_file_dataset(
|
|
|
346
396
|
mode: Mode,
|
|
347
397
|
inputs: Sequence[Path],
|
|
348
398
|
targets: Sequence[Path] | None,
|
|
399
|
+
masks: Sequence[Path] | None = None,
|
|
349
400
|
*,
|
|
350
401
|
read_func: ReadFunc,
|
|
351
402
|
read_kwargs: dict[str, Any],
|
|
@@ -363,6 +414,8 @@ def create_custom_file_dataset(
|
|
|
363
414
|
The input sources to the dataset.
|
|
364
415
|
targets : Any, optional
|
|
365
416
|
The target sources to the dataset.
|
|
417
|
+
masks : Any, optional
|
|
418
|
+
The mask sources used to filter patches.
|
|
366
419
|
read_func : Optional[ReadFunc], optional
|
|
367
420
|
A function that can that can be used to load custom data. This argument is
|
|
368
421
|
ignored unless the `data_type` is "custom".
|
|
@@ -388,8 +441,21 @@ def create_custom_file_dataset(
|
|
|
388
441
|
)
|
|
389
442
|
else:
|
|
390
443
|
target_extractor = None
|
|
391
|
-
|
|
392
|
-
|
|
444
|
+
|
|
445
|
+
mask_extractor: PatchExtractor[InMemoryImageStack] | None
|
|
446
|
+
if masks is not None:
|
|
447
|
+
mask_extractor = create_custom_file_extractor(
|
|
448
|
+
source=masks,
|
|
449
|
+
axes=config.axes,
|
|
450
|
+
read_func=read_func,
|
|
451
|
+
read_kwargs=read_kwargs,
|
|
452
|
+
)
|
|
453
|
+
else:
|
|
454
|
+
mask_extractor = None
|
|
455
|
+
|
|
456
|
+
return CareamicsDataset(
|
|
457
|
+
config, mode, input_extractor, target_extractor, mask_extractor
|
|
458
|
+
)
|
|
393
459
|
|
|
394
460
|
|
|
395
461
|
def create_custom_image_stack_dataset(
|
|
@@ -398,6 +464,7 @@ def create_custom_image_stack_dataset(
|
|
|
398
464
|
inputs: Any,
|
|
399
465
|
targets: Any | None,
|
|
400
466
|
image_stack_loader: ImageStackLoader[P, GenericImageStack],
|
|
467
|
+
masks: Any | None = None,
|
|
401
468
|
*args: P.args,
|
|
402
469
|
**kwargs: P.kwargs,
|
|
403
470
|
) -> CareamicsDataset[GenericImageStack]:
|
|
@@ -419,6 +486,8 @@ def create_custom_image_stack_dataset(
|
|
|
419
486
|
image_stack_loader : ImageStackLoader
|
|
420
487
|
A function for custom image stack loading. This argument is ignored unless the
|
|
421
488
|
`data_type` is "custom".
|
|
489
|
+
masks : Any, optional
|
|
490
|
+
The mask sources used to filter patches.
|
|
422
491
|
*args : Any
|
|
423
492
|
Positional arguments to pass to the `image_stack_loader`.
|
|
424
493
|
**kwargs : Any
|
|
@@ -447,5 +516,19 @@ def create_custom_image_stack_dataset(
|
|
|
447
516
|
)
|
|
448
517
|
else:
|
|
449
518
|
target_extractor = None
|
|
450
|
-
|
|
451
|
-
|
|
519
|
+
|
|
520
|
+
mask_extractor: PatchExtractor[GenericImageStack] | None
|
|
521
|
+
if masks is not None:
|
|
522
|
+
mask_extractor = create_custom_image_stack_extractor(
|
|
523
|
+
masks,
|
|
524
|
+
config.axes,
|
|
525
|
+
image_stack_loader,
|
|
526
|
+
*args,
|
|
527
|
+
**kwargs,
|
|
528
|
+
)
|
|
529
|
+
else:
|
|
530
|
+
mask_extractor = None
|
|
531
|
+
|
|
532
|
+
return CareamicsDataset(
|
|
533
|
+
config, mode, input_extractor, target_extractor, mask_extractor
|
|
534
|
+
)
|
|
@@ -7,7 +7,7 @@ import matplotlib.pyplot as plt
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
import zarr
|
|
9
9
|
from numpy.typing import NDArray
|
|
10
|
-
from zarr.storage import
|
|
10
|
+
from zarr.storage import FsspecStore
|
|
11
11
|
|
|
12
12
|
from careamics.config import DataConfig
|
|
13
13
|
from careamics.config.support import SupportedData
|
|
@@ -20,7 +20,7 @@ from careamics.dataset_ng.patch_extractor.patch_extractor_factory import (
|
|
|
20
20
|
|
|
21
21
|
# %%
|
|
22
22
|
def create_zarr_array(file_path: Path, data_path: str, data: NDArray):
|
|
23
|
-
store =
|
|
23
|
+
store = FsspecStore.from_url(url=file_path.resolve())
|
|
24
24
|
# create array
|
|
25
25
|
array = zarr.create(
|
|
26
26
|
store=store,
|
|
@@ -61,7 +61,7 @@ if not file_path.is_file() and not file_path.is_dir():
|
|
|
61
61
|
# ### Make sure file exists
|
|
62
62
|
|
|
63
63
|
# %%
|
|
64
|
-
store =
|
|
64
|
+
store = FsspecStore.from_url(url=file_path.resolve(), mode="r")
|
|
65
65
|
|
|
66
66
|
# %%
|
|
67
67
|
list(store.keys())
|
|
@@ -72,7 +72,7 @@ list(store.keys())
|
|
|
72
72
|
|
|
73
73
|
# %%
|
|
74
74
|
class ZarrSource(TypedDict):
|
|
75
|
-
store:
|
|
75
|
+
store: FsspecStore
|
|
76
76
|
data_paths: Sequence[str]
|
|
77
77
|
|
|
78
78
|
|
|
@@ -1,9 +1,8 @@
|
|
|
1
1
|
from collections.abc import Sequence
|
|
2
2
|
from pathlib import Path
|
|
3
|
-
from typing import Any, Literal, Union
|
|
3
|
+
from typing import Any, Literal, Self, Union
|
|
4
4
|
|
|
5
5
|
from numpy.typing import DTypeLike, NDArray
|
|
6
|
-
from typing_extensions import Self
|
|
7
6
|
|
|
8
7
|
from careamics.dataset.dataset_utils import reshape_array
|
|
9
8
|
from careamics.file_io.read import ReadFunc, read_tiff
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
from collections.abc import Sequence
|
|
2
2
|
from pathlib import Path
|
|
3
|
-
from typing import Union
|
|
3
|
+
from typing import Self, Union
|
|
4
4
|
|
|
5
|
+
import validators
|
|
5
6
|
import zarr
|
|
6
|
-
import zarr.storage
|
|
7
7
|
from numpy.typing import NDArray
|
|
8
|
-
from
|
|
8
|
+
from zarr.storage import FsspecStore, LocalStore
|
|
9
9
|
|
|
10
10
|
from careamics.dataset.dataset_utils import reshape_array
|
|
11
11
|
|
|
@@ -15,9 +15,10 @@ class ZarrImageStack:
|
|
|
15
15
|
A class for extracting patches from an image stack that is stored as a zarr array.
|
|
16
16
|
"""
|
|
17
17
|
|
|
18
|
-
# TODO:
|
|
19
|
-
#
|
|
20
|
-
|
|
18
|
+
# TODO: We should keep store type narrow
|
|
19
|
+
# - in zarr v3, does zarr.storage.Store exists and has the path attribute?
|
|
20
|
+
# - can we declare a narrow type rather than a union?
|
|
21
|
+
def __init__(self, store: LocalStore | FsspecStore, data_path: str, axes: str):
|
|
21
22
|
self._store = store
|
|
22
23
|
self._array = zarr.open_array(store=self._store, path=data_path, mode="r")
|
|
23
24
|
# TODO: validate axes
|
|
@@ -46,8 +47,33 @@ class ZarrImageStack:
|
|
|
46
47
|
Assumes the path only contains 1 image.
|
|
47
48
|
|
|
48
49
|
Path can be to a local file, or it can be a URL to a zarr stored in the cloud.
|
|
50
|
+
|
|
51
|
+
Parameters
|
|
52
|
+
----------
|
|
53
|
+
path : Union[Path, str]
|
|
54
|
+
Path to the root of the OME-Zarr, local file or url.
|
|
55
|
+
|
|
56
|
+
Returns
|
|
57
|
+
-------
|
|
58
|
+
ZarrImageStack
|
|
59
|
+
Initialised ZarrImageStack.
|
|
60
|
+
|
|
61
|
+
Raises
|
|
62
|
+
------
|
|
63
|
+
ValueError
|
|
64
|
+
If the path does not exist or is not a valid URL.
|
|
65
|
+
ValueError
|
|
66
|
+
If the OME-Zarr at the path does not contain the attribute 'multiscales'.
|
|
49
67
|
"""
|
|
50
|
-
|
|
68
|
+
if Path(path).is_file():
|
|
69
|
+
store = zarr.storage.LocalStore(root=Path(path).resolve())
|
|
70
|
+
elif validators.url(path):
|
|
71
|
+
store = zarr.storage.FsspecStore.from_url(url=path)
|
|
72
|
+
else:
|
|
73
|
+
raise ValueError(
|
|
74
|
+
f"Path '{path}' is neither an existing file nor a valid URL."
|
|
75
|
+
)
|
|
76
|
+
|
|
51
77
|
group = zarr.open_group(store=store, mode="r")
|
|
52
78
|
if "multiscales" not in group.attrs:
|
|
53
79
|
raise ValueError(
|
|
@@ -38,7 +38,7 @@ class ImageStackLoader(Protocol[P, GenericImageStack]):
|
|
|
38
38
|
|
|
39
39
|
>>> from typing import TypedDict
|
|
40
40
|
|
|
41
|
-
>>> from zarr.storage import
|
|
41
|
+
>>> from zarr.storage import FsspecStore
|
|
42
42
|
|
|
43
43
|
>>> from careamics.config import DataConfig
|
|
44
44
|
>>> from careamics.dataset_ng.patch_extractor.image_stack import ZarrImageStack
|
|
@@ -46,7 +46,7 @@ class ImageStackLoader(Protocol[P, GenericImageStack]):
|
|
|
46
46
|
>>> # Define a zarr source
|
|
47
47
|
>>> # It encompasses multiple arguments that determine what data will be loaded
|
|
48
48
|
>>> class ZarrSource(TypedDict):
|
|
49
|
-
... store:
|
|
49
|
+
... store: FsspecStore
|
|
50
50
|
... data_paths: Sequence[str]
|
|
51
51
|
|
|
52
52
|
>>> def custom_image_stack_loader(
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Patch filtering strategies."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"CoordinateFilterProtocol",
|
|
5
|
+
"MaskCoordFilter",
|
|
6
|
+
"MaxPatchFilter",
|
|
7
|
+
"MeanStdPatchFilter",
|
|
8
|
+
"PatchFilterProtocol",
|
|
9
|
+
"ShannonPatchFilter",
|
|
10
|
+
"create_coord_filter",
|
|
11
|
+
"create_patch_filter",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
from .coordinate_filter_protocol import CoordinateFilterProtocol
|
|
15
|
+
from .filter_factory import create_coord_filter, create_patch_filter
|
|
16
|
+
from .mask_filter import MaskCoordFilter
|
|
17
|
+
from .max_filter import MaxPatchFilter
|
|
18
|
+
from .mean_std_filter import MeanStdPatchFilter
|
|
19
|
+
from .patch_filter_protocol import PatchFilterProtocol
|
|
20
|
+
from .shannon_filter import ShannonPatchFilter
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""A protocol for patch filtering."""
|
|
2
|
+
|
|
3
|
+
from typing import Protocol
|
|
4
|
+
|
|
5
|
+
from careamics.dataset_ng.patching_strategies import PatchSpecs
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class CoordinateFilterProtocol(Protocol):
|
|
9
|
+
"""
|
|
10
|
+
An interface for implementing coordinate filtering strategies.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def filter_out(self, patch: PatchSpecs) -> bool:
|
|
14
|
+
"""
|
|
15
|
+
Determine whether to filter out a given patch based on its coordinates.
|
|
16
|
+
|
|
17
|
+
Parameters
|
|
18
|
+
----------
|
|
19
|
+
patch : PatchSpecs
|
|
20
|
+
The patch coordinates to evaluate.
|
|
21
|
+
|
|
22
|
+
Returns
|
|
23
|
+
-------
|
|
24
|
+
bool
|
|
25
|
+
True if the patch should be filtered out (excluded), False otherwise.
|
|
26
|
+
"""
|
|
27
|
+
...
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
"""Factories for coordinate and patch filters."""
|
|
2
|
+
|
|
3
|
+
from typing import Union
|
|
4
|
+
|
|
5
|
+
from careamics.config.data.patch_filter import (
|
|
6
|
+
FilterModel,
|
|
7
|
+
MaskFilterModel,
|
|
8
|
+
MaxFilterModel,
|
|
9
|
+
MeanSTDFilterModel,
|
|
10
|
+
ShannonFilterModel,
|
|
11
|
+
)
|
|
12
|
+
from careamics.config.support.supported_filters import (
|
|
13
|
+
SupportedCoordinateFilters,
|
|
14
|
+
SupportedPatchFilters,
|
|
15
|
+
)
|
|
16
|
+
from careamics.dataset_ng.patch_extractor import GenericImageStack, PatchExtractor
|
|
17
|
+
|
|
18
|
+
from .mask_filter import MaskCoordFilter
|
|
19
|
+
from .max_filter import MaxPatchFilter
|
|
20
|
+
from .mean_std_filter import MeanStdPatchFilter
|
|
21
|
+
from .shannon_filter import ShannonPatchFilter
|
|
22
|
+
|
|
23
|
+
PatchFilter = Union[
|
|
24
|
+
MaxPatchFilter,
|
|
25
|
+
MeanStdPatchFilter,
|
|
26
|
+
ShannonPatchFilter,
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
CoordFilter = Union[MaskCoordFilter]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def create_coord_filter(
|
|
34
|
+
filter_model: FilterModel, mask: PatchExtractor[GenericImageStack]
|
|
35
|
+
) -> CoordFilter:
|
|
36
|
+
"""Factory function to create coordinate filter instances based on the filter name.
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
filter_model : FilterModel
|
|
41
|
+
Pydantic model of the filter to be created.
|
|
42
|
+
mask : PatchExtractor[GenericImageStack]
|
|
43
|
+
Mask extractor to be used for the mask filter.
|
|
44
|
+
|
|
45
|
+
Returns
|
|
46
|
+
-------
|
|
47
|
+
CoordFilter
|
|
48
|
+
Instance of the mask patch filter.
|
|
49
|
+
"""
|
|
50
|
+
if filter_model.name == SupportedCoordinateFilters.MASK:
|
|
51
|
+
assert isinstance(filter_model, MaskFilterModel)
|
|
52
|
+
return MaskCoordFilter(
|
|
53
|
+
mask_extractor=mask,
|
|
54
|
+
coverage=filter_model.coverage,
|
|
55
|
+
p=filter_model.p,
|
|
56
|
+
seed=filter_model.seed,
|
|
57
|
+
)
|
|
58
|
+
else:
|
|
59
|
+
raise ValueError(f"Unknown filter name: {filter_model}")
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def create_patch_filter(filter_model: FilterModel) -> PatchFilter:
|
|
63
|
+
"""Factory function to create patch filter instances based on the filter name.
|
|
64
|
+
|
|
65
|
+
Parameters
|
|
66
|
+
----------
|
|
67
|
+
filter_model : FilterModel
|
|
68
|
+
Pydantic model of the filter to be created.
|
|
69
|
+
|
|
70
|
+
Returns
|
|
71
|
+
-------
|
|
72
|
+
PatchFilter
|
|
73
|
+
Instance of the requested patch filter.
|
|
74
|
+
"""
|
|
75
|
+
if filter_model.name == SupportedPatchFilters.MAX:
|
|
76
|
+
assert isinstance(filter_model, MaxFilterModel)
|
|
77
|
+
return MaxPatchFilter(
|
|
78
|
+
threshold=filter_model.threshold, p=filter_model.p, seed=filter_model.seed
|
|
79
|
+
)
|
|
80
|
+
elif filter_model.name == SupportedPatchFilters.MEANSTD:
|
|
81
|
+
assert isinstance(filter_model, MeanSTDFilterModel)
|
|
82
|
+
return MeanStdPatchFilter(
|
|
83
|
+
mean_threshold=filter_model.mean_threshold,
|
|
84
|
+
std_threshold=filter_model.std_threshold,
|
|
85
|
+
p=filter_model.p,
|
|
86
|
+
seed=filter_model.seed,
|
|
87
|
+
)
|
|
88
|
+
elif filter_model.name == SupportedPatchFilters.SHANNON:
|
|
89
|
+
assert isinstance(filter_model, ShannonFilterModel)
|
|
90
|
+
return ShannonPatchFilter(
|
|
91
|
+
threshold=filter_model.threshold, p=filter_model.p, seed=filter_model.seed
|
|
92
|
+
)
|
|
93
|
+
else:
|
|
94
|
+
raise ValueError(f"Unknown filter name: {filter_model}")
|