careamics 0.1.0rc3__py3-none-any.whl → 0.1.0rc4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +8 -6
- careamics/careamist.py +18 -18
- careamics/config/__init__.py +12 -8
- careamics/config/algorithm_model.py +5 -5
- careamics/config/configuration_example.py +89 -0
- careamics/config/configuration_factory.py +187 -50
- careamics/config/configuration_model.py +8 -7
- careamics/config/data_model.py +3 -3
- careamics/config/inference_model.py +1 -1
- careamics/config/support/supported_optimizers.py +3 -3
- careamics/config/training_model.py +1 -1
- careamics/config/transformations/n2v_manipulate_model.py +1 -1
- careamics/config/transformations/nd_flip_model.py +1 -1
- careamics/config/transformations/normalize_model.py +1 -1
- careamics/config/transformations/xy_random_rotate90_model.py +1 -1
- careamics/dataset/in_memory_dataset.py +3 -3
- careamics/dataset/iterable_dataset.py +3 -3
- careamics/lightning_datamodule.py +103 -25
- careamics/lightning_module.py +6 -6
- careamics/lightning_prediction_datamodule.py +44 -38
- careamics/model_io/bioimage/model_description.py +3 -3
- careamics/model_io/bmz_io.py +6 -6
- careamics/model_io/model_io_utils.py +4 -4
- {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc4.dist-info}/METADATA +1 -1
- {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc4.dist-info}/RECORD +27 -26
- {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc4.dist-info}/WHEEL +0 -0
- {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc4.dist-info}/licenses/LICENSE +0 -0
|
@@ -4,11 +4,11 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
|
|
4
4
|
|
|
5
5
|
from albumentations import Compose
|
|
6
6
|
|
|
7
|
-
from .algorithm_model import
|
|
7
|
+
from .algorithm_model import AlgorithmConfig
|
|
8
8
|
from .architectures import UNetModel
|
|
9
9
|
from .configuration_model import Configuration
|
|
10
|
-
from .data_model import
|
|
11
|
-
from .inference_model import
|
|
10
|
+
from .data_model import DataConfig
|
|
11
|
+
from .inference_model import InferenceConfig
|
|
12
12
|
from .support import (
|
|
13
13
|
SupportedAlgorithm,
|
|
14
14
|
SupportedArchitecture,
|
|
@@ -16,10 +16,11 @@ from .support import (
|
|
|
16
16
|
SupportedPixelManipulation,
|
|
17
17
|
SupportedTransform,
|
|
18
18
|
)
|
|
19
|
-
from .training_model import
|
|
19
|
+
from .training_model import TrainingConfig
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
def
|
|
22
|
+
def _create_supervised_configuration(
|
|
23
|
+
algorithm: Literal["care", "n2n"],
|
|
23
24
|
experiment_name: str,
|
|
24
25
|
data_type: Literal["array", "tiff", "custom"],
|
|
25
26
|
axes: str,
|
|
@@ -27,28 +28,18 @@ def create_n2n_configuration(
|
|
|
27
28
|
batch_size: int,
|
|
28
29
|
num_epochs: int,
|
|
29
30
|
use_augmentations: bool = True,
|
|
30
|
-
|
|
31
|
-
n_channels: int = 1,
|
|
31
|
+
loss: Literal["mae", "mse"] = "mae",
|
|
32
|
+
n_channels: int = -1,
|
|
32
33
|
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
33
34
|
model_kwargs: Optional[dict] = None,
|
|
34
35
|
) -> Configuration:
|
|
35
36
|
"""
|
|
36
|
-
Create a configuration for training
|
|
37
|
-
|
|
38
|
-
If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
|
|
39
|
-
2.
|
|
40
|
-
|
|
41
|
-
By setting `use_augmentations` to False, the only transformation applied will be
|
|
42
|
-
normalization and N2V manipulation.
|
|
43
|
-
|
|
44
|
-
The parameter `use_n2v2` overrides the corresponding `n2v2` that can be passed
|
|
45
|
-
in `model_kwargs`.
|
|
46
|
-
|
|
47
|
-
If you pass "horizontal" or "vertical" to `struct_n2v_axis`, then structN2V mask
|
|
48
|
-
will be applied to each manipulated pixel.
|
|
37
|
+
Create a configuration for training CARE or Noise2Noise.
|
|
49
38
|
|
|
50
39
|
Parameters
|
|
51
40
|
----------
|
|
41
|
+
algorithm : Literal["care", "n2n"]
|
|
42
|
+
Algorithm to use.
|
|
52
43
|
experiment_name : str
|
|
53
44
|
Name of the experiment.
|
|
54
45
|
data_type : Literal["array", "tiff", "custom"]
|
|
@@ -63,18 +54,10 @@ def create_n2n_configuration(
|
|
|
63
54
|
Number of epochs.
|
|
64
55
|
use_augmentations : bool, optional
|
|
65
56
|
Whether to use augmentations, by default True.
|
|
66
|
-
|
|
67
|
-
|
|
57
|
+
loss : Literal["mae", "mse"], optional
|
|
58
|
+
Loss function to use, by default "mae".
|
|
68
59
|
n_channels : int, optional
|
|
69
|
-
Number of channels (in and out), by default 1.
|
|
70
|
-
roi_size : int, optional
|
|
71
|
-
N2V pixel manipulation area, by default 11.
|
|
72
|
-
masked_pixel_percentage : float, optional
|
|
73
|
-
Percentage of pixels masked in each patch, by default 0.2.
|
|
74
|
-
struct_n2v_axis : Literal["horizontal", "vertical", "none"], optional
|
|
75
|
-
Axis along which to apply structN2V mask, by default "none".
|
|
76
|
-
struct_n2v_span : int, optional
|
|
77
|
-
Span of the structN2V mask, by default 5.
|
|
60
|
+
Number of channels (in and out), by default -1.
|
|
78
61
|
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
79
62
|
Logger to use, by default "none".
|
|
80
63
|
model_kwargs : dict, optional
|
|
@@ -83,12 +66,23 @@ def create_n2n_configuration(
|
|
|
83
66
|
Returns
|
|
84
67
|
-------
|
|
85
68
|
Configuration
|
|
86
|
-
Configuration for training
|
|
69
|
+
Configuration for training CARE or Noise2Noise.
|
|
87
70
|
"""
|
|
71
|
+
# if there are channels, we need to specify their number
|
|
72
|
+
if "C" in axes and n_channels == 1:
|
|
73
|
+
raise ValueError(
|
|
74
|
+
f"Number of channels must be specified when using channels "
|
|
75
|
+
f"(got {n_channels} channel)."
|
|
76
|
+
)
|
|
77
|
+
elif "C" not in axes and n_channels > 1:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
f"C is not present in the axes, but number of channels is specified "
|
|
80
|
+
f"(got {n_channels} channel)."
|
|
81
|
+
)
|
|
82
|
+
|
|
88
83
|
# model
|
|
89
84
|
if model_kwargs is None:
|
|
90
85
|
model_kwargs = {}
|
|
91
|
-
model_kwargs["n2v2"] = use_n2v2
|
|
92
86
|
model_kwargs["conv_dims"] = 3 if "Z" in axes else 2
|
|
93
87
|
model_kwargs["in_channels"] = n_channels
|
|
94
88
|
model_kwargs["num_classes"] = n_channels
|
|
@@ -99,9 +93,9 @@ def create_n2n_configuration(
|
|
|
99
93
|
)
|
|
100
94
|
|
|
101
95
|
# algorithm model
|
|
102
|
-
algorithm =
|
|
103
|
-
algorithm=
|
|
104
|
-
loss=
|
|
96
|
+
algorithm = AlgorithmConfig(
|
|
97
|
+
algorithm=algorithm,
|
|
98
|
+
loss=loss,
|
|
105
99
|
model=unet_model,
|
|
106
100
|
)
|
|
107
101
|
|
|
@@ -126,7 +120,7 @@ def create_n2n_configuration(
|
|
|
126
120
|
]
|
|
127
121
|
|
|
128
122
|
# data model
|
|
129
|
-
data =
|
|
123
|
+
data = DataConfig(
|
|
130
124
|
data_type=data_type,
|
|
131
125
|
axes=axes,
|
|
132
126
|
patch_size=patch_size,
|
|
@@ -135,7 +129,7 @@ def create_n2n_configuration(
|
|
|
135
129
|
)
|
|
136
130
|
|
|
137
131
|
# training model
|
|
138
|
-
training =
|
|
132
|
+
training = TrainingConfig(
|
|
139
133
|
num_epochs=num_epochs,
|
|
140
134
|
batch_size=batch_size,
|
|
141
135
|
logger=None if logger == "none" else logger,
|
|
@@ -152,6 +146,151 @@ def create_n2n_configuration(
|
|
|
152
146
|
return configuration
|
|
153
147
|
|
|
154
148
|
|
|
149
|
+
def create_care_configuration(
|
|
150
|
+
experiment_name: str,
|
|
151
|
+
data_type: Literal["array", "tiff", "custom"],
|
|
152
|
+
axes: str,
|
|
153
|
+
patch_size: List[int],
|
|
154
|
+
batch_size: int,
|
|
155
|
+
num_epochs: int,
|
|
156
|
+
use_augmentations: bool = True,
|
|
157
|
+
loss: Literal["mae", "mse"] = "mae",
|
|
158
|
+
n_channels: int = 1,
|
|
159
|
+
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
160
|
+
model_kwargs: Optional[dict] = None,
|
|
161
|
+
) -> Configuration:
|
|
162
|
+
"""
|
|
163
|
+
Create a configuration for training CARE.
|
|
164
|
+
|
|
165
|
+
If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
|
|
166
|
+
2.
|
|
167
|
+
|
|
168
|
+
If "C" is present in `axes`, then you need to set `n_channels` to the number of
|
|
169
|
+
channels. Likewise, if you set the number of channels, then "C" must be present in
|
|
170
|
+
`axes`.
|
|
171
|
+
|
|
172
|
+
By setting `use_augmentations` to False, the only transformation applied will be
|
|
173
|
+
normalization.
|
|
174
|
+
|
|
175
|
+
Parameters
|
|
176
|
+
----------
|
|
177
|
+
experiment_name : str
|
|
178
|
+
Name of the experiment.
|
|
179
|
+
data_type : Literal["array", "tiff", "custom"]
|
|
180
|
+
Type of the data.
|
|
181
|
+
axes : str
|
|
182
|
+
Axes of the data (e.g. SYX).
|
|
183
|
+
patch_size : List[int]
|
|
184
|
+
Size of the patches along the spatial dimensions (e.g. [64, 64]).
|
|
185
|
+
batch_size : int
|
|
186
|
+
Batch size.
|
|
187
|
+
num_epochs : int
|
|
188
|
+
Number of epochs.
|
|
189
|
+
use_augmentations : bool, optional
|
|
190
|
+
Whether to use augmentations, by default True.
|
|
191
|
+
loss : Literal["mae", "mse"], optional
|
|
192
|
+
Loss function to use, by default "mae".
|
|
193
|
+
n_channels : int, optional
|
|
194
|
+
Number of channels (in and out), by default 1.
|
|
195
|
+
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
196
|
+
Logger to use, by default "none".
|
|
197
|
+
model_kwargs : dict, optional
|
|
198
|
+
UNetModel parameters, by default {}.
|
|
199
|
+
|
|
200
|
+
Returns
|
|
201
|
+
-------
|
|
202
|
+
Configuration
|
|
203
|
+
Configuration for training CARE.
|
|
204
|
+
"""
|
|
205
|
+
return _create_supervised_configuration(
|
|
206
|
+
algorithm="care",
|
|
207
|
+
experiment_name=experiment_name,
|
|
208
|
+
data_type=data_type,
|
|
209
|
+
axes=axes,
|
|
210
|
+
patch_size=patch_size,
|
|
211
|
+
batch_size=batch_size,
|
|
212
|
+
num_epochs=num_epochs,
|
|
213
|
+
use_augmentations=use_augmentations,
|
|
214
|
+
loss=loss,
|
|
215
|
+
# TODO in the future we might support different in and out channels for CARE
|
|
216
|
+
n_channels=n_channels,
|
|
217
|
+
logger=logger,
|
|
218
|
+
model_kwargs=model_kwargs,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def create_n2n_configuration(
|
|
223
|
+
experiment_name: str,
|
|
224
|
+
data_type: Literal["array", "tiff", "custom"],
|
|
225
|
+
axes: str,
|
|
226
|
+
patch_size: List[int],
|
|
227
|
+
batch_size: int,
|
|
228
|
+
num_epochs: int,
|
|
229
|
+
use_augmentations: bool = True,
|
|
230
|
+
loss: Literal["mae", "mse"] = "mae",
|
|
231
|
+
n_channels: int = 1,
|
|
232
|
+
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
233
|
+
model_kwargs: Optional[dict] = None,
|
|
234
|
+
) -> Configuration:
|
|
235
|
+
"""
|
|
236
|
+
Create a configuration for training Noise2Noise.
|
|
237
|
+
|
|
238
|
+
If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
|
|
239
|
+
2.
|
|
240
|
+
|
|
241
|
+
If "C" is present in `axes`, then you need to set `n_channels` to the number of
|
|
242
|
+
channels. Likewise, if you set the number of channels, then "C" must be present in
|
|
243
|
+
`axes`.
|
|
244
|
+
|
|
245
|
+
By setting `use_augmentations` to False, the only transformation applied will be
|
|
246
|
+
normalization.
|
|
247
|
+
|
|
248
|
+
Parameters
|
|
249
|
+
----------
|
|
250
|
+
experiment_name : str
|
|
251
|
+
Name of the experiment.
|
|
252
|
+
data_type : Literal["array", "tiff", "custom"]
|
|
253
|
+
Type of the data.
|
|
254
|
+
axes : str
|
|
255
|
+
Axes of the data (e.g. SYX).
|
|
256
|
+
patch_size : List[int]
|
|
257
|
+
Size of the patches along the spatial dimensions (e.g. [64, 64]).
|
|
258
|
+
batch_size : int
|
|
259
|
+
Batch size.
|
|
260
|
+
num_epochs : int
|
|
261
|
+
Number of epochs.
|
|
262
|
+
use_augmentations : bool, optional
|
|
263
|
+
Whether to use augmentations, by default True.
|
|
264
|
+
loss : Literal["mae", "mse"], optional
|
|
265
|
+
Loss function to use, by default "mae".
|
|
266
|
+
n_channels : int, optional
|
|
267
|
+
Number of channels (in and out), by default 1.
|
|
268
|
+
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
269
|
+
Logger to use, by default "none".
|
|
270
|
+
model_kwargs : dict, optional
|
|
271
|
+
UNetModel parameters, by default {}.
|
|
272
|
+
|
|
273
|
+
Returns
|
|
274
|
+
-------
|
|
275
|
+
Configuration
|
|
276
|
+
Configuration for training Noise2Noise.
|
|
277
|
+
"""
|
|
278
|
+
return _create_supervised_configuration(
|
|
279
|
+
algorithm="n2n",
|
|
280
|
+
experiment_name=experiment_name,
|
|
281
|
+
data_type=data_type,
|
|
282
|
+
axes=axes,
|
|
283
|
+
patch_size=patch_size,
|
|
284
|
+
batch_size=batch_size,
|
|
285
|
+
num_epochs=num_epochs,
|
|
286
|
+
use_augmentations=use_augmentations,
|
|
287
|
+
loss=loss,
|
|
288
|
+
n_channels=n_channels,
|
|
289
|
+
logger=logger,
|
|
290
|
+
model_kwargs=model_kwargs,
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
|
|
155
294
|
def create_n2v_configuration(
|
|
156
295
|
experiment_name: str,
|
|
157
296
|
data_type: Literal["array", "tiff", "custom"],
|
|
@@ -161,7 +300,7 @@ def create_n2v_configuration(
|
|
|
161
300
|
num_epochs: int,
|
|
162
301
|
use_augmentations: bool = True,
|
|
163
302
|
use_n2v2: bool = False,
|
|
164
|
-
n_channels: int =
|
|
303
|
+
n_channels: int = 1,
|
|
165
304
|
roi_size: int = 11,
|
|
166
305
|
masked_pixel_percentage: float = 0.2,
|
|
167
306
|
struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
|
|
@@ -170,7 +309,7 @@ def create_n2v_configuration(
|
|
|
170
309
|
model_kwargs: Optional[dict] = None,
|
|
171
310
|
) -> Configuration:
|
|
172
311
|
"""
|
|
173
|
-
Create a configuration for training
|
|
312
|
+
Create a configuration for training Noise2Void.
|
|
174
313
|
|
|
175
314
|
N2V uses a UNet model to denoise images in a self-supervised manner. To use its
|
|
176
315
|
variants structN2V and N2V2, set the `struct_n2v_axis` and `struct_n2v_span`
|
|
@@ -220,7 +359,7 @@ def create_n2v_configuration(
|
|
|
220
359
|
use_n2v2 : bool, optional
|
|
221
360
|
Whether to use N2V2, by default False.
|
|
222
361
|
n_channels : int, optional
|
|
223
|
-
Number of channels (in and out), by default
|
|
362
|
+
Number of channels (in and out), by default 1.
|
|
224
363
|
roi_size : int, optional
|
|
225
364
|
N2V pixel manipulation area, by default 11.
|
|
226
365
|
masked_pixel_percentage : float, optional
|
|
@@ -300,18 +439,16 @@ def create_n2v_configuration(
|
|
|
300
439
|
... )
|
|
301
440
|
"""
|
|
302
441
|
# if there are channels, we need to specify their number
|
|
303
|
-
if "C" in axes and n_channels ==
|
|
442
|
+
if "C" in axes and n_channels == 1:
|
|
304
443
|
raise ValueError(
|
|
305
444
|
f"Number of channels must be specified when using channels "
|
|
306
445
|
f"(got {n_channels} channel)."
|
|
307
446
|
)
|
|
308
|
-
elif "C" not in axes and n_channels
|
|
447
|
+
elif "C" not in axes and n_channels > 1:
|
|
309
448
|
raise ValueError(
|
|
310
449
|
f"C is not present in the axes, but number of channels is specified "
|
|
311
450
|
f"(got {n_channels} channel)."
|
|
312
451
|
)
|
|
313
|
-
elif n_channels == -1:
|
|
314
|
-
n_channels = 1
|
|
315
452
|
|
|
316
453
|
# model
|
|
317
454
|
if model_kwargs is None:
|
|
@@ -327,7 +464,7 @@ def create_n2v_configuration(
|
|
|
327
464
|
)
|
|
328
465
|
|
|
329
466
|
# algorithm model
|
|
330
|
-
algorithm =
|
|
467
|
+
algorithm = AlgorithmConfig(
|
|
331
468
|
algorithm=SupportedAlgorithm.N2V.value,
|
|
332
469
|
loss=SupportedLoss.N2V.value,
|
|
333
470
|
model=unet_model,
|
|
@@ -367,7 +504,7 @@ def create_n2v_configuration(
|
|
|
367
504
|
transforms.append(nv2_transform)
|
|
368
505
|
|
|
369
506
|
# data model
|
|
370
|
-
data =
|
|
507
|
+
data = DataConfig(
|
|
371
508
|
data_type=data_type,
|
|
372
509
|
axes=axes,
|
|
373
510
|
patch_size=patch_size,
|
|
@@ -376,7 +513,7 @@ def create_n2v_configuration(
|
|
|
376
513
|
)
|
|
377
514
|
|
|
378
515
|
# training model
|
|
379
|
-
training =
|
|
516
|
+
training = TrainingConfig(
|
|
380
517
|
num_epochs=num_epochs,
|
|
381
518
|
batch_size=batch_size,
|
|
382
519
|
logger=None if logger == "none" else logger,
|
|
@@ -403,7 +540,7 @@ def create_inference_configuration(
|
|
|
403
540
|
transforms: Optional[Union[List[Dict[str, Any]], Compose]] = None,
|
|
404
541
|
tta_transforms: bool = True,
|
|
405
542
|
batch_size: Optional[int] = 1,
|
|
406
|
-
) ->
|
|
543
|
+
) -> InferenceConfig:
|
|
407
544
|
"""
|
|
408
545
|
Create a configuration for inference with N2V.
|
|
409
546
|
|
|
@@ -447,7 +584,7 @@ def create_inference_configuration(
|
|
|
447
584
|
},
|
|
448
585
|
]
|
|
449
586
|
|
|
450
|
-
return
|
|
587
|
+
return InferenceConfig(
|
|
451
588
|
data_type=data_type or training_configuration.data_config.data_type,
|
|
452
589
|
tile_size=tile_size,
|
|
453
590
|
tile_overlap=tile_overlap,
|
|
@@ -11,8 +11,8 @@ from bioimageio.spec.generic.v0_3 import CiteEntry
|
|
|
11
11
|
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
|
12
12
|
from typing_extensions import Self
|
|
13
13
|
|
|
14
|
-
from .algorithm_model import
|
|
15
|
-
from .data_model import
|
|
14
|
+
from .algorithm_model import AlgorithmConfig
|
|
15
|
+
from .data_model import DataConfig
|
|
16
16
|
from .references import (
|
|
17
17
|
CARE,
|
|
18
18
|
CUSTOM,
|
|
@@ -34,7 +34,7 @@ from .references import (
|
|
|
34
34
|
StructN2VRef,
|
|
35
35
|
)
|
|
36
36
|
from .support import SupportedAlgorithm, SupportedPixelManipulation, SupportedTransform
|
|
37
|
-
from .training_model import
|
|
37
|
+
from .training_model import TrainingConfig
|
|
38
38
|
from .transformations.n2v_manipulate_model import (
|
|
39
39
|
N2VManipulateModel,
|
|
40
40
|
)
|
|
@@ -156,9 +156,10 @@ class Configuration(BaseModel):
|
|
|
156
156
|
)
|
|
157
157
|
|
|
158
158
|
# Sub-configurations
|
|
159
|
-
algorithm_config:
|
|
160
|
-
|
|
161
|
-
|
|
159
|
+
algorithm_config: AlgorithmConfig
|
|
160
|
+
|
|
161
|
+
data_config: DataConfig
|
|
162
|
+
training_config: TrainingConfig
|
|
162
163
|
|
|
163
164
|
@field_validator("experiment_name")
|
|
164
165
|
@classmethod
|
|
@@ -591,6 +592,6 @@ def save_configuration(config: Configuration, path: Union[str, Path]) -> Path:
|
|
|
591
592
|
# save configuration as dictionary to yaml
|
|
592
593
|
with open(config_path, "w") as f:
|
|
593
594
|
# dump configuration
|
|
594
|
-
yaml.dump(config.model_dump(), f, default_flow_style=False)
|
|
595
|
+
yaml.dump(config.model_dump(), f, default_flow_style=False, sort_keys=False)
|
|
595
596
|
|
|
596
597
|
return config_path
|
careamics/config/data_model.py
CHANGED
|
@@ -33,7 +33,7 @@ TRANSFORMS_UNION = Annotated[
|
|
|
33
33
|
]
|
|
34
34
|
|
|
35
35
|
|
|
36
|
-
class
|
|
36
|
+
class DataConfig(BaseModel):
|
|
37
37
|
"""
|
|
38
38
|
Data configuration.
|
|
39
39
|
|
|
@@ -45,7 +45,7 @@ class DataModel(BaseModel):
|
|
|
45
45
|
--------
|
|
46
46
|
Minimum example:
|
|
47
47
|
|
|
48
|
-
>>> data =
|
|
48
|
+
>>> data = DataConfig(
|
|
49
49
|
... data_type="array", # defined in SupportedData
|
|
50
50
|
... patch_size=[128, 128],
|
|
51
51
|
... batch_size=4,
|
|
@@ -58,7 +58,7 @@ class DataModel(BaseModel):
|
|
|
58
58
|
One can pass also a list of transformations, by keyword, using the
|
|
59
59
|
SupportedTransform or the name of an Albumentation transform:
|
|
60
60
|
>>> from careamics.config.support import SupportedTransform
|
|
61
|
-
>>> data =
|
|
61
|
+
>>> data = DataConfig(
|
|
62
62
|
... data_type="tiff",
|
|
63
63
|
... patch_size=[128, 128],
|
|
64
64
|
... batch_size=4,
|
|
@@ -14,7 +14,7 @@ from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
|
|
|
14
14
|
TRANSFORMS_UNION = Union[NormalizeModel]
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
class
|
|
17
|
+
class InferenceConfig(BaseModel):
|
|
18
18
|
"""Configuration class for the prediction model."""
|
|
19
19
|
|
|
20
20
|
model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
|
|
@@ -15,7 +15,7 @@ class SupportedOptimizer(str, BaseEnum):
|
|
|
15
15
|
# ASGD = "ASGD"
|
|
16
16
|
# Adadelta = "Adadelta"
|
|
17
17
|
# Adagrad = "Adagrad"
|
|
18
|
-
|
|
18
|
+
ADAM = "Adam"
|
|
19
19
|
# AdamW = "AdamW"
|
|
20
20
|
# Adamax = "Adamax"
|
|
21
21
|
# LBFGS = "LBFGS"
|
|
@@ -50,6 +50,6 @@ class SupportedScheduler(str, BaseEnum):
|
|
|
50
50
|
# MultiplicativeLR = "MultiplicativeLR"
|
|
51
51
|
# OneCycleLR = "OneCycleLR"
|
|
52
52
|
# PolynomialLR = "PolynomialLR"
|
|
53
|
-
|
|
53
|
+
REDUCE_LR_ON_PLATEAU = "ReduceLROnPlateau"
|
|
54
54
|
# SequentialLR = "SequentialLR"
|
|
55
|
-
|
|
55
|
+
STEP_LR = "StepLR"
|
|
@@ -30,7 +30,7 @@ class N2VManipulateModel(TransformModel):
|
|
|
30
30
|
validate_assignment=True,
|
|
31
31
|
)
|
|
32
32
|
|
|
33
|
-
name: Literal["N2VManipulate"]
|
|
33
|
+
name: Literal["N2VManipulate"] = "N2VManipulate"
|
|
34
34
|
roi_size: int = Field(default=11, ge=3, le=21)
|
|
35
35
|
masked_pixel_percentage: float = Field(default=0.2, ge=0.05, le=1.0)
|
|
36
36
|
strategy: Literal["uniform", "median"] = Field(default="uniform")
|
|
@@ -26,7 +26,7 @@ class NDFlipModel(TransformModel):
|
|
|
26
26
|
validate_assignment=True,
|
|
27
27
|
)
|
|
28
28
|
|
|
29
|
-
name: Literal["NDFlip"]
|
|
29
|
+
name: Literal["NDFlip"] = "NDFlip"
|
|
30
30
|
p: float = Field(default=0.5, ge=0.0, le=1.0)
|
|
31
31
|
is_3D: bool = Field(default=False)
|
|
32
32
|
flip_z: bool = Field(default=True)
|
|
@@ -24,6 +24,6 @@ class XYRandomRotate90Model(TransformModel):
|
|
|
24
24
|
validate_assignment=True,
|
|
25
25
|
)
|
|
26
26
|
|
|
27
|
-
name: Literal["XYRandomRotate90"]
|
|
27
|
+
name: Literal["XYRandomRotate90"] = "XYRandomRotate90"
|
|
28
28
|
p: float = Field(default=0.5, ge=0.0, le=1.0)
|
|
29
29
|
is_3D: bool = Field(default=False)
|
|
@@ -8,7 +8,7 @@ from typing import Any, Callable, List, Optional, Tuple, Union
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
from torch.utils.data import Dataset
|
|
10
10
|
|
|
11
|
-
from ..config import
|
|
11
|
+
from ..config import DataConfig, InferenceConfig
|
|
12
12
|
from ..config.tile_information import TileInformation
|
|
13
13
|
from ..utils.logging import get_logger
|
|
14
14
|
from .dataset_utils import read_tiff, reshape_array
|
|
@@ -29,7 +29,7 @@ class InMemoryDataset(Dataset):
|
|
|
29
29
|
|
|
30
30
|
def __init__(
|
|
31
31
|
self,
|
|
32
|
-
data_config:
|
|
32
|
+
data_config: DataConfig,
|
|
33
33
|
inputs: Union[np.ndarray, List[Path]],
|
|
34
34
|
data_target: Optional[Union[np.ndarray, List[Path]]] = None,
|
|
35
35
|
read_source_func: Callable = read_tiff,
|
|
@@ -279,7 +279,7 @@ class InMemoryPredictionDataset(Dataset):
|
|
|
279
279
|
|
|
280
280
|
def __init__(
|
|
281
281
|
self,
|
|
282
|
-
prediction_config:
|
|
282
|
+
prediction_config: InferenceConfig,
|
|
283
283
|
inputs: np.ndarray,
|
|
284
284
|
data_target: Optional[np.ndarray] = None,
|
|
285
285
|
read_source_func: Optional[Callable] = read_tiff,
|
|
@@ -7,7 +7,7 @@ from typing import Any, Callable, Generator, List, Optional, Tuple, Union
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
from torch.utils.data import IterableDataset, get_worker_info
|
|
9
9
|
|
|
10
|
-
from ..config import
|
|
10
|
+
from ..config import DataConfig, InferenceConfig
|
|
11
11
|
from ..config.tile_information import TileInformation
|
|
12
12
|
from ..utils.logging import get_logger
|
|
13
13
|
from .dataset_utils import read_tiff, reshape_array
|
|
@@ -46,7 +46,7 @@ class PathIterableDataset(IterableDataset):
|
|
|
46
46
|
|
|
47
47
|
def __init__(
|
|
48
48
|
self,
|
|
49
|
-
data_config: Union[
|
|
49
|
+
data_config: Union[DataConfig, InferenceConfig],
|
|
50
50
|
src_files: List[Path],
|
|
51
51
|
target_files: Optional[List[Path]] = None,
|
|
52
52
|
read_source_func: Callable = read_tiff,
|
|
@@ -346,7 +346,7 @@ class IterablePredictionDataset(PathIterableDataset):
|
|
|
346
346
|
|
|
347
347
|
def __init__(
|
|
348
348
|
self,
|
|
349
|
-
prediction_config:
|
|
349
|
+
prediction_config: InferenceConfig,
|
|
350
350
|
src_files: List[Path],
|
|
351
351
|
read_source_func: Callable = read_tiff,
|
|
352
352
|
**kwargs: Any,
|