careamics 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (141) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +726 -0
  3. careamics/config/__init__.py +35 -0
  4. careamics/config/algorithm_model.py +162 -0
  5. careamics/config/architectures/__init__.py +17 -0
  6. careamics/config/architectures/architecture_model.py +37 -0
  7. careamics/config/architectures/custom_model.py +159 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/architectures/vae_model.py +42 -0
  11. careamics/config/callback_model.py +123 -0
  12. careamics/config/configuration_factory.py +575 -0
  13. careamics/config/configuration_model.py +600 -0
  14. careamics/config/data_model.py +502 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/optimizer_models.py +187 -0
  17. careamics/config/references/__init__.py +45 -0
  18. careamics/config/references/algorithm_descriptions.py +132 -0
  19. careamics/config/references/references.py +39 -0
  20. careamics/config/support/__init__.py +31 -0
  21. careamics/config/support/supported_activations.py +26 -0
  22. careamics/config/support/supported_algorithms.py +20 -0
  23. careamics/config/support/supported_architectures.py +20 -0
  24. careamics/config/support/supported_data.py +109 -0
  25. careamics/config/support/supported_loggers.py +10 -0
  26. careamics/config/support/supported_losses.py +27 -0
  27. careamics/config/support/supported_optimizers.py +57 -0
  28. careamics/config/support/supported_pixel_manipulations.py +15 -0
  29. careamics/config/support/supported_struct_axis.py +21 -0
  30. careamics/config/support/supported_transforms.py +11 -0
  31. careamics/config/tile_information.py +65 -0
  32. careamics/config/training_model.py +72 -0
  33. careamics/config/transformations/__init__.py +15 -0
  34. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  35. careamics/config/transformations/normalize_model.py +60 -0
  36. careamics/config/transformations/transform_model.py +45 -0
  37. careamics/config/transformations/xy_flip_model.py +43 -0
  38. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  39. careamics/config/validators/__init__.py +5 -0
  40. careamics/config/validators/validator_utils.py +101 -0
  41. careamics/conftest.py +39 -0
  42. careamics/dataset/__init__.py +17 -0
  43. careamics/dataset/dataset_utils/__init__.py +19 -0
  44. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  45. careamics/dataset/dataset_utils/file_utils.py +141 -0
  46. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  47. careamics/dataset/dataset_utils/running_stats.py +186 -0
  48. careamics/dataset/in_memory_dataset.py +310 -0
  49. careamics/dataset/in_memory_pred_dataset.py +88 -0
  50. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  51. careamics/dataset/iterable_dataset.py +295 -0
  52. careamics/dataset/iterable_pred_dataset.py +122 -0
  53. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  54. careamics/dataset/patching/__init__.py +1 -0
  55. careamics/dataset/patching/patching.py +299 -0
  56. careamics/dataset/patching/random_patching.py +201 -0
  57. careamics/dataset/patching/sequential_patching.py +212 -0
  58. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  59. careamics/dataset/tiling/__init__.py +10 -0
  60. careamics/dataset/tiling/collate_tiles.py +33 -0
  61. careamics/dataset/tiling/tiled_patching.py +164 -0
  62. careamics/dataset/zarr_dataset.py +151 -0
  63. careamics/file_io/__init__.py +15 -0
  64. careamics/file_io/read/__init__.py +12 -0
  65. careamics/file_io/read/get_func.py +56 -0
  66. careamics/file_io/read/tiff.py +58 -0
  67. careamics/file_io/read/zarr.py +60 -0
  68. careamics/file_io/write/__init__.py +15 -0
  69. careamics/file_io/write/get_func.py +63 -0
  70. careamics/file_io/write/tiff.py +40 -0
  71. careamics/lightning/__init__.py +17 -0
  72. careamics/lightning/callbacks/__init__.py +11 -0
  73. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  74. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  75. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  76. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  77. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  79. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  80. careamics/lightning/lightning_module.py +276 -0
  81. careamics/lightning/predict_data_module.py +333 -0
  82. careamics/lightning/train_data_module.py +680 -0
  83. careamics/losses/__init__.py +5 -0
  84. careamics/losses/loss_factory.py +49 -0
  85. careamics/losses/losses.py +98 -0
  86. careamics/lvae_training/__init__.py +0 -0
  87. careamics/lvae_training/data_modules.py +1220 -0
  88. careamics/lvae_training/data_utils.py +618 -0
  89. careamics/lvae_training/eval_utils.py +905 -0
  90. careamics/lvae_training/get_config.py +84 -0
  91. careamics/lvae_training/lightning_module.py +701 -0
  92. careamics/lvae_training/metrics.py +214 -0
  93. careamics/lvae_training/train_lvae.py +339 -0
  94. careamics/lvae_training/train_utils.py +121 -0
  95. careamics/model_io/__init__.py +7 -0
  96. careamics/model_io/bioimage/__init__.py +11 -0
  97. careamics/model_io/bioimage/_readme_factory.py +121 -0
  98. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  99. careamics/model_io/bioimage/model_description.py +327 -0
  100. careamics/model_io/bmz_io.py +233 -0
  101. careamics/model_io/model_io_utils.py +83 -0
  102. careamics/models/__init__.py +7 -0
  103. careamics/models/activation.py +37 -0
  104. careamics/models/layers.py +493 -0
  105. careamics/models/lvae/__init__.py +0 -0
  106. careamics/models/lvae/layers.py +1998 -0
  107. careamics/models/lvae/likelihoods.py +312 -0
  108. careamics/models/lvae/lvae.py +985 -0
  109. careamics/models/lvae/noise_models.py +409 -0
  110. careamics/models/lvae/utils.py +395 -0
  111. careamics/models/model_factory.py +52 -0
  112. careamics/models/unet.py +443 -0
  113. careamics/prediction_utils/__init__.py +10 -0
  114. careamics/prediction_utils/prediction_outputs.py +135 -0
  115. careamics/prediction_utils/stitch_prediction.py +98 -0
  116. careamics/transforms/__init__.py +20 -0
  117. careamics/transforms/compose.py +107 -0
  118. careamics/transforms/n2v_manipulate.py +146 -0
  119. careamics/transforms/normalize.py +243 -0
  120. careamics/transforms/pixel_manipulation.py +407 -0
  121. careamics/transforms/struct_mask_parameters.py +20 -0
  122. careamics/transforms/transform.py +24 -0
  123. careamics/transforms/tta.py +88 -0
  124. careamics/transforms/xy_flip.py +123 -0
  125. careamics/transforms/xy_random_rotate90.py +101 -0
  126. careamics/utils/__init__.py +19 -0
  127. careamics/utils/autocorrelation.py +40 -0
  128. careamics/utils/base_enum.py +60 -0
  129. careamics/utils/context.py +66 -0
  130. careamics/utils/logging.py +322 -0
  131. careamics/utils/metrics.py +115 -0
  132. careamics/utils/path_utils.py +26 -0
  133. careamics/utils/ram.py +15 -0
  134. careamics/utils/receptive_field.py +108 -0
  135. careamics/utils/torch_utils.py +127 -0
  136. careamics-0.0.2.dist-info/METADATA +78 -0
  137. careamics-0.0.2.dist-info/RECORD +140 -0
  138. {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/WHEEL +1 -1
  139. {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/licenses/LICENSE +1 -1
  140. careamics-0.0.1.dist-info/METADATA +0 -46
  141. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,233 @@
1
+ """Function to export to the BioImage Model Zoo format."""
2
+
3
+ import tempfile
4
+ from pathlib import Path
5
+ from typing import List, Optional, Tuple, Union
6
+
7
+ import numpy as np
8
+ import pkg_resources
9
+ from bioimageio.core import load_description, test_model
10
+ from bioimageio.spec import ValidationSummary, save_bioimageio_package
11
+ from torch import __version__, load, save
12
+
13
+ from careamics.config import Configuration, load_configuration, save_configuration
14
+ from careamics.config.support import SupportedArchitecture
15
+ from careamics.lightning.lightning_module import CAREamicsModule
16
+
17
+ from .bioimage import (
18
+ create_env_text,
19
+ create_model_description,
20
+ extract_model_path,
21
+ get_unzip_path,
22
+ )
23
+
24
+
25
+ def _export_state_dict(model: CAREamicsModule, path: Union[Path, str]) -> Path:
26
+ """
27
+ Export the model state dictionary to a file.
28
+
29
+ Parameters
30
+ ----------
31
+ model : CAREamicsKiln
32
+ CAREamics model to export.
33
+ path : Union[Path, str]
34
+ Path to the file where to save the model state dictionary.
35
+
36
+ Returns
37
+ -------
38
+ Path
39
+ Path to the saved model state dictionary.
40
+ """
41
+ path = Path(path)
42
+
43
+ # make sure it has the correct suffix
44
+ if path.suffix not in ".pth":
45
+ path = path.with_suffix(".pth")
46
+
47
+ # save model state dictionary
48
+ # we save through the torch model itself to avoid the initial "model." in the
49
+ # layers naming, which is incompatible with the way the BMZ load torch state dicts
50
+ save(model.model.state_dict(), path)
51
+
52
+ return path
53
+
54
+
55
+ def _load_state_dict(model: CAREamicsModule, path: Union[Path, str]) -> None:
56
+ """
57
+ Load a model from a state dictionary.
58
+
59
+ Parameters
60
+ ----------
61
+ model : CAREamicsKiln
62
+ CAREamics model to be updated with the weights.
63
+ path : Union[Path, str]
64
+ Path to the model state dictionary.
65
+ """
66
+ path = Path(path)
67
+
68
+ # load model state dictionary
69
+ # same as in _export_state_dict, we load through the torch model to be compatible
70
+ # witht bioimageio.core expectations for a torch state dict
71
+ state_dict = load(path)
72
+ model.model.load_state_dict(state_dict)
73
+
74
+
75
+ # TODO break down in subfunctions
76
+ def export_to_bmz(
77
+ model: CAREamicsModule,
78
+ config: Configuration,
79
+ path_to_archive: Union[Path, str],
80
+ model_name: str,
81
+ general_description: str,
82
+ authors: List[dict],
83
+ input_array: np.ndarray,
84
+ output_array: np.ndarray,
85
+ channel_names: Optional[List[str]] = None,
86
+ data_description: Optional[str] = None,
87
+ ) -> None:
88
+ """Export the model to BioImage Model Zoo format.
89
+
90
+ Arrays are expected to be SC(Z)YX with singleton dimensions allowed for S and C.
91
+
92
+ `model_name` should consist of letters, numbers, dashes, underscores and parentheses
93
+ only.
94
+
95
+ Parameters
96
+ ----------
97
+ model : CAREamicsModule
98
+ CAREamics model to export.
99
+ config : Configuration
100
+ Model configuration.
101
+ path_to_archive : Union[Path, str]
102
+ Path to the output file.
103
+ model_name : str
104
+ Model name.
105
+ general_description : str
106
+ General description of the model.
107
+ authors : List[dict]
108
+ Authors of the model.
109
+ input_array : np.ndarray
110
+ Input array, should not have been normalized.
111
+ output_array : np.ndarray
112
+ Output array, should have been denormalized.
113
+ channel_names : Optional[List[str]], optional
114
+ Channel names, by default None.
115
+ data_description : Optional[str], optional
116
+ Description of the data, by default None.
117
+
118
+ Raises
119
+ ------
120
+ ValueError
121
+ If the model is a Custom model.
122
+ """
123
+ path_to_archive = Path(path_to_archive)
124
+
125
+ # method is not compatible with Custom models
126
+ if config.algorithm_config.model.architecture == SupportedArchitecture.CUSTOM:
127
+ raise ValueError(
128
+ "Exporting Custom models to BioImage Model Zoo format is not supported."
129
+ )
130
+
131
+ if path_to_archive.suffix != ".zip":
132
+ raise ValueError(
133
+ f"Path to archive must point to a zip file, got {path_to_archive}."
134
+ )
135
+
136
+ if not path_to_archive.parent.exists():
137
+ path_to_archive.parent.mkdir(parents=True, exist_ok=True)
138
+
139
+ # versions
140
+ pytorch_version = __version__
141
+ careamics_version = pkg_resources.get_distribution("careamics").version
142
+
143
+ # save files in temporary folder
144
+ with tempfile.TemporaryDirectory() as tmpdirname:
145
+ temp_path = Path(tmpdirname)
146
+
147
+ # create environment file
148
+ # TODO move in bioimage module
149
+ env_path = temp_path / "environment.yml"
150
+ env_path.write_text(create_env_text(pytorch_version))
151
+
152
+ # export input and ouputs
153
+ inputs = temp_path / "inputs.npy"
154
+ np.save(inputs, input_array)
155
+ outputs = temp_path / "outputs.npy"
156
+ np.save(outputs, output_array)
157
+
158
+ # export configuration
159
+ config_path = save_configuration(config, temp_path)
160
+
161
+ # export model state dictionary
162
+ weight_path = _export_state_dict(model, temp_path / "weights.pth")
163
+
164
+ # create model description
165
+ model_description = create_model_description(
166
+ config=config,
167
+ name=model_name,
168
+ general_description=general_description,
169
+ authors=authors,
170
+ inputs=inputs,
171
+ outputs=outputs,
172
+ weights_path=weight_path,
173
+ torch_version=pytorch_version,
174
+ careamics_version=careamics_version,
175
+ config_path=config_path,
176
+ env_path=env_path,
177
+ channel_names=channel_names,
178
+ data_description=data_description,
179
+ )
180
+
181
+ # test model description
182
+ summary: ValidationSummary = test_model(model_description, decimal=1)
183
+ if summary.status == "failed":
184
+ raise ValueError(f"Model description test failed: {summary}")
185
+
186
+ # save bmz model
187
+ save_bioimageio_package(model_description, output_path=path_to_archive)
188
+
189
+
190
+ def load_from_bmz(path: Union[Path, str]) -> Tuple[CAREamicsModule, Configuration]:
191
+ """Load a model from a BioImage Model Zoo archive.
192
+
193
+ Parameters
194
+ ----------
195
+ path : Union[Path, str]
196
+ Path to the BioImage Model Zoo archive.
197
+
198
+ Returns
199
+ -------
200
+ Tuple[CAREamicsKiln, Configuration]
201
+ CAREamics model and configuration.
202
+
203
+ Raises
204
+ ------
205
+ ValueError
206
+ If the path is not a zip file.
207
+ """
208
+ path = Path(path)
209
+
210
+ if path.suffix != ".zip":
211
+ raise ValueError(f"Path must be a bioimage.io zip file, got {path}.")
212
+
213
+ # load description, this creates an unzipped folder next to the archive
214
+ model_desc = load_description(path)
215
+
216
+ # extract relative paths
217
+ weights_path, config_path = extract_model_path(model_desc)
218
+
219
+ # create folder path and absolute paths
220
+ unzip_path = get_unzip_path(path)
221
+ weights_path = unzip_path / weights_path
222
+ config_path = unzip_path / config_path
223
+
224
+ # load configuration
225
+ config = load_configuration(config_path)
226
+
227
+ # create careamics lightning module
228
+ model = CAREamicsModule(algorithm_config=config.algorithm_config)
229
+
230
+ # load model state dictionary
231
+ _load_state_dict(model, weights_path)
232
+
233
+ return model, config
@@ -0,0 +1,83 @@
1
+ """Utility functions to load pretrained models."""
2
+
3
+ from pathlib import Path
4
+ from typing import Tuple, Union
5
+
6
+ import torch
7
+
8
+ from careamics.config import Configuration
9
+ from careamics.lightning.lightning_module import CAREamicsModule
10
+ from careamics.model_io.bmz_io import load_from_bmz
11
+ from careamics.utils import check_path_exists
12
+
13
+
14
+ def load_pretrained(path: Union[Path, str]) -> Tuple[CAREamicsModule, Configuration]:
15
+ """
16
+ Load a pretrained model from a checkpoint or a BioImage Model Zoo model.
17
+
18
+ Expected formats are .ckpt or .zip files.
19
+
20
+ Parameters
21
+ ----------
22
+ path : Union[Path, str]
23
+ Path to the pretrained model.
24
+
25
+ Returns
26
+ -------
27
+ Tuple[CAREamicsKiln, Configuration]
28
+ Tuple of CAREamics model and its configuration.
29
+
30
+ Raises
31
+ ------
32
+ ValueError
33
+ If the model format is not supported.
34
+ """
35
+ path = check_path_exists(path)
36
+
37
+ if path.suffix == ".ckpt":
38
+ return _load_checkpoint(path)
39
+ elif path.suffix == ".zip":
40
+ return load_from_bmz(path)
41
+ else:
42
+ raise ValueError(
43
+ f"Invalid model format. Expected .ckpt or .zip, got {path.suffix}."
44
+ )
45
+
46
+
47
+ def _load_checkpoint(path: Union[Path, str]) -> Tuple[CAREamicsModule, Configuration]:
48
+ """
49
+ Load a model from a checkpoint and return both model and configuration.
50
+
51
+ Parameters
52
+ ----------
53
+ path : Union[Path, str]
54
+ Path to the checkpoint.
55
+
56
+ Returns
57
+ -------
58
+ Tuple[CAREamicsKiln, Configuration]
59
+ Tuple of CAREamics model and its configuration.
60
+
61
+ Raises
62
+ ------
63
+ ValueError
64
+ If the checkpoint file does not contain hyper parameters (configuration).
65
+ """
66
+ # load checkpoint
67
+ # here we might run into issues between devices
68
+ # see https://pytorch.org/tutorials/recipes/recipes/save_load_across_devices.html
69
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
70
+ checkpoint: dict = torch.load(path, map_location=device)
71
+
72
+ # attempt to load configuration
73
+ try:
74
+ cfg_dict = checkpoint["hyper_parameters"]
75
+ except KeyError as e:
76
+ raise ValueError(
77
+ f"Invalid checkpoint file. No `hyper_parameters` found in the "
78
+ f"checkpoint: {checkpoint.keys()}"
79
+ ) from e
80
+
81
+ model = CAREamicsModule.load_from_checkpoint(path)
82
+
83
+ return model, Configuration(**cfg_dict)
@@ -0,0 +1,7 @@
1
+ """Models package."""
2
+
3
+ __all__ = ["model_factory", "UNet"]
4
+
5
+
6
+ from .model_factory import model_factory
7
+ from .unet import UNet as UNet
@@ -0,0 +1,37 @@
1
+ """Activations for CAREamics models."""
2
+
3
+ from typing import Callable, Union
4
+
5
+ import torch.nn as nn
6
+
7
+ from ..config.support import SupportedActivation
8
+
9
+
10
+ def get_activation(activation: Union[SupportedActivation, str]) -> Callable:
11
+ """
12
+ Get activation function.
13
+
14
+ Parameters
15
+ ----------
16
+ activation : str
17
+ Activation function name.
18
+
19
+ Returns
20
+ -------
21
+ Callable
22
+ Activation function.
23
+ """
24
+ if activation == SupportedActivation.RELU:
25
+ return nn.ReLU()
26
+ elif activation == SupportedActivation.LEAKYRELU:
27
+ return nn.LeakyReLU()
28
+ elif activation == SupportedActivation.TANH:
29
+ return nn.Tanh()
30
+ elif activation == SupportedActivation.SIGMOID:
31
+ return nn.Sigmoid()
32
+ elif activation == SupportedActivation.SOFTMAX:
33
+ return nn.Softmax(dim=1)
34
+ elif activation == SupportedActivation.NONE:
35
+ return nn.Identity()
36
+ else:
37
+ raise ValueError(f"Activation {activation} not supported.")