careamics 0.1.0rc1__py3-none-any.whl → 0.1.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (132) hide show
  1. careamics/__init__.py +14 -4
  2. careamics/callbacks/__init__.py +6 -0
  3. careamics/callbacks/hyperparameters_callback.py +42 -0
  4. careamics/callbacks/progress_bar_callback.py +57 -0
  5. careamics/careamist.py +761 -0
  6. careamics/config/__init__.py +27 -3
  7. careamics/config/algorithm_model.py +167 -0
  8. careamics/config/architectures/__init__.py +17 -0
  9. careamics/config/architectures/architecture_model.py +29 -0
  10. careamics/config/architectures/custom_model.py +150 -0
  11. careamics/config/architectures/register_model.py +101 -0
  12. careamics/config/architectures/unet_model.py +96 -0
  13. careamics/config/architectures/vae_model.py +39 -0
  14. careamics/config/callback_model.py +92 -0
  15. careamics/config/configuration_factory.py +460 -0
  16. careamics/config/configuration_model.py +596 -0
  17. careamics/config/data_model.py +555 -0
  18. careamics/config/inference_model.py +283 -0
  19. careamics/config/noise_models.py +162 -0
  20. careamics/config/optimizer_models.py +181 -0
  21. careamics/config/references/__init__.py +45 -0
  22. careamics/config/references/algorithm_descriptions.py +131 -0
  23. careamics/config/references/references.py +38 -0
  24. careamics/config/support/__init__.py +33 -0
  25. careamics/config/support/supported_activations.py +24 -0
  26. careamics/config/support/supported_algorithms.py +18 -0
  27. careamics/config/support/supported_architectures.py +18 -0
  28. careamics/config/support/supported_data.py +82 -0
  29. careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
  30. careamics/config/support/supported_loggers.py +8 -0
  31. careamics/config/support/supported_losses.py +25 -0
  32. careamics/config/support/supported_optimizers.py +55 -0
  33. careamics/config/support/supported_pixel_manipulations.py +15 -0
  34. careamics/config/support/supported_struct_axis.py +19 -0
  35. careamics/config/support/supported_transforms.py +23 -0
  36. careamics/config/tile_information.py +104 -0
  37. careamics/config/training_model.py +65 -0
  38. careamics/config/transformations/__init__.py +14 -0
  39. careamics/config/transformations/n2v_manipulate_model.py +63 -0
  40. careamics/config/transformations/nd_flip_model.py +32 -0
  41. careamics/config/transformations/normalize_model.py +31 -0
  42. careamics/config/transformations/transform_model.py +44 -0
  43. careamics/config/transformations/xy_random_rotate90_model.py +29 -0
  44. careamics/config/validators/__init__.py +5 -0
  45. careamics/config/validators/validator_utils.py +100 -0
  46. careamics/conftest.py +26 -0
  47. careamics/dataset/__init__.py +5 -0
  48. careamics/dataset/dataset_utils/__init__.py +19 -0
  49. careamics/dataset/dataset_utils/dataset_utils.py +100 -0
  50. careamics/dataset/dataset_utils/file_utils.py +140 -0
  51. careamics/dataset/dataset_utils/read_tiff.py +61 -0
  52. careamics/dataset/dataset_utils/read_utils.py +25 -0
  53. careamics/dataset/dataset_utils/read_zarr.py +56 -0
  54. careamics/dataset/in_memory_dataset.py +321 -131
  55. careamics/dataset/iterable_dataset.py +416 -0
  56. careamics/dataset/patching/__init__.py +8 -0
  57. careamics/dataset/patching/patch_transform.py +44 -0
  58. careamics/dataset/patching/patching.py +212 -0
  59. careamics/dataset/patching/random_patching.py +190 -0
  60. careamics/dataset/patching/sequential_patching.py +206 -0
  61. careamics/dataset/patching/tiled_patching.py +158 -0
  62. careamics/dataset/patching/validate_patch_dimension.py +60 -0
  63. careamics/dataset/zarr_dataset.py +149 -0
  64. careamics/lightning_datamodule.py +665 -0
  65. careamics/lightning_module.py +292 -0
  66. careamics/lightning_prediction_datamodule.py +390 -0
  67. careamics/lightning_prediction_loop.py +116 -0
  68. careamics/losses/__init__.py +4 -1
  69. careamics/losses/loss_factory.py +24 -13
  70. careamics/losses/losses.py +65 -5
  71. careamics/losses/noise_model_factory.py +40 -0
  72. careamics/losses/noise_models.py +524 -0
  73. careamics/model_io/__init__.py +8 -0
  74. careamics/model_io/bioimage/__init__.py +11 -0
  75. careamics/model_io/bioimage/_readme_factory.py +120 -0
  76. careamics/model_io/bioimage/bioimage_utils.py +48 -0
  77. careamics/model_io/bioimage/model_description.py +318 -0
  78. careamics/model_io/bmz_io.py +231 -0
  79. careamics/model_io/model_io_utils.py +80 -0
  80. careamics/models/__init__.py +4 -1
  81. careamics/models/activation.py +35 -0
  82. careamics/models/layers.py +244 -0
  83. careamics/models/model_factory.py +21 -202
  84. careamics/models/unet.py +46 -20
  85. careamics/prediction/__init__.py +1 -3
  86. careamics/prediction/stitch_prediction.py +73 -0
  87. careamics/transforms/__init__.py +41 -0
  88. careamics/transforms/n2v_manipulate.py +113 -0
  89. careamics/transforms/nd_flip.py +93 -0
  90. careamics/transforms/normalize.py +109 -0
  91. careamics/transforms/pixel_manipulation.py +383 -0
  92. careamics/transforms/struct_mask_parameters.py +18 -0
  93. careamics/transforms/tta.py +74 -0
  94. careamics/transforms/xy_random_rotate90.py +95 -0
  95. careamics/utils/__init__.py +10 -13
  96. careamics/utils/base_enum.py +32 -0
  97. careamics/utils/context.py +22 -2
  98. careamics/utils/metrics.py +0 -46
  99. careamics/utils/path_utils.py +24 -0
  100. careamics/utils/ram.py +13 -0
  101. careamics/utils/receptive_field.py +102 -0
  102. careamics/utils/running_stats.py +43 -0
  103. careamics/utils/torch_utils.py +89 -56
  104. careamics-0.1.0rc3.dist-info/METADATA +122 -0
  105. careamics-0.1.0rc3.dist-info/RECORD +109 -0
  106. {careamics-0.1.0rc1.dist-info → careamics-0.1.0rc3.dist-info}/WHEEL +1 -1
  107. careamics/bioimage/__init__.py +0 -15
  108. careamics/bioimage/docs/Noise2Void.md +0 -5
  109. careamics/bioimage/docs/__init__.py +0 -1
  110. careamics/bioimage/io.py +0 -271
  111. careamics/config/algorithm.py +0 -231
  112. careamics/config/config.py +0 -296
  113. careamics/config/config_filter.py +0 -44
  114. careamics/config/data.py +0 -194
  115. careamics/config/torch_optim.py +0 -118
  116. careamics/config/training.py +0 -534
  117. careamics/dataset/dataset_utils.py +0 -115
  118. careamics/dataset/patching.py +0 -493
  119. careamics/dataset/prepare_dataset.py +0 -174
  120. careamics/dataset/tiff_dataset.py +0 -211
  121. careamics/engine.py +0 -954
  122. careamics/manipulation/__init__.py +0 -4
  123. careamics/manipulation/pixel_manipulation.py +0 -158
  124. careamics/prediction/prediction_utils.py +0 -102
  125. careamics/utils/ascii_logo.txt +0 -9
  126. careamics/utils/augment.py +0 -65
  127. careamics/utils/normalization.py +0 -55
  128. careamics/utils/validators.py +0 -156
  129. careamics/utils/wandb.py +0 -121
  130. careamics-0.1.0rc1.dist-info/METADATA +0 -80
  131. careamics-0.1.0rc1.dist-info/RECORD +0 -46
  132. {careamics-0.1.0rc1.dist-info → careamics-0.1.0rc3.dist-info}/licenses/LICENSE +0 -0
@@ -1,4 +0,0 @@
1
- """Pixel manipulation functions for N2V."""
2
-
3
-
4
- from .pixel_manipulation import default_manipulate as default_manipulate
@@ -1,158 +0,0 @@
1
- """
2
- Pixel manipulation methods.
3
-
4
- Pixel manipulation is used in N2V and similar algorithm to replace the value of
5
- masked pixels.
6
- """
7
- from typing import Callable, Optional, Tuple
8
-
9
- import numpy as np
10
-
11
-
12
- def _odd_jitter_func(step: float, rng: np.random.Generator) -> np.ndarray:
13
- """
14
- Randomly sample a jitter to be applied to the masking grid.
15
-
16
- This is done to account for cases where the step size is not an integer.
17
-
18
- Parameters
19
- ----------
20
- step : float
21
- Step size of the grid, output of np.linspace.
22
- rng : np.random.Generator
23
- Random number generator.
24
-
25
- Returns
26
- -------
27
- np.ndarray
28
- Array of random jitter to be added to the grid.
29
- """
30
- # Define the random jitter to be added to the grid
31
- odd_jitter = np.where(np.floor(step) == step, 0, rng.integers(0, 2))
32
-
33
- # Round the step size to the nearest integer depending on the jitter
34
- return np.floor(step) if odd_jitter == 0 else np.ceil(step)
35
-
36
-
37
- def get_stratified_coords(
38
- mask_pixel_perc: float,
39
- shape: Tuple[int, ...],
40
- ) -> np.ndarray:
41
- """
42
- Generate coordinates of the pixels to mask.
43
-
44
- Randomly selects the coordinates of the pixels to mask in a stratified way, i.e.
45
- the distance between masked pixels is approximately the same.
46
-
47
- Parameters
48
- ----------
49
- mask_pixel_perc : float
50
- Actual (quasi) percentage of masked pixels across the whole image. Used in
51
- calculating the distance between masked pixels across each axis.
52
- shape : Tuple[int, ...]
53
- Shape of the input patch.
54
-
55
- Returns
56
- -------
57
- np.ndarray
58
- Array of coordinates of the masked pixels.
59
- """
60
- rng = np.random.default_rng()
61
-
62
- # Define the approximate distance between masked pixels
63
- mask_pixel_distance = np.round((100 / mask_pixel_perc) ** (1 / len(shape))).astype(
64
- np.int32
65
- )
66
-
67
- # Define a grid of coordinates for each axis in the input patch and the step size
68
- pixel_coords = []
69
- for axis_size in shape:
70
- # make sure axis size is evenly divisible by box size
71
- num_pixels = int(np.ceil(axis_size / mask_pixel_distance))
72
- axis_pixel_coords, step = np.linspace(
73
- 0, axis_size, num_pixels, dtype=np.int32, endpoint=False, retstep=True
74
- )
75
- # explain
76
- pixel_coords.append(axis_pixel_coords.T)
77
-
78
- # Create a meshgrid of coordinates for each axis in the input patch
79
- coordinate_grid_list = np.meshgrid(*pixel_coords)
80
- coordinate_grid = np.array(coordinate_grid_list).reshape(len(shape), -1).T
81
-
82
- grid_random_increment = rng.integers(
83
- _odd_jitter_func(float(step), rng)
84
- * np.ones_like(coordinate_grid).astype(np.int32)
85
- - 1,
86
- size=coordinate_grid.shape,
87
- endpoint=True,
88
- )
89
- coordinate_grid += grid_random_increment
90
- coordinate_grid = np.clip(coordinate_grid, 0, np.array(shape) - 1)
91
- return coordinate_grid
92
-
93
-
94
- def default_manipulate(
95
- patch: np.ndarray,
96
- mask_pixel_percentage: float,
97
- roi_size: int = 11,
98
- augmentations: Optional[Callable] = None,
99
- ) -> Tuple[np.ndarray, ...]:
100
- """
101
- Manipulate pixel in a patch, i.e. replace the masked value.
102
-
103
- Parameters
104
- ----------
105
- patch : np.ndarray
106
- Image patch, 2D or 3D, shape (y, x) or (z, y, x).
107
- mask_pixel_percentage : floar
108
- Approximate percentage of pixels to be masked.
109
- roi_size : int
110
- Size of the ROI the new pixel value is sampled from, by default 11.
111
- augmentations : Callable, optional
112
- Augmentations to apply, by default None.
113
-
114
- Returns
115
- -------
116
- Tuple[np.ndarray]
117
- Tuple containing the manipulated patch, the original patch and the mask.
118
- """
119
- original_patch = patch.copy()
120
-
121
- # Get the coordinates of the pixels to be replaced
122
- roi_centers = get_stratified_coords(mask_pixel_percentage, patch.shape)
123
- rng = np.random.default_rng()
124
-
125
- # Generate coordinate grid for ROI
126
- roi_span_full = np.arange(-np.floor(roi_size / 2), np.ceil(roi_size / 2)).astype(
127
- np.int32
128
- )
129
- # Remove the center pixel from the grid
130
- roi_span_wo_center = roi_span_full[roi_span_full != 0]
131
-
132
- # Randomly select coordinates from the grid
133
- random_increment = rng.choice(roi_span_wo_center, size=roi_centers.shape)
134
-
135
- # Clip the coordinates to the patch size
136
- replacement_coords = np.clip(
137
- roi_centers + random_increment,
138
- 0,
139
- [patch.shape[i] - 1 for i in range(len(patch.shape))],
140
- )
141
- # Get the replacement pixels from all rois
142
- replacement_pixels = patch[tuple(replacement_coords.T.tolist())]
143
-
144
- # Replace the original pixels with the replacement pixels
145
- patch[tuple(roi_centers.T.tolist())] = replacement_pixels
146
- mask = np.where(patch != original_patch, 1, 0).astype(np.uint8)
147
-
148
- patch, original_patch, mask = (
149
- (patch, original_patch, mask)
150
- if augmentations is None
151
- else augmentations(patch, original_patch, mask)
152
- )
153
-
154
- return (
155
- np.expand_dims(patch, 0),
156
- np.expand_dims(original_patch, 0),
157
- np.expand_dims(mask, 0),
158
- )
@@ -1,102 +0,0 @@
1
- """
2
- Prediction convenience functions.
3
-
4
- These functions are used during prediction.
5
- """
6
- from typing import List
7
-
8
- import numpy as np
9
- import torch
10
-
11
-
12
- def stitch_prediction(
13
- tiles: List[np.ndarray],
14
- stitching_data: List,
15
- ) -> np.ndarray:
16
- """
17
- Stitch tiles back together to form a full image.
18
-
19
- Parameters
20
- ----------
21
- tiles : List[Tuple[np.ndarray, List[int]]]
22
- Cropped tiles and their respective stitching coordinates.
23
- stitching_data : List
24
- List of coordinates obtained from
25
- dataset.tiling.compute_crop_and_stitch_coords_1d.
26
-
27
- Returns
28
- -------
29
- np.ndarray
30
- Full image.
31
- """
32
- # Get whole sample shape
33
- input_shape = stitching_data[0][0]
34
- predicted_image = np.zeros(input_shape, dtype=np.float32)
35
- for tile, (_, overlap_crop_coords, stitch_coords) in zip(tiles, stitching_data):
36
- # Compute coordinates for cropping predicted tile
37
- slices = tuple([slice(c[0], c[1]) for c in overlap_crop_coords])
38
-
39
- # Crop predited tile according to overlap coordinates
40
- cropped_tile = tile.squeeze()[slices]
41
-
42
- # Insert cropped tile into predicted image using stitch coordinates
43
- predicted_image[
44
- (..., *[slice(c[0], c[1]) for c in stitch_coords])
45
- ] = cropped_tile
46
- return predicted_image
47
-
48
-
49
- def tta_forward(x: np.ndarray) -> List:
50
- """
51
- Augment 8-fold an array.
52
-
53
- The augmentation is performed using all 90 deg rotations and their flipped version,
54
- as well as the original image flipped.
55
-
56
- Parameters
57
- ----------
58
- x : torch.tensor
59
- Data to augment.
60
-
61
- Returns
62
- -------
63
- List
64
- Stack of augmented images.
65
- """
66
- x_aug = [
67
- x,
68
- torch.rot90(x, 1, dims=(2, 3)),
69
- torch.rot90(x, 2, dims=(2, 3)),
70
- torch.rot90(x, 3, dims=(2, 3)),
71
- ]
72
- x_aug_flip = x_aug.copy()
73
- for x_ in x_aug:
74
- x_aug_flip.append(torch.flip(x_, dims=(1, 3)))
75
- return x_aug_flip
76
-
77
-
78
- def tta_backward(x_aug: List) -> np.ndarray:
79
- """
80
- Invert `tta_forward` and average the 8 images.
81
-
82
- Parameters
83
- ----------
84
- x_aug : List
85
- Stack of 8-fold augmented images.
86
-
87
- Returns
88
- -------
89
- np.ndarray
90
- Average of de-augmented x_aug.
91
- """
92
- x_deaug = [
93
- x_aug[0],
94
- np.rot90(x_aug[1], -1),
95
- np.rot90(x_aug[2], -2),
96
- np.rot90(x_aug[3], -3),
97
- np.fliplr(x_aug[4]),
98
- np.rot90(np.fliplr(x_aug[5]), -1),
99
- np.rot90(np.fliplr(x_aug[6]), -2),
100
- np.rot90(np.fliplr(x_aug[7]), -3),
101
- ]
102
- return np.mean(x_deaug, 0)
@@ -1,9 +0,0 @@
1
- ...... ...... ........ ........ ....
2
- -+++----+- -+++--+++- :+++---+++: :+++----- .--:
3
- .+++ .: +++. .+++. :+++ :+++ :+++ :------. .---:----..:----. :--- :----: :----:.
4
- .+++ .+++. .+++. :+++ -++= :+++ +=....=+++ :+++-..=+++-..=++= -+++ .+++-..++ +++-..=+.
5
- .+++ .++++++++++. :++++++++=. :++++++: .+++. :+++ :+++ -+++ -+++ :+++ .+++=.
6
- .+++ .+++. .+++. :+++ -+++ :+++ :=++==++++. :+++ :+++ -+++ -+++ :+++ .-=+++=:
7
- .+++ .. .+++. .+++. :+++ :+++ :+++ .+++. .+++. :+++ :+++ -+++ -+++ :+++ .. .. :+++.
8
- -++=-::-+= .+++. .+++. :+++ :+++ :+++-:::: =++=--=+++. :+++ :+++ -+++ -+++ =++=:-+= =+-:=++=
9
- ...... ... ... ... ... ........ .... ... ... ... .... .... .... .....
@@ -1,65 +0,0 @@
1
- """Augmentation module."""
2
- from typing import Tuple
3
-
4
- import numpy as np
5
-
6
-
7
- # TODO: unused?
8
- def _flip_and_rotate(
9
- image: np.ndarray, rotate_state: int, flip_state: int
10
- ) -> np.ndarray:
11
- """
12
- Apply the given number of 90 degrees rotations and flip to an array.
13
-
14
- Parameters
15
- ----------
16
- image : np.ndarray
17
- Array containing single image or patch, 2D or 3D.
18
- rotate_state : int
19
- Number of 90 degree rotations to apply.
20
- flip_state : int
21
- 0 or 1, whether to flip the array or not.
22
-
23
- Returns
24
- -------
25
- np.ndarray
26
- Flipped and rotated array.
27
- """
28
- rotated = np.rot90(image, k=rotate_state, axes=(-2, -1))
29
- flipped = np.flip(rotated, axis=-1) if flip_state == 1 else rotated
30
- return flipped.copy()
31
-
32
-
33
- def augment_batch(
34
- patch: np.ndarray,
35
- original_image: np.ndarray,
36
- mask: np.ndarray,
37
- seed: int = 42,
38
- ) -> Tuple[np.ndarray, ...]:
39
- """
40
- Apply augmentation function to patches and masks.
41
-
42
- Parameters
43
- ----------
44
- patch : np.ndarray
45
- Array containing single image or patch, 2D or 3D with masked pixels.
46
- original_image : np.ndarray
47
- Array containing original image or patch, 2D or 3D.
48
- mask : np.ndarray
49
- Array containing only masked pixels, 2D or 3D.
50
- seed : int, optional
51
- Seed for random number generator, controls the rotation and falipping.
52
-
53
- Returns
54
- -------
55
- Tuple[np.ndarray, ...]
56
- Tuple of augmented arrays.
57
- """
58
- rng = np.random.default_rng(seed=seed)
59
- rotate_state = rng.integers(0, 4)
60
- flip_state = rng.integers(0, 2)
61
- return (
62
- _flip_and_rotate(patch, rotate_state, flip_state),
63
- _flip_and_rotate(original_image, rotate_state, flip_state),
64
- _flip_and_rotate(mask, rotate_state, flip_state),
65
- )
@@ -1,55 +0,0 @@
1
- """
2
- Normalization submodule.
3
-
4
- These methods are used to normalize and denormalize images.
5
- """
6
- import numpy as np
7
-
8
-
9
- def normalize(img: np.ndarray, mean: float, std: float) -> np.ndarray:
10
- """
11
- Normalize an image using mean and standard deviation.
12
-
13
- Images are normalised by subtracting the mean and dividing by the standard
14
- deviation.
15
-
16
- Parameters
17
- ----------
18
- img : np.ndarray
19
- Image to normalize.
20
- mean : float
21
- Mean.
22
- std : float
23
- Standard deviation.
24
-
25
- Returns
26
- -------
27
- np.ndarray
28
- Normalized array.
29
- """
30
- zero_mean = img - mean
31
- return zero_mean / std
32
-
33
-
34
- def denormalize(img: np.ndarray, mean: float, std: float) -> np.ndarray:
35
- """
36
- Denormalize an image using mean and standard deviation.
37
-
38
- Images are denormalised by multiplying by the standard deviation and adding the
39
- mean.
40
-
41
- Parameters
42
- ----------
43
- img : np.ndarray
44
- Image to denormalize.
45
- mean : float
46
- Mean.
47
- std : float
48
- Standard deviation.
49
-
50
- Returns
51
- -------
52
- np.ndarray
53
- Denormalized array.
54
- """
55
- return img * std + mean
@@ -1,156 +0,0 @@
1
- """
2
- Validator functions.
3
-
4
- These functions are used to validate dimensions and axes of inputs.
5
- """
6
- from typing import List
7
-
8
- import numpy as np
9
-
10
- AXES = "STCZYX"
11
-
12
-
13
- def check_axes_validity(axes: str) -> bool:
14
- """
15
- Sanity check on axes.
16
-
17
- The constraints on the axes are the following:
18
- - must be a combination of 'STCZYX'
19
- - must not contain duplicates
20
- - must contain at least 2 contiguous axes: X and Y
21
- - must contain at most 4 axes
22
- - cannot contain both S and T axes
23
- - C is currently not allowed
24
-
25
- Parameters
26
- ----------
27
- axes : str
28
- Axes to validate.
29
-
30
- Returns
31
- -------
32
- bool
33
- True if axes are valid, False otherwise.
34
- """
35
- _axes = axes.upper()
36
-
37
- # Minimum is 2 (XY) and maximum is 4 (TZYX)
38
- if len(_axes) < 2 or len(_axes) > 4:
39
- raise ValueError(
40
- f"Invalid axes {axes}. Must contain at least 2 and at most 4 axes."
41
- )
42
-
43
- # all characters must be in REF_AXES = 'STCZYX'
44
- if not all(s in AXES for s in _axes):
45
- raise ValueError(f"Invalid axes {axes}. Must be a combination of {AXES}.")
46
-
47
- # check for repeating characters
48
- for i, s in enumerate(_axes):
49
- if i != _axes.rfind(s):
50
- raise ValueError(
51
- f"Invalid axes {axes}. Cannot contain duplicate axes"
52
- f" (got multiple {axes[i]})."
53
- )
54
-
55
- # currently no implementation for C
56
- if "C" in _axes:
57
- raise NotImplementedError("Currently, C axis is not supported.")
58
-
59
- # prevent S and T axes together
60
- if "T" in _axes and "S" in _axes:
61
- raise NotImplementedError(
62
- f"Invalid axes {axes}. Cannot contain both S and T axes."
63
- )
64
-
65
- # prior: X and Y contiguous (#FancyComments)
66
- # right now the next check is invalidating this, but in the future, we might
67
- # allow random order of axes (or at least XY and YX)
68
- if "XY" not in _axes and "YX" not in _axes:
69
- raise ValueError(f"Invalid axes {axes}. X and Y must be contiguous.")
70
-
71
- # check that the axes are in the right order
72
- for i, s in enumerate(_axes):
73
- if i < len(_axes) - 1:
74
- index_s = AXES.find(s)
75
- index_next = AXES.find(_axes[i + 1])
76
-
77
- if index_s > index_next:
78
- raise ValueError(
79
- f"Invalid axes {axes}. Axes must be in the order {AXES}."
80
- )
81
-
82
- return True
83
-
84
-
85
- def check_array_validity(array: np.ndarray, axes: str) -> None:
86
- """
87
- Check that the numpy array is compatible with the axes.
88
-
89
- Parameters
90
- ----------
91
- array : np.ndarray
92
- Numpy array.
93
- axes : str
94
- Valid axes (see check_axes_validity).
95
- """
96
- if len(array.shape) - 2 != len(axes):
97
- raise ValueError(
98
- f"Array has {len(array.shape)} dimensions, but axes are {len(axes)}."
99
- f"Externally provided arrays must have extra dimensions for batch and"
100
- f"channel to be compatible with the batchnorm layers."
101
- )
102
-
103
-
104
- def check_tiling_validity(tile_shape: List[int], overlaps: List[int]) -> None:
105
- """
106
- Check that the tiling parameters are valid.
107
-
108
- Parameters
109
- ----------
110
- tile_shape : List[int]
111
- Shape of the tiles.
112
- overlaps : List[int]
113
- Overlap between tiles.
114
-
115
- Raises
116
- ------
117
- ValueError
118
- If one of the parameters is None.
119
- ValueError
120
- If one of the element is zero.
121
- ValueError
122
- If one of the element is non-divisible by 2.
123
- ValueError
124
- If the number of elements in `overlaps` and `tile_shape` is different.
125
- ValueError
126
- If one of the overlaps is larger than the corresponding tile shape.
127
- """
128
- # cannot be None
129
- if tile_shape is None or overlaps is None:
130
- raise ValueError(
131
- "Cannot use tiling without specifying `tile_shape` and "
132
- "`overlaps`, make sure they have been correctly specified."
133
- )
134
-
135
- # non-zero and divisible by two
136
- for dims_list in [tile_shape, overlaps]:
137
- for dim in dims_list:
138
- if dim < 1:
139
- raise ValueError(f"Entry must be non-null positive (got {dim}).")
140
-
141
- if dim % 2 != 0:
142
- raise ValueError(f"Entry must be divisible by 2 (got {dim}).")
143
-
144
- # same length
145
- if len(overlaps) != len(tile_shape):
146
- raise ValueError(
147
- f"Overlaps ({len(overlaps)}) and tile shape ({len(tile_shape)}) must "
148
- f"have the same number of dimensions."
149
- )
150
-
151
- # overlaps smaller than tile shape
152
- for overlap, tile_dim in zip(overlaps, tile_shape):
153
- if overlap >= tile_dim:
154
- raise ValueError(
155
- f"Overlap ({overlap}) must be smaller than tile shape ({tile_dim})."
156
- )
careamics/utils/wandb.py DELETED
@@ -1,121 +0,0 @@
1
- """
2
- A WandB logger for CAREamics.
3
-
4
- Implements a WandB class for use within the Engine.
5
- """
6
- import sys
7
- from pathlib import Path
8
- from typing import Dict, Union
9
-
10
- import torch
11
- import wandb
12
-
13
- from ..config import Configuration
14
-
15
-
16
- def is_notebook() -> bool:
17
- """
18
- Check if the code is executed from a notebook or a qtconsole.
19
-
20
- Returns
21
- -------
22
- bool
23
- True if the code is executed from a notebooks, False otherwise.
24
- """
25
- try:
26
- from IPython import get_ipython
27
-
28
- shell = get_ipython().__class__.__name__
29
- if shell == "ZMQInteractiveShell":
30
- return True # Jupyter notebook or qtconsole
31
- else:
32
- return False
33
- except (NameError, ModuleNotFoundError):
34
- return False
35
-
36
-
37
- class WandBLogging:
38
- """
39
- WandB logging class.
40
-
41
- Parameters
42
- ----------
43
- experiment_name : str
44
- Name of the experiment.
45
- log_path : Path
46
- Path in which to save the WandB log.
47
- config : Configuration
48
- Configuration of the model.
49
- model_to_watch : torch.nn.Module
50
- Model.
51
- save_code : bool, optional
52
- Whether to save the code, by default True.
53
- """
54
-
55
- def __init__(
56
- self,
57
- experiment_name: str,
58
- log_path: Path,
59
- config: Configuration,
60
- model_to_watch: torch.nn.Module,
61
- save_code: bool = True,
62
- ):
63
- """
64
- Constructor.
65
-
66
- Parameters
67
- ----------
68
- experiment_name : str
69
- Name of the experiment.
70
- log_path : Path
71
- Path in which to save the WandB log.
72
- config : Configuration
73
- Configuration of the model.
74
- model_to_watch : torch.nn.Module
75
- Model.
76
- save_code : bool, optional
77
- Whether to save the code, by default True.
78
- """
79
- self.run = wandb.init(
80
- project="careamics-restoration",
81
- dir=log_path,
82
- name=experiment_name,
83
- config=config.model_dump() if config else None,
84
- # save_code=save_code,
85
- )
86
- if model_to_watch:
87
- wandb.watch(model_to_watch, log="all", log_freq=1)
88
- if save_code:
89
- if is_notebook():
90
- # Get all sys path and select the root
91
- code_path = Path([p for p in sys.path if "caremics" in p][-1]).parent
92
- else:
93
- code_path = Path("../")
94
- self.log_code(code_path)
95
-
96
- def log_metrics(self, metric_dict: Dict) -> None:
97
- """
98
- Log metrics to wandb.
99
-
100
- Parameters
101
- ----------
102
- metric_dict : Dict
103
- New metrics entry.
104
- """
105
- self.run.log(metric_dict, commit=True)
106
-
107
- def log_code(self, code_path: Union[str, Path]) -> None:
108
- """
109
- Log code to wandb.
110
-
111
- Parameters
112
- ----------
113
- code_path : Union[str, Path]
114
- Path to the code.
115
- """
116
- self.run.log_code(
117
- root=code_path,
118
- include_fn=lambda path: path.endswith(".py")
119
- or path.endswith(".yml")
120
- or path.endswith(".yaml"),
121
- )