careamics 0.0.16__py3-none-any.whl → 0.0.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (36) hide show
  1. careamics/careamist.py +7 -4
  2. careamics/config/configuration.py +6 -55
  3. careamics/config/configuration_factories.py +22 -12
  4. careamics/config/data/data_model.py +49 -9
  5. careamics/config/data/ng_data_model.py +167 -2
  6. careamics/config/data/patch_filter/__init__.py +15 -0
  7. careamics/config/data/patch_filter/filter_model.py +16 -0
  8. careamics/config/data/patch_filter/mask_filter_model.py +17 -0
  9. careamics/config/data/patch_filter/max_filter_model.py +15 -0
  10. careamics/config/data/patch_filter/meanstd_filter_model.py +18 -0
  11. careamics/config/data/patch_filter/shannon_filter_model.py +15 -0
  12. careamics/config/support/supported_filters.py +17 -0
  13. careamics/dataset_ng/dataset.py +57 -5
  14. careamics/dataset_ng/factory.py +101 -18
  15. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  16. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  17. careamics/dataset_ng/patch_filter/filter_factory.py +94 -0
  18. careamics/dataset_ng/patch_filter/mask_filter.py +95 -0
  19. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  20. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  21. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  22. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  23. careamics/lightning/callbacks/data_stats_callback.py +13 -3
  24. careamics/lightning/dataset_ng/data_module.py +79 -2
  25. careamics/lightning/lightning_module.py +4 -3
  26. careamics/lightning/microsplit_data_module.py +15 -10
  27. careamics/lvae_training/eval_utils.py +46 -24
  28. careamics/models/lvae/likelihoods.py +2 -1
  29. careamics/prediction_utils/prediction_outputs.py +3 -2
  30. careamics/prediction_utils/stitch_prediction.py +17 -6
  31. careamics/utils/version.py +4 -4
  32. {careamics-0.0.16.dist-info → careamics-0.0.17.dist-info}/METADATA +5 -11
  33. {careamics-0.0.16.dist-info → careamics-0.0.17.dist-info}/RECORD +36 -21
  34. {careamics-0.0.16.dist-info → careamics-0.0.17.dist-info}/WHEEL +0 -0
  35. {careamics-0.0.16.dist-info → careamics-0.0.17.dist-info}/entry_points.txt +0 -0
  36. {careamics-0.0.16.dist-info → careamics-0.0.17.dist-info}/licenses/LICENSE +0 -0
@@ -16,6 +16,7 @@ from careamics.config.transformations import NormalizeModel
16
16
  from careamics.dataset.dataset_utils.running_stats import WelfordStatistics
17
17
  from careamics.dataset.patching.patching import Stats
18
18
  from careamics.dataset_ng.patch_extractor import GenericImageStack, PatchExtractor
19
+ from careamics.dataset_ng.patch_filter import create_coord_filter, create_patch_filter
19
20
  from careamics.dataset_ng.patching_strategies import (
20
21
  FixedRandomPatchingStrategy,
21
22
  PatchingStrategy,
@@ -52,13 +53,26 @@ class CareamicsDataset(Dataset, Generic[GenericImageStack]):
52
53
  mode: Mode,
53
54
  input_extractor: PatchExtractor[GenericImageStack],
54
55
  target_extractor: PatchExtractor[GenericImageStack] | None = None,
55
- ):
56
+ mask_extractor: PatchExtractor[GenericImageStack] | None = None,
57
+ ) -> None:
56
58
  self.config = data_config
57
59
  self.mode = mode
58
60
 
59
61
  self.input_extractor = input_extractor
60
62
  self.target_extractor = target_extractor
61
63
 
64
+ self.patch_filter = (
65
+ create_patch_filter(self.config.patch_filter)
66
+ if self.config.patch_filter is not None
67
+ else None
68
+ )
69
+ self.coord_filter = (
70
+ create_coord_filter(self.config.coord_filter, mask=mask_extractor)
71
+ if self.config.coord_filter is not None and mask_extractor is not None
72
+ else None
73
+ )
74
+ self.patch_filter_patience = self.config.patch_filter_patience
75
+
62
76
  self.patching_strategy = self._initialize_patching_strategy()
63
77
 
64
78
  self.input_stats, self.target_stats = self._initialize_statistics()
@@ -183,10 +197,10 @@ class CareamicsDataset(Dataset, Generic[GenericImageStack]):
183
197
  region_spec=patch_spec,
184
198
  )
185
199
 
186
- def __getitem__(
187
- self, index: int
188
- ) -> Union[tuple[ImageRegionData], tuple[ImageRegionData, ImageRegionData]]:
189
- patch_spec = self.patching_strategy.get_patch_spec(index)
200
+ def _extract_patches(
201
+ self, patch_spec: PatchSpecs
202
+ ) -> tuple[NDArray, NDArray | None]:
203
+ """Extract input and target patches based on patch specifications."""
190
204
  input_patch = self.input_extractor.extract_patch(
191
205
  data_idx=patch_spec["data_idx"],
192
206
  sample_idx=patch_spec["sample_idx"],
@@ -204,7 +218,45 @@ class CareamicsDataset(Dataset, Generic[GenericImageStack]):
204
218
  if self.target_extractor is not None
205
219
  else None
206
220
  )
221
+ return input_patch, target_patch
222
+
223
+ def _get_filtered_patch(
224
+ self, index: int
225
+ ) -> tuple[NDArray[Any], NDArray[Any] | None, PatchSpecs]:
226
+ """Extract a patch that passes filtering criteria with retry logic."""
227
+ should_filter = self.mode == Mode.TRAINING and (
228
+ self.patch_filter is not None or self.coord_filter is not None
229
+ )
230
+ empty_patch = True
231
+ patch_filter_patience = self.patch_filter_patience # reset patience
232
+
233
+ while empty_patch and patch_filter_patience > 0:
234
+ # query patches
235
+ patch_spec = self.patching_strategy.get_patch_spec(index)
236
+
237
+ # filter patch based on coordinates if needed
238
+ if should_filter and self.coord_filter is not None:
239
+ if self.coord_filter.filter_out(patch_spec):
240
+ patch_filter_patience -= 1
241
+ continue
242
+
243
+ input_patch, target_patch = self._extract_patches(patch_spec)
244
+
245
+ # filter patch based on values if needed
246
+ if should_filter and self.patch_filter is not None:
247
+ empty_patch = self.patch_filter.filter_out(input_patch)
248
+ patch_filter_patience -= 1 # decrease patience
249
+ else:
250
+ empty_patch = False
251
+
252
+ return input_patch, target_patch, patch_spec
253
+
254
+ def __getitem__(
255
+ self, index: int
256
+ ) -> Union[tuple[ImageRegionData], tuple[ImageRegionData, ImageRegionData]]:
257
+ input_patch, target_patch, patch_spec = self._get_filtered_patch(index)
207
258
 
259
+ # apply transforms
208
260
  if self.transforms is not None:
209
261
  if self.target_extractor is not None:
210
262
  input_patch, target_patch = self.transforms(input_patch, target_patch)
@@ -121,6 +121,7 @@ def create_dataset(
121
121
  inputs: Any,
122
122
  targets: Any,
123
123
  in_memory: bool,
124
+ masks: Any = None,
124
125
  read_func: ReadFunc | None = None,
125
126
  read_kwargs: dict[str, Any] | None = None,
126
127
  image_stack_loader: ImageStackLoader | None = None,
@@ -142,6 +143,8 @@ def create_dataset(
142
143
  in_memory : bool
143
144
  Whether all the data should be loaded into memory. This is argument is ignored
144
145
  unless the `data_type` in `config` is "tiff" or "custom".
146
+ masks : Any, optional
147
+ The mask sources used to filter patches.
145
148
  read_func : ReadFunc, optional
146
149
  A function that can that can be used to load custom data. This argument is
147
150
  ignored unless the `data_type` in the `config` is "custom".
@@ -168,18 +171,24 @@ def create_dataset(
168
171
  data_type, in_memory, read_func, image_stack_loader
169
172
  )
170
173
  if dataset_type == DatasetType.ARRAY:
171
- return create_array_dataset(config, mode, inputs, targets)
174
+ return create_array_dataset(config, mode, inputs, targets, masks)
172
175
  elif dataset_type == DatasetType.IN_MEM_TIFF:
173
- return create_tiff_dataset(config, mode, inputs, targets)
176
+ return create_tiff_dataset(config, mode, inputs, targets, masks)
174
177
  # TODO: Lazy tiff
175
178
  elif dataset_type == DatasetType.CZI:
176
- return create_czi_dataset(config, mode, inputs, targets)
179
+ return create_czi_dataset(config, mode, inputs, targets, masks)
177
180
  elif dataset_type == DatasetType.IN_MEM_CUSTOM_FILE:
178
181
  if read_kwargs is None:
179
182
  read_kwargs = {}
180
183
  assert read_func is not None # should be true from `determine_dataset_type`
181
184
  return create_custom_file_dataset(
182
- config, mode, inputs, targets, read_func=read_func, read_kwargs=read_kwargs
185
+ config,
186
+ mode,
187
+ inputs,
188
+ targets,
189
+ masks,
190
+ read_func=read_func,
191
+ read_kwargs=read_kwargs,
183
192
  )
184
193
  elif dataset_type == DatasetType.CUSTOM_IMAGE_STACK:
185
194
  if image_stack_loader_kwargs is None:
@@ -191,6 +200,7 @@ def create_dataset(
191
200
  inputs,
192
201
  targets,
193
202
  image_stack_loader,
203
+ masks,
194
204
  **image_stack_loader_kwargs,
195
205
  )
196
206
  else:
@@ -202,6 +212,7 @@ def create_array_dataset(
202
212
  mode: Mode,
203
213
  inputs: Sequence[NDArray[Any]],
204
214
  targets: Sequence[NDArray[Any]] | None,
215
+ masks: Sequence[NDArray[Any]] | None = None,
205
216
  ) -> CareamicsDataset[InMemoryImageStack]:
206
217
  """
207
218
  Create a CAREamicsDataset from array data.
@@ -216,6 +227,8 @@ def create_array_dataset(
216
227
  The input sources to the dataset.
217
228
  targets : Any, optional
218
229
  The target sources to the dataset.
230
+ masks : Any, optional
231
+ The mask sources used to filter patches.
219
232
 
220
233
  Returns
221
234
  -------
@@ -228,7 +241,14 @@ def create_array_dataset(
228
241
  target_extractor = create_array_extractor(source=targets, axes=config.axes)
229
242
  else:
230
243
  target_extractor = None
231
- return CareamicsDataset(config, mode, input_extractor, target_extractor)
244
+ mask_extractor: PatchExtractor[InMemoryImageStack] | None
245
+ if masks is not None:
246
+ mask_extractor = create_array_extractor(source=masks, axes=config.axes)
247
+ else:
248
+ mask_extractor = None
249
+ return CareamicsDataset(
250
+ config, mode, input_extractor, target_extractor, mask_extractor
251
+ )
232
252
 
233
253
 
234
254
  def create_tiff_dataset(
@@ -236,9 +256,10 @@ def create_tiff_dataset(
236
256
  mode: Mode,
237
257
  inputs: Sequence[Path],
238
258
  targets: Sequence[Path] | None,
259
+ masks: Sequence[Path] | None = None,
239
260
  ) -> CareamicsDataset[InMemoryImageStack]:
240
261
  """
241
- Create a CAREamicsDataset from tiff files that will be all loaded into memory.
262
+ Create a CAREamicsDataset from tiff files that will be loaded into memory.
242
263
 
243
264
  Parameters
244
265
  ----------
@@ -246,10 +267,12 @@ def create_tiff_dataset(
246
267
  The data configuration.
247
268
  mode : Mode
248
269
  Whether to create the dataset in "training", "validation" or "predicting" mode.
249
- inputs : Any
270
+ inputs : Sequence[Path]
250
271
  The input sources to the dataset.
251
- targets : Any, optional
272
+ targets : Sequence[Path], optional
252
273
  The target sources to the dataset.
274
+ masks : Sequence[Path], optional
275
+ The mask sources used to filter patches.
253
276
 
254
277
  Returns
255
278
  -------
@@ -265,8 +288,15 @@ def create_tiff_dataset(
265
288
  target_extractor = create_tiff_extractor(source=targets, axes=config.axes)
266
289
  else:
267
290
  target_extractor = None
268
- dataset = CareamicsDataset(config, mode, input_extractor, target_extractor)
269
- return dataset
291
+ mask_extractor: PatchExtractor[InMemoryImageStack] | None
292
+ if masks is not None:
293
+ mask_extractor = create_tiff_extractor(source=masks, axes=config.axes)
294
+ else:
295
+ mask_extractor = None
296
+
297
+ return CareamicsDataset(
298
+ config, mode, input_extractor, target_extractor, mask_extractor
299
+ )
270
300
 
271
301
 
272
302
  def create_czi_dataset(
@@ -274,6 +304,7 @@ def create_czi_dataset(
274
304
  mode: Mode,
275
305
  inputs: Sequence[Path],
276
306
  targets: Sequence[Path] | None,
307
+ masks: Sequence[Path] | None = None,
277
308
  ) -> CareamicsDataset[CziImageStack]:
278
309
  """
279
310
  Create a dataset from CZI files.
@@ -288,6 +319,8 @@ def create_czi_dataset(
288
319
  The input sources to the dataset.
289
320
  targets : Any, optional
290
321
  The target sources to the dataset.
322
+ masks : Any, optional
323
+ The mask sources used to filter patches.
291
324
 
292
325
  Returns
293
326
  -------
@@ -301,8 +334,15 @@ def create_czi_dataset(
301
334
  target_extractor = create_czi_extractor(source=targets, axes=config.axes)
302
335
  else:
303
336
  target_extractor = None
304
- dataset = CareamicsDataset(config, mode, input_extractor, target_extractor)
305
- return dataset
337
+ mask_extractor: PatchExtractor[CziImageStack] | None
338
+ if masks is not None:
339
+ mask_extractor = create_czi_extractor(source=masks, axes=config.axes)
340
+ else:
341
+ mask_extractor = None
342
+
343
+ return CareamicsDataset(
344
+ config, mode, input_extractor, target_extractor, mask_extractor
345
+ )
306
346
 
307
347
 
308
348
  def create_ome_zarr_dataset(
@@ -310,6 +350,7 @@ def create_ome_zarr_dataset(
310
350
  mode: Mode,
311
351
  inputs: Sequence[Path],
312
352
  targets: Sequence[Path] | None,
353
+ masks: Sequence[Path] | None = None,
313
354
  ) -> CareamicsDataset[ZarrImageStack]:
314
355
  """
315
356
  Create a dataset from OME ZARR files.
@@ -324,6 +365,8 @@ def create_ome_zarr_dataset(
324
365
  The input sources to the dataset.
325
366
  targets : Any, optional
326
367
  The target sources to the dataset.
368
+ masks : Any, optional
369
+ The mask sources used to filter patches.
327
370
 
328
371
  Returns
329
372
  -------
@@ -337,8 +380,15 @@ def create_ome_zarr_dataset(
337
380
  target_extractor = create_ome_zarr_extractor(source=targets, axes=config.axes)
338
381
  else:
339
382
  target_extractor = None
340
- dataset = CareamicsDataset(config, mode, input_extractor, target_extractor)
341
- return dataset
383
+ mask_extractor: PatchExtractor[ZarrImageStack] | None
384
+ if masks is not None:
385
+ mask_extractor = create_ome_zarr_extractor(source=masks, axes=config.axes)
386
+ else:
387
+ mask_extractor = None
388
+
389
+ return CareamicsDataset(
390
+ config, mode, input_extractor, target_extractor, mask_extractor
391
+ )
342
392
 
343
393
 
344
394
  def create_custom_file_dataset(
@@ -346,6 +396,7 @@ def create_custom_file_dataset(
346
396
  mode: Mode,
347
397
  inputs: Sequence[Path],
348
398
  targets: Sequence[Path] | None,
399
+ masks: Sequence[Path] | None = None,
349
400
  *,
350
401
  read_func: ReadFunc,
351
402
  read_kwargs: dict[str, Any],
@@ -363,6 +414,8 @@ def create_custom_file_dataset(
363
414
  The input sources to the dataset.
364
415
  targets : Any, optional
365
416
  The target sources to the dataset.
417
+ masks : Any, optional
418
+ The mask sources used to filter patches.
366
419
  read_func : Optional[ReadFunc], optional
367
420
  A function that can that can be used to load custom data. This argument is
368
421
  ignored unless the `data_type` is "custom".
@@ -388,8 +441,21 @@ def create_custom_file_dataset(
388
441
  )
389
442
  else:
390
443
  target_extractor = None
391
- dataset = CareamicsDataset(config, mode, input_extractor, target_extractor)
392
- return dataset
444
+
445
+ mask_extractor: PatchExtractor[InMemoryImageStack] | None
446
+ if masks is not None:
447
+ mask_extractor = create_custom_file_extractor(
448
+ source=masks,
449
+ axes=config.axes,
450
+ read_func=read_func,
451
+ read_kwargs=read_kwargs,
452
+ )
453
+ else:
454
+ mask_extractor = None
455
+
456
+ return CareamicsDataset(
457
+ config, mode, input_extractor, target_extractor, mask_extractor
458
+ )
393
459
 
394
460
 
395
461
  def create_custom_image_stack_dataset(
@@ -398,6 +464,7 @@ def create_custom_image_stack_dataset(
398
464
  inputs: Any,
399
465
  targets: Any | None,
400
466
  image_stack_loader: ImageStackLoader[P, GenericImageStack],
467
+ masks: Any | None = None,
401
468
  *args: P.args,
402
469
  **kwargs: P.kwargs,
403
470
  ) -> CareamicsDataset[GenericImageStack]:
@@ -419,6 +486,8 @@ def create_custom_image_stack_dataset(
419
486
  image_stack_loader : ImageStackLoader
420
487
  A function for custom image stack loading. This argument is ignored unless the
421
488
  `data_type` is "custom".
489
+ masks : Any, optional
490
+ The mask sources used to filter patches.
422
491
  *args : Any
423
492
  Positional arguments to pass to the `image_stack_loader`.
424
493
  **kwargs : Any
@@ -447,5 +516,19 @@ def create_custom_image_stack_dataset(
447
516
  )
448
517
  else:
449
518
  target_extractor = None
450
- dataset = CareamicsDataset(config, mode, input_extractor, target_extractor)
451
- return dataset
519
+
520
+ mask_extractor: PatchExtractor[GenericImageStack] | None
521
+ if masks is not None:
522
+ mask_extractor = create_custom_image_stack_extractor(
523
+ masks,
524
+ config.axes,
525
+ image_stack_loader,
526
+ *args,
527
+ **kwargs,
528
+ )
529
+ else:
530
+ mask_extractor = None
531
+
532
+ return CareamicsDataset(
533
+ config, mode, input_extractor, target_extractor, mask_extractor
534
+ )
@@ -0,0 +1,20 @@
1
+ """Patch filtering strategies."""
2
+
3
+ __all__ = [
4
+ "CoordinateFilterProtocol",
5
+ "MaskCoordFilter",
6
+ "MaxPatchFilter",
7
+ "MeanStdPatchFilter",
8
+ "PatchFilterProtocol",
9
+ "ShannonPatchFilter",
10
+ "create_coord_filter",
11
+ "create_patch_filter",
12
+ ]
13
+
14
+ from .coordinate_filter_protocol import CoordinateFilterProtocol
15
+ from .filter_factory import create_coord_filter, create_patch_filter
16
+ from .mask_filter import MaskCoordFilter
17
+ from .max_filter import MaxPatchFilter
18
+ from .mean_std_filter import MeanStdPatchFilter
19
+ from .patch_filter_protocol import PatchFilterProtocol
20
+ from .shannon_filter import ShannonPatchFilter
@@ -0,0 +1,27 @@
1
+ """A protocol for patch filtering."""
2
+
3
+ from typing import Protocol
4
+
5
+ from careamics.dataset_ng.patching_strategies import PatchSpecs
6
+
7
+
8
+ class CoordinateFilterProtocol(Protocol):
9
+ """
10
+ An interface for implementing coordinate filtering strategies.
11
+ """
12
+
13
+ def filter_out(self, patch: PatchSpecs) -> bool:
14
+ """
15
+ Determine whether to filter out a given patch based on its coordinates.
16
+
17
+ Parameters
18
+ ----------
19
+ patch : PatchSpecs
20
+ The patch coordinates to evaluate.
21
+
22
+ Returns
23
+ -------
24
+ bool
25
+ True if the patch should be filtered out (excluded), False otherwise.
26
+ """
27
+ ...
@@ -0,0 +1,94 @@
1
+ """Factories for coordinate and patch filters."""
2
+
3
+ from typing import Union
4
+
5
+ from careamics.config.data.patch_filter import (
6
+ FilterModel,
7
+ MaskFilterModel,
8
+ MaxFilterModel,
9
+ MeanSTDFilterModel,
10
+ ShannonFilterModel,
11
+ )
12
+ from careamics.config.support.supported_filters import (
13
+ SupportedCoordinateFilters,
14
+ SupportedPatchFilters,
15
+ )
16
+ from careamics.dataset_ng.patch_extractor import GenericImageStack, PatchExtractor
17
+
18
+ from .mask_filter import MaskCoordFilter
19
+ from .max_filter import MaxPatchFilter
20
+ from .mean_std_filter import MeanStdPatchFilter
21
+ from .shannon_filter import ShannonPatchFilter
22
+
23
+ PatchFilter = Union[
24
+ MaxPatchFilter,
25
+ MeanStdPatchFilter,
26
+ ShannonPatchFilter,
27
+ ]
28
+
29
+
30
+ CoordFilter = Union[MaskCoordFilter]
31
+
32
+
33
+ def create_coord_filter(
34
+ filter_model: FilterModel, mask: PatchExtractor[GenericImageStack]
35
+ ) -> CoordFilter:
36
+ """Factory function to create coordinate filter instances based on the filter name.
37
+
38
+ Parameters
39
+ ----------
40
+ filter_model : FilterModel
41
+ Pydantic model of the filter to be created.
42
+ mask : PatchExtractor[GenericImageStack]
43
+ Mask extractor to be used for the mask filter.
44
+
45
+ Returns
46
+ -------
47
+ CoordFilter
48
+ Instance of the mask patch filter.
49
+ """
50
+ if filter_model.name == SupportedCoordinateFilters.MASK:
51
+ assert isinstance(filter_model, MaskFilterModel)
52
+ return MaskCoordFilter(
53
+ mask_extractor=mask,
54
+ coverage=filter_model.coverage,
55
+ p=filter_model.p,
56
+ seed=filter_model.seed,
57
+ )
58
+ else:
59
+ raise ValueError(f"Unknown filter name: {filter_model}")
60
+
61
+
62
+ def create_patch_filter(filter_model: FilterModel) -> PatchFilter:
63
+ """Factory function to create patch filter instances based on the filter name.
64
+
65
+ Parameters
66
+ ----------
67
+ filter_model : FilterModel
68
+ Pydantic model of the filter to be created.
69
+
70
+ Returns
71
+ -------
72
+ PatchFilter
73
+ Instance of the requested patch filter.
74
+ """
75
+ if filter_model.name == SupportedPatchFilters.MAX:
76
+ assert isinstance(filter_model, MaxFilterModel)
77
+ return MaxPatchFilter(
78
+ threshold=filter_model.threshold, p=filter_model.p, seed=filter_model.seed
79
+ )
80
+ elif filter_model.name == SupportedPatchFilters.MEANSTD:
81
+ assert isinstance(filter_model, MeanSTDFilterModel)
82
+ return MeanStdPatchFilter(
83
+ mean_threshold=filter_model.mean_threshold,
84
+ std_threshold=filter_model.std_threshold,
85
+ p=filter_model.p,
86
+ seed=filter_model.seed,
87
+ )
88
+ elif filter_model.name == SupportedPatchFilters.SHANNON:
89
+ assert isinstance(filter_model, ShannonFilterModel)
90
+ return ShannonPatchFilter(
91
+ threshold=filter_model.threshold, p=filter_model.p, seed=filter_model.seed
92
+ )
93
+ else:
94
+ raise ValueError(f"Unknown filter name: {filter_model}")
@@ -0,0 +1,95 @@
1
+ """Filter using an image mask."""
2
+
3
+ import numpy as np
4
+
5
+ from careamics.dataset_ng.patch_extractor import GenericImageStack, PatchExtractor
6
+ from careamics.dataset_ng.patch_filter.coordinate_filter_protocol import (
7
+ CoordinateFilterProtocol,
8
+ )
9
+ from careamics.dataset_ng.patching_strategies import PatchSpecs
10
+
11
+
12
+ # TODO is it more intuitive to have a negative mask? (mask of what to avoid)
13
+ class MaskCoordFilter(CoordinateFilterProtocol):
14
+ """
15
+ Filter patch coordinates based on an image mask.
16
+
17
+ Attributes
18
+ ----------
19
+ mask_extractor : PatchExtractor[GenericImageStack]
20
+ Patch extractor for the binary mask to use for filtering.
21
+ coverage_perc : float
22
+ Minimum percentage of masked pixels required to keep a patch.
23
+ p : float
24
+ Probability of applying the filter to a patch.
25
+ rng : np.random.Generator
26
+ Random number generator for stochastic filtering.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ mask_extractor: PatchExtractor[GenericImageStack],
32
+ coverage: float,
33
+ p: float = 1.0,
34
+ seed: int | None = None,
35
+ ) -> None:
36
+ """
37
+ Create a MaskCoordFilter.
38
+
39
+ This filter removes patches who fall below a threshold of masked pixels
40
+ percentage. The mask is expected to be a positive mask where masked pixels
41
+ correspond to regions of interest.
42
+
43
+ Parameters
44
+ ----------
45
+ mask_extractor : PatchExtractor[GenericImageStack]
46
+ The patch extractor for the mask used for filtering.
47
+ coverage : float
48
+ Minimum percentage of masked pixels required to keep a patch. Must be
49
+ between 0 and 1.
50
+ p : float, default=1
51
+ Probability of applying the filter to a patch. Must be between 0 and 1.
52
+ seed : int | None, default=None
53
+ Seed for the random number generator for reproducibility.
54
+
55
+ Raises
56
+ ------
57
+ ValueError
58
+ If coverage is not between 0 and 1.
59
+ ValueError
60
+ If p is not between 0 and 1.
61
+ """
62
+
63
+ if not (0 <= coverage <= 1):
64
+ raise ValueError("Probability p must be between 0 and 1.")
65
+ if not (0 <= p <= 1):
66
+ raise ValueError("Probability p must be between 0 and 1.")
67
+
68
+ self.mask_extractor = mask_extractor
69
+ self.coverage = coverage
70
+
71
+ self.p = p
72
+ self.rng = np.random.default_rng(seed)
73
+
74
+ def filter_out(self, patch_specs: PatchSpecs) -> bool:
75
+ """
76
+ Determine whether to filter out a patch based an image mask.
77
+
78
+ Parameters
79
+ ----------
80
+ patch : PatchSpecs
81
+ The patch coordinates to evaluate.
82
+
83
+ Returns
84
+ -------
85
+ bool
86
+ True if the patch should be filtered out, False otherwise.
87
+ """
88
+
89
+ if self.rng.uniform(0, 1) < self.p:
90
+ mask_patch = self.mask_extractor.extract_patch(**patch_specs)
91
+
92
+ masked_fraction = np.sum(mask_patch) / mask_patch.size
93
+ if masked_fraction < self.coverage:
94
+ return True
95
+ return False