careamics 0.0.16__py3-none-any.whl → 0.0.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +7 -4
- careamics/config/configuration.py +6 -55
- careamics/config/configuration_factories.py +22 -12
- careamics/config/data/data_model.py +49 -9
- careamics/config/data/ng_data_model.py +167 -2
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_model.py +16 -0
- careamics/config/data/patch_filter/mask_filter_model.py +17 -0
- careamics/config/data/patch_filter/max_filter_model.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_model.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_model.py +15 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/dataset_ng/dataset.py +57 -5
- careamics/dataset_ng/factory.py +101 -18
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +94 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +95 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/lightning/callbacks/data_stats_callback.py +13 -3
- careamics/lightning/dataset_ng/data_module.py +79 -2
- careamics/lightning/lightning_module.py +4 -3
- careamics/lightning/microsplit_data_module.py +15 -10
- careamics/lvae_training/eval_utils.py +46 -24
- careamics/models/lvae/likelihoods.py +2 -1
- careamics/prediction_utils/prediction_outputs.py +3 -2
- careamics/prediction_utils/stitch_prediction.py +17 -6
- careamics/utils/version.py +4 -4
- {careamics-0.0.16.dist-info → careamics-0.0.17.dist-info}/METADATA +5 -11
- {careamics-0.0.16.dist-info → careamics-0.0.17.dist-info}/RECORD +36 -21
- {careamics-0.0.16.dist-info → careamics-0.0.17.dist-info}/WHEEL +0 -0
- {careamics-0.0.16.dist-info → careamics-0.0.17.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.16.dist-info → careamics-0.0.17.dist-info}/licenses/LICENSE +0 -0
careamics/careamist.py
CHANGED
|
@@ -41,6 +41,7 @@ logger = get_logger(__name__)
|
|
|
41
41
|
LOGGER_TYPES = list[Union[TensorBoardLogger, WandbLogger, CSVLogger]]
|
|
42
42
|
|
|
43
43
|
|
|
44
|
+
# TODO type ignore have been added because of the czi data type in data configuration
|
|
44
45
|
class CAREamist:
|
|
45
46
|
"""Main CAREamics class, allowing training and prediction using various algorithms.
|
|
46
47
|
|
|
@@ -674,7 +675,7 @@ class CAREamist:
|
|
|
674
675
|
# create the prediction
|
|
675
676
|
self.pred_datamodule = create_predict_datamodule(
|
|
676
677
|
pred_data=source,
|
|
677
|
-
data_type=data_type or self.cfg.data_config.data_type,
|
|
678
|
+
data_type=data_type or self.cfg.data_config.data_type, # type: ignore
|
|
678
679
|
axes=axes or self.cfg.data_config.axes,
|
|
679
680
|
image_means=self.cfg.data_config.image_means,
|
|
680
681
|
image_stds=self.cfg.data_config.image_stds,
|
|
@@ -817,14 +818,16 @@ class CAREamist:
|
|
|
817
818
|
|
|
818
819
|
# extract file names
|
|
819
820
|
source_path: Union[Path, str, NDArray]
|
|
820
|
-
source_data_type: Literal["array", "tiff", "
|
|
821
|
+
source_data_type: Literal["array", "tiff", "custom"]
|
|
821
822
|
if isinstance(source, PredictDataModule):
|
|
822
823
|
source_path = source.pred_data
|
|
823
|
-
source_data_type = source.data_type
|
|
824
|
+
source_data_type = source.data_type # type: ignore
|
|
824
825
|
extension_filter = source.extension_filter
|
|
825
826
|
elif isinstance(source, (str | Path)):
|
|
826
827
|
source_path = source
|
|
827
|
-
source_data_type =
|
|
828
|
+
source_data_type = (
|
|
829
|
+
data_type or self.cfg.data_config.data_type # type: ignore
|
|
830
|
+
)
|
|
828
831
|
extension_filter = SupportedData.get_extension_pattern(
|
|
829
832
|
SupportedData(source_data_type)
|
|
830
833
|
)
|
|
@@ -3,14 +3,12 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import re
|
|
6
|
-
from collections.abc import Callable
|
|
7
6
|
from pprint import pformat
|
|
8
7
|
from typing import Any, Literal, Self, Union
|
|
9
8
|
|
|
10
9
|
import numpy as np
|
|
11
10
|
from bioimageio.spec.generic.v0_3 import CiteEntry
|
|
12
11
|
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
|
13
|
-
from pydantic.main import IncEx
|
|
14
12
|
|
|
15
13
|
from careamics.config.algorithms import (
|
|
16
14
|
CAREAlgorithm,
|
|
@@ -343,19 +341,7 @@ class Configuration(BaseModel):
|
|
|
343
341
|
|
|
344
342
|
def model_dump(
|
|
345
343
|
self,
|
|
346
|
-
|
|
347
|
-
mode: Literal["json", "python"] | str = "python",
|
|
348
|
-
include: IncEx | None = None,
|
|
349
|
-
exclude: IncEx | None = None,
|
|
350
|
-
context: Any | None = None,
|
|
351
|
-
by_alias: bool | None = False,
|
|
352
|
-
exclude_unset: bool = False,
|
|
353
|
-
exclude_defaults: bool = False,
|
|
354
|
-
exclude_none: bool = True,
|
|
355
|
-
round_trip: bool = False,
|
|
356
|
-
warnings: bool | Literal["none", "warn", "error"] = True,
|
|
357
|
-
fallback: Callable[[Any], Any] | None = None,
|
|
358
|
-
serialize_as_any: bool = False,
|
|
344
|
+
**kwargs: Any,
|
|
359
345
|
) -> dict[str, Any]:
|
|
360
346
|
"""
|
|
361
347
|
Override model_dump method in order to set default values.
|
|
@@ -365,50 +351,15 @@ class Configuration(BaseModel):
|
|
|
365
351
|
|
|
366
352
|
Parameters
|
|
367
353
|
----------
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
include : Any | None, default=None
|
|
371
|
-
Attributes to include.
|
|
372
|
-
exclude : Any | None, default=None
|
|
373
|
-
Attributes to exclude.
|
|
374
|
-
context : Any | None, default=None
|
|
375
|
-
Additional context to pass to the serialization functions.
|
|
376
|
-
by_alias : bool, default=False
|
|
377
|
-
Whether to use attribute aliases.
|
|
378
|
-
exclude_unset : bool, default=False
|
|
379
|
-
Whether to exclude fields that are not set.
|
|
380
|
-
exclude_defaults : bool, default=False
|
|
381
|
-
Whether to exclude fields that have default values.
|
|
382
|
-
exclude_none : bool, default=true
|
|
383
|
-
Whether to exclude fields that have None values.
|
|
384
|
-
round_trip : bool, default=False
|
|
385
|
-
Whether to dump and load the data to ensure that the output is a valid
|
|
386
|
-
representation.
|
|
387
|
-
warnings : bool | Literal['none', 'warn', 'error'], default=True
|
|
388
|
-
Whether to emit warnings.
|
|
389
|
-
fallback : Callable[[Any], Any] | None, default=None
|
|
390
|
-
A function to call when an unknown value is encountered.
|
|
391
|
-
serialize_as_any : bool, default=False
|
|
392
|
-
Whether to serialize all types as Any.
|
|
354
|
+
**kwargs : Any
|
|
355
|
+
Additional arguments to pass to the parent model_dump method.
|
|
393
356
|
|
|
394
357
|
Returns
|
|
395
358
|
-------
|
|
396
359
|
dict
|
|
397
360
|
Dictionary containing the model parameters.
|
|
398
361
|
"""
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
include=include,
|
|
402
|
-
exclude=exclude,
|
|
403
|
-
context=context,
|
|
404
|
-
by_alias=by_alias,
|
|
405
|
-
exclude_unset=exclude_unset,
|
|
406
|
-
exclude_defaults=exclude_defaults,
|
|
407
|
-
exclude_none=exclude_none,
|
|
408
|
-
round_trip=round_trip,
|
|
409
|
-
warnings=warnings,
|
|
410
|
-
fallback=fallback,
|
|
411
|
-
serialize_as_any=serialize_as_any,
|
|
412
|
-
)
|
|
362
|
+
if "exclude_none" not in kwargs:
|
|
363
|
+
kwargs["exclude_none"] = True
|
|
413
364
|
|
|
414
|
-
return
|
|
365
|
+
return super().model_dump(**kwargs)
|
|
@@ -311,6 +311,10 @@ def _create_microsplit_data_configuration(
|
|
|
311
311
|
Axes of the data.
|
|
312
312
|
patch_size : list of int
|
|
313
313
|
Size of the patches along the spatial dimensions.
|
|
314
|
+
grid_size : int
|
|
315
|
+
Grid size for patch extraction.
|
|
316
|
+
multiscale_count : int
|
|
317
|
+
Number of LC scales.
|
|
314
318
|
batch_size : int
|
|
315
319
|
Batch size.
|
|
316
320
|
augmentations : list of transforms
|
|
@@ -1610,6 +1614,12 @@ def get_likelihood_config(
|
|
|
1610
1614
|
]:
|
|
1611
1615
|
"""Get the likelihood configuration for split models.
|
|
1612
1616
|
|
|
1617
|
+
Returns a tuple containing the following optional entries:
|
|
1618
|
+
- GaussianLikelihoodConfig: Gaussian likelihood configuration for musplit losses
|
|
1619
|
+
- MultiChannelNMConfig: Multi-channel noise model configuration for denoisplit
|
|
1620
|
+
losses
|
|
1621
|
+
- NMLikelihoodConfig: Noise model likelihood configuration for denoisplit losses
|
|
1622
|
+
|
|
1613
1623
|
Parameters
|
|
1614
1624
|
----------
|
|
1615
1625
|
loss_type : Literal["musplit", "denoisplit", "denoisplit_musplit"]
|
|
@@ -1629,15 +1639,12 @@ def get_likelihood_config(
|
|
|
1629
1639
|
|
|
1630
1640
|
Returns
|
|
1631
1641
|
-------
|
|
1632
|
-
|
|
1633
|
-
|
|
1634
|
-
|
|
1635
|
-
|
|
1636
|
-
|
|
1637
|
-
|
|
1638
|
-
- MultiChannelNMConfig: Multi-channel noise model configuration for denoisplit
|
|
1639
|
-
losses
|
|
1640
|
-
- NMLikelihoodConfig: Noise model likelihood configuration for denoisplit losses
|
|
1642
|
+
GaussianLikelihoodConfig or None
|
|
1643
|
+
Configuration for the Gaussian likelihood model.
|
|
1644
|
+
MultiChannelNMConfig or None
|
|
1645
|
+
Configuration for the multi-channel noise model.
|
|
1646
|
+
NMLikelihoodConfig or None
|
|
1647
|
+
Configuration for the noise model likelihood.
|
|
1641
1648
|
|
|
1642
1649
|
Raises
|
|
1643
1650
|
------
|
|
@@ -1647,7 +1654,7 @@ def get_likelihood_config(
|
|
|
1647
1654
|
# gaussian likelihood
|
|
1648
1655
|
if loss_type in ["musplit", "denoisplit_musplit"]:
|
|
1649
1656
|
# if predict_logvar is None:
|
|
1650
|
-
#
|
|
1657
|
+
# raise ValueError(f"predict_logvar is required for loss_type '{loss_type}'")
|
|
1651
1658
|
# TODO validators should be in pydantic models
|
|
1652
1659
|
gaussian_lik_config = GaussianLikelihoodConfig(
|
|
1653
1660
|
predict_logvar=predict_logvar,
|
|
@@ -1903,7 +1910,7 @@ def create_microsplit_configuration(
|
|
|
1903
1910
|
decoder_dropout: float = 0.0,
|
|
1904
1911
|
nonlinearity: Literal[
|
|
1905
1912
|
"None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU", "ELU"
|
|
1906
|
-
] = "ReLU",
|
|
1913
|
+
] = "ReLU", # TODO do we need all these?
|
|
1907
1914
|
analytical_kl: bool = False,
|
|
1908
1915
|
predict_logvar: Literal["pixelwise"] = "pixelwise",
|
|
1909
1916
|
logvar_lowerbound: Union[float, None] = None,
|
|
@@ -1943,8 +1950,11 @@ def create_microsplit_configuration(
|
|
|
1943
1950
|
Strides for the decoder convolutional layers, by default (2, 2).
|
|
1944
1951
|
multiscale_count : int, optional
|
|
1945
1952
|
Number of multiscale levels, by default 1.
|
|
1953
|
+
grid_size : int, optional
|
|
1954
|
+
Size of the grid for the lateral context, by default 32.
|
|
1946
1955
|
z_dims : tuple[int, ...], optional
|
|
1947
|
-
List of latent dimensions for each hierarchy level in the LVAE, by default
|
|
1956
|
+
List of latent dimensions for each hierarchy level in the LVAE, by default
|
|
1957
|
+
(128, 128).
|
|
1948
1958
|
output_channels : int, optional
|
|
1949
1959
|
Number of output channels for the model, by default 1.
|
|
1950
1960
|
encoder_n_filters : int, optional
|
|
@@ -207,13 +207,12 @@ class DataConfig(BaseModel):
|
|
|
207
207
|
|
|
208
208
|
@field_validator("train_dataloader_params", "val_dataloader_params", mode="before")
|
|
209
209
|
@classmethod
|
|
210
|
-
def
|
|
210
|
+
def set_default_pin_memory(
|
|
211
211
|
cls, dataloader_params: dict[str, Any]
|
|
212
212
|
) -> dict[str, Any]:
|
|
213
213
|
"""
|
|
214
|
-
Set default dataloader parameters if not provided.
|
|
214
|
+
Set default pin_memory for dataloader parameters if not provided.
|
|
215
215
|
|
|
216
|
-
- If 'num_workers' is not set, it defaults to the number of available CPU cores.
|
|
217
216
|
- If 'pin_memory' is not set, it defaults to True if CUDA is available.
|
|
218
217
|
|
|
219
218
|
Parameters
|
|
@@ -224,21 +223,62 @@ class DataConfig(BaseModel):
|
|
|
224
223
|
Returns
|
|
225
224
|
-------
|
|
226
225
|
dict of {str: Any}
|
|
227
|
-
The dataloader parameters with
|
|
226
|
+
The dataloader parameters with pin_memory default applied.
|
|
227
|
+
"""
|
|
228
|
+
if "pin_memory" not in dataloader_params:
|
|
229
|
+
import torch
|
|
230
|
+
|
|
231
|
+
dataloader_params["pin_memory"] = torch.cuda.is_available()
|
|
232
|
+
|
|
233
|
+
return dataloader_params
|
|
234
|
+
|
|
235
|
+
@field_validator("train_dataloader_params", mode="before")
|
|
236
|
+
@classmethod
|
|
237
|
+
def set_default_train_workers(
|
|
238
|
+
cls, dataloader_params: dict[str, Any]
|
|
239
|
+
) -> dict[str, Any]:
|
|
240
|
+
"""
|
|
241
|
+
Set default num_workers for training dataloader if not provided.
|
|
242
|
+
|
|
243
|
+
- If 'num_workers' is not set, it defaults to the number of available CPU cores.
|
|
244
|
+
|
|
245
|
+
Parameters
|
|
246
|
+
----------
|
|
247
|
+
dataloader_params : dict of {str: Any}
|
|
248
|
+
The training dataloader parameters.
|
|
249
|
+
|
|
250
|
+
Returns
|
|
251
|
+
-------
|
|
252
|
+
dict of {str: Any}
|
|
253
|
+
The dataloader parameters with num_workers default applied.
|
|
228
254
|
"""
|
|
229
255
|
if "num_workers" not in dataloader_params:
|
|
230
|
-
# Use
|
|
256
|
+
# Use 0 workers during tests, otherwise use all available CPU cores
|
|
231
257
|
if "pytest" in sys.modules:
|
|
232
258
|
dataloader_params["num_workers"] = 0
|
|
233
259
|
else:
|
|
234
260
|
dataloader_params["num_workers"] = os.cpu_count()
|
|
235
261
|
|
|
236
|
-
|
|
237
|
-
import torch
|
|
262
|
+
return dataloader_params
|
|
238
263
|
|
|
239
|
-
|
|
264
|
+
@model_validator(mode="after")
|
|
265
|
+
def set_val_workers_to_match_train(self: Self) -> Self:
|
|
266
|
+
"""
|
|
267
|
+
Set validation dataloader num_workers to match training dataloader.
|
|
240
268
|
|
|
241
|
-
|
|
269
|
+
If num_workers is not specified in val_dataloader_params, it will be set to the
|
|
270
|
+
same value as train_dataloader_params["num_workers"].
|
|
271
|
+
|
|
272
|
+
Returns
|
|
273
|
+
-------
|
|
274
|
+
Self
|
|
275
|
+
Validated data model with synchronized num_workers.
|
|
276
|
+
"""
|
|
277
|
+
if "num_workers" not in self.val_dataloader_params:
|
|
278
|
+
self.val_dataloader_params["num_workers"] = self.train_dataloader_params[
|
|
279
|
+
"num_workers"
|
|
280
|
+
]
|
|
281
|
+
return self
|
|
242
282
|
|
|
243
283
|
@field_validator("train_dataloader_params")
|
|
244
284
|
@classmethod
|
|
@@ -2,6 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import os
|
|
6
|
+
import random
|
|
7
|
+
import sys
|
|
5
8
|
from collections.abc import Sequence
|
|
6
9
|
from pprint import pformat
|
|
7
10
|
from typing import Annotated, Any, Literal, Self, Union
|
|
@@ -20,6 +23,12 @@ from pydantic import (
|
|
|
20
23
|
|
|
21
24
|
from ..transformations import XYFlipModel, XYRandomRotate90Model
|
|
22
25
|
from ..validators import check_axes_validity
|
|
26
|
+
from .patch_filter import (
|
|
27
|
+
MaskFilterModel,
|
|
28
|
+
MaxFilterModel,
|
|
29
|
+
MeanSTDFilterModel,
|
|
30
|
+
ShannonFilterModel,
|
|
31
|
+
)
|
|
23
32
|
from .patching_strategies import (
|
|
24
33
|
RandomPatchingModel,
|
|
25
34
|
TiledPatchingModel,
|
|
@@ -37,6 +46,17 @@ from .patching_strategies import (
|
|
|
37
46
|
# - or is the responsibility of the creator (e.g. conveneince functions)
|
|
38
47
|
|
|
39
48
|
|
|
49
|
+
def generate_random_seed() -> int:
|
|
50
|
+
"""Generate a random seed for reproducibility.
|
|
51
|
+
|
|
52
|
+
Returns
|
|
53
|
+
-------
|
|
54
|
+
int
|
|
55
|
+
A random integer between 1 and 2^31 - 1.
|
|
56
|
+
"""
|
|
57
|
+
return random.randint(1, 2**31 - 1)
|
|
58
|
+
|
|
59
|
+
|
|
40
60
|
def np_float_to_scientific_str(x: float) -> str:
|
|
41
61
|
"""Return a string scientific representation of a float.
|
|
42
62
|
|
|
@@ -67,6 +87,16 @@ PatchingStrategies = Union[
|
|
|
67
87
|
]
|
|
68
88
|
"""Patching strategies."""
|
|
69
89
|
|
|
90
|
+
PatchFilters = Union[
|
|
91
|
+
MaxFilterModel,
|
|
92
|
+
MeanSTDFilterModel,
|
|
93
|
+
ShannonFilterModel,
|
|
94
|
+
]
|
|
95
|
+
"""Patch filters."""
|
|
96
|
+
|
|
97
|
+
CoordFilters = Union[MaskFilterModel] # add more here as needed
|
|
98
|
+
"""Coordinate filters."""
|
|
99
|
+
|
|
70
100
|
|
|
71
101
|
class NGDataConfig(BaseModel):
|
|
72
102
|
"""Next-Generation Dataset configuration.
|
|
@@ -105,6 +135,18 @@ class NGDataConfig(BaseModel):
|
|
|
105
135
|
batch_size: int = Field(default=1, ge=1, validate_default=True)
|
|
106
136
|
"""Batch size for training."""
|
|
107
137
|
|
|
138
|
+
patch_filter: PatchFilters | None = Field(default=None, discriminator="name")
|
|
139
|
+
"""Patch filter to apply when using random patching. Only available during
|
|
140
|
+
training."""
|
|
141
|
+
|
|
142
|
+
coord_filter: CoordFilters | None = Field(default=None, discriminator="name")
|
|
143
|
+
"""Coordinate filter to apply when using random patching. Only available during
|
|
144
|
+
training."""
|
|
145
|
+
|
|
146
|
+
patch_filter_patience: int = Field(default=5, ge=1)
|
|
147
|
+
"""Number of consecutive patches not passing the filter before accepting the next
|
|
148
|
+
patch."""
|
|
149
|
+
|
|
108
150
|
image_means: list[Float] | None = Field(default=None, min_length=0, max_length=32)
|
|
109
151
|
"""Means of the data across channels, used for normalization."""
|
|
110
152
|
|
|
@@ -141,8 +183,8 @@ class NGDataConfig(BaseModel):
|
|
|
141
183
|
test_dataloader_params: dict[str, Any] = Field(default={})
|
|
142
184
|
"""Dictionary of PyTorch test dataloader parameters."""
|
|
143
185
|
|
|
144
|
-
seed: int | None = Field(
|
|
145
|
-
"""Random seed for reproducibility."""
|
|
186
|
+
seed: int | None = Field(default_factory=generate_random_seed, gt=0)
|
|
187
|
+
"""Random seed for reproducibility. If not specified, a random seed is generated."""
|
|
146
188
|
|
|
147
189
|
@field_validator("axes")
|
|
148
190
|
@classmethod
|
|
@@ -296,6 +338,129 @@ class NGDataConfig(BaseModel):
|
|
|
296
338
|
|
|
297
339
|
return self
|
|
298
340
|
|
|
341
|
+
@model_validator(mode="after")
|
|
342
|
+
def propagate_seed_to_filters(self: Self) -> Self:
|
|
343
|
+
"""
|
|
344
|
+
Propagate the main seed to patch and coordinate filters that support seeds.
|
|
345
|
+
|
|
346
|
+
This ensures that all filters use the same seed for reproducibility,
|
|
347
|
+
unless they already have a seed explicitly set.
|
|
348
|
+
|
|
349
|
+
Returns
|
|
350
|
+
-------
|
|
351
|
+
Self
|
|
352
|
+
Data model with propagated seeds.
|
|
353
|
+
"""
|
|
354
|
+
if self.seed is not None:
|
|
355
|
+
if self.patch_filter is not None:
|
|
356
|
+
if (
|
|
357
|
+
hasattr(self.patch_filter, "seed")
|
|
358
|
+
and self.patch_filter.seed is None
|
|
359
|
+
):
|
|
360
|
+
self.patch_filter.seed = self.seed
|
|
361
|
+
|
|
362
|
+
if self.coord_filter is not None:
|
|
363
|
+
if (
|
|
364
|
+
hasattr(self.coord_filter, "seed")
|
|
365
|
+
and self.coord_filter.seed is None
|
|
366
|
+
):
|
|
367
|
+
self.coord_filter.seed = self.seed
|
|
368
|
+
|
|
369
|
+
return self
|
|
370
|
+
|
|
371
|
+
@model_validator(mode="after")
|
|
372
|
+
def propagate_seed_to_transforms(self: Self) -> Self:
|
|
373
|
+
"""
|
|
374
|
+
Propagate the main seed to all transforms that support seeds.
|
|
375
|
+
|
|
376
|
+
This ensures that all transforms use the same seed for reproducibility,
|
|
377
|
+
unless they already have a seed explicitly set.
|
|
378
|
+
|
|
379
|
+
Returns
|
|
380
|
+
-------
|
|
381
|
+
Self
|
|
382
|
+
Data model with propagated seeds.
|
|
383
|
+
"""
|
|
384
|
+
if self.seed is not None:
|
|
385
|
+
for transform in self.transforms:
|
|
386
|
+
if hasattr(transform, "seed") and transform.seed is None:
|
|
387
|
+
transform.seed = self.seed
|
|
388
|
+
return self
|
|
389
|
+
|
|
390
|
+
@field_validator("train_dataloader_params", "val_dataloader_params", mode="before")
|
|
391
|
+
@classmethod
|
|
392
|
+
def set_default_pin_memory(
|
|
393
|
+
cls, dataloader_params: dict[str, Any]
|
|
394
|
+
) -> dict[str, Any]:
|
|
395
|
+
"""
|
|
396
|
+
Set default pin_memory for dataloader parameters if not provided.
|
|
397
|
+
|
|
398
|
+
- If 'pin_memory' is not set, it defaults to True if CUDA is available.
|
|
399
|
+
|
|
400
|
+
Parameters
|
|
401
|
+
----------
|
|
402
|
+
dataloader_params : dict of {str: Any}
|
|
403
|
+
The dataloader parameters.
|
|
404
|
+
|
|
405
|
+
Returns
|
|
406
|
+
-------
|
|
407
|
+
dict of {str: Any}
|
|
408
|
+
The dataloader parameters with pin_memory default applied.
|
|
409
|
+
"""
|
|
410
|
+
if "pin_memory" not in dataloader_params:
|
|
411
|
+
import torch
|
|
412
|
+
|
|
413
|
+
dataloader_params["pin_memory"] = torch.cuda.is_available()
|
|
414
|
+
return dataloader_params
|
|
415
|
+
|
|
416
|
+
@field_validator("train_dataloader_params", mode="before")
|
|
417
|
+
@classmethod
|
|
418
|
+
def set_default_train_workers(
|
|
419
|
+
cls, dataloader_params: dict[str, Any]
|
|
420
|
+
) -> dict[str, Any]:
|
|
421
|
+
"""
|
|
422
|
+
Set default num_workers for training dataloader if not provided.
|
|
423
|
+
|
|
424
|
+
- If 'num_workers' is not set, it defaults to the number of available CPU cores.
|
|
425
|
+
|
|
426
|
+
Parameters
|
|
427
|
+
----------
|
|
428
|
+
dataloader_params : dict of {str: Any}
|
|
429
|
+
The training dataloader parameters.
|
|
430
|
+
|
|
431
|
+
Returns
|
|
432
|
+
-------
|
|
433
|
+
dict of {str: Any}
|
|
434
|
+
The dataloader parameters with num_workers default applied.
|
|
435
|
+
"""
|
|
436
|
+
if "num_workers" not in dataloader_params:
|
|
437
|
+
# Use 0 workers during tests, otherwise use all available CPU cores
|
|
438
|
+
if "pytest" in sys.modules:
|
|
439
|
+
dataloader_params["num_workers"] = 0
|
|
440
|
+
else:
|
|
441
|
+
dataloader_params["num_workers"] = os.cpu_count()
|
|
442
|
+
|
|
443
|
+
return dataloader_params
|
|
444
|
+
|
|
445
|
+
@model_validator(mode="after")
|
|
446
|
+
def set_val_workers_to_match_train(self: Self) -> Self:
|
|
447
|
+
"""
|
|
448
|
+
Set validation dataloader num_workers to match training dataloader.
|
|
449
|
+
|
|
450
|
+
If num_workers is not specified in val_dataloader_params, it will be set to the
|
|
451
|
+
same value as train_dataloader_params["num_workers"].
|
|
452
|
+
|
|
453
|
+
Returns
|
|
454
|
+
-------
|
|
455
|
+
Self
|
|
456
|
+
Validated data model with synchronized num_workers.
|
|
457
|
+
"""
|
|
458
|
+
if "num_workers" not in self.val_dataloader_params:
|
|
459
|
+
self.val_dataloader_params["num_workers"] = self.train_dataloader_params[
|
|
460
|
+
"num_workers"
|
|
461
|
+
]
|
|
462
|
+
return self
|
|
463
|
+
|
|
299
464
|
def __str__(self) -> str:
|
|
300
465
|
"""
|
|
301
466
|
Pretty string reprensenting the configuration.
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Pydantic models representing coordinate and patch filters."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"FilterModel",
|
|
5
|
+
"MaskFilterModel",
|
|
6
|
+
"MaxFilterModel",
|
|
7
|
+
"MeanSTDFilterModel",
|
|
8
|
+
"ShannonFilterModel",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
from .filter_model import FilterModel
|
|
12
|
+
from .mask_filter_model import MaskFilterModel
|
|
13
|
+
from .max_filter_model import MaxFilterModel
|
|
14
|
+
from .meanstd_filter_model import MeanSTDFilterModel
|
|
15
|
+
from .shannon_filter_model import ShannonFilterModel
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""Base class for patch and coordinate filtering models."""
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class FilterModel(BaseModel):
|
|
7
|
+
"""Base class for patch and coordinate filtering models."""
|
|
8
|
+
|
|
9
|
+
name: str
|
|
10
|
+
"""Name of the filter."""
|
|
11
|
+
|
|
12
|
+
p: float = Field(1.0, ge=0.0, le=1.0)
|
|
13
|
+
"""Probability of applying the filter to a patch or coordinate."""
|
|
14
|
+
|
|
15
|
+
seed: int | None = Field(default=None, gt=0)
|
|
16
|
+
"""Seed for the random number generator for reproducibility."""
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""Pydantic model for the mask coordinate filter."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from pydantic import Field
|
|
6
|
+
|
|
7
|
+
from .filter_model import FilterModel
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class MaskFilterModel(FilterModel):
|
|
11
|
+
"""Pydantic model for the mask coordinate filter."""
|
|
12
|
+
|
|
13
|
+
name: Literal["mask"] = "mask"
|
|
14
|
+
"""Name of the filter."""
|
|
15
|
+
|
|
16
|
+
coverage: float = Field(0.5, ge=0.0, le=1.0)
|
|
17
|
+
"""Percentage of masked pixels required to keep a patch."""
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Pydantic model for the max patch filter."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from .filter_model import FilterModel
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MaxFilterModel(FilterModel):
|
|
9
|
+
"""Pydantic model for the max patch filter."""
|
|
10
|
+
|
|
11
|
+
name: Literal["max"] = "max"
|
|
12
|
+
"""Name of the filter."""
|
|
13
|
+
|
|
14
|
+
threshold: float
|
|
15
|
+
"""Threshold for the minimum of the max-filtered patch."""
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""Pydantic model for the mean std patch filter."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from .filter_model import FilterModel
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MeanSTDFilterModel(FilterModel):
|
|
9
|
+
"""Pydantic model for the mean std patch filter."""
|
|
10
|
+
|
|
11
|
+
name: Literal["mean_std"] = "mean_std"
|
|
12
|
+
"""Name of the filter."""
|
|
13
|
+
|
|
14
|
+
mean_threshold: float
|
|
15
|
+
"""Minimum mean intensity required to keep a patch."""
|
|
16
|
+
|
|
17
|
+
std_threshold: float | None = None
|
|
18
|
+
"""Minimum standard deviation required to keep a patch."""
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Pydantic model for the Shannon entropy patch filter."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from .filter_model import FilterModel
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ShannonFilterModel(FilterModel):
|
|
9
|
+
"""Pydantic model for the Shannon entropy patch filter."""
|
|
10
|
+
|
|
11
|
+
name: Literal["shannon"] = "shannon"
|
|
12
|
+
"""Name of the filter."""
|
|
13
|
+
|
|
14
|
+
threshold: float
|
|
15
|
+
"""Minimum Shannon entropy required to keep a patch."""
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""Coordinate and patch filters supported by CAREamics."""
|
|
2
|
+
|
|
3
|
+
from careamics.utils import BaseEnum
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class SupportedPatchFilters(str, BaseEnum):
|
|
7
|
+
"""Supported patch filters."""
|
|
8
|
+
|
|
9
|
+
MAX = "max"
|
|
10
|
+
MEANSTD = "mean_std"
|
|
11
|
+
SHANNON = "shannon"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SupportedCoordinateFilters(str, BaseEnum):
|
|
15
|
+
"""Supported coordinate filters."""
|
|
16
|
+
|
|
17
|
+
MASK = "mask"
|