careamics 0.1.0rc5__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/callbacks/hyperparameters_callback.py +10 -3
- careamics/callbacks/progress_bar_callback.py +37 -4
- careamics/careamist.py +164 -231
- careamics/config/algorithm_model.py +5 -18
- careamics/config/architectures/architecture_model.py +7 -0
- careamics/config/architectures/custom_model.py +11 -4
- careamics/config/architectures/register_model.py +3 -1
- careamics/config/architectures/unet_model.py +2 -0
- careamics/config/architectures/vae_model.py +2 -0
- careamics/config/callback_model.py +3 -15
- careamics/config/configuration_example.py +4 -5
- careamics/config/configuration_factory.py +27 -41
- careamics/config/configuration_model.py +11 -11
- careamics/config/data_model.py +89 -63
- careamics/config/inference_model.py +28 -81
- careamics/config/optimizer_models.py +11 -11
- careamics/config/support/__init__.py +0 -2
- careamics/config/support/supported_activations.py +2 -0
- careamics/config/support/supported_algorithms.py +3 -1
- careamics/config/support/supported_architectures.py +2 -0
- careamics/config/support/supported_data.py +2 -0
- careamics/config/support/supported_loggers.py +2 -0
- careamics/config/support/supported_losses.py +2 -0
- careamics/config/support/supported_optimizers.py +2 -0
- careamics/config/support/supported_pixel_manipulations.py +3 -3
- careamics/config/support/supported_struct_axis.py +2 -0
- careamics/config/support/supported_transforms.py +4 -16
- careamics/config/tile_information.py +28 -58
- careamics/config/transformations/__init__.py +3 -2
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +11 -3
- careamics/config/validators/validator_utils.py +1 -1
- careamics/conftest.py +12 -0
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -1
- careamics/dataset/dataset_utils/dataset_utils.py +4 -4
- careamics/dataset/dataset_utils/file_utils.py +4 -3
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/read_tiff.py +6 -11
- careamics/dataset/dataset_utils/read_utils.py +2 -0
- careamics/dataset/dataset_utils/read_zarr.py +11 -7
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +88 -154
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +121 -191
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
- careamics/dataset/patching/patching.py +109 -39
- careamics/dataset/patching/random_patching.py +17 -6
- careamics/dataset/patching/sequential_patching.py +14 -8
- careamics/dataset/patching/validate_patch_dimension.py +7 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +7 -5
- careamics/dataset/zarr_dataset.py +2 -0
- careamics/lightning_datamodule.py +46 -25
- careamics/lightning_module.py +19 -9
- careamics/lightning_prediction_datamodule.py +54 -84
- careamics/losses/__init__.py +2 -3
- careamics/losses/loss_factory.py +1 -1
- careamics/losses/losses.py +11 -7
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +3 -3
- careamics/model_io/model_io_utils.py +5 -2
- careamics/models/activation.py +2 -0
- careamics/models/layers.py +121 -25
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +1 -1
- careamics/models/unet.py +35 -14
- careamics/prediction_utils/__init__.py +12 -0
- careamics/prediction_utils/create_pred_datamodule.py +185 -0
- careamics/prediction_utils/prediction_outputs.py +165 -0
- careamics/prediction_utils/stitch_prediction.py +100 -0
- careamics/transforms/__init__.py +2 -2
- careamics/transforms/compose.py +33 -7
- careamics/transforms/n2v_manipulate.py +52 -14
- careamics/transforms/normalize.py +171 -48
- careamics/transforms/pixel_manipulation.py +35 -11
- careamics/transforms/struct_mask_parameters.py +3 -1
- careamics/transforms/transform.py +10 -19
- careamics/transforms/tta.py +43 -29
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +38 -5
- careamics/utils/base_enum.py +28 -0
- careamics/utils/path_utils.py +2 -0
- careamics/utils/ram.py +4 -2
- careamics/utils/receptive_field.py +93 -87
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +8 -6
- careamics-0.1.0rc7.dist-info/RECORD +130 -0
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
- careamics/config/noise_models.py +0 -162
- careamics/config/support/supported_extraction_strategies.py +0 -25
- careamics/config/transformations/nd_flip_model.py +0 -27
- careamics/lightning_prediction_loop.py +0 -116
- careamics/losses/noise_model_factory.py +0 -40
- careamics/losses/noise_models.py +0 -524
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -74
- careamics/transforms/nd_flip.py +0 -67
- careamics/utils/running_stats.py +0 -43
- careamics-0.1.0rc5.dist-info/RECORD +0 -111
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Algorithm configuration."""
|
|
2
|
+
|
|
1
3
|
from __future__ import annotations
|
|
2
4
|
|
|
3
5
|
from pprint import pformat
|
|
@@ -17,9 +19,9 @@ class AlgorithmConfig(BaseModel):
|
|
|
17
19
|
training algorithm: which algorithm, loss function, model architecture, optimizer,
|
|
18
20
|
and learning rate scheduler to use.
|
|
19
21
|
|
|
20
|
-
Currently, we only support N2V and custom
|
|
21
|
-
compatible with `n2v` loss and `UNet` architecture. The `custom` algorithm
|
|
22
|
-
you to register your own architecture and select it using its name as
|
|
22
|
+
Currently, we only support N2V, CARE, N2N and custom models. The `n2v` algorithm is
|
|
23
|
+
only compatible with `n2v` loss and `UNet` architecture. The `custom` algorithm
|
|
24
|
+
allows you to register your own architecture and select it using its name as
|
|
23
25
|
`name` in the custom pydantic model.
|
|
24
26
|
|
|
25
27
|
Attributes
|
|
@@ -132,21 +134,6 @@ class AlgorithmConfig(BaseModel):
|
|
|
132
134
|
"sure that `in_channels` and `num_classes` are the same."
|
|
133
135
|
)
|
|
134
136
|
|
|
135
|
-
# N2N
|
|
136
|
-
if self.algorithm == "n2n":
|
|
137
|
-
# n2n is only compatible with the UNet model
|
|
138
|
-
if not isinstance(self.model, UNetModel):
|
|
139
|
-
raise ValueError(
|
|
140
|
-
f"Model for algorithm {self.algorithm} must be a `UNetModel`."
|
|
141
|
-
)
|
|
142
|
-
|
|
143
|
-
# n2n requires the number of input and output channels to be the same
|
|
144
|
-
if self.model.in_channels != self.model.num_classes:
|
|
145
|
-
raise ValueError(
|
|
146
|
-
"N2N requires the same number of input and output channels. Make "
|
|
147
|
-
"sure that `in_channels` and `num_classes` are the same."
|
|
148
|
-
)
|
|
149
|
-
|
|
150
137
|
if self.algorithm == "care" or self.algorithm == "n2n":
|
|
151
138
|
if self.loss == "n2v":
|
|
152
139
|
raise ValueError("Supervised algorithms do not support loss `n2v`.")
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Base model for the various CAREamics architectures."""
|
|
2
|
+
|
|
1
3
|
from typing import Any, Dict
|
|
2
4
|
|
|
3
5
|
from pydantic import BaseModel
|
|
@@ -16,6 +18,11 @@ class ArchitectureModel(BaseModel):
|
|
|
16
18
|
"""
|
|
17
19
|
Dump the model as a dictionary, ignoring the architecture keyword.
|
|
18
20
|
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
**kwargs : Any
|
|
24
|
+
Additional keyword arguments from Pydantic BaseModel model_dump method.
|
|
25
|
+
|
|
19
26
|
Returns
|
|
20
27
|
-------
|
|
21
28
|
dict[str, Any]
|
|
@@ -1,7 +1,9 @@
|
|
|
1
|
+
"""Custom architecture Pydantic model."""
|
|
2
|
+
|
|
1
3
|
from __future__ import annotations
|
|
2
4
|
|
|
3
5
|
from pprint import pformat
|
|
4
|
-
from typing import Any,
|
|
6
|
+
from typing import Any, Literal
|
|
5
7
|
|
|
6
8
|
from pydantic import ConfigDict, field_validator, model_validator
|
|
7
9
|
from torch.nn import Module
|
|
@@ -84,6 +86,11 @@ class CustomModel(ArchitectureModel):
|
|
|
84
86
|
value : str
|
|
85
87
|
Name of the custom model as registered using the `@register_model`
|
|
86
88
|
decorator.
|
|
89
|
+
|
|
90
|
+
Returns
|
|
91
|
+
-------
|
|
92
|
+
str
|
|
93
|
+
The custom model name.
|
|
87
94
|
"""
|
|
88
95
|
# delegate error to get_custom_model
|
|
89
96
|
model = get_custom_model(value)
|
|
@@ -129,17 +136,17 @@ class CustomModel(ArchitectureModel):
|
|
|
129
136
|
"""
|
|
130
137
|
return pformat(self.model_dump())
|
|
131
138
|
|
|
132
|
-
def model_dump(self, **kwargs: Any) ->
|
|
139
|
+
def model_dump(self, **kwargs: Any) -> dict[str, Any]:
|
|
133
140
|
"""Dump the model configuration.
|
|
134
141
|
|
|
135
142
|
Parameters
|
|
136
143
|
----------
|
|
137
|
-
kwargs : Any
|
|
144
|
+
**kwargs : Any
|
|
138
145
|
Additional keyword arguments from Pydantic BaseModel model_dump method.
|
|
139
146
|
|
|
140
147
|
Returns
|
|
141
148
|
-------
|
|
142
|
-
|
|
149
|
+
dict[str, Any]
|
|
143
150
|
Model configuration.
|
|
144
151
|
"""
|
|
145
152
|
model_dict = super().model_dump()
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Custom model registration utilities."""
|
|
2
|
+
|
|
1
3
|
from typing import Callable
|
|
2
4
|
|
|
3
5
|
from torch.nn import Module
|
|
@@ -53,7 +55,7 @@ def register_model(name: str) -> Callable:
|
|
|
53
55
|
Parameters
|
|
54
56
|
----------
|
|
55
57
|
model : Module
|
|
56
|
-
Module class to register
|
|
58
|
+
Module class to register.
|
|
57
59
|
|
|
58
60
|
Returns
|
|
59
61
|
-------
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Callback Pydantic models."""
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
@@ -13,13 +13,7 @@ from pydantic import (
|
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class CheckpointModel(BaseModel):
|
|
16
|
-
"""
|
|
17
|
-
|
|
18
|
-
Parameters
|
|
19
|
-
----------
|
|
20
|
-
BaseModel : _type_
|
|
21
|
-
_description_
|
|
22
|
-
"""
|
|
16
|
+
"""Checkpoint saving callback Pydantic model."""
|
|
23
17
|
|
|
24
18
|
model_config = ConfigDict(
|
|
25
19
|
validate_assignment=True,
|
|
@@ -46,13 +40,7 @@ class CheckpointModel(BaseModel):
|
|
|
46
40
|
|
|
47
41
|
|
|
48
42
|
class EarlyStoppingModel(BaseModel):
|
|
49
|
-
"""
|
|
50
|
-
|
|
51
|
-
Parameters
|
|
52
|
-
----------
|
|
53
|
-
BaseModel : _type_
|
|
54
|
-
_description_
|
|
55
|
-
"""
|
|
43
|
+
"""Early stopping callback Pydantic model."""
|
|
56
44
|
|
|
57
45
|
model_config = ConfigDict(
|
|
58
46
|
validate_assignment=True,
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Example of configurations."""
|
|
2
|
+
|
|
1
3
|
from .algorithm_model import AlgorithmConfig
|
|
2
4
|
from .architectures import UNetModel
|
|
3
5
|
from .configuration_model import Configuration
|
|
@@ -19,7 +21,7 @@ from .training_model import TrainingConfig
|
|
|
19
21
|
|
|
20
22
|
|
|
21
23
|
def full_configuration_example() -> Configuration:
|
|
22
|
-
"""
|
|
24
|
+
"""Return a dictionnary representing a full configuration example.
|
|
23
25
|
|
|
24
26
|
Returns
|
|
25
27
|
-------
|
|
@@ -53,10 +55,7 @@ def full_configuration_example() -> Configuration:
|
|
|
53
55
|
axes="YX",
|
|
54
56
|
transforms=[
|
|
55
57
|
{
|
|
56
|
-
"name": SupportedTransform.
|
|
57
|
-
},
|
|
58
|
-
{
|
|
59
|
-
"name": SupportedTransform.NDFLIP.value,
|
|
58
|
+
"name": SupportedTransform.XY_FLIP.value,
|
|
60
59
|
},
|
|
61
60
|
{
|
|
62
61
|
"name": SupportedTransform.XY_RANDOM_ROTATE90.value,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Convenience functions to create configurations for training and inference."""
|
|
2
2
|
|
|
3
|
-
from typing import Any, Dict, List, Literal, Optional, Tuple
|
|
3
|
+
from typing import Any, Dict, List, Literal, Optional, Tuple
|
|
4
4
|
|
|
5
5
|
from .algorithm_model import AlgorithmConfig
|
|
6
6
|
from .architectures import UNetModel
|
|
@@ -108,21 +108,14 @@ def _create_supervised_configuration(
|
|
|
108
108
|
if use_augmentations:
|
|
109
109
|
transforms: List[Dict[str, Any]] = [
|
|
110
110
|
{
|
|
111
|
-
"name": SupportedTransform.
|
|
112
|
-
},
|
|
113
|
-
{
|
|
114
|
-
"name": SupportedTransform.NDFLIP.value,
|
|
111
|
+
"name": SupportedTransform.XY_FLIP.value,
|
|
115
112
|
},
|
|
116
113
|
{
|
|
117
114
|
"name": SupportedTransform.XY_RANDOM_ROTATE90.value,
|
|
118
115
|
},
|
|
119
116
|
]
|
|
120
117
|
else:
|
|
121
|
-
transforms = [
|
|
122
|
-
{
|
|
123
|
-
"name": SupportedTransform.NORMALIZE.value,
|
|
124
|
-
},
|
|
125
|
-
]
|
|
118
|
+
transforms = []
|
|
126
119
|
|
|
127
120
|
# data model
|
|
128
121
|
data = DataConfig(
|
|
@@ -250,7 +243,8 @@ def create_n2n_configuration(
|
|
|
250
243
|
use_augmentations: bool = True,
|
|
251
244
|
independent_channels: bool = False,
|
|
252
245
|
loss: Literal["mae", "mse"] = "mae",
|
|
253
|
-
|
|
246
|
+
n_channels_in: int = 1,
|
|
247
|
+
n_channels_out: int = -1,
|
|
254
248
|
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
255
249
|
model_kwargs: Optional[dict] = None,
|
|
256
250
|
) -> Configuration:
|
|
@@ -260,10 +254,13 @@ def create_n2n_configuration(
|
|
|
260
254
|
If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
|
|
261
255
|
2.
|
|
262
256
|
|
|
263
|
-
If "C" is present in `axes`, then you need to set `
|
|
257
|
+
If "C" is present in `axes`, then you need to set `n_channels_in` to the number of
|
|
264
258
|
channels. Likewise, if you set the number of channels, then "C" must be present in
|
|
265
259
|
`axes`.
|
|
266
260
|
|
|
261
|
+
To set the number of output channels, use the `n_channels_out` parameter. If it is
|
|
262
|
+
not specified, it will be assumed to be equal to `n_channels_in`.
|
|
263
|
+
|
|
267
264
|
By default, all channels are trained together. To train all channels independently,
|
|
268
265
|
set `independent_channels` to True.
|
|
269
266
|
|
|
@@ -290,8 +287,10 @@ def create_n2n_configuration(
|
|
|
290
287
|
Whether to train all channels independently, by default False.
|
|
291
288
|
loss : Literal["mae", "mse"], optional
|
|
292
289
|
Loss function to use, by default "mae".
|
|
293
|
-
|
|
294
|
-
Number of channels
|
|
290
|
+
n_channels_in : int, optional
|
|
291
|
+
Number of channels in, by default 1.
|
|
292
|
+
n_channels_out : int, optional
|
|
293
|
+
Number of channels out, by default -1.
|
|
295
294
|
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
296
295
|
Logger to use, by default "none".
|
|
297
296
|
model_kwargs : dict, optional
|
|
@@ -302,6 +301,9 @@ def create_n2n_configuration(
|
|
|
302
301
|
Configuration
|
|
303
302
|
Configuration for training Noise2Noise.
|
|
304
303
|
"""
|
|
304
|
+
if n_channels_out == -1:
|
|
305
|
+
n_channels_out = n_channels_in
|
|
306
|
+
|
|
305
307
|
return _create_supervised_configuration(
|
|
306
308
|
algorithm="n2n",
|
|
307
309
|
experiment_name=experiment_name,
|
|
@@ -313,8 +315,8 @@ def create_n2n_configuration(
|
|
|
313
315
|
use_augmentations=use_augmentations,
|
|
314
316
|
independent_channels=independent_channels,
|
|
315
317
|
loss=loss,
|
|
316
|
-
n_channels_in=
|
|
317
|
-
n_channels_out=
|
|
318
|
+
n_channels_in=n_channels_in,
|
|
319
|
+
n_channels_out=n_channels_out,
|
|
318
320
|
logger=logger,
|
|
319
321
|
model_kwargs=model_kwargs,
|
|
320
322
|
)
|
|
@@ -523,21 +525,14 @@ def create_n2v_configuration(
|
|
|
523
525
|
if use_augmentations:
|
|
524
526
|
transforms: List[Dict[str, Any]] = [
|
|
525
527
|
{
|
|
526
|
-
"name": SupportedTransform.
|
|
527
|
-
},
|
|
528
|
-
{
|
|
529
|
-
"name": SupportedTransform.NDFLIP.value,
|
|
528
|
+
"name": SupportedTransform.XY_FLIP.value,
|
|
530
529
|
},
|
|
531
530
|
{
|
|
532
531
|
"name": SupportedTransform.XY_RANDOM_ROTATE90.value,
|
|
533
532
|
},
|
|
534
533
|
]
|
|
535
534
|
else:
|
|
536
|
-
transforms = [
|
|
537
|
-
{
|
|
538
|
-
"name": SupportedTransform.NORMALIZE.value,
|
|
539
|
-
},
|
|
540
|
-
]
|
|
535
|
+
transforms = []
|
|
541
536
|
|
|
542
537
|
# n2v2 and structn2v
|
|
543
538
|
nv2_transform = {
|
|
@@ -587,7 +582,6 @@ def create_inference_configuration(
|
|
|
587
582
|
tile_overlap: Optional[Tuple[int, ...]] = None,
|
|
588
583
|
data_type: Optional[Literal["array", "tiff", "custom"]] = None,
|
|
589
584
|
axes: Optional[str] = None,
|
|
590
|
-
transforms: Optional[Union[List[Dict[str, Any]]]] = None,
|
|
591
585
|
tta_transforms: bool = True,
|
|
592
586
|
batch_size: Optional[int] = 1,
|
|
593
587
|
) -> InferenceConfig:
|
|
@@ -595,7 +589,7 @@ def create_inference_configuration(
|
|
|
595
589
|
Create a configuration for inference with N2V.
|
|
596
590
|
|
|
597
591
|
If not provided, `data_type` and `axes` are taken from the training
|
|
598
|
-
configuration.
|
|
592
|
+
configuration.
|
|
599
593
|
|
|
600
594
|
Parameters
|
|
601
595
|
----------
|
|
@@ -609,8 +603,6 @@ def create_inference_configuration(
|
|
|
609
603
|
Type of the data, by default "tiff".
|
|
610
604
|
axes : str, optional
|
|
611
605
|
Axes of the data, by default "YX".
|
|
612
|
-
transforms : List[Dict[str, Any]], optional
|
|
613
|
-
Transformations to apply to the data, by default None.
|
|
614
606
|
tta_transforms : bool, optional
|
|
615
607
|
Whether to apply test-time augmentations, by default True.
|
|
616
608
|
batch_size : int, optional
|
|
@@ -621,17 +613,12 @@ def create_inference_configuration(
|
|
|
621
613
|
InferenceConfiguration
|
|
622
614
|
Configuration used to configure CAREamicsPredictData.
|
|
623
615
|
"""
|
|
624
|
-
if
|
|
616
|
+
if (
|
|
617
|
+
configuration.data_config.image_means is None
|
|
618
|
+
or configuration.data_config.image_stds is None
|
|
619
|
+
):
|
|
625
620
|
raise ValueError("Mean and std must be provided in the configuration.")
|
|
626
621
|
|
|
627
|
-
# minimum transform
|
|
628
|
-
if transforms is None:
|
|
629
|
-
transforms = [
|
|
630
|
-
{
|
|
631
|
-
"name": SupportedTransform.NORMALIZE.value,
|
|
632
|
-
},
|
|
633
|
-
]
|
|
634
|
-
|
|
635
622
|
# tile size for UNets
|
|
636
623
|
if tile_size is not None:
|
|
637
624
|
model = configuration.algorithm_config.model
|
|
@@ -659,9 +646,8 @@ def create_inference_configuration(
|
|
|
659
646
|
tile_size=tile_size,
|
|
660
647
|
tile_overlap=tile_overlap,
|
|
661
648
|
axes=axes or configuration.data_config.axes,
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
transforms=transforms,
|
|
649
|
+
image_means=configuration.data_config.image_means,
|
|
650
|
+
image_stds=configuration.data_config.image_stds,
|
|
665
651
|
tta_transforms=tta_transforms,
|
|
666
652
|
batch_size=batch_size,
|
|
667
653
|
)
|
|
@@ -5,7 +5,7 @@ from __future__ import annotations
|
|
|
5
5
|
import re
|
|
6
6
|
from pathlib import Path
|
|
7
7
|
from pprint import pformat
|
|
8
|
-
from typing import
|
|
8
|
+
from typing import Literal, Union
|
|
9
9
|
|
|
10
10
|
import yaml
|
|
11
11
|
from bioimageio.spec.generic.v0_3 import CiteEntry
|
|
@@ -269,7 +269,7 @@ class Configuration(BaseModel):
|
|
|
269
269
|
"""
|
|
270
270
|
return pformat(self.model_dump())
|
|
271
271
|
|
|
272
|
-
def set_3D(self, is_3D: bool, axes: str, patch_size:
|
|
272
|
+
def set_3D(self, is_3D: bool, axes: str, patch_size: list[int]) -> None:
|
|
273
273
|
"""
|
|
274
274
|
Set 3D flag and axes.
|
|
275
275
|
|
|
@@ -279,7 +279,7 @@ class Configuration(BaseModel):
|
|
|
279
279
|
Whether the algorithm is 3D or not.
|
|
280
280
|
axes : str
|
|
281
281
|
Axes of the data.
|
|
282
|
-
patch_size :
|
|
282
|
+
patch_size : list[int]
|
|
283
283
|
Patch size.
|
|
284
284
|
"""
|
|
285
285
|
# set the flag and axes (this will not trigger validation at the config level)
|
|
@@ -389,7 +389,7 @@ class Configuration(BaseModel):
|
|
|
389
389
|
|
|
390
390
|
return ""
|
|
391
391
|
|
|
392
|
-
def get_algorithm_citations(self) ->
|
|
392
|
+
def get_algorithm_citations(self) -> list[CiteEntry]:
|
|
393
393
|
"""
|
|
394
394
|
Return a list of citation entries of the current algorithm.
|
|
395
395
|
|
|
@@ -455,13 +455,13 @@ class Configuration(BaseModel):
|
|
|
455
455
|
|
|
456
456
|
return ""
|
|
457
457
|
|
|
458
|
-
def get_algorithm_keywords(self) ->
|
|
458
|
+
def get_algorithm_keywords(self) -> list[str]:
|
|
459
459
|
"""
|
|
460
460
|
Get algorithm keywords.
|
|
461
461
|
|
|
462
462
|
Returns
|
|
463
463
|
-------
|
|
464
|
-
|
|
464
|
+
list[str]
|
|
465
465
|
List of keywords.
|
|
466
466
|
"""
|
|
467
467
|
if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
|
|
@@ -491,8 +491,8 @@ class Configuration(BaseModel):
|
|
|
491
491
|
self,
|
|
492
492
|
exclude_defaults: bool = False,
|
|
493
493
|
exclude_none: bool = True,
|
|
494
|
-
**kwargs:
|
|
495
|
-
) ->
|
|
494
|
+
**kwargs: dict,
|
|
495
|
+
) -> dict:
|
|
496
496
|
"""
|
|
497
497
|
Override model_dump method in order to set default values.
|
|
498
498
|
|
|
@@ -503,7 +503,7 @@ class Configuration(BaseModel):
|
|
|
503
503
|
True.
|
|
504
504
|
exclude_none : bool, optional
|
|
505
505
|
Whether to exclude fields with None values or not, by default True.
|
|
506
|
-
**kwargs :
|
|
506
|
+
**kwargs : dict
|
|
507
507
|
Keyword arguments.
|
|
508
508
|
|
|
509
509
|
Returns
|
|
@@ -524,7 +524,7 @@ def load_configuration(path: Union[str, Path]) -> Configuration:
|
|
|
524
524
|
|
|
525
525
|
Parameters
|
|
526
526
|
----------
|
|
527
|
-
path :
|
|
527
|
+
path : str or Path
|
|
528
528
|
Path to the configuration.
|
|
529
529
|
|
|
530
530
|
Returns
|
|
@@ -556,7 +556,7 @@ def save_configuration(config: Configuration, path: Union[str, Path]) -> Path:
|
|
|
556
556
|
----------
|
|
557
557
|
config : Configuration
|
|
558
558
|
Configuration to save.
|
|
559
|
-
path :
|
|
559
|
+
path : str or Path
|
|
560
560
|
Path to a existing folder in which to save the configuration or to an existing
|
|
561
561
|
configuration file.
|
|
562
562
|
|