sleap-nn 0.0.5__py3-none-any.whl → 0.1.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sleap_nn/__init__.py CHANGED
@@ -48,4 +48,9 @@ logger.add(
48
48
  format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {name}:{function}:{line} | {message}",
49
49
  )
50
50
 
51
- __version__ = "0.0.5"
51
+ __version__ = "0.1.0a0"
52
+
53
+ # Public API
54
+ from sleap_nn.evaluation import load_metrics
55
+
56
+ __all__ = ["load_metrics", "__version__"]
sleap_nn/cli.py CHANGED
@@ -4,14 +4,24 @@ import click
4
4
  from loguru import logger
5
5
  from pathlib import Path
6
6
  from omegaconf import OmegaConf, DictConfig
7
+ import sleap_io as sio
7
8
  from sleap_nn.predict import run_inference, frame_list
8
9
  from sleap_nn.evaluation import run_evaluation
9
10
  from sleap_nn.train import run_training
11
+ from sleap_nn import __version__
10
12
  import hydra
11
13
  import sys
12
14
  from click import Command
13
15
 
14
16
 
17
+ def print_version(ctx, param, value):
18
+ """Print version and exit."""
19
+ if not value or ctx.resilient_parsing:
20
+ return
21
+ click.echo(f"sleap-nn {__version__}")
22
+ ctx.exit()
23
+
24
+
15
25
  class TrainCommand(Command):
16
26
  """Custom command class that overrides help behavior for train command."""
17
27
 
@@ -20,7 +30,26 @@ class TrainCommand(Command):
20
30
  show_training_help()
21
31
 
22
32
 
33
+ def parse_path_map(ctx, param, value):
34
+ """Parse (old, new) path pairs into a dictionary for path mapping options."""
35
+ if not value:
36
+ return None
37
+ result = {}
38
+ for old_path, new_path in value:
39
+ result[old_path] = Path(new_path).as_posix()
40
+ return result
41
+
42
+
23
43
  @click.group()
44
+ @click.option(
45
+ "--version",
46
+ "-v",
47
+ is_flag=True,
48
+ callback=print_version,
49
+ expose_value=False,
50
+ is_eager=True,
51
+ help="Show version and exit.",
52
+ )
24
53
  def cli():
25
54
  """SLEAP-NN: Neural network backend for training and inference for animal pose estimation.
26
55
 
@@ -29,6 +58,7 @@ def cli():
29
58
  train - Run training workflow
30
59
  track - Run inference/ tracking workflow
31
60
  eval - Run evaluation workflow
61
+ system - Display system information and GPU status
32
62
  """
33
63
  pass
34
64
 
@@ -67,8 +97,38 @@ For a detailed list of all available config options, please refer to https://nn.
67
97
  @click.option(
68
98
  "--config-dir", "-d", type=str, default=".", help="Configuration directory path"
69
99
  )
100
+ @click.option(
101
+ "--video-paths",
102
+ "-v",
103
+ multiple=True,
104
+ help="Video paths to replace existing paths in the labels file. "
105
+ "Order must match the order of videos in the labels file. "
106
+ "Can be specified multiple times. "
107
+ "Example: --video-paths /path/to/vid1.mp4 --video-paths /path/to/vid2.mp4",
108
+ )
109
+ @click.option(
110
+ "--video-path-map",
111
+ nargs=2,
112
+ multiple=True,
113
+ callback=parse_path_map,
114
+ metavar="OLD NEW",
115
+ help="Map old video path to new path. Takes two arguments: old path and new path. "
116
+ "Can be specified multiple times. "
117
+ 'Example: --video-path-map "/old/vid.mp`4" "/new/vid.mp4"',
118
+ )
119
+ @click.option(
120
+ "--prefix-map",
121
+ nargs=2,
122
+ multiple=True,
123
+ callback=parse_path_map,
124
+ metavar="OLD NEW",
125
+ help="Map old path prefix to new prefix. Takes two arguments: old prefix and new prefix. "
126
+ "Updates ALL videos that share the same prefix. Useful when moving data between machines. "
127
+ "Can be specified multiple times. "
128
+ 'Example: --prefix-map "/old/server/path" "/new/local/path"',
129
+ )
70
130
  @click.argument("overrides", nargs=-1, type=click.UNPROCESSED)
71
- def train(config_name, config_dir, overrides):
131
+ def train(config_name, config_dir, video_paths, video_path_map, prefix_map, overrides):
72
132
  """Run training workflow with Hydra config overrides.
73
133
 
74
134
  Examples:
@@ -97,7 +157,62 @@ def train(config_name, config_dir, overrides):
97
157
 
98
158
  logger.info("Input config:")
99
159
  logger.info("\n" + OmegaConf.to_yaml(cfg))
100
- run_training(cfg)
160
+
161
+ # Handle video path replacement options
162
+ train_labels = None
163
+ val_labels = None
164
+
165
+ # Check that only one replacement option is used
166
+ # video_paths is a tuple (empty if not used), others are None or dict
167
+ has_video_paths = len(video_paths) > 0
168
+ has_video_path_map = video_path_map is not None
169
+ has_prefix_map = prefix_map is not None
170
+ options_used = sum([has_video_paths, has_video_path_map, has_prefix_map])
171
+
172
+ if options_used > 1:
173
+ raise click.UsageError(
174
+ "Cannot use multiple path replacement options. "
175
+ "Choose one of: --video-paths, --video-path-map, or --prefix-map."
176
+ )
177
+
178
+ if options_used == 1:
179
+ # Load train labels
180
+ train_labels = [
181
+ sio.load_slp(path) for path in cfg.data_config.train_labels_path
182
+ ]
183
+
184
+ # Load val labels if they exist
185
+ if (
186
+ cfg.data_config.val_labels_path is not None
187
+ and len(cfg.data_config.val_labels_path) > 0
188
+ ):
189
+ val_labels = [
190
+ sio.load_slp(path) for path in cfg.data_config.val_labels_path
191
+ ]
192
+
193
+ # Build replacement arguments based on option used
194
+ if has_video_paths:
195
+ # List of paths (order must match videos in labels file)
196
+ replace_kwargs = {
197
+ "new_filenames": [Path(p).as_posix() for p in video_paths]
198
+ }
199
+ elif has_video_path_map:
200
+ # Dictionary mapping old filenames to new filenames
201
+ replace_kwargs = {"filename_map": video_path_map}
202
+ else: # has_prefix_map
203
+ # Dictionary mapping old prefixes to new prefixes
204
+ replace_kwargs = {"prefix_map": prefix_map}
205
+
206
+ # Apply replacement to train labels
207
+ for labels in train_labels:
208
+ labels.replace_filenames(**replace_kwargs)
209
+
210
+ # Apply replacement to val labels if they exist
211
+ if val_labels:
212
+ for labels in val_labels:
213
+ labels.replace_filenames(**replace_kwargs)
214
+
215
+ run_training(config=cfg, train_labels=train_labels, val_labels=val_labels)
101
216
 
102
217
 
103
218
  @cli.command()
@@ -209,6 +324,18 @@ def train(config_name, config_dir, overrides):
209
324
  default=False,
210
325
  help="Only run inference on unlabeled suggested frames when running on labels dataset. This is useful for generating predictions for initialization during labeling.",
211
326
  )
327
+ @click.option(
328
+ "--exclude_user_labeled",
329
+ is_flag=True,
330
+ default=False,
331
+ help="Skip frames that have user-labeled instances. Useful when predicting on entire video but skipping already-labeled frames.",
332
+ )
333
+ @click.option(
334
+ "--only_predicted_frames",
335
+ is_flag=True,
336
+ default=False,
337
+ help="Only run inference on frames that already have predictions. Requires .slp input file. Useful for re-predicting with a different model.",
338
+ )
212
339
  @click.option(
213
340
  "--no_empty_frames",
214
341
  is_flag=True,
@@ -282,7 +409,7 @@ def train(config_name, config_dir, overrides):
282
409
  "--crop_size",
283
410
  type=int,
284
411
  default=None,
285
- help="Crop size. If not provided, the crop size from training_config.yaml is used.",
412
+ help="Crop size. If not provided, the crop size from training_config.yaml is used. If `input_scale` is provided, then the cropped image will be resized according to `input_scale`.",
286
413
  )
287
414
  @click.option(
288
415
  "--peak_threshold",
@@ -474,5 +601,17 @@ def eval(**kwargs):
474
601
  run_evaluation(**kwargs)
475
602
 
476
603
 
604
+ @cli.command()
605
+ def system():
606
+ """Display system information and GPU status.
607
+
608
+ Shows Python version, platform, PyTorch version, CUDA availability,
609
+ driver version with compatibility check, GPU details, and package versions.
610
+ """
611
+ from sleap_nn.system_info import print_system_info
612
+
613
+ print_system_info()
614
+
615
+
477
616
  if __name__ == "__main__":
478
617
  cli()
@@ -6,7 +6,7 @@ the parameters required to initialize the data config.
6
6
 
7
7
  from attrs import define, field, validators
8
8
  from omegaconf import MISSING
9
- from typing import Optional, Tuple, Any, List
9
+ from typing import Optional, Tuple, Any, List, Union
10
10
  from loguru import logger
11
11
  import sleap_io as sio
12
12
  import yaml
@@ -20,11 +20,15 @@ class PreprocessingConfig:
20
20
  Attributes:
21
21
  ensure_rgb: (bool) True if the input image should have 3 channels (RGB image). If input has only one channel when this is set to `True`, then the images from single-channel is replicated along the channel axis. If the image has three channels and this is set to False, then we retain the three channels. *Default*: `False`.
22
22
  ensure_grayscale: (bool) True if the input image should only have a single channel. If input has three channels (RGB) and this is set to True, then we convert the image to grayscale (single-channel) image. If the source image has only one channel and this is set to False, then we retain the single channel input. *Default*: `False`.
23
- max_height: (int) Maximum height the image should be padded to. If not provided, the original image size will be retained. *Default*: `None`.
24
- max_width: (int) Maximum width the image should be padded to. If not provided, the original image size will be retained. *Default*: `None`.
23
+ max_height: (int) Maximum height the original image should be resized and padded to. If not provided, the original image size will be retained. *Default*: `None`.
24
+ max_width: (int) Maximum width the original image should be resized and padded to. If not provided, the original image size will be retained. *Default*: `None`.
25
25
  scale: (float) Factor to resize the image dimensions by, specified as a float. *Default*: `1.0`.
26
- crop_size: (int) Crop size of each instance for centered-instance model. If `None`, this would be automatically computed based on the largest instance in the `sio.Labels` file. *Default*: `None`.
26
+ crop_size: (int) Crop size of each instance for centered-instance model. If `None`, this would be automatically computed based on the largest instance in the `sio.Labels` file.
27
+ If `scale` is provided, then the cropped image will be resized according to `scale`.*Default*: `None`.
27
28
  min_crop_size: (int) Minimum crop size to be used if `crop_size` is `None`. *Default*: `100`.
29
+ crop_padding: (int) Padding in pixels to add around the instance bounding box when computing crop size.
30
+ If `None`, padding is auto-computed based on augmentation settings (rotation/scale).
31
+ Only used when `crop_size` is `None`. *Default*: `None`.
28
32
  """
29
33
 
30
34
  ensure_rgb: bool = False
@@ -36,6 +40,7 @@ class PreprocessingConfig:
36
40
  )
37
41
  crop_size: Optional[int] = None
38
42
  min_crop_size: Optional[int] = 100 # to help app work in case of error
43
+ crop_padding: Optional[int] = None
39
44
 
40
45
  def validate_scale(self):
41
46
  """Scale Validation.
@@ -104,11 +109,14 @@ class GeometricConfig:
104
109
  Attributes:
105
110
  rotation_min: (float) Minimum rotation angle in degrees. A random angle in (rotation_min, rotation_max) will be sampled and applied to both images and keypoints. Set to 0 to disable rotation augmentation. *Default*: `-15.0`.
106
111
  rotation_max: (float) Maximum rotation angle in degrees. A random angle in (rotation_min, rotation_max) will be sampled and applied to both images and keypoints. Set to 0 to disable rotation augmentation. *Default*: `15.0`.
112
+ rotation_p: (float, optional) Probability of applying random rotation independently. If set, rotation is applied separately from scale/translate. If `None`, falls back to `affine_p` for bundled behavior. *Default*: `None`.
107
113
  scale_min: (float) Minimum scaling factor. If scale_min and scale_max are provided, the scale is randomly sampled from the range scale_min <= scale <= scale_max for isotropic scaling. *Default*: `0.9`.
108
114
  scale_max: (float) Maximum scaling factor. If scale_min and scale_max are provided, the scale is randomly sampled from the range scale_min <= scale <= scale_max for isotropic scaling. *Default*: `1.1`.
115
+ scale_p: (float, optional) Probability of applying random scaling independently. If set, scaling is applied separately from rotation/translate. If `None`, falls back to `affine_p` for bundled behavior. *Default*: `None`.
109
116
  translate_width: (float) Maximum absolute fraction for horizontal translation. For example, if translate_width=a, then horizontal shift is randomly sampled in the range -img_width * a < dx < img_width * a. Will not translate by default. *Default*: `0.0`.
110
117
  translate_height: (float) Maximum absolute fraction for vertical translation. For example, if translate_height=a, then vertical shift is randomly sampled in the range -img_height * a < dy < img_height * a. Will not translate by default. *Default*: `0.0`.
111
- affine_p: (float) Probability of applying random affine transformations. *Default*: `0.0`.
118
+ translate_p: (float, optional) Probability of applying random translation independently. If set, translation is applied separately from rotation/scale. If `None`, falls back to `affine_p` for bundled behavior. *Default*: `None`.
119
+ affine_p: (float) Probability of applying random affine transformations (rotation, scale, translate bundled together). Used for backwards compatibility when individual `*_p` params are not set. *Default*: `0.0`.
112
120
  erase_scale_min: (float) Minimum value of range of proportion of erased area against input image. *Default*: `0.0001`.
113
121
  erase_scale_max: (float) Maximum value of range of proportion of erased area against input image. *Default*: `0.01`.
114
122
  erase_ratio_min: (float) Minimum value of range of aspect ratio of erased area. *Default*: `1.0`.
@@ -121,10 +129,13 @@ class GeometricConfig:
121
129
 
122
130
  rotation_min: float = field(default=-15.0, validator=validators.ge(-180))
123
131
  rotation_max: float = field(default=15.0, validator=validators.le(180))
132
+ rotation_p: Optional[float] = field(default=None)
124
133
  scale_min: float = field(default=0.9, validator=validators.ge(0))
125
134
  scale_max: float = field(default=1.1, validator=validators.ge(0))
135
+ scale_p: Optional[float] = field(default=None)
126
136
  translate_width: float = 0.0
127
137
  translate_height: float = 0.0
138
+ translate_p: Optional[float] = field(default=None)
128
139
  affine_p: float = field(default=0.0, validator=validate_proportion)
129
140
  erase_scale_min: float = 0.0001
130
141
  erase_scale_max: float = 0.01
@@ -149,6 +160,28 @@ class AugmentationConfig:
149
160
  geometric: Optional[GeometricConfig] = None
150
161
 
151
162
 
163
+ def validate_test_file_path(instance, attribute, value):
164
+ """Validate test_file_path to accept str or List[str].
165
+
166
+ Args:
167
+ instance: The instance being validated.
168
+ attribute: The attribute being validated.
169
+ value: The value to validate.
170
+
171
+ Raises:
172
+ ValueError: If value is not None, str, or list of strings.
173
+ """
174
+ if value is None:
175
+ return
176
+ if isinstance(value, str):
177
+ return
178
+ if isinstance(value, (list, tuple)) and all(isinstance(p, str) for p in value):
179
+ return
180
+ message = f"{attribute.name} must be a string or list of strings, got {type(value).__name__}"
181
+ logger.error(message)
182
+ raise ValueError(message)
183
+
184
+
152
185
  @define
153
186
  class DataConfig:
154
187
  """Data configuration.
@@ -157,7 +190,8 @@ class DataConfig:
157
190
  train_labels_path: (List[str]) List of paths to training data (`.slp` file(s)). *Default*: `None`.
158
191
  val_labels_path: (List[str]) List of paths to validation data (`.slp` file(s)). *Default*: `None`.
159
192
  validation_fraction: (float) Float between 0 and 1 specifying the fraction of the training set to sample for generating the validation set. The remaining labeled frames will be left in the training set. If the `validation_labels` are already specified, this has no effect. *Default*: `0.1`.
160
- test_file_path: (str) Path to test dataset (`.slp` file or `.mp4` file). *Note*: This is used only with CLI to get evaluation on test set after training is completed. *Default*: `None`.
193
+ use_same_data_for_val: (bool) If `True`, use the same data for both training and validation (train = val). Useful for intentional overfitting on small datasets. When enabled, `val_labels_path` and `validation_fraction` are ignored. *Default*: `False`.
194
+ test_file_path: (str or List[str]) Path or list of paths to test dataset(s) (`.slp` file(s) or `.mp4` file(s)). *Note*: This is used only with CLI to get evaluation on test set after training is completed. *Default*: `None`.
161
195
  provider: (str) Provider class to read the input sleap files. Only "LabelsReader" is currently supported for the training pipeline. *Default*: `"LabelsReader"`.
162
196
  user_instances_only: (bool) `True` if only user labeled instances should be used for training. If `False`, both user labeled and predicted instances would be used. *Default*: `True`.
163
197
  data_pipeline_fw: (str) Framework to create the data loaders. One of [`torch_dataset`, `torch_dataset_cache_img_memory`, `torch_dataset_cache_img_disk`]. *Default*: `"torch_dataset"`. (Note: When using `torch_dataset`, `num_workers` in `trainer_config` should be set to 0 as multiprocessing doesn't work with pickling video backends.)
@@ -173,7 +207,10 @@ class DataConfig:
173
207
  train_labels_path: Optional[List[str]] = None
174
208
  val_labels_path: Optional[List[str]] = None # TODO : revisit MISSING!
175
209
  validation_fraction: float = 0.1
176
- test_file_path: Optional[str] = None
210
+ use_same_data_for_val: bool = False
211
+ test_file_path: Optional[Any] = field(
212
+ default=None, validator=validate_test_file_path
213
+ )
177
214
  provider: str = "LabelsReader"
178
215
  user_instances_only: bool = True
179
216
  data_pipeline_fw: str = "torch_dataset"
@@ -131,27 +131,18 @@ def get_aug_config(
131
131
 
132
132
  for g in geometric_aug:
133
133
  if g == "rotation":
134
- aug_config.geometric.affine_p = 1.0
135
- aug_config.geometric.scale_min = 1.0
136
- aug_config.geometric.scale_max = 1.0
137
- aug_config.geometric.translate_height = 0
138
- aug_config.geometric.translate_width = 0
134
+ # Use new independent rotation probability
135
+ aug_config.geometric.rotation_p = 1.0
139
136
  elif g == "scale":
137
+ # Use new independent scale probability
140
138
  aug_config.geometric.scale_min = 0.9
141
139
  aug_config.geometric.scale_max = 1.1
142
- aug_config.geometric.affine_p = 1.0
143
- aug_config.geometric.rotation_min = 0
144
- aug_config.geometric.rotation_max = 0
145
- aug_config.geometric.translate_height = 0
146
- aug_config.geometric.translate_width = 0
140
+ aug_config.geometric.scale_p = 1.0
147
141
  elif g == "translate":
142
+ # Use new independent translate probability
148
143
  aug_config.geometric.translate_height = 0.2
149
144
  aug_config.geometric.translate_width = 0.2
150
- aug_config.geometric.affine_p = 1.0
151
- aug_config.geometric.rotation_min = 0
152
- aug_config.geometric.rotation_max = 0
153
- aug_config.geometric.scale_min = 1.0
154
- aug_config.geometric.scale_max = 1.0
145
+ aug_config.geometric.translate_p = 1.0
155
146
  elif g == "erase_scale":
156
147
  aug_config.geometric.erase_p = 1.0
157
148
  elif g == "mixup":
@@ -456,7 +447,8 @@ def get_data_config(
456
447
  train_labels_path: Optional[List[str]] = None,
457
448
  val_labels_path: Optional[List[str]] = None,
458
449
  validation_fraction: float = 0.1,
459
- test_file_path: Optional[str] = None,
450
+ use_same_data_for_val: bool = False,
451
+ test_file_path: Optional[Union[str, List[str]]] = None,
460
452
  provider: str = "LabelsReader",
461
453
  user_instances_only: bool = True,
462
454
  data_pipeline_fw: str = "torch_dataset",
@@ -470,6 +462,7 @@ def get_data_config(
470
462
  max_width: Optional[int] = None,
471
463
  crop_size: Optional[int] = None,
472
464
  min_crop_size: Optional[int] = 100,
465
+ crop_padding: Optional[int] = None,
473
466
  use_augmentations_train: bool = False,
474
467
  intensity_aug: Optional[Union[str, List[str], Dict[str, Any]]] = None,
475
468
  geometry_aug: Optional[Union[str, List[str], Dict[str, Any]]] = None,
@@ -486,7 +479,11 @@ def get_data_config(
486
479
  training set to sample for generating the validation set. The remaining
487
480
  labeled frames will be left in the training set. If the `validation_labels`
488
481
  are already specified, this has no effect. Default: 0.1.
489
- test_file_path: Path to test dataset (`.slp` file or `.mp4` file).
482
+ use_same_data_for_val: If `True`, use the same data for both training and
483
+ validation (train = val). Useful for intentional overfitting on small
484
+ datasets. When enabled, `val_labels_path` and `validation_fraction` are
485
+ ignored. Default: False.
486
+ test_file_path: Path or list of paths to test dataset(s) (`.slp` file(s) or `.mp4` file(s)).
490
487
  Note: This is used to get evaluation on test set after training is completed.
491
488
  provider: Provider class to read the input sleap files. Only "LabelsReader"
492
489
  supported for the training pipeline. Default: "LabelsReader".
@@ -508,14 +505,17 @@ def get_data_config(
508
505
  is set to True, then we convert the image to grayscale (single-channel)
509
506
  image. If the source image has only one channel and this is set to False, then we retain the single channel input. Default: `False`.
510
507
  scale: Factor to resize the image dimensions by, specified as a float. Default: 1.0.
511
- max_height: Maximum height the image should be padded to. If not provided, the
508
+ max_height: Maximum height the original image should be resized and padded to. If not provided, the
512
509
  original image size will be retained. Default: None.
513
- max_width: Maximum width the image should be padded to. If not provided, the
510
+ max_width: Maximum width the original image should be resized and padded to. If not provided, the
514
511
  original image size will be retained. Default: None.
515
512
  crop_size: Crop size of each instance for centered-instance model.
516
513
  If `None`, this would be automatically computed based on the largest instance
517
- in the `sio.Labels` file. Default: None.
514
+ in the `sio.Labels` file. If `scale` is provided, then the cropped image will be resized according to `scale`. Default: None.
518
515
  min_crop_size: Minimum crop size to be used if `crop_size` is `None`. Default: 100.
516
+ crop_padding: Padding in pixels to add around instance bounding box when computing
517
+ crop size. If `None`, padding is auto-computed based on augmentation settings.
518
+ Only used when `crop_size` is `None`. Default: None.
519
519
  use_augmentations_train: True if the data augmentation should be applied to the
520
520
  training data, else False. Default: False.
521
521
  intensity_aug: One of ["uniform_noise", "gaussian_noise", "contrast", "brightness"]
@@ -541,6 +541,7 @@ def get_data_config(
541
541
  scale=scale,
542
542
  crop_size=crop_size,
543
543
  min_crop_size=min_crop_size,
544
+ crop_padding=crop_padding,
544
545
  )
545
546
  augmentation_config = None
546
547
  if use_augmentations_train:
@@ -553,6 +554,7 @@ def get_data_config(
553
554
  train_labels_path=train_labels_path,
554
555
  val_labels_path=val_labels_path,
555
556
  validation_fraction=validation_fraction,
557
+ use_same_data_for_val=use_same_data_for_val,
556
558
  test_file_path=test_file_path,
557
559
  provider=provider,
558
560
  user_instances_only=user_instances_only,
@@ -84,6 +84,12 @@ class WandBConfig:
84
84
  prv_runid: (str) Previous run ID if training should be resumed from a previous ckpt. *Default*: `None`.
85
85
  group: (str) Group for wandb logging. *Default*: `None`.
86
86
  current_run_id: (str) Run ID for the current model training. (stored once the training starts). *Default*: `None`.
87
+ viz_enabled: (bool) If True, log pre-rendered matplotlib images to wandb. *Default*: `True`.
88
+ viz_boxes: (bool) If True, log interactive keypoint boxes. *Default*: `False`.
89
+ viz_masks: (bool) If True, log confidence map overlay masks. *Default*: `False`.
90
+ viz_box_size: (float) Size of keypoint boxes in pixels (for viz_boxes). *Default*: `5.0`.
91
+ viz_confmap_threshold: (float) Threshold for confidence map masks (for viz_masks). *Default*: `0.1`.
92
+ log_viz_table: (bool) If True, also log images to a wandb.Table for backwards compatibility. *Default*: `False`.
87
93
  """
88
94
 
89
95
  entity: Optional[str] = None
@@ -95,6 +101,12 @@ class WandBConfig:
95
101
  prv_runid: Optional[str] = None
96
102
  group: Optional[str] = None
97
103
  current_run_id: Optional[str] = None
104
+ viz_enabled: bool = True
105
+ viz_boxes: bool = False
106
+ viz_masks: bool = False
107
+ viz_box_size: float = 5.0
108
+ viz_confmap_threshold: float = 0.1
109
+ log_viz_table: bool = False
98
110
 
99
111
 
100
112
  @define
@@ -112,10 +112,13 @@ def apply_geometric_augmentation(
112
112
  instances: torch.Tensor,
113
113
  rotation_min: Optional[float] = -15.0,
114
114
  rotation_max: Optional[float] = 15.0,
115
+ rotation_p: Optional[float] = None,
115
116
  scale_min: Optional[float] = 0.9,
116
117
  scale_max: Optional[float] = 1.1,
118
+ scale_p: Optional[float] = None,
117
119
  translate_width: Optional[float] = 0.02,
118
120
  translate_height: Optional[float] = 0.02,
121
+ translate_p: Optional[float] = None,
119
122
  affine_p: float = 0.0,
120
123
  erase_scale_min: Optional[float] = 0.0001,
121
124
  erase_scale_max: Optional[float] = 0.01,
@@ -133,11 +136,18 @@ def apply_geometric_augmentation(
133
136
  instances: Input keypoints. (n_samples, n_instances, n_nodes, 2) or (n_samples, n_nodes, 2)
134
137
  rotation_min: Minimum rotation angle in degrees. Default: -15.0.
135
138
  rotation_max: Maximum rotation angle in degrees. Default: 15.0.
139
+ rotation_p: Probability of applying random rotation independently. If None,
140
+ falls back to affine_p for bundled behavior. Default: None.
136
141
  scale_min: Minimum scaling factor for isotropic scaling. Default: 0.9.
137
142
  scale_max: Maximum scaling factor for isotropic scaling. Default: 1.1.
143
+ scale_p: Probability of applying random scaling independently. If None,
144
+ falls back to affine_p for bundled behavior. Default: None.
138
145
  translate_width: Maximum absolute fraction for horizontal translation. Default: 0.02.
139
146
  translate_height: Maximum absolute fraction for vertical translation. Default: 0.02.
140
- affine_p: Probability of applying random affine transformations. Default: 0.0.
147
+ translate_p: Probability of applying random translation independently. If None,
148
+ falls back to affine_p for bundled behavior. Default: None.
149
+ affine_p: Probability of applying random affine transformations (rotation, scale,
150
+ translate bundled). Used when individual *_p params are None. Default: 0.0.
141
151
  erase_scale_min: Minimum value of range of proportion of erased area against input image. Default: 0.0001.
142
152
  erase_scale_max: Maximum value of range of proportion of erased area against input image. Default: 0.01.
143
153
  erase_ratio_min: Minimum value of range of aspect ratio of erased area. Default: 1.
@@ -151,7 +161,49 @@ def apply_geometric_augmentation(
151
161
  Returns tuple: (image, instances) with augmentation applied.
152
162
  """
153
163
  aug_stack = []
154
- if affine_p > 0:
164
+
165
+ # Check if any individual probability is set
166
+ use_independent = (
167
+ rotation_p is not None or scale_p is not None or translate_p is not None
168
+ )
169
+
170
+ if use_independent:
171
+ # New behavior: Apply augmentations independently with separate probabilities
172
+ if rotation_p is not None and rotation_p > 0:
173
+ aug_stack.append(
174
+ K.augmentation.RandomRotation(
175
+ degrees=(rotation_min, rotation_max),
176
+ p=rotation_p,
177
+ keepdim=True,
178
+ same_on_batch=True,
179
+ )
180
+ )
181
+
182
+ if scale_p is not None and scale_p > 0:
183
+ aug_stack.append(
184
+ K.augmentation.RandomAffine(
185
+ degrees=0, # No rotation
186
+ translate=None, # No translation
187
+ scale=(scale_min, scale_max),
188
+ p=scale_p,
189
+ keepdim=True,
190
+ same_on_batch=True,
191
+ )
192
+ )
193
+
194
+ if translate_p is not None and translate_p > 0:
195
+ aug_stack.append(
196
+ K.augmentation.RandomAffine(
197
+ degrees=0, # No rotation
198
+ translate=(translate_width, translate_height),
199
+ scale=None, # No scaling
200
+ p=translate_p,
201
+ keepdim=True,
202
+ same_on_batch=True,
203
+ )
204
+ )
205
+ elif affine_p > 0:
206
+ # Legacy behavior: Bundled affine transformation
155
207
  aug_stack.append(
156
208
  K.augmentation.RandomAffine(
157
209
  degrees=(rotation_min, rotation_max),
@@ -177,6 +177,9 @@ class BaseDataset(Dataset):
177
177
  if self.user_instances_only:
178
178
  if lf.user_instances is not None and len(lf.user_instances) > 0:
179
179
  lf.instances = lf.user_instances
180
+ else:
181
+ # Skip frames without user instances
182
+ continue
180
183
  is_empty = True
181
184
  for _, inst in enumerate(lf.instances):
182
185
  if not inst.is_empty: # filter all NaN instances.
@@ -684,15 +687,12 @@ class CenteredInstanceDataset(BaseDataset):
684
687
  the images aren't cached and loaded from the `.slp` file on each access.
685
688
  cache_img_path: Path to save the `.jpg` files. If `None`, current working dir is used.
686
689
  use_existing_imgs: Use existing imgs/ chunks in the `cache_img_path`.
687
- crop_size: Crop size of each instance for centered-instance model.
690
+ crop_size: Crop size of each instance for centered-instance model. If `scale` is provided, then the cropped image will be resized according to `scale`.
688
691
  rank: Indicates the rank of the process. Used during distributed training to ensure that image storage to
689
692
  disk occurs only once across all workers.
690
693
  confmap_head_config: DictConfig object with all the keys in the `head_config` section.
691
694
  (required keys: `sigma`, `output_stride`, `part_names` and `anchor_part` depending on the model type ).
692
695
  labels_list: List of `sio.Labels` objects. Used to store the labels in the cache. (only used if `cache_img` is `None`)
693
-
694
- Note: If scale is provided for centered-instance model, the images are cropped out
695
- from the scaled image with the given crop size.
696
696
  """
697
697
 
698
698
  def __init__(
@@ -748,6 +748,9 @@ class CenteredInstanceDataset(BaseDataset):
748
748
  if self.user_instances_only:
749
749
  if lf.user_instances is not None and len(lf.user_instances) > 0:
750
750
  lf.instances = lf.user_instances
751
+ else:
752
+ # Skip frames without user instances
753
+ continue
751
754
  for inst_idx, inst in enumerate(lf.instances):
752
755
  if not inst.is_empty: # filter all NaN instances.
753
756
  video_idx = labels[labels_idx].videos.index(lf.video)
@@ -834,13 +837,6 @@ class CenteredInstanceDataset(BaseDataset):
834
837
  )
835
838
  instances = instances * eff_scale
836
839
 
837
- # resize image
838
- image, instances = apply_resizer(
839
- image,
840
- instances,
841
- scale=self.scale,
842
- )
843
-
844
840
  # get the centroids based on the anchor idx
845
841
  centroids = generate_centroids(instances, anchor_ind=self.anchor_ind)
846
842
 
@@ -901,6 +897,13 @@ class CenteredInstanceDataset(BaseDataset):
901
897
  sample["instance"] = center_instance # (n_samples=1, n_nodes, 2)
902
898
  sample["centroid"] = centered_centroid # (n_samples=1, 2)
903
899
 
900
+ # resize the cropped image
901
+ sample["instance_image"], sample["instance"] = apply_resizer(
902
+ sample["instance_image"],
903
+ sample["instance"],
904
+ scale=self.scale,
905
+ )
906
+
904
907
  # Pad the image (if needed) according max stride
905
908
  sample["instance_image"] = apply_pad_to_stride(
906
909
  sample["instance_image"], max_stride=self.max_stride
@@ -959,7 +962,7 @@ class TopDownCenteredInstanceMultiClassDataset(CenteredInstanceDataset):
959
962
  the images aren't cached and loaded from the `.slp` file on each access.
960
963
  cache_img_path: Path to save the `.jpg` files. If `None`, current working dir is used.
961
964
  use_existing_imgs: Use existing imgs/ chunks in the `cache_img_path`.
962
- crop_size: Crop size of each instance for centered-instance model.
965
+ crop_size: Crop size of each instance for centered-instance model. If `scale` is provided, then the cropped image will be resized according to `scale`.
963
966
  rank: Indicates the rank of the process. Used during distributed training to ensure that image storage to
964
967
  disk occurs only once across all workers.
965
968
  confmap_head_config: DictConfig object with all the keys in the `head_config` section.
@@ -967,9 +970,6 @@ class TopDownCenteredInstanceMultiClassDataset(CenteredInstanceDataset):
967
970
  class_vectors_head_config: DictConfig object with all the keys in the `head_config` section.
968
971
  (required keys: `classes`, `num_fc_layers`, `num_fc_units`, `output_stride`, `loss_weight`).
969
972
  labels_list: List of `sio.Labels` objects. Used to store the labels in the cache. (only used if `cache_img` is `None`)
970
-
971
- Note: If scale is provided for centered-instance model, the images are cropped out
972
- from the scaled image with the given crop size.
973
973
  """
974
974
 
975
975
  def __init__(
@@ -1082,13 +1082,6 @@ class TopDownCenteredInstanceMultiClassDataset(CenteredInstanceDataset):
1082
1082
  )
1083
1083
  instances = instances * eff_scale
1084
1084
 
1085
- # resize image
1086
- image, instances = apply_resizer(
1087
- image,
1088
- instances,
1089
- scale=self.scale,
1090
- )
1091
-
1092
1085
  # get class vectors
1093
1086
  track_ids = torch.Tensor(
1094
1087
  [
@@ -1165,6 +1158,13 @@ class TopDownCenteredInstanceMultiClassDataset(CenteredInstanceDataset):
1165
1158
  sample["instance"] = center_instance # (n_samples=1, n_nodes, 2)
1166
1159
  sample["centroid"] = centered_centroid # (n_samples=1, 2)
1167
1160
 
1161
+ # resize image
1162
+ sample["instance_image"], sample["instance"] = apply_resizer(
1163
+ sample["instance_image"],
1164
+ sample["instance"],
1165
+ scale=self.scale,
1166
+ )
1167
+
1168
1168
  # Pad the image (if needed) according max stride
1169
1169
  sample["instance_image"] = apply_pad_to_stride(
1170
1170
  sample["instance_image"], max_stride=self.max_stride