careamics 0.1.0rc4__py3-none-any.whl → 0.1.0rc6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (103) hide show
  1. careamics/callbacks/hyperparameters_callback.py +10 -3
  2. careamics/callbacks/progress_bar_callback.py +37 -4
  3. careamics/careamist.py +92 -55
  4. careamics/config/__init__.py +0 -1
  5. careamics/config/algorithm_model.py +5 -3
  6. careamics/config/architectures/architecture_model.py +7 -0
  7. careamics/config/architectures/custom_model.py +8 -1
  8. careamics/config/architectures/register_model.py +3 -1
  9. careamics/config/architectures/unet_model.py +3 -0
  10. careamics/config/architectures/vae_model.py +2 -0
  11. careamics/config/callback_model.py +4 -15
  12. careamics/config/configuration_example.py +4 -4
  13. careamics/config/configuration_factory.py +113 -55
  14. careamics/config/configuration_model.py +14 -16
  15. careamics/config/data_model.py +63 -165
  16. careamics/config/inference_model.py +9 -75
  17. careamics/config/optimizer_models.py +4 -4
  18. careamics/config/references/algorithm_descriptions.py +1 -0
  19. careamics/config/references/references.py +1 -0
  20. careamics/config/support/__init__.py +0 -2
  21. careamics/config/support/supported_activations.py +2 -0
  22. careamics/config/support/supported_algorithms.py +3 -1
  23. careamics/config/support/supported_architectures.py +2 -0
  24. careamics/config/support/supported_data.py +2 -0
  25. careamics/config/support/supported_loggers.py +2 -0
  26. careamics/config/support/supported_losses.py +2 -0
  27. careamics/config/support/supported_optimizers.py +2 -0
  28. careamics/config/support/supported_pixel_manipulations.py +3 -3
  29. careamics/config/support/supported_struct_axis.py +2 -0
  30. careamics/config/support/supported_transforms.py +4 -15
  31. careamics/config/tile_information.py +2 -0
  32. careamics/config/training_model.py +1 -0
  33. careamics/config/transformations/__init__.py +3 -2
  34. careamics/config/transformations/n2v_manipulate_model.py +1 -0
  35. careamics/config/transformations/normalize_model.py +1 -0
  36. careamics/config/transformations/transform_model.py +1 -0
  37. careamics/config/transformations/xy_flip_model.py +43 -0
  38. careamics/config/transformations/xy_random_rotate90_model.py +13 -7
  39. careamics/config/validators/validator_utils.py +1 -0
  40. careamics/conftest.py +13 -0
  41. careamics/dataset/dataset_utils/__init__.py +0 -1
  42. careamics/dataset/dataset_utils/dataset_utils.py +5 -4
  43. careamics/dataset/dataset_utils/file_utils.py +4 -3
  44. careamics/dataset/dataset_utils/read_tiff.py +6 -2
  45. careamics/dataset/dataset_utils/read_utils.py +2 -0
  46. careamics/dataset/dataset_utils/read_zarr.py +11 -7
  47. careamics/dataset/in_memory_dataset.py +84 -76
  48. careamics/dataset/iterable_dataset.py +166 -134
  49. careamics/dataset/patching/__init__.py +0 -7
  50. careamics/dataset/patching/patching.py +56 -14
  51. careamics/dataset/patching/random_patching.py +8 -2
  52. careamics/dataset/patching/sequential_patching.py +20 -14
  53. careamics/dataset/patching/tiled_patching.py +13 -7
  54. careamics/dataset/patching/validate_patch_dimension.py +2 -0
  55. careamics/dataset/zarr_dataset.py +2 -0
  56. careamics/lightning_datamodule.py +63 -41
  57. careamics/lightning_module.py +9 -3
  58. careamics/lightning_prediction_datamodule.py +15 -20
  59. careamics/lightning_prediction_loop.py +8 -6
  60. careamics/losses/__init__.py +1 -3
  61. careamics/losses/loss_factory.py +2 -1
  62. careamics/losses/losses.py +11 -7
  63. careamics/model_io/__init__.py +0 -1
  64. careamics/model_io/bioimage/_readme_factory.py +2 -1
  65. careamics/model_io/bioimage/bioimage_utils.py +1 -0
  66. careamics/model_io/bioimage/model_description.py +1 -0
  67. careamics/model_io/bmz_io.py +4 -3
  68. careamics/models/activation.py +2 -0
  69. careamics/models/layers.py +122 -25
  70. careamics/models/model_factory.py +2 -1
  71. careamics/models/unet.py +114 -19
  72. careamics/prediction/stitch_prediction.py +2 -5
  73. careamics/transforms/__init__.py +4 -25
  74. careamics/transforms/compose.py +124 -0
  75. careamics/transforms/n2v_manipulate.py +65 -34
  76. careamics/transforms/normalize.py +91 -28
  77. careamics/transforms/pixel_manipulation.py +7 -7
  78. careamics/transforms/struct_mask_parameters.py +3 -1
  79. careamics/transforms/transform.py +24 -0
  80. careamics/transforms/tta.py +2 -2
  81. careamics/transforms/xy_flip.py +123 -0
  82. careamics/transforms/xy_random_rotate90.py +66 -60
  83. careamics/utils/__init__.py +0 -1
  84. careamics/utils/base_enum.py +28 -0
  85. careamics/utils/context.py +1 -0
  86. careamics/utils/logging.py +1 -0
  87. careamics/utils/metrics.py +1 -0
  88. careamics/utils/path_utils.py +2 -0
  89. careamics/utils/ram.py +2 -0
  90. careamics/utils/receptive_field.py +93 -87
  91. careamics/utils/torch_utils.py +1 -0
  92. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/METADATA +17 -61
  93. careamics-0.1.0rc6.dist-info/RECORD +107 -0
  94. careamics/config/noise_models.py +0 -162
  95. careamics/config/support/supported_extraction_strategies.py +0 -24
  96. careamics/config/transformations/nd_flip_model.py +0 -32
  97. careamics/dataset/patching/patch_transform.py +0 -44
  98. careamics/losses/noise_model_factory.py +0 -40
  99. careamics/losses/noise_models.py +0 -524
  100. careamics/transforms/nd_flip.py +0 -93
  101. careamics-0.1.0rc4.dist-info/RECORD +0 -110
  102. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/WHEEL +0 -0
  103. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,123 @@
1
+ """XY flip transform."""
2
+
3
+ from typing import Optional, Tuple
4
+
5
+ import numpy as np
6
+
7
+ from careamics.transforms.transform import Transform
8
+
9
+
10
+ class XYFlip(Transform):
11
+ """Flip image along X and Y axis, one at a time.
12
+
13
+ This transform randomly flips one of the last two axes.
14
+
15
+ This transform expects C(Z)YX dimensions.
16
+
17
+ Attributes
18
+ ----------
19
+ axis_indices : List[int]
20
+ Indices of the axes that can be flipped.
21
+ rng : np.random.Generator
22
+ Random number generator.
23
+ p : float
24
+ Probability of applying the transform.
25
+ seed : Optional[int]
26
+ Random seed.
27
+
28
+ Parameters
29
+ ----------
30
+ flip_x : bool, optional
31
+ Whether to flip along the X axis, by default True.
32
+ flip_y : bool, optional
33
+ Whether to flip along the Y axis, by default True.
34
+ p : float, optional
35
+ Probability of applying the transform, by default 0.5.
36
+ seed : Optional[int], optional
37
+ Random seed, by default None.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ flip_x: bool = True,
43
+ flip_y: bool = True,
44
+ p: float = 0.5,
45
+ seed: Optional[int] = None,
46
+ ) -> None:
47
+ """Constructor.
48
+
49
+ Parameters
50
+ ----------
51
+ flip_x : bool, optional
52
+ Whether to flip along the X axis, by default True.
53
+ flip_y : bool, optional
54
+ Whether to flip along the Y axis, by default True.
55
+ p : float
56
+ Probability of applying the transform, by default 0.5.
57
+ seed : Optional[int], optional
58
+ Random seed, by default None.
59
+ """
60
+ if p < 0 or p > 1:
61
+ raise ValueError("Probability must be in [0, 1].")
62
+
63
+ if not flip_x and not flip_y:
64
+ raise ValueError("At least one axis must be flippable.")
65
+
66
+ # probability to apply the transform
67
+ self.p = p
68
+
69
+ # "flippable" axes
70
+ self.axis_indices = []
71
+
72
+ if flip_y:
73
+ self.axis_indices.append(-2)
74
+ if flip_x:
75
+ self.axis_indices.append(-1)
76
+
77
+ # numpy random generator
78
+ self.rng = np.random.default_rng(seed=seed)
79
+
80
+ def __call__(
81
+ self, patch: np.ndarray, target: Optional[np.ndarray] = None
82
+ ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
83
+ """Apply the transform to the source patch and the target (optional).
84
+
85
+ Parameters
86
+ ----------
87
+ patch : np.ndarray
88
+ Patch, 2D or 3D, shape C(Z)YX.
89
+ target : Optional[np.ndarray], optional
90
+ Target for the patch, by default None.
91
+
92
+ Returns
93
+ -------
94
+ Tuple[np.ndarray, Optional[np.ndarray]]
95
+ Transformed patch and target.
96
+ """
97
+ if self.rng.random() > self.p:
98
+ return patch, target
99
+
100
+ # choose an axis to flip
101
+ axis = self.rng.choice(self.axis_indices)
102
+
103
+ patch_transformed = self._apply(patch, axis)
104
+ target_transformed = self._apply(target, axis) if target is not None else None
105
+
106
+ return patch_transformed, target_transformed
107
+
108
+ def _apply(self, patch: np.ndarray, axis: int) -> np.ndarray:
109
+ """Apply the transform to the image.
110
+
111
+ Parameters
112
+ ----------
113
+ patch : np.ndarray
114
+ Image patch, 2D or 3D, shape C(Z)YX.
115
+ axis : int
116
+ Axis to flip.
117
+
118
+ Returns
119
+ -------
120
+ np.ndarray
121
+ Flipped image patch.
122
+ """
123
+ return np.ascontiguousarray(np.flip(patch, axis=axis))
@@ -1,95 +1,101 @@
1
- from typing import Any, Dict, Tuple
1
+ """Patch transform applying XY random 90 degrees rotations."""
2
+
3
+ from typing import Optional, Tuple
2
4
 
3
5
  import numpy as np
4
- from albumentations import DualTransform
6
+
7
+ from careamics.transforms.transform import Transform
5
8
 
6
9
 
7
- class XYRandomRotate90(DualTransform):
10
+ class XYRandomRotate90(Transform):
8
11
  """Applies random 90 degree rotations to the YX axis.
9
12
 
10
- This transform expects (Z)YXC dimensions.
13
+ This transform expects C(Z)YX dimensions.
14
+
15
+ Attributes
16
+ ----------
17
+ rng : np.random.Generator
18
+ Random number generator.
19
+ p : float
20
+ Probability of applying the transform.
21
+ seed : Optional[int]
22
+ Random seed.
11
23
 
12
24
  Parameters
13
25
  ----------
14
- p : int, optional
15
- Probability to apply the transform, by default 0.5
16
- is_3D : bool, optional
17
- Whether the patches are 3D, by default False
26
+ p : float
27
+ Probability of applying the transform, by default 0.5.
28
+ seed : Optional[int]
29
+ Random seed, by default None.
18
30
  """
19
31
 
20
- def __init__(self, p: float = 0.5, is_3D: bool = False):
32
+ def __init__(self, p: float = 0.5, seed: Optional[int] = None):
21
33
  """Constructor.
22
34
 
23
35
  Parameters
24
36
  ----------
25
- p : float, optional
26
- Probability to apply the transform, by default 0.5
27
- is_3D : bool, optional
28
- Whether the patches are 3D, by default False
37
+ p : float
38
+ Probability of applying the transform, by default 0.5.
39
+ seed : Optional[int]
40
+ Random seed, by default None.
29
41
  """
30
- super().__init__(p=p)
42
+ if p < 0 or p > 1:
43
+ raise ValueError("Probability must be in [0, 1].")
44
+
45
+ # probability to apply the transform
46
+ self.p = p
31
47
 
32
- self.is_3D = is_3D
48
+ # numpy random generator
49
+ self.rng = np.random.default_rng(seed=seed)
33
50
 
34
- # rotation axes
35
- if is_3D:
36
- self.axes = (1, 2)
37
- else:
38
- self.axes = (0, 1)
51
+ def __call__(
52
+ self, patch: np.ndarray, target: Optional[np.ndarray] = None
53
+ ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
54
+ """Apply the transform to the source patch and the target (optional).
39
55
 
40
- def get_params(self, **kwargs: Any) -> Dict[str, int]:
41
- """Get the transform parameters.
56
+ Parameters
57
+ ----------
58
+ patch : np.ndarray
59
+ Patch, 2D or 3D, shape C(Z)YX.
60
+ target : Optional[np.ndarray], optional
61
+ Target for the patch, by default None.
42
62
 
43
63
  Returns
44
64
  -------
45
- Dict[str, int]
46
- Transform parameters.
65
+ Tuple[np.ndarray, Optional[np.ndarray]]
66
+ Transformed patch and target.
47
67
  """
48
- return {"n_rotations": np.random.randint(1, 4)}
68
+ if self.rng.random() > self.p:
69
+ return patch, target
49
70
 
50
- def apply(self, patch: np.ndarray, n_rotations: int, **kwargs: Any) -> np.ndarray:
51
- """Apply the transform to the image.
71
+ # number of rotations
72
+ n_rot = self.rng.integers(1, 4)
52
73
 
53
- Parameters
54
- ----------
55
- patch : np.ndarray
56
- Image or image patch, 2D or 3D, shape (y, x, c) or (z, y, x, c).
57
- flip_axis : int
58
- Axis along which to flip the patch.
59
- """
60
- if len(patch.shape) == 3 and self.is_3D:
61
- raise ValueError(
62
- "Incompatible patch shape and dimensionality. ZYXC patch shape "
63
- "expected, but got YXC shape."
64
- )
74
+ axes = (-2, -1)
75
+ patch_transformed = self._apply(patch, n_rot, axes)
76
+ target_transformed = (
77
+ self._apply(target, n_rot, axes) if target is not None else None
78
+ )
65
79
 
66
- return np.ascontiguousarray(np.rot90(patch, k=n_rotations, axes=self.axes))
80
+ return patch_transformed, target_transformed
67
81
 
68
- def apply_to_mask(
69
- self, mask: np.ndarray, n_rotations: int, **kwargs: Any
82
+ def _apply(
83
+ self, patch: np.ndarray, n_rot: int, axes: Tuple[int, int]
70
84
  ) -> np.ndarray:
71
- """Apply the transform to the mask.
85
+ """Apply the transform to the image.
72
86
 
73
87
  Parameters
74
88
  ----------
75
- mask : np.ndarray
76
- Mask or mask patch, 2D or 3D, shape (y, x, c) or (z, y, x, c).
77
- """
78
- if len(mask.shape) != 4 and self.is_3D:
79
- raise ValueError(
80
- "Incompatible mask shape and dimensionality. ZYXC patch shape "
81
- "expected, but got YXC shape."
82
- )
83
-
84
- return np.ascontiguousarray(np.rot90(mask, k=n_rotations, axes=self.axes))
85
-
86
- def get_transform_init_args_names(self) -> Tuple[str, str]:
87
- """
88
- Get the transform arguments.
89
+ patch : np.ndarray
90
+ Image or image patch, 2D or 3D, shape C(Z)YX.
91
+ n_rot : int
92
+ Number of 90 degree rotations.
93
+ axes : Tuple[int, int]
94
+ Axes along which to rotate the patch.
89
95
 
90
96
  Returns
91
97
  -------
92
- Tuple[str]
93
- Transform arguments.
98
+ np.ndarray
99
+ Transformed patch.
94
100
  """
95
- return ("p", "is_3D")
101
+ return np.ascontiguousarray(np.rot90(patch, k=n_rot, axes=axes))
@@ -1,6 +1,5 @@
1
1
  """Utils module."""
2
2
 
3
-
4
3
  __all__ = [
5
4
  "cwd",
6
5
  "get_ram_size",
@@ -1,9 +1,25 @@
1
+ """A base class for Enum that allows checking if a value is in the Enum."""
2
+
1
3
  from enum import Enum, EnumMeta
2
4
  from typing import Any
3
5
 
4
6
 
5
7
  class _ContainerEnum(EnumMeta):
8
+ """Metaclass for Enum with __contains__ method."""
9
+
6
10
  def __contains__(cls, item: Any) -> bool:
11
+ """Check if an item is in the Enum.
12
+
13
+ Parameters
14
+ ----------
15
+ item : Any
16
+ Item to check.
17
+
18
+ Returns
19
+ -------
20
+ bool
21
+ True if the item is in the Enum, False otherwise.
22
+ """
7
23
  try:
8
24
  cls(item)
9
25
  except ValueError:
@@ -12,6 +28,18 @@ class _ContainerEnum(EnumMeta):
12
28
 
13
29
  @classmethod
14
30
  def has_value(cls, value: Any) -> bool:
31
+ """Check if a value is in the Enum.
32
+
33
+ Parameters
34
+ ----------
35
+ value : Any
36
+ Value to check.
37
+
38
+ Returns
39
+ -------
40
+ bool
41
+ True if the value is in the Enum, False otherwise.
42
+ """
15
43
  return value in cls._value2member_map_
16
44
 
17
45
 
@@ -3,6 +3,7 @@ Context submodule.
3
3
 
4
4
  A convenience function to change the working directory in order to save data.
5
5
  """
6
+
6
7
  import os
7
8
  from contextlib import contextmanager
8
9
  from pathlib import Path
@@ -3,6 +3,7 @@ Logging submodule.
3
3
 
4
4
  The methods are responsible for the in-console logger.
5
5
  """
6
+
6
7
  import logging
7
8
  import sys
8
9
  import time
@@ -3,6 +3,7 @@ Metrics submodule.
3
3
 
4
4
  This module contains various metrics and a metrics tracking class.
5
5
  """
6
+
6
7
  from typing import Union
7
8
 
8
9
  import numpy as np
@@ -1,3 +1,5 @@
1
+ """Utility functions for paths."""
2
+
1
3
  from pathlib import Path
2
4
  from typing import Union
3
5
 
careamics/utils/ram.py CHANGED
@@ -1,3 +1,5 @@
1
+ """Utility function to get RAM size."""
2
+
1
3
  import psutil
2
4
 
3
5
 
@@ -1,102 +1,108 @@
1
1
  """Receptive field calculation for computing the tile overlap."""
2
2
 
3
+ # TODO better docstring and function names
3
4
  # Adapted from: https://github.com/frgfm/torch-scan
4
5
 
5
- import math
6
- import warnings
7
- from typing import Tuple, Union
6
+ # import math
7
+ # import warnings
8
+ # from typing import Tuple, Union
8
9
 
9
- from torch import Tensor, nn
10
- from torch.nn import Module
11
- from torch.nn.modules.batchnorm import _BatchNorm
12
- from torch.nn.modules.conv import _ConvNd, _ConvTransposeNd
13
- from torch.nn.modules.pooling import (
14
- _AdaptiveAvgPoolNd,
15
- _AdaptiveMaxPoolNd,
16
- _AvgPoolNd,
17
- _MaxPoolNd,
18
- )
10
+ # from torch import Tensor, nn
11
+ # from torch.nn import Module
12
+ # from torch.nn.modules.batchnorm import _BatchNorm
13
+ # from torch.nn.modules.conv import _ConvNd, _ConvTransposeNd
14
+ # from torch.nn.modules.pooling import (
15
+ # _AdaptiveAvgPoolNd,
16
+ # _AdaptiveMaxPoolNd,
17
+ # _AvgPoolNd,
18
+ # _MaxPoolNd,
19
+ # )
19
20
 
20
21
 
21
- def module_rf(module: Module, inp: Tensor, out: Tensor) -> Tuple[float, float, float]:
22
- """Estimate the spatial receptive field of the module.
22
+ # def module_rf(module: Module, inp: Tensor, out: Tensor) -> Tuple[float, float, float]:
23
+ # """Estimate the spatial receptive field of the module.
23
24
 
24
- Args:
25
- module (torch.nn.Module): PyTorch module
26
- inp (torch.Tensor): input to the module
27
- out (torch.Tensor): output of the module
28
- Returns:
29
- receptive field
30
- effective stride
31
- effective padding
32
- """
33
- if isinstance(
34
- module,
35
- (
36
- nn.Identity,
37
- nn.Flatten,
38
- nn.ReLU,
39
- nn.ELU,
40
- nn.LeakyReLU,
41
- nn.ReLU6,
42
- nn.Tanh,
43
- nn.Sigmoid,
44
- _BatchNorm,
45
- nn.Dropout,
46
- nn.Linear,
47
- ),
48
- ):
49
- return 1.0, 1.0, 0.0
50
- elif isinstance(module, _ConvTransposeNd):
51
- return rf_convtransposend(module, inp, out)
52
- elif isinstance(module, (_ConvNd, _MaxPoolNd, _AvgPoolNd)):
53
- return rf_aggregnd(module, inp, out)
54
- elif isinstance(module, (_AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd)):
55
- return rf_adaptive_poolnd(module, inp, out)
56
- else:
57
- warnings.warn(
58
- f"Module type not supported: {module.__class__.__name__}", stacklevel=1
59
- )
60
- return 1.0, 1.0, 0.0
25
+ # Parameters
26
+ # ----------
27
+ # module : Module
28
+ # Module to estimate the receptive field.
29
+ # inp : Tensor
30
+ # Input tensor.
31
+ # out : Tensor
32
+ # Output tensor.
61
33
 
34
+ # Returns
35
+ # -------
36
+ # Tuple[float, float, float]
37
+ # Receptive field, effective stride and padding.
38
+ # """
39
+ # if isinstance(
40
+ # module,
41
+ # (
42
+ # nn.Identity,
43
+ # nn.Flatten,
44
+ # nn.ReLU,
45
+ # nn.ELU,
46
+ # nn.LeakyReLU,
47
+ # nn.ReLU6,
48
+ # nn.Tanh,
49
+ # nn.Sigmoid,
50
+ # _BatchNorm,
51
+ # nn.Dropout,
52
+ # nn.Linear,
53
+ # ),
54
+ # ):
55
+ # return 1.0, 1.0, 0.0
56
+ # elif isinstance(module, _ConvTransposeNd):
57
+ # return rf_convtransposend(module, inp, out)
58
+ # elif isinstance(module, (_ConvNd, _MaxPoolNd, _AvgPoolNd)):
59
+ # return rf_aggregnd(module, inp, out)
60
+ # elif isinstance(module, (_AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd)):
61
+ # return rf_adaptive_poolnd(module, inp, out)
62
+ # else:
63
+ # warnings.warn(
64
+ # f"Module type not supported: {module.__class__.__name__}", stacklevel=1
65
+ # )
66
+ # return 1.0, 1.0, 0.0
62
67
 
63
- def rf_convtransposend(
64
- module: _ConvTransposeNd, _: Tensor, __: Tensor
65
- ) -> Tuple[float, float, float]:
66
- k = (
67
- module.kernel_size[0]
68
- if isinstance(module.kernel_size, tuple)
69
- else module.kernel_size
70
- )
71
- s = module.stride[0] if isinstance(module.stride, tuple) else module.stride
72
- return -k, 1.0 / s, 0.0
73
68
 
69
+ # def rf_convtransposend(
70
+ # module: _ConvTransposeNd, _: Tensor, __: Tensor
71
+ # ) -> Tuple[float, float, float]:
72
+ # k = (
73
+ # module.kernel_size[0]
74
+ # if isinstance(module.kernel_size, tuple)
75
+ # else module.kernel_size
76
+ # )
77
+ # s = module.stride[0] if isinstance(module.stride, tuple) else module.stride
78
+ # return -k, 1.0 / s, 0.0
74
79
 
75
- def rf_aggregnd(
76
- module: Union[_ConvNd, _MaxPoolNd, _AvgPoolNd], _: Tensor, __: Tensor
77
- ) -> Tuple[float, float, float]:
78
- k = (
79
- module.kernel_size[0]
80
- if isinstance(module.kernel_size, tuple)
81
- else module.kernel_size
82
- )
83
- if hasattr(module, "dilation"):
84
- d = (
85
- module.dilation[0]
86
- if isinstance(module.dilation, tuple)
87
- else module.dilation
88
- )
89
- k = d * (k - 1) + 1
90
- s = module.stride[0] if isinstance(module.stride, tuple) else module.stride
91
- p = module.padding[0] if isinstance(module.padding, tuple) else module.padding
92
- return k, s, p # type: ignore[return-value]
93
80
 
81
+ # def rf_aggregnd(
82
+ # module: Union[_ConvNd, _MaxPoolNd, _AvgPoolNd], _: Tensor, __: Tensor
83
+ # ) -> Tuple[float, float, float]:
84
+ # k = (
85
+ # module.kernel_size[0]
86
+ # if isinstance(module.kernel_size, tuple)
87
+ # else module.kernel_size
88
+ # )
89
+ # if hasattr(module, "dilation"):
90
+ # d = (
91
+ # module.dilation[0]
92
+ # if isinstance(module.dilation, tuple)
93
+ # else module.dilation
94
+ # )
95
+ # k = d * (k - 1) + 1
96
+ # s = module.stride[0] if isinstance(module.stride, tuple) else module.stride
97
+ # p = module.padding[0] if isinstance(module.padding, tuple) else module.padding
98
+ # return k, s, p # type: ignore[return-value]
94
99
 
95
- def rf_adaptive_poolnd(
96
- _: Union[_AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd], inp: Tensor, out: Tensor
97
- ) -> Tuple[int, int, float]:
98
- stride = math.ceil(inp.shape[-1] / out.shape[-1])
99
- kernel_size = stride
100
- padding = (inp.shape[-1] - kernel_size * stride) / 2
101
100
 
102
- return kernel_size, stride, padding
101
+ # def rf_adaptive_poolnd(
102
+ # _: Union[_AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd], inp: Tensor, out: Tensor
103
+ # ) -> Tuple[int, int, float]:
104
+ # stride = math.ceil(inp.shape[-1] / out.shape[-1])
105
+ # kernel_size = stride
106
+ # padding = (inp.shape[-1] - kernel_size * stride) / 2
107
+
108
+ # return kernel_size, stride, padding
@@ -3,6 +3,7 @@ Convenience functions using torch.
3
3
 
4
4
  These functions are used to control certain aspects and behaviours of PyTorch.
5
5
  """
6
+
6
7
  import inspect
7
8
  from typing import Dict, Union
8
9