careamics 0.1.0rc5__py3-none-any.whl → 0.1.0rc6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (81) hide show
  1. careamics/callbacks/hyperparameters_callback.py +10 -3
  2. careamics/callbacks/progress_bar_callback.py +37 -4
  3. careamics/careamist.py +80 -44
  4. careamics/config/algorithm_model.py +5 -3
  5. careamics/config/architectures/architecture_model.py +7 -0
  6. careamics/config/architectures/custom_model.py +8 -1
  7. careamics/config/architectures/register_model.py +3 -1
  8. careamics/config/architectures/unet_model.py +2 -0
  9. careamics/config/architectures/vae_model.py +2 -0
  10. careamics/config/callback_model.py +3 -15
  11. careamics/config/configuration_example.py +4 -2
  12. careamics/config/configuration_factory.py +4 -16
  13. careamics/config/data_model.py +10 -14
  14. careamics/config/inference_model.py +0 -65
  15. careamics/config/optimizer_models.py +4 -4
  16. careamics/config/support/__init__.py +0 -2
  17. careamics/config/support/supported_activations.py +2 -0
  18. careamics/config/support/supported_algorithms.py +3 -1
  19. careamics/config/support/supported_architectures.py +2 -0
  20. careamics/config/support/supported_data.py +2 -0
  21. careamics/config/support/supported_loggers.py +2 -0
  22. careamics/config/support/supported_losses.py +2 -0
  23. careamics/config/support/supported_optimizers.py +2 -0
  24. careamics/config/support/supported_pixel_manipulations.py +3 -3
  25. careamics/config/support/supported_struct_axis.py +2 -0
  26. careamics/config/support/supported_transforms.py +4 -15
  27. careamics/config/tile_information.py +2 -0
  28. careamics/config/transformations/__init__.py +3 -2
  29. careamics/config/transformations/xy_flip_model.py +43 -0
  30. careamics/config/transformations/xy_random_rotate90_model.py +11 -3
  31. careamics/conftest.py +12 -0
  32. careamics/dataset/dataset_utils/dataset_utils.py +4 -4
  33. careamics/dataset/dataset_utils/file_utils.py +4 -3
  34. careamics/dataset/dataset_utils/read_tiff.py +6 -2
  35. careamics/dataset/dataset_utils/read_utils.py +2 -0
  36. careamics/dataset/dataset_utils/read_zarr.py +11 -7
  37. careamics/dataset/in_memory_dataset.py +71 -32
  38. careamics/dataset/iterable_dataset.py +155 -68
  39. careamics/dataset/patching/patching.py +56 -15
  40. careamics/dataset/patching/random_patching.py +8 -2
  41. careamics/dataset/patching/sequential_patching.py +14 -8
  42. careamics/dataset/patching/tiled_patching.py +3 -1
  43. careamics/dataset/patching/validate_patch_dimension.py +2 -0
  44. careamics/dataset/zarr_dataset.py +2 -0
  45. careamics/lightning_datamodule.py +45 -19
  46. careamics/lightning_module.py +8 -2
  47. careamics/lightning_prediction_datamodule.py +3 -13
  48. careamics/lightning_prediction_loop.py +8 -6
  49. careamics/losses/__init__.py +2 -3
  50. careamics/losses/loss_factory.py +1 -1
  51. careamics/losses/losses.py +11 -7
  52. careamics/model_io/bmz_io.py +3 -3
  53. careamics/models/activation.py +2 -0
  54. careamics/models/layers.py +121 -25
  55. careamics/models/model_factory.py +1 -1
  56. careamics/models/unet.py +35 -14
  57. careamics/prediction/stitch_prediction.py +2 -6
  58. careamics/transforms/__init__.py +2 -2
  59. careamics/transforms/compose.py +33 -7
  60. careamics/transforms/n2v_manipulate.py +49 -13
  61. careamics/transforms/normalize.py +55 -3
  62. careamics/transforms/pixel_manipulation.py +5 -5
  63. careamics/transforms/struct_mask_parameters.py +3 -1
  64. careamics/transforms/transform.py +10 -19
  65. careamics/transforms/xy_flip.py +123 -0
  66. careamics/transforms/xy_random_rotate90.py +38 -5
  67. careamics/utils/base_enum.py +28 -0
  68. careamics/utils/path_utils.py +2 -0
  69. careamics/utils/ram.py +2 -0
  70. careamics/utils/receptive_field.py +93 -87
  71. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc6.dist-info}/METADATA +2 -1
  72. careamics-0.1.0rc6.dist-info/RECORD +107 -0
  73. careamics/config/noise_models.py +0 -162
  74. careamics/config/support/supported_extraction_strategies.py +0 -25
  75. careamics/config/transformations/nd_flip_model.py +0 -27
  76. careamics/losses/noise_model_factory.py +0 -40
  77. careamics/losses/noise_models.py +0 -524
  78. careamics/transforms/nd_flip.py +0 -67
  79. careamics-0.1.0rc5.dist-info/RECORD +0 -111
  80. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc6.dist-info}/WHEEL +0 -0
  81. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc6.dist-info}/licenses/LICENSE +0 -0
@@ -1,3 +1,5 @@
1
+ """Sequential patching functions."""
2
+
1
3
  from typing import List, Optional, Tuple, Union
2
4
 
3
5
  import numpy as np
@@ -14,14 +16,14 @@ def _compute_number_of_patches(
14
16
 
15
17
  Parameters
16
18
  ----------
17
- arr : Tuple[int, ...]
19
+ arr_shape : Tuple[int, ...]
18
20
  Shape of the input array.
19
- patch_sizes : Tuple[int]
21
+ patch_sizes : Union[List[int], Tuple[int, ...]
20
22
  Shape of the patches.
21
23
 
22
24
  Returns
23
25
  -------
24
- Tuple[int]
26
+ Tuple[int, ...]
25
27
  Number of patches in each dimension.
26
28
  """
27
29
  if len(arr_shape) != len(patch_sizes):
@@ -55,14 +57,14 @@ def _compute_overlap(
55
57
 
56
58
  Parameters
57
59
  ----------
58
- arr : Tuple[int, ...]
60
+ arr_shape : Tuple[int, ...]
59
61
  Input array shape.
60
- patch_sizes : Tuple[int]
62
+ patch_sizes : Union[List[int], Tuple[int, ...]]
61
63
  Size of the patches.
62
64
 
63
65
  Returns
64
66
  -------
65
- Tuple[int]
67
+ Tuple[int, ...]
66
68
  Overlap between patches in each dimension.
67
69
  """
68
70
  n_patches = _compute_number_of_patches(arr_shape, patch_sizes)
@@ -123,6 +125,8 @@ def _compute_patch_views(
123
125
  Steps between views.
124
126
  output_shape : Tuple[int]
125
127
  Shape of the output array.
128
+ target : Optional[np.ndarray], optional
129
+ Target array, by default None.
126
130
 
127
131
  Returns
128
132
  -------
@@ -161,11 +165,13 @@ def extract_patches_sequential(
161
165
  Input image array.
162
166
  patch_size : Tuple[int]
163
167
  Patch sizes in each dimension.
168
+ target : Optional[np.ndarray], optional
169
+ Target array, by default None.
164
170
 
165
171
  Returns
166
172
  -------
167
- Generator[Tuple[np.ndarray, ...], None, None]
168
- Generator of patches.
173
+ Tuple[np.ndarray, Optional[np.ndarray]]
174
+ Patches.
169
175
  """
170
176
  is_3d_patch = len(patch_size) == 3
171
177
 
@@ -1,3 +1,5 @@
1
+ """Tiled patching utilities."""
2
+
1
3
  import itertools
2
4
  from typing import Generator, List, Tuple, Union
3
5
 
@@ -8,7 +10,7 @@ from careamics.config.tile_information import TileInformation
8
10
 
9
11
  def _compute_crop_and_stitch_coords_1d(
10
12
  axis_size: int, tile_size: int, overlap: int
11
- ) -> Tuple[List[Tuple[int, ...]], ...]:
13
+ ) -> Tuple[List[Tuple[int, int]], List[Tuple[int, int]], List[Tuple[int, int]]]:
12
14
  """
13
15
  Compute the coordinates of each tile along an axis, given the overlap.
14
16
 
@@ -1,3 +1,5 @@
1
+ """Patch validation functions."""
2
+
1
3
  from typing import List, Tuple, Union
2
4
 
3
5
  import numpy as np
@@ -1,3 +1,5 @@
1
+ """Zarr dataset."""
2
+
1
3
  # from itertools import islice
2
4
  # from typing import Callable, Dict, List, Optional, Tuple, Union
3
5
 
@@ -95,13 +95,13 @@ class CAREamicsTrainData(L.LightningDataModule):
95
95
  Batch size.
96
96
  use_in_memory : bool
97
97
  Whether to use in memory dataset if possible.
98
- train_data : Union[Path, str, np.ndarray]
98
+ train_data : Union[Path, np.ndarray]
99
99
  Training data.
100
- val_data : Optional[Union[Path, str, np.ndarray]]
100
+ val_data : Optional[Union[Path, np.ndarray]]
101
101
  Validation data.
102
- train_data_target : Optional[Union[Path, str, np.ndarray]]
102
+ train_data_target : Optional[Union[Path, np.ndarray]]
103
103
  Training target data.
104
- val_data_target : Optional[Union[Path, str, np.ndarray]]
104
+ val_data_target : Optional[Union[Path, np.ndarray]]
105
105
  Validation target data.
106
106
  val_percentage : float
107
107
  Percentage of the training data to use for validation, if no validation data is
@@ -217,17 +217,33 @@ class CAREamicsTrainData(L.LightningDataModule):
217
217
  )
218
218
 
219
219
  # configuration
220
- self.data_config = data_config
221
- self.data_type = data_config.data_type
222
- self.batch_size = data_config.batch_size
223
- self.use_in_memory = use_in_memory
220
+ self.data_config: DataConfig = data_config
221
+ self.data_type: str = data_config.data_type
222
+ self.batch_size: int = data_config.batch_size
223
+ self.use_in_memory: bool = use_in_memory
224
+
225
+ # data: make data Path or np.ndarray, use type annotations for mypy
226
+ self.train_data: Union[Path, np.ndarray] = (
227
+ Path(train_data) if isinstance(train_data, str) else train_data
228
+ )
229
+
230
+ self.val_data: Union[Path, np.ndarray] = (
231
+ Path(val_data) if isinstance(val_data, str) else val_data
232
+ )
233
+
234
+ self.train_data_target: Union[Path, np.ndarray] = (
235
+ Path(train_data_target)
236
+ if isinstance(train_data_target, str)
237
+ else train_data_target
238
+ )
224
239
 
225
- # data
226
- self.train_data = train_data
227
- self.val_data = val_data
240
+ self.val_data_target: Union[Path, np.ndarray] = (
241
+ Path(val_data_target)
242
+ if isinstance(val_data_target, str)
243
+ else val_data_target
244
+ )
228
245
 
229
- self.train_data_target = train_data_target
230
- self.val_data_target = val_data_target
246
+ # validation split
231
247
  self.val_percentage = val_percentage
232
248
  self.val_minimum_split = val_minimum_split
233
249
 
@@ -241,10 +257,10 @@ class CAREamicsTrainData(L.LightningDataModule):
241
257
  elif data_config.data_type != SupportedData.ARRAY:
242
258
  self.read_source_func = get_read_func(data_config.data_type)
243
259
 
244
- self.extension_filter = extension_filter
260
+ self.extension_filter: str = extension_filter
245
261
 
246
262
  # Pytorch dataloader parameters
247
- self.dataloader_params = (
263
+ self.dataloader_params: Dict[str, Any] = (
248
264
  data_config.dataloader_params if data_config.dataloader_params else {}
249
265
  )
250
266
 
@@ -309,20 +325,30 @@ class CAREamicsTrainData(L.LightningDataModule):
309
325
  """
310
326
  # if numpy array
311
327
  if self.data_type == SupportedData.ARRAY:
328
+ # mypy checks
329
+ assert isinstance(self.train_data, np.ndarray)
330
+ if self.train_data_target is not None:
331
+ assert isinstance(self.train_data_target, np.ndarray)
332
+
312
333
  # train dataset
313
334
  self.train_dataset: DatasetType = InMemoryDataset(
314
335
  data_config=self.data_config,
315
336
  inputs=self.train_data,
316
- data_target=self.train_data_target,
337
+ input_target=self.train_data_target,
317
338
  )
318
339
 
319
340
  # validation dataset
320
341
  if self.val_data is not None:
342
+ # mypy checks
343
+ assert isinstance(self.val_data, np.ndarray)
344
+ if self.val_data_target is not None:
345
+ assert isinstance(self.val_data_target, np.ndarray)
346
+
321
347
  # create its own dataset
322
348
  self.val_dataset: DatasetType = InMemoryDataset(
323
349
  data_config=self.data_config,
324
350
  inputs=self.val_data,
325
- data_target=self.val_data_target,
351
+ input_target=self.val_data_target,
326
352
  )
327
353
  else:
328
354
  # extract validation from the training patches
@@ -341,7 +367,7 @@ class CAREamicsTrainData(L.LightningDataModule):
341
367
  self.train_dataset = InMemoryDataset(
342
368
  data_config=self.data_config,
343
369
  inputs=self.train_files,
344
- data_target=(
370
+ input_target=(
345
371
  self.train_target_files if self.train_data_target else None
346
372
  ),
347
373
  read_source_func=self.read_source_func,
@@ -352,7 +378,7 @@ class CAREamicsTrainData(L.LightningDataModule):
352
378
  self.val_dataset = InMemoryDataset(
353
379
  data_config=self.data_config,
354
380
  inputs=self.val_files,
355
- data_target=(
381
+ input_target=(
356
382
  self.val_target_files if self.val_data_target else None
357
383
  ),
358
384
  read_source_func=self.read_source_func,
@@ -1,3 +1,5 @@
1
+ """CAREamics Lightning module."""
2
+
1
3
  from typing import Any, Optional, Union
2
4
 
3
5
  import pytorch_lightning as L
@@ -24,6 +26,11 @@ class CAREamicsModule(L.LightningModule):
24
26
  This class encapsulates the a PyTorch model along with the training, validation,
25
27
  and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
26
28
 
29
+ Parameters
30
+ ----------
31
+ algorithm_config : Union[AlgorithmModel, dict]
32
+ Algorithm configuration.
33
+
27
34
  Attributes
28
35
  ----------
29
36
  model : nn.Module
@@ -39,8 +46,7 @@ class CAREamicsModule(L.LightningModule):
39
46
  """
40
47
 
41
48
  def __init__(self, algorithm_config: Union[AlgorithmConfig, dict]) -> None:
42
- """
43
- CAREamics Lightning module.
49
+ """Lightning module for CAREamics.
44
50
 
45
51
  This class encapsulates the a PyTorch model along with the training, validation,
46
52
  and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
@@ -1,7 +1,7 @@
1
1
  """Prediction Lightning data modules."""
2
2
 
3
3
  from pathlib import Path
4
- from typing import Any, Callable, List, Literal, Optional, Tuple, Union
4
+ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
5
5
 
6
6
  import numpy as np
7
7
  import pytorch_lightning as L
@@ -303,9 +303,6 @@ class PredictDataWrapper(CAREamicsPredictData):
303
303
  Batch size.
304
304
  tta_transforms : bool, optional
305
305
  Use test time augmentation, by default True.
306
- transforms : List, optional
307
- List of transforms to apply to prediction patches. If None, default
308
- transforms are applied.
309
306
  read_source_func : Optional[Callable], optional
310
307
  Function to read the source data, used if `data_type` is `custom`, by
311
308
  default None.
@@ -326,7 +323,6 @@ class PredictDataWrapper(CAREamicsPredictData):
326
323
  axes: str = "YX",
327
324
  batch_size: int = 1,
328
325
  tta_transforms: bool = True,
329
- transforms: Optional[List] = None,
330
326
  read_source_func: Optional[Callable] = None,
331
327
  extension_filter: str = "",
332
328
  dataloader_params: Optional[dict] = None,
@@ -356,9 +352,6 @@ class PredictDataWrapper(CAREamicsPredictData):
356
352
  Batch size.
357
353
  tta_transforms : bool, optional
358
354
  Use test time augmentation, by default True.
359
- transforms : Optional[List], optional
360
- List of transforms to apply to prediction patches. If None, default
361
- transforms are applied.
362
355
  read_source_func : Optional[Callable], optional
363
356
  Function to read the source data, used if `data_type` is `custom`, by
364
357
  default None.
@@ -369,7 +362,7 @@ class PredictDataWrapper(CAREamicsPredictData):
369
362
  """
370
363
  if dataloader_params is None:
371
364
  dataloader_params = {}
372
- prediction_dict = {
365
+ prediction_dict: Dict[str, Any] = {
373
366
  "data_type": data_type,
374
367
  "tile_size": tile_size,
375
368
  "tile_overlap": tile_overlap,
@@ -378,12 +371,9 @@ class PredictDataWrapper(CAREamicsPredictData):
378
371
  "std": std,
379
372
  "tta": tta_transforms,
380
373
  "batch_size": batch_size,
374
+ "transforms": [],
381
375
  }
382
376
 
383
- # if transforms are passed (otherwise it will use the default ones)
384
- if transforms is not None:
385
- prediction_dict["transforms"] = transforms
386
-
387
377
  # validate configuration
388
378
  self.prediction_config = InferenceConfig(**prediction_dict)
389
379
 
@@ -1,3 +1,5 @@
1
+ """Lithning prediction loop allowing tiling."""
2
+
1
3
  from typing import Optional
2
4
 
3
5
  import pytorch_lightning as L
@@ -18,14 +20,14 @@ class CAREamicsPredictionLoop(L.loops._PredictionLoop):
18
20
  """
19
21
 
20
22
  def _on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]:
21
- """
22
- Calls `on_predict_epoch_end` hook.
23
+ """Call `on_predict_epoch_end` hook.
23
24
 
24
25
  Adapted from the parent method.
25
26
 
26
27
  Returns
27
28
  -------
28
- the results for all dataloaders
29
+ Optional[_PREDICT_OUTPUT]
30
+ Prediction output.
29
31
  """
30
32
  trainer = self.trainer
31
33
  call._call_callback_hooks(trainer, "on_predict_epoch_end")
@@ -45,15 +47,14 @@ class CAREamicsPredictionLoop(L.loops._PredictionLoop):
45
47
 
46
48
  @_no_grad_context
47
49
  def run(self) -> Optional[_PREDICT_OUTPUT]:
48
- """
49
- Runs the prediction loop.
50
+ """Run the prediction loop.
50
51
 
51
52
  Adapted from the parent method in order to stitch the predictions.
52
53
 
53
54
  Returns
54
55
  -------
55
56
  Optional[_PREDICT_OUTPUT]
56
- Prediction output
57
+ Prediction output.
57
58
  """
58
59
  self.setup_data()
59
60
  if self.skip:
@@ -86,6 +87,7 @@ class CAREamicsPredictionLoop(L.loops._PredictionLoop):
86
87
 
87
88
  ########################################################
88
89
  ################ CAREamics specific code ###############
90
+ # TODO: next line is not compatible with muSplit
89
91
  is_tiled = len(self.predictions[batch_idx]) == 2
90
92
  if is_tiled:
91
93
  # extract the last tile flag and the coordinates (crop and stitch)
@@ -1,6 +1,5 @@
1
1
  """Losses module."""
2
2
 
3
- from .loss_factory import loss_factory
3
+ __all__ = ["loss_factory"]
4
4
 
5
- # from .noise_model_factory import noise_model_factory as noise_model_factory
6
- # from .noise_models import GaussianMixtureNoiseModel, HistogramNoiseModel
5
+ from .loss_factory import loss_factory
@@ -17,7 +17,7 @@ def loss_factory(loss: Union[SupportedLoss, str]) -> Callable:
17
17
 
18
18
  Parameters
19
19
  ----------
20
- loss: SupportedLoss
20
+ loss : Union[SupportedLoss, str]
21
21
  Requested loss.
22
22
 
23
23
  Returns
@@ -5,23 +5,27 @@ This submodule contains the various losses used in CAREamics.
5
5
  """
6
6
 
7
7
  import torch
8
-
9
- # TODO if we are only using the DiceLoss, can we just implement it?
10
- # from segmentation_models_pytorch.losses import DiceLoss
11
8
  from torch.nn import L1Loss, MSELoss
12
9
 
13
10
 
14
- def mse_loss(samples: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
11
+ def mse_loss(source: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
15
12
  """
16
13
  Mean squared error loss.
17
14
 
15
+ Parameters
16
+ ----------
17
+ source : torch.Tensor
18
+ Source patches.
19
+ target : torch.Tensor
20
+ Target patches.
21
+
18
22
  Returns
19
23
  -------
20
24
  torch.Tensor
21
25
  Loss value.
22
26
  """
23
27
  loss = MSELoss()
24
- return loss(samples, labels)
28
+ return loss(source, target)
25
29
 
26
30
 
27
31
  def n2v_loss(
@@ -34,9 +38,9 @@ def n2v_loss(
34
38
 
35
39
  Parameters
36
40
  ----------
37
- samples : torch.Tensor
41
+ manipulated_patches : torch.Tensor
38
42
  Patches with manipulated pixels.
39
- labels : torch.Tensor
43
+ original_patches : torch.Tensor
40
44
  Noisy patches.
41
45
  masks : torch.Tensor
42
46
  Array containing masked pixel locations.
@@ -104,9 +104,9 @@ def export_to_bmz(
104
104
  authors : List[dict]
105
105
  Authors of the model.
106
106
  input_array : np.ndarray
107
- Input array.
107
+ Input array, should not have been normalized.
108
108
  output_array : np.ndarray
109
- Output array.
109
+ Output array, should have been denormalized.
110
110
  channel_names : Optional[List[str]], optional
111
111
  Channel names, by default None.
112
112
  data_description : Optional[str], optional
@@ -178,7 +178,7 @@ def export_to_bmz(
178
178
  )
179
179
 
180
180
  # test model description
181
- summary: ValidationSummary = test_model(model_description, decimal=0)
181
+ summary: ValidationSummary = test_model(model_description, decimal=2)
182
182
  if summary.status == "failed":
183
183
  raise ValueError(f"Model description test failed: {summary}")
184
184
 
@@ -1,3 +1,5 @@
1
+ """Activations for CAREamics models."""
2
+
1
3
  from typing import Callable, Union
2
4
 
3
5
  import torch.nn as nn
@@ -162,6 +162,18 @@ def _unpack_kernel_size(
162
162
  """Unpack kernel_size to a tuple of ints.
163
163
 
164
164
  Inspired by Kornia implementation. TODO: link
165
+
166
+ Parameters
167
+ ----------
168
+ kernel_size : Union[Tuple[int, ...], int]
169
+ Kernel size.
170
+ dim : int
171
+ Number of dimensions.
172
+
173
+ Returns
174
+ -------
175
+ Tuple[int, ...]
176
+ Kernel size tuple.
165
177
  """
166
178
  if isinstance(kernel_size, int):
167
179
  kernel_dims = tuple([kernel_size for _ in range(dim)])
@@ -173,7 +185,20 @@ def _unpack_kernel_size(
173
185
  def _compute_zero_padding(
174
186
  kernel_size: Union[Tuple[int, ...], int], dim: int
175
187
  ) -> Tuple[int, ...]:
176
- """Utility function that computes zero padding tuple."""
188
+ """Utility function that computes zero padding tuple.
189
+
190
+ Parameters
191
+ ----------
192
+ kernel_size : Union[Tuple[int, ...], int]
193
+ Kernel size.
194
+ dim : int
195
+ Number of dimensions.
196
+
197
+ Returns
198
+ -------
199
+ Tuple[int, ...]
200
+ Zero padding tuple.
201
+ """
177
202
  kernel_dims = _unpack_kernel_size(kernel_size, dim)
178
203
  return tuple([(kd - 1) // 2 for kd in kernel_dims])
179
204
 
@@ -191,14 +216,19 @@ def get_pascal_kernel_1d(
191
216
 
192
217
  Parameters
193
218
  ----------
194
- kernel_size: height and width of the kernel.
195
- norm: if to normalize the kernel or not. Default: False.
196
- device: tensor device
197
- dtype: tensor dtype
219
+ kernel_size : int
220
+ Kernel size.
221
+ norm : bool
222
+ Normalize the kernel, by default False.
223
+ device : Optional[torch.device]
224
+ Device of the tensor, by default None.
225
+ dtype : Optional[torch.dtype]
226
+ Data type of the tensor, by default None.
198
227
 
199
228
  Returns
200
229
  -------
201
- kernel shaped as :math:`(kernel_size,)`
230
+ torch.Tensor
231
+ Pascal kernel.
202
232
 
203
233
  Examples
204
234
  --------
@@ -245,19 +275,28 @@ def _get_pascal_kernel_nd(
245
275
  ) -> torch.Tensor:
246
276
  """Generate pascal filter kernel by kernel size.
247
277
 
278
+ If kernel_size is an integer the kernel will be shaped as (kernel_size, kernel_size)
279
+ otherwise the kernel will be shaped as kernel_size
280
+
248
281
  Inspired by Kornia implementation.
249
282
 
250
283
  Parameters
251
284
  ----------
252
- kernel_size: height and width of the kernel.
253
- norm: if to normalize the kernel or not. Default: True.
254
- device: tensor device
255
- dtype: tensor dtype
285
+ kernel_size : Union[Tuple[int, int], int]
286
+ Kernel size for the pascal kernel.
287
+ norm : bool
288
+ Normalize the kernel, by default True.
289
+ dim : int
290
+ Number of dimensions, by default 2.
291
+ device : Optional[torch.device]
292
+ Device of the tensor, by default None.
293
+ dtype : Optional[torch.dtype]
294
+ Data type of the tensor, by default None.
256
295
 
257
296
  Returns
258
297
  -------
259
- if kernel_size is an integer the kernel will be shaped as (kernel_size, kernel_size)
260
- otherwise the kernel will be shaped as kernel_size
298
+ torch.Tensor
299
+ Pascal kernel.
261
300
 
262
301
  Examples
263
302
  --------
@@ -303,6 +342,24 @@ def _max_blur_pool_by_kernel2d(
303
342
  """Compute max_blur_pool by a given :math:`CxC_(out, None)xNxN` kernel.
304
343
 
305
344
  Inspired by Kornia implementation.
345
+
346
+ Parameters
347
+ ----------
348
+ x : torch.Tensor
349
+ Input tensor.
350
+ kernel : torch.Tensor
351
+ Kernel tensor.
352
+ stride : int
353
+ Stride.
354
+ max_pool_size : int
355
+ Maximum pool size.
356
+ ceil_mode : bool
357
+ Ceil mode, by default False. Set to True to match output size of conv2d.
358
+
359
+ Returns
360
+ -------
361
+ torch.Tensor
362
+ Output tensor.
306
363
  """
307
364
  # compute local maxima
308
365
  x = F.max_pool2d(
@@ -323,6 +380,24 @@ def _max_blur_pool_by_kernel3d(
323
380
  """Compute max_blur_pool by a given :math:`CxC_(out, None)xNxNxN` kernel.
324
381
 
325
382
  Inspired by Kornia implementation.
383
+
384
+ Parameters
385
+ ----------
386
+ x : torch.Tensor
387
+ Input tensor.
388
+ kernel : torch.Tensor
389
+ Kernel tensor.
390
+ stride : int
391
+ Stride.
392
+ max_pool_size : int
393
+ Maximum pool size.
394
+ ceil_mode : bool
395
+ Ceil mode, by default False. Set to True to match output size of conv2d.
396
+
397
+ Returns
398
+ -------
399
+ torch.Tensor
400
+ Output tensor.
326
401
  """
327
402
  # compute local maxima
328
403
  x = F.max_pool3d(
@@ -343,21 +418,16 @@ class MaxBlurPool(nn.Module):
343
418
 
344
419
  Parameters
345
420
  ----------
346
- dim: int
347
- Toggles between 2D and 3D
348
- kernel_size: Union[Tuple[int, int], int]
421
+ dim : int
422
+ Toggles between 2D and 3D.
423
+ kernel_size : Union[Tuple[int, int], int]
349
424
  Kernel size for max pooling.
350
- stride: int
425
+ stride : int
351
426
  Stride for pooling.
352
- max_pool_size: int
427
+ max_pool_size : int
353
428
  Max kernel size for max pooling.
354
- ceil_mode: bool
355
- Should be true to match output size of conv2d with same kernel size.
356
-
357
- Returns
358
- -------
359
- torch.Tensor
360
- The pooled and blurred tensor.
429
+ ceil_mode : bool
430
+ Ceil mode, by default False. Set to True to match output size of conv2d.
361
431
  """
362
432
 
363
433
  def __init__(
@@ -368,6 +438,21 @@ class MaxBlurPool(nn.Module):
368
438
  max_pool_size: int = 2,
369
439
  ceil_mode: bool = False,
370
440
  ) -> None:
441
+ """Constructor.
442
+
443
+ Parameters
444
+ ----------
445
+ dim : int
446
+ Dimension of the convolution.
447
+ kernel_size : Union[Tuple[int, int], int]
448
+ Kernel size for max pooling.
449
+ stride : int, optional
450
+ Stride, by default 2.
451
+ max_pool_size : int, optional
452
+ Maximum pool size, by default 2.
453
+ ceil_mode : bool, optional
454
+ Ceil mode, by default False. Set to True to match output size of conv2d.
455
+ """
371
456
  super().__init__()
372
457
  self.dim = dim
373
458
  self.kernel_size = kernel_size
@@ -377,7 +462,18 @@ class MaxBlurPool(nn.Module):
377
462
  self.kernel = _get_pascal_kernel_nd(kernel_size, norm=True, dim=self.dim)
378
463
 
379
464
  def forward(self, x: torch.Tensor) -> torch.Tensor:
380
- """Forward pass of the function."""
465
+ """Forward pass of the function.
466
+
467
+ Parameters
468
+ ----------
469
+ x : torch.Tensor
470
+ Input tensor.
471
+
472
+ Returns
473
+ -------
474
+ torch.Tensor
475
+ Output tensor.
476
+ """
381
477
  self.kernel = torch.as_tensor(self.kernel, device=x.device, dtype=x.dtype)
382
478
  if self.dim == 2:
383
479
  return _max_blur_pool_by_kernel2d(
@@ -27,7 +27,7 @@ def model_factory(
27
27
  Parameters
28
28
  ----------
29
29
  model_configuration : Union[UNetModel, VAEModel]
30
- Model configuration
30
+ Model configuration.
31
31
 
32
32
  Returns
33
33
  -------