careamics 0.1.0rc3__py3-none-any.whl → 0.1.0rc5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +8 -6
- careamics/careamist.py +30 -29
- careamics/config/__init__.py +12 -9
- careamics/config/algorithm_model.py +5 -5
- careamics/config/architectures/unet_model.py +1 -0
- careamics/config/callback_model.py +1 -0
- careamics/config/configuration_example.py +87 -0
- careamics/config/configuration_factory.py +285 -78
- careamics/config/configuration_model.py +22 -23
- careamics/config/data_model.py +62 -160
- careamics/config/inference_model.py +20 -21
- careamics/config/references/algorithm_descriptions.py +1 -0
- careamics/config/references/references.py +1 -0
- careamics/config/support/supported_extraction_strategies.py +1 -0
- careamics/config/support/supported_optimizers.py +3 -3
- careamics/config/training_model.py +2 -1
- careamics/config/transformations/n2v_manipulate_model.py +2 -1
- careamics/config/transformations/nd_flip_model.py +7 -12
- careamics/config/transformations/normalize_model.py +2 -1
- careamics/config/transformations/transform_model.py +1 -0
- careamics/config/transformations/xy_random_rotate90_model.py +7 -9
- careamics/config/validators/validator_utils.py +1 -0
- careamics/conftest.py +1 -0
- careamics/dataset/dataset_utils/__init__.py +0 -1
- careamics/dataset/dataset_utils/dataset_utils.py +1 -0
- careamics/dataset/in_memory_dataset.py +17 -48
- careamics/dataset/iterable_dataset.py +16 -71
- careamics/dataset/patching/__init__.py +0 -7
- careamics/dataset/patching/patching.py +1 -0
- careamics/dataset/patching/sequential_patching.py +6 -6
- careamics/dataset/patching/tiled_patching.py +10 -6
- careamics/lightning_datamodule.py +123 -49
- careamics/lightning_module.py +7 -7
- careamics/lightning_prediction_datamodule.py +59 -48
- careamics/losses/__init__.py +0 -1
- careamics/losses/loss_factory.py +1 -0
- careamics/model_io/__init__.py +0 -1
- careamics/model_io/bioimage/_readme_factory.py +2 -1
- careamics/model_io/bioimage/bioimage_utils.py +1 -0
- careamics/model_io/bioimage/model_description.py +4 -3
- careamics/model_io/bmz_io.py +8 -7
- careamics/model_io/model_io_utils.py +4 -4
- careamics/models/layers.py +1 -0
- careamics/models/model_factory.py +1 -0
- careamics/models/unet.py +91 -17
- careamics/prediction/stitch_prediction.py +1 -0
- careamics/transforms/__init__.py +2 -23
- careamics/transforms/compose.py +98 -0
- careamics/transforms/n2v_manipulate.py +18 -23
- careamics/transforms/nd_flip.py +38 -64
- careamics/transforms/normalize.py +45 -34
- careamics/transforms/pixel_manipulation.py +2 -2
- careamics/transforms/transform.py +33 -0
- careamics/transforms/tta.py +2 -2
- careamics/transforms/xy_random_rotate90.py +41 -68
- careamics/utils/__init__.py +0 -1
- careamics/utils/context.py +1 -0
- careamics/utils/logging.py +1 -0
- careamics/utils/metrics.py +1 -0
- careamics/utils/torch_utils.py +1 -0
- {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc5.dist-info}/METADATA +16 -61
- careamics-0.1.0rc5.dist-info/RECORD +111 -0
- careamics/dataset/patching/patch_transform.py +0 -44
- careamics-0.1.0rc3.dist-info/RECORD +0 -109
- {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc5.dist-info}/WHEEL +0 -0
- {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc5.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,13 +2,11 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
|
4
4
|
|
|
5
|
-
from
|
|
6
|
-
|
|
7
|
-
from .algorithm_model import AlgorithmModel
|
|
5
|
+
from .algorithm_model import AlgorithmConfig
|
|
8
6
|
from .architectures import UNetModel
|
|
9
7
|
from .configuration_model import Configuration
|
|
10
|
-
from .data_model import
|
|
11
|
-
from .inference_model import
|
|
8
|
+
from .data_model import DataConfig
|
|
9
|
+
from .inference_model import InferenceConfig
|
|
12
10
|
from .support import (
|
|
13
11
|
SupportedAlgorithm,
|
|
14
12
|
SupportedArchitecture,
|
|
@@ -16,10 +14,11 @@ from .support import (
|
|
|
16
14
|
SupportedPixelManipulation,
|
|
17
15
|
SupportedTransform,
|
|
18
16
|
)
|
|
19
|
-
from .training_model import
|
|
17
|
+
from .training_model import TrainingConfig
|
|
20
18
|
|
|
21
19
|
|
|
22
|
-
def
|
|
20
|
+
def _create_supervised_configuration(
|
|
21
|
+
algorithm: Literal["care", "n2n"],
|
|
23
22
|
experiment_name: str,
|
|
24
23
|
data_type: Literal["array", "tiff", "custom"],
|
|
25
24
|
axes: str,
|
|
@@ -27,28 +26,20 @@ def create_n2n_configuration(
|
|
|
27
26
|
batch_size: int,
|
|
28
27
|
num_epochs: int,
|
|
29
28
|
use_augmentations: bool = True,
|
|
30
|
-
|
|
31
|
-
|
|
29
|
+
independent_channels: bool = False,
|
|
30
|
+
loss: Literal["mae", "mse"] = "mae",
|
|
31
|
+
n_channels_in: int = 1,
|
|
32
|
+
n_channels_out: int = 1,
|
|
32
33
|
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
33
34
|
model_kwargs: Optional[dict] = None,
|
|
34
35
|
) -> Configuration:
|
|
35
36
|
"""
|
|
36
|
-
Create a configuration for training
|
|
37
|
-
|
|
38
|
-
If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
|
|
39
|
-
2.
|
|
40
|
-
|
|
41
|
-
By setting `use_augmentations` to False, the only transformation applied will be
|
|
42
|
-
normalization and N2V manipulation.
|
|
43
|
-
|
|
44
|
-
The parameter `use_n2v2` overrides the corresponding `n2v2` that can be passed
|
|
45
|
-
in `model_kwargs`.
|
|
46
|
-
|
|
47
|
-
If you pass "horizontal" or "vertical" to `struct_n2v_axis`, then structN2V mask
|
|
48
|
-
will be applied to each manipulated pixel.
|
|
37
|
+
Create a configuration for training CARE or Noise2Noise.
|
|
49
38
|
|
|
50
39
|
Parameters
|
|
51
40
|
----------
|
|
41
|
+
algorithm : Literal["care", "n2n"]
|
|
42
|
+
Algorithm to use.
|
|
52
43
|
experiment_name : str
|
|
53
44
|
Name of the experiment.
|
|
54
45
|
data_type : Literal["array", "tiff", "custom"]
|
|
@@ -63,18 +54,14 @@ def create_n2n_configuration(
|
|
|
63
54
|
Number of epochs.
|
|
64
55
|
use_augmentations : bool, optional
|
|
65
56
|
Whether to use augmentations, by default True.
|
|
66
|
-
|
|
67
|
-
Whether to
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
struct_n2v_axis : Literal["horizontal", "vertical", "none"], optional
|
|
75
|
-
Axis along which to apply structN2V mask, by default "none".
|
|
76
|
-
struct_n2v_span : int, optional
|
|
77
|
-
Span of the structN2V mask, by default 5.
|
|
57
|
+
independent_channels : bool, optional
|
|
58
|
+
Whether to train all channels independently, by default False.
|
|
59
|
+
loss : Literal["mae", "mse"], optional
|
|
60
|
+
Loss function to use, by default "mae".
|
|
61
|
+
n_channels_in : int, optional
|
|
62
|
+
Number of channels in, by default 1.
|
|
63
|
+
n_channels_out : int, optional
|
|
64
|
+
Number of channels out, by default 1.
|
|
78
65
|
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
79
66
|
Logger to use, by default "none".
|
|
80
67
|
model_kwargs : dict, optional
|
|
@@ -83,15 +70,27 @@ def create_n2n_configuration(
|
|
|
83
70
|
Returns
|
|
84
71
|
-------
|
|
85
72
|
Configuration
|
|
86
|
-
Configuration for training
|
|
73
|
+
Configuration for training CARE or Noise2Noise.
|
|
87
74
|
"""
|
|
75
|
+
# if there are channels, we need to specify their number
|
|
76
|
+
if "C" in axes and n_channels_in == 1:
|
|
77
|
+
raise ValueError(
|
|
78
|
+
f"Number of channels in must be specified when using channels "
|
|
79
|
+
f"(got {n_channels_in} channel)."
|
|
80
|
+
)
|
|
81
|
+
elif "C" not in axes and n_channels_in > 1:
|
|
82
|
+
raise ValueError(
|
|
83
|
+
f"C is not present in the axes, but number of channels is specified "
|
|
84
|
+
f"(got {n_channels_in} channels)."
|
|
85
|
+
)
|
|
86
|
+
|
|
88
87
|
# model
|
|
89
88
|
if model_kwargs is None:
|
|
90
89
|
model_kwargs = {}
|
|
91
|
-
model_kwargs["n2v2"] = use_n2v2
|
|
92
90
|
model_kwargs["conv_dims"] = 3 if "Z" in axes else 2
|
|
93
|
-
model_kwargs["in_channels"] =
|
|
94
|
-
model_kwargs["num_classes"] =
|
|
91
|
+
model_kwargs["in_channels"] = n_channels_in
|
|
92
|
+
model_kwargs["num_classes"] = n_channels_out
|
|
93
|
+
model_kwargs["independent_channels"] = independent_channels
|
|
95
94
|
|
|
96
95
|
unet_model = UNetModel(
|
|
97
96
|
architecture=SupportedArchitecture.UNET.value,
|
|
@@ -99,9 +98,9 @@ def create_n2n_configuration(
|
|
|
99
98
|
)
|
|
100
99
|
|
|
101
100
|
# algorithm model
|
|
102
|
-
algorithm =
|
|
103
|
-
algorithm=
|
|
104
|
-
loss=
|
|
101
|
+
algorithm = AlgorithmConfig(
|
|
102
|
+
algorithm=algorithm,
|
|
103
|
+
loss=loss,
|
|
105
104
|
model=unet_model,
|
|
106
105
|
)
|
|
107
106
|
|
|
@@ -126,7 +125,7 @@ def create_n2n_configuration(
|
|
|
126
125
|
]
|
|
127
126
|
|
|
128
127
|
# data model
|
|
129
|
-
data =
|
|
128
|
+
data = DataConfig(
|
|
130
129
|
data_type=data_type,
|
|
131
130
|
axes=axes,
|
|
132
131
|
patch_size=patch_size,
|
|
@@ -135,7 +134,7 @@ def create_n2n_configuration(
|
|
|
135
134
|
)
|
|
136
135
|
|
|
137
136
|
# training model
|
|
138
|
-
training =
|
|
137
|
+
training = TrainingConfig(
|
|
139
138
|
num_epochs=num_epochs,
|
|
140
139
|
batch_size=batch_size,
|
|
141
140
|
logger=None if logger == "none" else logger,
|
|
@@ -152,6 +151,175 @@ def create_n2n_configuration(
|
|
|
152
151
|
return configuration
|
|
153
152
|
|
|
154
153
|
|
|
154
|
+
def create_care_configuration(
|
|
155
|
+
experiment_name: str,
|
|
156
|
+
data_type: Literal["array", "tiff", "custom"],
|
|
157
|
+
axes: str,
|
|
158
|
+
patch_size: List[int],
|
|
159
|
+
batch_size: int,
|
|
160
|
+
num_epochs: int,
|
|
161
|
+
use_augmentations: bool = True,
|
|
162
|
+
independent_channels: bool = False,
|
|
163
|
+
loss: Literal["mae", "mse"] = "mae",
|
|
164
|
+
n_channels_in: int = 1,
|
|
165
|
+
n_channels_out: int = -1,
|
|
166
|
+
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
167
|
+
model_kwargs: Optional[dict] = None,
|
|
168
|
+
) -> Configuration:
|
|
169
|
+
"""
|
|
170
|
+
Create a configuration for training CARE.
|
|
171
|
+
|
|
172
|
+
If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
|
|
173
|
+
2.
|
|
174
|
+
|
|
175
|
+
If "C" is present in `axes`, then you need to set `n_channels_in` to the number of
|
|
176
|
+
channels. Likewise, if you set the number of channels, then "C" must be present in
|
|
177
|
+
`axes`.
|
|
178
|
+
|
|
179
|
+
To set the number of output channels, use the `n_channels_out` parameter. If it is
|
|
180
|
+
not specified, it will be assumed to be equal to `n_channels_in`.
|
|
181
|
+
|
|
182
|
+
By default, all channels are trained together. To train all channels independently,
|
|
183
|
+
set `independent_channels` to True.
|
|
184
|
+
|
|
185
|
+
By setting `use_augmentations` to False, the only transformation applied will be
|
|
186
|
+
normalization.
|
|
187
|
+
|
|
188
|
+
Parameters
|
|
189
|
+
----------
|
|
190
|
+
experiment_name : str
|
|
191
|
+
Name of the experiment.
|
|
192
|
+
data_type : Literal["array", "tiff", "custom"]
|
|
193
|
+
Type of the data.
|
|
194
|
+
axes : str
|
|
195
|
+
Axes of the data (e.g. SYX).
|
|
196
|
+
patch_size : List[int]
|
|
197
|
+
Size of the patches along the spatial dimensions (e.g. [64, 64]).
|
|
198
|
+
batch_size : int
|
|
199
|
+
Batch size.
|
|
200
|
+
num_epochs : int
|
|
201
|
+
Number of epochs.
|
|
202
|
+
use_augmentations : bool, optional
|
|
203
|
+
Whether to use augmentations, by default True.
|
|
204
|
+
independent_channels : bool, optional
|
|
205
|
+
Whether to train all channels independently, by default False.
|
|
206
|
+
loss : Literal["mae", "mse"], optional
|
|
207
|
+
Loss function to use, by default "mae".
|
|
208
|
+
n_channels_in : int, optional
|
|
209
|
+
Number of channels in, by default 1.
|
|
210
|
+
n_channels_out : int, optional
|
|
211
|
+
Number of channels out, by default -1.
|
|
212
|
+
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
213
|
+
Logger to use, by default "none".
|
|
214
|
+
model_kwargs : dict, optional
|
|
215
|
+
UNetModel parameters, by default {}.
|
|
216
|
+
|
|
217
|
+
Returns
|
|
218
|
+
-------
|
|
219
|
+
Configuration
|
|
220
|
+
Configuration for training CARE.
|
|
221
|
+
"""
|
|
222
|
+
if n_channels_out == -1:
|
|
223
|
+
n_channels_out = n_channels_in
|
|
224
|
+
|
|
225
|
+
return _create_supervised_configuration(
|
|
226
|
+
algorithm="care",
|
|
227
|
+
experiment_name=experiment_name,
|
|
228
|
+
data_type=data_type,
|
|
229
|
+
axes=axes,
|
|
230
|
+
patch_size=patch_size,
|
|
231
|
+
batch_size=batch_size,
|
|
232
|
+
num_epochs=num_epochs,
|
|
233
|
+
use_augmentations=use_augmentations,
|
|
234
|
+
independent_channels=independent_channels,
|
|
235
|
+
loss=loss,
|
|
236
|
+
n_channels_in=n_channels_in,
|
|
237
|
+
n_channels_out=n_channels_out,
|
|
238
|
+
logger=logger,
|
|
239
|
+
model_kwargs=model_kwargs,
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def create_n2n_configuration(
|
|
244
|
+
experiment_name: str,
|
|
245
|
+
data_type: Literal["array", "tiff", "custom"],
|
|
246
|
+
axes: str,
|
|
247
|
+
patch_size: List[int],
|
|
248
|
+
batch_size: int,
|
|
249
|
+
num_epochs: int,
|
|
250
|
+
use_augmentations: bool = True,
|
|
251
|
+
independent_channels: bool = False,
|
|
252
|
+
loss: Literal["mae", "mse"] = "mae",
|
|
253
|
+
n_channels: int = 1,
|
|
254
|
+
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
255
|
+
model_kwargs: Optional[dict] = None,
|
|
256
|
+
) -> Configuration:
|
|
257
|
+
"""
|
|
258
|
+
Create a configuration for training Noise2Noise.
|
|
259
|
+
|
|
260
|
+
If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
|
|
261
|
+
2.
|
|
262
|
+
|
|
263
|
+
If "C" is present in `axes`, then you need to set `n_channels` to the number of
|
|
264
|
+
channels. Likewise, if you set the number of channels, then "C" must be present in
|
|
265
|
+
`axes`.
|
|
266
|
+
|
|
267
|
+
By default, all channels are trained together. To train all channels independently,
|
|
268
|
+
set `independent_channels` to True.
|
|
269
|
+
|
|
270
|
+
By setting `use_augmentations` to False, the only transformation applied will be
|
|
271
|
+
normalization.
|
|
272
|
+
|
|
273
|
+
Parameters
|
|
274
|
+
----------
|
|
275
|
+
experiment_name : str
|
|
276
|
+
Name of the experiment.
|
|
277
|
+
data_type : Literal["array", "tiff", "custom"]
|
|
278
|
+
Type of the data.
|
|
279
|
+
axes : str
|
|
280
|
+
Axes of the data (e.g. SYX).
|
|
281
|
+
patch_size : List[int]
|
|
282
|
+
Size of the patches along the spatial dimensions (e.g. [64, 64]).
|
|
283
|
+
batch_size : int
|
|
284
|
+
Batch size.
|
|
285
|
+
num_epochs : int
|
|
286
|
+
Number of epochs.
|
|
287
|
+
use_augmentations : bool, optional
|
|
288
|
+
Whether to use augmentations, by default True.
|
|
289
|
+
independent_channels : bool, optional
|
|
290
|
+
Whether to train all channels independently, by default False.
|
|
291
|
+
loss : Literal["mae", "mse"], optional
|
|
292
|
+
Loss function to use, by default "mae".
|
|
293
|
+
n_channels : int, optional
|
|
294
|
+
Number of channels (in and out), by default 1.
|
|
295
|
+
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
296
|
+
Logger to use, by default "none".
|
|
297
|
+
model_kwargs : dict, optional
|
|
298
|
+
UNetModel parameters, by default {}.
|
|
299
|
+
|
|
300
|
+
Returns
|
|
301
|
+
-------
|
|
302
|
+
Configuration
|
|
303
|
+
Configuration for training Noise2Noise.
|
|
304
|
+
"""
|
|
305
|
+
return _create_supervised_configuration(
|
|
306
|
+
algorithm="n2n",
|
|
307
|
+
experiment_name=experiment_name,
|
|
308
|
+
data_type=data_type,
|
|
309
|
+
axes=axes,
|
|
310
|
+
patch_size=patch_size,
|
|
311
|
+
batch_size=batch_size,
|
|
312
|
+
num_epochs=num_epochs,
|
|
313
|
+
use_augmentations=use_augmentations,
|
|
314
|
+
independent_channels=independent_channels,
|
|
315
|
+
loss=loss,
|
|
316
|
+
n_channels_in=n_channels,
|
|
317
|
+
n_channels_out=n_channels,
|
|
318
|
+
logger=logger,
|
|
319
|
+
model_kwargs=model_kwargs,
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
|
|
155
323
|
def create_n2v_configuration(
|
|
156
324
|
experiment_name: str,
|
|
157
325
|
data_type: Literal["array", "tiff", "custom"],
|
|
@@ -160,8 +328,9 @@ def create_n2v_configuration(
|
|
|
160
328
|
batch_size: int,
|
|
161
329
|
num_epochs: int,
|
|
162
330
|
use_augmentations: bool = True,
|
|
331
|
+
independent_channels: bool = True,
|
|
163
332
|
use_n2v2: bool = False,
|
|
164
|
-
n_channels: int =
|
|
333
|
+
n_channels: int = 1,
|
|
165
334
|
roi_size: int = 11,
|
|
166
335
|
masked_pixel_percentage: float = 0.2,
|
|
167
336
|
struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
|
|
@@ -170,7 +339,7 @@ def create_n2v_configuration(
|
|
|
170
339
|
model_kwargs: Optional[dict] = None,
|
|
171
340
|
) -> Configuration:
|
|
172
341
|
"""
|
|
173
|
-
Create a configuration for training
|
|
342
|
+
Create a configuration for training Noise2Void.
|
|
174
343
|
|
|
175
344
|
N2V uses a UNet model to denoise images in a self-supervised manner. To use its
|
|
176
345
|
variants structN2V and N2V2, set the `struct_n2v_axis` and `struct_n2v_span`
|
|
@@ -181,11 +350,14 @@ def create_n2v_configuration(
|
|
|
181
350
|
or horizontal correlations are present in the noise; it applies an additional mask
|
|
182
351
|
to the manipulated pixel neighbors.
|
|
183
352
|
|
|
353
|
+
If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
|
|
354
|
+
2.
|
|
355
|
+
|
|
184
356
|
If "C" is present in `axes`, then you need to set `n_channels` to the number of
|
|
185
357
|
channels.
|
|
186
358
|
|
|
187
|
-
|
|
188
|
-
|
|
359
|
+
By default, all channels are trained independently. To train all channels together,
|
|
360
|
+
set `independent_channels` to False.
|
|
189
361
|
|
|
190
362
|
By setting `use_augmentations` to False, the only transformations applied will be
|
|
191
363
|
normalization and N2V manipulation.
|
|
@@ -217,10 +389,12 @@ def create_n2v_configuration(
|
|
|
217
389
|
Number of epochs.
|
|
218
390
|
use_augmentations : bool, optional
|
|
219
391
|
Whether to use augmentations, by default True.
|
|
392
|
+
independent_channels : bool, optional
|
|
393
|
+
Whether to train all channels together, by default True.
|
|
220
394
|
use_n2v2 : bool, optional
|
|
221
395
|
Whether to use N2V2, by default False.
|
|
222
396
|
n_channels : int, optional
|
|
223
|
-
Number of channels (in and out), by default
|
|
397
|
+
Number of channels (in and out), by default 1.
|
|
224
398
|
roi_size : int, optional
|
|
225
399
|
N2V pixel manipulation area, by default 11.
|
|
226
400
|
masked_pixel_percentage : float, optional
|
|
@@ -275,8 +449,20 @@ def create_n2v_configuration(
|
|
|
275
449
|
... struct_n2v_span=7
|
|
276
450
|
... )
|
|
277
451
|
|
|
278
|
-
If you are training multiple channels
|
|
279
|
-
of channels:
|
|
452
|
+
If you are training multiple channels independently, then you need to specify the
|
|
453
|
+
number of channels:
|
|
454
|
+
>>> config = create_n2v_configuration(
|
|
455
|
+
... experiment_name="n2v_experiment",
|
|
456
|
+
... data_type="array",
|
|
457
|
+
... axes="YXC",
|
|
458
|
+
... patch_size=[64, 64],
|
|
459
|
+
... batch_size=32,
|
|
460
|
+
... num_epochs=100,
|
|
461
|
+
... n_channels=3
|
|
462
|
+
... )
|
|
463
|
+
|
|
464
|
+
If instead you want to train multiple channels together, you need to turn off the
|
|
465
|
+
`independent_channels` parameter:
|
|
280
466
|
>>> config = create_n2v_configuration(
|
|
281
467
|
... experiment_name="n2v_experiment",
|
|
282
468
|
... data_type="array",
|
|
@@ -284,6 +470,7 @@ def create_n2v_configuration(
|
|
|
284
470
|
... patch_size=[64, 64],
|
|
285
471
|
... batch_size=32,
|
|
286
472
|
... num_epochs=100,
|
|
473
|
+
... independent_channels=False,
|
|
287
474
|
... n_channels=3
|
|
288
475
|
... )
|
|
289
476
|
|
|
@@ -300,18 +487,16 @@ def create_n2v_configuration(
|
|
|
300
487
|
... )
|
|
301
488
|
"""
|
|
302
489
|
# if there are channels, we need to specify their number
|
|
303
|
-
if "C" in axes and n_channels ==
|
|
490
|
+
if "C" in axes and n_channels == 1:
|
|
304
491
|
raise ValueError(
|
|
305
492
|
f"Number of channels must be specified when using channels "
|
|
306
493
|
f"(got {n_channels} channel)."
|
|
307
494
|
)
|
|
308
|
-
elif "C" not in axes and n_channels
|
|
495
|
+
elif "C" not in axes and n_channels > 1:
|
|
309
496
|
raise ValueError(
|
|
310
497
|
f"C is not present in the axes, but number of channels is specified "
|
|
311
498
|
f"(got {n_channels} channel)."
|
|
312
499
|
)
|
|
313
|
-
elif n_channels == -1:
|
|
314
|
-
n_channels = 1
|
|
315
500
|
|
|
316
501
|
# model
|
|
317
502
|
if model_kwargs is None:
|
|
@@ -320,6 +505,7 @@ def create_n2v_configuration(
|
|
|
320
505
|
model_kwargs["conv_dims"] = 3 if "Z" in axes else 2
|
|
321
506
|
model_kwargs["in_channels"] = n_channels
|
|
322
507
|
model_kwargs["num_classes"] = n_channels
|
|
508
|
+
model_kwargs["independent_channels"] = independent_channels
|
|
323
509
|
|
|
324
510
|
unet_model = UNetModel(
|
|
325
511
|
architecture=SupportedArchitecture.UNET.value,
|
|
@@ -327,7 +513,7 @@ def create_n2v_configuration(
|
|
|
327
513
|
)
|
|
328
514
|
|
|
329
515
|
# algorithm model
|
|
330
|
-
algorithm =
|
|
516
|
+
algorithm = AlgorithmConfig(
|
|
331
517
|
algorithm=SupportedAlgorithm.N2V.value,
|
|
332
518
|
loss=SupportedLoss.N2V.value,
|
|
333
519
|
model=unet_model,
|
|
@@ -356,9 +542,11 @@ def create_n2v_configuration(
|
|
|
356
542
|
# n2v2 and structn2v
|
|
357
543
|
nv2_transform = {
|
|
358
544
|
"name": SupportedTransform.N2V_MANIPULATE.value,
|
|
359
|
-
"strategy":
|
|
360
|
-
|
|
361
|
-
|
|
545
|
+
"strategy": (
|
|
546
|
+
SupportedPixelManipulation.MEDIAN.value
|
|
547
|
+
if use_n2v2
|
|
548
|
+
else SupportedPixelManipulation.UNIFORM.value
|
|
549
|
+
),
|
|
362
550
|
"roi_size": roi_size,
|
|
363
551
|
"masked_pixel_percentage": masked_pixel_percentage,
|
|
364
552
|
"struct_mask_axis": struct_n2v_axis,
|
|
@@ -367,7 +555,7 @@ def create_n2v_configuration(
|
|
|
367
555
|
transforms.append(nv2_transform)
|
|
368
556
|
|
|
369
557
|
# data model
|
|
370
|
-
data =
|
|
558
|
+
data = DataConfig(
|
|
371
559
|
data_type=data_type,
|
|
372
560
|
axes=axes,
|
|
373
561
|
patch_size=patch_size,
|
|
@@ -376,7 +564,7 @@ def create_n2v_configuration(
|
|
|
376
564
|
)
|
|
377
565
|
|
|
378
566
|
# training model
|
|
379
|
-
training =
|
|
567
|
+
training = TrainingConfig(
|
|
380
568
|
num_epochs=num_epochs,
|
|
381
569
|
batch_size=batch_size,
|
|
382
570
|
logger=None if logger == "none" else logger,
|
|
@@ -393,17 +581,16 @@ def create_n2v_configuration(
|
|
|
393
581
|
return configuration
|
|
394
582
|
|
|
395
583
|
|
|
396
|
-
# TODO add tests
|
|
397
584
|
def create_inference_configuration(
|
|
398
|
-
|
|
585
|
+
configuration: Configuration,
|
|
399
586
|
tile_size: Optional[Tuple[int, ...]] = None,
|
|
400
587
|
tile_overlap: Optional[Tuple[int, ...]] = None,
|
|
401
588
|
data_type: Optional[Literal["array", "tiff", "custom"]] = None,
|
|
402
589
|
axes: Optional[str] = None,
|
|
403
|
-
transforms: Optional[Union[List[Dict[str, Any]]
|
|
590
|
+
transforms: Optional[Union[List[Dict[str, Any]]]] = None,
|
|
404
591
|
tta_transforms: bool = True,
|
|
405
592
|
batch_size: Optional[int] = 1,
|
|
406
|
-
) ->
|
|
593
|
+
) -> InferenceConfig:
|
|
407
594
|
"""
|
|
408
595
|
Create a configuration for inference with N2V.
|
|
409
596
|
|
|
@@ -412,8 +599,8 @@ def create_inference_configuration(
|
|
|
412
599
|
|
|
413
600
|
Parameters
|
|
414
601
|
----------
|
|
415
|
-
|
|
416
|
-
|
|
602
|
+
configuration : Configuration
|
|
603
|
+
Global configuration.
|
|
417
604
|
tile_size : Tuple[int, ...], optional
|
|
418
605
|
Size of the tiles.
|
|
419
606
|
tile_overlap : Tuple[int, ...], optional
|
|
@@ -422,7 +609,7 @@ def create_inference_configuration(
|
|
|
422
609
|
Type of the data, by default "tiff".
|
|
423
610
|
axes : str, optional
|
|
424
611
|
Axes of the data, by default "YX".
|
|
425
|
-
transforms : List[Dict[str, Any]]
|
|
612
|
+
transforms : List[Dict[str, Any]], optional
|
|
426
613
|
Transformations to apply to the data, by default None.
|
|
427
614
|
tta_transforms : bool, optional
|
|
428
615
|
Whether to apply test-time augmentations, by default True.
|
|
@@ -432,14 +619,12 @@ def create_inference_configuration(
|
|
|
432
619
|
Returns
|
|
433
620
|
-------
|
|
434
621
|
InferenceConfiguration
|
|
435
|
-
Configuration
|
|
622
|
+
Configuration used to configure CAREamicsPredictData.
|
|
436
623
|
"""
|
|
437
|
-
if
|
|
438
|
-
|
|
439
|
-
or training_configuration.data_config.std is None
|
|
440
|
-
):
|
|
441
|
-
raise ValueError("Mean and std must be provided in the training configuration.")
|
|
624
|
+
if configuration.data_config.mean is None or configuration.data_config.std is None:
|
|
625
|
+
raise ValueError("Mean and std must be provided in the configuration.")
|
|
442
626
|
|
|
627
|
+
# minimum transform
|
|
443
628
|
if transforms is None:
|
|
444
629
|
transforms = [
|
|
445
630
|
{
|
|
@@ -447,13 +632,35 @@ def create_inference_configuration(
|
|
|
447
632
|
},
|
|
448
633
|
]
|
|
449
634
|
|
|
450
|
-
|
|
451
|
-
|
|
635
|
+
# tile size for UNets
|
|
636
|
+
if tile_size is not None:
|
|
637
|
+
model = configuration.algorithm_config.model
|
|
638
|
+
|
|
639
|
+
if model.architecture == SupportedArchitecture.UNET.value:
|
|
640
|
+
# tile size must be equal to k*2^n, where n is the number of pooling layers
|
|
641
|
+
# (equal to the depth) and k is an integer
|
|
642
|
+
depth = model.depth
|
|
643
|
+
tile_increment = 2**depth
|
|
644
|
+
|
|
645
|
+
for i, t in enumerate(tile_size):
|
|
646
|
+
if t % tile_increment != 0:
|
|
647
|
+
raise ValueError(
|
|
648
|
+
f"Tile size must be divisible by {tile_increment} along all "
|
|
649
|
+
f"axes (got {t} for axis {i}). If your image size is smaller "
|
|
650
|
+
f"along one axis (e.g. Z), consider padding the image."
|
|
651
|
+
)
|
|
652
|
+
|
|
653
|
+
# tile overlaps must be specified
|
|
654
|
+
if tile_overlap is None:
|
|
655
|
+
raise ValueError("Tile overlap must be specified.")
|
|
656
|
+
|
|
657
|
+
return InferenceConfig(
|
|
658
|
+
data_type=data_type or configuration.data_config.data_type,
|
|
452
659
|
tile_size=tile_size,
|
|
453
660
|
tile_overlap=tile_overlap,
|
|
454
|
-
axes=axes or
|
|
455
|
-
mean=
|
|
456
|
-
std=
|
|
661
|
+
axes=axes or configuration.data_config.axes,
|
|
662
|
+
mean=configuration.data_config.mean,
|
|
663
|
+
std=configuration.data_config.std,
|
|
457
664
|
transforms=transforms,
|
|
458
665
|
tta_transforms=tta_transforms,
|
|
459
666
|
batch_size=batch_size,
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""Pydantic CAREamics configuration."""
|
|
2
|
+
|
|
2
3
|
from __future__ import annotations
|
|
3
4
|
|
|
4
5
|
import re
|
|
@@ -11,8 +12,8 @@ from bioimageio.spec.generic.v0_3 import CiteEntry
|
|
|
11
12
|
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
|
12
13
|
from typing_extensions import Self
|
|
13
14
|
|
|
14
|
-
from .algorithm_model import
|
|
15
|
-
from .data_model import
|
|
15
|
+
from .algorithm_model import AlgorithmConfig
|
|
16
|
+
from .data_model import DataConfig
|
|
16
17
|
from .references import (
|
|
17
18
|
CARE,
|
|
18
19
|
CUSTOM,
|
|
@@ -34,7 +35,7 @@ from .references import (
|
|
|
34
35
|
StructN2VRef,
|
|
35
36
|
)
|
|
36
37
|
from .support import SupportedAlgorithm, SupportedPixelManipulation, SupportedTransform
|
|
37
|
-
from .training_model import
|
|
38
|
+
from .training_model import TrainingConfig
|
|
38
39
|
from .transformations.n2v_manipulate_model import (
|
|
39
40
|
N2VManipulateModel,
|
|
40
41
|
)
|
|
@@ -156,9 +157,10 @@ class Configuration(BaseModel):
|
|
|
156
157
|
)
|
|
157
158
|
|
|
158
159
|
# Sub-configurations
|
|
159
|
-
algorithm_config:
|
|
160
|
-
|
|
161
|
-
|
|
160
|
+
algorithm_config: AlgorithmConfig
|
|
161
|
+
|
|
162
|
+
data_config: DataConfig
|
|
163
|
+
training_config: TrainingConfig
|
|
162
164
|
|
|
163
165
|
@field_validator("experiment_name")
|
|
164
166
|
@classmethod
|
|
@@ -237,25 +239,22 @@ class Configuration(BaseModel):
|
|
|
237
239
|
Validated configuration.
|
|
238
240
|
"""
|
|
239
241
|
if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
|
|
240
|
-
#
|
|
241
|
-
if self.data_config.
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
N2VManipulateModel(
|
|
246
|
-
name=SupportedTransform.N2V_MANIPULATE.value,
|
|
247
|
-
)
|
|
242
|
+
# missing N2V_MANIPULATE
|
|
243
|
+
if not self.data_config.has_n2v_manipulate():
|
|
244
|
+
self.data_config.transforms.append(
|
|
245
|
+
N2VManipulateModel(
|
|
246
|
+
name=SupportedTransform.N2V_MANIPULATE.value,
|
|
248
247
|
)
|
|
248
|
+
)
|
|
249
249
|
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
250
|
+
median = SupportedPixelManipulation.MEDIAN.value
|
|
251
|
+
uniform = SupportedPixelManipulation.UNIFORM.value
|
|
252
|
+
strategy = median if self.algorithm_config.model.n2v2 else uniform
|
|
253
|
+
self.data_config.set_N2V2_strategy(strategy)
|
|
254
254
|
else:
|
|
255
|
-
#
|
|
256
|
-
if self.data_config.
|
|
257
|
-
|
|
258
|
-
self.data_config.remove_n2v_manipulate()
|
|
255
|
+
# remove N2V manipulate if present
|
|
256
|
+
if self.data_config.has_n2v_manipulate():
|
|
257
|
+
self.data_config.remove_n2v_manipulate()
|
|
259
258
|
|
|
260
259
|
return self
|
|
261
260
|
|
|
@@ -591,6 +590,6 @@ def save_configuration(config: Configuration, path: Union[str, Path]) -> Path:
|
|
|
591
590
|
# save configuration as dictionary to yaml
|
|
592
591
|
with open(config_path, "w") as f:
|
|
593
592
|
# dump configuration
|
|
594
|
-
yaml.dump(config.model_dump(), f, default_flow_style=False)
|
|
593
|
+
yaml.dump(config.model_dump(), f, default_flow_style=False, sort_keys=False)
|
|
595
594
|
|
|
596
595
|
return config_path
|