careamics 0.1.0rc4__py3-none-any.whl → 0.1.0rc6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/callbacks/hyperparameters_callback.py +10 -3
- careamics/callbacks/progress_bar_callback.py +37 -4
- careamics/careamist.py +92 -55
- careamics/config/__init__.py +0 -1
- careamics/config/algorithm_model.py +5 -3
- careamics/config/architectures/architecture_model.py +7 -0
- careamics/config/architectures/custom_model.py +8 -1
- careamics/config/architectures/register_model.py +3 -1
- careamics/config/architectures/unet_model.py +3 -0
- careamics/config/architectures/vae_model.py +2 -0
- careamics/config/callback_model.py +4 -15
- careamics/config/configuration_example.py +4 -4
- careamics/config/configuration_factory.py +113 -55
- careamics/config/configuration_model.py +14 -16
- careamics/config/data_model.py +63 -165
- careamics/config/inference_model.py +9 -75
- careamics/config/optimizer_models.py +4 -4
- careamics/config/references/algorithm_descriptions.py +1 -0
- careamics/config/references/references.py +1 -0
- careamics/config/support/__init__.py +0 -2
- careamics/config/support/supported_activations.py +2 -0
- careamics/config/support/supported_algorithms.py +3 -1
- careamics/config/support/supported_architectures.py +2 -0
- careamics/config/support/supported_data.py +2 -0
- careamics/config/support/supported_loggers.py +2 -0
- careamics/config/support/supported_losses.py +2 -0
- careamics/config/support/supported_optimizers.py +2 -0
- careamics/config/support/supported_pixel_manipulations.py +3 -3
- careamics/config/support/supported_struct_axis.py +2 -0
- careamics/config/support/supported_transforms.py +4 -15
- careamics/config/tile_information.py +2 -0
- careamics/config/training_model.py +1 -0
- careamics/config/transformations/__init__.py +3 -2
- careamics/config/transformations/n2v_manipulate_model.py +1 -0
- careamics/config/transformations/normalize_model.py +1 -0
- careamics/config/transformations/transform_model.py +1 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +13 -7
- careamics/config/validators/validator_utils.py +1 -0
- careamics/conftest.py +13 -0
- careamics/dataset/dataset_utils/__init__.py +0 -1
- careamics/dataset/dataset_utils/dataset_utils.py +5 -4
- careamics/dataset/dataset_utils/file_utils.py +4 -3
- careamics/dataset/dataset_utils/read_tiff.py +6 -2
- careamics/dataset/dataset_utils/read_utils.py +2 -0
- careamics/dataset/dataset_utils/read_zarr.py +11 -7
- careamics/dataset/in_memory_dataset.py +84 -76
- careamics/dataset/iterable_dataset.py +166 -134
- careamics/dataset/patching/__init__.py +0 -7
- careamics/dataset/patching/patching.py +56 -14
- careamics/dataset/patching/random_patching.py +8 -2
- careamics/dataset/patching/sequential_patching.py +20 -14
- careamics/dataset/patching/tiled_patching.py +13 -7
- careamics/dataset/patching/validate_patch_dimension.py +2 -0
- careamics/dataset/zarr_dataset.py +2 -0
- careamics/lightning_datamodule.py +63 -41
- careamics/lightning_module.py +9 -3
- careamics/lightning_prediction_datamodule.py +15 -20
- careamics/lightning_prediction_loop.py +8 -6
- careamics/losses/__init__.py +1 -3
- careamics/losses/loss_factory.py +2 -1
- careamics/losses/losses.py +11 -7
- careamics/model_io/__init__.py +0 -1
- careamics/model_io/bioimage/_readme_factory.py +2 -1
- careamics/model_io/bioimage/bioimage_utils.py +1 -0
- careamics/model_io/bioimage/model_description.py +1 -0
- careamics/model_io/bmz_io.py +4 -3
- careamics/models/activation.py +2 -0
- careamics/models/layers.py +122 -25
- careamics/models/model_factory.py +2 -1
- careamics/models/unet.py +114 -19
- careamics/prediction/stitch_prediction.py +2 -5
- careamics/transforms/__init__.py +4 -25
- careamics/transforms/compose.py +124 -0
- careamics/transforms/n2v_manipulate.py +65 -34
- careamics/transforms/normalize.py +91 -28
- careamics/transforms/pixel_manipulation.py +7 -7
- careamics/transforms/struct_mask_parameters.py +3 -1
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +2 -2
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +66 -60
- careamics/utils/__init__.py +0 -1
- careamics/utils/base_enum.py +28 -0
- careamics/utils/context.py +1 -0
- careamics/utils/logging.py +1 -0
- careamics/utils/metrics.py +1 -0
- careamics/utils/path_utils.py +2 -0
- careamics/utils/ram.py +2 -0
- careamics/utils/receptive_field.py +93 -87
- careamics/utils/torch_utils.py +1 -0
- {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/METADATA +17 -61
- careamics-0.1.0rc6.dist-info/RECORD +107 -0
- careamics/config/noise_models.py +0 -162
- careamics/config/support/supported_extraction_strategies.py +0 -24
- careamics/config/transformations/nd_flip_model.py +0 -32
- careamics/dataset/patching/patch_transform.py +0 -44
- careamics/losses/noise_model_factory.py +0 -40
- careamics/losses/noise_models.py +0 -524
- careamics/transforms/nd_flip.py +0 -93
- careamics-0.1.0rc4.dist-info/RECORD +0 -110
- {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/WHEEL +0 -0
- {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/licenses/LICENSE +0 -0
careamics/config/data_model.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
"""Data configuration."""
|
|
2
|
+
|
|
2
3
|
from __future__ import annotations
|
|
3
4
|
|
|
4
5
|
from pprint import pformat
|
|
5
6
|
from typing import Any, List, Literal, Optional, Union
|
|
6
7
|
|
|
7
|
-
from albumentations import Compose
|
|
8
8
|
from pydantic import (
|
|
9
9
|
BaseModel,
|
|
10
10
|
ConfigDict,
|
|
@@ -17,14 +17,14 @@ from typing_extensions import Annotated, Self
|
|
|
17
17
|
|
|
18
18
|
from .support import SupportedTransform
|
|
19
19
|
from .transformations.n2v_manipulate_model import N2VManipulateModel
|
|
20
|
-
from .transformations.nd_flip_model import NDFlipModel
|
|
21
20
|
from .transformations.normalize_model import NormalizeModel
|
|
21
|
+
from .transformations.xy_flip_model import XYFlipModel
|
|
22
22
|
from .transformations.xy_random_rotate90_model import XYRandomRotate90Model
|
|
23
23
|
from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
|
|
24
24
|
|
|
25
25
|
TRANSFORMS_UNION = Annotated[
|
|
26
26
|
Union[
|
|
27
|
-
|
|
27
|
+
XYFlipModel,
|
|
28
28
|
XYRandomRotate90Model,
|
|
29
29
|
NormalizeModel,
|
|
30
30
|
N2VManipulateModel,
|
|
@@ -41,6 +41,8 @@ class DataConfig(BaseModel):
|
|
|
41
41
|
and then the mean (if they were both `None` before) will raise a validation error.
|
|
42
42
|
Prefer instead `set_mean_and_std` to set both at once.
|
|
43
43
|
|
|
44
|
+
All supported transforms are defined in the SupportedTransform enum.
|
|
45
|
+
|
|
44
46
|
Examples
|
|
45
47
|
--------
|
|
46
48
|
Minimum example:
|
|
@@ -56,7 +58,7 @@ class DataConfig(BaseModel):
|
|
|
56
58
|
>>> data.set_mean_and_std(mean=214.3, std=84.5)
|
|
57
59
|
|
|
58
60
|
One can pass also a list of transformations, by keyword, using the
|
|
59
|
-
SupportedTransform
|
|
61
|
+
SupportedTransform value:
|
|
60
62
|
>>> from careamics.config.support import SupportedTransform
|
|
61
63
|
>>> data = DataConfig(
|
|
62
64
|
... data_type="tiff",
|
|
@@ -70,9 +72,7 @@ class DataConfig(BaseModel):
|
|
|
70
72
|
... "std": 47.2,
|
|
71
73
|
... },
|
|
72
74
|
... {
|
|
73
|
-
... "name": "
|
|
74
|
-
... "is_3D": True,
|
|
75
|
-
... "flip_z": True,
|
|
75
|
+
... "name": "XYFlip",
|
|
76
76
|
... }
|
|
77
77
|
... ]
|
|
78
78
|
... )
|
|
@@ -81,7 +81,6 @@ class DataConfig(BaseModel):
|
|
|
81
81
|
# Pydantic class configuration
|
|
82
82
|
model_config = ConfigDict(
|
|
83
83
|
validate_assignment=True,
|
|
84
|
-
arbitrary_types_allowed=True, # Allow Compose declaration
|
|
85
84
|
)
|
|
86
85
|
|
|
87
86
|
# Dataset configuration
|
|
@@ -94,13 +93,13 @@ class DataConfig(BaseModel):
|
|
|
94
93
|
mean: Optional[float] = None
|
|
95
94
|
std: Optional[float] = None
|
|
96
95
|
|
|
97
|
-
transforms:
|
|
96
|
+
transforms: List[TRANSFORMS_UNION] = Field(
|
|
98
97
|
default=[
|
|
99
98
|
{
|
|
100
99
|
"name": SupportedTransform.NORMALIZE.value,
|
|
101
100
|
},
|
|
102
101
|
{
|
|
103
|
-
"name": SupportedTransform.
|
|
102
|
+
"name": SupportedTransform.XY_FLIP.value,
|
|
104
103
|
},
|
|
105
104
|
{
|
|
106
105
|
"name": SupportedTransform.XY_RANDOM_ROTATE90.value,
|
|
@@ -181,19 +180,19 @@ class DataConfig(BaseModel):
|
|
|
181
180
|
@field_validator("transforms")
|
|
182
181
|
@classmethod
|
|
183
182
|
def validate_prediction_transforms(
|
|
184
|
-
cls, transforms:
|
|
185
|
-
) ->
|
|
183
|
+
cls, transforms: List[TRANSFORMS_UNION]
|
|
184
|
+
) -> List[TRANSFORMS_UNION]:
|
|
186
185
|
"""
|
|
187
186
|
Validate N2VManipulate transform position in the transform list.
|
|
188
187
|
|
|
189
188
|
Parameters
|
|
190
189
|
----------
|
|
191
|
-
transforms :
|
|
190
|
+
transforms : List[Transformations_Union]
|
|
192
191
|
Transforms.
|
|
193
192
|
|
|
194
193
|
Returns
|
|
195
194
|
-------
|
|
196
|
-
|
|
195
|
+
List[TRANSFORMS_UNION]
|
|
197
196
|
Validated transforms.
|
|
198
197
|
|
|
199
198
|
Raises
|
|
@@ -201,23 +200,22 @@ class DataConfig(BaseModel):
|
|
|
201
200
|
ValueError
|
|
202
201
|
If multiple instances of N2VManipulate are found.
|
|
203
202
|
"""
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
transforms.append(transform)
|
|
203
|
+
transform_list = [t.name for t in transforms]
|
|
204
|
+
|
|
205
|
+
if SupportedTransform.N2V_MANIPULATE in transform_list:
|
|
206
|
+
# multiple N2V_MANIPULATE
|
|
207
|
+
if transform_list.count(SupportedTransform.N2V_MANIPULATE.value) > 1:
|
|
208
|
+
raise ValueError(
|
|
209
|
+
f"Multiple instances of "
|
|
210
|
+
f"{SupportedTransform.N2V_MANIPULATE} transforms "
|
|
211
|
+
f"are not allowed."
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# N2V_MANIPULATE not the last transform
|
|
215
|
+
elif transform_list[-1] != SupportedTransform.N2V_MANIPULATE:
|
|
216
|
+
index = transform_list.index(SupportedTransform.N2V_MANIPULATE.value)
|
|
217
|
+
transform = transforms.pop(index)
|
|
218
|
+
transforms.append(transform)
|
|
221
219
|
|
|
222
220
|
return transforms
|
|
223
221
|
|
|
@@ -254,13 +252,12 @@ class DataConfig(BaseModel):
|
|
|
254
252
|
Self
|
|
255
253
|
Data model with mean and std added to the Normalize transform.
|
|
256
254
|
"""
|
|
257
|
-
if self.mean is not None
|
|
255
|
+
if self.mean is not None and self.std is not None:
|
|
258
256
|
# search in the transforms for Normalize and update parameters
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
transform.std = self.std
|
|
257
|
+
for transform in self.transforms:
|
|
258
|
+
if transform.name == SupportedTransform.NORMALIZE.value:
|
|
259
|
+
transform.mean = self.mean
|
|
260
|
+
transform.std = self.std
|
|
264
261
|
|
|
265
262
|
return self
|
|
266
263
|
|
|
@@ -286,13 +283,6 @@ class DataConfig(BaseModel):
|
|
|
286
283
|
f"({self.axes})."
|
|
287
284
|
)
|
|
288
285
|
|
|
289
|
-
if self.has_transform_list():
|
|
290
|
-
for transform in self.transforms:
|
|
291
|
-
if transform.name == SupportedTransform.NDFLIP:
|
|
292
|
-
transform.is_3D = True
|
|
293
|
-
elif transform.name == SupportedTransform.XY_RANDOM_ROTATE90:
|
|
294
|
-
transform.is_3D = True
|
|
295
|
-
|
|
296
286
|
else:
|
|
297
287
|
if len(self.patch_size) != 2:
|
|
298
288
|
raise ValueError(
|
|
@@ -300,13 +290,6 @@ class DataConfig(BaseModel):
|
|
|
300
290
|
f"({self.axes})."
|
|
301
291
|
)
|
|
302
292
|
|
|
303
|
-
if self.has_transform_list():
|
|
304
|
-
for transform in self.transforms:
|
|
305
|
-
if transform.name == SupportedTransform.NDFLIP:
|
|
306
|
-
transform.is_3D = False
|
|
307
|
-
elif transform.name == SupportedTransform.XY_RANDOM_ROTATE90:
|
|
308
|
-
transform.is_3D = False
|
|
309
|
-
|
|
310
293
|
return self
|
|
311
294
|
|
|
312
295
|
def __str__(self) -> str:
|
|
@@ -332,84 +315,31 @@ class DataConfig(BaseModel):
|
|
|
332
315
|
self.__dict__.update(kwargs)
|
|
333
316
|
self.__class__.model_validate(self.__dict__)
|
|
334
317
|
|
|
335
|
-
def has_transform_list(self) -> bool:
|
|
336
|
-
"""
|
|
337
|
-
Check if the transforms are a list, as opposed to a Compose object.
|
|
338
|
-
|
|
339
|
-
Returns
|
|
340
|
-
-------
|
|
341
|
-
bool
|
|
342
|
-
True if the transforms are a list, False otherwise.
|
|
343
|
-
"""
|
|
344
|
-
return isinstance(self.transforms, list)
|
|
345
|
-
|
|
346
318
|
def has_n2v_manipulate(self) -> bool:
|
|
347
319
|
"""
|
|
348
320
|
Check if the transforms contain N2VManipulate.
|
|
349
321
|
|
|
350
|
-
Use `has_transform_list` to check if the transforms are a list.
|
|
351
|
-
|
|
352
322
|
Returns
|
|
353
323
|
-------
|
|
354
324
|
bool
|
|
355
325
|
True if the transforms contain N2VManipulate, False otherwise.
|
|
356
|
-
|
|
357
|
-
Raises
|
|
358
|
-
------
|
|
359
|
-
ValueError
|
|
360
|
-
If the transforms are a Compose object.
|
|
361
326
|
"""
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
)
|
|
367
|
-
else:
|
|
368
|
-
raise ValueError(
|
|
369
|
-
"Checking for N2VManipulate with Compose transforms is not allowed. "
|
|
370
|
-
"Check directly in the Compose."
|
|
371
|
-
)
|
|
327
|
+
return any(
|
|
328
|
+
transform.name == SupportedTransform.N2V_MANIPULATE.value
|
|
329
|
+
for transform in self.transforms
|
|
330
|
+
)
|
|
372
331
|
|
|
373
332
|
def add_n2v_manipulate(self) -> None:
|
|
374
|
-
"""
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
Raises
|
|
380
|
-
------
|
|
381
|
-
ValueError
|
|
382
|
-
If the transforms are a Compose object.
|
|
383
|
-
"""
|
|
384
|
-
if self.has_transform_list():
|
|
385
|
-
if not self.has_n2v_manipulate():
|
|
386
|
-
self.transforms.append(
|
|
387
|
-
N2VManipulateModel(name=SupportedTransform.N2V_MANIPULATE.value)
|
|
388
|
-
)
|
|
389
|
-
else:
|
|
390
|
-
raise ValueError(
|
|
391
|
-
"Adding N2VManipulate with Compose transforms is not allowed. Add "
|
|
392
|
-
"N2VManipulate directly to the transform in the Compose."
|
|
333
|
+
"""Add N2VManipulate to the transforms."""
|
|
334
|
+
if not self.has_n2v_manipulate():
|
|
335
|
+
self.transforms.append(
|
|
336
|
+
N2VManipulateModel(name=SupportedTransform.N2V_MANIPULATE.value)
|
|
393
337
|
)
|
|
394
338
|
|
|
395
339
|
def remove_n2v_manipulate(self) -> None:
|
|
396
|
-
"""
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
Use `has_transform_list` to check if the transforms are a list.
|
|
400
|
-
|
|
401
|
-
Raises
|
|
402
|
-
------
|
|
403
|
-
ValueError
|
|
404
|
-
If the transforms are a Compose object.
|
|
405
|
-
"""
|
|
406
|
-
if self.has_transform_list() and self.has_n2v_manipulate():
|
|
340
|
+
"""Remove N2VManipulate from the transforms."""
|
|
341
|
+
if self.has_n2v_manipulate():
|
|
407
342
|
self.transforms.pop(-1)
|
|
408
|
-
else:
|
|
409
|
-
raise ValueError(
|
|
410
|
-
"Removing N2VManipulate with Compose transforms is not allowed. Remove "
|
|
411
|
-
"N2VManipulate directly from the transform in the Compose."
|
|
412
|
-
)
|
|
413
343
|
|
|
414
344
|
def set_mean_and_std(self, mean: float, std: float) -> None:
|
|
415
345
|
"""
|
|
@@ -427,18 +357,6 @@ class DataConfig(BaseModel):
|
|
|
427
357
|
"""
|
|
428
358
|
self._update(mean=mean, std=std)
|
|
429
359
|
|
|
430
|
-
# search in the transforms for Normalize and update parameters
|
|
431
|
-
if self.has_transform_list():
|
|
432
|
-
for transform in self.transforms:
|
|
433
|
-
if transform.name == SupportedTransform.NORMALIZE.value:
|
|
434
|
-
transform.mean = mean
|
|
435
|
-
transform.std = std
|
|
436
|
-
else:
|
|
437
|
-
raise ValueError(
|
|
438
|
-
"Setting mean and std with Compose transforms is not allowed. Add "
|
|
439
|
-
"mean and std parameters directly to the transform in the Compose."
|
|
440
|
-
)
|
|
441
|
-
|
|
442
360
|
def set_3D(self, axes: str, patch_size: List[int]) -> None:
|
|
443
361
|
"""
|
|
444
362
|
Set 3D parameters.
|
|
@@ -465,8 +383,6 @@ class DataConfig(BaseModel):
|
|
|
465
383
|
------
|
|
466
384
|
ValueError
|
|
467
385
|
If the N2V pixel manipulate transform is not found in the transforms.
|
|
468
|
-
ValueError
|
|
469
|
-
If the transforms are a Compose object.
|
|
470
386
|
"""
|
|
471
387
|
if use_n2v2:
|
|
472
388
|
self.set_N2V2_strategy("median")
|
|
@@ -486,28 +402,19 @@ class DataConfig(BaseModel):
|
|
|
486
402
|
------
|
|
487
403
|
ValueError
|
|
488
404
|
If the N2V pixel manipulate transform is not found in the transforms.
|
|
489
|
-
ValueError
|
|
490
|
-
If the transforms are a Compose object.
|
|
491
405
|
"""
|
|
492
|
-
|
|
493
|
-
found_n2v = False
|
|
494
|
-
|
|
495
|
-
for transform in self.transforms:
|
|
496
|
-
if transform.name == SupportedTransform.N2V_MANIPULATE.value:
|
|
497
|
-
transform.strategy = strategy
|
|
498
|
-
found_n2v = True
|
|
406
|
+
found_n2v = False
|
|
499
407
|
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
f"({transforms})."
|
|
505
|
-
)
|
|
408
|
+
for transform in self.transforms:
|
|
409
|
+
if transform.name == SupportedTransform.N2V_MANIPULATE.value:
|
|
410
|
+
transform.strategy = strategy
|
|
411
|
+
found_n2v = True
|
|
506
412
|
|
|
507
|
-
|
|
413
|
+
if not found_n2v:
|
|
414
|
+
transforms = [t.name for t in self.transforms]
|
|
508
415
|
raise ValueError(
|
|
509
|
-
"
|
|
510
|
-
"
|
|
416
|
+
f"N2V_Manipulate transform not found in the transforms list "
|
|
417
|
+
f"({transforms})."
|
|
511
418
|
)
|
|
512
419
|
|
|
513
420
|
def set_structN2V_mask(
|
|
@@ -529,27 +436,18 @@ class DataConfig(BaseModel):
|
|
|
529
436
|
------
|
|
530
437
|
ValueError
|
|
531
438
|
If the N2V pixel manipulate transform is not found in the transforms.
|
|
532
|
-
ValueError
|
|
533
|
-
If the transforms are a Compose object.
|
|
534
439
|
"""
|
|
535
|
-
|
|
536
|
-
found_n2v = False
|
|
537
|
-
|
|
538
|
-
for transform in self.transforms:
|
|
539
|
-
if transform.name == SupportedTransform.N2V_MANIPULATE.value:
|
|
540
|
-
transform.struct_mask_axis = mask_axis
|
|
541
|
-
transform.struct_mask_span = mask_span
|
|
542
|
-
found_n2v = True
|
|
440
|
+
found_n2v = False
|
|
543
441
|
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
)
|
|
442
|
+
for transform in self.transforms:
|
|
443
|
+
if transform.name == SupportedTransform.N2V_MANIPULATE.value:
|
|
444
|
+
transform.struct_mask_axis = mask_axis
|
|
445
|
+
transform.struct_mask_span = mask_span
|
|
446
|
+
found_n2v = True
|
|
550
447
|
|
|
551
|
-
|
|
448
|
+
if not found_n2v:
|
|
449
|
+
transforms = [t.name for t in self.transforms]
|
|
552
450
|
raise ValueError(
|
|
553
|
-
"
|
|
554
|
-
"
|
|
451
|
+
f"N2V pixel manipulate transform not found in the transforms "
|
|
452
|
+
f"({transforms})."
|
|
555
453
|
)
|
|
@@ -1,18 +1,14 @@
|
|
|
1
1
|
"""Pydantic model representing CAREamics prediction configuration."""
|
|
2
|
+
|
|
2
3
|
from __future__ import annotations
|
|
3
4
|
|
|
4
5
|
from typing import Any, List, Literal, Optional, Union
|
|
5
6
|
|
|
6
|
-
from albumentations import Compose
|
|
7
7
|
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
|
8
8
|
from typing_extensions import Self
|
|
9
9
|
|
|
10
|
-
from .support import SupportedTransform
|
|
11
|
-
from .transformations.normalize_model import NormalizeModel
|
|
12
10
|
from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
|
|
13
11
|
|
|
14
|
-
TRANSFORMS_UNION = Union[NormalizeModel]
|
|
15
|
-
|
|
16
12
|
|
|
17
13
|
class InferenceConfig(BaseModel):
|
|
18
14
|
"""Configuration class for the prediction model."""
|
|
@@ -33,15 +29,6 @@ class InferenceConfig(BaseModel):
|
|
|
33
29
|
mean: float
|
|
34
30
|
std: float = Field(..., ge=0.0)
|
|
35
31
|
|
|
36
|
-
transforms: Union[List[TRANSFORMS_UNION], Compose] = Field(
|
|
37
|
-
default=[
|
|
38
|
-
{
|
|
39
|
-
"name": SupportedTransform.NORMALIZE.value,
|
|
40
|
-
},
|
|
41
|
-
],
|
|
42
|
-
validate_default=True,
|
|
43
|
-
)
|
|
44
|
-
|
|
45
32
|
# only default TTAs are supported for now
|
|
46
33
|
tta_transforms: bool = Field(default=True)
|
|
47
34
|
|
|
@@ -51,22 +38,22 @@ class InferenceConfig(BaseModel):
|
|
|
51
38
|
@field_validator("tile_overlap")
|
|
52
39
|
@classmethod
|
|
53
40
|
def all_elements_non_zero_even(
|
|
54
|
-
cls,
|
|
41
|
+
cls, tile_overlap: Optional[Union[List[int]]]
|
|
55
42
|
) -> Optional[Union[List[int]]]:
|
|
56
43
|
"""
|
|
57
|
-
Validate
|
|
44
|
+
Validate tile overlap.
|
|
58
45
|
|
|
59
|
-
|
|
46
|
+
Overlaps must be non-zero, positive and even.
|
|
60
47
|
|
|
61
48
|
Parameters
|
|
62
49
|
----------
|
|
63
|
-
|
|
50
|
+
tile_overlap : Optional[Union[List[int]]]
|
|
64
51
|
Patch size.
|
|
65
52
|
|
|
66
53
|
Returns
|
|
67
54
|
-------
|
|
68
55
|
Optional[Union[List[int]]]
|
|
69
|
-
Validated
|
|
56
|
+
Validated tile overlap.
|
|
70
57
|
|
|
71
58
|
Raises
|
|
72
59
|
------
|
|
@@ -75,8 +62,8 @@ class InferenceConfig(BaseModel):
|
|
|
75
62
|
ValueError
|
|
76
63
|
If the patch size is not even.
|
|
77
64
|
"""
|
|
78
|
-
if
|
|
79
|
-
for dim in
|
|
65
|
+
if tile_overlap is not None:
|
|
66
|
+
for dim in tile_overlap:
|
|
80
67
|
if dim < 1:
|
|
81
68
|
raise ValueError(
|
|
82
69
|
f"Patch size must be non-zero positive (got {dim})."
|
|
@@ -85,7 +72,7 @@ class InferenceConfig(BaseModel):
|
|
|
85
72
|
if dim % 2 != 0:
|
|
86
73
|
raise ValueError(f"Patch size must be even (got {dim}).")
|
|
87
74
|
|
|
88
|
-
return
|
|
75
|
+
return tile_overlap
|
|
89
76
|
|
|
90
77
|
@field_validator("tile_size")
|
|
91
78
|
@classmethod
|
|
@@ -149,39 +136,6 @@ class InferenceConfig(BaseModel):
|
|
|
149
136
|
|
|
150
137
|
return axes
|
|
151
138
|
|
|
152
|
-
@field_validator("transforms")
|
|
153
|
-
@classmethod
|
|
154
|
-
def validate_transforms(
|
|
155
|
-
cls, transforms: Union[List[TRANSFORMS_UNION], Compose]
|
|
156
|
-
) -> Union[List[TRANSFORMS_UNION], Compose]:
|
|
157
|
-
"""
|
|
158
|
-
Validate that transforms do not have N2V pixel manipulate transforms.
|
|
159
|
-
|
|
160
|
-
Parameters
|
|
161
|
-
----------
|
|
162
|
-
transforms : Union[List[TransformModel], Compose]
|
|
163
|
-
Transforms.
|
|
164
|
-
|
|
165
|
-
Returns
|
|
166
|
-
-------
|
|
167
|
-
Union[List[Transformations_Union], Compose]
|
|
168
|
-
Validated transforms.
|
|
169
|
-
|
|
170
|
-
Raises
|
|
171
|
-
------
|
|
172
|
-
ValueError
|
|
173
|
-
If transforms contain N2V pixel manipulate transforms.
|
|
174
|
-
"""
|
|
175
|
-
if not isinstance(transforms, Compose) and transforms is not None:
|
|
176
|
-
for transform in transforms:
|
|
177
|
-
if transform.name == SupportedTransform.N2V_MANIPULATE.value:
|
|
178
|
-
raise ValueError(
|
|
179
|
-
"N2V_Manipulate transform is not allowed in "
|
|
180
|
-
"prediction transforms."
|
|
181
|
-
)
|
|
182
|
-
|
|
183
|
-
return transforms
|
|
184
|
-
|
|
185
139
|
@model_validator(mode="after")
|
|
186
140
|
def validate_dimensions(self: Self) -> Self:
|
|
187
141
|
"""
|
|
@@ -235,26 +189,6 @@ class InferenceConfig(BaseModel):
|
|
|
235
189
|
|
|
236
190
|
return self
|
|
237
191
|
|
|
238
|
-
@model_validator(mode="after")
|
|
239
|
-
def add_std_and_mean_to_normalize(self: Self) -> Self:
|
|
240
|
-
"""
|
|
241
|
-
Add mean and std to the Normalize transform if it is present.
|
|
242
|
-
|
|
243
|
-
Returns
|
|
244
|
-
-------
|
|
245
|
-
Self
|
|
246
|
-
Inference model with mean and std added to the Normalize transform.
|
|
247
|
-
"""
|
|
248
|
-
if self.mean is not None or self.std is not None:
|
|
249
|
-
# search in the transforms for Normalize and update parameters
|
|
250
|
-
if not isinstance(self.transforms, Compose):
|
|
251
|
-
for transform in self.transforms:
|
|
252
|
-
if transform.name == SupportedTransform.NORMALIZE.value:
|
|
253
|
-
transform.mean = self.mean
|
|
254
|
-
transform.std = self.std
|
|
255
|
-
|
|
256
|
-
return self
|
|
257
|
-
|
|
258
192
|
def _update(self, **kwargs: Any) -> None:
|
|
259
193
|
"""
|
|
260
194
|
Update multiple arguments at once.
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Optimizers and schedulers Pydantic models."""
|
|
2
|
+
|
|
1
3
|
from __future__ import annotations
|
|
2
4
|
|
|
3
5
|
from typing import Dict, Literal
|
|
@@ -19,8 +21,7 @@ from .support import SupportedOptimizer
|
|
|
19
21
|
|
|
20
22
|
|
|
21
23
|
class OptimizerModel(BaseModel):
|
|
22
|
-
"""
|
|
23
|
-
Torch optimizer.
|
|
24
|
+
"""Torch optimizer Pydantic model.
|
|
24
25
|
|
|
25
26
|
Only parameters supported by the corresponding torch optimizer will be taken
|
|
26
27
|
into account. For more details, check:
|
|
@@ -115,8 +116,7 @@ class OptimizerModel(BaseModel):
|
|
|
115
116
|
|
|
116
117
|
|
|
117
118
|
class LrSchedulerModel(BaseModel):
|
|
118
|
-
"""
|
|
119
|
-
Torch learning rate scheduler.
|
|
119
|
+
"""Torch learning rate scheduler Pydantic model.
|
|
120
120
|
|
|
121
121
|
Only parameters supported by the corresponding torch lr scheduler will be taken
|
|
122
122
|
into account. For more details, check:
|
|
@@ -14,7 +14,6 @@ __all__ = [
|
|
|
14
14
|
"SupportedPixelManipulation",
|
|
15
15
|
"SupportedTransform",
|
|
16
16
|
"SupportedData",
|
|
17
|
-
"SupportedExtractionStrategy",
|
|
18
17
|
"SupportedStructAxis",
|
|
19
18
|
"SupportedLogger",
|
|
20
19
|
]
|
|
@@ -24,7 +23,6 @@ from .supported_activations import SupportedActivation
|
|
|
24
23
|
from .supported_algorithms import SupportedAlgorithm
|
|
25
24
|
from .supported_architectures import SupportedArchitecture
|
|
26
25
|
from .supported_data import SupportedData
|
|
27
|
-
from .supported_extraction_strategies import SupportedExtractionStrategy
|
|
28
26
|
from .supported_loggers import SupportedLogger
|
|
29
27
|
from .supported_losses import SupportedLoss
|
|
30
28
|
from .supported_optimizers import SupportedOptimizer, SupportedScheduler
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Algorithms supported by CAREamics."""
|
|
2
|
+
|
|
1
3
|
from __future__ import annotations
|
|
2
4
|
|
|
3
5
|
from careamics.utils import BaseEnum
|
|
@@ -10,9 +12,9 @@ class SupportedAlgorithm(str, BaseEnum):
|
|
|
10
12
|
"""
|
|
11
13
|
|
|
12
14
|
N2V = "n2v"
|
|
13
|
-
CUSTOM = "custom"
|
|
14
15
|
CARE = "care"
|
|
15
16
|
N2N = "n2n"
|
|
17
|
+
CUSTOM = "custom"
|
|
16
18
|
# PN2V = "pn2v"
|
|
17
19
|
# HDN = "hdn"
|
|
18
20
|
# SEG = "segmentation"
|
|
@@ -1,15 +1,15 @@
|
|
|
1
|
+
"""Pixel manipulation methods supported by CAREamics."""
|
|
2
|
+
|
|
1
3
|
from careamics.utils import BaseEnum
|
|
2
4
|
|
|
3
5
|
|
|
4
6
|
class SupportedPixelManipulation(str, BaseEnum):
|
|
5
|
-
"""
|
|
7
|
+
"""Supported Noise2Void pixel manipulations.
|
|
6
8
|
|
|
7
9
|
- Uniform: Replace masked pixel value by a (uniformly) randomly selected neighbor
|
|
8
10
|
pixel value.
|
|
9
11
|
- Median: Replace masked pixel value by the mean of the neighborhood.
|
|
10
12
|
"""
|
|
11
13
|
|
|
12
|
-
# TODO docs
|
|
13
|
-
|
|
14
14
|
UNIFORM = "uniform"
|
|
15
15
|
MEDIAN = "median"
|