sleap-nn 0.1.0__py3-none-any.whl → 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. sleap_nn/__init__.py +1 -1
  2. sleap_nn/architectures/convnext.py +0 -5
  3. sleap_nn/architectures/encoder_decoder.py +6 -25
  4. sleap_nn/architectures/swint.py +0 -8
  5. sleap_nn/cli.py +60 -364
  6. sleap_nn/config/data_config.py +5 -11
  7. sleap_nn/config/get_config.py +4 -5
  8. sleap_nn/config/trainer_config.py +0 -71
  9. sleap_nn/data/augmentation.py +241 -50
  10. sleap_nn/data/custom_datasets.py +34 -364
  11. sleap_nn/data/instance_cropping.py +1 -1
  12. sleap_nn/data/resizing.py +2 -2
  13. sleap_nn/data/utils.py +17 -135
  14. sleap_nn/evaluation.py +22 -81
  15. sleap_nn/inference/bottomup.py +20 -86
  16. sleap_nn/inference/peak_finding.py +19 -88
  17. sleap_nn/inference/predictors.py +117 -224
  18. sleap_nn/legacy_models.py +11 -65
  19. sleap_nn/predict.py +9 -37
  20. sleap_nn/train.py +4 -69
  21. sleap_nn/training/callbacks.py +105 -1046
  22. sleap_nn/training/lightning_modules.py +65 -602
  23. sleap_nn/training/model_trainer.py +204 -201
  24. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/METADATA +3 -15
  25. sleap_nn-0.1.0a1.dist-info/RECORD +65 -0
  26. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/WHEEL +1 -1
  27. sleap_nn/data/skia_augmentation.py +0 -414
  28. sleap_nn/export/__init__.py +0 -21
  29. sleap_nn/export/cli.py +0 -1778
  30. sleap_nn/export/exporters/__init__.py +0 -51
  31. sleap_nn/export/exporters/onnx_exporter.py +0 -80
  32. sleap_nn/export/exporters/tensorrt_exporter.py +0 -291
  33. sleap_nn/export/metadata.py +0 -225
  34. sleap_nn/export/predictors/__init__.py +0 -63
  35. sleap_nn/export/predictors/base.py +0 -22
  36. sleap_nn/export/predictors/onnx.py +0 -154
  37. sleap_nn/export/predictors/tensorrt.py +0 -312
  38. sleap_nn/export/utils.py +0 -307
  39. sleap_nn/export/wrappers/__init__.py +0 -25
  40. sleap_nn/export/wrappers/base.py +0 -96
  41. sleap_nn/export/wrappers/bottomup.py +0 -243
  42. sleap_nn/export/wrappers/bottomup_multiclass.py +0 -195
  43. sleap_nn/export/wrappers/centered_instance.py +0 -56
  44. sleap_nn/export/wrappers/centroid.py +0 -58
  45. sleap_nn/export/wrappers/single_instance.py +0 -83
  46. sleap_nn/export/wrappers/topdown.py +0 -180
  47. sleap_nn/export/wrappers/topdown_multiclass.py +0 -304
  48. sleap_nn/inference/postprocessing.py +0 -284
  49. sleap_nn/training/schedulers.py +0 -191
  50. sleap_nn-0.1.0.dist-info/RECORD +0 -88
  51. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/entry_points.txt +0 -0
  52. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/licenses/LICENSE +0 -0
  53. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/top_level.txt +0 -0
@@ -109,10 +109,10 @@ class GeometricConfig:
109
109
  Attributes:
110
110
  rotation_min: (float) Minimum rotation angle in degrees. A random angle in (rotation_min, rotation_max) will be sampled and applied to both images and keypoints. Set to 0 to disable rotation augmentation. *Default*: `-15.0`.
111
111
  rotation_max: (float) Maximum rotation angle in degrees. A random angle in (rotation_min, rotation_max) will be sampled and applied to both images and keypoints. Set to 0 to disable rotation augmentation. *Default*: `15.0`.
112
- rotation_p: (float, optional) Probability of applying random rotation independently. If set, rotation is applied separately from scale/translate. If `None`, falls back to `affine_p` for bundled behavior. *Default*: `1.0`.
112
+ rotation_p: (float, optional) Probability of applying random rotation independently. If set, rotation is applied separately from scale/translate. If `None`, falls back to `affine_p` for bundled behavior. *Default*: `None`.
113
113
  scale_min: (float) Minimum scaling factor. If scale_min and scale_max are provided, the scale is randomly sampled from the range scale_min <= scale <= scale_max for isotropic scaling. *Default*: `0.9`.
114
114
  scale_max: (float) Maximum scaling factor. If scale_min and scale_max are provided, the scale is randomly sampled from the range scale_min <= scale <= scale_max for isotropic scaling. *Default*: `1.1`.
115
- scale_p: (float, optional) Probability of applying random scaling independently. If set, scaling is applied separately from rotation/translate. If `None`, falls back to `affine_p` for bundled behavior. *Default*: `1.0`.
115
+ scale_p: (float, optional) Probability of applying random scaling independently. If set, scaling is applied separately from rotation/translate. If `None`, falls back to `affine_p` for bundled behavior. *Default*: `None`.
116
116
  translate_width: (float) Maximum absolute fraction for horizontal translation. For example, if translate_width=a, then horizontal shift is randomly sampled in the range -img_width * a < dx < img_width * a. Will not translate by default. *Default*: `0.0`.
117
117
  translate_height: (float) Maximum absolute fraction for vertical translation. For example, if translate_height=a, then vertical shift is randomly sampled in the range -img_height * a < dy < img_height * a. Will not translate by default. *Default*: `0.0`.
118
118
  translate_p: (float, optional) Probability of applying random translation independently. If set, translation is applied separately from rotation/scale. If `None`, falls back to `affine_p` for bundled behavior. *Default*: `None`.
@@ -129,10 +129,10 @@ class GeometricConfig:
129
129
 
130
130
  rotation_min: float = field(default=-15.0, validator=validators.ge(-180))
131
131
  rotation_max: float = field(default=15.0, validator=validators.le(180))
132
- rotation_p: Optional[float] = field(default=1.0)
132
+ rotation_p: Optional[float] = field(default=None)
133
133
  scale_min: float = field(default=0.9, validator=validators.ge(0))
134
134
  scale_max: float = field(default=1.1, validator=validators.ge(0))
135
- scale_p: Optional[float] = field(default=1.0)
135
+ scale_p: Optional[float] = field(default=None)
136
136
  translate_width: float = 0.0
137
137
  translate_height: float = 0.0
138
138
  translate_p: Optional[float] = field(default=None)
@@ -198,8 +198,6 @@ class DataConfig:
198
198
  cache_img_path: (str) Path to save `.jpg` images created with `torch_dataset_cache_img_disk` data pipeline framework. If `None`, the path provided in `trainer_config.save_ckpt` is used. The `train_imgs` and `val_imgs` dirs are created inside this path. *Default*: `None`.
199
199
  use_existing_imgs: (bool) Use existing train and val images/ chunks in the `cache_img_path` for `torch_dataset_cache_img_disk` frameworks. If `True`, the `cache_img_path` should have `train_imgs` and `val_imgs` dirs. *Default*: `False`.
200
200
  delete_cache_imgs_after_training: (bool) If `False`, the images (torch_dataset_cache_img_disk) are retained after training. Else, the files are deleted. *Default*: `True`.
201
- parallel_caching: (bool) If `True`, use parallel processing to cache images (significantly faster for large datasets). *Default*: `True`.
202
- cache_workers: (int) Number of worker threads for parallel caching. If 0, uses min(4, cpu_count). *Default*: `0`.
203
201
  preprocessing: Configuration options related to data preprocessing.
204
202
  use_augmentations_train: (bool) True if the data augmentation should be applied to the training data, else False. *Default*: `True`.
205
203
  augmentation_config: Configurations related to augmentation. (only if `use_augmentations_train` is `True`)
@@ -219,13 +217,9 @@ class DataConfig:
219
217
  cache_img_path: Optional[str] = None
220
218
  use_existing_imgs: bool = False
221
219
  delete_cache_imgs_after_training: bool = True
222
- parallel_caching: bool = True
223
- cache_workers: int = 0
224
220
  preprocessing: PreprocessingConfig = field(factory=PreprocessingConfig)
225
221
  use_augmentations_train: bool = True
226
- augmentation_config: Optional[AugmentationConfig] = field(
227
- factory=lambda: AugmentationConfig(geometric=GeometricConfig())
228
- )
222
+ augmentation_config: Optional[AugmentationConfig] = None
229
223
  skeletons: Optional[list] = None
230
224
 
231
225
 
@@ -463,9 +463,9 @@ def get_data_config(
463
463
  crop_size: Optional[int] = None,
464
464
  min_crop_size: Optional[int] = 100,
465
465
  crop_padding: Optional[int] = None,
466
- use_augmentations_train: bool = True,
466
+ use_augmentations_train: bool = False,
467
467
  intensity_aug: Optional[Union[str, List[str], Dict[str, Any]]] = None,
468
- geometry_aug: Optional[Union[str, List[str], Dict[str, Any]]] = "rotation",
468
+ geometry_aug: Optional[Union[str, List[str], Dict[str, Any]]] = None,
469
469
  ):
470
470
  """Train a pose-estimation model with SLEAP-NN framework.
471
471
 
@@ -517,7 +517,7 @@ def get_data_config(
517
517
  crop size. If `None`, padding is auto-computed based on augmentation settings.
518
518
  Only used when `crop_size` is `None`. Default: None.
519
519
  use_augmentations_train: True if the data augmentation should be applied to the
520
- training data, else False. Default: True.
520
+ training data, else False. Default: False.
521
521
  intensity_aug: One of ["uniform_noise", "gaussian_noise", "contrast", "brightness"]
522
522
  or list of strings from the above allowed values. To have custom values, pass
523
523
  a dict with the structure in `sleap_nn.config.data_config.IntensityConfig`.
@@ -529,8 +529,7 @@ def get_data_config(
529
529
  or list of strings from the above allowed values. To have custom values, pass
530
530
  a dict with the structure in `sleap_nn.config.data_config.GeometryConfig`.
531
531
  For eg: {
532
- "rotation_min": -45,
533
- "rotation_max": 45,
532
+ "rotation": 45,
534
533
  "affine_p": 1.0
535
534
  }
536
535
  """
@@ -178,69 +178,19 @@ class ReduceLROnPlateauConfig:
178
178
  raise ValueError(message)
179
179
 
180
180
 
181
- @define
182
- class CosineAnnealingWarmupConfig:
183
- """Configuration for Cosine Annealing with Linear Warmup scheduler.
184
-
185
- The learning rate increases linearly during warmup, then decreases following
186
- a cosine curve to the minimum value.
187
-
188
- Attributes:
189
- warmup_epochs: (int) Number of epochs for linear warmup phase. *Default*: `5`.
190
- max_epochs: (int) Total number of training epochs. Will be overridden by
191
- trainer's max_epochs if not specified. *Default*: `None`.
192
- warmup_start_lr: (float) Learning rate at start of warmup. *Default*: `0.0`.
193
- eta_min: (float) Minimum learning rate at end of cosine decay. *Default*: `0.0`.
194
- """
195
-
196
- warmup_epochs: int = field(default=5, validator=validators.ge(0))
197
- max_epochs: Optional[int] = None
198
- warmup_start_lr: float = field(default=0.0, validator=validators.ge(0))
199
- eta_min: float = field(default=0.0, validator=validators.ge(0))
200
-
201
-
202
- @define
203
- class LinearWarmupLinearDecayConfig:
204
- """Configuration for Linear Warmup + Linear Decay scheduler.
205
-
206
- The learning rate increases linearly during warmup, then decreases linearly
207
- to the end learning rate.
208
-
209
- Attributes:
210
- warmup_epochs: (int) Number of epochs for linear warmup phase. *Default*: `5`.
211
- max_epochs: (int) Total number of training epochs. Will be overridden by
212
- trainer's max_epochs if not specified. *Default*: `None`.
213
- warmup_start_lr: (float) Learning rate at start of warmup. *Default*: `0.0`.
214
- end_lr: (float) Learning rate at end of training. *Default*: `0.0`.
215
- """
216
-
217
- warmup_epochs: int = field(default=5, validator=validators.ge(0))
218
- max_epochs: Optional[int] = None
219
- warmup_start_lr: float = field(default=0.0, validator=validators.ge(0))
220
- end_lr: float = field(default=0.0, validator=validators.ge(0))
221
-
222
-
223
181
  @define
224
182
  class LRSchedulerConfig:
225
183
  """Configuration for lr_scheduler.
226
184
 
227
- Only one scheduler should be configured at a time. If multiple are set,
228
- priority order is: cosine_annealing_warmup > linear_warmup_linear_decay >
229
- step_lr > reduce_lr_on_plateau.
230
-
231
185
  Attributes:
232
186
  step_lr: Configuration for StepLR scheduler.
233
187
  reduce_lr_on_plateau: Configuration for ReduceLROnPlateau scheduler.
234
- cosine_annealing_warmup: Configuration for Cosine Annealing with Linear Warmup scheduler.
235
- linear_warmup_linear_decay: Configuration for Linear Warmup + Linear Decay scheduler.
236
188
  """
237
189
 
238
190
  step_lr: Optional[StepLRConfig] = None
239
191
  reduce_lr_on_plateau: Optional[ReduceLROnPlateauConfig] = field(
240
192
  factory=ReduceLROnPlateauConfig
241
193
  )
242
- cosine_annealing_warmup: Optional[CosineAnnealingWarmupConfig] = None
243
- linear_warmup_linear_decay: Optional[LinearWarmupLinearDecayConfig] = None
244
194
 
245
195
 
246
196
  @define
@@ -258,26 +208,6 @@ class EarlyStoppingConfig:
258
208
  stop_training_on_plateau: bool = True
259
209
 
260
210
 
261
- @define
262
- class EvalConfig:
263
- """Configuration for epoch-end evaluation.
264
-
265
- Attributes:
266
- enabled: (bool) Enable epoch-end evaluation metrics. *Default*: `False`.
267
- frequency: (int) Evaluate every N epochs. *Default*: `1`.
268
- oks_stddev: (float) OKS standard deviation for evaluation. *Default*: `0.025`.
269
- oks_scale: (float) OKS scale override. If None, uses default. *Default*: `None`.
270
- match_threshold: (float) Maximum distance in pixels for centroid matching.
271
- Only used for centroid model evaluation. *Default*: `50.0`.
272
- """
273
-
274
- enabled: bool = False
275
- frequency: int = field(default=1, validator=validators.ge(1))
276
- oks_stddev: float = field(default=0.025, validator=validators.gt(0))
277
- oks_scale: Optional[float] = None
278
- match_threshold: float = field(default=50.0, validator=validators.gt(0))
279
-
280
-
281
211
  @define
282
212
  class HardKeypointMiningConfig:
283
213
  """Configuration for online hard keypoint mining.
@@ -380,7 +310,6 @@ class TrainerConfig:
380
310
  factory=HardKeypointMiningConfig
381
311
  )
382
312
  zmq: Optional[ZMQConfig] = field(factory=ZMQConfig) # Required for SLEAP GUI
383
- eval: EvalConfig = field(factory=EvalConfig) # Epoch-end evaluation config
384
313
 
385
314
  @staticmethod
386
315
  def validate_optimizer_name(value):
@@ -1,15 +1,12 @@
1
- """This module implements data pipeline blocks for augmentation operations.
1
+ """This module implements data pipeline blocks for augmentation operations."""
2
2
 
3
- Uses Skia (skia-python) for ~1.5x faster augmentation compared to Kornia.
4
- """
5
-
6
- from typing import Optional, Tuple
3
+ from typing import Any, Dict, Optional, Tuple, Union
4
+ import kornia as K
7
5
  import torch
8
-
9
- from sleap_nn.data.skia_augmentation import (
10
- apply_intensity_augmentation_skia,
11
- apply_geometric_augmentation_skia,
12
- )
6
+ from kornia.augmentation._2d.intensity.base import IntensityAugmentationBase2D
7
+ from kornia.augmentation.container import AugmentationSequential
8
+ from kornia.augmentation.utils.param_validation import _range_bound
9
+ from kornia.core import Tensor
13
10
 
14
11
 
15
12
  def apply_intensity_augmentation(
@@ -27,8 +24,8 @@ def apply_intensity_augmentation(
27
24
  brightness_min: Optional[float] = 1.0,
28
25
  brightness_max: Optional[float] = 1.0,
29
26
  brightness_p: float = 0.0,
30
- ) -> Tuple[torch.Tensor, torch.Tensor]:
31
- """Apply intensity augmentation on image and instances.
27
+ ) -> Tuple[torch.Tensor]:
28
+ """Apply kornia intensity augmentation on image and instances.
32
29
 
33
30
  Args:
34
31
  image: Input image. Shape: (n_samples, C, H, W)
@@ -49,23 +46,66 @@ def apply_intensity_augmentation(
49
46
  Returns:
50
47
  Returns tuple: (image, instances) with augmentation applied.
51
48
  """
52
- return apply_intensity_augmentation_skia(
53
- image=image,
54
- instances=instances,
55
- uniform_noise_min=uniform_noise_min,
56
- uniform_noise_max=uniform_noise_max,
57
- uniform_noise_p=uniform_noise_p,
58
- gaussian_noise_mean=gaussian_noise_mean,
59
- gaussian_noise_std=gaussian_noise_std,
60
- gaussian_noise_p=gaussian_noise_p,
61
- contrast_min=contrast_min,
62
- contrast_max=contrast_max,
63
- contrast_p=contrast_p,
64
- brightness_min=brightness_min,
65
- brightness_max=brightness_max,
66
- brightness_p=brightness_p,
49
+ aug_stack = []
50
+ if uniform_noise_p > 0:
51
+ aug_stack.append(
52
+ RandomUniformNoise(
53
+ noise=(uniform_noise_min, uniform_noise_max),
54
+ p=uniform_noise_p,
55
+ keepdim=True,
56
+ same_on_batch=True,
57
+ )
58
+ )
59
+ if gaussian_noise_p > 0:
60
+ aug_stack.append(
61
+ K.augmentation.RandomGaussianNoise(
62
+ mean=gaussian_noise_mean,
63
+ std=gaussian_noise_std,
64
+ p=gaussian_noise_p,
65
+ keepdim=True,
66
+ same_on_batch=True,
67
+ )
68
+ )
69
+ if contrast_p > 0:
70
+ aug_stack.append(
71
+ K.augmentation.RandomContrast(
72
+ contrast=(contrast_min, contrast_max),
73
+ p=contrast_p,
74
+ keepdim=True,
75
+ same_on_batch=True,
76
+ )
77
+ )
78
+ if brightness_p > 0:
79
+ aug_stack.append(
80
+ K.augmentation.RandomBrightness(
81
+ brightness=(brightness_min, brightness_max),
82
+ p=brightness_p,
83
+ keepdim=True,
84
+ same_on_batch=True,
85
+ )
86
+ )
87
+
88
+ augmenter = AugmentationSequential(
89
+ *aug_stack,
90
+ data_keys=["input", "keypoints"],
91
+ keepdim=True,
92
+ same_on_batch=True,
67
93
  )
68
94
 
95
+ inst_shape = instances.shape
96
+ # Before (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2)
97
+ # or
98
+ # Before (cropped image): (B=1, C, crop_H, crop_W), (n_samples, n_nodes, 2)
99
+ instances = instances.reshape(inst_shape[0], -1, 2)
100
+ # (n_samples, C, H, W), (n_samples, n_instances * n_nodes, 2) OR (n_samples, n_nodes, 2)
101
+
102
+ aug_image, aug_instances = augmenter(image, instances)
103
+
104
+ # After (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2)
105
+ # or
106
+ # After (cropped image): (n_samples, C, crop_H, crop_W), (n_samples, n_nodes, 2)
107
+ return aug_image, aug_instances.reshape(*inst_shape)
108
+
69
109
 
70
110
  def apply_geometric_augmentation(
71
111
  image: torch.Tensor,
@@ -88,8 +128,8 @@ def apply_geometric_augmentation(
88
128
  mixup_lambda_min: Optional[float] = 0.01,
89
129
  mixup_lambda_max: Optional[float] = 0.05,
90
130
  mixup_p: float = 0.0,
91
- ) -> Tuple[torch.Tensor, torch.Tensor]:
92
- """Apply geometric augmentation on image and instances.
131
+ ) -> Tuple[torch.Tensor]:
132
+ """Apply kornia geometric augmentation on image and instances.
93
133
 
94
134
  Args:
95
135
  image: Input image. Shape: (n_samples, C, H, W)
@@ -120,25 +160,176 @@ def apply_geometric_augmentation(
120
160
  Returns:
121
161
  Returns tuple: (image, instances) with augmentation applied.
122
162
  """
123
- return apply_geometric_augmentation_skia(
124
- image=image,
125
- instances=instances,
126
- rotation_min=rotation_min,
127
- rotation_max=rotation_max,
128
- rotation_p=rotation_p,
129
- scale_min=scale_min,
130
- scale_max=scale_max,
131
- scale_p=scale_p,
132
- translate_width=translate_width,
133
- translate_height=translate_height,
134
- translate_p=translate_p,
135
- affine_p=affine_p,
136
- erase_scale_min=erase_scale_min,
137
- erase_scale_max=erase_scale_max,
138
- erase_ratio_min=erase_ratio_min,
139
- erase_ratio_max=erase_ratio_max,
140
- erase_p=erase_p,
141
- mixup_lambda_min=mixup_lambda_min,
142
- mixup_lambda_max=mixup_lambda_max,
143
- mixup_p=mixup_p,
163
+ aug_stack = []
164
+
165
+ # Check if any individual probability is set
166
+ use_independent = (
167
+ rotation_p is not None or scale_p is not None or translate_p is not None
144
168
  )
169
+
170
+ if use_independent:
171
+ # New behavior: Apply augmentations independently with separate probabilities
172
+ if rotation_p is not None and rotation_p > 0:
173
+ aug_stack.append(
174
+ K.augmentation.RandomRotation(
175
+ degrees=(rotation_min, rotation_max),
176
+ p=rotation_p,
177
+ keepdim=True,
178
+ same_on_batch=True,
179
+ )
180
+ )
181
+
182
+ if scale_p is not None and scale_p > 0:
183
+ aug_stack.append(
184
+ K.augmentation.RandomAffine(
185
+ degrees=0, # No rotation
186
+ translate=None, # No translation
187
+ scale=(scale_min, scale_max),
188
+ p=scale_p,
189
+ keepdim=True,
190
+ same_on_batch=True,
191
+ )
192
+ )
193
+
194
+ if translate_p is not None and translate_p > 0:
195
+ aug_stack.append(
196
+ K.augmentation.RandomAffine(
197
+ degrees=0, # No rotation
198
+ translate=(translate_width, translate_height),
199
+ scale=None, # No scaling
200
+ p=translate_p,
201
+ keepdim=True,
202
+ same_on_batch=True,
203
+ )
204
+ )
205
+ elif affine_p > 0:
206
+ # Legacy behavior: Bundled affine transformation
207
+ aug_stack.append(
208
+ K.augmentation.RandomAffine(
209
+ degrees=(rotation_min, rotation_max),
210
+ translate=(translate_width, translate_height),
211
+ scale=(scale_min, scale_max),
212
+ p=affine_p,
213
+ keepdim=True,
214
+ same_on_batch=True,
215
+ )
216
+ )
217
+
218
+ if erase_p > 0:
219
+ aug_stack.append(
220
+ K.augmentation.RandomErasing(
221
+ scale=(erase_scale_min, erase_scale_max),
222
+ ratio=(erase_ratio_min, erase_ratio_max),
223
+ p=erase_p,
224
+ keepdim=True,
225
+ same_on_batch=True,
226
+ )
227
+ )
228
+ if mixup_p > 0:
229
+ aug_stack.append(
230
+ K.augmentation.RandomMixUpV2(
231
+ lambda_val=(mixup_lambda_min, mixup_lambda_max),
232
+ p=mixup_p,
233
+ keepdim=True,
234
+ same_on_batch=True,
235
+ )
236
+ )
237
+
238
+ augmenter = AugmentationSequential(
239
+ *aug_stack,
240
+ data_keys=["input", "keypoints"],
241
+ keepdim=True,
242
+ same_on_batch=True,
243
+ )
244
+
245
+ inst_shape = instances.shape
246
+ # Before (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2)
247
+ # or
248
+ # Before (cropped image): (B=1, C, crop_H, crop_W), (n_samples, n_nodes, 2)
249
+ instances = instances.reshape(inst_shape[0], -1, 2)
250
+ # (n_samples, C, H, W), (n_samples, n_instances * n_nodes, 2) OR (n_samples, n_nodes, 2)
251
+
252
+ aug_image, aug_instances = augmenter(image, instances)
253
+
254
+ # After (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2)
255
+ # or
256
+ # After (cropped image): (n_samples, C, crop_H, crop_W), (n_samples, n_nodes, 2)
257
+ return aug_image, aug_instances.reshape(*inst_shape)
258
+
259
+
260
+ class RandomUniformNoise(IntensityAugmentationBase2D):
261
+ """Data transformer for applying random uniform noise to input images.
262
+
263
+ This is a custom Kornia augmentation inheriting from `IntensityAugmentationBase2D`.
264
+ Uniform noise within (min_val, max_val) is applied to the entire input image.
265
+
266
+ Note: Inverse transform is not implemented and re-applying the same transformation
267
+ in the example below does not work when included in an AugmentationSequential class.
268
+
269
+ Args:
270
+ noise: 2-tuple (min_val, max_val); 0.0 <= min_val <= max_val <= 1.0.
271
+ p: probability for applying an augmentation. This param controls the augmentation probabilities
272
+ element-wise for a batch.
273
+ p_batch: probability for applying an augmentation to a batch. This param controls the augmentation
274
+ probabilities batch-wise.
275
+ same_on_batch: apply the same transformation across the batch.
276
+ keepdim: whether to keep the output shape the same as input `True` or broadcast it
277
+ to the batch form `False`.
278
+
279
+ Examples:
280
+ >>> rng = torch.manual_seed(0)
281
+ >>> img = torch.rand(1, 1, 2, 2)
282
+ >>> RandomUniformNoise(min_val=0., max_val=0.1, p=1.)(img)
283
+ tensor([[[[0.9607, 0.5865],
284
+ [0.2705, 0.5920]]]])
285
+
286
+ To apply the exact augmentation again, you may take the advantage of the previous parameter state:
287
+ >>> input = torch.rand(1, 3, 32, 32)
288
+ >>> aug = RandomUniformNoise(min_val=0., max_val=0.1, p=1.)
289
+ >>> (aug(input) == aug(input, params=aug._params)).all()
290
+ tensor(True)
291
+
292
+ Ref: `kornia.augmentation._2d.intensity.gaussian_noise
293
+ <https://kornia.readthedocs.io/en/latest/_modules/kornia/augmentation/_2d/intensity/gaussian_noise.html#RandomGaussianNoise>`_.
294
+ """
295
+
296
+ def __init__(
297
+ self,
298
+ noise: Tuple[float, float],
299
+ p: float = 0.5,
300
+ p_batch: float = 1.0,
301
+ clip_output: bool = True,
302
+ same_on_batch: bool = False,
303
+ keepdim: bool = False,
304
+ ) -> None:
305
+ """Initialize the class."""
306
+ super().__init__(
307
+ p=p, p_batch=p_batch, same_on_batch=same_on_batch, keepdim=keepdim
308
+ )
309
+ self.flags = {
310
+ "uniform_noise": _range_bound(noise, "uniform_noise", bounds=(0.0, 1.0))
311
+ }
312
+ self.clip_output = clip_output
313
+
314
+ def apply_transform(
315
+ self,
316
+ input: Tensor,
317
+ params: Dict[str, Tensor],
318
+ flags: Dict[str, Any],
319
+ transform: Optional[Tensor] = None,
320
+ ) -> Tensor:
321
+ """Compute the uniform noise, add, and clamp output."""
322
+ if "uniform_noise" in params:
323
+ uniform_noise = params["uniform_noise"]
324
+ else:
325
+ uniform_noise = (
326
+ torch.FloatTensor(input.shape)
327
+ .uniform_(flags["uniform_noise"][0], flags["uniform_noise"][1])
328
+ .to(input.device)
329
+ )
330
+ self._params["uniform_noise"] = uniform_noise
331
+ if self.clip_output:
332
+ return torch.clamp(
333
+ input + uniform_noise, 0.0, 1.0
334
+ ) # RandomGaussianNoise doesn't clamp.
335
+ return input + uniform_noise