careamics 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +17 -2
- careamics/careamist.py +4 -3
- careamics/cli/conf.py +1 -2
- careamics/cli/main.py +1 -2
- careamics/cli/utils.py +3 -3
- careamics/config/__init__.py +47 -25
- careamics/config/algorithms/__init__.py +15 -0
- careamics/config/algorithms/care_algorithm_model.py +50 -0
- careamics/config/algorithms/n2n_algorithm_model.py +42 -0
- careamics/config/algorithms/n2v_algorithm_model.py +35 -0
- careamics/config/algorithms/unet_algorithm_model.py +88 -0
- careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +14 -12
- careamics/config/architectures/__init__.py +1 -11
- careamics/config/architectures/architecture_model.py +3 -3
- careamics/config/architectures/lvae_model.py +6 -1
- careamics/config/architectures/unet_model.py +1 -0
- careamics/config/care_configuration.py +100 -0
- careamics/config/configuration.py +354 -0
- careamics/config/{configuration_factory.py → configuration_factories.py} +103 -36
- careamics/config/configuration_io.py +85 -0
- careamics/config/data/__init__.py +10 -0
- careamics/config/{data_model.py → data/data_model.py} +58 -198
- careamics/config/data/n2v_data_model.py +193 -0
- careamics/config/likelihood_model.py +1 -2
- careamics/config/n2n_configuration.py +101 -0
- careamics/config/n2v_configuration.py +266 -0
- careamics/config/nm_model.py +1 -2
- careamics/config/support/__init__.py +7 -7
- careamics/config/support/supported_algorithms.py +0 -3
- careamics/config/support/supported_architectures.py +0 -4
- careamics/config/transformations/__init__.py +10 -4
- careamics/config/transformations/transform_model.py +3 -3
- careamics/config/transformations/transform_unions.py +42 -0
- careamics/config/validators/validator_utils.py +3 -3
- careamics/dataset/__init__.py +2 -2
- careamics/dataset/dataset_utils/__init__.py +3 -3
- careamics/dataset/dataset_utils/dataset_utils.py +4 -6
- careamics/dataset/dataset_utils/file_utils.py +9 -9
- careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
- careamics/dataset/in_memory_dataset.py +11 -12
- careamics/dataset/iterable_dataset.py +4 -4
- careamics/dataset/iterable_pred_dataset.py +2 -1
- careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
- careamics/dataset/patching/random_patching.py +11 -10
- careamics/dataset/patching/sequential_patching.py +26 -26
- careamics/dataset/patching/validate_patch_dimension.py +3 -3
- careamics/dataset/tiling/__init__.py +2 -2
- careamics/dataset/tiling/collate_tiles.py +3 -3
- careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
- careamics/dataset/tiling/tiled_patching.py +11 -10
- careamics/file_io/__init__.py +5 -5
- careamics/file_io/read/__init__.py +1 -1
- careamics/file_io/read/get_func.py +2 -2
- careamics/file_io/write/__init__.py +2 -2
- careamics/lightning/__init__.py +5 -5
- careamics/lightning/callbacks/__init__.py +1 -1
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
- careamics/lightning/callbacks/progress_bar_callback.py +2 -2
- careamics/lightning/lightning_module.py +11 -7
- careamics/lightning/train_data_module.py +26 -26
- careamics/losses/__init__.py +3 -3
- careamics/model_io/__init__.py +1 -1
- careamics/model_io/bioimage/__init__.py +1 -1
- careamics/model_io/bioimage/_readme_factory.py +1 -1
- careamics/model_io/bioimage/model_description.py +17 -17
- careamics/model_io/bmz_io.py +6 -17
- careamics/model_io/model_io_utils.py +9 -9
- careamics/models/layers.py +16 -16
- careamics/models/lvae/lvae.py +0 -3
- careamics/models/model_factory.py +2 -15
- careamics/models/unet.py +8 -8
- careamics/prediction_utils/__init__.py +1 -1
- careamics/prediction_utils/prediction_outputs.py +15 -15
- careamics/prediction_utils/stitch_prediction.py +6 -6
- careamics/transforms/__init__.py +5 -5
- careamics/transforms/compose.py +13 -13
- careamics/transforms/n2v_manipulate.py +3 -3
- careamics/transforms/pixel_manipulation.py +9 -9
- careamics/transforms/xy_random_rotate90.py +4 -4
- careamics/utils/__init__.py +5 -5
- careamics/utils/context.py +2 -1
- careamics/utils/logging.py +11 -10
- careamics/utils/torch_utils.py +7 -7
- {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/METADATA +11 -11
- {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/RECORD +90 -85
- careamics/config/architectures/custom_model.py +0 -162
- careamics/config/architectures/register_model.py +0 -103
- careamics/config/configuration_model.py +0 -603
- careamics/config/fcn_algorithm_model.py +0 -152
- careamics/config/references/__init__.py +0 -45
- careamics/config/references/algorithm_descriptions.py +0 -132
- careamics/config/references/references.py +0 -39
- careamics/config/transformations/transform_union.py +0 -20
- {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/WHEEL +0 -0
- {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/licenses/LICENSE +0 -0
|
@@ -9,7 +9,7 @@ import pytorch_lightning as L
|
|
|
9
9
|
from numpy.typing import NDArray
|
|
10
10
|
from torch.utils.data import DataLoader, IterableDataset
|
|
11
11
|
|
|
12
|
-
from careamics.config import DataConfig
|
|
12
|
+
from careamics.config.data import DataConfig, GeneralDataConfig, N2VDataConfig
|
|
13
13
|
from careamics.config.support import SupportedData
|
|
14
14
|
from careamics.config.transformations import TransformModel
|
|
15
15
|
from careamics.dataset.dataset_utils import (
|
|
@@ -119,7 +119,7 @@ class TrainDataModule(L.LightningDataModule):
|
|
|
119
119
|
|
|
120
120
|
def __init__(
|
|
121
121
|
self,
|
|
122
|
-
data_config:
|
|
122
|
+
data_config: GeneralDataConfig,
|
|
123
123
|
train_data: Union[Path, str, NDArray],
|
|
124
124
|
val_data: Optional[Union[Path, str, NDArray]] = None,
|
|
125
125
|
train_data_target: Optional[Union[Path, str, NDArray]] = None,
|
|
@@ -219,7 +219,7 @@ class TrainDataModule(L.LightningDataModule):
|
|
|
219
219
|
)
|
|
220
220
|
|
|
221
221
|
# configuration
|
|
222
|
-
self.data_config:
|
|
222
|
+
self.data_config: GeneralDataConfig = data_config
|
|
223
223
|
self.data_type: str = data_config.data_type
|
|
224
224
|
self.batch_size: int = data_config.batch_size
|
|
225
225
|
self.use_in_memory: bool = use_in_memory
|
|
@@ -502,12 +502,23 @@ def create_train_datamodule(
|
|
|
502
502
|
"""Create a TrainDataModule.
|
|
503
503
|
|
|
504
504
|
This function is used to explicitly pass the parameters usually contained in a
|
|
505
|
-
`
|
|
505
|
+
`GenericDataConfig` to a TrainDataModule.
|
|
506
506
|
|
|
507
507
|
Since the lightning datamodule has no access to the model, make sure that the
|
|
508
508
|
parameters passed to the datamodule are consistent with the model's requirements and
|
|
509
509
|
are coherent.
|
|
510
510
|
|
|
511
|
+
By default, the train DataModule will be set for Noise2Void if no target data is
|
|
512
|
+
provided. That means that it will add a `N2VManipulateModel` transformation to the
|
|
513
|
+
list of augmentations. The default augmentations are XY flip, XY rotation, and N2V
|
|
514
|
+
pixel manipulation. If you pass a training target data, the default behaviour is to
|
|
515
|
+
train a supervised model. It will use the default XY flip and rotation
|
|
516
|
+
augmentations.
|
|
517
|
+
|
|
518
|
+
To use a different set of transformations, you can pass a list of transforms to
|
|
519
|
+
`transforms`. Note that if you intend to use Noise2Void, you should add
|
|
520
|
+
`N2VManipulateModel` as the last transform in the list of transformations.
|
|
521
|
+
|
|
511
522
|
The data module can be used with Path, str or numpy arrays. In the case of
|
|
512
523
|
numpy arrays, it loads and computes all the patches in memory. For Path and str
|
|
513
524
|
inputs, it calculates the total file size and estimate whether it can fit in
|
|
@@ -518,11 +529,6 @@ def create_train_datamodule(
|
|
|
518
529
|
To use array data, set `data_type` to `array` and pass a numpy array to
|
|
519
530
|
`train_data`.
|
|
520
531
|
|
|
521
|
-
In particular, N2V requires a specific transformation (N2V manipulates), which is
|
|
522
|
-
not compatible with supervised training. The default transformations applied to the
|
|
523
|
-
training patches are defined in `careamics.config.data_model`. To use different
|
|
524
|
-
transformations, pass a list of transforms. See examples for more details.
|
|
525
|
-
|
|
526
532
|
By default, CAREamics only supports types defined in
|
|
527
533
|
`careamics.config.support.SupportedData`. To read custom data types, you can set
|
|
528
534
|
`data_type` to `custom` and provide a function that returns a numpy array from a
|
|
@@ -627,12 +633,12 @@ def create_train_datamodule(
|
|
|
627
633
|
transforms:
|
|
628
634
|
>>> import numpy as np
|
|
629
635
|
>>> from careamics.lightning import create_train_datamodule
|
|
636
|
+
>>> from careamics.config.transformations import XYFlipModel, N2VManipulateModel
|
|
630
637
|
>>> from careamics.config.support import SupportedTransform
|
|
631
638
|
>>> my_array = np.arange(256).reshape(16, 16)
|
|
632
639
|
>>> my_transforms = [
|
|
633
|
-
...
|
|
634
|
-
...
|
|
635
|
-
... }
|
|
640
|
+
... XYFlipModel(flip_y=False),
|
|
641
|
+
... N2VManipulateModel()
|
|
636
642
|
... ]
|
|
637
643
|
>>> data_module = create_train_datamodule(
|
|
638
644
|
... train_data=my_array,
|
|
@@ -659,21 +665,15 @@ def create_train_datamodule(
|
|
|
659
665
|
if transforms is not None:
|
|
660
666
|
data_dict["transforms"] = transforms
|
|
661
667
|
|
|
662
|
-
#
|
|
663
|
-
|
|
668
|
+
# TODO not compatible with HDN, consider adding an argument for n2v/hdn
|
|
669
|
+
if train_target_data is None:
|
|
670
|
+
data_config: GeneralDataConfig = N2VDataConfig(**data_dict)
|
|
671
|
+
assert isinstance(data_config, N2VDataConfig)
|
|
664
672
|
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
data_config.set_N2V2(use_n2v2)
|
|
670
|
-
data_config.set_structN2V_mask(struct_n2v_axis, struct_n2v_span)
|
|
671
|
-
else:
|
|
672
|
-
raise ValueError(
|
|
673
|
-
"Cannot have both supervised training (target data) and "
|
|
674
|
-
"N2V manipulation in the transforms. Pass a list of transforms "
|
|
675
|
-
"that is compatible with your supervised training."
|
|
676
|
-
)
|
|
673
|
+
data_config.set_n2v2(use_n2v2)
|
|
674
|
+
data_config.set_structN2V_mask(struct_n2v_axis, struct_n2v_span)
|
|
675
|
+
else:
|
|
676
|
+
data_config = DataConfig(**data_dict)
|
|
677
677
|
|
|
678
678
|
# sanity check on the dataloader parameters
|
|
679
679
|
if "batch_size" in dataloader_params:
|
careamics/losses/__init__.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
"""Losses module."""
|
|
2
2
|
|
|
3
3
|
__all__ = [
|
|
4
|
+
"denoisplit_loss",
|
|
5
|
+
"denoisplit_musplit_loss",
|
|
4
6
|
"loss_factory",
|
|
5
7
|
"mae_loss",
|
|
6
8
|
"mse_loss",
|
|
7
|
-
"n2v_loss",
|
|
8
|
-
"denoisplit_loss",
|
|
9
9
|
"musplit_loss",
|
|
10
|
-
"
|
|
10
|
+
"n2v_loss",
|
|
11
11
|
]
|
|
12
12
|
|
|
13
13
|
from .fcn.losses import mae_loss, mse_loss, n2v_loss
|
careamics/model_io/__init__.py
CHANGED
|
@@ -55,7 +55,7 @@ def readme_factory(
|
|
|
55
55
|
readme.touch()
|
|
56
56
|
|
|
57
57
|
# algorithm pretty name
|
|
58
|
-
algorithm_flavour = config.
|
|
58
|
+
algorithm_flavour = config.get_algorithm_friendly_name()
|
|
59
59
|
algorithm_pretty_name = algorithm_flavour + " - CAREamics"
|
|
60
60
|
|
|
61
61
|
description = [f"# {algorithm_pretty_name}\n\n"]
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Module use to build BMZ model description."""
|
|
2
2
|
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import
|
|
4
|
+
from typing import Optional, Union
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
from bioimageio.spec._internal.io import resolve_and_extract
|
|
@@ -28,17 +28,17 @@ from bioimageio.spec.model.v0_5 import (
|
|
|
28
28
|
WeightsDescr,
|
|
29
29
|
)
|
|
30
30
|
|
|
31
|
-
from careamics.config import Configuration,
|
|
31
|
+
from careamics.config import Configuration, GeneralDataConfig
|
|
32
32
|
|
|
33
33
|
from ._readme_factory import readme_factory
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
def _create_axes(
|
|
37
37
|
array: np.ndarray,
|
|
38
|
-
data_config:
|
|
39
|
-
channel_names: Optional[
|
|
38
|
+
data_config: GeneralDataConfig,
|
|
39
|
+
channel_names: Optional[list[str]] = None,
|
|
40
40
|
is_input: bool = True,
|
|
41
|
-
) ->
|
|
41
|
+
) -> list[AxisBase]:
|
|
42
42
|
"""Create axes description.
|
|
43
43
|
|
|
44
44
|
Array shape is expected to be SC(Z)YX.
|
|
@@ -49,15 +49,15 @@ def _create_axes(
|
|
|
49
49
|
Array.
|
|
50
50
|
data_config : DataModel
|
|
51
51
|
CAREamics data configuration.
|
|
52
|
-
channel_names : Optional[
|
|
52
|
+
channel_names : Optional[list[str]], optional
|
|
53
53
|
Channel names, by default None.
|
|
54
54
|
is_input : bool, optional
|
|
55
55
|
Whether the axes are input axes, by default True.
|
|
56
56
|
|
|
57
57
|
Returns
|
|
58
58
|
-------
|
|
59
|
-
|
|
60
|
-
|
|
59
|
+
list[AxisBase]
|
|
60
|
+
list of axes description.
|
|
61
61
|
|
|
62
62
|
Raises
|
|
63
63
|
------
|
|
@@ -102,11 +102,11 @@ def _create_axes(
|
|
|
102
102
|
def _create_inputs_ouputs(
|
|
103
103
|
input_array: np.ndarray,
|
|
104
104
|
output_array: np.ndarray,
|
|
105
|
-
data_config:
|
|
105
|
+
data_config: GeneralDataConfig,
|
|
106
106
|
input_path: Union[Path, str],
|
|
107
107
|
output_path: Union[Path, str],
|
|
108
|
-
channel_names: Optional[
|
|
109
|
-
) ->
|
|
108
|
+
channel_names: Optional[list[str]] = None,
|
|
109
|
+
) -> tuple[InputTensorDescr, OutputTensorDescr]:
|
|
110
110
|
"""Create input and output tensor description.
|
|
111
111
|
|
|
112
112
|
Input and output paths must point to a `.npy` file.
|
|
@@ -123,12 +123,12 @@ def _create_inputs_ouputs(
|
|
|
123
123
|
Path to input .npy file.
|
|
124
124
|
output_path : Union[Path, str]
|
|
125
125
|
Path to output .npy file.
|
|
126
|
-
channel_names : Optional[
|
|
126
|
+
channel_names : Optional[list[str]], optional
|
|
127
127
|
Channel names, by default None.
|
|
128
128
|
|
|
129
129
|
Returns
|
|
130
130
|
-------
|
|
131
|
-
|
|
131
|
+
tuple[InputTensorDescr, OutputTensorDescr]
|
|
132
132
|
Input and output tensor descriptions.
|
|
133
133
|
"""
|
|
134
134
|
input_axes = _create_axes(input_array, data_config, channel_names)
|
|
@@ -188,7 +188,7 @@ def create_model_description(
|
|
|
188
188
|
name: str,
|
|
189
189
|
general_description: str,
|
|
190
190
|
data_description: str,
|
|
191
|
-
authors:
|
|
191
|
+
authors: list[Author],
|
|
192
192
|
inputs: Union[Path, str],
|
|
193
193
|
outputs: Union[Path, str],
|
|
194
194
|
weights_path: Union[Path, str],
|
|
@@ -197,7 +197,7 @@ def create_model_description(
|
|
|
197
197
|
config_path: Union[Path, str],
|
|
198
198
|
env_path: Union[Path, str],
|
|
199
199
|
covers: list[Union[Path, str]],
|
|
200
|
-
channel_names: Optional[
|
|
200
|
+
channel_names: Optional[list[str]] = None,
|
|
201
201
|
model_version: str = "0.1.0",
|
|
202
202
|
) -> ModelDescr:
|
|
203
203
|
"""Create model description.
|
|
@@ -212,7 +212,7 @@ def create_model_description(
|
|
|
212
212
|
General description of the model.
|
|
213
213
|
data_description : str
|
|
214
214
|
Description of the data the model was trained on.
|
|
215
|
-
authors :
|
|
215
|
+
authors : list[Author]
|
|
216
216
|
Authors of the model.
|
|
217
217
|
inputs : Union[Path, str]
|
|
218
218
|
Path to input .npy file.
|
|
@@ -230,7 +230,7 @@ def create_model_description(
|
|
|
230
230
|
Path to environment file.
|
|
231
231
|
covers : list of pathlib.Path or str
|
|
232
232
|
Paths to cover images.
|
|
233
|
-
channel_names : Optional[
|
|
233
|
+
channel_names : Optional[list[str]], optional
|
|
234
234
|
Channel names, by default None.
|
|
235
235
|
model_version : str, default "0.1.0"
|
|
236
236
|
Model version.
|
careamics/model_io/bmz_io.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import tempfile
|
|
4
4
|
from pathlib import Path
|
|
5
|
-
from typing import
|
|
5
|
+
from typing import Optional, Union
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
import pkg_resources
|
|
@@ -87,11 +87,11 @@ def export_to_bmz(
|
|
|
87
87
|
model_name: str,
|
|
88
88
|
general_description: str,
|
|
89
89
|
data_description: str,
|
|
90
|
-
authors:
|
|
90
|
+
authors: list[dict],
|
|
91
91
|
input_array: np.ndarray,
|
|
92
92
|
output_array: np.ndarray,
|
|
93
93
|
covers: Optional[list[Union[Path, str]]] = None,
|
|
94
|
-
channel_names: Optional[
|
|
94
|
+
channel_names: Optional[list[str]] = None,
|
|
95
95
|
model_version: str = "0.1.0",
|
|
96
96
|
) -> None:
|
|
97
97
|
"""Export the model to BioImage Model Zoo format.
|
|
@@ -115,7 +115,7 @@ def export_to_bmz(
|
|
|
115
115
|
General description of the model.
|
|
116
116
|
data_description : str
|
|
117
117
|
Description of the data the model was trained on.
|
|
118
|
-
authors :
|
|
118
|
+
authors : list[dict]
|
|
119
119
|
Authors of the model.
|
|
120
120
|
input_array : np.ndarray
|
|
121
121
|
Input array, should not have been normalized.
|
|
@@ -123,24 +123,13 @@ def export_to_bmz(
|
|
|
123
123
|
Output array, should have been denormalized.
|
|
124
124
|
covers : list of pathlib.Path or str, default=None
|
|
125
125
|
Paths to the cover images.
|
|
126
|
-
channel_names : Optional[
|
|
126
|
+
channel_names : Optional[list[str]], optional
|
|
127
127
|
Channel names, by default None.
|
|
128
128
|
model_version : str, default="0.1.0"
|
|
129
129
|
Model version.
|
|
130
|
-
|
|
131
|
-
Raises
|
|
132
|
-
------
|
|
133
|
-
ValueError
|
|
134
|
-
If the model is a Custom model.
|
|
135
130
|
"""
|
|
136
131
|
path_to_archive = Path(path_to_archive)
|
|
137
132
|
|
|
138
|
-
# method is not compatible with Custom models
|
|
139
|
-
if config.algorithm_config.model.architecture == SupportedArchitecture.CUSTOM:
|
|
140
|
-
raise ValueError(
|
|
141
|
-
"Exporting Custom models to BioImage Model Zoo format is not supported."
|
|
142
|
-
)
|
|
143
|
-
|
|
144
133
|
if path_to_archive.suffix != ".zip":
|
|
145
134
|
raise ValueError(
|
|
146
135
|
f"Path to archive must point to a zip file, got {path_to_archive}."
|
|
@@ -212,7 +201,7 @@ def export_to_bmz(
|
|
|
212
201
|
|
|
213
202
|
def load_from_bmz(
|
|
214
203
|
path: Union[Path, str, HttpUrl]
|
|
215
|
-
) ->
|
|
204
|
+
) -> tuple[Union[FCNModule, VAEModule], Configuration]:
|
|
216
205
|
"""Load a model from a BioImage Model Zoo archive.
|
|
217
206
|
|
|
218
207
|
Parameters
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
"""Utility functions to load pretrained models."""
|
|
2
2
|
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import
|
|
4
|
+
from typing import Union
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from careamics.config import Configuration
|
|
8
|
+
from careamics.config import Configuration, configuration_factory
|
|
9
9
|
from careamics.lightning.lightning_module import FCNModule, VAEModule
|
|
10
10
|
from careamics.model_io.bmz_io import load_from_bmz
|
|
11
11
|
from careamics.utils import check_path_exists
|
|
@@ -13,7 +13,7 @@ from careamics.utils import check_path_exists
|
|
|
13
13
|
|
|
14
14
|
def load_pretrained(
|
|
15
15
|
path: Union[Path, str]
|
|
16
|
-
) ->
|
|
16
|
+
) -> tuple[Union[FCNModule, VAEModule], Configuration]:
|
|
17
17
|
"""
|
|
18
18
|
Load a pretrained model from a checkpoint or a BioImage Model Zoo model.
|
|
19
19
|
|
|
@@ -26,8 +26,8 @@ def load_pretrained(
|
|
|
26
26
|
|
|
27
27
|
Returns
|
|
28
28
|
-------
|
|
29
|
-
|
|
30
|
-
|
|
29
|
+
tuple[CAREamicsKiln, Configuration]
|
|
30
|
+
tuple of CAREamics model and its configuration.
|
|
31
31
|
|
|
32
32
|
Raises
|
|
33
33
|
------
|
|
@@ -48,7 +48,7 @@ def load_pretrained(
|
|
|
48
48
|
|
|
49
49
|
def _load_checkpoint(
|
|
50
50
|
path: Union[Path, str]
|
|
51
|
-
) ->
|
|
51
|
+
) -> tuple[Union[FCNModule, VAEModule], Configuration]:
|
|
52
52
|
"""
|
|
53
53
|
Load a model from a checkpoint and return both model and configuration.
|
|
54
54
|
|
|
@@ -59,8 +59,8 @@ def _load_checkpoint(
|
|
|
59
59
|
|
|
60
60
|
Returns
|
|
61
61
|
-------
|
|
62
|
-
|
|
63
|
-
|
|
62
|
+
tuple[CAREamicsKiln, Configuration]
|
|
63
|
+
tuple of CAREamics model and its configuration.
|
|
64
64
|
|
|
65
65
|
Raises
|
|
66
66
|
------
|
|
@@ -92,4 +92,4 @@ def _load_checkpoint(
|
|
|
92
92
|
f"{cfg_dict['algorithm_config']['model']['architecture']}"
|
|
93
93
|
)
|
|
94
94
|
|
|
95
|
-
return model,
|
|
95
|
+
return model, configuration_factory(cfg_dict)
|
careamics/models/layers.py
CHANGED
|
@@ -4,7 +4,7 @@ Layer module.
|
|
|
4
4
|
This submodule contains layers used in the CAREamics models.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
from typing import
|
|
7
|
+
from typing import Optional, Union
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
import torch.nn as nn
|
|
@@ -157,22 +157,22 @@ class Conv_Block(nn.Module):
|
|
|
157
157
|
|
|
158
158
|
|
|
159
159
|
def _unpack_kernel_size(
|
|
160
|
-
kernel_size: Union[
|
|
161
|
-
) ->
|
|
160
|
+
kernel_size: Union[tuple[int, ...], int], dim: int
|
|
161
|
+
) -> tuple[int, ...]:
|
|
162
162
|
"""Unpack kernel_size to a tuple of ints.
|
|
163
163
|
|
|
164
164
|
Inspired by Kornia implementation. TODO: link
|
|
165
165
|
|
|
166
166
|
Parameters
|
|
167
167
|
----------
|
|
168
|
-
kernel_size : Union[
|
|
168
|
+
kernel_size : Union[tuple[int, ...], int]
|
|
169
169
|
Kernel size.
|
|
170
170
|
dim : int
|
|
171
171
|
Number of dimensions.
|
|
172
172
|
|
|
173
173
|
Returns
|
|
174
174
|
-------
|
|
175
|
-
|
|
175
|
+
tuple[int, ...]
|
|
176
176
|
Kernel size tuple.
|
|
177
177
|
"""
|
|
178
178
|
if isinstance(kernel_size, int):
|
|
@@ -183,20 +183,20 @@ def _unpack_kernel_size(
|
|
|
183
183
|
|
|
184
184
|
|
|
185
185
|
def _compute_zero_padding(
|
|
186
|
-
kernel_size: Union[
|
|
187
|
-
) ->
|
|
186
|
+
kernel_size: Union[tuple[int, ...], int], dim: int
|
|
187
|
+
) -> tuple[int, ...]:
|
|
188
188
|
"""Utility function that computes zero padding tuple.
|
|
189
189
|
|
|
190
190
|
Parameters
|
|
191
191
|
----------
|
|
192
|
-
kernel_size : Union[
|
|
192
|
+
kernel_size : Union[tuple[int, ...], int]
|
|
193
193
|
Kernel size.
|
|
194
194
|
dim : int
|
|
195
195
|
Number of dimensions.
|
|
196
196
|
|
|
197
197
|
Returns
|
|
198
198
|
-------
|
|
199
|
-
|
|
199
|
+
tuple[int, ...]
|
|
200
200
|
Zero padding tuple.
|
|
201
201
|
"""
|
|
202
202
|
kernel_dims = _unpack_kernel_size(kernel_size, dim)
|
|
@@ -245,8 +245,8 @@ def get_pascal_kernel_1d(
|
|
|
245
245
|
>>> get_pascal_kernel_1d(6)
|
|
246
246
|
tensor([ 1., 5., 10., 10., 5., 1.])
|
|
247
247
|
"""
|
|
248
|
-
pre:
|
|
249
|
-
cur:
|
|
248
|
+
pre: list[float] = []
|
|
249
|
+
cur: list[float] = []
|
|
250
250
|
for i in range(kernel_size):
|
|
251
251
|
cur = [1.0] * (i + 1)
|
|
252
252
|
|
|
@@ -266,7 +266,7 @@ def get_pascal_kernel_1d(
|
|
|
266
266
|
|
|
267
267
|
|
|
268
268
|
def _get_pascal_kernel_nd(
|
|
269
|
-
kernel_size: Union[
|
|
269
|
+
kernel_size: Union[tuple[int, int], int],
|
|
270
270
|
norm: bool = True,
|
|
271
271
|
dim: int = 2,
|
|
272
272
|
*,
|
|
@@ -282,7 +282,7 @@ def _get_pascal_kernel_nd(
|
|
|
282
282
|
|
|
283
283
|
Parameters
|
|
284
284
|
----------
|
|
285
|
-
kernel_size : Union[
|
|
285
|
+
kernel_size : Union[tuple[int, int], int]
|
|
286
286
|
Kernel size for the pascal kernel.
|
|
287
287
|
norm : bool
|
|
288
288
|
Normalize the kernel, by default True.
|
|
@@ -420,7 +420,7 @@ class MaxBlurPool(nn.Module):
|
|
|
420
420
|
----------
|
|
421
421
|
dim : int
|
|
422
422
|
Toggles between 2D and 3D.
|
|
423
|
-
kernel_size : Union[
|
|
423
|
+
kernel_size : Union[tuple[int, int], int]
|
|
424
424
|
Kernel size for max pooling.
|
|
425
425
|
stride : int
|
|
426
426
|
Stride for pooling.
|
|
@@ -433,7 +433,7 @@ class MaxBlurPool(nn.Module):
|
|
|
433
433
|
def __init__(
|
|
434
434
|
self,
|
|
435
435
|
dim: int,
|
|
436
|
-
kernel_size: Union[
|
|
436
|
+
kernel_size: Union[tuple[int, int], int],
|
|
437
437
|
stride: int = 2,
|
|
438
438
|
max_pool_size: int = 2,
|
|
439
439
|
ceil_mode: bool = False,
|
|
@@ -444,7 +444,7 @@ class MaxBlurPool(nn.Module):
|
|
|
444
444
|
----------
|
|
445
445
|
dim : int
|
|
446
446
|
Dimension of the convolution.
|
|
447
|
-
kernel_size : Union[
|
|
447
|
+
kernel_size : Union[tuple[int, int], int]
|
|
448
448
|
Kernel size for max pooling.
|
|
449
449
|
stride : int, optional
|
|
450
450
|
Stride, by default 2.
|
careamics/models/lvae/lvae.py
CHANGED
|
@@ -12,8 +12,6 @@ import numpy as np
|
|
|
12
12
|
import torch
|
|
13
13
|
import torch.nn as nn
|
|
14
14
|
|
|
15
|
-
from careamics.config.architectures import register_model
|
|
16
|
-
|
|
17
15
|
from ..activation import get_activation
|
|
18
16
|
from .layers import (
|
|
19
17
|
BottomUpDeterministicResBlock,
|
|
@@ -25,7 +23,6 @@ from .layers import (
|
|
|
25
23
|
from .utils import Interpolate, ModelType, crop_img_tensor
|
|
26
24
|
|
|
27
25
|
|
|
28
|
-
@register_model("LVAE")
|
|
29
26
|
class LadderVAE(nn.Module):
|
|
30
27
|
"""
|
|
31
28
|
Constructor.
|
|
@@ -1,8 +1,4 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Model factory.
|
|
3
|
-
|
|
4
|
-
Model creation factory functions.
|
|
5
|
-
"""
|
|
1
|
+
"""Model creation factory functions."""
|
|
6
2
|
|
|
7
3
|
from __future__ import annotations
|
|
8
4
|
|
|
@@ -10,10 +6,6 @@ from typing import TYPE_CHECKING, Union
|
|
|
10
6
|
|
|
11
7
|
import torch
|
|
12
8
|
|
|
13
|
-
from careamics.config.architectures import (
|
|
14
|
-
CustomModel,
|
|
15
|
-
get_custom_model,
|
|
16
|
-
)
|
|
17
9
|
from careamics.config.support import SupportedArchitecture
|
|
18
10
|
from careamics.models.lvae import LadderVAE as LVAE
|
|
19
11
|
from careamics.models.unet import UNet
|
|
@@ -21,7 +13,6 @@ from careamics.utils import get_logger
|
|
|
21
13
|
|
|
22
14
|
if TYPE_CHECKING:
|
|
23
15
|
from careamics.config.architectures import (
|
|
24
|
-
CustomModel,
|
|
25
16
|
LVAEModel,
|
|
26
17
|
UNetModel,
|
|
27
18
|
)
|
|
@@ -31,7 +22,7 @@ logger = get_logger(__name__)
|
|
|
31
22
|
|
|
32
23
|
|
|
33
24
|
def model_factory(
|
|
34
|
-
model_configuration: Union[UNetModel, LVAEModel
|
|
25
|
+
model_configuration: Union[UNetModel, LVAEModel],
|
|
35
26
|
) -> torch.nn.Module:
|
|
36
27
|
"""
|
|
37
28
|
Deep learning model factory.
|
|
@@ -57,10 +48,6 @@ def model_factory(
|
|
|
57
48
|
return UNet(**model_configuration.model_dump())
|
|
58
49
|
elif model_configuration.architecture == SupportedArchitecture.LVAE:
|
|
59
50
|
return LVAE(**model_configuration.model_dump())
|
|
60
|
-
elif model_configuration.architecture == SupportedArchitecture.CUSTOM:
|
|
61
|
-
assert isinstance(model_configuration, CustomModel)
|
|
62
|
-
model = get_custom_model(model_configuration.name)
|
|
63
|
-
return model(**model_configuration.model_dump())
|
|
64
51
|
else:
|
|
65
52
|
raise NotImplementedError(
|
|
66
53
|
f"Model {model_configuration.architecture} is not implemented or unknown."
|
careamics/models/unet.py
CHANGED
|
@@ -4,7 +4,7 @@ UNet model.
|
|
|
4
4
|
A UNet encoder, decoder and complete model.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
from typing import Any,
|
|
7
|
+
from typing import Any, Union
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
import torch.nn as nn
|
|
@@ -104,7 +104,7 @@ class UnetEncoder(nn.Module):
|
|
|
104
104
|
encoder_blocks.append(self.pooling)
|
|
105
105
|
self.encoder_blocks = nn.ModuleList(encoder_blocks)
|
|
106
106
|
|
|
107
|
-
def forward(self, x: torch.Tensor) ->
|
|
107
|
+
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
|
|
108
108
|
"""
|
|
109
109
|
Forward pass.
|
|
110
110
|
|
|
@@ -115,7 +115,7 @@ class UnetEncoder(nn.Module):
|
|
|
115
115
|
|
|
116
116
|
Returns
|
|
117
117
|
-------
|
|
118
|
-
|
|
118
|
+
list[torch.Tensor]
|
|
119
119
|
Output of each encoder block (skip connections) and final output of the
|
|
120
120
|
encoder.
|
|
121
121
|
"""
|
|
@@ -202,7 +202,7 @@ class UnetDecoder(nn.Module):
|
|
|
202
202
|
groups=self.groups,
|
|
203
203
|
)
|
|
204
204
|
|
|
205
|
-
decoder_blocks:
|
|
205
|
+
decoder_blocks: list[nn.Module] = []
|
|
206
206
|
for n in range(depth):
|
|
207
207
|
decoder_blocks.append(upsampling)
|
|
208
208
|
in_channels = (num_channels_init * 2 ** (depth - n)) * groups
|
|
@@ -230,7 +230,7 @@ class UnetDecoder(nn.Module):
|
|
|
230
230
|
|
|
231
231
|
Parameters
|
|
232
232
|
----------
|
|
233
|
-
*features :
|
|
233
|
+
*features : list[torch.Tensor]
|
|
234
234
|
List containing the output of each encoder block(skip connections) and final
|
|
235
235
|
output of the encoder.
|
|
236
236
|
|
|
@@ -240,7 +240,7 @@ class UnetDecoder(nn.Module):
|
|
|
240
240
|
Output of the decoder.
|
|
241
241
|
"""
|
|
242
242
|
x: torch.Tensor = features[0]
|
|
243
|
-
skip_connections:
|
|
243
|
+
skip_connections: tuple[torch.Tensor, ...] = features[-1:0:-1]
|
|
244
244
|
|
|
245
245
|
x = self.bottleneck(x)
|
|
246
246
|
|
|
@@ -289,10 +289,10 @@ class UnetDecoder(nn.Module):
|
|
|
289
289
|
m = A.shape[1] // groups
|
|
290
290
|
n = B.shape[1] // groups
|
|
291
291
|
|
|
292
|
-
A_groups:
|
|
292
|
+
A_groups: list[torch.Tensor] = [
|
|
293
293
|
A[:, i * m : (i + 1) * m] for i in range(groups)
|
|
294
294
|
]
|
|
295
|
-
B_groups:
|
|
295
|
+
B_groups: list[torch.Tensor] = [
|
|
296
296
|
B[:, i * n : (i + 1) * n] for i in range(groups)
|
|
297
297
|
]
|
|
298
298
|
|