careamics 0.1.0rc4__py3-none-any.whl → 0.1.0rc6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (103) hide show
  1. careamics/callbacks/hyperparameters_callback.py +10 -3
  2. careamics/callbacks/progress_bar_callback.py +37 -4
  3. careamics/careamist.py +92 -55
  4. careamics/config/__init__.py +0 -1
  5. careamics/config/algorithm_model.py +5 -3
  6. careamics/config/architectures/architecture_model.py +7 -0
  7. careamics/config/architectures/custom_model.py +8 -1
  8. careamics/config/architectures/register_model.py +3 -1
  9. careamics/config/architectures/unet_model.py +3 -0
  10. careamics/config/architectures/vae_model.py +2 -0
  11. careamics/config/callback_model.py +4 -15
  12. careamics/config/configuration_example.py +4 -4
  13. careamics/config/configuration_factory.py +113 -55
  14. careamics/config/configuration_model.py +14 -16
  15. careamics/config/data_model.py +63 -165
  16. careamics/config/inference_model.py +9 -75
  17. careamics/config/optimizer_models.py +4 -4
  18. careamics/config/references/algorithm_descriptions.py +1 -0
  19. careamics/config/references/references.py +1 -0
  20. careamics/config/support/__init__.py +0 -2
  21. careamics/config/support/supported_activations.py +2 -0
  22. careamics/config/support/supported_algorithms.py +3 -1
  23. careamics/config/support/supported_architectures.py +2 -0
  24. careamics/config/support/supported_data.py +2 -0
  25. careamics/config/support/supported_loggers.py +2 -0
  26. careamics/config/support/supported_losses.py +2 -0
  27. careamics/config/support/supported_optimizers.py +2 -0
  28. careamics/config/support/supported_pixel_manipulations.py +3 -3
  29. careamics/config/support/supported_struct_axis.py +2 -0
  30. careamics/config/support/supported_transforms.py +4 -15
  31. careamics/config/tile_information.py +2 -0
  32. careamics/config/training_model.py +1 -0
  33. careamics/config/transformations/__init__.py +3 -2
  34. careamics/config/transformations/n2v_manipulate_model.py +1 -0
  35. careamics/config/transformations/normalize_model.py +1 -0
  36. careamics/config/transformations/transform_model.py +1 -0
  37. careamics/config/transformations/xy_flip_model.py +43 -0
  38. careamics/config/transformations/xy_random_rotate90_model.py +13 -7
  39. careamics/config/validators/validator_utils.py +1 -0
  40. careamics/conftest.py +13 -0
  41. careamics/dataset/dataset_utils/__init__.py +0 -1
  42. careamics/dataset/dataset_utils/dataset_utils.py +5 -4
  43. careamics/dataset/dataset_utils/file_utils.py +4 -3
  44. careamics/dataset/dataset_utils/read_tiff.py +6 -2
  45. careamics/dataset/dataset_utils/read_utils.py +2 -0
  46. careamics/dataset/dataset_utils/read_zarr.py +11 -7
  47. careamics/dataset/in_memory_dataset.py +84 -76
  48. careamics/dataset/iterable_dataset.py +166 -134
  49. careamics/dataset/patching/__init__.py +0 -7
  50. careamics/dataset/patching/patching.py +56 -14
  51. careamics/dataset/patching/random_patching.py +8 -2
  52. careamics/dataset/patching/sequential_patching.py +20 -14
  53. careamics/dataset/patching/tiled_patching.py +13 -7
  54. careamics/dataset/patching/validate_patch_dimension.py +2 -0
  55. careamics/dataset/zarr_dataset.py +2 -0
  56. careamics/lightning_datamodule.py +63 -41
  57. careamics/lightning_module.py +9 -3
  58. careamics/lightning_prediction_datamodule.py +15 -20
  59. careamics/lightning_prediction_loop.py +8 -6
  60. careamics/losses/__init__.py +1 -3
  61. careamics/losses/loss_factory.py +2 -1
  62. careamics/losses/losses.py +11 -7
  63. careamics/model_io/__init__.py +0 -1
  64. careamics/model_io/bioimage/_readme_factory.py +2 -1
  65. careamics/model_io/bioimage/bioimage_utils.py +1 -0
  66. careamics/model_io/bioimage/model_description.py +1 -0
  67. careamics/model_io/bmz_io.py +4 -3
  68. careamics/models/activation.py +2 -0
  69. careamics/models/layers.py +122 -25
  70. careamics/models/model_factory.py +2 -1
  71. careamics/models/unet.py +114 -19
  72. careamics/prediction/stitch_prediction.py +2 -5
  73. careamics/transforms/__init__.py +4 -25
  74. careamics/transforms/compose.py +124 -0
  75. careamics/transforms/n2v_manipulate.py +65 -34
  76. careamics/transforms/normalize.py +91 -28
  77. careamics/transforms/pixel_manipulation.py +7 -7
  78. careamics/transforms/struct_mask_parameters.py +3 -1
  79. careamics/transforms/transform.py +24 -0
  80. careamics/transforms/tta.py +2 -2
  81. careamics/transforms/xy_flip.py +123 -0
  82. careamics/transforms/xy_random_rotate90.py +66 -60
  83. careamics/utils/__init__.py +0 -1
  84. careamics/utils/base_enum.py +28 -0
  85. careamics/utils/context.py +1 -0
  86. careamics/utils/logging.py +1 -0
  87. careamics/utils/metrics.py +1 -0
  88. careamics/utils/path_utils.py +2 -0
  89. careamics/utils/ram.py +2 -0
  90. careamics/utils/receptive_field.py +93 -87
  91. careamics/utils/torch_utils.py +1 -0
  92. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/METADATA +17 -61
  93. careamics-0.1.0rc6.dist-info/RECORD +107 -0
  94. careamics/config/noise_models.py +0 -162
  95. careamics/config/support/supported_extraction_strategies.py +0 -24
  96. careamics/config/transformations/nd_flip_model.py +0 -32
  97. careamics/dataset/patching/patch_transform.py +0 -44
  98. careamics/losses/noise_model_factory.py +0 -40
  99. careamics/losses/noise_models.py +0 -524
  100. careamics/transforms/nd_flip.py +0 -93
  101. careamics-0.1.0rc4.dist-info/RECORD +0 -110
  102. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/WHEEL +0 -0
  103. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,124 @@
1
+ """A class chaining transforms together."""
2
+
3
+ from typing import Callable, Dict, List, Optional, Tuple
4
+
5
+ import numpy as np
6
+
7
+ from careamics.config.data_model import TRANSFORMS_UNION
8
+
9
+ from .n2v_manipulate import N2VManipulate
10
+ from .normalize import Normalize
11
+ from .transform import Transform
12
+ from .xy_flip import XYFlip
13
+ from .xy_random_rotate90 import XYRandomRotate90
14
+
15
+ ALL_TRANSFORMS = {
16
+ "Normalize": Normalize,
17
+ "N2VManipulate": N2VManipulate,
18
+ "XYFlip": XYFlip,
19
+ "XYRandomRotate90": XYRandomRotate90,
20
+ }
21
+
22
+
23
+ def get_all_transforms() -> Dict[str, type]:
24
+ """Return all the transforms accepted by CAREamics.
25
+
26
+ Returns
27
+ -------
28
+ dict
29
+ A dictionary with all the transforms accepted by CAREamics, where the keys are
30
+ the transform names and the values are the transform classes.
31
+ """
32
+ return ALL_TRANSFORMS
33
+
34
+
35
+ class Compose:
36
+ """A class chaining transforms together.
37
+
38
+ Parameters
39
+ ----------
40
+ transform_list : List[TRANSFORMS_UNION]
41
+ A list of dictionaries where each dictionary contains the name of a
42
+ transform and its parameters.
43
+
44
+ Attributes
45
+ ----------
46
+ _callable_transforms : Callable
47
+ A callable that applies the transforms to the input data.
48
+ """
49
+
50
+ def __init__(self, transform_list: List[TRANSFORMS_UNION]) -> None:
51
+ """Instantiate a Compose object.
52
+
53
+ Parameters
54
+ ----------
55
+ transform_list : List[TRANSFORMS_UNION]
56
+ A list of dictionaries where each dictionary contains the name of a
57
+ transform and its parameters.
58
+ """
59
+ # retrieve all available transforms
60
+ all_transforms = get_all_transforms()
61
+
62
+ # instantiate all transforms
63
+ transforms = [all_transforms[t.name](**t.model_dump()) for t in transform_list]
64
+
65
+ self._callable_transforms = self._chain_transforms(transforms)
66
+
67
+ def _chain_transforms(self, transforms: List[Transform]) -> Callable:
68
+ """Chain the transforms together.
69
+
70
+ Parameters
71
+ ----------
72
+ transforms : List[Transform]
73
+ A list of transforms to chain together.
74
+
75
+ Returns
76
+ -------
77
+ Callable
78
+ A callable that applies the transforms in order to the input data.
79
+ """
80
+
81
+ def _chain(
82
+ patch: np.ndarray, target: Optional[np.ndarray]
83
+ ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
84
+ """Chain transforms on the input data.
85
+
86
+ Parameters
87
+ ----------
88
+ patch : np.ndarray
89
+ Input data.
90
+ target : Optional[np.ndarray]
91
+ Target data, by default None.
92
+
93
+ Returns
94
+ -------
95
+ Tuple[np.ndarray, Optional[np.ndarray]]
96
+ The output of the transformations.
97
+ """
98
+ params = (patch, target)
99
+
100
+ for t in transforms:
101
+ params = t(*params)
102
+
103
+ return params
104
+
105
+ return _chain
106
+
107
+ def __call__(
108
+ self, patch: np.ndarray, target: Optional[np.ndarray] = None
109
+ ) -> Tuple[np.ndarray, ...]:
110
+ """Apply the transforms to the input data.
111
+
112
+ Parameters
113
+ ----------
114
+ patch : np.ndarray
115
+ The input data.
116
+ target : Optional[np.ndarray], optional
117
+ Target data, by default None.
118
+
119
+ Returns
120
+ -------
121
+ Tuple[np.ndarray, ...]
122
+ The output of the transformations.
123
+ """
124
+ return self._callable_transforms(patch, target)
@@ -1,26 +1,53 @@
1
+ """N2V manipulation transform."""
2
+
1
3
  from typing import Any, Literal, Optional, Tuple
2
4
 
3
5
  import numpy as np
4
- from albumentations import ImageOnlyTransform
5
6
 
6
7
  from careamics.config.support import SupportedPixelManipulation, SupportedStructAxis
8
+ from careamics.transforms.transform import Transform
7
9
 
8
10
  from .pixel_manipulation import median_manipulate, uniform_manipulate
9
11
  from .struct_mask_parameters import StructMaskParameters
10
12
 
11
13
 
12
- class N2VManipulate(ImageOnlyTransform):
14
+ class N2VManipulate(Transform):
13
15
  """
14
16
  Default augmentation for the N2V model.
15
17
 
16
- This transform expects (Z)YXC dimensions.
18
+ This transform expects C(Z)YX dimensions.
17
19
 
18
20
  Parameters
19
21
  ----------
20
- mask_pixel_percentage : float
21
- Approximate percentage of pixels to be masked.
22
+ roi_size : int, optional
23
+ Size of the replacement area, by default 11.
24
+ masked_pixel_percentage : float, optional
25
+ Percentage of pixels to mask, by default 0.2.
26
+ strategy : Literal[ "uniform", "median" ], optional
27
+ Replaccement strategy, uniform or median, by default uniform.
28
+ remove_center : bool, optional
29
+ Whether to remove central pixel from patch, by default True.
30
+ struct_mask_axis : Literal["horizontal", "vertical", "none"], optional
31
+ StructN2V mask axis, by default "none".
32
+ struct_mask_span : int, optional
33
+ StructN2V mask span, by default 5.
34
+ seed : Optional[int], optional
35
+ Random seed, by default None.
36
+
37
+ Attributes
38
+ ----------
39
+ masked_pixel_percentage : float
40
+ Percentage of pixels to mask.
22
41
  roi_size : int
23
- Size of the ROI the new pixel value is sampled from, by default 11.
42
+ Size of the replacement area.
43
+ strategy : Literal[ "uniform", "median" ]
44
+ Replaccement strategy, uniform or median.
45
+ remove_center : bool
46
+ Whether to remove central pixel from patch.
47
+ struct_mask : Optional[StructMaskParameters]
48
+ StructN2V mask parameters.
49
+ rng : Generator
50
+ Random number generator.
24
51
  """
25
52
 
26
53
  def __init__(
@@ -33,29 +60,31 @@ class N2VManipulate(ImageOnlyTransform):
33
60
  remove_center: bool = True,
34
61
  struct_mask_axis: Literal["horizontal", "vertical", "none"] = "none",
35
62
  struct_mask_span: int = 5,
63
+ seed: Optional[int] = None, # TODO use in pixel manipulation
36
64
  ):
37
65
  """Constructor.
38
66
 
39
67
  Parameters
40
68
  ----------
41
69
  roi_size : int, optional
42
- Size of the replacement area, by default 11
70
+ Size of the replacement area, by default 11.
43
71
  masked_pixel_percentage : float, optional
44
- Percentage of pixels to mask, by default 0.2
72
+ Percentage of pixels to mask, by default 0.2.
45
73
  strategy : Literal[ "uniform", "median" ], optional
46
- Replaccement strategy, uniform or median, by default uniform
74
+ Replaccement strategy, uniform or median, by default uniform.
47
75
  remove_center : bool, optional
48
- Whether to remove central pixel from patch, by default True
76
+ Whether to remove central pixel from patch, by default True.
49
77
  struct_mask_axis : Literal["horizontal", "vertical", "none"], optional
50
- StructN2V mask axis, by default "none"
78
+ StructN2V mask axis, by default "none".
51
79
  struct_mask_span : int, optional
52
- StructN2V mask span, by default 5
80
+ StructN2V mask span, by default 5.
81
+ seed : Optional[int], optional
82
+ Random seed, by default None.
53
83
  """
54
- super().__init__(p=1)
55
84
  self.masked_pixel_percentage = masked_pixel_percentage
56
85
  self.roi_size = roi_size
57
86
  self.strategy = strategy
58
- self.remove_center = remove_center
87
+ self.remove_center = remove_center # TODO is this ever used?
59
88
 
60
89
  if struct_mask_axis == SupportedStructAxis.NONE:
61
90
  self.struct_mask: Optional[StructMaskParameters] = None
@@ -65,23 +94,35 @@ class N2VManipulate(ImageOnlyTransform):
65
94
  span=struct_mask_span,
66
95
  )
67
96
 
68
- def apply(
69
- self, patch: np.ndarray, **kwargs: Any
97
+ # numpy random generator
98
+ self.rng = np.random.default_rng(seed=seed)
99
+
100
+ def __call__(
101
+ self, patch: np.ndarray, *args: Any, **kwargs: Any
70
102
  ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
71
103
  """Apply the transform to the image.
72
104
 
73
105
  Parameters
74
106
  ----------
75
- image : np.ndarray
76
- Image or image patch, 2D or 3D, shape (y, x, c) or (z, y, x, c).
107
+ patch : np.ndarray
108
+ Image patch, 2D or 3D, shape C(Z)YX.
109
+ *args : Any
110
+ Additional arguments, unused.
111
+ **kwargs : Any
112
+ Additional keyword arguments, unused.
113
+
114
+ Returns
115
+ -------
116
+ Tuple[np.ndarray, np.ndarray, np.ndarray]
117
+ Masked patch, original patch, and mask.
77
118
  """
78
119
  masked = np.zeros_like(patch)
79
120
  mask = np.zeros_like(patch)
80
121
  if self.strategy == SupportedPixelManipulation.UNIFORM:
81
122
  # Iterate over the channels to apply manipulation separately
82
- for c in range(patch.shape[-1]):
83
- masked[..., c], mask[..., c] = uniform_manipulate(
84
- patch=patch[..., c],
123
+ for c in range(patch.shape[0]):
124
+ masked[c, ...], mask[c, ...] = uniform_manipulate(
125
+ patch=patch[c, ...],
85
126
  mask_pixel_percentage=self.masked_pixel_percentage,
86
127
  subpatch_size=self.roi_size,
87
128
  remove_center=self.remove_center,
@@ -89,9 +130,9 @@ class N2VManipulate(ImageOnlyTransform):
89
130
  )
90
131
  elif self.strategy == SupportedPixelManipulation.MEDIAN:
91
132
  # Iterate over the channels to apply manipulation separately
92
- for c in range(patch.shape[-1]):
93
- masked[..., c], mask[..., c] = median_manipulate(
94
- patch=patch[..., c],
133
+ for c in range(patch.shape[0]):
134
+ masked[c, ...], mask[c, ...] = median_manipulate(
135
+ patch=patch[c, ...],
95
136
  mask_pixel_percentage=self.masked_pixel_percentage,
96
137
  subpatch_size=self.roi_size,
97
138
  struct_params=self.struct_mask,
@@ -101,13 +142,3 @@ class N2VManipulate(ImageOnlyTransform):
101
142
 
102
143
  # TODO why return patch?
103
144
  return masked, patch, mask
104
-
105
- def get_transform_init_args_names(self) -> Tuple[str, ...]:
106
- """Get the transform parameters.
107
-
108
- Returns
109
- -------
110
- Tuple[str, ...]
111
- Transform parameters.
112
- """
113
- return ("roi_size", "masked_pixel_percentage", "strategy", "struct_mask")
@@ -1,27 +1,35 @@
1
- from typing import Any
1
+ """Normalization and denormalization transforms for image patches."""
2
+
3
+ from typing import Optional, Tuple
2
4
 
3
5
  import numpy as np
4
- from albumentations import DualTransform
6
+
7
+ from careamics.transforms.transform import Transform
5
8
 
6
9
 
7
- class Normalize(DualTransform):
10
+ class Normalize(Transform):
8
11
  """
9
12
  Normalize an image or image patch.
10
13
 
11
- Normalization is a zero mean and unit variance. This transform expects (Z)YXC
14
+ Normalization is a zero mean and unit variance. This transform expects C(Z)YX
12
15
  dimensions.
13
16
 
14
17
  Not that an epsilon value of 1e-6 is added to the standard deviation to avoid
15
18
  division by zero and that it returns a float32 image.
16
19
 
20
+ Parameters
21
+ ----------
22
+ mean : float
23
+ Mean value.
24
+ std : float
25
+ Standard deviation value.
26
+
17
27
  Attributes
18
28
  ----------
19
29
  mean : float
20
30
  Mean value.
21
31
  std : float
22
32
  Standard deviation value.
23
- eps : float
24
- Epsilon value to avoid division by zero.
25
33
  """
26
34
 
27
35
  def __init__(
@@ -29,61 +37,82 @@ class Normalize(DualTransform):
29
37
  mean: float,
30
38
  std: float,
31
39
  ):
32
- super().__init__(always_apply=True, p=1)
40
+ """Constructor.
33
41
 
42
+ Parameters
43
+ ----------
44
+ mean : float
45
+ Mean value.
46
+ std : float
47
+ Standard deviation value.
48
+ """
34
49
  self.mean = mean
35
50
  self.std = std
36
51
  self.eps = 1e-6
37
52
 
38
- def apply(self, patch: np.ndarray, **kwargs: Any) -> np.ndarray:
39
- """
40
- Apply the transform to the image.
53
+ def __call__(
54
+ self, patch: np.ndarray, target: Optional[np.ndarray] = None
55
+ ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
56
+ """Apply the transform to the source patch and the target (optional).
41
57
 
42
58
  Parameters
43
59
  ----------
44
60
  patch : np.ndarray
45
- Image or image patch, 2D or 3D, shape (y, x, c) or (z, y, x, c).
61
+ Patch, 2D or 3D, shape C(Z)YX.
62
+ target : Optional[np.ndarray], optional
63
+ Target for the patch, by default None.
46
64
 
47
65
  Returns
48
66
  -------
49
- np.ndarray
50
- Normalized image or image patch.
67
+ Tuple[np.ndarray, Optional[np.ndarray]]
68
+ Transformed patch and target.
51
69
  """
52
- return ((patch - self.mean) / (self.std + self.eps)).astype(np.float32)
70
+ norm_patch = self._apply(patch)
71
+ norm_target = self._apply(target) if target is not None else None
53
72
 
54
- def apply_to_mask(self, mask: np.ndarray, **kwargs: Any) -> np.ndarray:
55
- """
56
- Apply the transform to the mask.
73
+ return norm_patch, norm_target
57
74
 
58
- The mask is returned as is.
75
+ def _apply(self, patch: np.ndarray) -> np.ndarray:
76
+ """
77
+ Apply the transform to the image.
59
78
 
60
79
  Parameters
61
80
  ----------
62
- mask : np.ndarray
63
- Mask or mask patch, 2D or 3D, shape (y, x, c) or (z, y, x, c).
81
+ patch : np.ndarray
82
+ Image patch, 2D or 3D, shape C(Z)YX.
83
+
84
+ Returns
85
+ -------
86
+ np.ndarray
87
+ Normalizedimage patch.
64
88
  """
65
- return mask
89
+ return ((patch - self.mean) / (self.std + self.eps)).astype(np.float32)
66
90
 
67
91
 
68
- class Denormalize(DualTransform):
92
+ class Denormalize:
69
93
  """
70
94
  Denormalize an image or image patch.
71
95
 
72
96
  Denormalization is performed expecting a zero mean and unit variance input. This
73
- transform expects (Z)YXC dimensions.
97
+ transform expects C(Z)YX dimensions.
74
98
 
75
99
  Not that an epsilon value of 1e-6 is added to the standard deviation to avoid
76
100
  division by zero during the normalization step, which is taken into account during
77
101
  denormalization.
78
102
 
103
+ Parameters
104
+ ----------
105
+ mean : float
106
+ Mean value.
107
+ std : float
108
+ Standard deviation value.
109
+
79
110
  Attributes
80
111
  ----------
81
112
  mean : float
82
113
  Mean value.
83
114
  std : float
84
115
  Standard deviation value.
85
- eps : float
86
- Epsilon value to avoid division by zero.
87
116
  """
88
117
 
89
118
  def __init__(
@@ -91,19 +120,53 @@ class Denormalize(DualTransform):
91
120
  mean: float,
92
121
  std: float,
93
122
  ):
94
- super().__init__(always_apply=True, p=1)
123
+ """Constructor.
95
124
 
125
+ Parameters
126
+ ----------
127
+ mean : float
128
+ Mean.
129
+ std : float
130
+ Standard deviation.
131
+ """
96
132
  self.mean = mean
97
133
  self.std = std
98
134
  self.eps = 1e-6
99
135
 
100
- def apply(self, patch: np.ndarray, **kwargs: Any) -> np.ndarray:
136
+ def __call__(
137
+ self, patch: np.ndarray, target: Optional[np.ndarray] = None
138
+ ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
139
+ """Apply the transform to the source patch and the target (optional).
140
+
141
+ Parameters
142
+ ----------
143
+ patch : np.ndarray
144
+ Patch, 2D or 3D, shape C(Z)YX.
145
+ target : Optional[np.ndarray], optional
146
+ Target for the patch, by default None.
147
+
148
+ Returns
149
+ -------
150
+ Tuple[np.ndarray, Optional[np.ndarray]]
151
+ Transformed patch and target.
152
+ """
153
+ norm_patch = self._apply(patch)
154
+ norm_target = self._apply(target) if target is not None else None
155
+
156
+ return norm_patch, norm_target
157
+
158
+ def _apply(self, patch: np.ndarray) -> np.ndarray:
101
159
  """
102
160
  Apply the transform to the image.
103
161
 
104
162
  Parameters
105
163
  ----------
106
164
  patch : np.ndarray
107
- Image or image patch, 2D or 3D, shape (y, x, c) or (z, y, x, c).
165
+ Image patch, 2D or 3D, shape C(Z)YX.
166
+
167
+ Returns
168
+ -------
169
+ np.ndarray
170
+ Denormalized image patch.
108
171
  """
109
172
  return patch * (self.std + self.eps) + self.mean
@@ -4,7 +4,8 @@ Pixel manipulation methods.
4
4
  Pixel manipulation is used in N2V and similar algorithm to replace the value of
5
5
  masked pixels.
6
6
  """
7
- from typing import Optional, Tuple, Union
7
+
8
+ from typing import Optional, Tuple
8
9
 
9
10
  import numpy as np
10
11
 
@@ -14,7 +15,7 @@ from .struct_mask_parameters import StructMaskParameters
14
15
  def _apply_struct_mask(
15
16
  patch: np.ndarray, coords: np.ndarray, struct_params: StructMaskParameters
16
17
  ) -> np.ndarray:
17
- """Applies structN2V masks to patch.
18
+ """Apply structN2V masks to patch.
18
19
 
19
20
  Each point in `coords` corresponds to the center of a mask, masks are paremeterized
20
21
  by `struct_params` and pixels in the mask (with respect to `coords`) are replaced by
@@ -97,7 +98,7 @@ def _odd_jitter_func(step: float, rng: np.random.Generator) -> np.ndarray:
97
98
 
98
99
 
99
100
  def _get_stratified_coords(
100
- mask_pixel_perc: float, shape: Union[Tuple[int, int], Tuple[int, int, int]]
101
+ mask_pixel_perc: float, shape: Tuple[int, ...]
101
102
  ) -> np.ndarray:
102
103
  """
103
104
  Generate coordinates of the pixels to mask.
@@ -246,9 +247,8 @@ def uniform_manipulate(
246
247
  subpatch_size : int
247
248
  Size of the subpatch the new pixel value is sampled from, by default 11.
248
249
  remove_center : bool
249
- Whether to remove the center pixel from the subpatch, by default False. See
250
- uniform with/without central pixel in the documentation. #TODO add link
251
- struct_params: Optional[StructMaskParameters]
250
+ Whether to remove the center pixel from the subpatch, by default False.
251
+ struct_params : Optional[StructMaskParameters]
252
252
  Parameters for the structN2V mask (axis and span).
253
253
 
254
254
  Returns
@@ -322,7 +322,7 @@ def median_manipulate(
322
322
  Approximate percentage of pixels to be masked.
323
323
  subpatch_size : int
324
324
  Size of the subpatch the new pixel value is sampled from, by default 11.
325
- struct_params: Optional[StructMaskParameters]
325
+ struct_params : Optional[StructMaskParameters]
326
326
  Parameters for the structN2V mask (axis and span).
327
327
 
328
328
  Returns
@@ -1,3 +1,5 @@
1
+ """Class representing the parameters of structN2V masks."""
2
+
1
3
  from dataclasses import dataclass
2
4
  from typing import Literal
3
5
 
@@ -6,7 +8,7 @@ from typing import Literal
6
8
  class StructMaskParameters:
7
9
  """Parameters of structN2V masks.
8
10
 
9
- Parameters
11
+ Attributes
10
12
  ----------
11
13
  axis : Literal[0, 1]
12
14
  Axis along which to apply the mask, horizontal (0) or vertical (1).
@@ -0,0 +1,24 @@
1
+ """A general parent class for transforms."""
2
+
3
+ from typing import Any
4
+
5
+
6
+ class Transform:
7
+ """A general parent class for transforms."""
8
+
9
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
10
+ """Apply the transform.
11
+
12
+ Parameters
13
+ ----------
14
+ *args : Any
15
+ Arguments.
16
+ **kwargs : Any
17
+ Keyword arguments.
18
+
19
+ Returns
20
+ -------
21
+ Any
22
+ Transformed data.
23
+ """
24
+ pass
@@ -1,7 +1,7 @@
1
1
  """Test-time augmentations."""
2
+
2
3
  from typing import List
3
4
 
4
- import numpy as np
5
5
  from torch import Tensor, flip, mean, rot90, stack
6
6
 
7
7
 
@@ -48,7 +48,7 @@ class ImageRestorationTTA:
48
48
  augmented_flip.append(flip(x_, dims=(-3, -1)))
49
49
  return augmented_flip
50
50
 
51
- def backward(self, x: List[Tensor]) -> np.ndarray:
51
+ def backward(self, x: List[Tensor]) -> Tensor:
52
52
  """Undo the test-time augmentation.
53
53
 
54
54
  Parameters