careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +6 -1
- careamics/careamist.py +729 -0
- careamics/config/__init__.py +39 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +37 -0
- careamics/config/architectures/custom_model.py +162 -0
- careamics/config/architectures/lvae_model.py +174 -0
- careamics/config/architectures/register_model.py +103 -0
- careamics/config/architectures/unet_model.py +118 -0
- careamics/config/callback_model.py +123 -0
- careamics/config/configuration_factory.py +583 -0
- careamics/config/configuration_model.py +604 -0
- careamics/config/data_model.py +527 -0
- careamics/config/fcn_algorithm_model.py +147 -0
- careamics/config/inference_model.py +239 -0
- careamics/config/likelihood_model.py +43 -0
- careamics/config/nm_model.py +101 -0
- careamics/config/optimizer_models.py +187 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +132 -0
- careamics/config/references/references.py +39 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +33 -0
- careamics/config/support/supported_architectures.py +17 -0
- careamics/config/support/supported_data.py +109 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +29 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +11 -0
- careamics/config/tile_information.py +65 -0
- careamics/config/training_model.py +72 -0
- careamics/config/transformations/__init__.py +15 -0
- careamics/config/transformations/n2v_manipulate_model.py +64 -0
- careamics/config/transformations/normalize_model.py +60 -0
- careamics/config/transformations/transform_model.py +45 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +35 -0
- careamics/config/vae_algorithm_model.py +171 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +101 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +101 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +310 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +295 -0
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +299 -0
- careamics/dataset/patching/random_patching.py +201 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
- careamics/dataset/tiling/tiled_patching.py +164 -0
- careamics/dataset/zarr_dataset.py +151 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +12 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/read/zarr.py +60 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +18 -0
- careamics/lightning/callbacks/__init__.py +11 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/lightning_module.py +632 -0
- careamics/lightning/predict_data_module.py +333 -0
- careamics/lightning/train_data_module.py +680 -0
- careamics/losses/__init__.py +15 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +98 -0
- careamics/losses/loss_factory.py +155 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/dataset/__init__.py +0 -0
- careamics/lvae_training/dataset/data_utils.py +701 -0
- careamics/lvae_training/dataset/lc_dataset.py +259 -0
- careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
- careamics/lvae_training/dataset/vae_data_config.py +179 -0
- careamics/lvae_training/dataset/vae_dataset.py +1054 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +121 -0
- careamics/model_io/bioimage/bioimage_utils.py +52 -0
- careamics/model_io/bioimage/model_description.py +327 -0
- careamics/model_io/bmz_io.py +246 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +39 -0
- careamics/models/layers.py +493 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +364 -0
- careamics/models/lvae/lvae.py +901 -0
- careamics/models/lvae/noise_models.py +541 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +67 -0
- careamics/models/unet.py +443 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +135 -0
- careamics/prediction_utils/stitch_prediction.py +112 -0
- careamics/transforms/__init__.py +20 -0
- careamics/transforms/compose.py +107 -0
- careamics/transforms/n2v_manipulate.py +146 -0
- careamics/transforms/normalize.py +243 -0
- careamics/transforms/pixel_manipulation.py +407 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +101 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +66 -0
- careamics/utils/logging.py +322 -0
- careamics/utils/metrics.py +188 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/torch_utils.py +127 -0
- careamics-0.0.3.dist-info/METADATA +78 -0
- careamics-0.0.3.dist-info/RECORD +154 -0
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
"""Module containing convenience function to create `WriteStrategy`."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Optional
|
|
4
|
+
|
|
5
|
+
from careamics.config.support import SupportedData
|
|
6
|
+
from careamics.file_io import SupportedWriteType, WriteFunc, get_write_func
|
|
7
|
+
|
|
8
|
+
from .write_strategy import CacheTiles, WriteImage, WriteStrategy
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def create_write_strategy(
|
|
12
|
+
write_type: SupportedWriteType,
|
|
13
|
+
tiled: bool,
|
|
14
|
+
write_func: Optional[WriteFunc] = None,
|
|
15
|
+
write_extension: Optional[str] = None,
|
|
16
|
+
write_func_kwargs: Optional[dict[str, Any]] = None,
|
|
17
|
+
) -> WriteStrategy:
|
|
18
|
+
"""
|
|
19
|
+
Create a write strategy from convenient parameters.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
write_type : {"tiff", "custom"}
|
|
24
|
+
The data type to save as, includes custom.
|
|
25
|
+
tiled : bool
|
|
26
|
+
Whether the prediction will be tiled or not.
|
|
27
|
+
write_func : WriteFunc, optional
|
|
28
|
+
If a known `write_type` is selected this argument is ignored. For a custom
|
|
29
|
+
`write_type` a function to save the data must be passed. See notes below.
|
|
30
|
+
write_extension : str, optional
|
|
31
|
+
If a known `write_type` is selected this argument is ignored. For a custom
|
|
32
|
+
`write_type` an extension to save the data with must be passed.
|
|
33
|
+
write_func_kwargs : dict of {str: any}, optional
|
|
34
|
+
Additional keyword arguments to be passed to the save function.
|
|
35
|
+
|
|
36
|
+
Returns
|
|
37
|
+
-------
|
|
38
|
+
WriteStrategy
|
|
39
|
+
A strategy for writing predicions.
|
|
40
|
+
|
|
41
|
+
Notes
|
|
42
|
+
-----
|
|
43
|
+
The `write_func` function signature must match that of the example below
|
|
44
|
+
```
|
|
45
|
+
write_func(file_path: Path, img: NDArray, *args, **kwargs) -> None: ...
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
The `write_func_kwargs` will be passed to the `write_func` doing the following:
|
|
49
|
+
```
|
|
50
|
+
write_func(file_path=file_path, img=img, **kwargs)
|
|
51
|
+
```
|
|
52
|
+
"""
|
|
53
|
+
if write_func_kwargs is None:
|
|
54
|
+
write_func_kwargs = {}
|
|
55
|
+
|
|
56
|
+
write_strategy: WriteStrategy
|
|
57
|
+
if not tiled:
|
|
58
|
+
write_func = select_write_func(write_type=write_type, write_func=write_func)
|
|
59
|
+
write_extension = select_write_extension(
|
|
60
|
+
write_type=write_type, write_extension=write_extension
|
|
61
|
+
)
|
|
62
|
+
write_strategy = WriteImage(
|
|
63
|
+
write_func=write_func,
|
|
64
|
+
write_extension=write_extension,
|
|
65
|
+
write_func_kwargs=write_func_kwargs,
|
|
66
|
+
)
|
|
67
|
+
else:
|
|
68
|
+
# select CacheTiles or WriteTilesZarr (when implemented)
|
|
69
|
+
write_strategy = _create_tiled_write_strategy(
|
|
70
|
+
write_type=write_type,
|
|
71
|
+
write_func=write_func,
|
|
72
|
+
write_extension=write_extension,
|
|
73
|
+
write_func_kwargs=write_func_kwargs,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
return write_strategy
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _create_tiled_write_strategy(
|
|
80
|
+
write_type: SupportedWriteType,
|
|
81
|
+
write_func: Optional[WriteFunc],
|
|
82
|
+
write_extension: Optional[str],
|
|
83
|
+
write_func_kwargs: dict[str, Any],
|
|
84
|
+
) -> WriteStrategy:
|
|
85
|
+
"""
|
|
86
|
+
Create a tiled write strategy.
|
|
87
|
+
|
|
88
|
+
Either `CacheTiles` for caching tiles until a whole image is predicted or
|
|
89
|
+
`WriteTilesZarr` for writing tiles directly to disk.
|
|
90
|
+
|
|
91
|
+
Parameters
|
|
92
|
+
----------
|
|
93
|
+
write_type : {"tiff", "custom"}
|
|
94
|
+
The data type to save as, includes custom.
|
|
95
|
+
write_func : WriteFunc, optional
|
|
96
|
+
If a known `write_type` is selected this argument is ignored. For a custom
|
|
97
|
+
`write_type` a function to save the data must be passed. See notes below.
|
|
98
|
+
write_extension : str, optional
|
|
99
|
+
If a known `write_type` is selected this argument is ignored. For a custom
|
|
100
|
+
`write_type` an extension to save the data with must be passed.
|
|
101
|
+
write_func_kwargs : dict of {str: any}
|
|
102
|
+
Additional keyword arguments to be passed to the save function.
|
|
103
|
+
|
|
104
|
+
Returns
|
|
105
|
+
-------
|
|
106
|
+
WriteStrategy
|
|
107
|
+
A strategy for writing tiled predictions.
|
|
108
|
+
|
|
109
|
+
Raises
|
|
110
|
+
------
|
|
111
|
+
NotImplementedError
|
|
112
|
+
if `write_type="zarr" is chosen.
|
|
113
|
+
"""
|
|
114
|
+
# if write_type == SupportedData.ZARR:
|
|
115
|
+
# create *args, **kwargs
|
|
116
|
+
# return WriteTilesZarr(*args, **kwargs)
|
|
117
|
+
# else:
|
|
118
|
+
if write_type == "zarr":
|
|
119
|
+
raise NotImplementedError("Saving to zarr is not implemented yet.")
|
|
120
|
+
else:
|
|
121
|
+
write_func = select_write_func(write_type=write_type, write_func=write_func)
|
|
122
|
+
write_extension = select_write_extension(
|
|
123
|
+
write_type=write_type, write_extension=write_extension
|
|
124
|
+
)
|
|
125
|
+
return CacheTiles(
|
|
126
|
+
write_func=write_func,
|
|
127
|
+
write_extension=write_extension,
|
|
128
|
+
write_func_kwargs=write_func_kwargs,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def select_write_func(
|
|
133
|
+
write_type: SupportedWriteType, write_func: Optional[WriteFunc] = None
|
|
134
|
+
) -> WriteFunc:
|
|
135
|
+
"""
|
|
136
|
+
Return a function to write images.
|
|
137
|
+
|
|
138
|
+
If `write_type` is "custom" then `write_func`, otherwise the known write function
|
|
139
|
+
is selected.
|
|
140
|
+
|
|
141
|
+
Parameters
|
|
142
|
+
----------
|
|
143
|
+
write_type : {"tiff", "custom"}
|
|
144
|
+
The data type to save as, includes custom.
|
|
145
|
+
write_func : WriteFunc, optional
|
|
146
|
+
If a known `write_type` is selected this argument is ignored. For a custom
|
|
147
|
+
`write_type` a function to save the data must be passed. See notes below.
|
|
148
|
+
|
|
149
|
+
Returns
|
|
150
|
+
-------
|
|
151
|
+
WriteFunc
|
|
152
|
+
A function for writing images.
|
|
153
|
+
|
|
154
|
+
Raises
|
|
155
|
+
------
|
|
156
|
+
ValueError
|
|
157
|
+
If `write_type="custom"` but `write_func` has not been given.
|
|
158
|
+
|
|
159
|
+
Notes
|
|
160
|
+
-----
|
|
161
|
+
The `write_func` function signature must match that of the example below
|
|
162
|
+
```
|
|
163
|
+
write_func(file_path: Path, img: NDArray, *args, **kwargs) -> None: ...
|
|
164
|
+
```
|
|
165
|
+
"""
|
|
166
|
+
if write_type == SupportedData.CUSTOM:
|
|
167
|
+
if write_func is None:
|
|
168
|
+
raise ValueError(
|
|
169
|
+
"A save function must be provided for custom data types."
|
|
170
|
+
# TODO: link to how save functions should be implemented
|
|
171
|
+
)
|
|
172
|
+
else:
|
|
173
|
+
write_func = write_func
|
|
174
|
+
else:
|
|
175
|
+
write_func = get_write_func(write_type)
|
|
176
|
+
return write_func
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def select_write_extension(
|
|
180
|
+
write_type: SupportedWriteType, write_extension: Optional[str] = None
|
|
181
|
+
) -> str:
|
|
182
|
+
"""
|
|
183
|
+
Return an extension to add to file paths.
|
|
184
|
+
|
|
185
|
+
If `write_type` is "custom" then `write_extension`, otherwise the known
|
|
186
|
+
write extension is selected.
|
|
187
|
+
|
|
188
|
+
Parameters
|
|
189
|
+
----------
|
|
190
|
+
write_type : {"tiff", "custom"}
|
|
191
|
+
The data type to save as, includes custom.
|
|
192
|
+
write_extension : str, optional
|
|
193
|
+
If a known `write_type` is selected this argument is ignored. For a custom
|
|
194
|
+
`write_type` an extension to save the data with must be passed.
|
|
195
|
+
|
|
196
|
+
Returns
|
|
197
|
+
-------
|
|
198
|
+
str
|
|
199
|
+
The extension to be added to file paths.
|
|
200
|
+
|
|
201
|
+
Raises
|
|
202
|
+
------
|
|
203
|
+
ValueError
|
|
204
|
+
If `self.save_type="custom"` but `save_extension` has not been given.
|
|
205
|
+
"""
|
|
206
|
+
write_type_: SupportedData = SupportedData(write_type) # new variable for mypy
|
|
207
|
+
if write_type_ == SupportedData.CUSTOM:
|
|
208
|
+
if write_extension is None:
|
|
209
|
+
raise ValueError("A save extension must be provided for custom data types.")
|
|
210
|
+
else:
|
|
211
|
+
write_extension = write_extension
|
|
212
|
+
else:
|
|
213
|
+
# kind of a weird pattern -> reason to move get_extension from SupportedData
|
|
214
|
+
write_extension = write_type_.get_extension(write_type_)
|
|
215
|
+
return write_extension
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""Progressbar callback."""
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
from typing import Dict, Union
|
|
5
|
+
|
|
6
|
+
from pytorch_lightning import LightningModule, Trainer
|
|
7
|
+
from pytorch_lightning.callbacks import TQDMProgressBar
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ProgressBarCallback(TQDMProgressBar):
|
|
12
|
+
"""Progress bar for training and validation steps."""
|
|
13
|
+
|
|
14
|
+
def init_train_tqdm(self) -> tqdm:
|
|
15
|
+
"""Override this to customize the tqdm bar for training.
|
|
16
|
+
|
|
17
|
+
Returns
|
|
18
|
+
-------
|
|
19
|
+
tqdm
|
|
20
|
+
A tqdm bar.
|
|
21
|
+
"""
|
|
22
|
+
bar = tqdm(
|
|
23
|
+
desc="Training",
|
|
24
|
+
position=(2 * self.process_position),
|
|
25
|
+
disable=self.is_disabled,
|
|
26
|
+
leave=True,
|
|
27
|
+
dynamic_ncols=True,
|
|
28
|
+
file=sys.stdout,
|
|
29
|
+
smoothing=0,
|
|
30
|
+
)
|
|
31
|
+
return bar
|
|
32
|
+
|
|
33
|
+
def init_validation_tqdm(self) -> tqdm:
|
|
34
|
+
"""Override this to customize the tqdm bar for validation.
|
|
35
|
+
|
|
36
|
+
Returns
|
|
37
|
+
-------
|
|
38
|
+
tqdm
|
|
39
|
+
A tqdm bar.
|
|
40
|
+
"""
|
|
41
|
+
# The main progress bar doesn't exist in `trainer.validate()`
|
|
42
|
+
has_main_bar = self.train_progress_bar is not None
|
|
43
|
+
bar = tqdm(
|
|
44
|
+
desc="Validating",
|
|
45
|
+
position=(2 * self.process_position + has_main_bar),
|
|
46
|
+
disable=self.is_disabled,
|
|
47
|
+
leave=False,
|
|
48
|
+
dynamic_ncols=True,
|
|
49
|
+
file=sys.stdout,
|
|
50
|
+
)
|
|
51
|
+
return bar
|
|
52
|
+
|
|
53
|
+
def init_test_tqdm(self) -> tqdm:
|
|
54
|
+
"""Override this to customize the tqdm bar for testing.
|
|
55
|
+
|
|
56
|
+
Returns
|
|
57
|
+
-------
|
|
58
|
+
tqdm
|
|
59
|
+
A tqdm bar.
|
|
60
|
+
"""
|
|
61
|
+
bar = tqdm(
|
|
62
|
+
desc="Testing",
|
|
63
|
+
position=(2 * self.process_position),
|
|
64
|
+
disable=self.is_disabled,
|
|
65
|
+
leave=True,
|
|
66
|
+
dynamic_ncols=False,
|
|
67
|
+
ncols=100,
|
|
68
|
+
file=sys.stdout,
|
|
69
|
+
)
|
|
70
|
+
return bar
|
|
71
|
+
|
|
72
|
+
def get_metrics(
|
|
73
|
+
self, trainer: Trainer, pl_module: LightningModule
|
|
74
|
+
) -> Dict[str, Union[int, str, float, Dict[str, float]]]:
|
|
75
|
+
"""Override this to customize the metrics displayed in the progress bar.
|
|
76
|
+
|
|
77
|
+
Parameters
|
|
78
|
+
----------
|
|
79
|
+
trainer : Trainer
|
|
80
|
+
The trainer object.
|
|
81
|
+
pl_module : LightningModule
|
|
82
|
+
The LightningModule object, unused.
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
dict
|
|
87
|
+
A dictionary with the metrics to display in the progress bar.
|
|
88
|
+
"""
|
|
89
|
+
pbar_metrics = trainer.progress_bar_metrics
|
|
90
|
+
return {**pbar_metrics}
|