careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +14 -4
- careamics/callbacks/__init__.py +6 -0
- careamics/callbacks/hyperparameters_callback.py +42 -0
- careamics/callbacks/progress_bar_callback.py +57 -0
- careamics/careamist.py +761 -0
- careamics/config/__init__.py +27 -3
- careamics/config/algorithm_model.py +167 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +29 -0
- careamics/config/architectures/custom_model.py +150 -0
- careamics/config/architectures/register_model.py +101 -0
- careamics/config/architectures/unet_model.py +96 -0
- careamics/config/architectures/vae_model.py +39 -0
- careamics/config/callback_model.py +92 -0
- careamics/config/configuration_factory.py +460 -0
- careamics/config/configuration_model.py +596 -0
- careamics/config/data_model.py +555 -0
- careamics/config/inference_model.py +283 -0
- careamics/config/noise_models.py +162 -0
- careamics/config/optimizer_models.py +181 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +131 -0
- careamics/config/references/references.py +38 -0
- careamics/config/support/__init__.py +33 -0
- careamics/config/support/supported_activations.py +24 -0
- careamics/config/support/supported_algorithms.py +18 -0
- careamics/config/support/supported_architectures.py +18 -0
- careamics/config/support/supported_data.py +82 -0
- careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
- careamics/config/support/supported_loggers.py +8 -0
- careamics/config/support/supported_losses.py +25 -0
- careamics/config/support/supported_optimizers.py +55 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +19 -0
- careamics/config/support/supported_transforms.py +23 -0
- careamics/config/tile_information.py +104 -0
- careamics/config/training_model.py +65 -0
- careamics/config/transformations/__init__.py +14 -0
- careamics/config/transformations/n2v_manipulate_model.py +63 -0
- careamics/config/transformations/nd_flip_model.py +32 -0
- careamics/config/transformations/normalize_model.py +31 -0
- careamics/config/transformations/transform_model.py +44 -0
- careamics/config/transformations/xy_random_rotate90_model.py +29 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +100 -0
- careamics/conftest.py +26 -0
- careamics/dataset/__init__.py +5 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +100 -0
- careamics/dataset/dataset_utils/file_utils.py +140 -0
- careamics/dataset/dataset_utils/read_tiff.py +61 -0
- careamics/dataset/dataset_utils/read_utils.py +25 -0
- careamics/dataset/dataset_utils/read_zarr.py +56 -0
- careamics/dataset/in_memory_dataset.py +323 -134
- careamics/dataset/iterable_dataset.py +416 -0
- careamics/dataset/patching/__init__.py +8 -0
- careamics/dataset/patching/patch_transform.py +44 -0
- careamics/dataset/patching/patching.py +212 -0
- careamics/dataset/patching/random_patching.py +190 -0
- careamics/dataset/patching/sequential_patching.py +206 -0
- careamics/dataset/patching/tiled_patching.py +158 -0
- careamics/dataset/patching/validate_patch_dimension.py +60 -0
- careamics/dataset/zarr_dataset.py +149 -0
- careamics/lightning_datamodule.py +665 -0
- careamics/lightning_module.py +292 -0
- careamics/lightning_prediction_datamodule.py +390 -0
- careamics/lightning_prediction_loop.py +116 -0
- careamics/losses/__init__.py +4 -1
- careamics/losses/loss_factory.py +24 -14
- careamics/losses/losses.py +65 -5
- careamics/losses/noise_model_factory.py +40 -0
- careamics/losses/noise_models.py +524 -0
- careamics/model_io/__init__.py +8 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +120 -0
- careamics/model_io/bioimage/bioimage_utils.py +48 -0
- careamics/model_io/bioimage/model_description.py +318 -0
- careamics/model_io/bmz_io.py +231 -0
- careamics/model_io/model_io_utils.py +80 -0
- careamics/models/__init__.py +4 -1
- careamics/models/activation.py +35 -0
- careamics/models/layers.py +244 -0
- careamics/models/model_factory.py +21 -221
- careamics/models/unet.py +46 -20
- careamics/prediction/__init__.py +1 -3
- careamics/prediction/stitch_prediction.py +73 -0
- careamics/transforms/__init__.py +41 -0
- careamics/transforms/n2v_manipulate.py +113 -0
- careamics/transforms/nd_flip.py +93 -0
- careamics/transforms/normalize.py +109 -0
- careamics/transforms/pixel_manipulation.py +383 -0
- careamics/transforms/struct_mask_parameters.py +18 -0
- careamics/transforms/tta.py +74 -0
- careamics/transforms/xy_random_rotate90.py +95 -0
- careamics/utils/__init__.py +10 -12
- careamics/utils/base_enum.py +32 -0
- careamics/utils/context.py +22 -2
- careamics/utils/metrics.py +0 -46
- careamics/utils/path_utils.py +24 -0
- careamics/utils/ram.py +13 -0
- careamics/utils/receptive_field.py +102 -0
- careamics/utils/running_stats.py +43 -0
- careamics/utils/torch_utils.py +112 -75
- careamics-0.1.0rc3.dist-info/METADATA +122 -0
- careamics-0.1.0rc3.dist-info/RECORD +109 -0
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc3.dist-info}/WHEEL +1 -1
- careamics/bioimage/__init__.py +0 -15
- careamics/bioimage/docs/Noise2Void.md +0 -5
- careamics/bioimage/docs/__init__.py +0 -1
- careamics/bioimage/io.py +0 -182
- careamics/bioimage/rdf.py +0 -105
- careamics/config/algorithm.py +0 -231
- careamics/config/config.py +0 -297
- careamics/config/config_filter.py +0 -44
- careamics/config/data.py +0 -194
- careamics/config/torch_optim.py +0 -118
- careamics/config/training.py +0 -534
- careamics/dataset/dataset_utils.py +0 -111
- careamics/dataset/patching.py +0 -492
- careamics/dataset/prepare_dataset.py +0 -175
- careamics/dataset/tiff_dataset.py +0 -212
- careamics/engine.py +0 -1014
- careamics/manipulation/__init__.py +0 -4
- careamics/manipulation/pixel_manipulation.py +0 -158
- careamics/prediction/prediction_utils.py +0 -106
- careamics/utils/ascii_logo.txt +0 -9
- careamics/utils/augment.py +0 -65
- careamics/utils/normalization.py +0 -55
- careamics/utils/validators.py +0 -170
- careamics/utils/wandb.py +0 -121
- careamics-0.1.0rc2.dist-info/METADATA +0 -81
- careamics-0.1.0rc2.dist-info/RECORD +0 -47
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc3.dist-info}/licenses/LICENSE +0 -0
careamics/bioimage/__init__.py
DELETED
|
@@ -1,15 +0,0 @@
|
|
|
1
|
-
"""Provide utilities for exporting models to BioImage model zoo."""
|
|
2
|
-
|
|
3
|
-
__all__ = [
|
|
4
|
-
"save_bioimage_model",
|
|
5
|
-
"import_bioimage_model",
|
|
6
|
-
"get_default_model_specs",
|
|
7
|
-
"PYTORCH_STATE_DICT",
|
|
8
|
-
]
|
|
9
|
-
|
|
10
|
-
from .io import (
|
|
11
|
-
PYTORCH_STATE_DICT,
|
|
12
|
-
import_bioimage_model,
|
|
13
|
-
save_bioimage_model,
|
|
14
|
-
)
|
|
15
|
-
from .rdf import get_default_model_specs
|
|
@@ -1,5 +0,0 @@
|
|
|
1
|
-
## Noise2Void
|
|
2
|
-
Learning Denoising From Single Noisy Images
|
|
3
|
-
|
|
4
|
-
## Cite Noise2Void
|
|
5
|
-
A. Krull, T.-O. Buchholz and F. Jug, "Noise2Void - Learning Denoising From Single Noisy Images," 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 2019, pp. 2124-2132
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
"""Default algorithm READMEs for bioimage.io format export."""
|
careamics/bioimage/io.py
DELETED
|
@@ -1,182 +0,0 @@
|
|
|
1
|
-
"""Export to bioimage.io format."""
|
|
2
|
-
from pathlib import Path
|
|
3
|
-
from typing import Union
|
|
4
|
-
|
|
5
|
-
import torch
|
|
6
|
-
from bioimageio.core import load_resource_description
|
|
7
|
-
from bioimageio.core.build_spec import build_model
|
|
8
|
-
|
|
9
|
-
from careamics.config.config import Configuration
|
|
10
|
-
from careamics.utils.context import cwd
|
|
11
|
-
|
|
12
|
-
PYTORCH_STATE_DICT = "pytorch_state_dict"
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
def save_bioimage_model(
|
|
16
|
-
path: Union[str, Path],
|
|
17
|
-
config: Configuration,
|
|
18
|
-
specs: dict,
|
|
19
|
-
) -> None:
|
|
20
|
-
"""
|
|
21
|
-
Build bioimage model zip file from model RDF data.
|
|
22
|
-
|
|
23
|
-
Parameters
|
|
24
|
-
----------
|
|
25
|
-
path : Union[str, Path]
|
|
26
|
-
Path to the model zip file.
|
|
27
|
-
config : Configuration
|
|
28
|
-
Configuration object.
|
|
29
|
-
specs : dict
|
|
30
|
-
Model RDF dict.
|
|
31
|
-
"""
|
|
32
|
-
workdir = config.working_directory
|
|
33
|
-
|
|
34
|
-
# temporary folder
|
|
35
|
-
temp_folder = Path.home().joinpath(".careamics", "bmz_tmp")
|
|
36
|
-
temp_folder.mkdir(exist_ok=True, parents=True)
|
|
37
|
-
|
|
38
|
-
# change working directory to the temp folder
|
|
39
|
-
with cwd(temp_folder):
|
|
40
|
-
# load best checkpoint
|
|
41
|
-
checkpoint_path = workdir.joinpath(
|
|
42
|
-
f"{config.experiment_name}_best.pth"
|
|
43
|
-
).absolute()
|
|
44
|
-
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
|
45
|
-
|
|
46
|
-
# save chekpoint entries in separate files
|
|
47
|
-
weight_path = Path("model_weights.pth")
|
|
48
|
-
torch.save(checkpoint["model_state_dict"], weight_path)
|
|
49
|
-
|
|
50
|
-
optim_path = Path("optim.pth")
|
|
51
|
-
torch.save(checkpoint["optimizer_state_dict"], optim_path)
|
|
52
|
-
|
|
53
|
-
scheduler_path = Path("scheduler.pth")
|
|
54
|
-
torch.save(checkpoint["scheduler_state_dict"], scheduler_path)
|
|
55
|
-
|
|
56
|
-
grad_path = Path("grad.pth")
|
|
57
|
-
torch.save(checkpoint["grad_scaler_state_dict"], grad_path)
|
|
58
|
-
|
|
59
|
-
config_path = Path("config.pth")
|
|
60
|
-
torch.save(config.model_dump(), config_path)
|
|
61
|
-
|
|
62
|
-
# create attachments
|
|
63
|
-
attachments = [
|
|
64
|
-
str(optim_path),
|
|
65
|
-
str(scheduler_path),
|
|
66
|
-
str(grad_path),
|
|
67
|
-
str(config_path),
|
|
68
|
-
]
|
|
69
|
-
|
|
70
|
-
# create requirements file
|
|
71
|
-
requirements = Path("requirements.txt")
|
|
72
|
-
with open(requirements, "w") as f:
|
|
73
|
-
f.write("git+https://github.com/CAREamics/careamics.git")
|
|
74
|
-
|
|
75
|
-
algo_config = config.algorithm
|
|
76
|
-
specs.update(
|
|
77
|
-
{
|
|
78
|
-
"weight_type": PYTORCH_STATE_DICT,
|
|
79
|
-
"weight_uri": str(weight_path),
|
|
80
|
-
"architecture": "careamics.models.unet.UNet",
|
|
81
|
-
"pytorch_version": torch.__version__,
|
|
82
|
-
"model_kwargs": {
|
|
83
|
-
"conv_dim": algo_config.get_conv_dim(),
|
|
84
|
-
"depth": algo_config.model_parameters.depth,
|
|
85
|
-
"num_channels_init": algo_config.model_parameters.num_channels_init,
|
|
86
|
-
},
|
|
87
|
-
"dependencies": "pip:" + str(requirements),
|
|
88
|
-
"attachments": {"files": attachments},
|
|
89
|
-
}
|
|
90
|
-
)
|
|
91
|
-
|
|
92
|
-
if config.algorithm.is_3D:
|
|
93
|
-
specs["tags"].append("3D")
|
|
94
|
-
else:
|
|
95
|
-
specs["tags"].append("2D")
|
|
96
|
-
|
|
97
|
-
# build model zip
|
|
98
|
-
build_model(
|
|
99
|
-
output_path=Path(path).absolute(),
|
|
100
|
-
**specs,
|
|
101
|
-
)
|
|
102
|
-
|
|
103
|
-
# remove temporary files
|
|
104
|
-
for file in temp_folder.glob("*"):
|
|
105
|
-
file.unlink()
|
|
106
|
-
|
|
107
|
-
# delete temporary folder
|
|
108
|
-
temp_folder.rmdir()
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
def import_bioimage_model(model_path: Union[str, Path]) -> Path:
|
|
112
|
-
"""
|
|
113
|
-
Load configuration and weights from a bioimage zip model.
|
|
114
|
-
|
|
115
|
-
Parameters
|
|
116
|
-
----------
|
|
117
|
-
model_path : Union[str, Path]
|
|
118
|
-
Path to the bioimage.io archive.
|
|
119
|
-
|
|
120
|
-
Returns
|
|
121
|
-
-------
|
|
122
|
-
Path
|
|
123
|
-
Path to the checkpoint.
|
|
124
|
-
|
|
125
|
-
Raises
|
|
126
|
-
------
|
|
127
|
-
ValueError
|
|
128
|
-
If the model format is invalid.
|
|
129
|
-
FileNotFoundError
|
|
130
|
-
If the checkpoint file was not found.
|
|
131
|
-
"""
|
|
132
|
-
model_path = Path(model_path)
|
|
133
|
-
|
|
134
|
-
# check the model extension (should be a zip file).
|
|
135
|
-
if model_path.suffix != ".zip":
|
|
136
|
-
raise ValueError("Invalid model format. Expected bioimage model zip file.")
|
|
137
|
-
|
|
138
|
-
# load the model
|
|
139
|
-
rdf = load_resource_description(model_path)
|
|
140
|
-
|
|
141
|
-
# create a valid checkpoint file from weights and attached files
|
|
142
|
-
basedir = model_path.parent.joinpath("rdf_model")
|
|
143
|
-
basedir.mkdir(exist_ok=True)
|
|
144
|
-
optim_path = None
|
|
145
|
-
scheduler_path = None
|
|
146
|
-
grad_path = None
|
|
147
|
-
config_path = None
|
|
148
|
-
weight_path = None
|
|
149
|
-
|
|
150
|
-
if rdf.weights.get(PYTORCH_STATE_DICT) is not None:
|
|
151
|
-
weight_path = rdf.weights.get(PYTORCH_STATE_DICT).source
|
|
152
|
-
|
|
153
|
-
for file in rdf.attachments.files:
|
|
154
|
-
if file.name.endswith("optim.pth"):
|
|
155
|
-
optim_path = file
|
|
156
|
-
elif file.name.endswith("scheduler.pth"):
|
|
157
|
-
scheduler_path = file
|
|
158
|
-
elif file.name.endswith("grad.pth"):
|
|
159
|
-
grad_path = file
|
|
160
|
-
elif file.name.endswith("config.pth"):
|
|
161
|
-
config_path = file
|
|
162
|
-
|
|
163
|
-
if (
|
|
164
|
-
weight_path is None
|
|
165
|
-
or optim_path is None
|
|
166
|
-
or scheduler_path is None
|
|
167
|
-
or grad_path is None
|
|
168
|
-
or config_path is None
|
|
169
|
-
):
|
|
170
|
-
raise FileNotFoundError(f"No valid checkpoint file was found in {model_path}.")
|
|
171
|
-
|
|
172
|
-
checkpoint = {
|
|
173
|
-
"model_state_dict": torch.load(weight_path, map_location="cpu"),
|
|
174
|
-
"optimizer_state_dict": torch.load(optim_path, map_location="cpu"),
|
|
175
|
-
"scheduler_state_dict": torch.load(scheduler_path, map_location="cpu"),
|
|
176
|
-
"grad_scaler_state_dict": torch.load(grad_path, map_location="cpu"),
|
|
177
|
-
"config": torch.load(config_path, map_location="cpu"),
|
|
178
|
-
}
|
|
179
|
-
checkpoint_path = basedir.joinpath("checkpoint.pth")
|
|
180
|
-
torch.save(checkpoint, checkpoint_path)
|
|
181
|
-
|
|
182
|
-
return checkpoint_path
|
careamics/bioimage/rdf.py
DELETED
|
@@ -1,105 +0,0 @@
|
|
|
1
|
-
"""RDF related methods."""
|
|
2
|
-
from pathlib import Path
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
def _get_model_doc(name: str) -> str:
|
|
6
|
-
"""
|
|
7
|
-
Return markdown documentation path for the provided model.
|
|
8
|
-
|
|
9
|
-
Parameters
|
|
10
|
-
----------
|
|
11
|
-
name : str
|
|
12
|
-
Model's name.
|
|
13
|
-
|
|
14
|
-
Returns
|
|
15
|
-
-------
|
|
16
|
-
str
|
|
17
|
-
Path to the model's markdown documentation.
|
|
18
|
-
|
|
19
|
-
Raises
|
|
20
|
-
------
|
|
21
|
-
FileNotFoundError
|
|
22
|
-
If the documentation file was not found.
|
|
23
|
-
"""
|
|
24
|
-
doc = Path(__file__).parent.joinpath("docs").joinpath(f"{name}.md")
|
|
25
|
-
if doc.exists():
|
|
26
|
-
return str(doc.absolute())
|
|
27
|
-
else:
|
|
28
|
-
raise FileNotFoundError(f"Documentation for {name} was not found.")
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
def get_default_model_specs(
|
|
32
|
-
name: str, mean: float, std: float, is_3D: bool = False
|
|
33
|
-
) -> dict:
|
|
34
|
-
"""
|
|
35
|
-
Return the default bioimage.io specs for the provided model's name.
|
|
36
|
-
|
|
37
|
-
Currently only supports `Noise2Void` model.
|
|
38
|
-
|
|
39
|
-
Parameters
|
|
40
|
-
----------
|
|
41
|
-
name : str
|
|
42
|
-
Algorithm's name.
|
|
43
|
-
mean : float
|
|
44
|
-
Mean of the dataset.
|
|
45
|
-
std : float
|
|
46
|
-
Std of the dataset.
|
|
47
|
-
is_3D : bool, optional
|
|
48
|
-
Whether the model is 3D or not, by default False.
|
|
49
|
-
|
|
50
|
-
Returns
|
|
51
|
-
-------
|
|
52
|
-
dict
|
|
53
|
-
Model specs compatible with bioimage.io export.
|
|
54
|
-
"""
|
|
55
|
-
rdf = {
|
|
56
|
-
"name": "Noise2Void",
|
|
57
|
-
"description": "Self-supervised denoising.",
|
|
58
|
-
"license": "BSD-3-Clause",
|
|
59
|
-
"authors": [
|
|
60
|
-
{"name": "Alexander Krull"},
|
|
61
|
-
{"name": "Tim-Oliver Buchholz"},
|
|
62
|
-
{"name": "Florian Jug"},
|
|
63
|
-
],
|
|
64
|
-
"cite": [
|
|
65
|
-
{
|
|
66
|
-
"doi": "10.48550/arXiv.1811.10980",
|
|
67
|
-
"text": 'A. Krull, T.-O. Buchholz and F. Jug, "Noise2Void - Learning '
|
|
68
|
-
'Denoising From Single Noisy Images," 2019 IEEE/CVF '
|
|
69
|
-
"Conference on Computer Vision and Pattern Recognition "
|
|
70
|
-
"(CVPR), 2019, pp. 2124-2132",
|
|
71
|
-
}
|
|
72
|
-
],
|
|
73
|
-
# "input_axes": ["bcyx"], <- overriden in save_as_bioimage
|
|
74
|
-
"preprocessing": [ # for multiple inputs
|
|
75
|
-
[ # multiple processes per input
|
|
76
|
-
{
|
|
77
|
-
"kwargs": {
|
|
78
|
-
"axes": "zyx" if is_3D else "yx",
|
|
79
|
-
"mean": [mean],
|
|
80
|
-
"mode": "fixed",
|
|
81
|
-
"std": [std],
|
|
82
|
-
},
|
|
83
|
-
"name": "zero_mean_unit_variance",
|
|
84
|
-
}
|
|
85
|
-
]
|
|
86
|
-
],
|
|
87
|
-
# "output_axes": ["bcyx"], <- overriden in save_as_bioimage
|
|
88
|
-
"postprocessing": [ # for multiple outputs
|
|
89
|
-
[ # multiple processes per input
|
|
90
|
-
{
|
|
91
|
-
"kwargs": {
|
|
92
|
-
"axes": "zyx" if is_3D else "yx",
|
|
93
|
-
"gain": [std],
|
|
94
|
-
"offset": [mean],
|
|
95
|
-
},
|
|
96
|
-
"name": "scale_linear",
|
|
97
|
-
}
|
|
98
|
-
]
|
|
99
|
-
],
|
|
100
|
-
"tags": ["unet", "denoising", "Noise2Void", "tensorflow", "napari"],
|
|
101
|
-
}
|
|
102
|
-
|
|
103
|
-
rdf["documentation"] = _get_model_doc(name)
|
|
104
|
-
|
|
105
|
-
return rdf
|
careamics/config/algorithm.py
DELETED
|
@@ -1,231 +0,0 @@
|
|
|
1
|
-
"""Algorithm configuration."""
|
|
2
|
-
from enum import Enum
|
|
3
|
-
from typing import Dict, List
|
|
4
|
-
|
|
5
|
-
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
6
|
-
|
|
7
|
-
from .config_filter import remove_default_optionals
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
# python 3.11: https://docs.python.org/3/library/enum.html
|
|
11
|
-
class Loss(str, Enum):
|
|
12
|
-
"""
|
|
13
|
-
Available loss functions.
|
|
14
|
-
|
|
15
|
-
Currently supported losses:
|
|
16
|
-
|
|
17
|
-
- n2v: Noise2Void loss.
|
|
18
|
-
"""
|
|
19
|
-
|
|
20
|
-
N2V = "n2v"
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
class Models(str, Enum):
|
|
24
|
-
"""
|
|
25
|
-
Available models.
|
|
26
|
-
|
|
27
|
-
Currently supported models:
|
|
28
|
-
- UNet: U-Net model.
|
|
29
|
-
"""
|
|
30
|
-
|
|
31
|
-
UNET = "UNet"
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
class MaskingStrategy(str, Enum):
|
|
35
|
-
"""
|
|
36
|
-
Available masking strategy.
|
|
37
|
-
|
|
38
|
-
Currently supported strategies:
|
|
39
|
-
|
|
40
|
-
- default: default masking strategy of Noise2Void (uniform sampling of neighbors).
|
|
41
|
-
- median: median masking strategy of N2V2.
|
|
42
|
-
"""
|
|
43
|
-
|
|
44
|
-
DEFAULT = "default"
|
|
45
|
-
MEDIAN = "median"
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
class ModelParameters(BaseModel):
|
|
49
|
-
"""
|
|
50
|
-
Deep-learning model parameters.
|
|
51
|
-
|
|
52
|
-
The number of filters (base) must be even and minimum 8.
|
|
53
|
-
|
|
54
|
-
Attributes
|
|
55
|
-
----------
|
|
56
|
-
depth : int
|
|
57
|
-
Depth of the model, between 1 and 10 (default 2).
|
|
58
|
-
num_channels_init : int
|
|
59
|
-
Number of filters of the first level of the network, should be even
|
|
60
|
-
and minimum 8 (default 96).
|
|
61
|
-
"""
|
|
62
|
-
|
|
63
|
-
model_config = ConfigDict(validate_assignment=True)
|
|
64
|
-
|
|
65
|
-
depth: int = Field(default=2, ge=1, le=10)
|
|
66
|
-
num_channels_init: int = Field(default=32, ge=8)
|
|
67
|
-
|
|
68
|
-
# TODO revisit the constraints on num_channels_init
|
|
69
|
-
@field_validator("num_channels_init")
|
|
70
|
-
def even(cls, num_channels: int) -> int:
|
|
71
|
-
"""
|
|
72
|
-
Validate that num_channels_init is even.
|
|
73
|
-
|
|
74
|
-
Parameters
|
|
75
|
-
----------
|
|
76
|
-
num_channels : int
|
|
77
|
-
Number of channels.
|
|
78
|
-
|
|
79
|
-
Returns
|
|
80
|
-
-------
|
|
81
|
-
int
|
|
82
|
-
Validated number of channels.
|
|
83
|
-
|
|
84
|
-
Raises
|
|
85
|
-
------
|
|
86
|
-
ValueError
|
|
87
|
-
If the number of channels is odd.
|
|
88
|
-
"""
|
|
89
|
-
# if odd
|
|
90
|
-
if num_channels % 2 != 0:
|
|
91
|
-
raise ValueError(
|
|
92
|
-
f"Number of channels (init) must be even (got {num_channels})."
|
|
93
|
-
)
|
|
94
|
-
|
|
95
|
-
return num_channels
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
class Algorithm(BaseModel):
|
|
99
|
-
"""
|
|
100
|
-
Algorithm configuration.
|
|
101
|
-
|
|
102
|
-
The minimum algorithm configuration is composed of the following fields:
|
|
103
|
-
- loss:
|
|
104
|
-
Loss to use, currently only supports n2v.
|
|
105
|
-
- model:
|
|
106
|
-
Model to use, currently only supports UNet.
|
|
107
|
-
- is_3D:
|
|
108
|
-
Whether to use a 3D model or not, this should be coherent with the
|
|
109
|
-
data configuration (axes).
|
|
110
|
-
|
|
111
|
-
Other optional fields are:
|
|
112
|
-
- masking_strategy:
|
|
113
|
-
Masking strategy to use, currently only supports default masking.
|
|
114
|
-
- masked_pixel_percentage:
|
|
115
|
-
Percentage of pixels to be masked in each patch.
|
|
116
|
-
- roi_size:
|
|
117
|
-
Size of the region of interest to use in the masking algorithm.
|
|
118
|
-
- model_parameters:
|
|
119
|
-
Model parameters, see ModelParameters for more details.
|
|
120
|
-
|
|
121
|
-
Attributes
|
|
122
|
-
----------
|
|
123
|
-
loss : List[Losses]
|
|
124
|
-
List of losses to use, currently only supports n2v.
|
|
125
|
-
model : Models
|
|
126
|
-
Model to use, currently only supports UNet.
|
|
127
|
-
is_3D : bool
|
|
128
|
-
Whether to use a 3D model or not.
|
|
129
|
-
masking_strategy : MaskingStrategies
|
|
130
|
-
Masking strategy to use, currently only supports default masking.
|
|
131
|
-
masked_pixel_percentage : float
|
|
132
|
-
Percentage of pixels to be masked in each patch.
|
|
133
|
-
roi_size : int
|
|
134
|
-
Size of the region of interest used in the masking scheme.
|
|
135
|
-
model_parameters : ModelParameters
|
|
136
|
-
Model parameters, see ModelParameters for more details.
|
|
137
|
-
"""
|
|
138
|
-
|
|
139
|
-
# Pydantic class configuration
|
|
140
|
-
model_config = ConfigDict(
|
|
141
|
-
use_enum_values=True,
|
|
142
|
-
protected_namespaces=(), # allows to use model_* as a field name
|
|
143
|
-
validate_assignment=True,
|
|
144
|
-
)
|
|
145
|
-
|
|
146
|
-
# Mandatory fields
|
|
147
|
-
loss: Loss
|
|
148
|
-
model: Models
|
|
149
|
-
is_3D: bool
|
|
150
|
-
|
|
151
|
-
# Optional fields, define a default value
|
|
152
|
-
masking_strategy: MaskingStrategy = MaskingStrategy.DEFAULT
|
|
153
|
-
masked_pixel_percentage: float = Field(default=0.2, ge=0.1, le=20)
|
|
154
|
-
roi_size: int = Field(default=11, ge=3, le=21)
|
|
155
|
-
model_parameters: ModelParameters = ModelParameters()
|
|
156
|
-
|
|
157
|
-
def get_conv_dim(self) -> int:
|
|
158
|
-
"""
|
|
159
|
-
Get the convolution layers dimension (2D or 3D).
|
|
160
|
-
|
|
161
|
-
Returns
|
|
162
|
-
-------
|
|
163
|
-
int
|
|
164
|
-
Dimension (2 or 3).
|
|
165
|
-
"""
|
|
166
|
-
return 3 if self.is_3D else 2
|
|
167
|
-
|
|
168
|
-
@field_validator("roi_size")
|
|
169
|
-
def even(cls, roi_size: int) -> int:
|
|
170
|
-
"""
|
|
171
|
-
Validate that roi_size is odd.
|
|
172
|
-
|
|
173
|
-
Parameters
|
|
174
|
-
----------
|
|
175
|
-
roi_size : int
|
|
176
|
-
Size of the region of interest in the masking scheme.
|
|
177
|
-
|
|
178
|
-
Returns
|
|
179
|
-
-------
|
|
180
|
-
int
|
|
181
|
-
Validated size of the region of interest.
|
|
182
|
-
|
|
183
|
-
Raises
|
|
184
|
-
------
|
|
185
|
-
ValueError
|
|
186
|
-
If the size of the region of interest is even.
|
|
187
|
-
"""
|
|
188
|
-
# if even
|
|
189
|
-
if roi_size % 2 == 0:
|
|
190
|
-
raise ValueError(f"ROI size must be odd (got {roi_size}).")
|
|
191
|
-
|
|
192
|
-
return roi_size
|
|
193
|
-
|
|
194
|
-
def model_dump(
|
|
195
|
-
self, exclude_optionals: bool = True, *args: List, **kwargs: Dict
|
|
196
|
-
) -> Dict:
|
|
197
|
-
"""
|
|
198
|
-
Override model_dump method.
|
|
199
|
-
|
|
200
|
-
The purpose is to ensure export smooth import to yaml. It includes:
|
|
201
|
-
- remove entries with None value.
|
|
202
|
-
- remove optional values if they have the default value.
|
|
203
|
-
|
|
204
|
-
Parameters
|
|
205
|
-
----------
|
|
206
|
-
exclude_optionals : bool, optional
|
|
207
|
-
Whether to exclude optional arguments if they are default, by default True.
|
|
208
|
-
*args : List
|
|
209
|
-
Positional arguments, unused.
|
|
210
|
-
**kwargs : Dict
|
|
211
|
-
Keyword arguments, unused.
|
|
212
|
-
|
|
213
|
-
Returns
|
|
214
|
-
-------
|
|
215
|
-
Dict
|
|
216
|
-
Dictionary representation of the model.
|
|
217
|
-
"""
|
|
218
|
-
dictionary = super().model_dump(exclude_none=True)
|
|
219
|
-
|
|
220
|
-
if exclude_optionals is True:
|
|
221
|
-
# remove optional arguments if they are default
|
|
222
|
-
defaults = {
|
|
223
|
-
"masking_strategy": MaskingStrategy.DEFAULT.value,
|
|
224
|
-
"masked_pixel_percentage": 0.2,
|
|
225
|
-
"roi_size": 11,
|
|
226
|
-
"model_parameters": ModelParameters().model_dump(exclude_none=True),
|
|
227
|
-
}
|
|
228
|
-
|
|
229
|
-
remove_default_optionals(dictionary, defaults)
|
|
230
|
-
|
|
231
|
-
return dictionary
|