careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +16 -4
- careamics/callbacks/__init__.py +6 -0
- careamics/callbacks/hyperparameters_callback.py +42 -0
- careamics/callbacks/progress_bar_callback.py +57 -0
- careamics/careamist.py +761 -0
- careamics/config/__init__.py +31 -3
- careamics/config/algorithm_model.py +167 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +29 -0
- careamics/config/architectures/custom_model.py +150 -0
- careamics/config/architectures/register_model.py +101 -0
- careamics/config/architectures/unet_model.py +96 -0
- careamics/config/architectures/vae_model.py +39 -0
- careamics/config/callback_model.py +92 -0
- careamics/config/configuration_example.py +89 -0
- careamics/config/configuration_factory.py +597 -0
- careamics/config/configuration_model.py +597 -0
- careamics/config/data_model.py +555 -0
- careamics/config/inference_model.py +283 -0
- careamics/config/noise_models.py +162 -0
- careamics/config/optimizer_models.py +181 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +131 -0
- careamics/config/references/references.py +38 -0
- careamics/config/support/__init__.py +33 -0
- careamics/config/support/supported_activations.py +24 -0
- careamics/config/support/supported_algorithms.py +18 -0
- careamics/config/support/supported_architectures.py +18 -0
- careamics/config/support/supported_data.py +82 -0
- careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
- careamics/config/support/supported_loggers.py +8 -0
- careamics/config/support/supported_losses.py +25 -0
- careamics/config/support/supported_optimizers.py +55 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +19 -0
- careamics/config/support/supported_transforms.py +23 -0
- careamics/config/tile_information.py +104 -0
- careamics/config/training_model.py +65 -0
- careamics/config/transformations/__init__.py +14 -0
- careamics/config/transformations/n2v_manipulate_model.py +63 -0
- careamics/config/transformations/nd_flip_model.py +32 -0
- careamics/config/transformations/normalize_model.py +31 -0
- careamics/config/transformations/transform_model.py +44 -0
- careamics/config/transformations/xy_random_rotate90_model.py +29 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +100 -0
- careamics/conftest.py +26 -0
- careamics/dataset/__init__.py +5 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +100 -0
- careamics/dataset/dataset_utils/file_utils.py +140 -0
- careamics/dataset/dataset_utils/read_tiff.py +61 -0
- careamics/dataset/dataset_utils/read_utils.py +25 -0
- careamics/dataset/dataset_utils/read_zarr.py +56 -0
- careamics/dataset/in_memory_dataset.py +323 -134
- careamics/dataset/iterable_dataset.py +416 -0
- careamics/dataset/patching/__init__.py +8 -0
- careamics/dataset/patching/patch_transform.py +44 -0
- careamics/dataset/patching/patching.py +212 -0
- careamics/dataset/patching/random_patching.py +190 -0
- careamics/dataset/patching/sequential_patching.py +206 -0
- careamics/dataset/patching/tiled_patching.py +158 -0
- careamics/dataset/patching/validate_patch_dimension.py +60 -0
- careamics/dataset/zarr_dataset.py +149 -0
- careamics/lightning_datamodule.py +743 -0
- careamics/lightning_module.py +292 -0
- careamics/lightning_prediction_datamodule.py +396 -0
- careamics/lightning_prediction_loop.py +116 -0
- careamics/losses/__init__.py +4 -1
- careamics/losses/loss_factory.py +24 -14
- careamics/losses/losses.py +65 -5
- careamics/losses/noise_model_factory.py +40 -0
- careamics/losses/noise_models.py +524 -0
- careamics/model_io/__init__.py +8 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +120 -0
- careamics/model_io/bioimage/bioimage_utils.py +48 -0
- careamics/model_io/bioimage/model_description.py +318 -0
- careamics/model_io/bmz_io.py +231 -0
- careamics/model_io/model_io_utils.py +80 -0
- careamics/models/__init__.py +4 -1
- careamics/models/activation.py +35 -0
- careamics/models/layers.py +244 -0
- careamics/models/model_factory.py +21 -221
- careamics/models/unet.py +46 -20
- careamics/prediction/__init__.py +1 -3
- careamics/prediction/stitch_prediction.py +73 -0
- careamics/transforms/__init__.py +41 -0
- careamics/transforms/n2v_manipulate.py +113 -0
- careamics/transforms/nd_flip.py +93 -0
- careamics/transforms/normalize.py +109 -0
- careamics/transforms/pixel_manipulation.py +383 -0
- careamics/transforms/struct_mask_parameters.py +18 -0
- careamics/transforms/tta.py +74 -0
- careamics/transforms/xy_random_rotate90.py +95 -0
- careamics/utils/__init__.py +10 -12
- careamics/utils/base_enum.py +32 -0
- careamics/utils/context.py +22 -2
- careamics/utils/metrics.py +0 -46
- careamics/utils/path_utils.py +24 -0
- careamics/utils/ram.py +13 -0
- careamics/utils/receptive_field.py +102 -0
- careamics/utils/running_stats.py +43 -0
- careamics/utils/torch_utils.py +112 -75
- careamics-0.1.0rc4.dist-info/METADATA +122 -0
- careamics-0.1.0rc4.dist-info/RECORD +110 -0
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/WHEEL +1 -1
- careamics/bioimage/__init__.py +0 -15
- careamics/bioimage/docs/Noise2Void.md +0 -5
- careamics/bioimage/docs/__init__.py +0 -1
- careamics/bioimage/io.py +0 -182
- careamics/bioimage/rdf.py +0 -105
- careamics/config/algorithm.py +0 -231
- careamics/config/config.py +0 -297
- careamics/config/config_filter.py +0 -44
- careamics/config/data.py +0 -194
- careamics/config/torch_optim.py +0 -118
- careamics/config/training.py +0 -534
- careamics/dataset/dataset_utils.py +0 -111
- careamics/dataset/patching.py +0 -492
- careamics/dataset/prepare_dataset.py +0 -175
- careamics/dataset/tiff_dataset.py +0 -212
- careamics/engine.py +0 -1014
- careamics/manipulation/__init__.py +0 -4
- careamics/manipulation/pixel_manipulation.py +0 -158
- careamics/prediction/prediction_utils.py +0 -106
- careamics/utils/ascii_logo.txt +0 -9
- careamics/utils/augment.py +0 -65
- careamics/utils/normalization.py +0 -55
- careamics/utils/validators.py +0 -170
- careamics/utils/wandb.py +0 -121
- careamics-0.1.0rc2.dist-info/METADATA +0 -81
- careamics-0.1.0rc2.dist-info/RECORD +0 -47
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,158 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Pixel manipulation methods.
|
|
3
|
-
|
|
4
|
-
Pixel manipulation is used in N2V and similar algorithm to replace the value of
|
|
5
|
-
masked pixels.
|
|
6
|
-
"""
|
|
7
|
-
from typing import Callable, Optional, Tuple
|
|
8
|
-
|
|
9
|
-
import numpy as np
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
def _odd_jitter_func(step: float, rng: np.random.Generator) -> np.ndarray:
|
|
13
|
-
"""
|
|
14
|
-
Randomly sample a jitter to be applied to the masking grid.
|
|
15
|
-
|
|
16
|
-
This is done to account for cases where the step size is not an integer.
|
|
17
|
-
|
|
18
|
-
Parameters
|
|
19
|
-
----------
|
|
20
|
-
step : float
|
|
21
|
-
Step size of the grid, output of np.linspace.
|
|
22
|
-
rng : np.random.Generator
|
|
23
|
-
Random number generator.
|
|
24
|
-
|
|
25
|
-
Returns
|
|
26
|
-
-------
|
|
27
|
-
np.ndarray
|
|
28
|
-
Array of random jitter to be added to the grid.
|
|
29
|
-
"""
|
|
30
|
-
# Define the random jitter to be added to the grid
|
|
31
|
-
odd_jitter = np.where(np.floor(step) == step, 0, rng.integers(0, 2))
|
|
32
|
-
|
|
33
|
-
# Round the step size to the nearest integer depending on the jitter
|
|
34
|
-
return np.floor(step) if odd_jitter == 0 else np.ceil(step)
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def get_stratified_coords(
|
|
38
|
-
mask_pixel_perc: float,
|
|
39
|
-
shape: Tuple[int, ...],
|
|
40
|
-
) -> np.ndarray:
|
|
41
|
-
"""
|
|
42
|
-
Generate coordinates of the pixels to mask.
|
|
43
|
-
|
|
44
|
-
Randomly selects the coordinates of the pixels to mask in a stratified way, i.e.
|
|
45
|
-
the distance between masked pixels is approximately the same.
|
|
46
|
-
|
|
47
|
-
Parameters
|
|
48
|
-
----------
|
|
49
|
-
mask_pixel_perc : float
|
|
50
|
-
Actual (quasi) percentage of masked pixels across the whole image. Used in
|
|
51
|
-
calculating the distance between masked pixels across each axis.
|
|
52
|
-
shape : Tuple[int, ...]
|
|
53
|
-
Shape of the input patch.
|
|
54
|
-
|
|
55
|
-
Returns
|
|
56
|
-
-------
|
|
57
|
-
np.ndarray
|
|
58
|
-
Array of coordinates of the masked pixels.
|
|
59
|
-
"""
|
|
60
|
-
rng = np.random.default_rng()
|
|
61
|
-
|
|
62
|
-
# Define the approximate distance between masked pixels
|
|
63
|
-
mask_pixel_distance = np.round((100 / mask_pixel_perc) ** (1 / len(shape))).astype(
|
|
64
|
-
np.int32
|
|
65
|
-
)
|
|
66
|
-
|
|
67
|
-
# Define a grid of coordinates for each axis in the input patch and the step size
|
|
68
|
-
pixel_coords = []
|
|
69
|
-
for axis_size in shape:
|
|
70
|
-
# make sure axis size is evenly divisible by box size
|
|
71
|
-
num_pixels = int(np.ceil(axis_size / mask_pixel_distance))
|
|
72
|
-
axis_pixel_coords, step = np.linspace(
|
|
73
|
-
0, axis_size, num_pixels, dtype=np.int32, endpoint=False, retstep=True
|
|
74
|
-
)
|
|
75
|
-
# explain
|
|
76
|
-
pixel_coords.append(axis_pixel_coords.T)
|
|
77
|
-
|
|
78
|
-
# Create a meshgrid of coordinates for each axis in the input patch
|
|
79
|
-
coordinate_grid_list = np.meshgrid(*pixel_coords)
|
|
80
|
-
coordinate_grid = np.array(coordinate_grid_list).reshape(len(shape), -1).T
|
|
81
|
-
|
|
82
|
-
grid_random_increment = rng.integers(
|
|
83
|
-
_odd_jitter_func(float(step), rng)
|
|
84
|
-
* np.ones_like(coordinate_grid).astype(np.int32)
|
|
85
|
-
- 1,
|
|
86
|
-
size=coordinate_grid.shape,
|
|
87
|
-
endpoint=True,
|
|
88
|
-
)
|
|
89
|
-
coordinate_grid += grid_random_increment
|
|
90
|
-
coordinate_grid = np.clip(coordinate_grid, 0, np.array(shape) - 1)
|
|
91
|
-
return coordinate_grid
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
def default_manipulate(
|
|
95
|
-
patch: np.ndarray,
|
|
96
|
-
mask_pixel_percentage: float,
|
|
97
|
-
roi_size: int = 11,
|
|
98
|
-
augmentations: Optional[Callable] = None,
|
|
99
|
-
) -> Tuple[np.ndarray, ...]:
|
|
100
|
-
"""
|
|
101
|
-
Manipulate pixel in a patch, i.e. replace the masked value.
|
|
102
|
-
|
|
103
|
-
Parameters
|
|
104
|
-
----------
|
|
105
|
-
patch : np.ndarray
|
|
106
|
-
Image patch, 2D or 3D, shape (y, x) or (z, y, x).
|
|
107
|
-
mask_pixel_percentage : floar
|
|
108
|
-
Approximate percentage of pixels to be masked.
|
|
109
|
-
roi_size : int
|
|
110
|
-
Size of the ROI the new pixel value is sampled from, by default 11.
|
|
111
|
-
augmentations : Callable, optional
|
|
112
|
-
Augmentations to apply, by default None.
|
|
113
|
-
|
|
114
|
-
Returns
|
|
115
|
-
-------
|
|
116
|
-
Tuple[np.ndarray]
|
|
117
|
-
Tuple containing the manipulated patch, the original patch and the mask.
|
|
118
|
-
"""
|
|
119
|
-
original_patch = patch.copy()
|
|
120
|
-
|
|
121
|
-
# Get the coordinates of the pixels to be replaced
|
|
122
|
-
roi_centers = get_stratified_coords(mask_pixel_percentage, patch.shape)
|
|
123
|
-
rng = np.random.default_rng()
|
|
124
|
-
|
|
125
|
-
# Generate coordinate grid for ROI
|
|
126
|
-
roi_span_full = np.arange(-np.floor(roi_size / 2), np.ceil(roi_size / 2)).astype(
|
|
127
|
-
np.int32
|
|
128
|
-
)
|
|
129
|
-
# Remove the center pixel from the grid
|
|
130
|
-
roi_span_wo_center = roi_span_full[roi_span_full != 0]
|
|
131
|
-
|
|
132
|
-
# Randomly select coordinates from the grid
|
|
133
|
-
random_increment = rng.choice(roi_span_wo_center, size=roi_centers.shape)
|
|
134
|
-
|
|
135
|
-
# Clip the coordinates to the patch size
|
|
136
|
-
replacement_coords = np.clip(
|
|
137
|
-
roi_centers + random_increment,
|
|
138
|
-
0,
|
|
139
|
-
[patch.shape[i] - 1 for i in range(len(patch.shape))],
|
|
140
|
-
)
|
|
141
|
-
# Get the replacement pixels from all rois
|
|
142
|
-
replacement_pixels = patch[tuple(replacement_coords.T.tolist())]
|
|
143
|
-
|
|
144
|
-
# Replace the original pixels with the replacement pixels
|
|
145
|
-
patch[tuple(roi_centers.T.tolist())] = replacement_pixels
|
|
146
|
-
mask = np.where(patch != original_patch, 1, 0).astype(np.uint8)
|
|
147
|
-
|
|
148
|
-
patch, original_patch, mask = (
|
|
149
|
-
(patch, original_patch, mask)
|
|
150
|
-
if augmentations is None
|
|
151
|
-
else augmentations(patch, original_patch, mask)
|
|
152
|
-
)
|
|
153
|
-
|
|
154
|
-
return (
|
|
155
|
-
np.expand_dims(patch, 0),
|
|
156
|
-
np.expand_dims(original_patch, 0),
|
|
157
|
-
np.expand_dims(mask, 0),
|
|
158
|
-
)
|
|
@@ -1,106 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Prediction convenience functions.
|
|
3
|
-
|
|
4
|
-
These functions are used during prediction.
|
|
5
|
-
"""
|
|
6
|
-
from typing import List
|
|
7
|
-
|
|
8
|
-
import numpy as np
|
|
9
|
-
import torch
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
def stitch_prediction(
|
|
13
|
-
tiles: List[np.ndarray],
|
|
14
|
-
stitching_data: List,
|
|
15
|
-
) -> np.ndarray:
|
|
16
|
-
"""
|
|
17
|
-
Stitch tiles back together to form a full image.
|
|
18
|
-
|
|
19
|
-
Parameters
|
|
20
|
-
----------
|
|
21
|
-
tiles : List[Tuple[np.ndarray, List[int]]]
|
|
22
|
-
Cropped tiles and their respective stitching coordinates.
|
|
23
|
-
stitching_data : List
|
|
24
|
-
List of coordinates obtained from
|
|
25
|
-
dataset.tiling.compute_crop_and_stitch_coords_1d.
|
|
26
|
-
|
|
27
|
-
Returns
|
|
28
|
-
-------
|
|
29
|
-
np.ndarray
|
|
30
|
-
Full image.
|
|
31
|
-
"""
|
|
32
|
-
# Get whole sample shape
|
|
33
|
-
input_shape = stitching_data[0][0]
|
|
34
|
-
predicted_image = np.zeros(input_shape, dtype=np.float32)
|
|
35
|
-
for tile, (_, overlap_crop_coords, stitch_coords) in zip(tiles, stitching_data):
|
|
36
|
-
# Compute coordinates for cropping predicted tile
|
|
37
|
-
slices = tuple([slice(c[0], c[1]) for c in overlap_crop_coords])
|
|
38
|
-
|
|
39
|
-
# Crop predited tile according to overlap coordinates
|
|
40
|
-
cropped_tile = tile.squeeze()[slices]
|
|
41
|
-
|
|
42
|
-
# Insert cropped tile into predicted image using stitch coordinates
|
|
43
|
-
predicted_image[
|
|
44
|
-
(..., *[slice(c[0], c[1]) for c in stitch_coords])
|
|
45
|
-
] = cropped_tile
|
|
46
|
-
return predicted_image
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
def tta_forward(x: torch.Tensor) -> List[torch.Tensor]:
|
|
50
|
-
"""
|
|
51
|
-
Augment 8-fold an array.
|
|
52
|
-
|
|
53
|
-
The augmentation is performed using all 90 deg rotations and their flipped version,
|
|
54
|
-
as well as the original image flipped.
|
|
55
|
-
|
|
56
|
-
Tensors should be of shape SC(Z)YX, with S and C potentially singleton dimensions.
|
|
57
|
-
|
|
58
|
-
Parameters
|
|
59
|
-
----------
|
|
60
|
-
x : torch.Tensor
|
|
61
|
-
Data to augment.
|
|
62
|
-
|
|
63
|
-
Returns
|
|
64
|
-
-------
|
|
65
|
-
List
|
|
66
|
-
Stack of augmented images.
|
|
67
|
-
"""
|
|
68
|
-
x_aug = [
|
|
69
|
-
x,
|
|
70
|
-
torch.rot90(x, 1, dims=(2, 3)),
|
|
71
|
-
torch.rot90(x, 2, dims=(2, 3)),
|
|
72
|
-
torch.rot90(x, 3, dims=(2, 3)),
|
|
73
|
-
]
|
|
74
|
-
x_aug_flip = x_aug.copy()
|
|
75
|
-
for x_ in x_aug:
|
|
76
|
-
x_aug_flip.append(torch.flip(x_, dims=(1, 3)))
|
|
77
|
-
return x_aug_flip
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
def tta_backward(x_aug: List[torch.Tensor]) -> np.ndarray:
|
|
81
|
-
"""
|
|
82
|
-
Invert `tta_forward` and average the 8 images.
|
|
83
|
-
|
|
84
|
-
The function takes a list of torch tensors and returns a numpy array.
|
|
85
|
-
|
|
86
|
-
Parameters
|
|
87
|
-
----------
|
|
88
|
-
x_aug : List[torch.Tensor]
|
|
89
|
-
Stack of 8-fold augmented images.
|
|
90
|
-
|
|
91
|
-
Returns
|
|
92
|
-
-------
|
|
93
|
-
np.ndarray
|
|
94
|
-
Average of de-augmented x_aug.
|
|
95
|
-
"""
|
|
96
|
-
x_deaug = [
|
|
97
|
-
x_aug[0].numpy(),
|
|
98
|
-
np.rot90(x_aug[1], -1, axes=(2, 3)),
|
|
99
|
-
np.rot90(x_aug[2], -2, axes=(2, 3)),
|
|
100
|
-
np.rot90(x_aug[3], -3, axes=(2, 3)),
|
|
101
|
-
np.flip(x_aug[4].numpy(), axis=(1, 3)),
|
|
102
|
-
np.rot90(np.flip(x_aug[5].numpy(), axis=(1, 3)), -1, axes=(2, 3)),
|
|
103
|
-
np.rot90(np.flip(x_aug[6].numpy(), axis=(1, 3)), -2, axes=(2, 3)),
|
|
104
|
-
np.rot90(np.flip(x_aug[7].numpy(), axis=(1, 3)), -3, axes=(2, 3)),
|
|
105
|
-
]
|
|
106
|
-
return np.mean(x_deaug, 0)
|
careamics/utils/ascii_logo.txt
DELETED
|
@@ -1,9 +0,0 @@
|
|
|
1
|
-
...... ...... ........ ........ ....
|
|
2
|
-
-+++----+- -+++--+++- :+++---+++: :+++----- .--:
|
|
3
|
-
.+++ .: +++. .+++. :+++ :+++ :+++ :------. .---:----..:----. :--- :----: :----:.
|
|
4
|
-
.+++ .+++. .+++. :+++ -++= :+++ +=....=+++ :+++-..=+++-..=++= -+++ .+++-..++ +++-..=+.
|
|
5
|
-
.+++ .++++++++++. :++++++++=. :++++++: .+++. :+++ :+++ -+++ -+++ :+++ .+++=.
|
|
6
|
-
.+++ .+++. .+++. :+++ -+++ :+++ :=++==++++. :+++ :+++ -+++ -+++ :+++ .-=+++=:
|
|
7
|
-
.+++ .. .+++. .+++. :+++ :+++ :+++ .+++. .+++. :+++ :+++ -+++ -+++ :+++ .. .. :+++.
|
|
8
|
-
-++=-::-+= .+++. .+++. :+++ :+++ :+++-:::: =++=--=+++. :+++ :+++ -+++ -+++ =++=:-+= =+-:=++=
|
|
9
|
-
...... ... ... ... ... ........ .... ... ... ... .... .... .... .....
|
careamics/utils/augment.py
DELETED
|
@@ -1,65 +0,0 @@
|
|
|
1
|
-
"""Augmentation module."""
|
|
2
|
-
from typing import Tuple
|
|
3
|
-
|
|
4
|
-
import numpy as np
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
# TODO: unused?
|
|
8
|
-
def _flip_and_rotate(
|
|
9
|
-
image: np.ndarray, rotate_state: int, flip_state: int
|
|
10
|
-
) -> np.ndarray:
|
|
11
|
-
"""
|
|
12
|
-
Apply the given number of 90 degrees rotations and flip to an array.
|
|
13
|
-
|
|
14
|
-
Parameters
|
|
15
|
-
----------
|
|
16
|
-
image : np.ndarray
|
|
17
|
-
Array containing single image or patch, 2D or 3D.
|
|
18
|
-
rotate_state : int
|
|
19
|
-
Number of 90 degree rotations to apply.
|
|
20
|
-
flip_state : int
|
|
21
|
-
0 or 1, whether to flip the array or not.
|
|
22
|
-
|
|
23
|
-
Returns
|
|
24
|
-
-------
|
|
25
|
-
np.ndarray
|
|
26
|
-
Flipped and rotated array.
|
|
27
|
-
"""
|
|
28
|
-
rotated = np.rot90(image, k=rotate_state, axes=(-2, -1))
|
|
29
|
-
flipped = np.flip(rotated, axis=-1) if flip_state == 1 else rotated
|
|
30
|
-
return flipped.copy()
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
def augment_batch(
|
|
34
|
-
patch: np.ndarray,
|
|
35
|
-
original_image: np.ndarray,
|
|
36
|
-
mask: np.ndarray,
|
|
37
|
-
seed: int = 42,
|
|
38
|
-
) -> Tuple[np.ndarray, ...]:
|
|
39
|
-
"""
|
|
40
|
-
Apply augmentation function to patches and masks.
|
|
41
|
-
|
|
42
|
-
Parameters
|
|
43
|
-
----------
|
|
44
|
-
patch : np.ndarray
|
|
45
|
-
Array containing single image or patch, 2D or 3D with masked pixels.
|
|
46
|
-
original_image : np.ndarray
|
|
47
|
-
Array containing original image or patch, 2D or 3D.
|
|
48
|
-
mask : np.ndarray
|
|
49
|
-
Array containing only masked pixels, 2D or 3D.
|
|
50
|
-
seed : int, optional
|
|
51
|
-
Seed for random number generator, controls the rotation and falipping.
|
|
52
|
-
|
|
53
|
-
Returns
|
|
54
|
-
-------
|
|
55
|
-
Tuple[np.ndarray, ...]
|
|
56
|
-
Tuple of augmented arrays.
|
|
57
|
-
"""
|
|
58
|
-
rng = np.random.default_rng(seed=seed)
|
|
59
|
-
rotate_state = rng.integers(0, 4)
|
|
60
|
-
flip_state = rng.integers(0, 2)
|
|
61
|
-
return (
|
|
62
|
-
_flip_and_rotate(patch, rotate_state, flip_state),
|
|
63
|
-
_flip_and_rotate(original_image, rotate_state, flip_state),
|
|
64
|
-
_flip_and_rotate(mask, rotate_state, flip_state),
|
|
65
|
-
)
|
careamics/utils/normalization.py
DELETED
|
@@ -1,55 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Normalization submodule.
|
|
3
|
-
|
|
4
|
-
These methods are used to normalize and denormalize images.
|
|
5
|
-
"""
|
|
6
|
-
import numpy as np
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
def normalize(img: np.ndarray, mean: float, std: float) -> np.ndarray:
|
|
10
|
-
"""
|
|
11
|
-
Normalize an image using mean and standard deviation.
|
|
12
|
-
|
|
13
|
-
Images are normalised by subtracting the mean and dividing by the standard
|
|
14
|
-
deviation.
|
|
15
|
-
|
|
16
|
-
Parameters
|
|
17
|
-
----------
|
|
18
|
-
img : np.ndarray
|
|
19
|
-
Image to normalize.
|
|
20
|
-
mean : float
|
|
21
|
-
Mean.
|
|
22
|
-
std : float
|
|
23
|
-
Standard deviation.
|
|
24
|
-
|
|
25
|
-
Returns
|
|
26
|
-
-------
|
|
27
|
-
np.ndarray
|
|
28
|
-
Normalized array.
|
|
29
|
-
"""
|
|
30
|
-
zero_mean = img - mean
|
|
31
|
-
return zero_mean / std
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
def denormalize(img: np.ndarray, mean: float, std: float) -> np.ndarray:
|
|
35
|
-
"""
|
|
36
|
-
Denormalize an image using mean and standard deviation.
|
|
37
|
-
|
|
38
|
-
Images are denormalised by multiplying by the standard deviation and adding the
|
|
39
|
-
mean.
|
|
40
|
-
|
|
41
|
-
Parameters
|
|
42
|
-
----------
|
|
43
|
-
img : np.ndarray
|
|
44
|
-
Image to denormalize.
|
|
45
|
-
mean : float
|
|
46
|
-
Mean.
|
|
47
|
-
std : float
|
|
48
|
-
Standard deviation.
|
|
49
|
-
|
|
50
|
-
Returns
|
|
51
|
-
-------
|
|
52
|
-
np.ndarray
|
|
53
|
-
Denormalized array.
|
|
54
|
-
"""
|
|
55
|
-
return img * std + mean
|
careamics/utils/validators.py
DELETED
|
@@ -1,170 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Validator functions.
|
|
3
|
-
|
|
4
|
-
These functions are used to validate dimensions and axes of inputs.
|
|
5
|
-
"""
|
|
6
|
-
from typing import List
|
|
7
|
-
|
|
8
|
-
import numpy as np
|
|
9
|
-
|
|
10
|
-
AXES = "STCZYX"
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
def check_axes_validity(axes: str) -> None:
|
|
14
|
-
"""
|
|
15
|
-
Sanity check on axes.
|
|
16
|
-
|
|
17
|
-
The constraints on the axes are the following:
|
|
18
|
-
- must be a combination of 'STCZYX'
|
|
19
|
-
- must not contain duplicates
|
|
20
|
-
- must contain at least 2 contiguous axes: X and Y
|
|
21
|
-
- must contain at most 4 axes
|
|
22
|
-
- cannot contain both S and T axes
|
|
23
|
-
- C is currently not allowed
|
|
24
|
-
|
|
25
|
-
Parameters
|
|
26
|
-
----------
|
|
27
|
-
axes : str
|
|
28
|
-
Axes to validate.
|
|
29
|
-
"""
|
|
30
|
-
_axes = axes.upper()
|
|
31
|
-
|
|
32
|
-
# Minimum is 2 (XY) and maximum is 4 (TZYX)
|
|
33
|
-
if len(_axes) < 2 or len(_axes) > 4:
|
|
34
|
-
raise ValueError(
|
|
35
|
-
f"Invalid axes {axes}. Must contain at least 2 and at most 4 axes."
|
|
36
|
-
)
|
|
37
|
-
|
|
38
|
-
# all characters must be in REF_AXES = 'STCZYX'
|
|
39
|
-
if not all(s in AXES for s in _axes):
|
|
40
|
-
raise ValueError(f"Invalid axes {axes}. Must be a combination of {AXES}.")
|
|
41
|
-
|
|
42
|
-
# check for repeating characters
|
|
43
|
-
for i, s in enumerate(_axes):
|
|
44
|
-
if i != _axes.rfind(s):
|
|
45
|
-
raise ValueError(
|
|
46
|
-
f"Invalid axes {axes}. Cannot contain duplicate axes"
|
|
47
|
-
f" (got multiple {axes[i]})."
|
|
48
|
-
)
|
|
49
|
-
|
|
50
|
-
# currently no implementation for C
|
|
51
|
-
if "C" in _axes:
|
|
52
|
-
raise NotImplementedError("Currently, C axis is not supported.")
|
|
53
|
-
|
|
54
|
-
# prevent S and T axes at the same time
|
|
55
|
-
if "T" in _axes and "S" in _axes:
|
|
56
|
-
raise NotImplementedError(
|
|
57
|
-
f"Invalid axes {axes}. Cannot contain both S and T axes."
|
|
58
|
-
)
|
|
59
|
-
|
|
60
|
-
# prior: X and Y contiguous (#FancyComments)
|
|
61
|
-
# right now the next check is invalidating this, but in the future, we might
|
|
62
|
-
# allow random order of axes (or at least XY and YX)
|
|
63
|
-
if "XY" not in _axes and "YX" not in _axes:
|
|
64
|
-
raise ValueError(f"Invalid axes {axes}. X and Y must be contiguous.")
|
|
65
|
-
|
|
66
|
-
# check that the axes are in the right order
|
|
67
|
-
for i, s in enumerate(_axes):
|
|
68
|
-
if i < len(_axes) - 1:
|
|
69
|
-
index_s = AXES.find(s)
|
|
70
|
-
index_next = AXES.find(_axes[i + 1])
|
|
71
|
-
|
|
72
|
-
if index_s > index_next:
|
|
73
|
-
raise ValueError(
|
|
74
|
-
f"Invalid axes {axes}. Axes must be in the order {AXES}."
|
|
75
|
-
)
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
def add_axes(input_array: np.ndarray, axes: str) -> np.ndarray:
|
|
79
|
-
"""
|
|
80
|
-
Add missing axes to the input, typically batch and channel.
|
|
81
|
-
|
|
82
|
-
This method validates the axes first. Then it inspects the input array and add
|
|
83
|
-
missing dimensions if necessary.
|
|
84
|
-
|
|
85
|
-
Parameters
|
|
86
|
-
----------
|
|
87
|
-
input_array : np.ndarray
|
|
88
|
-
Input array.
|
|
89
|
-
axes : str
|
|
90
|
-
Axes to add.
|
|
91
|
-
|
|
92
|
-
Returns
|
|
93
|
-
-------
|
|
94
|
-
np.ndarray
|
|
95
|
-
Array with new singleton axes.
|
|
96
|
-
"""
|
|
97
|
-
# validate axes
|
|
98
|
-
check_axes_validity(axes)
|
|
99
|
-
|
|
100
|
-
# is 3D
|
|
101
|
-
is_3D = "Z" in axes
|
|
102
|
-
|
|
103
|
-
# number of dims
|
|
104
|
-
n_dims = 5 if is_3D else 4
|
|
105
|
-
|
|
106
|
-
# array of dim 2, 3 or 4
|
|
107
|
-
if len(input_array.shape) < n_dims:
|
|
108
|
-
if "S" not in axes and "T" not in axes:
|
|
109
|
-
input_array = input_array[np.newaxis, ...]
|
|
110
|
-
|
|
111
|
-
# still missing C dimension
|
|
112
|
-
if len(input_array.shape) < n_dims:
|
|
113
|
-
input_array = input_array[:, np.newaxis, ...]
|
|
114
|
-
|
|
115
|
-
return input_array
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
def check_tiling_validity(tile_shape: List[int], overlaps: List[int]) -> None:
|
|
119
|
-
"""
|
|
120
|
-
Check that the tiling parameters are valid.
|
|
121
|
-
|
|
122
|
-
Parameters
|
|
123
|
-
----------
|
|
124
|
-
tile_shape : List[int]
|
|
125
|
-
Shape of the tiles.
|
|
126
|
-
overlaps : List[int]
|
|
127
|
-
Overlap between tiles.
|
|
128
|
-
|
|
129
|
-
Raises
|
|
130
|
-
------
|
|
131
|
-
ValueError
|
|
132
|
-
If one of the parameters is None.
|
|
133
|
-
ValueError
|
|
134
|
-
If one of the element is zero.
|
|
135
|
-
ValueError
|
|
136
|
-
If one of the element is non-divisible by 2.
|
|
137
|
-
ValueError
|
|
138
|
-
If the number of elements in `overlaps` and `tile_shape` is different.
|
|
139
|
-
ValueError
|
|
140
|
-
If one of the overlaps is larger than the corresponding tile shape.
|
|
141
|
-
"""
|
|
142
|
-
# cannot be None
|
|
143
|
-
if tile_shape is None or overlaps is None:
|
|
144
|
-
raise ValueError(
|
|
145
|
-
"Cannot use tiling without specifying `tile_shape` and "
|
|
146
|
-
"`overlaps`, make sure they have been correctly specified."
|
|
147
|
-
)
|
|
148
|
-
|
|
149
|
-
# non-zero and divisible by two
|
|
150
|
-
for dims_list in [tile_shape, overlaps]:
|
|
151
|
-
for dim in dims_list:
|
|
152
|
-
if dim < 1:
|
|
153
|
-
raise ValueError(f"Entry must be non-null positive (got {dim}).")
|
|
154
|
-
|
|
155
|
-
if dim % 2 != 0:
|
|
156
|
-
raise ValueError(f"Entry must be divisible by 2 (got {dim}).")
|
|
157
|
-
|
|
158
|
-
# same length
|
|
159
|
-
if len(overlaps) != len(tile_shape):
|
|
160
|
-
raise ValueError(
|
|
161
|
-
f"Overlaps ({len(overlaps)}) and tile shape ({len(tile_shape)}) must "
|
|
162
|
-
f"have the same number of dimensions."
|
|
163
|
-
)
|
|
164
|
-
|
|
165
|
-
# overlaps smaller than tile shape
|
|
166
|
-
for overlap, tile_dim in zip(overlaps, tile_shape):
|
|
167
|
-
if overlap >= tile_dim:
|
|
168
|
-
raise ValueError(
|
|
169
|
-
f"Overlap ({overlap}) must be smaller than tile shape ({tile_dim})."
|
|
170
|
-
)
|
careamics/utils/wandb.py
DELETED
|
@@ -1,121 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
A WandB logger for CAREamics.
|
|
3
|
-
|
|
4
|
-
Implements a WandB class for use within the Engine.
|
|
5
|
-
"""
|
|
6
|
-
import sys
|
|
7
|
-
from pathlib import Path
|
|
8
|
-
from typing import Dict, Union
|
|
9
|
-
|
|
10
|
-
import torch
|
|
11
|
-
import wandb
|
|
12
|
-
|
|
13
|
-
from careamics.config import Configuration
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def is_notebook() -> bool:
|
|
17
|
-
"""
|
|
18
|
-
Check if the code is executed from a notebook or a qtconsole.
|
|
19
|
-
|
|
20
|
-
Returns
|
|
21
|
-
-------
|
|
22
|
-
bool
|
|
23
|
-
True if the code is executed from a notebooks, False otherwise.
|
|
24
|
-
"""
|
|
25
|
-
try:
|
|
26
|
-
from IPython import get_ipython
|
|
27
|
-
|
|
28
|
-
shell = get_ipython().__class__.__name__
|
|
29
|
-
if shell == "ZMQInteractiveShell":
|
|
30
|
-
return True # Jupyter notebook or qtconsole
|
|
31
|
-
else:
|
|
32
|
-
return False
|
|
33
|
-
except (NameError, ModuleNotFoundError):
|
|
34
|
-
return False
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
class WandBLogging:
|
|
38
|
-
"""
|
|
39
|
-
WandB logging class.
|
|
40
|
-
|
|
41
|
-
Parameters
|
|
42
|
-
----------
|
|
43
|
-
experiment_name : str
|
|
44
|
-
Name of the experiment.
|
|
45
|
-
log_path : Path
|
|
46
|
-
Path in which to save the WandB log.
|
|
47
|
-
config : Configuration
|
|
48
|
-
Configuration of the model.
|
|
49
|
-
model_to_watch : torch.nn.Module
|
|
50
|
-
Model.
|
|
51
|
-
save_code : bool, optional
|
|
52
|
-
Whether to save the code, by default True.
|
|
53
|
-
"""
|
|
54
|
-
|
|
55
|
-
def __init__(
|
|
56
|
-
self,
|
|
57
|
-
experiment_name: str,
|
|
58
|
-
log_path: Path,
|
|
59
|
-
config: Configuration,
|
|
60
|
-
model_to_watch: torch.nn.Module,
|
|
61
|
-
save_code: bool = True,
|
|
62
|
-
):
|
|
63
|
-
"""
|
|
64
|
-
Constructor.
|
|
65
|
-
|
|
66
|
-
Parameters
|
|
67
|
-
----------
|
|
68
|
-
experiment_name : str
|
|
69
|
-
Name of the experiment.
|
|
70
|
-
log_path : Path
|
|
71
|
-
Path in which to save the WandB log.
|
|
72
|
-
config : Configuration
|
|
73
|
-
Configuration of the model.
|
|
74
|
-
model_to_watch : torch.nn.Module
|
|
75
|
-
Model.
|
|
76
|
-
save_code : bool, optional
|
|
77
|
-
Whether to save the code, by default True.
|
|
78
|
-
"""
|
|
79
|
-
self.run = wandb.init(
|
|
80
|
-
project="careamics-restoration",
|
|
81
|
-
dir=log_path,
|
|
82
|
-
name=experiment_name,
|
|
83
|
-
config=config.model_dump() if config else None,
|
|
84
|
-
# save_code=save_code,
|
|
85
|
-
)
|
|
86
|
-
if model_to_watch:
|
|
87
|
-
wandb.watch(model_to_watch, log="all", log_freq=1)
|
|
88
|
-
if save_code:
|
|
89
|
-
if is_notebook():
|
|
90
|
-
# Get all sys path and select the root
|
|
91
|
-
code_path = Path([p for p in sys.path if "caremics" in p][-1]).parent
|
|
92
|
-
else:
|
|
93
|
-
code_path = Path("../")
|
|
94
|
-
self.log_code(code_path)
|
|
95
|
-
|
|
96
|
-
def log_metrics(self, metric_dict: Dict) -> None:
|
|
97
|
-
"""
|
|
98
|
-
Log metrics to wandb.
|
|
99
|
-
|
|
100
|
-
Parameters
|
|
101
|
-
----------
|
|
102
|
-
metric_dict : Dict
|
|
103
|
-
New metrics entry.
|
|
104
|
-
"""
|
|
105
|
-
self.run.log(metric_dict, commit=True)
|
|
106
|
-
|
|
107
|
-
def log_code(self, code_path: Union[str, Path]) -> None:
|
|
108
|
-
"""
|
|
109
|
-
Log code to wandb.
|
|
110
|
-
|
|
111
|
-
Parameters
|
|
112
|
-
----------
|
|
113
|
-
code_path : Union[str, Path]
|
|
114
|
-
Path to the code.
|
|
115
|
-
"""
|
|
116
|
-
self.run.log_code(
|
|
117
|
-
root=code_path,
|
|
118
|
-
include_fn=lambda path: path.endswith(".py")
|
|
119
|
-
or path.endswith(".yml")
|
|
120
|
-
or path.endswith(".yaml"),
|
|
121
|
-
)
|