careamics 0.1.0rc5__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/callbacks/hyperparameters_callback.py +10 -3
- careamics/callbacks/progress_bar_callback.py +37 -4
- careamics/careamist.py +164 -231
- careamics/config/algorithm_model.py +5 -18
- careamics/config/architectures/architecture_model.py +7 -0
- careamics/config/architectures/custom_model.py +11 -4
- careamics/config/architectures/register_model.py +3 -1
- careamics/config/architectures/unet_model.py +2 -0
- careamics/config/architectures/vae_model.py +2 -0
- careamics/config/callback_model.py +3 -15
- careamics/config/configuration_example.py +4 -5
- careamics/config/configuration_factory.py +27 -41
- careamics/config/configuration_model.py +11 -11
- careamics/config/data_model.py +89 -63
- careamics/config/inference_model.py +28 -81
- careamics/config/optimizer_models.py +11 -11
- careamics/config/support/__init__.py +0 -2
- careamics/config/support/supported_activations.py +2 -0
- careamics/config/support/supported_algorithms.py +3 -1
- careamics/config/support/supported_architectures.py +2 -0
- careamics/config/support/supported_data.py +2 -0
- careamics/config/support/supported_loggers.py +2 -0
- careamics/config/support/supported_losses.py +2 -0
- careamics/config/support/supported_optimizers.py +2 -0
- careamics/config/support/supported_pixel_manipulations.py +3 -3
- careamics/config/support/supported_struct_axis.py +2 -0
- careamics/config/support/supported_transforms.py +4 -16
- careamics/config/tile_information.py +28 -58
- careamics/config/transformations/__init__.py +3 -2
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +11 -3
- careamics/config/validators/validator_utils.py +1 -1
- careamics/conftest.py +12 -0
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -1
- careamics/dataset/dataset_utils/dataset_utils.py +4 -4
- careamics/dataset/dataset_utils/file_utils.py +4 -3
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/read_tiff.py +6 -11
- careamics/dataset/dataset_utils/read_utils.py +2 -0
- careamics/dataset/dataset_utils/read_zarr.py +11 -7
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +88 -154
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +121 -191
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
- careamics/dataset/patching/patching.py +109 -39
- careamics/dataset/patching/random_patching.py +17 -6
- careamics/dataset/patching/sequential_patching.py +14 -8
- careamics/dataset/patching/validate_patch_dimension.py +7 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +7 -5
- careamics/dataset/zarr_dataset.py +2 -0
- careamics/lightning_datamodule.py +46 -25
- careamics/lightning_module.py +19 -9
- careamics/lightning_prediction_datamodule.py +54 -84
- careamics/losses/__init__.py +2 -3
- careamics/losses/loss_factory.py +1 -1
- careamics/losses/losses.py +11 -7
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +3 -3
- careamics/model_io/model_io_utils.py +5 -2
- careamics/models/activation.py +2 -0
- careamics/models/layers.py +121 -25
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +1 -1
- careamics/models/unet.py +35 -14
- careamics/prediction_utils/__init__.py +12 -0
- careamics/prediction_utils/create_pred_datamodule.py +185 -0
- careamics/prediction_utils/prediction_outputs.py +165 -0
- careamics/prediction_utils/stitch_prediction.py +100 -0
- careamics/transforms/__init__.py +2 -2
- careamics/transforms/compose.py +33 -7
- careamics/transforms/n2v_manipulate.py +52 -14
- careamics/transforms/normalize.py +171 -48
- careamics/transforms/pixel_manipulation.py +35 -11
- careamics/transforms/struct_mask_parameters.py +3 -1
- careamics/transforms/transform.py +10 -19
- careamics/transforms/tta.py +43 -29
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +38 -5
- careamics/utils/base_enum.py +28 -0
- careamics/utils/path_utils.py +2 -0
- careamics/utils/ram.py +4 -2
- careamics/utils/receptive_field.py +93 -87
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +8 -6
- careamics-0.1.0rc7.dist-info/RECORD +130 -0
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
- careamics/config/noise_models.py +0 -162
- careamics/config/support/supported_extraction_strategies.py +0 -25
- careamics/config/transformations/nd_flip_model.py +0 -27
- careamics/lightning_prediction_loop.py +0 -116
- careamics/losses/noise_model_factory.py +0 -40
- careamics/losses/noise_models.py +0 -524
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -74
- careamics/transforms/nd_flip.py +0 -67
- careamics/utils/running_stats.py +0 -43
- careamics-0.1.0rc5.dist-info/RECORD +0 -111
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
"""Module containing functions to create `CAREamicsPredictData`."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Callable, Dict, Literal, Optional, Tuple, Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from numpy.typing import NDArray
|
|
8
|
+
|
|
9
|
+
from careamics.config import Configuration, create_inference_configuration
|
|
10
|
+
from careamics.utils import check_path_exists
|
|
11
|
+
|
|
12
|
+
from ..lightning_prediction_datamodule import CAREamicsPredictData
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def create_pred_datamodule(
|
|
16
|
+
source: Union[CAREamicsPredictData, Path, str, NDArray],
|
|
17
|
+
config: Configuration,
|
|
18
|
+
batch_size: Optional[int] = None,
|
|
19
|
+
tile_size: Optional[Tuple[int, ...]] = None,
|
|
20
|
+
tile_overlap: Tuple[int, ...] = (48, 48),
|
|
21
|
+
axes: Optional[str] = None,
|
|
22
|
+
data_type: Optional[Literal["array", "tiff", "custom"]] = None,
|
|
23
|
+
tta_transforms: bool = True,
|
|
24
|
+
dataloader_params: Optional[Dict] = None,
|
|
25
|
+
read_source_func: Optional[Callable] = None,
|
|
26
|
+
extension_filter: str = "",
|
|
27
|
+
) -> CAREamicsPredictData:
|
|
28
|
+
"""
|
|
29
|
+
Create a `CAREamicsPredictData` module.
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
source : CAREamicsPredData, pathlib.Path, str or numpy.ndarray
|
|
34
|
+
Data to predict on.
|
|
35
|
+
config : Configuration
|
|
36
|
+
Global configuration.
|
|
37
|
+
batch_size : int, default=1
|
|
38
|
+
Batch size for prediction.
|
|
39
|
+
tile_size : tuple of int, optional
|
|
40
|
+
Size of the tiles to use for prediction.
|
|
41
|
+
tile_overlap : tuple of int, default=(48, 48)
|
|
42
|
+
Overlap between tiles.
|
|
43
|
+
axes : str, optional
|
|
44
|
+
Axes of the input data, by default None.
|
|
45
|
+
data_type : {"array", "tiff", "custom"}, optional
|
|
46
|
+
Type of the input data.
|
|
47
|
+
tta_transforms : bool, default=True
|
|
48
|
+
Whether to apply test-time augmentation.
|
|
49
|
+
dataloader_params : dict, optional
|
|
50
|
+
Parameters to pass to the dataloader.
|
|
51
|
+
read_source_func : Callable, optional
|
|
52
|
+
Function to read the source data.
|
|
53
|
+
extension_filter : str, default=""
|
|
54
|
+
Filter for the file extension.
|
|
55
|
+
|
|
56
|
+
Returns
|
|
57
|
+
-------
|
|
58
|
+
prediction datamodule: CAREamicsPredictData
|
|
59
|
+
Subclass of `pytorch_lightning.LightningDataModule` for creating predictions.
|
|
60
|
+
|
|
61
|
+
Raises
|
|
62
|
+
------
|
|
63
|
+
ValueError
|
|
64
|
+
If the input is not a CAREamicsPredData instance, a path or a numpy array.
|
|
65
|
+
"""
|
|
66
|
+
# Reuse batch size if not provided explicitly
|
|
67
|
+
if batch_size is None:
|
|
68
|
+
batch_size = config.data_config.batch_size
|
|
69
|
+
|
|
70
|
+
# create predict config, reuse training config if parameters missing
|
|
71
|
+
prediction_config = create_inference_configuration(
|
|
72
|
+
configuration=config,
|
|
73
|
+
tile_size=tile_size,
|
|
74
|
+
tile_overlap=tile_overlap,
|
|
75
|
+
data_type=data_type,
|
|
76
|
+
axes=axes,
|
|
77
|
+
tta_transforms=tta_transforms,
|
|
78
|
+
batch_size=batch_size,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
# remove batch from dataloader parameters (priority given to config)
|
|
82
|
+
if dataloader_params is None:
|
|
83
|
+
dataloader_params = {}
|
|
84
|
+
if "batch_size" in dataloader_params:
|
|
85
|
+
del dataloader_params["batch_size"]
|
|
86
|
+
|
|
87
|
+
if isinstance(source, CAREamicsPredictData):
|
|
88
|
+
pred_datamodule = source
|
|
89
|
+
elif isinstance(source, Path) or isinstance(source, str):
|
|
90
|
+
pred_datamodule = _create_from_path(
|
|
91
|
+
source=source,
|
|
92
|
+
pred_config=prediction_config,
|
|
93
|
+
read_source_func=read_source_func,
|
|
94
|
+
extension_filter=extension_filter,
|
|
95
|
+
dataloader_params=dataloader_params,
|
|
96
|
+
)
|
|
97
|
+
elif isinstance(source, np.ndarray):
|
|
98
|
+
pred_datamodule = _create_from_array(
|
|
99
|
+
source=source,
|
|
100
|
+
pred_config=prediction_config,
|
|
101
|
+
dataloader_params=dataloader_params,
|
|
102
|
+
)
|
|
103
|
+
else:
|
|
104
|
+
raise ValueError(
|
|
105
|
+
f"Invalid input. Expected a CAREamicsPredData instance, paths or "
|
|
106
|
+
f"NDArray (got {type(source)})."
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
return pred_datamodule
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _create_from_path(
|
|
113
|
+
source: Union[Path, str],
|
|
114
|
+
pred_config: Configuration,
|
|
115
|
+
read_source_func: Optional[Callable] = None,
|
|
116
|
+
extension_filter: str = "",
|
|
117
|
+
dataloader_params: Optional[Dict] = None,
|
|
118
|
+
**kwargs,
|
|
119
|
+
) -> CAREamicsPredictData:
|
|
120
|
+
"""
|
|
121
|
+
Create `CAREamicsPredictData` from path.
|
|
122
|
+
|
|
123
|
+
Parameters
|
|
124
|
+
----------
|
|
125
|
+
source : Path or str
|
|
126
|
+
_Data to predict on.
|
|
127
|
+
pred_config : Configuration
|
|
128
|
+
Prediction configuration.
|
|
129
|
+
read_source_func : Callable, optional
|
|
130
|
+
Function to read the source data.
|
|
131
|
+
extension_filter : str, default=""
|
|
132
|
+
Function to read the source data.
|
|
133
|
+
dataloader_params : Optional[Dict], optional
|
|
134
|
+
Parameters to pass to the dataloader.
|
|
135
|
+
**kwargs
|
|
136
|
+
Unused.
|
|
137
|
+
|
|
138
|
+
Returns
|
|
139
|
+
-------
|
|
140
|
+
prediction datamodule: CAREamicsPredictData
|
|
141
|
+
Subclass of `pytorch_lightning.LightningDataModule` for creating predictions.
|
|
142
|
+
"""
|
|
143
|
+
source_path = check_path_exists(source)
|
|
144
|
+
|
|
145
|
+
datamodule = CAREamicsPredictData(
|
|
146
|
+
pred_config=pred_config,
|
|
147
|
+
pred_data=source_path,
|
|
148
|
+
read_source_func=read_source_func,
|
|
149
|
+
extension_filter=extension_filter,
|
|
150
|
+
dataloader_params=dataloader_params,
|
|
151
|
+
)
|
|
152
|
+
return datamodule
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _create_from_array(
|
|
156
|
+
source: NDArray,
|
|
157
|
+
pred_config: Configuration,
|
|
158
|
+
dataloader_params: Optional[Dict] = None,
|
|
159
|
+
**kwargs,
|
|
160
|
+
) -> CAREamicsPredictData:
|
|
161
|
+
"""
|
|
162
|
+
Create `CAREamicsPredictData` from array.
|
|
163
|
+
|
|
164
|
+
Parameters
|
|
165
|
+
----------
|
|
166
|
+
source : Path or str
|
|
167
|
+
_Data to predict on.
|
|
168
|
+
pred_config : Configuration
|
|
169
|
+
Prediction configuration.
|
|
170
|
+
dataloader_params : Optional[Dict], optional
|
|
171
|
+
Parameters to pass to the dataloader.
|
|
172
|
+
**kwargs
|
|
173
|
+
Unused. Added for compatible function signature with `_create_from_path`.
|
|
174
|
+
|
|
175
|
+
Returns
|
|
176
|
+
-------
|
|
177
|
+
prediction datamodule: CAREamicsPredictData
|
|
178
|
+
Subclass of `pytorch_lightning.LightningDataModule` for creating predictions.
|
|
179
|
+
"""
|
|
180
|
+
datamodule = CAREamicsPredictData(
|
|
181
|
+
pred_config=pred_config,
|
|
182
|
+
pred_data=source,
|
|
183
|
+
dataloader_params=dataloader_params,
|
|
184
|
+
)
|
|
185
|
+
return datamodule
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
"""Module containing functions to convert prediction outputs to desired form."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, List, Literal, Tuple, Union, overload
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from numpy.typing import NDArray
|
|
7
|
+
|
|
8
|
+
from ..config.tile_information import TileInformation
|
|
9
|
+
from .stitch_prediction import stitch_prediction
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def convert_outputs(
|
|
13
|
+
predictions: List[Any], tiled: bool
|
|
14
|
+
) -> Union[List[NDArray], NDArray]:
|
|
15
|
+
"""
|
|
16
|
+
Convert the outputs to the desired form.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
predictions : list
|
|
21
|
+
Predictions that are output from `Trainer.predict`.
|
|
22
|
+
tiled : bool
|
|
23
|
+
Whether the predictions are tiled.
|
|
24
|
+
|
|
25
|
+
Returns
|
|
26
|
+
-------
|
|
27
|
+
list of numpy.ndarray or numpy.ndarray
|
|
28
|
+
List of arrays with the axes SC(Z)YX. If there is only 1 output it will not
|
|
29
|
+
be in a list.
|
|
30
|
+
"""
|
|
31
|
+
if len(predictions) == 0:
|
|
32
|
+
return predictions
|
|
33
|
+
|
|
34
|
+
# this layout is to stop mypy complaining
|
|
35
|
+
if tiled:
|
|
36
|
+
predictions_comb = combine_batches(predictions, tiled)
|
|
37
|
+
# remove sample dimension (always 1) `stitch_predict` func expects no S dim
|
|
38
|
+
tiles = [pred[0] for pred in predictions_comb[0]]
|
|
39
|
+
tile_infos = predictions_comb[1]
|
|
40
|
+
predictions_output = stitch_prediction(tiles, tile_infos)
|
|
41
|
+
else:
|
|
42
|
+
predictions_output = combine_batches(predictions, tiled)
|
|
43
|
+
|
|
44
|
+
# TODO: add this in? Returns output with same axes as input
|
|
45
|
+
# Won't work with tiling rn because stitch_prediction func removes S axis
|
|
46
|
+
# predictions = reshape(predictions, axes)
|
|
47
|
+
# At least make sure stitched prediction and non-tiled prediction have matching axes
|
|
48
|
+
|
|
49
|
+
# TODO: might want to remove this
|
|
50
|
+
if len(predictions_output) == 1:
|
|
51
|
+
return predictions_output[0]
|
|
52
|
+
return predictions_output
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
# for mypy
|
|
56
|
+
@overload
|
|
57
|
+
def combine_batches( # numpydoc ignore=GL08
|
|
58
|
+
predictions: List[Any], tiled: Literal[True]
|
|
59
|
+
) -> Tuple[List[NDArray], List[TileInformation]]: ...
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
# for mypy
|
|
63
|
+
@overload
|
|
64
|
+
def combine_batches( # numpydoc ignore=GL08
|
|
65
|
+
predictions: List[Any], tiled: Literal[False]
|
|
66
|
+
) -> List[NDArray]: ...
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
# for mypy
|
|
70
|
+
@overload
|
|
71
|
+
def combine_batches( # numpydoc ignore=GL08
|
|
72
|
+
predictions: List[Any], tiled: Union[bool, Literal[True], Literal[False]]
|
|
73
|
+
) -> Union[List[NDArray], Tuple[List[NDArray], List[TileInformation]]]: ...
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def combine_batches(
|
|
77
|
+
predictions: List[Any], tiled: bool
|
|
78
|
+
) -> Union[List[NDArray], Tuple[List[NDArray], List[TileInformation]]]:
|
|
79
|
+
"""
|
|
80
|
+
If predictions are in batches, they will be combined.
|
|
81
|
+
|
|
82
|
+
Parameters
|
|
83
|
+
----------
|
|
84
|
+
predictions : list
|
|
85
|
+
Predictions that are output from `Trainer.predict`.
|
|
86
|
+
tiled : bool
|
|
87
|
+
Whether the predictions are tiled.
|
|
88
|
+
|
|
89
|
+
Returns
|
|
90
|
+
-------
|
|
91
|
+
(list of numpy.ndarray) or tuple of (list of numpy.ndarray, list of TileInformation)
|
|
92
|
+
Combined batches.
|
|
93
|
+
"""
|
|
94
|
+
if tiled:
|
|
95
|
+
return _combine_tiled_batches(predictions)
|
|
96
|
+
else:
|
|
97
|
+
return _combine_untiled_batches(predictions)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _combine_tiled_batches(
|
|
101
|
+
predictions: List[Tuple[NDArray, List[TileInformation]]]
|
|
102
|
+
) -> Tuple[List[NDArray], List[TileInformation]]:
|
|
103
|
+
"""
|
|
104
|
+
Combine batches from tiled output.
|
|
105
|
+
|
|
106
|
+
Parameters
|
|
107
|
+
----------
|
|
108
|
+
predictions : list
|
|
109
|
+
Predictions that are output from `Trainer.predict`.
|
|
110
|
+
|
|
111
|
+
Returns
|
|
112
|
+
-------
|
|
113
|
+
tuple of (list of numpy.ndarray, list of TileInformation)
|
|
114
|
+
Combined batches.
|
|
115
|
+
"""
|
|
116
|
+
# turn list of lists into single list
|
|
117
|
+
tile_infos = [
|
|
118
|
+
tile_info for _, tile_info_list in predictions for tile_info in tile_info_list
|
|
119
|
+
]
|
|
120
|
+
prediction_tiles: List[NDArray] = _combine_untiled_batches(
|
|
121
|
+
[preds for preds, _ in predictions]
|
|
122
|
+
)
|
|
123
|
+
return prediction_tiles, tile_infos
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _combine_untiled_batches(predictions: List[NDArray]) -> List[NDArray]:
|
|
127
|
+
"""
|
|
128
|
+
Combine batches from un-tiled output.
|
|
129
|
+
|
|
130
|
+
Parameters
|
|
131
|
+
----------
|
|
132
|
+
predictions : list
|
|
133
|
+
Predictions that are output from `Trainer.predict`.
|
|
134
|
+
|
|
135
|
+
Returns
|
|
136
|
+
-------
|
|
137
|
+
list of nunpy.ndarray
|
|
138
|
+
Combined batches.
|
|
139
|
+
"""
|
|
140
|
+
prediction_concat: NDArray = np.concatenate(predictions, axis=0)
|
|
141
|
+
prediction_split = np.split(prediction_concat, prediction_concat.shape[0], axis=0)
|
|
142
|
+
return prediction_split
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def reshape(predictions: List[NDArray], axes: str) -> List[NDArray]:
|
|
146
|
+
"""
|
|
147
|
+
Reshape predictions to have dimensions of input.
|
|
148
|
+
|
|
149
|
+
Parameters
|
|
150
|
+
----------
|
|
151
|
+
predictions : list
|
|
152
|
+
Predictions that are output from `Trainer.predict`.
|
|
153
|
+
axes : str
|
|
154
|
+
Axes SC(Z)YX.
|
|
155
|
+
|
|
156
|
+
Returns
|
|
157
|
+
-------
|
|
158
|
+
List[NDArray]
|
|
159
|
+
Reshaped predicitions.
|
|
160
|
+
"""
|
|
161
|
+
if "C" not in axes:
|
|
162
|
+
predictions = [pred[:, 0] for pred in predictions]
|
|
163
|
+
if "S" not in axes:
|
|
164
|
+
predictions = [pred[0] for pred in predictions]
|
|
165
|
+
return predictions
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
"""Prediction utility functions."""
|
|
2
|
+
|
|
3
|
+
from typing import List
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from careamics.config.tile_information import TileInformation
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
# TODO: why not allow input and output of torch.tensor ?
|
|
11
|
+
def stitch_prediction(
|
|
12
|
+
tiles: List[np.ndarray],
|
|
13
|
+
tile_infos: List[TileInformation],
|
|
14
|
+
) -> List[np.ndarray]:
|
|
15
|
+
"""
|
|
16
|
+
Stitch tiles back together to form a full image(s).
|
|
17
|
+
|
|
18
|
+
Tiles are of dimensions SC(Z)YX, where C is the number of channels and can be a
|
|
19
|
+
singleton dimension.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
tiles : list of numpy.ndarray
|
|
24
|
+
Cropped tiles and their respective stitching coordinates. Can contain tiles
|
|
25
|
+
from multiple images.
|
|
26
|
+
tile_infos : list of TileInformation
|
|
27
|
+
List of information and coordinates obtained from
|
|
28
|
+
`dataset.tiled_patching.extract_tiles`.
|
|
29
|
+
|
|
30
|
+
Returns
|
|
31
|
+
-------
|
|
32
|
+
list of numpy.ndarray
|
|
33
|
+
Full image(s).
|
|
34
|
+
"""
|
|
35
|
+
# Find where to split the lists so that only info from one image is contained.
|
|
36
|
+
# Do this by locating the last tiles of each image.
|
|
37
|
+
last_tiles = [tile_info.last_tile for tile_info in tile_infos]
|
|
38
|
+
last_tile_position = np.where(last_tiles)[0]
|
|
39
|
+
image_slices = [
|
|
40
|
+
slice(
|
|
41
|
+
None if i == 0 else last_tile_position[i - 1] + 1, last_tile_position[i] + 1
|
|
42
|
+
)
|
|
43
|
+
for i in range(len(last_tile_position))
|
|
44
|
+
]
|
|
45
|
+
image_predictions = []
|
|
46
|
+
# slice the lists and apply stitch_prediction_single to each in turn.
|
|
47
|
+
for image_slice in image_slices:
|
|
48
|
+
image_predictions.append(
|
|
49
|
+
stitch_prediction_single(tiles[image_slice], tile_infos[image_slice])
|
|
50
|
+
)
|
|
51
|
+
return image_predictions
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def stitch_prediction_single(
|
|
55
|
+
tiles: List[np.ndarray],
|
|
56
|
+
tile_infos: List[TileInformation],
|
|
57
|
+
) -> np.ndarray:
|
|
58
|
+
"""
|
|
59
|
+
Stitch tiles back together to form a full image.
|
|
60
|
+
|
|
61
|
+
Tiles are of dimensions SC(Z)YX, where C is the number of channels and can be a
|
|
62
|
+
singleton dimension.
|
|
63
|
+
|
|
64
|
+
Parameters
|
|
65
|
+
----------
|
|
66
|
+
tiles : list of numpy.ndarray
|
|
67
|
+
Cropped tiles and their respective stitching coordinates.
|
|
68
|
+
tile_infos : list of TileInformation
|
|
69
|
+
List of information and coordinates obtained from
|
|
70
|
+
`dataset.tiled_patching.extract_tiles`.
|
|
71
|
+
|
|
72
|
+
Returns
|
|
73
|
+
-------
|
|
74
|
+
numpy.ndarray
|
|
75
|
+
Full image.
|
|
76
|
+
"""
|
|
77
|
+
# retrieve whole array size
|
|
78
|
+
input_shape = tile_infos[0].array_shape
|
|
79
|
+
predicted_image = np.zeros(input_shape, dtype=np.float32)
|
|
80
|
+
|
|
81
|
+
for tile, tile_info in zip(tiles, tile_infos):
|
|
82
|
+
n_channels = tile.shape[0]
|
|
83
|
+
|
|
84
|
+
# Compute coordinates for cropping predicted tile
|
|
85
|
+
slices = (slice(0, n_channels),) + tuple(
|
|
86
|
+
[slice(c[0], c[1]) for c in tile_info.overlap_crop_coords]
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# Crop predited tile according to overlap coordinates
|
|
90
|
+
cropped_tile = tile[slices]
|
|
91
|
+
|
|
92
|
+
# Insert cropped tile into predicted image using stitch coordinates
|
|
93
|
+
predicted_image[
|
|
94
|
+
(
|
|
95
|
+
...,
|
|
96
|
+
*[slice(c[0], c[1]) for c in tile_info.stitch_coords],
|
|
97
|
+
)
|
|
98
|
+
] = cropped_tile.astype(np.float32)
|
|
99
|
+
|
|
100
|
+
return predicted_image
|
careamics/transforms/__init__.py
CHANGED
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
__all__ = [
|
|
4
4
|
"get_all_transforms",
|
|
5
5
|
"N2VManipulate",
|
|
6
|
-
"
|
|
6
|
+
"XYFlip",
|
|
7
7
|
"XYRandomRotate90",
|
|
8
8
|
"ImageRestorationTTA",
|
|
9
9
|
"Denormalize",
|
|
@@ -14,7 +14,7 @@ __all__ = [
|
|
|
14
14
|
|
|
15
15
|
from .compose import Compose, get_all_transforms
|
|
16
16
|
from .n2v_manipulate import N2VManipulate
|
|
17
|
-
from .nd_flip import NDFlip
|
|
18
17
|
from .normalize import Denormalize, Normalize
|
|
19
18
|
from .tta import ImageRestorationTTA
|
|
19
|
+
from .xy_flip import XYFlip
|
|
20
20
|
from .xy_random_rotate90 import XYRandomRotate90
|
careamics/transforms/compose.py
CHANGED
|
@@ -1,26 +1,26 @@
|
|
|
1
1
|
"""A class chaining transforms together."""
|
|
2
2
|
|
|
3
|
-
from typing import Callable, List, Optional, Tuple
|
|
3
|
+
from typing import Callable, Dict, List, Optional, Tuple
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
|
|
7
7
|
from careamics.config.data_model import TRANSFORMS_UNION
|
|
8
8
|
|
|
9
9
|
from .n2v_manipulate import N2VManipulate
|
|
10
|
-
from .nd_flip import NDFlip
|
|
11
10
|
from .normalize import Normalize
|
|
12
11
|
from .transform import Transform
|
|
12
|
+
from .xy_flip import XYFlip
|
|
13
13
|
from .xy_random_rotate90 import XYRandomRotate90
|
|
14
14
|
|
|
15
15
|
ALL_TRANSFORMS = {
|
|
16
16
|
"Normalize": Normalize,
|
|
17
17
|
"N2VManipulate": N2VManipulate,
|
|
18
|
-
"
|
|
18
|
+
"XYFlip": XYFlip,
|
|
19
19
|
"XYRandomRotate90": XYRandomRotate90,
|
|
20
20
|
}
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
def get_all_transforms() ->
|
|
23
|
+
def get_all_transforms() -> Dict[str, type]:
|
|
24
24
|
"""Return all the transforms accepted by CAREamics.
|
|
25
25
|
|
|
26
26
|
Returns
|
|
@@ -33,7 +33,19 @@ def get_all_transforms() -> dict:
|
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
class Compose:
|
|
36
|
-
"""A class chaining transforms together.
|
|
36
|
+
"""A class chaining transforms together.
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
transform_list : List[TRANSFORMS_UNION]
|
|
41
|
+
A list of dictionaries where each dictionary contains the name of a
|
|
42
|
+
transform and its parameters.
|
|
43
|
+
|
|
44
|
+
Attributes
|
|
45
|
+
----------
|
|
46
|
+
_callable_transforms : Callable
|
|
47
|
+
A callable that applies the transforms to the input data.
|
|
48
|
+
"""
|
|
37
49
|
|
|
38
50
|
def __init__(self, transform_list: List[TRANSFORMS_UNION]) -> None:
|
|
39
51
|
"""Instantiate a Compose object.
|
|
@@ -68,7 +80,21 @@ class Compose:
|
|
|
68
80
|
|
|
69
81
|
def _chain(
|
|
70
82
|
patch: np.ndarray, target: Optional[np.ndarray]
|
|
71
|
-
) -> Tuple[np.ndarray,
|
|
83
|
+
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
|
84
|
+
"""Chain transforms on the input data.
|
|
85
|
+
|
|
86
|
+
Parameters
|
|
87
|
+
----------
|
|
88
|
+
patch : np.ndarray
|
|
89
|
+
Input data.
|
|
90
|
+
target : Optional[np.ndarray]
|
|
91
|
+
Target data, by default None.
|
|
92
|
+
|
|
93
|
+
Returns
|
|
94
|
+
-------
|
|
95
|
+
Tuple[np.ndarray, Optional[np.ndarray]]
|
|
96
|
+
The output of the transformations.
|
|
97
|
+
"""
|
|
72
98
|
params = (patch, target)
|
|
73
99
|
|
|
74
100
|
for t in transforms:
|
|
@@ -88,7 +114,7 @@ class Compose:
|
|
|
88
114
|
patch : np.ndarray
|
|
89
115
|
The input data.
|
|
90
116
|
target : Optional[np.ndarray], optional
|
|
91
|
-
Target data, by default None
|
|
117
|
+
Target data, by default None.
|
|
92
118
|
|
|
93
119
|
Returns
|
|
94
120
|
-------
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""N2V manipulation transform."""
|
|
2
|
+
|
|
1
3
|
from typing import Any, Literal, Optional, Tuple
|
|
2
4
|
|
|
3
5
|
import numpy as np
|
|
@@ -17,10 +19,35 @@ class N2VManipulate(Transform):
|
|
|
17
19
|
|
|
18
20
|
Parameters
|
|
19
21
|
----------
|
|
20
|
-
|
|
21
|
-
|
|
22
|
+
roi_size : int, optional
|
|
23
|
+
Size of the replacement area, by default 11.
|
|
24
|
+
masked_pixel_percentage : float, optional
|
|
25
|
+
Percentage of pixels to mask, by default 0.2.
|
|
26
|
+
strategy : Literal[ "uniform", "median" ], optional
|
|
27
|
+
Replaccement strategy, uniform or median, by default uniform.
|
|
28
|
+
remove_center : bool, optional
|
|
29
|
+
Whether to remove central pixel from patch, by default True.
|
|
30
|
+
struct_mask_axis : Literal["horizontal", "vertical", "none"], optional
|
|
31
|
+
StructN2V mask axis, by default "none".
|
|
32
|
+
struct_mask_span : int, optional
|
|
33
|
+
StructN2V mask span, by default 5.
|
|
34
|
+
seed : Optional[int], optional
|
|
35
|
+
Random seed, by default None.
|
|
36
|
+
|
|
37
|
+
Attributes
|
|
38
|
+
----------
|
|
39
|
+
masked_pixel_percentage : float
|
|
40
|
+
Percentage of pixels to mask.
|
|
22
41
|
roi_size : int
|
|
23
|
-
Size of the
|
|
42
|
+
Size of the replacement area.
|
|
43
|
+
strategy : Literal[ "uniform", "median" ]
|
|
44
|
+
Replaccement strategy, uniform or median.
|
|
45
|
+
remove_center : bool
|
|
46
|
+
Whether to remove central pixel from patch.
|
|
47
|
+
struct_mask : Optional[StructMaskParameters]
|
|
48
|
+
StructN2V mask parameters.
|
|
49
|
+
rng : Generator
|
|
50
|
+
Random number generator.
|
|
24
51
|
"""
|
|
25
52
|
|
|
26
53
|
def __init__(
|
|
@@ -33,31 +60,31 @@ class N2VManipulate(Transform):
|
|
|
33
60
|
remove_center: bool = True,
|
|
34
61
|
struct_mask_axis: Literal["horizontal", "vertical", "none"] = "none",
|
|
35
62
|
struct_mask_span: int = 5,
|
|
36
|
-
seed: Optional[int] = None,
|
|
63
|
+
seed: Optional[int] = None,
|
|
37
64
|
):
|
|
38
65
|
"""Constructor.
|
|
39
66
|
|
|
40
67
|
Parameters
|
|
41
68
|
----------
|
|
42
69
|
roi_size : int, optional
|
|
43
|
-
Size of the replacement area, by default 11
|
|
70
|
+
Size of the replacement area, by default 11.
|
|
44
71
|
masked_pixel_percentage : float, optional
|
|
45
|
-
Percentage of pixels to mask, by default 0.2
|
|
72
|
+
Percentage of pixels to mask, by default 0.2.
|
|
46
73
|
strategy : Literal[ "uniform", "median" ], optional
|
|
47
|
-
Replaccement strategy, uniform or median, by default uniform
|
|
74
|
+
Replaccement strategy, uniform or median, by default uniform.
|
|
48
75
|
remove_center : bool, optional
|
|
49
|
-
Whether to remove central pixel from patch, by default True
|
|
76
|
+
Whether to remove central pixel from patch, by default True.
|
|
50
77
|
struct_mask_axis : Literal["horizontal", "vertical", "none"], optional
|
|
51
|
-
StructN2V mask axis, by default "none"
|
|
78
|
+
StructN2V mask axis, by default "none".
|
|
52
79
|
struct_mask_span : int, optional
|
|
53
|
-
StructN2V mask span, by default 5
|
|
80
|
+
StructN2V mask span, by default 5.
|
|
54
81
|
seed : Optional[int], optional
|
|
55
|
-
Random seed, by default None
|
|
82
|
+
Random seed, by default None.
|
|
56
83
|
"""
|
|
57
84
|
self.masked_pixel_percentage = masked_pixel_percentage
|
|
58
85
|
self.roi_size = roi_size
|
|
59
86
|
self.strategy = strategy
|
|
60
|
-
self.remove_center = remove_center
|
|
87
|
+
self.remove_center = remove_center # TODO is this ever used?
|
|
61
88
|
|
|
62
89
|
if struct_mask_axis == SupportedStructAxis.NONE:
|
|
63
90
|
self.struct_mask: Optional[StructMaskParameters] = None
|
|
@@ -77,8 +104,17 @@ class N2VManipulate(Transform):
|
|
|
77
104
|
|
|
78
105
|
Parameters
|
|
79
106
|
----------
|
|
80
|
-
|
|
81
|
-
Image
|
|
107
|
+
patch : np.ndarray
|
|
108
|
+
Image patch, 2D or 3D, shape C(Z)YX.
|
|
109
|
+
*args : Any
|
|
110
|
+
Additional arguments, unused.
|
|
111
|
+
**kwargs : Any
|
|
112
|
+
Additional keyword arguments, unused.
|
|
113
|
+
|
|
114
|
+
Returns
|
|
115
|
+
-------
|
|
116
|
+
Tuple[np.ndarray, np.ndarray, np.ndarray]
|
|
117
|
+
Masked patch, original patch, and mask.
|
|
82
118
|
"""
|
|
83
119
|
masked = np.zeros_like(patch)
|
|
84
120
|
mask = np.zeros_like(patch)
|
|
@@ -91,6 +127,7 @@ class N2VManipulate(Transform):
|
|
|
91
127
|
subpatch_size=self.roi_size,
|
|
92
128
|
remove_center=self.remove_center,
|
|
93
129
|
struct_params=self.struct_mask,
|
|
130
|
+
rng=self.rng,
|
|
94
131
|
)
|
|
95
132
|
elif self.strategy == SupportedPixelManipulation.MEDIAN:
|
|
96
133
|
# Iterate over the channels to apply manipulation separately
|
|
@@ -100,6 +137,7 @@ class N2VManipulate(Transform):
|
|
|
100
137
|
mask_pixel_percentage=self.masked_pixel_percentage,
|
|
101
138
|
subpatch_size=self.roi_size,
|
|
102
139
|
struct_params=self.struct_mask,
|
|
140
|
+
rng=self.rng,
|
|
103
141
|
)
|
|
104
142
|
else:
|
|
105
143
|
raise ValueError(f"Unknown masking strategy ({self.strategy}).")
|