careamics 0.0.15__py3-none-any.whl → 0.0.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (79) hide show
  1. careamics/careamist.py +11 -14
  2. careamics/cli/conf.py +18 -3
  3. careamics/config/__init__.py +8 -0
  4. careamics/config/algorithms/__init__.py +4 -0
  5. careamics/config/algorithms/hdn_algorithm_model.py +103 -0
  6. careamics/config/algorithms/microsplit_algorithm_model.py +103 -0
  7. careamics/config/algorithms/n2v_algorithm_model.py +1 -2
  8. careamics/config/algorithms/vae_algorithm_model.py +51 -16
  9. careamics/config/architectures/lvae_model.py +12 -8
  10. careamics/config/callback_model.py +7 -3
  11. careamics/config/configuration.py +15 -63
  12. careamics/config/configuration_factories.py +853 -29
  13. careamics/config/data/data_model.py +50 -11
  14. careamics/config/data/ng_data_model.py +168 -4
  15. careamics/config/data/patch_filter/__init__.py +15 -0
  16. careamics/config/data/patch_filter/filter_model.py +16 -0
  17. careamics/config/data/patch_filter/mask_filter_model.py +17 -0
  18. careamics/config/data/patch_filter/max_filter_model.py +15 -0
  19. careamics/config/data/patch_filter/meanstd_filter_model.py +18 -0
  20. careamics/config/data/patch_filter/shannon_filter_model.py +15 -0
  21. careamics/config/inference_model.py +1 -2
  22. careamics/config/likelihood_model.py +2 -2
  23. careamics/config/loss_model.py +6 -2
  24. careamics/config/nm_model.py +26 -1
  25. careamics/config/optimizer_models.py +1 -2
  26. careamics/config/support/supported_algorithms.py +5 -3
  27. careamics/config/support/supported_filters.py +17 -0
  28. careamics/config/support/supported_losses.py +5 -2
  29. careamics/config/training_model.py +6 -36
  30. careamics/config/transformations/normalize_model.py +1 -2
  31. careamics/dataset_ng/dataset.py +57 -5
  32. careamics/dataset_ng/factory.py +101 -18
  33. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +4 -4
  34. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -2
  35. careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +33 -7
  36. careamics/dataset_ng/patch_extractor/image_stack_loader.py +2 -2
  37. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  38. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  39. careamics/dataset_ng/patch_filter/filter_factory.py +94 -0
  40. careamics/dataset_ng/patch_filter/mask_filter.py +95 -0
  41. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  42. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  43. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  44. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  45. careamics/file_io/read/__init__.py +0 -1
  46. careamics/lightning/__init__.py +16 -2
  47. careamics/lightning/callbacks/__init__.py +2 -0
  48. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  49. careamics/lightning/dataset_ng/data_module.py +79 -2
  50. careamics/lightning/lightning_module.py +162 -61
  51. careamics/lightning/microsplit_data_module.py +636 -0
  52. careamics/lightning/predict_data_module.py +8 -1
  53. careamics/lightning/train_data_module.py +19 -8
  54. careamics/losses/__init__.py +7 -1
  55. careamics/losses/loss_factory.py +9 -1
  56. careamics/losses/lvae/losses.py +85 -0
  57. careamics/lvae_training/dataset/__init__.py +8 -8
  58. careamics/lvae_training/dataset/config.py +56 -44
  59. careamics/lvae_training/dataset/lc_dataset.py +18 -12
  60. careamics/lvae_training/dataset/ms_dataset_ref.py +5 -5
  61. careamics/lvae_training/dataset/multich_dataset.py +24 -18
  62. careamics/lvae_training/dataset/multifile_dataset.py +6 -6
  63. careamics/lvae_training/eval_utils.py +46 -24
  64. careamics/model_io/bmz_io.py +9 -5
  65. careamics/models/lvae/likelihoods.py +31 -14
  66. careamics/models/lvae/lvae.py +2 -2
  67. careamics/models/lvae/noise_models.py +20 -14
  68. careamics/prediction_utils/__init__.py +8 -2
  69. careamics/prediction_utils/prediction_outputs.py +49 -3
  70. careamics/prediction_utils/stitch_prediction.py +83 -1
  71. careamics/transforms/xy_random_rotate90.py +1 -1
  72. careamics/utils/version.py +4 -4
  73. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/METADATA +19 -22
  74. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/RECORD +77 -60
  75. careamics/dataset/zarr_dataset.py +0 -151
  76. careamics/file_io/read/zarr.py +0 -60
  77. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/WHEEL +0 -0
  78. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/entry_points.txt +0 -0
  79. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/licenses/LICENSE +0 -0
@@ -16,6 +16,7 @@ from careamics.config.transformations import NormalizeModel
16
16
  from careamics.dataset.dataset_utils.running_stats import WelfordStatistics
17
17
  from careamics.dataset.patching.patching import Stats
18
18
  from careamics.dataset_ng.patch_extractor import GenericImageStack, PatchExtractor
19
+ from careamics.dataset_ng.patch_filter import create_coord_filter, create_patch_filter
19
20
  from careamics.dataset_ng.patching_strategies import (
20
21
  FixedRandomPatchingStrategy,
21
22
  PatchingStrategy,
@@ -52,13 +53,26 @@ class CareamicsDataset(Dataset, Generic[GenericImageStack]):
52
53
  mode: Mode,
53
54
  input_extractor: PatchExtractor[GenericImageStack],
54
55
  target_extractor: PatchExtractor[GenericImageStack] | None = None,
55
- ):
56
+ mask_extractor: PatchExtractor[GenericImageStack] | None = None,
57
+ ) -> None:
56
58
  self.config = data_config
57
59
  self.mode = mode
58
60
 
59
61
  self.input_extractor = input_extractor
60
62
  self.target_extractor = target_extractor
61
63
 
64
+ self.patch_filter = (
65
+ create_patch_filter(self.config.patch_filter)
66
+ if self.config.patch_filter is not None
67
+ else None
68
+ )
69
+ self.coord_filter = (
70
+ create_coord_filter(self.config.coord_filter, mask=mask_extractor)
71
+ if self.config.coord_filter is not None and mask_extractor is not None
72
+ else None
73
+ )
74
+ self.patch_filter_patience = self.config.patch_filter_patience
75
+
62
76
  self.patching_strategy = self._initialize_patching_strategy()
63
77
 
64
78
  self.input_stats, self.target_stats = self._initialize_statistics()
@@ -183,10 +197,10 @@ class CareamicsDataset(Dataset, Generic[GenericImageStack]):
183
197
  region_spec=patch_spec,
184
198
  )
185
199
 
186
- def __getitem__(
187
- self, index: int
188
- ) -> Union[tuple[ImageRegionData], tuple[ImageRegionData, ImageRegionData]]:
189
- patch_spec = self.patching_strategy.get_patch_spec(index)
200
+ def _extract_patches(
201
+ self, patch_spec: PatchSpecs
202
+ ) -> tuple[NDArray, NDArray | None]:
203
+ """Extract input and target patches based on patch specifications."""
190
204
  input_patch = self.input_extractor.extract_patch(
191
205
  data_idx=patch_spec["data_idx"],
192
206
  sample_idx=patch_spec["sample_idx"],
@@ -204,7 +218,45 @@ class CareamicsDataset(Dataset, Generic[GenericImageStack]):
204
218
  if self.target_extractor is not None
205
219
  else None
206
220
  )
221
+ return input_patch, target_patch
222
+
223
+ def _get_filtered_patch(
224
+ self, index: int
225
+ ) -> tuple[NDArray[Any], NDArray[Any] | None, PatchSpecs]:
226
+ """Extract a patch that passes filtering criteria with retry logic."""
227
+ should_filter = self.mode == Mode.TRAINING and (
228
+ self.patch_filter is not None or self.coord_filter is not None
229
+ )
230
+ empty_patch = True
231
+ patch_filter_patience = self.patch_filter_patience # reset patience
232
+
233
+ while empty_patch and patch_filter_patience > 0:
234
+ # query patches
235
+ patch_spec = self.patching_strategy.get_patch_spec(index)
236
+
237
+ # filter patch based on coordinates if needed
238
+ if should_filter and self.coord_filter is not None:
239
+ if self.coord_filter.filter_out(patch_spec):
240
+ patch_filter_patience -= 1
241
+ continue
242
+
243
+ input_patch, target_patch = self._extract_patches(patch_spec)
244
+
245
+ # filter patch based on values if needed
246
+ if should_filter and self.patch_filter is not None:
247
+ empty_patch = self.patch_filter.filter_out(input_patch)
248
+ patch_filter_patience -= 1 # decrease patience
249
+ else:
250
+ empty_patch = False
251
+
252
+ return input_patch, target_patch, patch_spec
253
+
254
+ def __getitem__(
255
+ self, index: int
256
+ ) -> Union[tuple[ImageRegionData], tuple[ImageRegionData, ImageRegionData]]:
257
+ input_patch, target_patch, patch_spec = self._get_filtered_patch(index)
207
258
 
259
+ # apply transforms
208
260
  if self.transforms is not None:
209
261
  if self.target_extractor is not None:
210
262
  input_patch, target_patch = self.transforms(input_patch, target_patch)
@@ -121,6 +121,7 @@ def create_dataset(
121
121
  inputs: Any,
122
122
  targets: Any,
123
123
  in_memory: bool,
124
+ masks: Any = None,
124
125
  read_func: ReadFunc | None = None,
125
126
  read_kwargs: dict[str, Any] | None = None,
126
127
  image_stack_loader: ImageStackLoader | None = None,
@@ -142,6 +143,8 @@ def create_dataset(
142
143
  in_memory : bool
143
144
  Whether all the data should be loaded into memory. This is argument is ignored
144
145
  unless the `data_type` in `config` is "tiff" or "custom".
146
+ masks : Any, optional
147
+ The mask sources used to filter patches.
145
148
  read_func : ReadFunc, optional
146
149
  A function that can that can be used to load custom data. This argument is
147
150
  ignored unless the `data_type` in the `config` is "custom".
@@ -168,18 +171,24 @@ def create_dataset(
168
171
  data_type, in_memory, read_func, image_stack_loader
169
172
  )
170
173
  if dataset_type == DatasetType.ARRAY:
171
- return create_array_dataset(config, mode, inputs, targets)
174
+ return create_array_dataset(config, mode, inputs, targets, masks)
172
175
  elif dataset_type == DatasetType.IN_MEM_TIFF:
173
- return create_tiff_dataset(config, mode, inputs, targets)
176
+ return create_tiff_dataset(config, mode, inputs, targets, masks)
174
177
  # TODO: Lazy tiff
175
178
  elif dataset_type == DatasetType.CZI:
176
- return create_czi_dataset(config, mode, inputs, targets)
179
+ return create_czi_dataset(config, mode, inputs, targets, masks)
177
180
  elif dataset_type == DatasetType.IN_MEM_CUSTOM_FILE:
178
181
  if read_kwargs is None:
179
182
  read_kwargs = {}
180
183
  assert read_func is not None # should be true from `determine_dataset_type`
181
184
  return create_custom_file_dataset(
182
- config, mode, inputs, targets, read_func=read_func, read_kwargs=read_kwargs
185
+ config,
186
+ mode,
187
+ inputs,
188
+ targets,
189
+ masks,
190
+ read_func=read_func,
191
+ read_kwargs=read_kwargs,
183
192
  )
184
193
  elif dataset_type == DatasetType.CUSTOM_IMAGE_STACK:
185
194
  if image_stack_loader_kwargs is None:
@@ -191,6 +200,7 @@ def create_dataset(
191
200
  inputs,
192
201
  targets,
193
202
  image_stack_loader,
203
+ masks,
194
204
  **image_stack_loader_kwargs,
195
205
  )
196
206
  else:
@@ -202,6 +212,7 @@ def create_array_dataset(
202
212
  mode: Mode,
203
213
  inputs: Sequence[NDArray[Any]],
204
214
  targets: Sequence[NDArray[Any]] | None,
215
+ masks: Sequence[NDArray[Any]] | None = None,
205
216
  ) -> CareamicsDataset[InMemoryImageStack]:
206
217
  """
207
218
  Create a CAREamicsDataset from array data.
@@ -216,6 +227,8 @@ def create_array_dataset(
216
227
  The input sources to the dataset.
217
228
  targets : Any, optional
218
229
  The target sources to the dataset.
230
+ masks : Any, optional
231
+ The mask sources used to filter patches.
219
232
 
220
233
  Returns
221
234
  -------
@@ -228,7 +241,14 @@ def create_array_dataset(
228
241
  target_extractor = create_array_extractor(source=targets, axes=config.axes)
229
242
  else:
230
243
  target_extractor = None
231
- return CareamicsDataset(config, mode, input_extractor, target_extractor)
244
+ mask_extractor: PatchExtractor[InMemoryImageStack] | None
245
+ if masks is not None:
246
+ mask_extractor = create_array_extractor(source=masks, axes=config.axes)
247
+ else:
248
+ mask_extractor = None
249
+ return CareamicsDataset(
250
+ config, mode, input_extractor, target_extractor, mask_extractor
251
+ )
232
252
 
233
253
 
234
254
  def create_tiff_dataset(
@@ -236,9 +256,10 @@ def create_tiff_dataset(
236
256
  mode: Mode,
237
257
  inputs: Sequence[Path],
238
258
  targets: Sequence[Path] | None,
259
+ masks: Sequence[Path] | None = None,
239
260
  ) -> CareamicsDataset[InMemoryImageStack]:
240
261
  """
241
- Create a CAREamicsDataset from tiff files that will be all loaded into memory.
262
+ Create a CAREamicsDataset from tiff files that will be loaded into memory.
242
263
 
243
264
  Parameters
244
265
  ----------
@@ -246,10 +267,12 @@ def create_tiff_dataset(
246
267
  The data configuration.
247
268
  mode : Mode
248
269
  Whether to create the dataset in "training", "validation" or "predicting" mode.
249
- inputs : Any
270
+ inputs : Sequence[Path]
250
271
  The input sources to the dataset.
251
- targets : Any, optional
272
+ targets : Sequence[Path], optional
252
273
  The target sources to the dataset.
274
+ masks : Sequence[Path], optional
275
+ The mask sources used to filter patches.
253
276
 
254
277
  Returns
255
278
  -------
@@ -265,8 +288,15 @@ def create_tiff_dataset(
265
288
  target_extractor = create_tiff_extractor(source=targets, axes=config.axes)
266
289
  else:
267
290
  target_extractor = None
268
- dataset = CareamicsDataset(config, mode, input_extractor, target_extractor)
269
- return dataset
291
+ mask_extractor: PatchExtractor[InMemoryImageStack] | None
292
+ if masks is not None:
293
+ mask_extractor = create_tiff_extractor(source=masks, axes=config.axes)
294
+ else:
295
+ mask_extractor = None
296
+
297
+ return CareamicsDataset(
298
+ config, mode, input_extractor, target_extractor, mask_extractor
299
+ )
270
300
 
271
301
 
272
302
  def create_czi_dataset(
@@ -274,6 +304,7 @@ def create_czi_dataset(
274
304
  mode: Mode,
275
305
  inputs: Sequence[Path],
276
306
  targets: Sequence[Path] | None,
307
+ masks: Sequence[Path] | None = None,
277
308
  ) -> CareamicsDataset[CziImageStack]:
278
309
  """
279
310
  Create a dataset from CZI files.
@@ -288,6 +319,8 @@ def create_czi_dataset(
288
319
  The input sources to the dataset.
289
320
  targets : Any, optional
290
321
  The target sources to the dataset.
322
+ masks : Any, optional
323
+ The mask sources used to filter patches.
291
324
 
292
325
  Returns
293
326
  -------
@@ -301,8 +334,15 @@ def create_czi_dataset(
301
334
  target_extractor = create_czi_extractor(source=targets, axes=config.axes)
302
335
  else:
303
336
  target_extractor = None
304
- dataset = CareamicsDataset(config, mode, input_extractor, target_extractor)
305
- return dataset
337
+ mask_extractor: PatchExtractor[CziImageStack] | None
338
+ if masks is not None:
339
+ mask_extractor = create_czi_extractor(source=masks, axes=config.axes)
340
+ else:
341
+ mask_extractor = None
342
+
343
+ return CareamicsDataset(
344
+ config, mode, input_extractor, target_extractor, mask_extractor
345
+ )
306
346
 
307
347
 
308
348
  def create_ome_zarr_dataset(
@@ -310,6 +350,7 @@ def create_ome_zarr_dataset(
310
350
  mode: Mode,
311
351
  inputs: Sequence[Path],
312
352
  targets: Sequence[Path] | None,
353
+ masks: Sequence[Path] | None = None,
313
354
  ) -> CareamicsDataset[ZarrImageStack]:
314
355
  """
315
356
  Create a dataset from OME ZARR files.
@@ -324,6 +365,8 @@ def create_ome_zarr_dataset(
324
365
  The input sources to the dataset.
325
366
  targets : Any, optional
326
367
  The target sources to the dataset.
368
+ masks : Any, optional
369
+ The mask sources used to filter patches.
327
370
 
328
371
  Returns
329
372
  -------
@@ -337,8 +380,15 @@ def create_ome_zarr_dataset(
337
380
  target_extractor = create_ome_zarr_extractor(source=targets, axes=config.axes)
338
381
  else:
339
382
  target_extractor = None
340
- dataset = CareamicsDataset(config, mode, input_extractor, target_extractor)
341
- return dataset
383
+ mask_extractor: PatchExtractor[ZarrImageStack] | None
384
+ if masks is not None:
385
+ mask_extractor = create_ome_zarr_extractor(source=masks, axes=config.axes)
386
+ else:
387
+ mask_extractor = None
388
+
389
+ return CareamicsDataset(
390
+ config, mode, input_extractor, target_extractor, mask_extractor
391
+ )
342
392
 
343
393
 
344
394
  def create_custom_file_dataset(
@@ -346,6 +396,7 @@ def create_custom_file_dataset(
346
396
  mode: Mode,
347
397
  inputs: Sequence[Path],
348
398
  targets: Sequence[Path] | None,
399
+ masks: Sequence[Path] | None = None,
349
400
  *,
350
401
  read_func: ReadFunc,
351
402
  read_kwargs: dict[str, Any],
@@ -363,6 +414,8 @@ def create_custom_file_dataset(
363
414
  The input sources to the dataset.
364
415
  targets : Any, optional
365
416
  The target sources to the dataset.
417
+ masks : Any, optional
418
+ The mask sources used to filter patches.
366
419
  read_func : Optional[ReadFunc], optional
367
420
  A function that can that can be used to load custom data. This argument is
368
421
  ignored unless the `data_type` is "custom".
@@ -388,8 +441,21 @@ def create_custom_file_dataset(
388
441
  )
389
442
  else:
390
443
  target_extractor = None
391
- dataset = CareamicsDataset(config, mode, input_extractor, target_extractor)
392
- return dataset
444
+
445
+ mask_extractor: PatchExtractor[InMemoryImageStack] | None
446
+ if masks is not None:
447
+ mask_extractor = create_custom_file_extractor(
448
+ source=masks,
449
+ axes=config.axes,
450
+ read_func=read_func,
451
+ read_kwargs=read_kwargs,
452
+ )
453
+ else:
454
+ mask_extractor = None
455
+
456
+ return CareamicsDataset(
457
+ config, mode, input_extractor, target_extractor, mask_extractor
458
+ )
393
459
 
394
460
 
395
461
  def create_custom_image_stack_dataset(
@@ -398,6 +464,7 @@ def create_custom_image_stack_dataset(
398
464
  inputs: Any,
399
465
  targets: Any | None,
400
466
  image_stack_loader: ImageStackLoader[P, GenericImageStack],
467
+ masks: Any | None = None,
401
468
  *args: P.args,
402
469
  **kwargs: P.kwargs,
403
470
  ) -> CareamicsDataset[GenericImageStack]:
@@ -419,6 +486,8 @@ def create_custom_image_stack_dataset(
419
486
  image_stack_loader : ImageStackLoader
420
487
  A function for custom image stack loading. This argument is ignored unless the
421
488
  `data_type` is "custom".
489
+ masks : Any, optional
490
+ The mask sources used to filter patches.
422
491
  *args : Any
423
492
  Positional arguments to pass to the `image_stack_loader`.
424
493
  **kwargs : Any
@@ -447,5 +516,19 @@ def create_custom_image_stack_dataset(
447
516
  )
448
517
  else:
449
518
  target_extractor = None
450
- dataset = CareamicsDataset(config, mode, input_extractor, target_extractor)
451
- return dataset
519
+
520
+ mask_extractor: PatchExtractor[GenericImageStack] | None
521
+ if masks is not None:
522
+ mask_extractor = create_custom_image_stack_extractor(
523
+ masks,
524
+ config.axes,
525
+ image_stack_loader,
526
+ *args,
527
+ **kwargs,
528
+ )
529
+ else:
530
+ mask_extractor = None
531
+
532
+ return CareamicsDataset(
533
+ config, mode, input_extractor, target_extractor, mask_extractor
534
+ )
@@ -7,7 +7,7 @@ import matplotlib.pyplot as plt
7
7
  import numpy as np
8
8
  import zarr
9
9
  from numpy.typing import NDArray
10
- from zarr.storage import FSStore
10
+ from zarr.storage import FsspecStore
11
11
 
12
12
  from careamics.config import DataConfig
13
13
  from careamics.config.support import SupportedData
@@ -20,7 +20,7 @@ from careamics.dataset_ng.patch_extractor.patch_extractor_factory import (
20
20
 
21
21
  # %%
22
22
  def create_zarr_array(file_path: Path, data_path: str, data: NDArray):
23
- store = FSStore(url=file_path.resolve())
23
+ store = FsspecStore.from_url(url=file_path.resolve())
24
24
  # create array
25
25
  array = zarr.create(
26
26
  store=store,
@@ -61,7 +61,7 @@ if not file_path.is_file() and not file_path.is_dir():
61
61
  # ### Make sure file exists
62
62
 
63
63
  # %%
64
- store = FSStore(url=file_path.resolve(), mode="r")
64
+ store = FsspecStore.from_url(url=file_path.resolve(), mode="r")
65
65
 
66
66
  # %%
67
67
  list(store.keys())
@@ -72,7 +72,7 @@ list(store.keys())
72
72
 
73
73
  # %%
74
74
  class ZarrSource(TypedDict):
75
- store: FSStore
75
+ store: FsspecStore
76
76
  data_paths: Sequence[str]
77
77
 
78
78
 
@@ -1,9 +1,8 @@
1
1
  from collections.abc import Sequence
2
2
  from pathlib import Path
3
- from typing import Any, Literal, Union
3
+ from typing import Any, Literal, Self, Union
4
4
 
5
5
  from numpy.typing import DTypeLike, NDArray
6
- from typing_extensions import Self
7
6
 
8
7
  from careamics.dataset.dataset_utils import reshape_array
9
8
  from careamics.file_io.read import ReadFunc, read_tiff
@@ -1,11 +1,11 @@
1
1
  from collections.abc import Sequence
2
2
  from pathlib import Path
3
- from typing import Union
3
+ from typing import Self, Union
4
4
 
5
+ import validators
5
6
  import zarr
6
- import zarr.storage
7
7
  from numpy.typing import NDArray
8
- from typing_extensions import Self
8
+ from zarr.storage import FsspecStore, LocalStore
9
9
 
10
10
  from careamics.dataset.dataset_utils import reshape_array
11
11
 
@@ -15,9 +15,10 @@ class ZarrImageStack:
15
15
  A class for extracting patches from an image stack that is stored as a zarr array.
16
16
  """
17
17
 
18
- # TODO: keeping store type narrow so that it has the path attribute
19
- # base zarr store is zarr.storage.Store, includes MemoryStore
20
- def __init__(self, store: zarr.storage.FSStore, data_path: str, axes: str):
18
+ # TODO: We should keep store type narrow
19
+ # - in zarr v3, does zarr.storage.Store exists and has the path attribute?
20
+ # - can we declare a narrow type rather than a union?
21
+ def __init__(self, store: LocalStore | FsspecStore, data_path: str, axes: str):
21
22
  self._store = store
22
23
  self._array = zarr.open_array(store=self._store, path=data_path, mode="r")
23
24
  # TODO: validate axes
@@ -46,8 +47,33 @@ class ZarrImageStack:
46
47
  Assumes the path only contains 1 image.
47
48
 
48
49
  Path can be to a local file, or it can be a URL to a zarr stored in the cloud.
50
+
51
+ Parameters
52
+ ----------
53
+ path : Union[Path, str]
54
+ Path to the root of the OME-Zarr, local file or url.
55
+
56
+ Returns
57
+ -------
58
+ ZarrImageStack
59
+ Initialised ZarrImageStack.
60
+
61
+ Raises
62
+ ------
63
+ ValueError
64
+ If the path does not exist or is not a valid URL.
65
+ ValueError
66
+ If the OME-Zarr at the path does not contain the attribute 'multiscales'.
49
67
  """
50
- store = zarr.storage.FSStore(url=path)
68
+ if Path(path).is_file():
69
+ store = zarr.storage.LocalStore(root=Path(path).resolve())
70
+ elif validators.url(path):
71
+ store = zarr.storage.FsspecStore.from_url(url=path)
72
+ else:
73
+ raise ValueError(
74
+ f"Path '{path}' is neither an existing file nor a valid URL."
75
+ )
76
+
51
77
  group = zarr.open_group(store=store, mode="r")
52
78
  if "multiscales" not in group.attrs:
53
79
  raise ValueError(
@@ -38,7 +38,7 @@ class ImageStackLoader(Protocol[P, GenericImageStack]):
38
38
 
39
39
  >>> from typing import TypedDict
40
40
 
41
- >>> from zarr.storage import FSStore
41
+ >>> from zarr.storage import FsspecStore
42
42
 
43
43
  >>> from careamics.config import DataConfig
44
44
  >>> from careamics.dataset_ng.patch_extractor.image_stack import ZarrImageStack
@@ -46,7 +46,7 @@ class ImageStackLoader(Protocol[P, GenericImageStack]):
46
46
  >>> # Define a zarr source
47
47
  >>> # It encompasses multiple arguments that determine what data will be loaded
48
48
  >>> class ZarrSource(TypedDict):
49
- ... store: FSStore
49
+ ... store: FsspecStore
50
50
  ... data_paths: Sequence[str]
51
51
 
52
52
  >>> def custom_image_stack_loader(
@@ -0,0 +1,20 @@
1
+ """Patch filtering strategies."""
2
+
3
+ __all__ = [
4
+ "CoordinateFilterProtocol",
5
+ "MaskCoordFilter",
6
+ "MaxPatchFilter",
7
+ "MeanStdPatchFilter",
8
+ "PatchFilterProtocol",
9
+ "ShannonPatchFilter",
10
+ "create_coord_filter",
11
+ "create_patch_filter",
12
+ ]
13
+
14
+ from .coordinate_filter_protocol import CoordinateFilterProtocol
15
+ from .filter_factory import create_coord_filter, create_patch_filter
16
+ from .mask_filter import MaskCoordFilter
17
+ from .max_filter import MaxPatchFilter
18
+ from .mean_std_filter import MeanStdPatchFilter
19
+ from .patch_filter_protocol import PatchFilterProtocol
20
+ from .shannon_filter import ShannonPatchFilter
@@ -0,0 +1,27 @@
1
+ """A protocol for patch filtering."""
2
+
3
+ from typing import Protocol
4
+
5
+ from careamics.dataset_ng.patching_strategies import PatchSpecs
6
+
7
+
8
+ class CoordinateFilterProtocol(Protocol):
9
+ """
10
+ An interface for implementing coordinate filtering strategies.
11
+ """
12
+
13
+ def filter_out(self, patch: PatchSpecs) -> bool:
14
+ """
15
+ Determine whether to filter out a given patch based on its coordinates.
16
+
17
+ Parameters
18
+ ----------
19
+ patch : PatchSpecs
20
+ The patch coordinates to evaluate.
21
+
22
+ Returns
23
+ -------
24
+ bool
25
+ True if the patch should be filtered out (excluded), False otherwise.
26
+ """
27
+ ...
@@ -0,0 +1,94 @@
1
+ """Factories for coordinate and patch filters."""
2
+
3
+ from typing import Union
4
+
5
+ from careamics.config.data.patch_filter import (
6
+ FilterModel,
7
+ MaskFilterModel,
8
+ MaxFilterModel,
9
+ MeanSTDFilterModel,
10
+ ShannonFilterModel,
11
+ )
12
+ from careamics.config.support.supported_filters import (
13
+ SupportedCoordinateFilters,
14
+ SupportedPatchFilters,
15
+ )
16
+ from careamics.dataset_ng.patch_extractor import GenericImageStack, PatchExtractor
17
+
18
+ from .mask_filter import MaskCoordFilter
19
+ from .max_filter import MaxPatchFilter
20
+ from .mean_std_filter import MeanStdPatchFilter
21
+ from .shannon_filter import ShannonPatchFilter
22
+
23
+ PatchFilter = Union[
24
+ MaxPatchFilter,
25
+ MeanStdPatchFilter,
26
+ ShannonPatchFilter,
27
+ ]
28
+
29
+
30
+ CoordFilter = Union[MaskCoordFilter]
31
+
32
+
33
+ def create_coord_filter(
34
+ filter_model: FilterModel, mask: PatchExtractor[GenericImageStack]
35
+ ) -> CoordFilter:
36
+ """Factory function to create coordinate filter instances based on the filter name.
37
+
38
+ Parameters
39
+ ----------
40
+ filter_model : FilterModel
41
+ Pydantic model of the filter to be created.
42
+ mask : PatchExtractor[GenericImageStack]
43
+ Mask extractor to be used for the mask filter.
44
+
45
+ Returns
46
+ -------
47
+ CoordFilter
48
+ Instance of the mask patch filter.
49
+ """
50
+ if filter_model.name == SupportedCoordinateFilters.MASK:
51
+ assert isinstance(filter_model, MaskFilterModel)
52
+ return MaskCoordFilter(
53
+ mask_extractor=mask,
54
+ coverage=filter_model.coverage,
55
+ p=filter_model.p,
56
+ seed=filter_model.seed,
57
+ )
58
+ else:
59
+ raise ValueError(f"Unknown filter name: {filter_model}")
60
+
61
+
62
+ def create_patch_filter(filter_model: FilterModel) -> PatchFilter:
63
+ """Factory function to create patch filter instances based on the filter name.
64
+
65
+ Parameters
66
+ ----------
67
+ filter_model : FilterModel
68
+ Pydantic model of the filter to be created.
69
+
70
+ Returns
71
+ -------
72
+ PatchFilter
73
+ Instance of the requested patch filter.
74
+ """
75
+ if filter_model.name == SupportedPatchFilters.MAX:
76
+ assert isinstance(filter_model, MaxFilterModel)
77
+ return MaxPatchFilter(
78
+ threshold=filter_model.threshold, p=filter_model.p, seed=filter_model.seed
79
+ )
80
+ elif filter_model.name == SupportedPatchFilters.MEANSTD:
81
+ assert isinstance(filter_model, MeanSTDFilterModel)
82
+ return MeanStdPatchFilter(
83
+ mean_threshold=filter_model.mean_threshold,
84
+ std_threshold=filter_model.std_threshold,
85
+ p=filter_model.p,
86
+ seed=filter_model.seed,
87
+ )
88
+ elif filter_model.name == SupportedPatchFilters.SHANNON:
89
+ assert isinstance(filter_model, ShannonFilterModel)
90
+ return ShannonPatchFilter(
91
+ threshold=filter_model.threshold, p=filter_model.p, seed=filter_model.seed
92
+ )
93
+ else:
94
+ raise ValueError(f"Unknown filter name: {filter_model}")