careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +16 -4
- careamics/callbacks/__init__.py +6 -0
- careamics/callbacks/hyperparameters_callback.py +42 -0
- careamics/callbacks/progress_bar_callback.py +57 -0
- careamics/careamist.py +761 -0
- careamics/config/__init__.py +31 -3
- careamics/config/algorithm_model.py +167 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +29 -0
- careamics/config/architectures/custom_model.py +150 -0
- careamics/config/architectures/register_model.py +101 -0
- careamics/config/architectures/unet_model.py +96 -0
- careamics/config/architectures/vae_model.py +39 -0
- careamics/config/callback_model.py +92 -0
- careamics/config/configuration_example.py +89 -0
- careamics/config/configuration_factory.py +597 -0
- careamics/config/configuration_model.py +597 -0
- careamics/config/data_model.py +555 -0
- careamics/config/inference_model.py +283 -0
- careamics/config/noise_models.py +162 -0
- careamics/config/optimizer_models.py +181 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +131 -0
- careamics/config/references/references.py +38 -0
- careamics/config/support/__init__.py +33 -0
- careamics/config/support/supported_activations.py +24 -0
- careamics/config/support/supported_algorithms.py +18 -0
- careamics/config/support/supported_architectures.py +18 -0
- careamics/config/support/supported_data.py +82 -0
- careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
- careamics/config/support/supported_loggers.py +8 -0
- careamics/config/support/supported_losses.py +25 -0
- careamics/config/support/supported_optimizers.py +55 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +19 -0
- careamics/config/support/supported_transforms.py +23 -0
- careamics/config/tile_information.py +104 -0
- careamics/config/training_model.py +65 -0
- careamics/config/transformations/__init__.py +14 -0
- careamics/config/transformations/n2v_manipulate_model.py +63 -0
- careamics/config/transformations/nd_flip_model.py +32 -0
- careamics/config/transformations/normalize_model.py +31 -0
- careamics/config/transformations/transform_model.py +44 -0
- careamics/config/transformations/xy_random_rotate90_model.py +29 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +100 -0
- careamics/conftest.py +26 -0
- careamics/dataset/__init__.py +5 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +100 -0
- careamics/dataset/dataset_utils/file_utils.py +140 -0
- careamics/dataset/dataset_utils/read_tiff.py +61 -0
- careamics/dataset/dataset_utils/read_utils.py +25 -0
- careamics/dataset/dataset_utils/read_zarr.py +56 -0
- careamics/dataset/in_memory_dataset.py +323 -134
- careamics/dataset/iterable_dataset.py +416 -0
- careamics/dataset/patching/__init__.py +8 -0
- careamics/dataset/patching/patch_transform.py +44 -0
- careamics/dataset/patching/patching.py +212 -0
- careamics/dataset/patching/random_patching.py +190 -0
- careamics/dataset/patching/sequential_patching.py +206 -0
- careamics/dataset/patching/tiled_patching.py +158 -0
- careamics/dataset/patching/validate_patch_dimension.py +60 -0
- careamics/dataset/zarr_dataset.py +149 -0
- careamics/lightning_datamodule.py +743 -0
- careamics/lightning_module.py +292 -0
- careamics/lightning_prediction_datamodule.py +396 -0
- careamics/lightning_prediction_loop.py +116 -0
- careamics/losses/__init__.py +4 -1
- careamics/losses/loss_factory.py +24 -14
- careamics/losses/losses.py +65 -5
- careamics/losses/noise_model_factory.py +40 -0
- careamics/losses/noise_models.py +524 -0
- careamics/model_io/__init__.py +8 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +120 -0
- careamics/model_io/bioimage/bioimage_utils.py +48 -0
- careamics/model_io/bioimage/model_description.py +318 -0
- careamics/model_io/bmz_io.py +231 -0
- careamics/model_io/model_io_utils.py +80 -0
- careamics/models/__init__.py +4 -1
- careamics/models/activation.py +35 -0
- careamics/models/layers.py +244 -0
- careamics/models/model_factory.py +21 -221
- careamics/models/unet.py +46 -20
- careamics/prediction/__init__.py +1 -3
- careamics/prediction/stitch_prediction.py +73 -0
- careamics/transforms/__init__.py +41 -0
- careamics/transforms/n2v_manipulate.py +113 -0
- careamics/transforms/nd_flip.py +93 -0
- careamics/transforms/normalize.py +109 -0
- careamics/transforms/pixel_manipulation.py +383 -0
- careamics/transforms/struct_mask_parameters.py +18 -0
- careamics/transforms/tta.py +74 -0
- careamics/transforms/xy_random_rotate90.py +95 -0
- careamics/utils/__init__.py +10 -12
- careamics/utils/base_enum.py +32 -0
- careamics/utils/context.py +22 -2
- careamics/utils/metrics.py +0 -46
- careamics/utils/path_utils.py +24 -0
- careamics/utils/ram.py +13 -0
- careamics/utils/receptive_field.py +102 -0
- careamics/utils/running_stats.py +43 -0
- careamics/utils/torch_utils.py +112 -75
- careamics-0.1.0rc4.dist-info/METADATA +122 -0
- careamics-0.1.0rc4.dist-info/RECORD +110 -0
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/WHEEL +1 -1
- careamics/bioimage/__init__.py +0 -15
- careamics/bioimage/docs/Noise2Void.md +0 -5
- careamics/bioimage/docs/__init__.py +0 -1
- careamics/bioimage/io.py +0 -182
- careamics/bioimage/rdf.py +0 -105
- careamics/config/algorithm.py +0 -231
- careamics/config/config.py +0 -297
- careamics/config/config_filter.py +0 -44
- careamics/config/data.py +0 -194
- careamics/config/torch_optim.py +0 -118
- careamics/config/training.py +0 -534
- careamics/dataset/dataset_utils.py +0 -111
- careamics/dataset/patching.py +0 -492
- careamics/dataset/prepare_dataset.py +0 -175
- careamics/dataset/tiff_dataset.py +0 -212
- careamics/engine.py +0 -1014
- careamics/manipulation/__init__.py +0 -4
- careamics/manipulation/pixel_manipulation.py +0 -158
- careamics/prediction/prediction_utils.py +0 -106
- careamics/utils/ascii_logo.txt +0 -9
- careamics/utils/augment.py +0 -65
- careamics/utils/normalization.py +0 -55
- careamics/utils/validators.py +0 -170
- careamics/utils/wandb.py +0 -121
- careamics-0.1.0rc2.dist-info/METADATA +0 -81
- careamics-0.1.0rc2.dist-info/RECORD +0 -47
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,597 @@
|
|
|
1
|
+
"""Convenience functions to create configurations for training and inference."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
|
4
|
+
|
|
5
|
+
from albumentations import Compose
|
|
6
|
+
|
|
7
|
+
from .algorithm_model import AlgorithmConfig
|
|
8
|
+
from .architectures import UNetModel
|
|
9
|
+
from .configuration_model import Configuration
|
|
10
|
+
from .data_model import DataConfig
|
|
11
|
+
from .inference_model import InferenceConfig
|
|
12
|
+
from .support import (
|
|
13
|
+
SupportedAlgorithm,
|
|
14
|
+
SupportedArchitecture,
|
|
15
|
+
SupportedLoss,
|
|
16
|
+
SupportedPixelManipulation,
|
|
17
|
+
SupportedTransform,
|
|
18
|
+
)
|
|
19
|
+
from .training_model import TrainingConfig
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _create_supervised_configuration(
|
|
23
|
+
algorithm: Literal["care", "n2n"],
|
|
24
|
+
experiment_name: str,
|
|
25
|
+
data_type: Literal["array", "tiff", "custom"],
|
|
26
|
+
axes: str,
|
|
27
|
+
patch_size: List[int],
|
|
28
|
+
batch_size: int,
|
|
29
|
+
num_epochs: int,
|
|
30
|
+
use_augmentations: bool = True,
|
|
31
|
+
loss: Literal["mae", "mse"] = "mae",
|
|
32
|
+
n_channels: int = -1,
|
|
33
|
+
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
34
|
+
model_kwargs: Optional[dict] = None,
|
|
35
|
+
) -> Configuration:
|
|
36
|
+
"""
|
|
37
|
+
Create a configuration for training CARE or Noise2Noise.
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
algorithm : Literal["care", "n2n"]
|
|
42
|
+
Algorithm to use.
|
|
43
|
+
experiment_name : str
|
|
44
|
+
Name of the experiment.
|
|
45
|
+
data_type : Literal["array", "tiff", "custom"]
|
|
46
|
+
Type of the data.
|
|
47
|
+
axes : str
|
|
48
|
+
Axes of the data (e.g. SYX).
|
|
49
|
+
patch_size : List[int]
|
|
50
|
+
Size of the patches along the spatial dimensions (e.g. [64, 64]).
|
|
51
|
+
batch_size : int
|
|
52
|
+
Batch size.
|
|
53
|
+
num_epochs : int
|
|
54
|
+
Number of epochs.
|
|
55
|
+
use_augmentations : bool, optional
|
|
56
|
+
Whether to use augmentations, by default True.
|
|
57
|
+
loss : Literal["mae", "mse"], optional
|
|
58
|
+
Loss function to use, by default "mae".
|
|
59
|
+
n_channels : int, optional
|
|
60
|
+
Number of channels (in and out), by default -1.
|
|
61
|
+
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
62
|
+
Logger to use, by default "none".
|
|
63
|
+
model_kwargs : dict, optional
|
|
64
|
+
UNetModel parameters, by default {}.
|
|
65
|
+
|
|
66
|
+
Returns
|
|
67
|
+
-------
|
|
68
|
+
Configuration
|
|
69
|
+
Configuration for training CARE or Noise2Noise.
|
|
70
|
+
"""
|
|
71
|
+
# if there are channels, we need to specify their number
|
|
72
|
+
if "C" in axes and n_channels == 1:
|
|
73
|
+
raise ValueError(
|
|
74
|
+
f"Number of channels must be specified when using channels "
|
|
75
|
+
f"(got {n_channels} channel)."
|
|
76
|
+
)
|
|
77
|
+
elif "C" not in axes and n_channels > 1:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
f"C is not present in the axes, but number of channels is specified "
|
|
80
|
+
f"(got {n_channels} channel)."
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# model
|
|
84
|
+
if model_kwargs is None:
|
|
85
|
+
model_kwargs = {}
|
|
86
|
+
model_kwargs["conv_dims"] = 3 if "Z" in axes else 2
|
|
87
|
+
model_kwargs["in_channels"] = n_channels
|
|
88
|
+
model_kwargs["num_classes"] = n_channels
|
|
89
|
+
|
|
90
|
+
unet_model = UNetModel(
|
|
91
|
+
architecture=SupportedArchitecture.UNET.value,
|
|
92
|
+
**model_kwargs,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# algorithm model
|
|
96
|
+
algorithm = AlgorithmConfig(
|
|
97
|
+
algorithm=algorithm,
|
|
98
|
+
loss=loss,
|
|
99
|
+
model=unet_model,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# augmentations
|
|
103
|
+
if use_augmentations:
|
|
104
|
+
transforms: List[Dict[str, Any]] = [
|
|
105
|
+
{
|
|
106
|
+
"name": SupportedTransform.NORMALIZE.value,
|
|
107
|
+
},
|
|
108
|
+
{
|
|
109
|
+
"name": SupportedTransform.NDFLIP.value,
|
|
110
|
+
},
|
|
111
|
+
{
|
|
112
|
+
"name": SupportedTransform.XY_RANDOM_ROTATE90.value,
|
|
113
|
+
},
|
|
114
|
+
]
|
|
115
|
+
else:
|
|
116
|
+
transforms = [
|
|
117
|
+
{
|
|
118
|
+
"name": SupportedTransform.NORMALIZE.value,
|
|
119
|
+
},
|
|
120
|
+
]
|
|
121
|
+
|
|
122
|
+
# data model
|
|
123
|
+
data = DataConfig(
|
|
124
|
+
data_type=data_type,
|
|
125
|
+
axes=axes,
|
|
126
|
+
patch_size=patch_size,
|
|
127
|
+
batch_size=batch_size,
|
|
128
|
+
transforms=transforms,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
# training model
|
|
132
|
+
training = TrainingConfig(
|
|
133
|
+
num_epochs=num_epochs,
|
|
134
|
+
batch_size=batch_size,
|
|
135
|
+
logger=None if logger == "none" else logger,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
# create configuration
|
|
139
|
+
configuration = Configuration(
|
|
140
|
+
experiment_name=experiment_name,
|
|
141
|
+
algorithm_config=algorithm,
|
|
142
|
+
data_config=data,
|
|
143
|
+
training_config=training,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
return configuration
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def create_care_configuration(
|
|
150
|
+
experiment_name: str,
|
|
151
|
+
data_type: Literal["array", "tiff", "custom"],
|
|
152
|
+
axes: str,
|
|
153
|
+
patch_size: List[int],
|
|
154
|
+
batch_size: int,
|
|
155
|
+
num_epochs: int,
|
|
156
|
+
use_augmentations: bool = True,
|
|
157
|
+
loss: Literal["mae", "mse"] = "mae",
|
|
158
|
+
n_channels: int = 1,
|
|
159
|
+
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
160
|
+
model_kwargs: Optional[dict] = None,
|
|
161
|
+
) -> Configuration:
|
|
162
|
+
"""
|
|
163
|
+
Create a configuration for training CARE.
|
|
164
|
+
|
|
165
|
+
If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
|
|
166
|
+
2.
|
|
167
|
+
|
|
168
|
+
If "C" is present in `axes`, then you need to set `n_channels` to the number of
|
|
169
|
+
channels. Likewise, if you set the number of channels, then "C" must be present in
|
|
170
|
+
`axes`.
|
|
171
|
+
|
|
172
|
+
By setting `use_augmentations` to False, the only transformation applied will be
|
|
173
|
+
normalization.
|
|
174
|
+
|
|
175
|
+
Parameters
|
|
176
|
+
----------
|
|
177
|
+
experiment_name : str
|
|
178
|
+
Name of the experiment.
|
|
179
|
+
data_type : Literal["array", "tiff", "custom"]
|
|
180
|
+
Type of the data.
|
|
181
|
+
axes : str
|
|
182
|
+
Axes of the data (e.g. SYX).
|
|
183
|
+
patch_size : List[int]
|
|
184
|
+
Size of the patches along the spatial dimensions (e.g. [64, 64]).
|
|
185
|
+
batch_size : int
|
|
186
|
+
Batch size.
|
|
187
|
+
num_epochs : int
|
|
188
|
+
Number of epochs.
|
|
189
|
+
use_augmentations : bool, optional
|
|
190
|
+
Whether to use augmentations, by default True.
|
|
191
|
+
loss : Literal["mae", "mse"], optional
|
|
192
|
+
Loss function to use, by default "mae".
|
|
193
|
+
n_channels : int, optional
|
|
194
|
+
Number of channels (in and out), by default 1.
|
|
195
|
+
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
196
|
+
Logger to use, by default "none".
|
|
197
|
+
model_kwargs : dict, optional
|
|
198
|
+
UNetModel parameters, by default {}.
|
|
199
|
+
|
|
200
|
+
Returns
|
|
201
|
+
-------
|
|
202
|
+
Configuration
|
|
203
|
+
Configuration for training CARE.
|
|
204
|
+
"""
|
|
205
|
+
return _create_supervised_configuration(
|
|
206
|
+
algorithm="care",
|
|
207
|
+
experiment_name=experiment_name,
|
|
208
|
+
data_type=data_type,
|
|
209
|
+
axes=axes,
|
|
210
|
+
patch_size=patch_size,
|
|
211
|
+
batch_size=batch_size,
|
|
212
|
+
num_epochs=num_epochs,
|
|
213
|
+
use_augmentations=use_augmentations,
|
|
214
|
+
loss=loss,
|
|
215
|
+
# TODO in the future we might support different in and out channels for CARE
|
|
216
|
+
n_channels=n_channels,
|
|
217
|
+
logger=logger,
|
|
218
|
+
model_kwargs=model_kwargs,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def create_n2n_configuration(
|
|
223
|
+
experiment_name: str,
|
|
224
|
+
data_type: Literal["array", "tiff", "custom"],
|
|
225
|
+
axes: str,
|
|
226
|
+
patch_size: List[int],
|
|
227
|
+
batch_size: int,
|
|
228
|
+
num_epochs: int,
|
|
229
|
+
use_augmentations: bool = True,
|
|
230
|
+
loss: Literal["mae", "mse"] = "mae",
|
|
231
|
+
n_channels: int = 1,
|
|
232
|
+
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
233
|
+
model_kwargs: Optional[dict] = None,
|
|
234
|
+
) -> Configuration:
|
|
235
|
+
"""
|
|
236
|
+
Create a configuration for training Noise2Noise.
|
|
237
|
+
|
|
238
|
+
If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
|
|
239
|
+
2.
|
|
240
|
+
|
|
241
|
+
If "C" is present in `axes`, then you need to set `n_channels` to the number of
|
|
242
|
+
channels. Likewise, if you set the number of channels, then "C" must be present in
|
|
243
|
+
`axes`.
|
|
244
|
+
|
|
245
|
+
By setting `use_augmentations` to False, the only transformation applied will be
|
|
246
|
+
normalization.
|
|
247
|
+
|
|
248
|
+
Parameters
|
|
249
|
+
----------
|
|
250
|
+
experiment_name : str
|
|
251
|
+
Name of the experiment.
|
|
252
|
+
data_type : Literal["array", "tiff", "custom"]
|
|
253
|
+
Type of the data.
|
|
254
|
+
axes : str
|
|
255
|
+
Axes of the data (e.g. SYX).
|
|
256
|
+
patch_size : List[int]
|
|
257
|
+
Size of the patches along the spatial dimensions (e.g. [64, 64]).
|
|
258
|
+
batch_size : int
|
|
259
|
+
Batch size.
|
|
260
|
+
num_epochs : int
|
|
261
|
+
Number of epochs.
|
|
262
|
+
use_augmentations : bool, optional
|
|
263
|
+
Whether to use augmentations, by default True.
|
|
264
|
+
loss : Literal["mae", "mse"], optional
|
|
265
|
+
Loss function to use, by default "mae".
|
|
266
|
+
n_channels : int, optional
|
|
267
|
+
Number of channels (in and out), by default 1.
|
|
268
|
+
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
269
|
+
Logger to use, by default "none".
|
|
270
|
+
model_kwargs : dict, optional
|
|
271
|
+
UNetModel parameters, by default {}.
|
|
272
|
+
|
|
273
|
+
Returns
|
|
274
|
+
-------
|
|
275
|
+
Configuration
|
|
276
|
+
Configuration for training Noise2Noise.
|
|
277
|
+
"""
|
|
278
|
+
return _create_supervised_configuration(
|
|
279
|
+
algorithm="n2n",
|
|
280
|
+
experiment_name=experiment_name,
|
|
281
|
+
data_type=data_type,
|
|
282
|
+
axes=axes,
|
|
283
|
+
patch_size=patch_size,
|
|
284
|
+
batch_size=batch_size,
|
|
285
|
+
num_epochs=num_epochs,
|
|
286
|
+
use_augmentations=use_augmentations,
|
|
287
|
+
loss=loss,
|
|
288
|
+
n_channels=n_channels,
|
|
289
|
+
logger=logger,
|
|
290
|
+
model_kwargs=model_kwargs,
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def create_n2v_configuration(
|
|
295
|
+
experiment_name: str,
|
|
296
|
+
data_type: Literal["array", "tiff", "custom"],
|
|
297
|
+
axes: str,
|
|
298
|
+
patch_size: List[int],
|
|
299
|
+
batch_size: int,
|
|
300
|
+
num_epochs: int,
|
|
301
|
+
use_augmentations: bool = True,
|
|
302
|
+
use_n2v2: bool = False,
|
|
303
|
+
n_channels: int = 1,
|
|
304
|
+
roi_size: int = 11,
|
|
305
|
+
masked_pixel_percentage: float = 0.2,
|
|
306
|
+
struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
|
|
307
|
+
struct_n2v_span: int = 5,
|
|
308
|
+
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
309
|
+
model_kwargs: Optional[dict] = None,
|
|
310
|
+
) -> Configuration:
|
|
311
|
+
"""
|
|
312
|
+
Create a configuration for training Noise2Void.
|
|
313
|
+
|
|
314
|
+
N2V uses a UNet model to denoise images in a self-supervised manner. To use its
|
|
315
|
+
variants structN2V and N2V2, set the `struct_n2v_axis` and `struct_n2v_span`
|
|
316
|
+
(structN2V) parameters, or set `use_n2v2` to True (N2V2).
|
|
317
|
+
|
|
318
|
+
N2V2 modifies the UNet architecture by adding blur pool layers and removes the skip
|
|
319
|
+
connections, thus removing checkboard artefacts. StructN2V is used when vertical
|
|
320
|
+
or horizontal correlations are present in the noise; it applies an additional mask
|
|
321
|
+
to the manipulated pixel neighbors.
|
|
322
|
+
|
|
323
|
+
If "C" is present in `axes`, then you need to set `n_channels` to the number of
|
|
324
|
+
channels.
|
|
325
|
+
|
|
326
|
+
If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
|
|
327
|
+
2.
|
|
328
|
+
|
|
329
|
+
By setting `use_augmentations` to False, the only transformations applied will be
|
|
330
|
+
normalization and N2V manipulation.
|
|
331
|
+
|
|
332
|
+
The `roi_size` parameter specifies the size of the area around each pixel that will
|
|
333
|
+
be manipulated by N2V. The `masked_pixel_percentage` parameter specifies how many
|
|
334
|
+
pixels per patch will be manipulated.
|
|
335
|
+
|
|
336
|
+
The parameters of the UNet can be specified in the `model_kwargs` (passed as a
|
|
337
|
+
parameter-value dictionary). Note that `use_n2v2` and 'n_channels' override the
|
|
338
|
+
corresponding parameters passed in `model_kwargs`.
|
|
339
|
+
|
|
340
|
+
If you pass "horizontal" or "vertical" to `struct_n2v_axis`, then structN2V mask
|
|
341
|
+
will be applied to each manipulated pixel.
|
|
342
|
+
|
|
343
|
+
Parameters
|
|
344
|
+
----------
|
|
345
|
+
experiment_name : str
|
|
346
|
+
Name of the experiment.
|
|
347
|
+
data_type : Literal["array", "tiff", "custom"]
|
|
348
|
+
Type of the data.
|
|
349
|
+
axes : str
|
|
350
|
+
Axes of the data (e.g. SYX).
|
|
351
|
+
patch_size : List[int]
|
|
352
|
+
Size of the patches along the spatial dimensions (e.g. [64, 64]).
|
|
353
|
+
batch_size : int
|
|
354
|
+
Batch size.
|
|
355
|
+
num_epochs : int
|
|
356
|
+
Number of epochs.
|
|
357
|
+
use_augmentations : bool, optional
|
|
358
|
+
Whether to use augmentations, by default True.
|
|
359
|
+
use_n2v2 : bool, optional
|
|
360
|
+
Whether to use N2V2, by default False.
|
|
361
|
+
n_channels : int, optional
|
|
362
|
+
Number of channels (in and out), by default 1.
|
|
363
|
+
roi_size : int, optional
|
|
364
|
+
N2V pixel manipulation area, by default 11.
|
|
365
|
+
masked_pixel_percentage : float, optional
|
|
366
|
+
Percentage of pixels masked in each patch, by default 0.2.
|
|
367
|
+
struct_n2v_axis : Literal["horizontal", "vertical", "none"], optional
|
|
368
|
+
Axis along which to apply structN2V mask, by default "none".
|
|
369
|
+
struct_n2v_span : int, optional
|
|
370
|
+
Span of the structN2V mask, by default 5.
|
|
371
|
+
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
372
|
+
Logger to use, by default "none".
|
|
373
|
+
model_kwargs : dict, optional
|
|
374
|
+
UNetModel parameters, by default {}.
|
|
375
|
+
|
|
376
|
+
Returns
|
|
377
|
+
-------
|
|
378
|
+
Configuration
|
|
379
|
+
Configuration for training N2V.
|
|
380
|
+
|
|
381
|
+
Examples
|
|
382
|
+
--------
|
|
383
|
+
Minimum example:
|
|
384
|
+
>>> config = create_n2v_configuration(
|
|
385
|
+
... experiment_name="n2v_experiment",
|
|
386
|
+
... data_type="array",
|
|
387
|
+
... axes="YX",
|
|
388
|
+
... patch_size=[64, 64],
|
|
389
|
+
... batch_size=32,
|
|
390
|
+
... num_epochs=100
|
|
391
|
+
... )
|
|
392
|
+
|
|
393
|
+
To use N2V2, simply pass the `use_n2v2` parameter:
|
|
394
|
+
>>> config = create_n2v_configuration(
|
|
395
|
+
... experiment_name="n2v2_experiment",
|
|
396
|
+
... data_type="tiff",
|
|
397
|
+
... axes="YX",
|
|
398
|
+
... patch_size=[64, 64],
|
|
399
|
+
... batch_size=32,
|
|
400
|
+
... num_epochs=100,
|
|
401
|
+
... use_n2v2=True
|
|
402
|
+
... )
|
|
403
|
+
|
|
404
|
+
For structN2V, there are two parameters to set, `struct_n2v_axis` and
|
|
405
|
+
`struct_n2v_span`:
|
|
406
|
+
>>> config = create_n2v_configuration(
|
|
407
|
+
... experiment_name="structn2v_experiment",
|
|
408
|
+
... data_type="tiff",
|
|
409
|
+
... axes="YX",
|
|
410
|
+
... patch_size=[64, 64],
|
|
411
|
+
... batch_size=32,
|
|
412
|
+
... num_epochs=100,
|
|
413
|
+
... struct_n2v_axis="horizontal",
|
|
414
|
+
... struct_n2v_span=7
|
|
415
|
+
... )
|
|
416
|
+
|
|
417
|
+
If you are training multiple channels together, then you need to specify the number
|
|
418
|
+
of channels:
|
|
419
|
+
>>> config = create_n2v_configuration(
|
|
420
|
+
... experiment_name="n2v_experiment",
|
|
421
|
+
... data_type="array",
|
|
422
|
+
... axes="YXC",
|
|
423
|
+
... patch_size=[64, 64],
|
|
424
|
+
... batch_size=32,
|
|
425
|
+
... num_epochs=100,
|
|
426
|
+
... n_channels=3
|
|
427
|
+
... )
|
|
428
|
+
|
|
429
|
+
To turn off the augmentations, except normalization and N2V manipulation, use the
|
|
430
|
+
relevant keyword argument:
|
|
431
|
+
>>> config = create_n2v_configuration(
|
|
432
|
+
... experiment_name="n2v_experiment",
|
|
433
|
+
... data_type="array",
|
|
434
|
+
... axes="YX",
|
|
435
|
+
... patch_size=[64, 64],
|
|
436
|
+
... batch_size=32,
|
|
437
|
+
... num_epochs=100,
|
|
438
|
+
... use_augmentations=False
|
|
439
|
+
... )
|
|
440
|
+
"""
|
|
441
|
+
# if there are channels, we need to specify their number
|
|
442
|
+
if "C" in axes and n_channels == 1:
|
|
443
|
+
raise ValueError(
|
|
444
|
+
f"Number of channels must be specified when using channels "
|
|
445
|
+
f"(got {n_channels} channel)."
|
|
446
|
+
)
|
|
447
|
+
elif "C" not in axes and n_channels > 1:
|
|
448
|
+
raise ValueError(
|
|
449
|
+
f"C is not present in the axes, but number of channels is specified "
|
|
450
|
+
f"(got {n_channels} channel)."
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
# model
|
|
454
|
+
if model_kwargs is None:
|
|
455
|
+
model_kwargs = {}
|
|
456
|
+
model_kwargs["n2v2"] = use_n2v2
|
|
457
|
+
model_kwargs["conv_dims"] = 3 if "Z" in axes else 2
|
|
458
|
+
model_kwargs["in_channels"] = n_channels
|
|
459
|
+
model_kwargs["num_classes"] = n_channels
|
|
460
|
+
|
|
461
|
+
unet_model = UNetModel(
|
|
462
|
+
architecture=SupportedArchitecture.UNET.value,
|
|
463
|
+
**model_kwargs,
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
# algorithm model
|
|
467
|
+
algorithm = AlgorithmConfig(
|
|
468
|
+
algorithm=SupportedAlgorithm.N2V.value,
|
|
469
|
+
loss=SupportedLoss.N2V.value,
|
|
470
|
+
model=unet_model,
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
# augmentations
|
|
474
|
+
if use_augmentations:
|
|
475
|
+
transforms: List[Dict[str, Any]] = [
|
|
476
|
+
{
|
|
477
|
+
"name": SupportedTransform.NORMALIZE.value,
|
|
478
|
+
},
|
|
479
|
+
{
|
|
480
|
+
"name": SupportedTransform.NDFLIP.value,
|
|
481
|
+
},
|
|
482
|
+
{
|
|
483
|
+
"name": SupportedTransform.XY_RANDOM_ROTATE90.value,
|
|
484
|
+
},
|
|
485
|
+
]
|
|
486
|
+
else:
|
|
487
|
+
transforms = [
|
|
488
|
+
{
|
|
489
|
+
"name": SupportedTransform.NORMALIZE.value,
|
|
490
|
+
},
|
|
491
|
+
]
|
|
492
|
+
|
|
493
|
+
# n2v2 and structn2v
|
|
494
|
+
nv2_transform = {
|
|
495
|
+
"name": SupportedTransform.N2V_MANIPULATE.value,
|
|
496
|
+
"strategy": SupportedPixelManipulation.MEDIAN.value
|
|
497
|
+
if use_n2v2
|
|
498
|
+
else SupportedPixelManipulation.UNIFORM.value,
|
|
499
|
+
"roi_size": roi_size,
|
|
500
|
+
"masked_pixel_percentage": masked_pixel_percentage,
|
|
501
|
+
"struct_mask_axis": struct_n2v_axis,
|
|
502
|
+
"struct_mask_span": struct_n2v_span,
|
|
503
|
+
}
|
|
504
|
+
transforms.append(nv2_transform)
|
|
505
|
+
|
|
506
|
+
# data model
|
|
507
|
+
data = DataConfig(
|
|
508
|
+
data_type=data_type,
|
|
509
|
+
axes=axes,
|
|
510
|
+
patch_size=patch_size,
|
|
511
|
+
batch_size=batch_size,
|
|
512
|
+
transforms=transforms,
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
# training model
|
|
516
|
+
training = TrainingConfig(
|
|
517
|
+
num_epochs=num_epochs,
|
|
518
|
+
batch_size=batch_size,
|
|
519
|
+
logger=None if logger == "none" else logger,
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
# create configuration
|
|
523
|
+
configuration = Configuration(
|
|
524
|
+
experiment_name=experiment_name,
|
|
525
|
+
algorithm_config=algorithm,
|
|
526
|
+
data_config=data,
|
|
527
|
+
training_config=training,
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
return configuration
|
|
531
|
+
|
|
532
|
+
|
|
533
|
+
# TODO add tests
|
|
534
|
+
def create_inference_configuration(
|
|
535
|
+
training_configuration: Configuration,
|
|
536
|
+
tile_size: Optional[Tuple[int, ...]] = None,
|
|
537
|
+
tile_overlap: Optional[Tuple[int, ...]] = None,
|
|
538
|
+
data_type: Optional[Literal["array", "tiff", "custom"]] = None,
|
|
539
|
+
axes: Optional[str] = None,
|
|
540
|
+
transforms: Optional[Union[List[Dict[str, Any]], Compose]] = None,
|
|
541
|
+
tta_transforms: bool = True,
|
|
542
|
+
batch_size: Optional[int] = 1,
|
|
543
|
+
) -> InferenceConfig:
|
|
544
|
+
"""
|
|
545
|
+
Create a configuration for inference with N2V.
|
|
546
|
+
|
|
547
|
+
If not provided, `data_type` and `axes` are taken from the training
|
|
548
|
+
configuration. If `transforms` are not provided, only normalization is applied.
|
|
549
|
+
|
|
550
|
+
Parameters
|
|
551
|
+
----------
|
|
552
|
+
training_configuration : Configuration
|
|
553
|
+
Configuration used for training.
|
|
554
|
+
tile_size : Tuple[int, ...], optional
|
|
555
|
+
Size of the tiles.
|
|
556
|
+
tile_overlap : Tuple[int, ...], optional
|
|
557
|
+
Overlap of the tiles.
|
|
558
|
+
data_type : str, optional
|
|
559
|
+
Type of the data, by default "tiff".
|
|
560
|
+
axes : str, optional
|
|
561
|
+
Axes of the data, by default "YX".
|
|
562
|
+
transforms : List[Dict[str, Any]] or Compose, optional
|
|
563
|
+
Transformations to apply to the data, by default None.
|
|
564
|
+
tta_transforms : bool, optional
|
|
565
|
+
Whether to apply test-time augmentations, by default True.
|
|
566
|
+
batch_size : int, optional
|
|
567
|
+
Batch size, by default 1.
|
|
568
|
+
|
|
569
|
+
Returns
|
|
570
|
+
-------
|
|
571
|
+
InferenceConfiguration
|
|
572
|
+
Configuration for inference with N2V.
|
|
573
|
+
"""
|
|
574
|
+
if (
|
|
575
|
+
training_configuration.data_config.mean is None
|
|
576
|
+
or training_configuration.data_config.std is None
|
|
577
|
+
):
|
|
578
|
+
raise ValueError("Mean and std must be provided in the training configuration.")
|
|
579
|
+
|
|
580
|
+
if transforms is None:
|
|
581
|
+
transforms = [
|
|
582
|
+
{
|
|
583
|
+
"name": SupportedTransform.NORMALIZE.value,
|
|
584
|
+
},
|
|
585
|
+
]
|
|
586
|
+
|
|
587
|
+
return InferenceConfig(
|
|
588
|
+
data_type=data_type or training_configuration.data_config.data_type,
|
|
589
|
+
tile_size=tile_size,
|
|
590
|
+
tile_overlap=tile_overlap,
|
|
591
|
+
axes=axes or training_configuration.data_config.axes,
|
|
592
|
+
mean=training_configuration.data_config.mean,
|
|
593
|
+
std=training_configuration.data_config.std,
|
|
594
|
+
transforms=transforms,
|
|
595
|
+
tta_transforms=tta_transforms,
|
|
596
|
+
batch_size=batch_size,
|
|
597
|
+
)
|