careamics 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +17 -2
- careamics/careamist.py +4 -3
- careamics/cli/conf.py +1 -2
- careamics/cli/main.py +1 -2
- careamics/cli/utils.py +3 -3
- careamics/config/__init__.py +47 -25
- careamics/config/algorithms/__init__.py +15 -0
- careamics/config/algorithms/care_algorithm_model.py +38 -0
- careamics/config/algorithms/n2n_algorithm_model.py +30 -0
- careamics/config/algorithms/n2v_algorithm_model.py +29 -0
- careamics/config/algorithms/unet_algorithm_model.py +88 -0
- careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +14 -12
- careamics/config/architectures/__init__.py +1 -11
- careamics/config/architectures/architecture_model.py +3 -3
- careamics/config/architectures/lvae_model.py +6 -1
- careamics/config/architectures/unet_model.py +1 -0
- careamics/config/care_configuration.py +100 -0
- careamics/config/configuration.py +354 -0
- careamics/config/{configuration_factory.py → configuration_factories.py} +185 -57
- careamics/config/configuration_io.py +85 -0
- careamics/config/data/__init__.py +10 -0
- careamics/config/{data_model.py → data/data_model.py} +91 -186
- careamics/config/data/n2v_data_model.py +193 -0
- careamics/config/likelihood_model.py +1 -2
- careamics/config/n2n_configuration.py +101 -0
- careamics/config/n2v_configuration.py +266 -0
- careamics/config/nm_model.py +1 -2
- careamics/config/support/__init__.py +7 -7
- careamics/config/support/supported_algorithms.py +5 -4
- careamics/config/support/supported_architectures.py +0 -4
- careamics/config/transformations/__init__.py +10 -4
- careamics/config/transformations/transform_model.py +3 -3
- careamics/config/transformations/transform_unions.py +42 -0
- careamics/config/validators/__init__.py +12 -1
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/validator_utils.py +3 -3
- careamics/dataset/__init__.py +2 -2
- careamics/dataset/dataset_utils/__init__.py +3 -3
- careamics/dataset/dataset_utils/dataset_utils.py +4 -6
- careamics/dataset/dataset_utils/file_utils.py +9 -9
- careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
- careamics/dataset/in_memory_dataset.py +11 -12
- careamics/dataset/iterable_dataset.py +4 -4
- careamics/dataset/iterable_pred_dataset.py +2 -1
- careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
- careamics/dataset/patching/random_patching.py +11 -10
- careamics/dataset/patching/sequential_patching.py +26 -26
- careamics/dataset/patching/validate_patch_dimension.py +3 -3
- careamics/dataset/tiling/__init__.py +2 -2
- careamics/dataset/tiling/collate_tiles.py +3 -3
- careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
- careamics/dataset/tiling/tiled_patching.py +11 -10
- careamics/file_io/__init__.py +5 -5
- careamics/file_io/read/__init__.py +1 -1
- careamics/file_io/read/get_func.py +2 -2
- careamics/file_io/write/__init__.py +2 -2
- careamics/lightning/__init__.py +5 -5
- careamics/lightning/callbacks/__init__.py +1 -1
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
- careamics/lightning/callbacks/progress_bar_callback.py +3 -3
- careamics/lightning/lightning_module.py +11 -7
- careamics/lightning/train_data_module.py +36 -45
- careamics/losses/__init__.py +3 -3
- careamics/lvae_training/calibration.py +64 -57
- careamics/lvae_training/dataset/lc_dataset.py +2 -1
- careamics/lvae_training/dataset/multich_dataset.py +2 -2
- careamics/lvae_training/dataset/types.py +1 -1
- careamics/lvae_training/eval_utils.py +123 -128
- careamics/model_io/__init__.py +1 -1
- careamics/model_io/bioimage/__init__.py +1 -1
- careamics/model_io/bioimage/_readme_factory.py +1 -1
- careamics/model_io/bioimage/model_description.py +17 -17
- careamics/model_io/bmz_io.py +6 -17
- careamics/model_io/model_io_utils.py +9 -9
- careamics/models/layers.py +16 -16
- careamics/models/lvae/likelihoods.py +2 -0
- careamics/models/lvae/lvae.py +13 -4
- careamics/models/lvae/noise_models.py +280 -217
- careamics/models/lvae/stochastic.py +1 -0
- careamics/models/model_factory.py +2 -15
- careamics/models/unet.py +8 -8
- careamics/prediction_utils/__init__.py +1 -1
- careamics/prediction_utils/prediction_outputs.py +15 -15
- careamics/prediction_utils/stitch_prediction.py +6 -6
- careamics/transforms/__init__.py +5 -5
- careamics/transforms/compose.py +13 -13
- careamics/transforms/n2v_manipulate.py +3 -3
- careamics/transforms/pixel_manipulation.py +9 -9
- careamics/transforms/xy_random_rotate90.py +4 -4
- careamics/utils/__init__.py +5 -5
- careamics/utils/context.py +2 -1
- careamics/utils/logging.py +11 -10
- careamics/utils/metrics.py +25 -0
- careamics/utils/plotting.py +78 -0
- careamics/utils/torch_utils.py +7 -7
- {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/METADATA +13 -11
- careamics-0.0.7.dist-info/RECORD +178 -0
- careamics/config/architectures/custom_model.py +0 -162
- careamics/config/architectures/register_model.py +0 -103
- careamics/config/configuration_model.py +0 -603
- careamics/config/fcn_algorithm_model.py +0 -152
- careamics/config/references/__init__.py +0 -45
- careamics/config/references/algorithm_descriptions.py +0 -132
- careamics/config/references/references.py +0 -39
- careamics/config/transformations/transform_union.py +0 -20
- careamics-0.0.5.dist-info/RECORD +0 -171
- {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/WHEEL +0 -0
- {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,8 +2,10 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
from collections.abc import Sequence
|
|
5
6
|
from pprint import pformat
|
|
6
|
-
from typing import Any, Literal, Optional, Union
|
|
7
|
+
from typing import Annotated, Any, Literal, Optional, Union
|
|
8
|
+
from warnings import warn
|
|
7
9
|
|
|
8
10
|
import numpy as np
|
|
9
11
|
from numpy.typing import NDArray
|
|
@@ -15,11 +17,10 @@ from pydantic import (
|
|
|
15
17
|
field_validator,
|
|
16
18
|
model_validator,
|
|
17
19
|
)
|
|
18
|
-
from typing_extensions import
|
|
20
|
+
from typing_extensions import Self
|
|
19
21
|
|
|
20
|
-
from
|
|
21
|
-
from
|
|
22
|
-
from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
|
|
22
|
+
from ..transformations import N2V_TRANSFORMS_UNION, XYFlipModel, XYRandomRotate90Model
|
|
23
|
+
from ..validators import check_axes_validity, patch_size_ge_than_8_power_of_2
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
def np_float_to_scientific_str(x: float) -> str:
|
|
@@ -45,47 +46,8 @@ Float = Annotated[float, PlainSerializer(np_float_to_scientific_str, return_type
|
|
|
45
46
|
"""Annotated float type, used to serialize floats to strings."""
|
|
46
47
|
|
|
47
48
|
|
|
48
|
-
class
|
|
49
|
-
"""
|
|
50
|
-
Data configuration.
|
|
51
|
-
|
|
52
|
-
If std is specified, mean must be specified as well. Note that setting the std first
|
|
53
|
-
and then the mean (if they were both `None` before) will raise a validation error.
|
|
54
|
-
Prefer instead `set_mean_and_std` to set both at once. Means and stds are expected
|
|
55
|
-
to be lists of floats, one for each channel. For supervised tasks, the mean and std
|
|
56
|
-
of the target could be different from the input data.
|
|
57
|
-
|
|
58
|
-
All supported transforms are defined in the SupportedTransform enum.
|
|
59
|
-
|
|
60
|
-
Examples
|
|
61
|
-
--------
|
|
62
|
-
Minimum example:
|
|
63
|
-
|
|
64
|
-
>>> data = DataConfig(
|
|
65
|
-
... data_type="array", # defined in SupportedData
|
|
66
|
-
... patch_size=[128, 128],
|
|
67
|
-
... batch_size=4,
|
|
68
|
-
... axes="YX"
|
|
69
|
-
... )
|
|
70
|
-
|
|
71
|
-
To change the image_means and image_stds of the data:
|
|
72
|
-
>>> data.set_means_and_stds(image_means=[214.3], image_stds=[84.5])
|
|
73
|
-
|
|
74
|
-
One can pass also a list of transformations, by keyword, using the
|
|
75
|
-
SupportedTransform value:
|
|
76
|
-
>>> from careamics.config.support import SupportedTransform
|
|
77
|
-
>>> data = DataConfig(
|
|
78
|
-
... data_type="tiff",
|
|
79
|
-
... patch_size=[128, 128],
|
|
80
|
-
... batch_size=4,
|
|
81
|
-
... axes="YX",
|
|
82
|
-
... transforms=[
|
|
83
|
-
... {
|
|
84
|
-
... "name": "XYFlip",
|
|
85
|
-
... }
|
|
86
|
-
... ]
|
|
87
|
-
... )
|
|
88
|
-
"""
|
|
49
|
+
class GeneralDataConfig(BaseModel):
|
|
50
|
+
"""General data configuration."""
|
|
89
51
|
|
|
90
52
|
# Pydantic class configuration
|
|
91
53
|
model_config = ConfigDict(
|
|
@@ -126,25 +88,26 @@ class DataConfig(BaseModel):
|
|
|
126
88
|
"""Standard deviations of the target data across channels, used for
|
|
127
89
|
normalization."""
|
|
128
90
|
|
|
129
|
-
|
|
91
|
+
# defining as Sequence allows assigning subclasses of TransformModel without mypy
|
|
92
|
+
# complaining, this is important for instance to differentiate N2VDataConfig and
|
|
93
|
+
# DataConfig
|
|
94
|
+
transforms: Sequence[N2V_TRANSFORMS_UNION] = Field(
|
|
130
95
|
default=[
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
},
|
|
134
|
-
{
|
|
135
|
-
"name": SupportedTransform.XY_RANDOM_ROTATE90.value,
|
|
136
|
-
},
|
|
137
|
-
{
|
|
138
|
-
"name": SupportedTransform.N2V_MANIPULATE.value,
|
|
139
|
-
},
|
|
96
|
+
XYFlipModel(),
|
|
97
|
+
XYRandomRotate90Model(),
|
|
140
98
|
],
|
|
141
99
|
validate_default=True,
|
|
142
100
|
)
|
|
143
101
|
"""List of transformations to apply to the data, available transforms are defined
|
|
144
|
-
in SupportedTransform.
|
|
102
|
+
in SupportedTransform."""
|
|
103
|
+
|
|
104
|
+
train_dataloader_params: dict[str, Any] = Field(
|
|
105
|
+
default={"shuffle": True}, validate_default=True
|
|
106
|
+
)
|
|
107
|
+
"""Dictionary of PyTorch training dataloader parameters."""
|
|
145
108
|
|
|
146
|
-
|
|
147
|
-
"""Dictionary of PyTorch dataloader parameters."""
|
|
109
|
+
val_dataloader_params: dict[str, Any] = Field(default={})
|
|
110
|
+
"""Dictionary of PyTorch validation dataloader parameters."""
|
|
148
111
|
|
|
149
112
|
@field_validator("patch_size")
|
|
150
113
|
@classmethod
|
|
@@ -210,47 +173,44 @@ class DataConfig(BaseModel):
|
|
|
210
173
|
|
|
211
174
|
return axes
|
|
212
175
|
|
|
213
|
-
@field_validator("
|
|
176
|
+
@field_validator("train_dataloader_params")
|
|
214
177
|
@classmethod
|
|
215
|
-
def
|
|
216
|
-
cls,
|
|
217
|
-
) ->
|
|
178
|
+
def shuffle_train_dataloader(
|
|
179
|
+
cls, train_dataloader_params: dict[str, Any]
|
|
180
|
+
) -> dict[str, Any]:
|
|
218
181
|
"""
|
|
219
|
-
Validate
|
|
182
|
+
Validate that "shuffle" is included in the training dataloader params.
|
|
183
|
+
|
|
184
|
+
A warning will be raised if `shuffle=False`.
|
|
220
185
|
|
|
221
186
|
Parameters
|
|
222
187
|
----------
|
|
223
|
-
|
|
224
|
-
|
|
188
|
+
train_dataloader_params : dict of {str: Any}
|
|
189
|
+
The training dataloader parameters.
|
|
225
190
|
|
|
226
191
|
Returns
|
|
227
192
|
-------
|
|
228
|
-
|
|
229
|
-
|
|
193
|
+
dict of {str: Any}
|
|
194
|
+
The validated training dataloader parameters.
|
|
230
195
|
|
|
231
196
|
Raises
|
|
232
197
|
------
|
|
233
198
|
ValueError
|
|
234
|
-
If
|
|
199
|
+
If "shuffle" is not included in the training dataloader params.
|
|
235
200
|
"""
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
index = transform_list.index(SupportedTransform.N2V_MANIPULATE.value)
|
|
250
|
-
transform = transforms.pop(index)
|
|
251
|
-
transforms.append(transform)
|
|
252
|
-
|
|
253
|
-
return transforms
|
|
201
|
+
if "shuffle" not in train_dataloader_params:
|
|
202
|
+
raise ValueError(
|
|
203
|
+
"Value for 'shuffle' was not included in the `train_dataloader_params`."
|
|
204
|
+
)
|
|
205
|
+
elif ("shuffle" in train_dataloader_params) and (
|
|
206
|
+
not train_dataloader_params["shuffle"]
|
|
207
|
+
):
|
|
208
|
+
warn(
|
|
209
|
+
"Dataloader parameters include `shuffle=False`, this will be passed to "
|
|
210
|
+
"the training dataloader and may result in bad results.",
|
|
211
|
+
stacklevel=1,
|
|
212
|
+
)
|
|
213
|
+
return train_dataloader_params
|
|
254
214
|
|
|
255
215
|
@model_validator(mode="after")
|
|
256
216
|
def std_only_with_mean(self: Self) -> Self:
|
|
@@ -350,32 +310,6 @@ class DataConfig(BaseModel):
|
|
|
350
310
|
self.__dict__.update(kwargs)
|
|
351
311
|
self.__class__.model_validate(self.__dict__)
|
|
352
312
|
|
|
353
|
-
def has_n2v_manipulate(self) -> bool:
|
|
354
|
-
"""
|
|
355
|
-
Check if the transforms contain N2VManipulate.
|
|
356
|
-
|
|
357
|
-
Returns
|
|
358
|
-
-------
|
|
359
|
-
bool
|
|
360
|
-
True if the transforms contain N2VManipulate, False otherwise.
|
|
361
|
-
"""
|
|
362
|
-
return any(
|
|
363
|
-
transform.name == SupportedTransform.N2V_MANIPULATE.value
|
|
364
|
-
for transform in self.transforms
|
|
365
|
-
)
|
|
366
|
-
|
|
367
|
-
def add_n2v_manipulate(self) -> None:
|
|
368
|
-
"""Add N2VManipulate to the transforms."""
|
|
369
|
-
if not self.has_n2v_manipulate():
|
|
370
|
-
self.transforms.append(
|
|
371
|
-
N2VManipulateModel(name=SupportedTransform.N2V_MANIPULATE.value)
|
|
372
|
-
)
|
|
373
|
-
|
|
374
|
-
def remove_n2v_manipulate(self) -> None:
|
|
375
|
-
"""Remove N2VManipulate from the transforms."""
|
|
376
|
-
if self.has_n2v_manipulate():
|
|
377
|
-
self.transforms.pop(-1)
|
|
378
|
-
|
|
379
313
|
def set_means_and_stds(
|
|
380
314
|
self,
|
|
381
315
|
image_means: Union[NDArray, tuple, list, None],
|
|
@@ -430,84 +364,55 @@ class DataConfig(BaseModel):
|
|
|
430
364
|
"""
|
|
431
365
|
self._update(axes=axes, patch_size=patch_size)
|
|
432
366
|
|
|
433
|
-
def set_N2V2(self, use_n2v2: bool) -> None:
|
|
434
|
-
"""
|
|
435
|
-
Set N2V2.
|
|
436
|
-
|
|
437
|
-
Parameters
|
|
438
|
-
----------
|
|
439
|
-
use_n2v2 : bool
|
|
440
|
-
Whether to use N2V2.
|
|
441
367
|
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
If the N2V pixel manipulate transform is not found in the transforms.
|
|
446
|
-
"""
|
|
447
|
-
if use_n2v2:
|
|
448
|
-
self.set_N2V2_strategy("median")
|
|
449
|
-
else:
|
|
450
|
-
self.set_N2V2_strategy("uniform")
|
|
451
|
-
|
|
452
|
-
def set_N2V2_strategy(self, strategy: Literal["uniform", "median"]) -> None:
|
|
453
|
-
"""
|
|
454
|
-
Set N2V2 strategy.
|
|
455
|
-
|
|
456
|
-
Parameters
|
|
457
|
-
----------
|
|
458
|
-
strategy : Literal["uniform", "median"]
|
|
459
|
-
Strategy to use for N2V2.
|
|
460
|
-
|
|
461
|
-
Raises
|
|
462
|
-
------
|
|
463
|
-
ValueError
|
|
464
|
-
If the N2V pixel manipulate transform is not found in the transforms.
|
|
465
|
-
"""
|
|
466
|
-
found_n2v = False
|
|
467
|
-
|
|
468
|
-
for transform in self.transforms:
|
|
469
|
-
if transform.name == SupportedTransform.N2V_MANIPULATE.value:
|
|
470
|
-
transform.strategy = strategy
|
|
471
|
-
found_n2v = True
|
|
368
|
+
class DataConfig(GeneralDataConfig):
|
|
369
|
+
"""
|
|
370
|
+
Data configuration.
|
|
472
371
|
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
)
|
|
372
|
+
If std is specified, mean must be specified as well. Note that setting the std first
|
|
373
|
+
and then the mean (if they were both `None` before) will raise a validation error.
|
|
374
|
+
Prefer instead `set_mean_and_std` to set both at once. Means and stds are expected
|
|
375
|
+
to be lists of floats, one for each channel. For supervised tasks, the mean and std
|
|
376
|
+
of the target could be different from the input data.
|
|
479
377
|
|
|
480
|
-
|
|
481
|
-
self, mask_axis: Literal["horizontal", "vertical", "none"], mask_span: int
|
|
482
|
-
) -> None:
|
|
483
|
-
"""
|
|
484
|
-
Set structN2V mask parameters.
|
|
378
|
+
All supported transforms are defined in the SupportedTransform enum.
|
|
485
379
|
|
|
486
|
-
|
|
380
|
+
Examples
|
|
381
|
+
--------
|
|
382
|
+
Minimum example:
|
|
487
383
|
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
384
|
+
>>> data = DataConfig(
|
|
385
|
+
... data_type="array", # defined in SupportedData
|
|
386
|
+
... patch_size=[128, 128],
|
|
387
|
+
... batch_size=4,
|
|
388
|
+
... axes="YX"
|
|
389
|
+
... )
|
|
494
390
|
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
ValueError
|
|
498
|
-
If the N2V pixel manipulate transform is not found in the transforms.
|
|
499
|
-
"""
|
|
500
|
-
found_n2v = False
|
|
391
|
+
To change the image_means and image_stds of the data:
|
|
392
|
+
>>> data.set_means_and_stds(image_means=[214.3], image_stds=[84.5])
|
|
501
393
|
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
394
|
+
One can pass also a list of transformations, by keyword, using the
|
|
395
|
+
SupportedTransform value:
|
|
396
|
+
>>> from careamics.config.support import SupportedTransform
|
|
397
|
+
>>> data = DataConfig(
|
|
398
|
+
... data_type="tiff",
|
|
399
|
+
... patch_size=[128, 128],
|
|
400
|
+
... batch_size=4,
|
|
401
|
+
... axes="YX",
|
|
402
|
+
... transforms=[
|
|
403
|
+
... {
|
|
404
|
+
... "name": "XYFlip",
|
|
405
|
+
... }
|
|
406
|
+
... ]
|
|
407
|
+
... )
|
|
408
|
+
"""
|
|
507
409
|
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
410
|
+
transforms: Sequence[Union[XYFlipModel, XYRandomRotate90Model]] = Field(
|
|
411
|
+
default=[
|
|
412
|
+
XYFlipModel(),
|
|
413
|
+
XYRandomRotate90Model(),
|
|
414
|
+
],
|
|
415
|
+
validate_default=True,
|
|
416
|
+
)
|
|
417
|
+
"""List of transformations to apply to the data, available transforms are defined
|
|
418
|
+
in SupportedTransform. This excludes N2V specific transformations."""
|
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
"""Noise2Void specific data configuration model."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
from pydantic import Field, field_validator
|
|
7
|
+
|
|
8
|
+
from careamics.config.data.data_model import GeneralDataConfig
|
|
9
|
+
from careamics.config.support import SupportedTransform
|
|
10
|
+
from careamics.config.transformations import (
|
|
11
|
+
N2V_TRANSFORMS_UNION,
|
|
12
|
+
N2VManipulateModel,
|
|
13
|
+
XYFlipModel,
|
|
14
|
+
XYRandomRotate90Model,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class N2VDataConfig(GeneralDataConfig):
|
|
19
|
+
"""N2V specific data configuration model."""
|
|
20
|
+
|
|
21
|
+
transforms: Sequence[N2V_TRANSFORMS_UNION] = Field(
|
|
22
|
+
default=[XYFlipModel(), XYRandomRotate90Model(), N2VManipulateModel()],
|
|
23
|
+
validate_default=True,
|
|
24
|
+
)
|
|
25
|
+
"""N2V compatible transforms. N2VManpulate should be the last transform."""
|
|
26
|
+
|
|
27
|
+
@field_validator("transforms")
|
|
28
|
+
@classmethod
|
|
29
|
+
def validate_transforms(
|
|
30
|
+
cls, transforms: list[N2V_TRANSFORMS_UNION]
|
|
31
|
+
) -> list[N2V_TRANSFORMS_UNION]:
|
|
32
|
+
"""
|
|
33
|
+
Validate N2VManipulate transform position in the transform list.
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
transforms : list of transforms compatible with N2V
|
|
38
|
+
Transforms.
|
|
39
|
+
|
|
40
|
+
Returns
|
|
41
|
+
-------
|
|
42
|
+
list of transforms
|
|
43
|
+
Validated transforms.
|
|
44
|
+
|
|
45
|
+
Raises
|
|
46
|
+
------
|
|
47
|
+
ValueError
|
|
48
|
+
If multiple instances of N2VManipulate are found or if it is not the last
|
|
49
|
+
transform.
|
|
50
|
+
"""
|
|
51
|
+
transform_list = [t.name for t in transforms]
|
|
52
|
+
|
|
53
|
+
if SupportedTransform.N2V_MANIPULATE in transform_list:
|
|
54
|
+
# multiple N2V_MANIPULATE
|
|
55
|
+
if transform_list.count(SupportedTransform.N2V_MANIPULATE.value) > 1:
|
|
56
|
+
raise ValueError(
|
|
57
|
+
f"Multiple instances of "
|
|
58
|
+
f"{SupportedTransform.N2V_MANIPULATE} transforms "
|
|
59
|
+
f"are not allowed."
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# N2V_MANIPULATE not the last transform
|
|
63
|
+
elif transform_list[-1] != SupportedTransform.N2V_MANIPULATE:
|
|
64
|
+
raise ValueError(
|
|
65
|
+
f"{SupportedTransform.N2V_MANIPULATE} transform "
|
|
66
|
+
f"should be the last transform."
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
else:
|
|
70
|
+
raise ValueError(
|
|
71
|
+
f"{SupportedTransform.N2V_MANIPULATE} transform "
|
|
72
|
+
f"is required for N2V training."
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
return transforms
|
|
76
|
+
|
|
77
|
+
def set_n2v2(self, use_n2v2: bool) -> None:
|
|
78
|
+
"""
|
|
79
|
+
Set the N2V transform to the N2V2 version.
|
|
80
|
+
|
|
81
|
+
Parameters
|
|
82
|
+
----------
|
|
83
|
+
use_n2v2 : bool
|
|
84
|
+
Whether to use N2V2.
|
|
85
|
+
|
|
86
|
+
Raises
|
|
87
|
+
------
|
|
88
|
+
ValueError
|
|
89
|
+
If the N2V pixel manipulate transform is not found in the transforms.
|
|
90
|
+
"""
|
|
91
|
+
if use_n2v2:
|
|
92
|
+
self.set_masking_strategy("median")
|
|
93
|
+
else:
|
|
94
|
+
self.set_masking_strategy("uniform")
|
|
95
|
+
|
|
96
|
+
def set_masking_strategy(self, strategy: Literal["uniform", "median"]) -> None:
|
|
97
|
+
"""
|
|
98
|
+
Set masking strategy.
|
|
99
|
+
|
|
100
|
+
Parameters
|
|
101
|
+
----------
|
|
102
|
+
strategy : "uniform" or "median"
|
|
103
|
+
Strategy to use for N2V2.
|
|
104
|
+
|
|
105
|
+
Raises
|
|
106
|
+
------
|
|
107
|
+
ValueError
|
|
108
|
+
If the N2V pixel manipulate transform is not found in the transforms.
|
|
109
|
+
"""
|
|
110
|
+
found_n2v = False
|
|
111
|
+
|
|
112
|
+
for transform in self.transforms:
|
|
113
|
+
if transform.name == SupportedTransform.N2V_MANIPULATE.value:
|
|
114
|
+
transform.strategy = strategy
|
|
115
|
+
found_n2v = True
|
|
116
|
+
|
|
117
|
+
if not found_n2v:
|
|
118
|
+
transforms = [t.name for t in self.transforms]
|
|
119
|
+
raise ValueError(
|
|
120
|
+
f"N2V_Manipulate transform not found in the transforms list "
|
|
121
|
+
f"({transforms})."
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
def get_masking_strategy(self) -> Literal["uniform", "median"]:
|
|
125
|
+
"""
|
|
126
|
+
Get N2V2 strategy.
|
|
127
|
+
|
|
128
|
+
Returns
|
|
129
|
+
-------
|
|
130
|
+
"uniform" or "median"
|
|
131
|
+
Strategy used for N2V2.
|
|
132
|
+
"""
|
|
133
|
+
for transform in self.transforms:
|
|
134
|
+
if transform.name == SupportedTransform.N2V_MANIPULATE.value:
|
|
135
|
+
return transform.strategy
|
|
136
|
+
|
|
137
|
+
raise ValueError(
|
|
138
|
+
f"{SupportedTransform.N2V_MANIPULATE} transform "
|
|
139
|
+
f"is required for N2V training."
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
def set_structN2V_mask(
|
|
143
|
+
self, mask_axis: Literal["horizontal", "vertical", "none"], mask_span: int
|
|
144
|
+
) -> None:
|
|
145
|
+
"""
|
|
146
|
+
Set structN2V mask parameters.
|
|
147
|
+
|
|
148
|
+
Setting `mask_axis` to `none` will disable structN2V.
|
|
149
|
+
|
|
150
|
+
Parameters
|
|
151
|
+
----------
|
|
152
|
+
mask_axis : Literal["horizontal", "vertical", "none"]
|
|
153
|
+
Axis along which to apply the mask. `none` will disable structN2V.
|
|
154
|
+
mask_span : int
|
|
155
|
+
Total span of the mask in pixels.
|
|
156
|
+
|
|
157
|
+
Raises
|
|
158
|
+
------
|
|
159
|
+
ValueError
|
|
160
|
+
If the N2V pixel manipulate transform is not found in the transforms.
|
|
161
|
+
"""
|
|
162
|
+
found_n2v = False
|
|
163
|
+
|
|
164
|
+
for transform in self.transforms:
|
|
165
|
+
if transform.name == SupportedTransform.N2V_MANIPULATE.value:
|
|
166
|
+
transform.struct_mask_axis = mask_axis
|
|
167
|
+
transform.struct_mask_span = mask_span
|
|
168
|
+
found_n2v = True
|
|
169
|
+
|
|
170
|
+
if not found_n2v:
|
|
171
|
+
transforms = [t.name for t in self.transforms]
|
|
172
|
+
raise ValueError(
|
|
173
|
+
f"N2V pixel manipulate transform not found in the transforms "
|
|
174
|
+
f"({transforms})."
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
def is_using_struct_n2v(self) -> bool:
|
|
178
|
+
"""
|
|
179
|
+
Check if structN2V is enabled.
|
|
180
|
+
|
|
181
|
+
Returns
|
|
182
|
+
-------
|
|
183
|
+
bool
|
|
184
|
+
Whether structN2V is enabled or not.
|
|
185
|
+
"""
|
|
186
|
+
for transform in self.transforms:
|
|
187
|
+
if transform.name == SupportedTransform.N2V_MANIPULATE.value:
|
|
188
|
+
return transform.struct_mask_axis != "none"
|
|
189
|
+
|
|
190
|
+
raise ValueError(
|
|
191
|
+
f"N2V pixel manipulate transform not found in the transforms "
|
|
192
|
+
f"({self.transforms})."
|
|
193
|
+
)
|
|
@@ -1,11 +1,10 @@
|
|
|
1
1
|
"""Likelihood model."""
|
|
2
2
|
|
|
3
|
-
from typing import Literal, Optional, Union
|
|
3
|
+
from typing import Annotated, Literal, Optional, Union
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import torch
|
|
7
7
|
from pydantic import BaseModel, ConfigDict, PlainSerializer, PlainValidator
|
|
8
|
-
from typing_extensions import Annotated
|
|
9
8
|
|
|
10
9
|
from careamics.models.lvae.noise_models import (
|
|
11
10
|
GaussianMixtureNoiseModel,
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
"""N2N configuration."""
|
|
2
|
+
|
|
3
|
+
from bioimageio.spec.generic.v0_3 import CiteEntry
|
|
4
|
+
|
|
5
|
+
from careamics.config.algorithms import N2NAlgorithm
|
|
6
|
+
from careamics.config.configuration import Configuration
|
|
7
|
+
from careamics.config.data import DataConfig
|
|
8
|
+
|
|
9
|
+
N2N = "Noise2Noise"
|
|
10
|
+
|
|
11
|
+
N2N_DESCRIPTION = (
|
|
12
|
+
"Noise2Noise is a deep-learning-based algorithm that uses a U-Net "
|
|
13
|
+
"architecture to restore images. Noise2Noise is a self-supervised "
|
|
14
|
+
"algorithm that requires only noisy images to train the network. "
|
|
15
|
+
"The algorithm learns to predict the clean image from the noisy "
|
|
16
|
+
"image. Noise2Noise is particularly useful when clean images are "
|
|
17
|
+
"not available for training."
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
N2N_REF = CiteEntry(
|
|
21
|
+
text="Lehtinen, J., Munkberg, J., Hasselgren, J., Laine, S., Karras, T., "
|
|
22
|
+
'Aittala, M. and Aila, T., 2018. "Noise2Noise: Learning image restoration '
|
|
23
|
+
'without clean data". arXiv preprint arXiv:1803.04189.',
|
|
24
|
+
doi="10.48550/arXiv.1803.04189",
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class N2NConfiguration(Configuration):
|
|
29
|
+
"""Noise2Noise configuration."""
|
|
30
|
+
|
|
31
|
+
algorithm_config: N2NAlgorithm
|
|
32
|
+
"""Algorithm configuration."""
|
|
33
|
+
|
|
34
|
+
data_config: DataConfig
|
|
35
|
+
"""Data configuration."""
|
|
36
|
+
|
|
37
|
+
def get_algorithm_friendly_name(self) -> str:
|
|
38
|
+
"""
|
|
39
|
+
Get the algorithm friendly name.
|
|
40
|
+
|
|
41
|
+
Returns
|
|
42
|
+
-------
|
|
43
|
+
str
|
|
44
|
+
Friendly name of the algorithm.
|
|
45
|
+
"""
|
|
46
|
+
return N2N
|
|
47
|
+
|
|
48
|
+
def get_algorithm_keywords(self) -> list[str]:
|
|
49
|
+
"""
|
|
50
|
+
Get algorithm keywords.
|
|
51
|
+
|
|
52
|
+
Returns
|
|
53
|
+
-------
|
|
54
|
+
list[str]
|
|
55
|
+
List of keywords.
|
|
56
|
+
"""
|
|
57
|
+
return [
|
|
58
|
+
"restoration",
|
|
59
|
+
"UNet",
|
|
60
|
+
"3D" if "Z" in self.data_config.axes else "2D",
|
|
61
|
+
"CAREamics",
|
|
62
|
+
"pytorch",
|
|
63
|
+
N2N,
|
|
64
|
+
]
|
|
65
|
+
|
|
66
|
+
def get_algorithm_references(self) -> str:
|
|
67
|
+
"""
|
|
68
|
+
Get the algorithm references.
|
|
69
|
+
|
|
70
|
+
This is used to generate the README of the BioImage Model Zoo export.
|
|
71
|
+
|
|
72
|
+
Returns
|
|
73
|
+
-------
|
|
74
|
+
str
|
|
75
|
+
Algorithm references.
|
|
76
|
+
"""
|
|
77
|
+
return N2N_REF.text + " doi: " + N2N_REF.doi
|
|
78
|
+
|
|
79
|
+
def get_algorithm_citations(self) -> list[CiteEntry]:
|
|
80
|
+
"""
|
|
81
|
+
Return a list of citation entries of the current algorithm.
|
|
82
|
+
|
|
83
|
+
This is used to generate the model description for the BioImage Model Zoo.
|
|
84
|
+
|
|
85
|
+
Returns
|
|
86
|
+
-------
|
|
87
|
+
List[CiteEntry]
|
|
88
|
+
List of citation entries.
|
|
89
|
+
"""
|
|
90
|
+
return [N2N_REF]
|
|
91
|
+
|
|
92
|
+
def get_algorithm_description(self) -> str:
|
|
93
|
+
"""
|
|
94
|
+
Get the algorithm description.
|
|
95
|
+
|
|
96
|
+
Returns
|
|
97
|
+
-------
|
|
98
|
+
str
|
|
99
|
+
Algorithm description.
|
|
100
|
+
"""
|
|
101
|
+
return N2N_DESCRIPTION
|