careamics 0.1.0rc5__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (118) hide show
  1. careamics/callbacks/hyperparameters_callback.py +10 -3
  2. careamics/callbacks/progress_bar_callback.py +37 -4
  3. careamics/careamist.py +164 -231
  4. careamics/config/algorithm_model.py +5 -18
  5. careamics/config/architectures/architecture_model.py +7 -0
  6. careamics/config/architectures/custom_model.py +11 -4
  7. careamics/config/architectures/register_model.py +3 -1
  8. careamics/config/architectures/unet_model.py +2 -0
  9. careamics/config/architectures/vae_model.py +2 -0
  10. careamics/config/callback_model.py +3 -15
  11. careamics/config/configuration_example.py +4 -5
  12. careamics/config/configuration_factory.py +27 -41
  13. careamics/config/configuration_model.py +11 -11
  14. careamics/config/data_model.py +89 -63
  15. careamics/config/inference_model.py +28 -81
  16. careamics/config/optimizer_models.py +11 -11
  17. careamics/config/support/__init__.py +0 -2
  18. careamics/config/support/supported_activations.py +2 -0
  19. careamics/config/support/supported_algorithms.py +3 -1
  20. careamics/config/support/supported_architectures.py +2 -0
  21. careamics/config/support/supported_data.py +2 -0
  22. careamics/config/support/supported_loggers.py +2 -0
  23. careamics/config/support/supported_losses.py +2 -0
  24. careamics/config/support/supported_optimizers.py +2 -0
  25. careamics/config/support/supported_pixel_manipulations.py +3 -3
  26. careamics/config/support/supported_struct_axis.py +2 -0
  27. careamics/config/support/supported_transforms.py +4 -16
  28. careamics/config/tile_information.py +28 -58
  29. careamics/config/transformations/__init__.py +3 -2
  30. careamics/config/transformations/normalize_model.py +32 -4
  31. careamics/config/transformations/xy_flip_model.py +43 -0
  32. careamics/config/transformations/xy_random_rotate90_model.py +11 -3
  33. careamics/config/validators/validator_utils.py +1 -1
  34. careamics/conftest.py +12 -0
  35. careamics/dataset/__init__.py +12 -1
  36. careamics/dataset/dataset_utils/__init__.py +8 -1
  37. careamics/dataset/dataset_utils/dataset_utils.py +4 -4
  38. careamics/dataset/dataset_utils/file_utils.py +4 -3
  39. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  40. careamics/dataset/dataset_utils/read_tiff.py +6 -11
  41. careamics/dataset/dataset_utils/read_utils.py +2 -0
  42. careamics/dataset/dataset_utils/read_zarr.py +11 -7
  43. careamics/dataset/dataset_utils/running_stats.py +186 -0
  44. careamics/dataset/in_memory_dataset.py +88 -154
  45. careamics/dataset/in_memory_pred_dataset.py +88 -0
  46. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  47. careamics/dataset/iterable_dataset.py +121 -191
  48. careamics/dataset/iterable_pred_dataset.py +121 -0
  49. careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
  50. careamics/dataset/patching/patching.py +109 -39
  51. careamics/dataset/patching/random_patching.py +17 -6
  52. careamics/dataset/patching/sequential_patching.py +14 -8
  53. careamics/dataset/patching/validate_patch_dimension.py +7 -3
  54. careamics/dataset/tiling/__init__.py +10 -0
  55. careamics/dataset/tiling/collate_tiles.py +33 -0
  56. careamics/dataset/{patching → tiling}/tiled_patching.py +7 -5
  57. careamics/dataset/zarr_dataset.py +2 -0
  58. careamics/lightning_datamodule.py +46 -25
  59. careamics/lightning_module.py +19 -9
  60. careamics/lightning_prediction_datamodule.py +54 -84
  61. careamics/losses/__init__.py +2 -3
  62. careamics/losses/loss_factory.py +1 -1
  63. careamics/losses/losses.py +11 -7
  64. careamics/lvae_training/__init__.py +0 -0
  65. careamics/lvae_training/data_modules.py +1220 -0
  66. careamics/lvae_training/data_utils.py +618 -0
  67. careamics/lvae_training/eval_utils.py +905 -0
  68. careamics/lvae_training/get_config.py +84 -0
  69. careamics/lvae_training/lightning_module.py +701 -0
  70. careamics/lvae_training/metrics.py +214 -0
  71. careamics/lvae_training/train_lvae.py +339 -0
  72. careamics/lvae_training/train_utils.py +121 -0
  73. careamics/model_io/bioimage/model_description.py +40 -32
  74. careamics/model_io/bmz_io.py +3 -3
  75. careamics/model_io/model_io_utils.py +5 -2
  76. careamics/models/activation.py +2 -0
  77. careamics/models/layers.py +121 -25
  78. careamics/models/lvae/__init__.py +0 -0
  79. careamics/models/lvae/layers.py +1998 -0
  80. careamics/models/lvae/likelihoods.py +312 -0
  81. careamics/models/lvae/lvae.py +985 -0
  82. careamics/models/lvae/noise_models.py +409 -0
  83. careamics/models/lvae/utils.py +395 -0
  84. careamics/models/model_factory.py +1 -1
  85. careamics/models/unet.py +35 -14
  86. careamics/prediction_utils/__init__.py +12 -0
  87. careamics/prediction_utils/create_pred_datamodule.py +185 -0
  88. careamics/prediction_utils/prediction_outputs.py +165 -0
  89. careamics/prediction_utils/stitch_prediction.py +100 -0
  90. careamics/transforms/__init__.py +2 -2
  91. careamics/transforms/compose.py +33 -7
  92. careamics/transforms/n2v_manipulate.py +52 -14
  93. careamics/transforms/normalize.py +171 -48
  94. careamics/transforms/pixel_manipulation.py +35 -11
  95. careamics/transforms/struct_mask_parameters.py +3 -1
  96. careamics/transforms/transform.py +10 -19
  97. careamics/transforms/tta.py +43 -29
  98. careamics/transforms/xy_flip.py +123 -0
  99. careamics/transforms/xy_random_rotate90.py +38 -5
  100. careamics/utils/base_enum.py +28 -0
  101. careamics/utils/path_utils.py +2 -0
  102. careamics/utils/ram.py +4 -2
  103. careamics/utils/receptive_field.py +93 -87
  104. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +8 -6
  105. careamics-0.1.0rc7.dist-info/RECORD +130 -0
  106. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
  107. careamics/config/noise_models.py +0 -162
  108. careamics/config/support/supported_extraction_strategies.py +0 -25
  109. careamics/config/transformations/nd_flip_model.py +0 -27
  110. careamics/lightning_prediction_loop.py +0 -116
  111. careamics/losses/noise_model_factory.py +0 -40
  112. careamics/losses/noise_models.py +0 -524
  113. careamics/prediction/__init__.py +0 -7
  114. careamics/prediction/stitch_prediction.py +0 -74
  115. careamics/transforms/nd_flip.py +0 -67
  116. careamics/utils/running_stats.py +0 -43
  117. careamics-0.1.0rc5.dist-info/RECORD +0 -111
  118. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,185 @@
1
+ """Module containing functions to create `CAREamicsPredictData`."""
2
+
3
+ from pathlib import Path
4
+ from typing import Callable, Dict, Literal, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ from numpy.typing import NDArray
8
+
9
+ from careamics.config import Configuration, create_inference_configuration
10
+ from careamics.utils import check_path_exists
11
+
12
+ from ..lightning_prediction_datamodule import CAREamicsPredictData
13
+
14
+
15
+ def create_pred_datamodule(
16
+ source: Union[CAREamicsPredictData, Path, str, NDArray],
17
+ config: Configuration,
18
+ batch_size: Optional[int] = None,
19
+ tile_size: Optional[Tuple[int, ...]] = None,
20
+ tile_overlap: Tuple[int, ...] = (48, 48),
21
+ axes: Optional[str] = None,
22
+ data_type: Optional[Literal["array", "tiff", "custom"]] = None,
23
+ tta_transforms: bool = True,
24
+ dataloader_params: Optional[Dict] = None,
25
+ read_source_func: Optional[Callable] = None,
26
+ extension_filter: str = "",
27
+ ) -> CAREamicsPredictData:
28
+ """
29
+ Create a `CAREamicsPredictData` module.
30
+
31
+ Parameters
32
+ ----------
33
+ source : CAREamicsPredData, pathlib.Path, str or numpy.ndarray
34
+ Data to predict on.
35
+ config : Configuration
36
+ Global configuration.
37
+ batch_size : int, default=1
38
+ Batch size for prediction.
39
+ tile_size : tuple of int, optional
40
+ Size of the tiles to use for prediction.
41
+ tile_overlap : tuple of int, default=(48, 48)
42
+ Overlap between tiles.
43
+ axes : str, optional
44
+ Axes of the input data, by default None.
45
+ data_type : {"array", "tiff", "custom"}, optional
46
+ Type of the input data.
47
+ tta_transforms : bool, default=True
48
+ Whether to apply test-time augmentation.
49
+ dataloader_params : dict, optional
50
+ Parameters to pass to the dataloader.
51
+ read_source_func : Callable, optional
52
+ Function to read the source data.
53
+ extension_filter : str, default=""
54
+ Filter for the file extension.
55
+
56
+ Returns
57
+ -------
58
+ prediction datamodule: CAREamicsPredictData
59
+ Subclass of `pytorch_lightning.LightningDataModule` for creating predictions.
60
+
61
+ Raises
62
+ ------
63
+ ValueError
64
+ If the input is not a CAREamicsPredData instance, a path or a numpy array.
65
+ """
66
+ # Reuse batch size if not provided explicitly
67
+ if batch_size is None:
68
+ batch_size = config.data_config.batch_size
69
+
70
+ # create predict config, reuse training config if parameters missing
71
+ prediction_config = create_inference_configuration(
72
+ configuration=config,
73
+ tile_size=tile_size,
74
+ tile_overlap=tile_overlap,
75
+ data_type=data_type,
76
+ axes=axes,
77
+ tta_transforms=tta_transforms,
78
+ batch_size=batch_size,
79
+ )
80
+
81
+ # remove batch from dataloader parameters (priority given to config)
82
+ if dataloader_params is None:
83
+ dataloader_params = {}
84
+ if "batch_size" in dataloader_params:
85
+ del dataloader_params["batch_size"]
86
+
87
+ if isinstance(source, CAREamicsPredictData):
88
+ pred_datamodule = source
89
+ elif isinstance(source, Path) or isinstance(source, str):
90
+ pred_datamodule = _create_from_path(
91
+ source=source,
92
+ pred_config=prediction_config,
93
+ read_source_func=read_source_func,
94
+ extension_filter=extension_filter,
95
+ dataloader_params=dataloader_params,
96
+ )
97
+ elif isinstance(source, np.ndarray):
98
+ pred_datamodule = _create_from_array(
99
+ source=source,
100
+ pred_config=prediction_config,
101
+ dataloader_params=dataloader_params,
102
+ )
103
+ else:
104
+ raise ValueError(
105
+ f"Invalid input. Expected a CAREamicsPredData instance, paths or "
106
+ f"NDArray (got {type(source)})."
107
+ )
108
+
109
+ return pred_datamodule
110
+
111
+
112
+ def _create_from_path(
113
+ source: Union[Path, str],
114
+ pred_config: Configuration,
115
+ read_source_func: Optional[Callable] = None,
116
+ extension_filter: str = "",
117
+ dataloader_params: Optional[Dict] = None,
118
+ **kwargs,
119
+ ) -> CAREamicsPredictData:
120
+ """
121
+ Create `CAREamicsPredictData` from path.
122
+
123
+ Parameters
124
+ ----------
125
+ source : Path or str
126
+ _Data to predict on.
127
+ pred_config : Configuration
128
+ Prediction configuration.
129
+ read_source_func : Callable, optional
130
+ Function to read the source data.
131
+ extension_filter : str, default=""
132
+ Function to read the source data.
133
+ dataloader_params : Optional[Dict], optional
134
+ Parameters to pass to the dataloader.
135
+ **kwargs
136
+ Unused.
137
+
138
+ Returns
139
+ -------
140
+ prediction datamodule: CAREamicsPredictData
141
+ Subclass of `pytorch_lightning.LightningDataModule` for creating predictions.
142
+ """
143
+ source_path = check_path_exists(source)
144
+
145
+ datamodule = CAREamicsPredictData(
146
+ pred_config=pred_config,
147
+ pred_data=source_path,
148
+ read_source_func=read_source_func,
149
+ extension_filter=extension_filter,
150
+ dataloader_params=dataloader_params,
151
+ )
152
+ return datamodule
153
+
154
+
155
+ def _create_from_array(
156
+ source: NDArray,
157
+ pred_config: Configuration,
158
+ dataloader_params: Optional[Dict] = None,
159
+ **kwargs,
160
+ ) -> CAREamicsPredictData:
161
+ """
162
+ Create `CAREamicsPredictData` from array.
163
+
164
+ Parameters
165
+ ----------
166
+ source : Path or str
167
+ _Data to predict on.
168
+ pred_config : Configuration
169
+ Prediction configuration.
170
+ dataloader_params : Optional[Dict], optional
171
+ Parameters to pass to the dataloader.
172
+ **kwargs
173
+ Unused. Added for compatible function signature with `_create_from_path`.
174
+
175
+ Returns
176
+ -------
177
+ prediction datamodule: CAREamicsPredictData
178
+ Subclass of `pytorch_lightning.LightningDataModule` for creating predictions.
179
+ """
180
+ datamodule = CAREamicsPredictData(
181
+ pred_config=pred_config,
182
+ pred_data=source,
183
+ dataloader_params=dataloader_params,
184
+ )
185
+ return datamodule
@@ -0,0 +1,165 @@
1
+ """Module containing functions to convert prediction outputs to desired form."""
2
+
3
+ from typing import Any, List, Literal, Tuple, Union, overload
4
+
5
+ import numpy as np
6
+ from numpy.typing import NDArray
7
+
8
+ from ..config.tile_information import TileInformation
9
+ from .stitch_prediction import stitch_prediction
10
+
11
+
12
+ def convert_outputs(
13
+ predictions: List[Any], tiled: bool
14
+ ) -> Union[List[NDArray], NDArray]:
15
+ """
16
+ Convert the outputs to the desired form.
17
+
18
+ Parameters
19
+ ----------
20
+ predictions : list
21
+ Predictions that are output from `Trainer.predict`.
22
+ tiled : bool
23
+ Whether the predictions are tiled.
24
+
25
+ Returns
26
+ -------
27
+ list of numpy.ndarray or numpy.ndarray
28
+ List of arrays with the axes SC(Z)YX. If there is only 1 output it will not
29
+ be in a list.
30
+ """
31
+ if len(predictions) == 0:
32
+ return predictions
33
+
34
+ # this layout is to stop mypy complaining
35
+ if tiled:
36
+ predictions_comb = combine_batches(predictions, tiled)
37
+ # remove sample dimension (always 1) `stitch_predict` func expects no S dim
38
+ tiles = [pred[0] for pred in predictions_comb[0]]
39
+ tile_infos = predictions_comb[1]
40
+ predictions_output = stitch_prediction(tiles, tile_infos)
41
+ else:
42
+ predictions_output = combine_batches(predictions, tiled)
43
+
44
+ # TODO: add this in? Returns output with same axes as input
45
+ # Won't work with tiling rn because stitch_prediction func removes S axis
46
+ # predictions = reshape(predictions, axes)
47
+ # At least make sure stitched prediction and non-tiled prediction have matching axes
48
+
49
+ # TODO: might want to remove this
50
+ if len(predictions_output) == 1:
51
+ return predictions_output[0]
52
+ return predictions_output
53
+
54
+
55
+ # for mypy
56
+ @overload
57
+ def combine_batches( # numpydoc ignore=GL08
58
+ predictions: List[Any], tiled: Literal[True]
59
+ ) -> Tuple[List[NDArray], List[TileInformation]]: ...
60
+
61
+
62
+ # for mypy
63
+ @overload
64
+ def combine_batches( # numpydoc ignore=GL08
65
+ predictions: List[Any], tiled: Literal[False]
66
+ ) -> List[NDArray]: ...
67
+
68
+
69
+ # for mypy
70
+ @overload
71
+ def combine_batches( # numpydoc ignore=GL08
72
+ predictions: List[Any], tiled: Union[bool, Literal[True], Literal[False]]
73
+ ) -> Union[List[NDArray], Tuple[List[NDArray], List[TileInformation]]]: ...
74
+
75
+
76
+ def combine_batches(
77
+ predictions: List[Any], tiled: bool
78
+ ) -> Union[List[NDArray], Tuple[List[NDArray], List[TileInformation]]]:
79
+ """
80
+ If predictions are in batches, they will be combined.
81
+
82
+ Parameters
83
+ ----------
84
+ predictions : list
85
+ Predictions that are output from `Trainer.predict`.
86
+ tiled : bool
87
+ Whether the predictions are tiled.
88
+
89
+ Returns
90
+ -------
91
+ (list of numpy.ndarray) or tuple of (list of numpy.ndarray, list of TileInformation)
92
+ Combined batches.
93
+ """
94
+ if tiled:
95
+ return _combine_tiled_batches(predictions)
96
+ else:
97
+ return _combine_untiled_batches(predictions)
98
+
99
+
100
+ def _combine_tiled_batches(
101
+ predictions: List[Tuple[NDArray, List[TileInformation]]]
102
+ ) -> Tuple[List[NDArray], List[TileInformation]]:
103
+ """
104
+ Combine batches from tiled output.
105
+
106
+ Parameters
107
+ ----------
108
+ predictions : list
109
+ Predictions that are output from `Trainer.predict`.
110
+
111
+ Returns
112
+ -------
113
+ tuple of (list of numpy.ndarray, list of TileInformation)
114
+ Combined batches.
115
+ """
116
+ # turn list of lists into single list
117
+ tile_infos = [
118
+ tile_info for _, tile_info_list in predictions for tile_info in tile_info_list
119
+ ]
120
+ prediction_tiles: List[NDArray] = _combine_untiled_batches(
121
+ [preds for preds, _ in predictions]
122
+ )
123
+ return prediction_tiles, tile_infos
124
+
125
+
126
+ def _combine_untiled_batches(predictions: List[NDArray]) -> List[NDArray]:
127
+ """
128
+ Combine batches from un-tiled output.
129
+
130
+ Parameters
131
+ ----------
132
+ predictions : list
133
+ Predictions that are output from `Trainer.predict`.
134
+
135
+ Returns
136
+ -------
137
+ list of nunpy.ndarray
138
+ Combined batches.
139
+ """
140
+ prediction_concat: NDArray = np.concatenate(predictions, axis=0)
141
+ prediction_split = np.split(prediction_concat, prediction_concat.shape[0], axis=0)
142
+ return prediction_split
143
+
144
+
145
+ def reshape(predictions: List[NDArray], axes: str) -> List[NDArray]:
146
+ """
147
+ Reshape predictions to have dimensions of input.
148
+
149
+ Parameters
150
+ ----------
151
+ predictions : list
152
+ Predictions that are output from `Trainer.predict`.
153
+ axes : str
154
+ Axes SC(Z)YX.
155
+
156
+ Returns
157
+ -------
158
+ List[NDArray]
159
+ Reshaped predicitions.
160
+ """
161
+ if "C" not in axes:
162
+ predictions = [pred[:, 0] for pred in predictions]
163
+ if "S" not in axes:
164
+ predictions = [pred[0] for pred in predictions]
165
+ return predictions
@@ -0,0 +1,100 @@
1
+ """Prediction utility functions."""
2
+
3
+ from typing import List
4
+
5
+ import numpy as np
6
+
7
+ from careamics.config.tile_information import TileInformation
8
+
9
+
10
+ # TODO: why not allow input and output of torch.tensor ?
11
+ def stitch_prediction(
12
+ tiles: List[np.ndarray],
13
+ tile_infos: List[TileInformation],
14
+ ) -> List[np.ndarray]:
15
+ """
16
+ Stitch tiles back together to form a full image(s).
17
+
18
+ Tiles are of dimensions SC(Z)YX, where C is the number of channels and can be a
19
+ singleton dimension.
20
+
21
+ Parameters
22
+ ----------
23
+ tiles : list of numpy.ndarray
24
+ Cropped tiles and their respective stitching coordinates. Can contain tiles
25
+ from multiple images.
26
+ tile_infos : list of TileInformation
27
+ List of information and coordinates obtained from
28
+ `dataset.tiled_patching.extract_tiles`.
29
+
30
+ Returns
31
+ -------
32
+ list of numpy.ndarray
33
+ Full image(s).
34
+ """
35
+ # Find where to split the lists so that only info from one image is contained.
36
+ # Do this by locating the last tiles of each image.
37
+ last_tiles = [tile_info.last_tile for tile_info in tile_infos]
38
+ last_tile_position = np.where(last_tiles)[0]
39
+ image_slices = [
40
+ slice(
41
+ None if i == 0 else last_tile_position[i - 1] + 1, last_tile_position[i] + 1
42
+ )
43
+ for i in range(len(last_tile_position))
44
+ ]
45
+ image_predictions = []
46
+ # slice the lists and apply stitch_prediction_single to each in turn.
47
+ for image_slice in image_slices:
48
+ image_predictions.append(
49
+ stitch_prediction_single(tiles[image_slice], tile_infos[image_slice])
50
+ )
51
+ return image_predictions
52
+
53
+
54
+ def stitch_prediction_single(
55
+ tiles: List[np.ndarray],
56
+ tile_infos: List[TileInformation],
57
+ ) -> np.ndarray:
58
+ """
59
+ Stitch tiles back together to form a full image.
60
+
61
+ Tiles are of dimensions SC(Z)YX, where C is the number of channels and can be a
62
+ singleton dimension.
63
+
64
+ Parameters
65
+ ----------
66
+ tiles : list of numpy.ndarray
67
+ Cropped tiles and their respective stitching coordinates.
68
+ tile_infos : list of TileInformation
69
+ List of information and coordinates obtained from
70
+ `dataset.tiled_patching.extract_tiles`.
71
+
72
+ Returns
73
+ -------
74
+ numpy.ndarray
75
+ Full image.
76
+ """
77
+ # retrieve whole array size
78
+ input_shape = tile_infos[0].array_shape
79
+ predicted_image = np.zeros(input_shape, dtype=np.float32)
80
+
81
+ for tile, tile_info in zip(tiles, tile_infos):
82
+ n_channels = tile.shape[0]
83
+
84
+ # Compute coordinates for cropping predicted tile
85
+ slices = (slice(0, n_channels),) + tuple(
86
+ [slice(c[0], c[1]) for c in tile_info.overlap_crop_coords]
87
+ )
88
+
89
+ # Crop predited tile according to overlap coordinates
90
+ cropped_tile = tile[slices]
91
+
92
+ # Insert cropped tile into predicted image using stitch coordinates
93
+ predicted_image[
94
+ (
95
+ ...,
96
+ *[slice(c[0], c[1]) for c in tile_info.stitch_coords],
97
+ )
98
+ ] = cropped_tile.astype(np.float32)
99
+
100
+ return predicted_image
@@ -3,7 +3,7 @@
3
3
  __all__ = [
4
4
  "get_all_transforms",
5
5
  "N2VManipulate",
6
- "NDFlip",
6
+ "XYFlip",
7
7
  "XYRandomRotate90",
8
8
  "ImageRestorationTTA",
9
9
  "Denormalize",
@@ -14,7 +14,7 @@ __all__ = [
14
14
 
15
15
  from .compose import Compose, get_all_transforms
16
16
  from .n2v_manipulate import N2VManipulate
17
- from .nd_flip import NDFlip
18
17
  from .normalize import Denormalize, Normalize
19
18
  from .tta import ImageRestorationTTA
19
+ from .xy_flip import XYFlip
20
20
  from .xy_random_rotate90 import XYRandomRotate90
@@ -1,26 +1,26 @@
1
1
  """A class chaining transforms together."""
2
2
 
3
- from typing import Callable, List, Optional, Tuple
3
+ from typing import Callable, Dict, List, Optional, Tuple
4
4
 
5
5
  import numpy as np
6
6
 
7
7
  from careamics.config.data_model import TRANSFORMS_UNION
8
8
 
9
9
  from .n2v_manipulate import N2VManipulate
10
- from .nd_flip import NDFlip
11
10
  from .normalize import Normalize
12
11
  from .transform import Transform
12
+ from .xy_flip import XYFlip
13
13
  from .xy_random_rotate90 import XYRandomRotate90
14
14
 
15
15
  ALL_TRANSFORMS = {
16
16
  "Normalize": Normalize,
17
17
  "N2VManipulate": N2VManipulate,
18
- "NDFlip": NDFlip,
18
+ "XYFlip": XYFlip,
19
19
  "XYRandomRotate90": XYRandomRotate90,
20
20
  }
21
21
 
22
22
 
23
- def get_all_transforms() -> dict:
23
+ def get_all_transforms() -> Dict[str, type]:
24
24
  """Return all the transforms accepted by CAREamics.
25
25
 
26
26
  Returns
@@ -33,7 +33,19 @@ def get_all_transforms() -> dict:
33
33
 
34
34
 
35
35
  class Compose:
36
- """A class chaining transforms together."""
36
+ """A class chaining transforms together.
37
+
38
+ Parameters
39
+ ----------
40
+ transform_list : List[TRANSFORMS_UNION]
41
+ A list of dictionaries where each dictionary contains the name of a
42
+ transform and its parameters.
43
+
44
+ Attributes
45
+ ----------
46
+ _callable_transforms : Callable
47
+ A callable that applies the transforms to the input data.
48
+ """
37
49
 
38
50
  def __init__(self, transform_list: List[TRANSFORMS_UNION]) -> None:
39
51
  """Instantiate a Compose object.
@@ -68,7 +80,21 @@ class Compose:
68
80
 
69
81
  def _chain(
70
82
  patch: np.ndarray, target: Optional[np.ndarray]
71
- ) -> Tuple[np.ndarray, ...]:
83
+ ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
84
+ """Chain transforms on the input data.
85
+
86
+ Parameters
87
+ ----------
88
+ patch : np.ndarray
89
+ Input data.
90
+ target : Optional[np.ndarray]
91
+ Target data, by default None.
92
+
93
+ Returns
94
+ -------
95
+ Tuple[np.ndarray, Optional[np.ndarray]]
96
+ The output of the transformations.
97
+ """
72
98
  params = (patch, target)
73
99
 
74
100
  for t in transforms:
@@ -88,7 +114,7 @@ class Compose:
88
114
  patch : np.ndarray
89
115
  The input data.
90
116
  target : Optional[np.ndarray], optional
91
- Target data, by default None
117
+ Target data, by default None.
92
118
 
93
119
  Returns
94
120
  -------
@@ -1,3 +1,5 @@
1
+ """N2V manipulation transform."""
2
+
1
3
  from typing import Any, Literal, Optional, Tuple
2
4
 
3
5
  import numpy as np
@@ -17,10 +19,35 @@ class N2VManipulate(Transform):
17
19
 
18
20
  Parameters
19
21
  ----------
20
- mask_pixel_percentage : float
21
- Approximate percentage of pixels to be masked.
22
+ roi_size : int, optional
23
+ Size of the replacement area, by default 11.
24
+ masked_pixel_percentage : float, optional
25
+ Percentage of pixels to mask, by default 0.2.
26
+ strategy : Literal[ "uniform", "median" ], optional
27
+ Replaccement strategy, uniform or median, by default uniform.
28
+ remove_center : bool, optional
29
+ Whether to remove central pixel from patch, by default True.
30
+ struct_mask_axis : Literal["horizontal", "vertical", "none"], optional
31
+ StructN2V mask axis, by default "none".
32
+ struct_mask_span : int, optional
33
+ StructN2V mask span, by default 5.
34
+ seed : Optional[int], optional
35
+ Random seed, by default None.
36
+
37
+ Attributes
38
+ ----------
39
+ masked_pixel_percentage : float
40
+ Percentage of pixels to mask.
22
41
  roi_size : int
23
- Size of the ROI the new pixel value is sampled from, by default 11.
42
+ Size of the replacement area.
43
+ strategy : Literal[ "uniform", "median" ]
44
+ Replaccement strategy, uniform or median.
45
+ remove_center : bool
46
+ Whether to remove central pixel from patch.
47
+ struct_mask : Optional[StructMaskParameters]
48
+ StructN2V mask parameters.
49
+ rng : Generator
50
+ Random number generator.
24
51
  """
25
52
 
26
53
  def __init__(
@@ -33,31 +60,31 @@ class N2VManipulate(Transform):
33
60
  remove_center: bool = True,
34
61
  struct_mask_axis: Literal["horizontal", "vertical", "none"] = "none",
35
62
  struct_mask_span: int = 5,
36
- seed: Optional[int] = None, # TODO use in pixel manipulation
63
+ seed: Optional[int] = None,
37
64
  ):
38
65
  """Constructor.
39
66
 
40
67
  Parameters
41
68
  ----------
42
69
  roi_size : int, optional
43
- Size of the replacement area, by default 11
70
+ Size of the replacement area, by default 11.
44
71
  masked_pixel_percentage : float, optional
45
- Percentage of pixels to mask, by default 0.2
72
+ Percentage of pixels to mask, by default 0.2.
46
73
  strategy : Literal[ "uniform", "median" ], optional
47
- Replaccement strategy, uniform or median, by default uniform
74
+ Replaccement strategy, uniform or median, by default uniform.
48
75
  remove_center : bool, optional
49
- Whether to remove central pixel from patch, by default True
76
+ Whether to remove central pixel from patch, by default True.
50
77
  struct_mask_axis : Literal["horizontal", "vertical", "none"], optional
51
- StructN2V mask axis, by default "none"
78
+ StructN2V mask axis, by default "none".
52
79
  struct_mask_span : int, optional
53
- StructN2V mask span, by default 5
80
+ StructN2V mask span, by default 5.
54
81
  seed : Optional[int], optional
55
- Random seed, by default None
82
+ Random seed, by default None.
56
83
  """
57
84
  self.masked_pixel_percentage = masked_pixel_percentage
58
85
  self.roi_size = roi_size
59
86
  self.strategy = strategy
60
- self.remove_center = remove_center
87
+ self.remove_center = remove_center # TODO is this ever used?
61
88
 
62
89
  if struct_mask_axis == SupportedStructAxis.NONE:
63
90
  self.struct_mask: Optional[StructMaskParameters] = None
@@ -77,8 +104,17 @@ class N2VManipulate(Transform):
77
104
 
78
105
  Parameters
79
106
  ----------
80
- image : np.ndarray
81
- Image or image patch, 2D or 3D, shape C(Z)YX.
107
+ patch : np.ndarray
108
+ Image patch, 2D or 3D, shape C(Z)YX.
109
+ *args : Any
110
+ Additional arguments, unused.
111
+ **kwargs : Any
112
+ Additional keyword arguments, unused.
113
+
114
+ Returns
115
+ -------
116
+ Tuple[np.ndarray, np.ndarray, np.ndarray]
117
+ Masked patch, original patch, and mask.
82
118
  """
83
119
  masked = np.zeros_like(patch)
84
120
  mask = np.zeros_like(patch)
@@ -91,6 +127,7 @@ class N2VManipulate(Transform):
91
127
  subpatch_size=self.roi_size,
92
128
  remove_center=self.remove_center,
93
129
  struct_params=self.struct_mask,
130
+ rng=self.rng,
94
131
  )
95
132
  elif self.strategy == SupportedPixelManipulation.MEDIAN:
96
133
  # Iterate over the channels to apply manipulation separately
@@ -100,6 +137,7 @@ class N2VManipulate(Transform):
100
137
  mask_pixel_percentage=self.masked_pixel_percentage,
101
138
  subpatch_size=self.roi_size,
102
139
  struct_params=self.struct_mask,
140
+ rng=self.rng,
103
141
  )
104
142
  else:
105
143
  raise ValueError(f"Unknown masking strategy ({self.strategy}).")