careamics 0.0.16__py3-none-any.whl → 0.0.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (36) hide show
  1. careamics/careamist.py +7 -4
  2. careamics/config/configuration.py +6 -55
  3. careamics/config/configuration_factories.py +22 -12
  4. careamics/config/data/data_model.py +49 -9
  5. careamics/config/data/ng_data_model.py +167 -2
  6. careamics/config/data/patch_filter/__init__.py +15 -0
  7. careamics/config/data/patch_filter/filter_model.py +16 -0
  8. careamics/config/data/patch_filter/mask_filter_model.py +17 -0
  9. careamics/config/data/patch_filter/max_filter_model.py +15 -0
  10. careamics/config/data/patch_filter/meanstd_filter_model.py +18 -0
  11. careamics/config/data/patch_filter/shannon_filter_model.py +15 -0
  12. careamics/config/support/supported_filters.py +17 -0
  13. careamics/dataset_ng/dataset.py +57 -5
  14. careamics/dataset_ng/factory.py +101 -18
  15. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  16. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  17. careamics/dataset_ng/patch_filter/filter_factory.py +94 -0
  18. careamics/dataset_ng/patch_filter/mask_filter.py +95 -0
  19. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  20. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  21. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  22. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  23. careamics/lightning/callbacks/data_stats_callback.py +13 -3
  24. careamics/lightning/dataset_ng/data_module.py +79 -2
  25. careamics/lightning/lightning_module.py +4 -3
  26. careamics/lightning/microsplit_data_module.py +15 -10
  27. careamics/lvae_training/eval_utils.py +46 -24
  28. careamics/models/lvae/likelihoods.py +2 -1
  29. careamics/prediction_utils/prediction_outputs.py +3 -2
  30. careamics/prediction_utils/stitch_prediction.py +17 -6
  31. careamics/utils/version.py +4 -4
  32. {careamics-0.0.16.dist-info → careamics-0.0.17.dist-info}/METADATA +5 -11
  33. {careamics-0.0.16.dist-info → careamics-0.0.17.dist-info}/RECORD +36 -21
  34. {careamics-0.0.16.dist-info → careamics-0.0.17.dist-info}/WHEEL +0 -0
  35. {careamics-0.0.16.dist-info → careamics-0.0.17.dist-info}/entry_points.txt +0 -0
  36. {careamics-0.0.16.dist-info → careamics-0.0.17.dist-info}/licenses/LICENSE +0 -0
careamics/careamist.py CHANGED
@@ -41,6 +41,7 @@ logger = get_logger(__name__)
41
41
  LOGGER_TYPES = list[Union[TensorBoardLogger, WandbLogger, CSVLogger]]
42
42
 
43
43
 
44
+ # TODO type ignore have been added because of the czi data type in data configuration
44
45
  class CAREamist:
45
46
  """Main CAREamics class, allowing training and prediction using various algorithms.
46
47
 
@@ -674,7 +675,7 @@ class CAREamist:
674
675
  # create the prediction
675
676
  self.pred_datamodule = create_predict_datamodule(
676
677
  pred_data=source,
677
- data_type=data_type or self.cfg.data_config.data_type,
678
+ data_type=data_type or self.cfg.data_config.data_type, # type: ignore
678
679
  axes=axes or self.cfg.data_config.axes,
679
680
  image_means=self.cfg.data_config.image_means,
680
681
  image_stds=self.cfg.data_config.image_stds,
@@ -817,14 +818,16 @@ class CAREamist:
817
818
 
818
819
  # extract file names
819
820
  source_path: Union[Path, str, NDArray]
820
- source_data_type: Literal["array", "tiff", "czi", "custom"]
821
+ source_data_type: Literal["array", "tiff", "custom"]
821
822
  if isinstance(source, PredictDataModule):
822
823
  source_path = source.pred_data
823
- source_data_type = source.data_type
824
+ source_data_type = source.data_type # type: ignore
824
825
  extension_filter = source.extension_filter
825
826
  elif isinstance(source, (str | Path)):
826
827
  source_path = source
827
- source_data_type = data_type or self.cfg.data_config.data_type
828
+ source_data_type = (
829
+ data_type or self.cfg.data_config.data_type # type: ignore
830
+ )
828
831
  extension_filter = SupportedData.get_extension_pattern(
829
832
  SupportedData(source_data_type)
830
833
  )
@@ -3,14 +3,12 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import re
6
- from collections.abc import Callable
7
6
  from pprint import pformat
8
7
  from typing import Any, Literal, Self, Union
9
8
 
10
9
  import numpy as np
11
10
  from bioimageio.spec.generic.v0_3 import CiteEntry
12
11
  from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
13
- from pydantic.main import IncEx
14
12
 
15
13
  from careamics.config.algorithms import (
16
14
  CAREAlgorithm,
@@ -343,19 +341,7 @@ class Configuration(BaseModel):
343
341
 
344
342
  def model_dump(
345
343
  self,
346
- *,
347
- mode: Literal["json", "python"] | str = "python",
348
- include: IncEx | None = None,
349
- exclude: IncEx | None = None,
350
- context: Any | None = None,
351
- by_alias: bool | None = False,
352
- exclude_unset: bool = False,
353
- exclude_defaults: bool = False,
354
- exclude_none: bool = True,
355
- round_trip: bool = False,
356
- warnings: bool | Literal["none", "warn", "error"] = True,
357
- fallback: Callable[[Any], Any] | None = None,
358
- serialize_as_any: bool = False,
344
+ **kwargs: Any,
359
345
  ) -> dict[str, Any]:
360
346
  """
361
347
  Override model_dump method in order to set default values.
@@ -365,50 +351,15 @@ class Configuration(BaseModel):
365
351
 
366
352
  Parameters
367
353
  ----------
368
- mode : Literal['json', 'python'] | str, default='python'
369
- The serialization format.
370
- include : Any | None, default=None
371
- Attributes to include.
372
- exclude : Any | None, default=None
373
- Attributes to exclude.
374
- context : Any | None, default=None
375
- Additional context to pass to the serialization functions.
376
- by_alias : bool, default=False
377
- Whether to use attribute aliases.
378
- exclude_unset : bool, default=False
379
- Whether to exclude fields that are not set.
380
- exclude_defaults : bool, default=False
381
- Whether to exclude fields that have default values.
382
- exclude_none : bool, default=true
383
- Whether to exclude fields that have None values.
384
- round_trip : bool, default=False
385
- Whether to dump and load the data to ensure that the output is a valid
386
- representation.
387
- warnings : bool | Literal['none', 'warn', 'error'], default=True
388
- Whether to emit warnings.
389
- fallback : Callable[[Any], Any] | None, default=None
390
- A function to call when an unknown value is encountered.
391
- serialize_as_any : bool, default=False
392
- Whether to serialize all types as Any.
354
+ **kwargs : Any
355
+ Additional arguments to pass to the parent model_dump method.
393
356
 
394
357
  Returns
395
358
  -------
396
359
  dict
397
360
  Dictionary containing the model parameters.
398
361
  """
399
- dictionary = super().model_dump(
400
- mode=mode,
401
- include=include,
402
- exclude=exclude,
403
- context=context,
404
- by_alias=by_alias,
405
- exclude_unset=exclude_unset,
406
- exclude_defaults=exclude_defaults,
407
- exclude_none=exclude_none,
408
- round_trip=round_trip,
409
- warnings=warnings,
410
- fallback=fallback,
411
- serialize_as_any=serialize_as_any,
412
- )
362
+ if "exclude_none" not in kwargs:
363
+ kwargs["exclude_none"] = True
413
364
 
414
- return dictionary
365
+ return super().model_dump(**kwargs)
@@ -311,6 +311,10 @@ def _create_microsplit_data_configuration(
311
311
  Axes of the data.
312
312
  patch_size : list of int
313
313
  Size of the patches along the spatial dimensions.
314
+ grid_size : int
315
+ Grid size for patch extraction.
316
+ multiscale_count : int
317
+ Number of LC scales.
314
318
  batch_size : int
315
319
  Batch size.
316
320
  augmentations : list of transforms
@@ -1610,6 +1614,12 @@ def get_likelihood_config(
1610
1614
  ]:
1611
1615
  """Get the likelihood configuration for split models.
1612
1616
 
1617
+ Returns a tuple containing the following optional entries:
1618
+ - GaussianLikelihoodConfig: Gaussian likelihood configuration for musplit losses
1619
+ - MultiChannelNMConfig: Multi-channel noise model configuration for denoisplit
1620
+ losses
1621
+ - NMLikelihoodConfig: Noise model likelihood configuration for denoisplit losses
1622
+
1613
1623
  Parameters
1614
1624
  ----------
1615
1625
  loss_type : Literal["musplit", "denoisplit", "denoisplit_musplit"]
@@ -1629,15 +1639,12 @@ def get_likelihood_config(
1629
1639
 
1630
1640
  Returns
1631
1641
  -------
1632
- tuple[GaussianLikelihoodConfig | None, MultiChannelNMConfig | None,
1633
- NMLikelihoodConfig | None]
1634
- A tuple containing the likelihood and noise model configurations for the
1635
- specified loss type.
1636
-
1637
- - GaussianLikelihoodConfig: Gaussian likelihood configuration for musplit losses
1638
- - MultiChannelNMConfig: Multi-channel noise model configuration for denoisplit
1639
- losses
1640
- - NMLikelihoodConfig: Noise model likelihood configuration for denoisplit losses
1642
+ GaussianLikelihoodConfig or None
1643
+ Configuration for the Gaussian likelihood model.
1644
+ MultiChannelNMConfig or None
1645
+ Configuration for the multi-channel noise model.
1646
+ NMLikelihoodConfig or None
1647
+ Configuration for the noise model likelihood.
1641
1648
 
1642
1649
  Raises
1643
1650
  ------
@@ -1647,7 +1654,7 @@ def get_likelihood_config(
1647
1654
  # gaussian likelihood
1648
1655
  if loss_type in ["musplit", "denoisplit_musplit"]:
1649
1656
  # if predict_logvar is None:
1650
- # raise ValueError(f"predict_logvar is required for loss_type '{loss_type}'")
1657
+ # raise ValueError(f"predict_logvar is required for loss_type '{loss_type}'")
1651
1658
  # TODO validators should be in pydantic models
1652
1659
  gaussian_lik_config = GaussianLikelihoodConfig(
1653
1660
  predict_logvar=predict_logvar,
@@ -1903,7 +1910,7 @@ def create_microsplit_configuration(
1903
1910
  decoder_dropout: float = 0.0,
1904
1911
  nonlinearity: Literal[
1905
1912
  "None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU", "ELU"
1906
- ] = "ReLU",
1913
+ ] = "ReLU", # TODO do we need all these?
1907
1914
  analytical_kl: bool = False,
1908
1915
  predict_logvar: Literal["pixelwise"] = "pixelwise",
1909
1916
  logvar_lowerbound: Union[float, None] = None,
@@ -1943,8 +1950,11 @@ def create_microsplit_configuration(
1943
1950
  Strides for the decoder convolutional layers, by default (2, 2).
1944
1951
  multiscale_count : int, optional
1945
1952
  Number of multiscale levels, by default 1.
1953
+ grid_size : int, optional
1954
+ Size of the grid for the lateral context, by default 32.
1946
1955
  z_dims : tuple[int, ...], optional
1947
- List of latent dimensions for each hierarchy level in the LVAE, by default (128, 128).
1956
+ List of latent dimensions for each hierarchy level in the LVAE, by default
1957
+ (128, 128).
1948
1958
  output_channels : int, optional
1949
1959
  Number of output channels for the model, by default 1.
1950
1960
  encoder_n_filters : int, optional
@@ -207,13 +207,12 @@ class DataConfig(BaseModel):
207
207
 
208
208
  @field_validator("train_dataloader_params", "val_dataloader_params", mode="before")
209
209
  @classmethod
210
- def set_default_dataloader_params(
210
+ def set_default_pin_memory(
211
211
  cls, dataloader_params: dict[str, Any]
212
212
  ) -> dict[str, Any]:
213
213
  """
214
- Set default dataloader parameters if not provided.
214
+ Set default pin_memory for dataloader parameters if not provided.
215
215
 
216
- - If 'num_workers' is not set, it defaults to the number of available CPU cores.
217
216
  - If 'pin_memory' is not set, it defaults to True if CUDA is available.
218
217
 
219
218
  Parameters
@@ -224,21 +223,62 @@ class DataConfig(BaseModel):
224
223
  Returns
225
224
  -------
226
225
  dict of {str: Any}
227
- The dataloader parameters with defaults applied.
226
+ The dataloader parameters with pin_memory default applied.
227
+ """
228
+ if "pin_memory" not in dataloader_params:
229
+ import torch
230
+
231
+ dataloader_params["pin_memory"] = torch.cuda.is_available()
232
+
233
+ return dataloader_params
234
+
235
+ @field_validator("train_dataloader_params", mode="before")
236
+ @classmethod
237
+ def set_default_train_workers(
238
+ cls, dataloader_params: dict[str, Any]
239
+ ) -> dict[str, Any]:
240
+ """
241
+ Set default num_workers for training dataloader if not provided.
242
+
243
+ - If 'num_workers' is not set, it defaults to the number of available CPU cores.
244
+
245
+ Parameters
246
+ ----------
247
+ dataloader_params : dict of {str: Any}
248
+ The training dataloader parameters.
249
+
250
+ Returns
251
+ -------
252
+ dict of {str: Any}
253
+ The dataloader parameters with num_workers default applied.
228
254
  """
229
255
  if "num_workers" not in dataloader_params:
230
- # Use 1 worker during tests, otherwise use all available CPU cores
256
+ # Use 0 workers during tests, otherwise use all available CPU cores
231
257
  if "pytest" in sys.modules:
232
258
  dataloader_params["num_workers"] = 0
233
259
  else:
234
260
  dataloader_params["num_workers"] = os.cpu_count()
235
261
 
236
- if "pin_memory" not in dataloader_params:
237
- import torch
262
+ return dataloader_params
238
263
 
239
- dataloader_params["pin_memory"] = torch.cuda.is_available()
264
+ @model_validator(mode="after")
265
+ def set_val_workers_to_match_train(self: Self) -> Self:
266
+ """
267
+ Set validation dataloader num_workers to match training dataloader.
240
268
 
241
- return dataloader_params
269
+ If num_workers is not specified in val_dataloader_params, it will be set to the
270
+ same value as train_dataloader_params["num_workers"].
271
+
272
+ Returns
273
+ -------
274
+ Self
275
+ Validated data model with synchronized num_workers.
276
+ """
277
+ if "num_workers" not in self.val_dataloader_params:
278
+ self.val_dataloader_params["num_workers"] = self.train_dataloader_params[
279
+ "num_workers"
280
+ ]
281
+ return self
242
282
 
243
283
  @field_validator("train_dataloader_params")
244
284
  @classmethod
@@ -2,6 +2,9 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import os
6
+ import random
7
+ import sys
5
8
  from collections.abc import Sequence
6
9
  from pprint import pformat
7
10
  from typing import Annotated, Any, Literal, Self, Union
@@ -20,6 +23,12 @@ from pydantic import (
20
23
 
21
24
  from ..transformations import XYFlipModel, XYRandomRotate90Model
22
25
  from ..validators import check_axes_validity
26
+ from .patch_filter import (
27
+ MaskFilterModel,
28
+ MaxFilterModel,
29
+ MeanSTDFilterModel,
30
+ ShannonFilterModel,
31
+ )
23
32
  from .patching_strategies import (
24
33
  RandomPatchingModel,
25
34
  TiledPatchingModel,
@@ -37,6 +46,17 @@ from .patching_strategies import (
37
46
  # - or is the responsibility of the creator (e.g. conveneince functions)
38
47
 
39
48
 
49
+ def generate_random_seed() -> int:
50
+ """Generate a random seed for reproducibility.
51
+
52
+ Returns
53
+ -------
54
+ int
55
+ A random integer between 1 and 2^31 - 1.
56
+ """
57
+ return random.randint(1, 2**31 - 1)
58
+
59
+
40
60
  def np_float_to_scientific_str(x: float) -> str:
41
61
  """Return a string scientific representation of a float.
42
62
 
@@ -67,6 +87,16 @@ PatchingStrategies = Union[
67
87
  ]
68
88
  """Patching strategies."""
69
89
 
90
+ PatchFilters = Union[
91
+ MaxFilterModel,
92
+ MeanSTDFilterModel,
93
+ ShannonFilterModel,
94
+ ]
95
+ """Patch filters."""
96
+
97
+ CoordFilters = Union[MaskFilterModel] # add more here as needed
98
+ """Coordinate filters."""
99
+
70
100
 
71
101
  class NGDataConfig(BaseModel):
72
102
  """Next-Generation Dataset configuration.
@@ -105,6 +135,18 @@ class NGDataConfig(BaseModel):
105
135
  batch_size: int = Field(default=1, ge=1, validate_default=True)
106
136
  """Batch size for training."""
107
137
 
138
+ patch_filter: PatchFilters | None = Field(default=None, discriminator="name")
139
+ """Patch filter to apply when using random patching. Only available during
140
+ training."""
141
+
142
+ coord_filter: CoordFilters | None = Field(default=None, discriminator="name")
143
+ """Coordinate filter to apply when using random patching. Only available during
144
+ training."""
145
+
146
+ patch_filter_patience: int = Field(default=5, ge=1)
147
+ """Number of consecutive patches not passing the filter before accepting the next
148
+ patch."""
149
+
108
150
  image_means: list[Float] | None = Field(default=None, min_length=0, max_length=32)
109
151
  """Means of the data across channels, used for normalization."""
110
152
 
@@ -141,8 +183,8 @@ class NGDataConfig(BaseModel):
141
183
  test_dataloader_params: dict[str, Any] = Field(default={})
142
184
  """Dictionary of PyTorch test dataloader parameters."""
143
185
 
144
- seed: int | None = Field(default=None, gt=0)
145
- """Random seed for reproducibility."""
186
+ seed: int | None = Field(default_factory=generate_random_seed, gt=0)
187
+ """Random seed for reproducibility. If not specified, a random seed is generated."""
146
188
 
147
189
  @field_validator("axes")
148
190
  @classmethod
@@ -296,6 +338,129 @@ class NGDataConfig(BaseModel):
296
338
 
297
339
  return self
298
340
 
341
+ @model_validator(mode="after")
342
+ def propagate_seed_to_filters(self: Self) -> Self:
343
+ """
344
+ Propagate the main seed to patch and coordinate filters that support seeds.
345
+
346
+ This ensures that all filters use the same seed for reproducibility,
347
+ unless they already have a seed explicitly set.
348
+
349
+ Returns
350
+ -------
351
+ Self
352
+ Data model with propagated seeds.
353
+ """
354
+ if self.seed is not None:
355
+ if self.patch_filter is not None:
356
+ if (
357
+ hasattr(self.patch_filter, "seed")
358
+ and self.patch_filter.seed is None
359
+ ):
360
+ self.patch_filter.seed = self.seed
361
+
362
+ if self.coord_filter is not None:
363
+ if (
364
+ hasattr(self.coord_filter, "seed")
365
+ and self.coord_filter.seed is None
366
+ ):
367
+ self.coord_filter.seed = self.seed
368
+
369
+ return self
370
+
371
+ @model_validator(mode="after")
372
+ def propagate_seed_to_transforms(self: Self) -> Self:
373
+ """
374
+ Propagate the main seed to all transforms that support seeds.
375
+
376
+ This ensures that all transforms use the same seed for reproducibility,
377
+ unless they already have a seed explicitly set.
378
+
379
+ Returns
380
+ -------
381
+ Self
382
+ Data model with propagated seeds.
383
+ """
384
+ if self.seed is not None:
385
+ for transform in self.transforms:
386
+ if hasattr(transform, "seed") and transform.seed is None:
387
+ transform.seed = self.seed
388
+ return self
389
+
390
+ @field_validator("train_dataloader_params", "val_dataloader_params", mode="before")
391
+ @classmethod
392
+ def set_default_pin_memory(
393
+ cls, dataloader_params: dict[str, Any]
394
+ ) -> dict[str, Any]:
395
+ """
396
+ Set default pin_memory for dataloader parameters if not provided.
397
+
398
+ - If 'pin_memory' is not set, it defaults to True if CUDA is available.
399
+
400
+ Parameters
401
+ ----------
402
+ dataloader_params : dict of {str: Any}
403
+ The dataloader parameters.
404
+
405
+ Returns
406
+ -------
407
+ dict of {str: Any}
408
+ The dataloader parameters with pin_memory default applied.
409
+ """
410
+ if "pin_memory" not in dataloader_params:
411
+ import torch
412
+
413
+ dataloader_params["pin_memory"] = torch.cuda.is_available()
414
+ return dataloader_params
415
+
416
+ @field_validator("train_dataloader_params", mode="before")
417
+ @classmethod
418
+ def set_default_train_workers(
419
+ cls, dataloader_params: dict[str, Any]
420
+ ) -> dict[str, Any]:
421
+ """
422
+ Set default num_workers for training dataloader if not provided.
423
+
424
+ - If 'num_workers' is not set, it defaults to the number of available CPU cores.
425
+
426
+ Parameters
427
+ ----------
428
+ dataloader_params : dict of {str: Any}
429
+ The training dataloader parameters.
430
+
431
+ Returns
432
+ -------
433
+ dict of {str: Any}
434
+ The dataloader parameters with num_workers default applied.
435
+ """
436
+ if "num_workers" not in dataloader_params:
437
+ # Use 0 workers during tests, otherwise use all available CPU cores
438
+ if "pytest" in sys.modules:
439
+ dataloader_params["num_workers"] = 0
440
+ else:
441
+ dataloader_params["num_workers"] = os.cpu_count()
442
+
443
+ return dataloader_params
444
+
445
+ @model_validator(mode="after")
446
+ def set_val_workers_to_match_train(self: Self) -> Self:
447
+ """
448
+ Set validation dataloader num_workers to match training dataloader.
449
+
450
+ If num_workers is not specified in val_dataloader_params, it will be set to the
451
+ same value as train_dataloader_params["num_workers"].
452
+
453
+ Returns
454
+ -------
455
+ Self
456
+ Validated data model with synchronized num_workers.
457
+ """
458
+ if "num_workers" not in self.val_dataloader_params:
459
+ self.val_dataloader_params["num_workers"] = self.train_dataloader_params[
460
+ "num_workers"
461
+ ]
462
+ return self
463
+
299
464
  def __str__(self) -> str:
300
465
  """
301
466
  Pretty string reprensenting the configuration.
@@ -0,0 +1,15 @@
1
+ """Pydantic models representing coordinate and patch filters."""
2
+
3
+ __all__ = [
4
+ "FilterModel",
5
+ "MaskFilterModel",
6
+ "MaxFilterModel",
7
+ "MeanSTDFilterModel",
8
+ "ShannonFilterModel",
9
+ ]
10
+
11
+ from .filter_model import FilterModel
12
+ from .mask_filter_model import MaskFilterModel
13
+ from .max_filter_model import MaxFilterModel
14
+ from .meanstd_filter_model import MeanSTDFilterModel
15
+ from .shannon_filter_model import ShannonFilterModel
@@ -0,0 +1,16 @@
1
+ """Base class for patch and coordinate filtering models."""
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+
6
+ class FilterModel(BaseModel):
7
+ """Base class for patch and coordinate filtering models."""
8
+
9
+ name: str
10
+ """Name of the filter."""
11
+
12
+ p: float = Field(1.0, ge=0.0, le=1.0)
13
+ """Probability of applying the filter to a patch or coordinate."""
14
+
15
+ seed: int | None = Field(default=None, gt=0)
16
+ """Seed for the random number generator for reproducibility."""
@@ -0,0 +1,17 @@
1
+ """Pydantic model for the mask coordinate filter."""
2
+
3
+ from typing import Literal
4
+
5
+ from pydantic import Field
6
+
7
+ from .filter_model import FilterModel
8
+
9
+
10
+ class MaskFilterModel(FilterModel):
11
+ """Pydantic model for the mask coordinate filter."""
12
+
13
+ name: Literal["mask"] = "mask"
14
+ """Name of the filter."""
15
+
16
+ coverage: float = Field(0.5, ge=0.0, le=1.0)
17
+ """Percentage of masked pixels required to keep a patch."""
@@ -0,0 +1,15 @@
1
+ """Pydantic model for the max patch filter."""
2
+
3
+ from typing import Literal
4
+
5
+ from .filter_model import FilterModel
6
+
7
+
8
+ class MaxFilterModel(FilterModel):
9
+ """Pydantic model for the max patch filter."""
10
+
11
+ name: Literal["max"] = "max"
12
+ """Name of the filter."""
13
+
14
+ threshold: float
15
+ """Threshold for the minimum of the max-filtered patch."""
@@ -0,0 +1,18 @@
1
+ """Pydantic model for the mean std patch filter."""
2
+
3
+ from typing import Literal
4
+
5
+ from .filter_model import FilterModel
6
+
7
+
8
+ class MeanSTDFilterModel(FilterModel):
9
+ """Pydantic model for the mean std patch filter."""
10
+
11
+ name: Literal["mean_std"] = "mean_std"
12
+ """Name of the filter."""
13
+
14
+ mean_threshold: float
15
+ """Minimum mean intensity required to keep a patch."""
16
+
17
+ std_threshold: float | None = None
18
+ """Minimum standard deviation required to keep a patch."""
@@ -0,0 +1,15 @@
1
+ """Pydantic model for the Shannon entropy patch filter."""
2
+
3
+ from typing import Literal
4
+
5
+ from .filter_model import FilterModel
6
+
7
+
8
+ class ShannonFilterModel(FilterModel):
9
+ """Pydantic model for the Shannon entropy patch filter."""
10
+
11
+ name: Literal["shannon"] = "shannon"
12
+ """Name of the filter."""
13
+
14
+ threshold: float
15
+ """Minimum Shannon entropy required to keep a patch."""
@@ -0,0 +1,17 @@
1
+ """Coordinate and patch filters supported by CAREamics."""
2
+
3
+ from careamics.utils import BaseEnum
4
+
5
+
6
+ class SupportedPatchFilters(str, BaseEnum):
7
+ """Supported patch filters."""
8
+
9
+ MAX = "max"
10
+ MEANSTD = "mean_std"
11
+ SHANNON = "shannon"
12
+
13
+
14
+ class SupportedCoordinateFilters(str, BaseEnum):
15
+ """Supported coordinate filters."""
16
+
17
+ MASK = "mask"