careamics 0.1.0rc5__py3-none-any.whl → 0.1.0rc6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/callbacks/hyperparameters_callback.py +10 -3
- careamics/callbacks/progress_bar_callback.py +37 -4
- careamics/careamist.py +80 -44
- careamics/config/algorithm_model.py +5 -3
- careamics/config/architectures/architecture_model.py +7 -0
- careamics/config/architectures/custom_model.py +8 -1
- careamics/config/architectures/register_model.py +3 -1
- careamics/config/architectures/unet_model.py +2 -0
- careamics/config/architectures/vae_model.py +2 -0
- careamics/config/callback_model.py +3 -15
- careamics/config/configuration_example.py +4 -2
- careamics/config/configuration_factory.py +4 -16
- careamics/config/data_model.py +10 -14
- careamics/config/inference_model.py +0 -65
- careamics/config/optimizer_models.py +4 -4
- careamics/config/support/__init__.py +0 -2
- careamics/config/support/supported_activations.py +2 -0
- careamics/config/support/supported_algorithms.py +3 -1
- careamics/config/support/supported_architectures.py +2 -0
- careamics/config/support/supported_data.py +2 -0
- careamics/config/support/supported_loggers.py +2 -0
- careamics/config/support/supported_losses.py +2 -0
- careamics/config/support/supported_optimizers.py +2 -0
- careamics/config/support/supported_pixel_manipulations.py +3 -3
- careamics/config/support/supported_struct_axis.py +2 -0
- careamics/config/support/supported_transforms.py +4 -15
- careamics/config/tile_information.py +2 -0
- careamics/config/transformations/__init__.py +3 -2
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +11 -3
- careamics/conftest.py +12 -0
- careamics/dataset/dataset_utils/dataset_utils.py +4 -4
- careamics/dataset/dataset_utils/file_utils.py +4 -3
- careamics/dataset/dataset_utils/read_tiff.py +6 -2
- careamics/dataset/dataset_utils/read_utils.py +2 -0
- careamics/dataset/dataset_utils/read_zarr.py +11 -7
- careamics/dataset/in_memory_dataset.py +71 -32
- careamics/dataset/iterable_dataset.py +155 -68
- careamics/dataset/patching/patching.py +56 -15
- careamics/dataset/patching/random_patching.py +8 -2
- careamics/dataset/patching/sequential_patching.py +14 -8
- careamics/dataset/patching/tiled_patching.py +3 -1
- careamics/dataset/patching/validate_patch_dimension.py +2 -0
- careamics/dataset/zarr_dataset.py +2 -0
- careamics/lightning_datamodule.py +45 -19
- careamics/lightning_module.py +8 -2
- careamics/lightning_prediction_datamodule.py +3 -13
- careamics/lightning_prediction_loop.py +8 -6
- careamics/losses/__init__.py +2 -3
- careamics/losses/loss_factory.py +1 -1
- careamics/losses/losses.py +11 -7
- careamics/model_io/bmz_io.py +3 -3
- careamics/models/activation.py +2 -0
- careamics/models/layers.py +121 -25
- careamics/models/model_factory.py +1 -1
- careamics/models/unet.py +35 -14
- careamics/prediction/stitch_prediction.py +2 -6
- careamics/transforms/__init__.py +2 -2
- careamics/transforms/compose.py +33 -7
- careamics/transforms/n2v_manipulate.py +49 -13
- careamics/transforms/normalize.py +55 -3
- careamics/transforms/pixel_manipulation.py +5 -5
- careamics/transforms/struct_mask_parameters.py +3 -1
- careamics/transforms/transform.py +10 -19
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +38 -5
- careamics/utils/base_enum.py +28 -0
- careamics/utils/path_utils.py +2 -0
- careamics/utils/ram.py +2 -0
- careamics/utils/receptive_field.py +93 -87
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc6.dist-info}/METADATA +2 -1
- careamics-0.1.0rc6.dist-info/RECORD +107 -0
- careamics/config/noise_models.py +0 -162
- careamics/config/support/supported_extraction_strategies.py +0 -25
- careamics/config/transformations/nd_flip_model.py +0 -27
- careamics/losses/noise_model_factory.py +0 -40
- careamics/losses/noise_models.py +0 -524
- careamics/transforms/nd_flip.py +0 -67
- careamics-0.1.0rc5.dist-info/RECORD +0 -111
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc6.dist-info}/WHEEL +0 -0
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc6.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Sequential patching functions."""
|
|
2
|
+
|
|
1
3
|
from typing import List, Optional, Tuple, Union
|
|
2
4
|
|
|
3
5
|
import numpy as np
|
|
@@ -14,14 +16,14 @@ def _compute_number_of_patches(
|
|
|
14
16
|
|
|
15
17
|
Parameters
|
|
16
18
|
----------
|
|
17
|
-
|
|
19
|
+
arr_shape : Tuple[int, ...]
|
|
18
20
|
Shape of the input array.
|
|
19
|
-
patch_sizes : Tuple[int]
|
|
21
|
+
patch_sizes : Union[List[int], Tuple[int, ...]
|
|
20
22
|
Shape of the patches.
|
|
21
23
|
|
|
22
24
|
Returns
|
|
23
25
|
-------
|
|
24
|
-
Tuple[int]
|
|
26
|
+
Tuple[int, ...]
|
|
25
27
|
Number of patches in each dimension.
|
|
26
28
|
"""
|
|
27
29
|
if len(arr_shape) != len(patch_sizes):
|
|
@@ -55,14 +57,14 @@ def _compute_overlap(
|
|
|
55
57
|
|
|
56
58
|
Parameters
|
|
57
59
|
----------
|
|
58
|
-
|
|
60
|
+
arr_shape : Tuple[int, ...]
|
|
59
61
|
Input array shape.
|
|
60
|
-
patch_sizes : Tuple[int]
|
|
62
|
+
patch_sizes : Union[List[int], Tuple[int, ...]]
|
|
61
63
|
Size of the patches.
|
|
62
64
|
|
|
63
65
|
Returns
|
|
64
66
|
-------
|
|
65
|
-
Tuple[int]
|
|
67
|
+
Tuple[int, ...]
|
|
66
68
|
Overlap between patches in each dimension.
|
|
67
69
|
"""
|
|
68
70
|
n_patches = _compute_number_of_patches(arr_shape, patch_sizes)
|
|
@@ -123,6 +125,8 @@ def _compute_patch_views(
|
|
|
123
125
|
Steps between views.
|
|
124
126
|
output_shape : Tuple[int]
|
|
125
127
|
Shape of the output array.
|
|
128
|
+
target : Optional[np.ndarray], optional
|
|
129
|
+
Target array, by default None.
|
|
126
130
|
|
|
127
131
|
Returns
|
|
128
132
|
-------
|
|
@@ -161,11 +165,13 @@ def extract_patches_sequential(
|
|
|
161
165
|
Input image array.
|
|
162
166
|
patch_size : Tuple[int]
|
|
163
167
|
Patch sizes in each dimension.
|
|
168
|
+
target : Optional[np.ndarray], optional
|
|
169
|
+
Target array, by default None.
|
|
164
170
|
|
|
165
171
|
Returns
|
|
166
172
|
-------
|
|
167
|
-
|
|
168
|
-
|
|
173
|
+
Tuple[np.ndarray, Optional[np.ndarray]]
|
|
174
|
+
Patches.
|
|
169
175
|
"""
|
|
170
176
|
is_3d_patch = len(patch_size) == 3
|
|
171
177
|
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Tiled patching utilities."""
|
|
2
|
+
|
|
1
3
|
import itertools
|
|
2
4
|
from typing import Generator, List, Tuple, Union
|
|
3
5
|
|
|
@@ -8,7 +10,7 @@ from careamics.config.tile_information import TileInformation
|
|
|
8
10
|
|
|
9
11
|
def _compute_crop_and_stitch_coords_1d(
|
|
10
12
|
axis_size: int, tile_size: int, overlap: int
|
|
11
|
-
) -> Tuple[List[Tuple[int,
|
|
13
|
+
) -> Tuple[List[Tuple[int, int]], List[Tuple[int, int]], List[Tuple[int, int]]]:
|
|
12
14
|
"""
|
|
13
15
|
Compute the coordinates of each tile along an axis, given the overlap.
|
|
14
16
|
|
|
@@ -95,13 +95,13 @@ class CAREamicsTrainData(L.LightningDataModule):
|
|
|
95
95
|
Batch size.
|
|
96
96
|
use_in_memory : bool
|
|
97
97
|
Whether to use in memory dataset if possible.
|
|
98
|
-
train_data : Union[Path,
|
|
98
|
+
train_data : Union[Path, np.ndarray]
|
|
99
99
|
Training data.
|
|
100
|
-
val_data : Optional[Union[Path,
|
|
100
|
+
val_data : Optional[Union[Path, np.ndarray]]
|
|
101
101
|
Validation data.
|
|
102
|
-
train_data_target : Optional[Union[Path,
|
|
102
|
+
train_data_target : Optional[Union[Path, np.ndarray]]
|
|
103
103
|
Training target data.
|
|
104
|
-
val_data_target : Optional[Union[Path,
|
|
104
|
+
val_data_target : Optional[Union[Path, np.ndarray]]
|
|
105
105
|
Validation target data.
|
|
106
106
|
val_percentage : float
|
|
107
107
|
Percentage of the training data to use for validation, if no validation data is
|
|
@@ -217,17 +217,33 @@ class CAREamicsTrainData(L.LightningDataModule):
|
|
|
217
217
|
)
|
|
218
218
|
|
|
219
219
|
# configuration
|
|
220
|
-
self.data_config = data_config
|
|
221
|
-
self.data_type = data_config.data_type
|
|
222
|
-
self.batch_size = data_config.batch_size
|
|
223
|
-
self.use_in_memory = use_in_memory
|
|
220
|
+
self.data_config: DataConfig = data_config
|
|
221
|
+
self.data_type: str = data_config.data_type
|
|
222
|
+
self.batch_size: int = data_config.batch_size
|
|
223
|
+
self.use_in_memory: bool = use_in_memory
|
|
224
|
+
|
|
225
|
+
# data: make data Path or np.ndarray, use type annotations for mypy
|
|
226
|
+
self.train_data: Union[Path, np.ndarray] = (
|
|
227
|
+
Path(train_data) if isinstance(train_data, str) else train_data
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
self.val_data: Union[Path, np.ndarray] = (
|
|
231
|
+
Path(val_data) if isinstance(val_data, str) else val_data
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
self.train_data_target: Union[Path, np.ndarray] = (
|
|
235
|
+
Path(train_data_target)
|
|
236
|
+
if isinstance(train_data_target, str)
|
|
237
|
+
else train_data_target
|
|
238
|
+
)
|
|
224
239
|
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
240
|
+
self.val_data_target: Union[Path, np.ndarray] = (
|
|
241
|
+
Path(val_data_target)
|
|
242
|
+
if isinstance(val_data_target, str)
|
|
243
|
+
else val_data_target
|
|
244
|
+
)
|
|
228
245
|
|
|
229
|
-
|
|
230
|
-
self.val_data_target = val_data_target
|
|
246
|
+
# validation split
|
|
231
247
|
self.val_percentage = val_percentage
|
|
232
248
|
self.val_minimum_split = val_minimum_split
|
|
233
249
|
|
|
@@ -241,10 +257,10 @@ class CAREamicsTrainData(L.LightningDataModule):
|
|
|
241
257
|
elif data_config.data_type != SupportedData.ARRAY:
|
|
242
258
|
self.read_source_func = get_read_func(data_config.data_type)
|
|
243
259
|
|
|
244
|
-
self.extension_filter = extension_filter
|
|
260
|
+
self.extension_filter: str = extension_filter
|
|
245
261
|
|
|
246
262
|
# Pytorch dataloader parameters
|
|
247
|
-
self.dataloader_params = (
|
|
263
|
+
self.dataloader_params: Dict[str, Any] = (
|
|
248
264
|
data_config.dataloader_params if data_config.dataloader_params else {}
|
|
249
265
|
)
|
|
250
266
|
|
|
@@ -309,20 +325,30 @@ class CAREamicsTrainData(L.LightningDataModule):
|
|
|
309
325
|
"""
|
|
310
326
|
# if numpy array
|
|
311
327
|
if self.data_type == SupportedData.ARRAY:
|
|
328
|
+
# mypy checks
|
|
329
|
+
assert isinstance(self.train_data, np.ndarray)
|
|
330
|
+
if self.train_data_target is not None:
|
|
331
|
+
assert isinstance(self.train_data_target, np.ndarray)
|
|
332
|
+
|
|
312
333
|
# train dataset
|
|
313
334
|
self.train_dataset: DatasetType = InMemoryDataset(
|
|
314
335
|
data_config=self.data_config,
|
|
315
336
|
inputs=self.train_data,
|
|
316
|
-
|
|
337
|
+
input_target=self.train_data_target,
|
|
317
338
|
)
|
|
318
339
|
|
|
319
340
|
# validation dataset
|
|
320
341
|
if self.val_data is not None:
|
|
342
|
+
# mypy checks
|
|
343
|
+
assert isinstance(self.val_data, np.ndarray)
|
|
344
|
+
if self.val_data_target is not None:
|
|
345
|
+
assert isinstance(self.val_data_target, np.ndarray)
|
|
346
|
+
|
|
321
347
|
# create its own dataset
|
|
322
348
|
self.val_dataset: DatasetType = InMemoryDataset(
|
|
323
349
|
data_config=self.data_config,
|
|
324
350
|
inputs=self.val_data,
|
|
325
|
-
|
|
351
|
+
input_target=self.val_data_target,
|
|
326
352
|
)
|
|
327
353
|
else:
|
|
328
354
|
# extract validation from the training patches
|
|
@@ -341,7 +367,7 @@ class CAREamicsTrainData(L.LightningDataModule):
|
|
|
341
367
|
self.train_dataset = InMemoryDataset(
|
|
342
368
|
data_config=self.data_config,
|
|
343
369
|
inputs=self.train_files,
|
|
344
|
-
|
|
370
|
+
input_target=(
|
|
345
371
|
self.train_target_files if self.train_data_target else None
|
|
346
372
|
),
|
|
347
373
|
read_source_func=self.read_source_func,
|
|
@@ -352,7 +378,7 @@ class CAREamicsTrainData(L.LightningDataModule):
|
|
|
352
378
|
self.val_dataset = InMemoryDataset(
|
|
353
379
|
data_config=self.data_config,
|
|
354
380
|
inputs=self.val_files,
|
|
355
|
-
|
|
381
|
+
input_target=(
|
|
356
382
|
self.val_target_files if self.val_data_target else None
|
|
357
383
|
),
|
|
358
384
|
read_source_func=self.read_source_func,
|
careamics/lightning_module.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""CAREamics Lightning module."""
|
|
2
|
+
|
|
1
3
|
from typing import Any, Optional, Union
|
|
2
4
|
|
|
3
5
|
import pytorch_lightning as L
|
|
@@ -24,6 +26,11 @@ class CAREamicsModule(L.LightningModule):
|
|
|
24
26
|
This class encapsulates the a PyTorch model along with the training, validation,
|
|
25
27
|
and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
|
|
26
28
|
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
algorithm_config : Union[AlgorithmModel, dict]
|
|
32
|
+
Algorithm configuration.
|
|
33
|
+
|
|
27
34
|
Attributes
|
|
28
35
|
----------
|
|
29
36
|
model : nn.Module
|
|
@@ -39,8 +46,7 @@ class CAREamicsModule(L.LightningModule):
|
|
|
39
46
|
"""
|
|
40
47
|
|
|
41
48
|
def __init__(self, algorithm_config: Union[AlgorithmConfig, dict]) -> None:
|
|
42
|
-
"""
|
|
43
|
-
CAREamics Lightning module.
|
|
49
|
+
"""Lightning module for CAREamics.
|
|
44
50
|
|
|
45
51
|
This class encapsulates the a PyTorch model along with the training, validation,
|
|
46
52
|
and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Prediction Lightning data modules."""
|
|
2
2
|
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
|
|
4
|
+
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
import pytorch_lightning as L
|
|
@@ -303,9 +303,6 @@ class PredictDataWrapper(CAREamicsPredictData):
|
|
|
303
303
|
Batch size.
|
|
304
304
|
tta_transforms : bool, optional
|
|
305
305
|
Use test time augmentation, by default True.
|
|
306
|
-
transforms : List, optional
|
|
307
|
-
List of transforms to apply to prediction patches. If None, default
|
|
308
|
-
transforms are applied.
|
|
309
306
|
read_source_func : Optional[Callable], optional
|
|
310
307
|
Function to read the source data, used if `data_type` is `custom`, by
|
|
311
308
|
default None.
|
|
@@ -326,7 +323,6 @@ class PredictDataWrapper(CAREamicsPredictData):
|
|
|
326
323
|
axes: str = "YX",
|
|
327
324
|
batch_size: int = 1,
|
|
328
325
|
tta_transforms: bool = True,
|
|
329
|
-
transforms: Optional[List] = None,
|
|
330
326
|
read_source_func: Optional[Callable] = None,
|
|
331
327
|
extension_filter: str = "",
|
|
332
328
|
dataloader_params: Optional[dict] = None,
|
|
@@ -356,9 +352,6 @@ class PredictDataWrapper(CAREamicsPredictData):
|
|
|
356
352
|
Batch size.
|
|
357
353
|
tta_transforms : bool, optional
|
|
358
354
|
Use test time augmentation, by default True.
|
|
359
|
-
transforms : Optional[List], optional
|
|
360
|
-
List of transforms to apply to prediction patches. If None, default
|
|
361
|
-
transforms are applied.
|
|
362
355
|
read_source_func : Optional[Callable], optional
|
|
363
356
|
Function to read the source data, used if `data_type` is `custom`, by
|
|
364
357
|
default None.
|
|
@@ -369,7 +362,7 @@ class PredictDataWrapper(CAREamicsPredictData):
|
|
|
369
362
|
"""
|
|
370
363
|
if dataloader_params is None:
|
|
371
364
|
dataloader_params = {}
|
|
372
|
-
prediction_dict = {
|
|
365
|
+
prediction_dict: Dict[str, Any] = {
|
|
373
366
|
"data_type": data_type,
|
|
374
367
|
"tile_size": tile_size,
|
|
375
368
|
"tile_overlap": tile_overlap,
|
|
@@ -378,12 +371,9 @@ class PredictDataWrapper(CAREamicsPredictData):
|
|
|
378
371
|
"std": std,
|
|
379
372
|
"tta": tta_transforms,
|
|
380
373
|
"batch_size": batch_size,
|
|
374
|
+
"transforms": [],
|
|
381
375
|
}
|
|
382
376
|
|
|
383
|
-
# if transforms are passed (otherwise it will use the default ones)
|
|
384
|
-
if transforms is not None:
|
|
385
|
-
prediction_dict["transforms"] = transforms
|
|
386
|
-
|
|
387
377
|
# validate configuration
|
|
388
378
|
self.prediction_config = InferenceConfig(**prediction_dict)
|
|
389
379
|
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Lithning prediction loop allowing tiling."""
|
|
2
|
+
|
|
1
3
|
from typing import Optional
|
|
2
4
|
|
|
3
5
|
import pytorch_lightning as L
|
|
@@ -18,14 +20,14 @@ class CAREamicsPredictionLoop(L.loops._PredictionLoop):
|
|
|
18
20
|
"""
|
|
19
21
|
|
|
20
22
|
def _on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]:
|
|
21
|
-
"""
|
|
22
|
-
Calls `on_predict_epoch_end` hook.
|
|
23
|
+
"""Call `on_predict_epoch_end` hook.
|
|
23
24
|
|
|
24
25
|
Adapted from the parent method.
|
|
25
26
|
|
|
26
27
|
Returns
|
|
27
28
|
-------
|
|
28
|
-
|
|
29
|
+
Optional[_PREDICT_OUTPUT]
|
|
30
|
+
Prediction output.
|
|
29
31
|
"""
|
|
30
32
|
trainer = self.trainer
|
|
31
33
|
call._call_callback_hooks(trainer, "on_predict_epoch_end")
|
|
@@ -45,15 +47,14 @@ class CAREamicsPredictionLoop(L.loops._PredictionLoop):
|
|
|
45
47
|
|
|
46
48
|
@_no_grad_context
|
|
47
49
|
def run(self) -> Optional[_PREDICT_OUTPUT]:
|
|
48
|
-
"""
|
|
49
|
-
Runs the prediction loop.
|
|
50
|
+
"""Run the prediction loop.
|
|
50
51
|
|
|
51
52
|
Adapted from the parent method in order to stitch the predictions.
|
|
52
53
|
|
|
53
54
|
Returns
|
|
54
55
|
-------
|
|
55
56
|
Optional[_PREDICT_OUTPUT]
|
|
56
|
-
Prediction output
|
|
57
|
+
Prediction output.
|
|
57
58
|
"""
|
|
58
59
|
self.setup_data()
|
|
59
60
|
if self.skip:
|
|
@@ -86,6 +87,7 @@ class CAREamicsPredictionLoop(L.loops._PredictionLoop):
|
|
|
86
87
|
|
|
87
88
|
########################################################
|
|
88
89
|
################ CAREamics specific code ###############
|
|
90
|
+
# TODO: next line is not compatible with muSplit
|
|
89
91
|
is_tiled = len(self.predictions[batch_idx]) == 2
|
|
90
92
|
if is_tiled:
|
|
91
93
|
# extract the last tile flag and the coordinates (crop and stitch)
|
careamics/losses/__init__.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
"""Losses module."""
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
__all__ = ["loss_factory"]
|
|
4
4
|
|
|
5
|
-
|
|
6
|
-
# from .noise_models import GaussianMixtureNoiseModel, HistogramNoiseModel
|
|
5
|
+
from .loss_factory import loss_factory
|
careamics/losses/loss_factory.py
CHANGED
careamics/losses/losses.py
CHANGED
|
@@ -5,23 +5,27 @@ This submodule contains the various losses used in CAREamics.
|
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
|
-
|
|
9
|
-
# TODO if we are only using the DiceLoss, can we just implement it?
|
|
10
|
-
# from segmentation_models_pytorch.losses import DiceLoss
|
|
11
8
|
from torch.nn import L1Loss, MSELoss
|
|
12
9
|
|
|
13
10
|
|
|
14
|
-
def mse_loss(
|
|
11
|
+
def mse_loss(source: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
15
12
|
"""
|
|
16
13
|
Mean squared error loss.
|
|
17
14
|
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
source : torch.Tensor
|
|
18
|
+
Source patches.
|
|
19
|
+
target : torch.Tensor
|
|
20
|
+
Target patches.
|
|
21
|
+
|
|
18
22
|
Returns
|
|
19
23
|
-------
|
|
20
24
|
torch.Tensor
|
|
21
25
|
Loss value.
|
|
22
26
|
"""
|
|
23
27
|
loss = MSELoss()
|
|
24
|
-
return loss(
|
|
28
|
+
return loss(source, target)
|
|
25
29
|
|
|
26
30
|
|
|
27
31
|
def n2v_loss(
|
|
@@ -34,9 +38,9 @@ def n2v_loss(
|
|
|
34
38
|
|
|
35
39
|
Parameters
|
|
36
40
|
----------
|
|
37
|
-
|
|
41
|
+
manipulated_patches : torch.Tensor
|
|
38
42
|
Patches with manipulated pixels.
|
|
39
|
-
|
|
43
|
+
original_patches : torch.Tensor
|
|
40
44
|
Noisy patches.
|
|
41
45
|
masks : torch.Tensor
|
|
42
46
|
Array containing masked pixel locations.
|
careamics/model_io/bmz_io.py
CHANGED
|
@@ -104,9 +104,9 @@ def export_to_bmz(
|
|
|
104
104
|
authors : List[dict]
|
|
105
105
|
Authors of the model.
|
|
106
106
|
input_array : np.ndarray
|
|
107
|
-
Input array.
|
|
107
|
+
Input array, should not have been normalized.
|
|
108
108
|
output_array : np.ndarray
|
|
109
|
-
Output array.
|
|
109
|
+
Output array, should have been denormalized.
|
|
110
110
|
channel_names : Optional[List[str]], optional
|
|
111
111
|
Channel names, by default None.
|
|
112
112
|
data_description : Optional[str], optional
|
|
@@ -178,7 +178,7 @@ def export_to_bmz(
|
|
|
178
178
|
)
|
|
179
179
|
|
|
180
180
|
# test model description
|
|
181
|
-
summary: ValidationSummary = test_model(model_description, decimal=
|
|
181
|
+
summary: ValidationSummary = test_model(model_description, decimal=2)
|
|
182
182
|
if summary.status == "failed":
|
|
183
183
|
raise ValueError(f"Model description test failed: {summary}")
|
|
184
184
|
|
careamics/models/activation.py
CHANGED
careamics/models/layers.py
CHANGED
|
@@ -162,6 +162,18 @@ def _unpack_kernel_size(
|
|
|
162
162
|
"""Unpack kernel_size to a tuple of ints.
|
|
163
163
|
|
|
164
164
|
Inspired by Kornia implementation. TODO: link
|
|
165
|
+
|
|
166
|
+
Parameters
|
|
167
|
+
----------
|
|
168
|
+
kernel_size : Union[Tuple[int, ...], int]
|
|
169
|
+
Kernel size.
|
|
170
|
+
dim : int
|
|
171
|
+
Number of dimensions.
|
|
172
|
+
|
|
173
|
+
Returns
|
|
174
|
+
-------
|
|
175
|
+
Tuple[int, ...]
|
|
176
|
+
Kernel size tuple.
|
|
165
177
|
"""
|
|
166
178
|
if isinstance(kernel_size, int):
|
|
167
179
|
kernel_dims = tuple([kernel_size for _ in range(dim)])
|
|
@@ -173,7 +185,20 @@ def _unpack_kernel_size(
|
|
|
173
185
|
def _compute_zero_padding(
|
|
174
186
|
kernel_size: Union[Tuple[int, ...], int], dim: int
|
|
175
187
|
) -> Tuple[int, ...]:
|
|
176
|
-
"""Utility function that computes zero padding tuple.
|
|
188
|
+
"""Utility function that computes zero padding tuple.
|
|
189
|
+
|
|
190
|
+
Parameters
|
|
191
|
+
----------
|
|
192
|
+
kernel_size : Union[Tuple[int, ...], int]
|
|
193
|
+
Kernel size.
|
|
194
|
+
dim : int
|
|
195
|
+
Number of dimensions.
|
|
196
|
+
|
|
197
|
+
Returns
|
|
198
|
+
-------
|
|
199
|
+
Tuple[int, ...]
|
|
200
|
+
Zero padding tuple.
|
|
201
|
+
"""
|
|
177
202
|
kernel_dims = _unpack_kernel_size(kernel_size, dim)
|
|
178
203
|
return tuple([(kd - 1) // 2 for kd in kernel_dims])
|
|
179
204
|
|
|
@@ -191,14 +216,19 @@ def get_pascal_kernel_1d(
|
|
|
191
216
|
|
|
192
217
|
Parameters
|
|
193
218
|
----------
|
|
194
|
-
kernel_size:
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
219
|
+
kernel_size : int
|
|
220
|
+
Kernel size.
|
|
221
|
+
norm : bool
|
|
222
|
+
Normalize the kernel, by default False.
|
|
223
|
+
device : Optional[torch.device]
|
|
224
|
+
Device of the tensor, by default None.
|
|
225
|
+
dtype : Optional[torch.dtype]
|
|
226
|
+
Data type of the tensor, by default None.
|
|
198
227
|
|
|
199
228
|
Returns
|
|
200
229
|
-------
|
|
201
|
-
|
|
230
|
+
torch.Tensor
|
|
231
|
+
Pascal kernel.
|
|
202
232
|
|
|
203
233
|
Examples
|
|
204
234
|
--------
|
|
@@ -245,19 +275,28 @@ def _get_pascal_kernel_nd(
|
|
|
245
275
|
) -> torch.Tensor:
|
|
246
276
|
"""Generate pascal filter kernel by kernel size.
|
|
247
277
|
|
|
278
|
+
If kernel_size is an integer the kernel will be shaped as (kernel_size, kernel_size)
|
|
279
|
+
otherwise the kernel will be shaped as kernel_size
|
|
280
|
+
|
|
248
281
|
Inspired by Kornia implementation.
|
|
249
282
|
|
|
250
283
|
Parameters
|
|
251
284
|
----------
|
|
252
|
-
kernel_size:
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
285
|
+
kernel_size : Union[Tuple[int, int], int]
|
|
286
|
+
Kernel size for the pascal kernel.
|
|
287
|
+
norm : bool
|
|
288
|
+
Normalize the kernel, by default True.
|
|
289
|
+
dim : int
|
|
290
|
+
Number of dimensions, by default 2.
|
|
291
|
+
device : Optional[torch.device]
|
|
292
|
+
Device of the tensor, by default None.
|
|
293
|
+
dtype : Optional[torch.dtype]
|
|
294
|
+
Data type of the tensor, by default None.
|
|
256
295
|
|
|
257
296
|
Returns
|
|
258
297
|
-------
|
|
259
|
-
|
|
260
|
-
|
|
298
|
+
torch.Tensor
|
|
299
|
+
Pascal kernel.
|
|
261
300
|
|
|
262
301
|
Examples
|
|
263
302
|
--------
|
|
@@ -303,6 +342,24 @@ def _max_blur_pool_by_kernel2d(
|
|
|
303
342
|
"""Compute max_blur_pool by a given :math:`CxC_(out, None)xNxN` kernel.
|
|
304
343
|
|
|
305
344
|
Inspired by Kornia implementation.
|
|
345
|
+
|
|
346
|
+
Parameters
|
|
347
|
+
----------
|
|
348
|
+
x : torch.Tensor
|
|
349
|
+
Input tensor.
|
|
350
|
+
kernel : torch.Tensor
|
|
351
|
+
Kernel tensor.
|
|
352
|
+
stride : int
|
|
353
|
+
Stride.
|
|
354
|
+
max_pool_size : int
|
|
355
|
+
Maximum pool size.
|
|
356
|
+
ceil_mode : bool
|
|
357
|
+
Ceil mode, by default False. Set to True to match output size of conv2d.
|
|
358
|
+
|
|
359
|
+
Returns
|
|
360
|
+
-------
|
|
361
|
+
torch.Tensor
|
|
362
|
+
Output tensor.
|
|
306
363
|
"""
|
|
307
364
|
# compute local maxima
|
|
308
365
|
x = F.max_pool2d(
|
|
@@ -323,6 +380,24 @@ def _max_blur_pool_by_kernel3d(
|
|
|
323
380
|
"""Compute max_blur_pool by a given :math:`CxC_(out, None)xNxNxN` kernel.
|
|
324
381
|
|
|
325
382
|
Inspired by Kornia implementation.
|
|
383
|
+
|
|
384
|
+
Parameters
|
|
385
|
+
----------
|
|
386
|
+
x : torch.Tensor
|
|
387
|
+
Input tensor.
|
|
388
|
+
kernel : torch.Tensor
|
|
389
|
+
Kernel tensor.
|
|
390
|
+
stride : int
|
|
391
|
+
Stride.
|
|
392
|
+
max_pool_size : int
|
|
393
|
+
Maximum pool size.
|
|
394
|
+
ceil_mode : bool
|
|
395
|
+
Ceil mode, by default False. Set to True to match output size of conv2d.
|
|
396
|
+
|
|
397
|
+
Returns
|
|
398
|
+
-------
|
|
399
|
+
torch.Tensor
|
|
400
|
+
Output tensor.
|
|
326
401
|
"""
|
|
327
402
|
# compute local maxima
|
|
328
403
|
x = F.max_pool3d(
|
|
@@ -343,21 +418,16 @@ class MaxBlurPool(nn.Module):
|
|
|
343
418
|
|
|
344
419
|
Parameters
|
|
345
420
|
----------
|
|
346
|
-
dim: int
|
|
347
|
-
Toggles between 2D and 3D
|
|
348
|
-
kernel_size: Union[Tuple[int, int], int]
|
|
421
|
+
dim : int
|
|
422
|
+
Toggles between 2D and 3D.
|
|
423
|
+
kernel_size : Union[Tuple[int, int], int]
|
|
349
424
|
Kernel size for max pooling.
|
|
350
|
-
stride: int
|
|
425
|
+
stride : int
|
|
351
426
|
Stride for pooling.
|
|
352
|
-
max_pool_size: int
|
|
427
|
+
max_pool_size : int
|
|
353
428
|
Max kernel size for max pooling.
|
|
354
|
-
ceil_mode: bool
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
Returns
|
|
358
|
-
-------
|
|
359
|
-
torch.Tensor
|
|
360
|
-
The pooled and blurred tensor.
|
|
429
|
+
ceil_mode : bool
|
|
430
|
+
Ceil mode, by default False. Set to True to match output size of conv2d.
|
|
361
431
|
"""
|
|
362
432
|
|
|
363
433
|
def __init__(
|
|
@@ -368,6 +438,21 @@ class MaxBlurPool(nn.Module):
|
|
|
368
438
|
max_pool_size: int = 2,
|
|
369
439
|
ceil_mode: bool = False,
|
|
370
440
|
) -> None:
|
|
441
|
+
"""Constructor.
|
|
442
|
+
|
|
443
|
+
Parameters
|
|
444
|
+
----------
|
|
445
|
+
dim : int
|
|
446
|
+
Dimension of the convolution.
|
|
447
|
+
kernel_size : Union[Tuple[int, int], int]
|
|
448
|
+
Kernel size for max pooling.
|
|
449
|
+
stride : int, optional
|
|
450
|
+
Stride, by default 2.
|
|
451
|
+
max_pool_size : int, optional
|
|
452
|
+
Maximum pool size, by default 2.
|
|
453
|
+
ceil_mode : bool, optional
|
|
454
|
+
Ceil mode, by default False. Set to True to match output size of conv2d.
|
|
455
|
+
"""
|
|
371
456
|
super().__init__()
|
|
372
457
|
self.dim = dim
|
|
373
458
|
self.kernel_size = kernel_size
|
|
@@ -377,7 +462,18 @@ class MaxBlurPool(nn.Module):
|
|
|
377
462
|
self.kernel = _get_pascal_kernel_nd(kernel_size, norm=True, dim=self.dim)
|
|
378
463
|
|
|
379
464
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
380
|
-
"""Forward pass of the function.
|
|
465
|
+
"""Forward pass of the function.
|
|
466
|
+
|
|
467
|
+
Parameters
|
|
468
|
+
----------
|
|
469
|
+
x : torch.Tensor
|
|
470
|
+
Input tensor.
|
|
471
|
+
|
|
472
|
+
Returns
|
|
473
|
+
-------
|
|
474
|
+
torch.Tensor
|
|
475
|
+
Output tensor.
|
|
476
|
+
"""
|
|
381
477
|
self.kernel = torch.as_tensor(self.kernel, device=x.device, dtype=x.dtype)
|
|
382
478
|
if self.dim == 2:
|
|
383
479
|
return _max_blur_pool_by_kernel2d(
|