careamics 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +24 -7
- careamics/cli/utils.py +1 -1
- careamics/config/algorithms/n2v_algorithm_model.py +1 -1
- careamics/config/architectures/unet_model.py +3 -0
- careamics/config/callback_model.py +23 -34
- careamics/config/configuration.py +55 -4
- careamics/config/configuration_factories.py +288 -23
- careamics/config/data/__init__.py +2 -0
- careamics/config/data/data_model.py +41 -4
- careamics/config/data/ng_data_model.py +381 -0
- careamics/config/data/patching_strategies/__init__.py +14 -0
- careamics/config/data/patching_strategies/_overlapping_patched_model.py +103 -0
- careamics/config/data/patching_strategies/_patched_model.py +56 -0
- careamics/config/data/patching_strategies/random_patching_model.py +21 -0
- careamics/config/data/patching_strategies/sequential_patching_model.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_model.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_model.py +12 -0
- careamics/config/inference_model.py +6 -3
- careamics/config/optimizer_models.py +1 -3
- careamics/config/support/supported_data.py +7 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/training_model.py +0 -2
- careamics/config/validators/validator_utils.py +4 -3
- careamics/dataset/dataset_utils/iterate_over_files.py +2 -2
- careamics/dataset/in_memory_dataset.py +2 -1
- careamics/dataset/iterable_dataset.py +2 -2
- careamics/dataset/iterable_pred_dataset.py +2 -2
- careamics/dataset/iterable_tiled_pred_dataset.py +2 -2
- careamics/dataset/patching/patching.py +3 -2
- careamics/dataset/tiling/lvae_tiled_patching.py +16 -6
- careamics/dataset/tiling/tiled_patching.py +2 -1
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/dataset.py +229 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +734 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/{demo_dataset.ipynb → demos/demo_dataset.ipynb} +60 -53
- careamics/dataset_ng/{demo_patch_extractor.py → demos/demo_patch_extractor.py} +7 -9
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +292 -0
- careamics/dataset_ng/factory.py +451 -0
- careamics/dataset_ng/legacy_interoperability.py +170 -0
- careamics/dataset_ng/patch_extractor/__init__.py +3 -8
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +7 -5
- careamics/dataset_ng/patch_extractor/image_stack/__init__.py +4 -1
- careamics/dataset_ng/patch_extractor/image_stack/czi_image_stack.py +360 -0
- careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +5 -1
- careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -1
- careamics/dataset_ng/patch_extractor/image_stack_loader.py +5 -75
- careamics/dataset_ng/patch_extractor/patch_extractor.py +5 -4
- careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +114 -105
- careamics/dataset_ng/patching_strategies/__init__.py +6 -1
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +31 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +5 -1
- careamics/dataset_ng/patching_strategies/sequential_patching.py +5 -5
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +172 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +36 -0
- careamics/file_io/read/get_func.py +2 -1
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/data_module.py +678 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +212 -0
- careamics/lightning/lightning_module.py +5 -1
- careamics/lightning/predict_data_module.py +2 -1
- careamics/lightning/train_data_module.py +2 -1
- careamics/losses/loss_factory.py +2 -1
- careamics/lvae_training/dataset/__init__.py +8 -3
- careamics/lvae_training/dataset/config.py +3 -3
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +46 -17
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/types.py +3 -3
- careamics/lvae_training/dataset/utils/index_manager.py +259 -0
- careamics/lvae_training/eval_utils.py +93 -3
- careamics/model_io/bioimage/bioimage_utils.py +1 -1
- careamics/model_io/bioimage/model_description.py +1 -1
- careamics/model_io/bmz_io.py +1 -1
- careamics/model_io/model_io_utils.py +2 -2
- careamics/models/activation.py +2 -1
- careamics/prediction_utils/prediction_outputs.py +1 -1
- careamics/prediction_utils/stitch_prediction.py +1 -1
- careamics/transforms/compose.py +1 -0
- careamics/transforms/n2v_manipulate_torch.py +15 -9
- careamics/transforms/normalize.py +18 -7
- careamics/transforms/pixel_manipulation_torch.py +59 -92
- careamics/utils/lightning_utils.py +25 -11
- careamics/utils/metrics.py +2 -1
- careamics/utils/torch_utils.py +23 -0
- {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/METADATA +12 -11
- {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/RECORD +95 -69
- careamics/dataset_ng/dataset/__init__.py +0 -3
- careamics/dataset_ng/dataset/dataset.py +0 -184
- careamics/dataset_ng/demo_patch_extractor_factory.py +0 -37
- {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/WHEEL +0 -0
- {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.11.dist-info → careamics-0.0.13.dist-info}/licenses/LICENSE +0 -0
careamics/careamist.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
"""A class to train, predict and export models in CAREamics."""
|
|
2
2
|
|
|
3
|
+
from collections.abc import Callable
|
|
3
4
|
from pathlib import Path
|
|
4
|
-
from typing import Any,
|
|
5
|
+
from typing import Any, Literal, Optional, Union, overload
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
7
8
|
from numpy.typing import NDArray
|
|
@@ -52,6 +53,9 @@ class CAREamist:
|
|
|
52
53
|
by default None.
|
|
53
54
|
callbacks : list of Callback, optional
|
|
54
55
|
List of callbacks to use during training and prediction, by default None.
|
|
56
|
+
enable_progress_bar : bool
|
|
57
|
+
Whether a progress bar will be displayed during training, validation and
|
|
58
|
+
prediction.
|
|
55
59
|
|
|
56
60
|
Attributes
|
|
57
61
|
----------
|
|
@@ -77,6 +81,7 @@ class CAREamist:
|
|
|
77
81
|
source: Union[Path, str],
|
|
78
82
|
work_dir: Optional[Union[Path, str]] = None,
|
|
79
83
|
callbacks: Optional[list[Callback]] = None,
|
|
84
|
+
enable_progress_bar: bool = True,
|
|
80
85
|
) -> None: ...
|
|
81
86
|
|
|
82
87
|
@overload
|
|
@@ -85,6 +90,7 @@ class CAREamist:
|
|
|
85
90
|
source: Configuration,
|
|
86
91
|
work_dir: Optional[Union[Path, str]] = None,
|
|
87
92
|
callbacks: Optional[list[Callback]] = None,
|
|
93
|
+
enable_progress_bar: bool = True,
|
|
88
94
|
) -> None: ...
|
|
89
95
|
|
|
90
96
|
def __init__(
|
|
@@ -92,6 +98,7 @@ class CAREamist:
|
|
|
92
98
|
source: Union[Path, str, Configuration],
|
|
93
99
|
work_dir: Optional[Union[Path, str]] = None,
|
|
94
100
|
callbacks: Optional[list[Callback]] = None,
|
|
101
|
+
enable_progress_bar: bool = True,
|
|
95
102
|
) -> None:
|
|
96
103
|
"""
|
|
97
104
|
Initialize CAREamist with a configuration object or a path.
|
|
@@ -112,6 +119,9 @@ class CAREamist:
|
|
|
112
119
|
by default None.
|
|
113
120
|
callbacks : list of Callback, optional
|
|
114
121
|
List of callbacks to use during training and prediction, by default None.
|
|
122
|
+
enable_progress_bar : bool
|
|
123
|
+
Whether a progress bar will be displayed during training, validation and
|
|
124
|
+
prediction.
|
|
115
125
|
|
|
116
126
|
Raises
|
|
117
127
|
------
|
|
@@ -169,7 +179,7 @@ class CAREamist:
|
|
|
169
179
|
self.model, self.cfg = load_pretrained(source)
|
|
170
180
|
|
|
171
181
|
# define the checkpoint saving callback
|
|
172
|
-
self._define_callbacks(callbacks)
|
|
182
|
+
self._define_callbacks(callbacks, enable_progress_bar)
|
|
173
183
|
|
|
174
184
|
# instantiate logger
|
|
175
185
|
csv_logger = CSVLogger(
|
|
@@ -202,7 +212,7 @@ class CAREamist:
|
|
|
202
212
|
precision=self.cfg.training_config.precision,
|
|
203
213
|
max_steps=self.cfg.training_config.max_steps,
|
|
204
214
|
check_val_every_n_epoch=self.cfg.training_config.check_val_every_n_epoch,
|
|
205
|
-
enable_progress_bar=
|
|
215
|
+
enable_progress_bar=enable_progress_bar,
|
|
206
216
|
accumulate_grad_batches=self.cfg.training_config.accumulate_grad_batches,
|
|
207
217
|
gradient_clip_val=self.cfg.training_config.gradient_clip_val,
|
|
208
218
|
gradient_clip_algorithm=self.cfg.training_config.gradient_clip_algorithm,
|
|
@@ -215,13 +225,19 @@ class CAREamist:
|
|
|
215
225
|
self.train_datamodule: Optional[TrainDataModule] = None
|
|
216
226
|
self.pred_datamodule: Optional[PredictDataModule] = None
|
|
217
227
|
|
|
218
|
-
def _define_callbacks(
|
|
228
|
+
def _define_callbacks(
|
|
229
|
+
self, callbacks: Optional[list[Callback]], enable_progress_bar: bool
|
|
230
|
+
) -> None:
|
|
219
231
|
"""Define the callbacks for the training loop.
|
|
220
232
|
|
|
221
233
|
Parameters
|
|
222
234
|
----------
|
|
223
235
|
callbacks : list of Callback, optional
|
|
224
236
|
List of callbacks to use during training and prediction, by default None.
|
|
237
|
+
enable_progress_bar : bool
|
|
238
|
+
Whether a progress bar will be displayed during training, validation and
|
|
239
|
+
prediction. It controls whether a `ProgressBarCallback` is added to the
|
|
240
|
+
callback list.
|
|
225
241
|
"""
|
|
226
242
|
self.callbacks = [] if callbacks is None else callbacks
|
|
227
243
|
|
|
@@ -251,9 +267,10 @@ class CAREamist:
|
|
|
251
267
|
filename=self.cfg.experiment_name,
|
|
252
268
|
**self.cfg.training_config.checkpoint_callback.model_dump(),
|
|
253
269
|
),
|
|
254
|
-
ProgressBarCallback(),
|
|
255
270
|
]
|
|
256
271
|
)
|
|
272
|
+
if enable_progress_bar:
|
|
273
|
+
self.callbacks.append(ProgressBarCallback())
|
|
257
274
|
|
|
258
275
|
# early stopping callback
|
|
259
276
|
if self.cfg.training_config.early_stopping_callback is not None:
|
|
@@ -811,7 +828,7 @@ class CAREamist:
|
|
|
811
828
|
source_path = source.pred_data
|
|
812
829
|
source_data_type = source.data_type
|
|
813
830
|
extension_filter = source.extension_filter
|
|
814
|
-
elif isinstance(source,
|
|
831
|
+
elif isinstance(source, str | Path):
|
|
815
832
|
source_path = source
|
|
816
833
|
source_data_type = data_type or self.cfg.data_config.data_type
|
|
817
834
|
extension_filter = SupportedData.get_extension_pattern(
|
|
@@ -824,7 +841,7 @@ class CAREamist:
|
|
|
824
841
|
raise ValueError(
|
|
825
842
|
"Predicting to disk is not supported for input type 'array'."
|
|
826
843
|
)
|
|
827
|
-
assert isinstance(source_path,
|
|
844
|
+
assert isinstance(source_path, str | Path) # because data_type != "array"
|
|
828
845
|
source_path = Path(source_path)
|
|
829
846
|
|
|
830
847
|
file_paths = list_files(source_path, source_data_type, extension_filter)
|
careamics/cli/utils.py
CHANGED
|
@@ -63,6 +63,9 @@ class UNetModel(ArchitectureModel):
|
|
|
63
63
|
"""Whether information is processed independently in each channel, used to train
|
|
64
64
|
channels independently."""
|
|
65
65
|
|
|
66
|
+
use_batch_norm: bool = Field(default=True, validate_default=True)
|
|
67
|
+
"""Whether to use batch normalization in the model."""
|
|
68
|
+
|
|
66
69
|
@field_validator("num_channels_init")
|
|
67
70
|
@classmethod
|
|
68
71
|
def validate_num_channels_init(cls, num_channels_init: int) -> int:
|
|
@@ -22,52 +22,42 @@ class CheckpointModel(BaseModel):
|
|
|
22
22
|
https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint
|
|
23
23
|
"""
|
|
24
24
|
|
|
25
|
-
model_config = ConfigDict(
|
|
26
|
-
validate_assignment=True,
|
|
27
|
-
)
|
|
25
|
+
model_config = ConfigDict(validate_assignment=True, validate_default=True)
|
|
28
26
|
|
|
29
|
-
monitor: Literal["val_loss"] = Field(default="val_loss"
|
|
30
|
-
"""Quantity to monitor
|
|
27
|
+
monitor: Literal["val_loss"] = Field(default="val_loss")
|
|
28
|
+
"""Quantity to monitor, currently only `val_loss`."""
|
|
31
29
|
|
|
32
|
-
verbose: bool = Field(default=False
|
|
30
|
+
verbose: bool = Field(default=False)
|
|
33
31
|
"""Verbosity mode."""
|
|
34
32
|
|
|
35
|
-
save_weights_only: bool = Field(default=False
|
|
33
|
+
save_weights_only: bool = Field(default=False)
|
|
36
34
|
"""When `True`, only the model's weights will be saved (model.save_weights)."""
|
|
37
35
|
|
|
38
|
-
save_last: Optional[Literal[True, False, "link"]] = Field(
|
|
39
|
-
default=True, validate_default=True
|
|
40
|
-
)
|
|
36
|
+
save_last: Optional[Literal[True, False, "link"]] = Field(default=True)
|
|
41
37
|
"""When `True`, saves a last.ckpt copy whenever a checkpoint file gets saved."""
|
|
42
38
|
|
|
43
|
-
save_top_k: int = Field(default=3, ge
|
|
39
|
+
save_top_k: int = Field(default=3, ge=-1, le=100)
|
|
44
40
|
"""If `save_top_k == kz, the best k models according to the quantity monitored
|
|
45
41
|
will be saved. If `save_top_k == 0`, no models are saved. if `save_top_k == -1`,
|
|
46
42
|
all models are saved."""
|
|
47
43
|
|
|
48
|
-
mode: Literal["min", "max"] = Field(default="min"
|
|
44
|
+
mode: Literal["min", "max"] = Field(default="min")
|
|
49
45
|
"""One of {min, max}. If `save_top_k != 0`, the decision to overwrite the current
|
|
50
46
|
save file is made based on either the maximization or the minimization of the
|
|
51
47
|
monitored quantity. For 'val_acc', this should be 'max', for 'val_loss' this should
|
|
52
48
|
be 'min', etc.
|
|
53
49
|
"""
|
|
54
50
|
|
|
55
|
-
auto_insert_metric_name: bool = Field(default=False
|
|
51
|
+
auto_insert_metric_name: bool = Field(default=False)
|
|
56
52
|
"""When `True`, the checkpoints filenames will contain the metric name."""
|
|
57
53
|
|
|
58
|
-
every_n_train_steps: Optional[int] = Field(
|
|
59
|
-
default=None, ge=1, le=10, validate_default=True
|
|
60
|
-
)
|
|
54
|
+
every_n_train_steps: Optional[int] = Field(default=None, ge=1, le=1000)
|
|
61
55
|
"""Number of training steps between checkpoints."""
|
|
62
56
|
|
|
63
|
-
train_time_interval: Optional[timedelta] = Field(
|
|
64
|
-
default=None, validate_default=True
|
|
65
|
-
)
|
|
57
|
+
train_time_interval: Optional[timedelta] = Field(default=None)
|
|
66
58
|
"""Checkpoints are monitored at the specified time interval."""
|
|
67
59
|
|
|
68
|
-
every_n_epochs: Optional[int] = Field(
|
|
69
|
-
default=None, ge=1, le=10, validate_default=True
|
|
70
|
-
)
|
|
60
|
+
every_n_epochs: Optional[int] = Field(default=None, ge=1, le=100)
|
|
71
61
|
"""Number of epochs between checkpoints."""
|
|
72
62
|
|
|
73
63
|
|
|
@@ -83,41 +73,40 @@ class EarlyStoppingModel(BaseModel):
|
|
|
83
73
|
|
|
84
74
|
model_config = ConfigDict(
|
|
85
75
|
validate_assignment=True,
|
|
76
|
+
validate_default=True,
|
|
86
77
|
)
|
|
87
78
|
|
|
88
|
-
monitor: Literal["val_loss"] = Field(default="val_loss"
|
|
79
|
+
monitor: Literal["val_loss"] = Field(default="val_loss")
|
|
89
80
|
"""Quantity to monitor."""
|
|
90
81
|
|
|
91
|
-
min_delta: float = Field(default=0.0, ge=0.0, le=1.0
|
|
82
|
+
min_delta: float = Field(default=0.0, ge=0.0, le=1.0)
|
|
92
83
|
"""Minimum change in the monitored quantity to qualify as an improvement, i.e. an
|
|
93
84
|
absolute change of less than or equal to min_delta, will count as no improvement."""
|
|
94
85
|
|
|
95
|
-
patience: int = Field(default=3, ge=1, le=10
|
|
86
|
+
patience: int = Field(default=3, ge=1, le=10)
|
|
96
87
|
"""Number of checks with no improvement after which training will be stopped."""
|
|
97
88
|
|
|
98
|
-
verbose: bool = Field(default=False
|
|
89
|
+
verbose: bool = Field(default=False)
|
|
99
90
|
"""Verbosity mode."""
|
|
100
91
|
|
|
101
|
-
mode: Literal["min", "max", "auto"] = Field(default="min"
|
|
92
|
+
mode: Literal["min", "max", "auto"] = Field(default="min")
|
|
102
93
|
"""One of {min, max, auto}."""
|
|
103
94
|
|
|
104
|
-
check_finite: bool = Field(default=True
|
|
95
|
+
check_finite: bool = Field(default=True)
|
|
105
96
|
"""When `True`, stops training when the monitored quantity becomes `NaN` or
|
|
106
97
|
`inf`."""
|
|
107
98
|
|
|
108
|
-
stopping_threshold: Optional[float] = Field(default=None
|
|
99
|
+
stopping_threshold: Optional[float] = Field(default=None)
|
|
109
100
|
"""Stop training immediately once the monitored quantity reaches this threshold."""
|
|
110
101
|
|
|
111
|
-
divergence_threshold: Optional[float] = Field(default=None
|
|
102
|
+
divergence_threshold: Optional[float] = Field(default=None)
|
|
112
103
|
"""Stop training as soon as the monitored quantity becomes worse than this
|
|
113
104
|
threshold."""
|
|
114
105
|
|
|
115
|
-
check_on_train_epoch_end: Optional[bool] = Field(
|
|
116
|
-
default=False, validate_default=True
|
|
117
|
-
)
|
|
106
|
+
check_on_train_epoch_end: Optional[bool] = Field(default=False)
|
|
118
107
|
"""Whether to run early stopping at the end of the training epoch. If this is
|
|
119
108
|
`False`, then the check runs at the end of the validation."""
|
|
120
109
|
|
|
121
|
-
log_rank_zero_only: bool = Field(default=False
|
|
110
|
+
log_rank_zero_only: bool = Field(default=False)
|
|
122
111
|
"""When set `True`, logs the status of the early stopping callback only for rank 0
|
|
123
112
|
process."""
|
|
@@ -3,11 +3,14 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import re
|
|
6
|
+
from collections.abc import Callable
|
|
6
7
|
from pprint import pformat
|
|
7
8
|
from typing import Any, Literal, Union
|
|
8
9
|
|
|
10
|
+
import numpy as np
|
|
9
11
|
from bioimageio.spec.generic.v0_3 import CiteEntry
|
|
10
12
|
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
|
13
|
+
from pydantic.main import IncEx
|
|
11
14
|
from typing_extensions import Self
|
|
12
15
|
|
|
13
16
|
from careamics.config.algorithms import (
|
|
@@ -182,6 +185,50 @@ class Configuration(BaseModel):
|
|
|
182
185
|
|
|
183
186
|
return name
|
|
184
187
|
|
|
188
|
+
@model_validator(mode="after")
|
|
189
|
+
def validate_n2v_mask_pixel_perc(self: Self) -> Self:
|
|
190
|
+
"""
|
|
191
|
+
Validate that there will always be at least one blind-spot pixel in every patch.
|
|
192
|
+
|
|
193
|
+
The probability of creating a blind-spot pixel is a function of the chosen
|
|
194
|
+
masked pixel percentage and patch size.
|
|
195
|
+
|
|
196
|
+
Returns
|
|
197
|
+
-------
|
|
198
|
+
Self
|
|
199
|
+
Validated configuration.
|
|
200
|
+
|
|
201
|
+
Raises
|
|
202
|
+
------
|
|
203
|
+
ValueError
|
|
204
|
+
If the probability of masking a pixel within a patch is less than 1 for the
|
|
205
|
+
chosen masked pixel percentage and patch size.
|
|
206
|
+
"""
|
|
207
|
+
# No validation needed for non n2v algorithms
|
|
208
|
+
if not isinstance(self.algorithm_config, N2VAlgorithm):
|
|
209
|
+
return self
|
|
210
|
+
|
|
211
|
+
mask_pixel_perc = self.algorithm_config.n2v_config.masked_pixel_percentage
|
|
212
|
+
patch_size = self.data_config.patch_size
|
|
213
|
+
expected_area_per_pixel = 1 / (mask_pixel_perc / 100)
|
|
214
|
+
|
|
215
|
+
n_dims = 3 if self.algorithm_config.model.is_3D() else 2
|
|
216
|
+
patch_size_lower_bound = int(np.ceil(expected_area_per_pixel ** (1 / n_dims)))
|
|
217
|
+
required_patch_size = tuple(
|
|
218
|
+
2 ** int(np.ceil(np.log2(patch_size_lower_bound))) for _ in range(n_dims)
|
|
219
|
+
)
|
|
220
|
+
required_mask_pixel_perc = (1 / np.prod(patch_size)) * 100
|
|
221
|
+
if expected_area_per_pixel > np.prod(patch_size):
|
|
222
|
+
raise ValueError(
|
|
223
|
+
"The probability of creating a blind-spot pixel within a patch is "
|
|
224
|
+
f"below 1, for a patch size of {patch_size} with a masked pixel "
|
|
225
|
+
f"percentage of {mask_pixel_perc}%. Either increase the patch size to "
|
|
226
|
+
f"{required_patch_size} or increase the masked pixel percentage to "
|
|
227
|
+
f"at least {required_mask_pixel_perc}%."
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
return self
|
|
231
|
+
|
|
185
232
|
@model_validator(mode="after")
|
|
186
233
|
def validate_3D(self: Self) -> Self:
|
|
187
234
|
"""
|
|
@@ -297,17 +344,18 @@ class Configuration(BaseModel):
|
|
|
297
344
|
self,
|
|
298
345
|
*,
|
|
299
346
|
mode: Literal["json", "python"] | str = "python",
|
|
300
|
-
include:
|
|
301
|
-
exclude:
|
|
347
|
+
include: IncEx | None = None,
|
|
348
|
+
exclude: IncEx | None = None,
|
|
302
349
|
context: Any | None = None,
|
|
303
|
-
by_alias: bool = False,
|
|
350
|
+
by_alias: bool | None = False,
|
|
304
351
|
exclude_unset: bool = False,
|
|
305
352
|
exclude_defaults: bool = False,
|
|
306
353
|
exclude_none: bool = True,
|
|
307
354
|
round_trip: bool = False,
|
|
308
355
|
warnings: bool | Literal["none", "warn", "error"] = True,
|
|
356
|
+
fallback: Callable[[Any], Any] | None = None,
|
|
309
357
|
serialize_as_any: bool = False,
|
|
310
|
-
) -> dict:
|
|
358
|
+
) -> dict[str, Any]:
|
|
311
359
|
"""
|
|
312
360
|
Override model_dump method in order to set default values.
|
|
313
361
|
|
|
@@ -337,6 +385,8 @@ class Configuration(BaseModel):
|
|
|
337
385
|
representation.
|
|
338
386
|
warnings : bool | Literal['none', 'warn', 'error'], default=True
|
|
339
387
|
Whether to emit warnings.
|
|
388
|
+
fallback : Callable[[Any], Any] | None, default=None
|
|
389
|
+
A function to call when an unknown value is encountered.
|
|
340
390
|
serialize_as_any : bool, default=False
|
|
341
391
|
Whether to serialize all types as Any.
|
|
342
392
|
|
|
@@ -356,6 +406,7 @@ class Configuration(BaseModel):
|
|
|
356
406
|
exclude_none=exclude_none,
|
|
357
407
|
round_trip=round_trip,
|
|
358
408
|
warnings=warnings,
|
|
409
|
+
fallback=fallback,
|
|
359
410
|
serialize_as_any=serialize_as_any,
|
|
360
411
|
)
|
|
361
412
|
|