sleap-nn 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sleap_nn/__init__.py +9 -2
  2. sleap_nn/architectures/convnext.py +5 -0
  3. sleap_nn/architectures/encoder_decoder.py +25 -6
  4. sleap_nn/architectures/swint.py +8 -0
  5. sleap_nn/cli.py +489 -46
  6. sleap_nn/config/data_config.py +51 -8
  7. sleap_nn/config/get_config.py +32 -24
  8. sleap_nn/config/trainer_config.py +88 -0
  9. sleap_nn/data/augmentation.py +61 -200
  10. sleap_nn/data/custom_datasets.py +433 -61
  11. sleap_nn/data/instance_cropping.py +71 -6
  12. sleap_nn/data/normalization.py +45 -2
  13. sleap_nn/data/providers.py +26 -0
  14. sleap_nn/data/resizing.py +2 -2
  15. sleap_nn/data/skia_augmentation.py +414 -0
  16. sleap_nn/data/utils.py +135 -17
  17. sleap_nn/evaluation.py +177 -42
  18. sleap_nn/export/__init__.py +21 -0
  19. sleap_nn/export/cli.py +1778 -0
  20. sleap_nn/export/exporters/__init__.py +51 -0
  21. sleap_nn/export/exporters/onnx_exporter.py +80 -0
  22. sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
  23. sleap_nn/export/metadata.py +225 -0
  24. sleap_nn/export/predictors/__init__.py +63 -0
  25. sleap_nn/export/predictors/base.py +22 -0
  26. sleap_nn/export/predictors/onnx.py +154 -0
  27. sleap_nn/export/predictors/tensorrt.py +312 -0
  28. sleap_nn/export/utils.py +307 -0
  29. sleap_nn/export/wrappers/__init__.py +25 -0
  30. sleap_nn/export/wrappers/base.py +96 -0
  31. sleap_nn/export/wrappers/bottomup.py +243 -0
  32. sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
  33. sleap_nn/export/wrappers/centered_instance.py +56 -0
  34. sleap_nn/export/wrappers/centroid.py +58 -0
  35. sleap_nn/export/wrappers/single_instance.py +83 -0
  36. sleap_nn/export/wrappers/topdown.py +180 -0
  37. sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
  38. sleap_nn/inference/__init__.py +6 -0
  39. sleap_nn/inference/bottomup.py +86 -20
  40. sleap_nn/inference/peak_finding.py +93 -16
  41. sleap_nn/inference/postprocessing.py +284 -0
  42. sleap_nn/inference/predictors.py +339 -137
  43. sleap_nn/inference/provenance.py +292 -0
  44. sleap_nn/inference/topdown.py +55 -47
  45. sleap_nn/legacy_models.py +65 -11
  46. sleap_nn/predict.py +224 -19
  47. sleap_nn/system_info.py +443 -0
  48. sleap_nn/tracking/tracker.py +8 -1
  49. sleap_nn/train.py +138 -44
  50. sleap_nn/training/callbacks.py +1258 -5
  51. sleap_nn/training/lightning_modules.py +902 -220
  52. sleap_nn/training/model_trainer.py +424 -111
  53. sleap_nn/training/schedulers.py +191 -0
  54. sleap_nn/training/utils.py +367 -2
  55. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/METADATA +35 -33
  56. sleap_nn-0.1.0.dist-info/RECORD +88 -0
  57. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/WHEEL +1 -1
  58. sleap_nn-0.0.5.dist-info/RECORD +0 -63
  59. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/entry_points.txt +0 -0
  60. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/licenses/LICENSE +0 -0
  61. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/top_level.txt +0 -0
sleap_nn/cli.py CHANGED
@@ -1,15 +1,63 @@
1
- """Unified CLI for SLEAP-NN using Click."""
1
+ """Unified CLI for SLEAP-NN using rich-click for styled output."""
2
2
 
3
- import click
3
+ import subprocess
4
+ import tempfile
5
+ import shutil
6
+ from datetime import datetime
7
+
8
+ import rich_click as click
9
+ from click import Command
4
10
  from loguru import logger
5
11
  from pathlib import Path
6
12
  from omegaconf import OmegaConf, DictConfig
13
+ import sleap_io as sio
7
14
  from sleap_nn.predict import run_inference, frame_list
8
15
  from sleap_nn.evaluation import run_evaluation
16
+ from sleap_nn.export.cli import export as export_command
17
+ from sleap_nn.export.cli import predict as predict_command
9
18
  from sleap_nn.train import run_training
19
+ from sleap_nn import __version__
20
+ from sleap_nn.config.utils import get_model_type_from_cfg
10
21
  import hydra
11
22
  import sys
12
- from click import Command
23
+
24
+ # Rich-click configuration for styled help
25
+ click.rich_click.TEXT_MARKUP = "markdown"
26
+ click.rich_click.SHOW_ARGUMENTS = True
27
+ click.rich_click.GROUP_ARGUMENTS_OPTIONS = True
28
+ click.rich_click.STYLE_ERRORS_SUGGESTION = "magenta italic"
29
+ click.rich_click.ERRORS_EPILOGUE = (
30
+ "Try 'sleap-nn [COMMAND] --help' for more information."
31
+ )
32
+
33
+
34
+ def is_config_path(arg: str) -> bool:
35
+ """Check if an argument looks like a config file path.
36
+
37
+ Returns True if the arg ends with .yaml or .yml.
38
+ """
39
+ return arg.endswith(".yaml") or arg.endswith(".yml")
40
+
41
+
42
+ def split_config_path(config_path: str) -> tuple:
43
+ """Split a full config path into (config_dir, config_name).
44
+
45
+ Args:
46
+ config_path: Full path to a config file.
47
+
48
+ Returns:
49
+ Tuple of (config_dir, config_name) where config_dir is an absolute path.
50
+ """
51
+ path = Path(config_path).resolve()
52
+ return path.parent.as_posix(), path.name
53
+
54
+
55
+ def print_version(ctx, param, value):
56
+ """Print version and exit."""
57
+ if not value or ctx.resilient_parsing:
58
+ return
59
+ click.echo(f"sleap-nn {__version__}")
60
+ ctx.exit()
13
61
 
14
62
 
15
63
  class TrainCommand(Command):
@@ -20,73 +68,295 @@ class TrainCommand(Command):
20
68
  show_training_help()
21
69
 
22
70
 
71
+ def parse_path_map(ctx, param, value):
72
+ """Parse (old, new) path pairs into a dictionary for path mapping options."""
73
+ if not value:
74
+ return None
75
+ result = {}
76
+ for old_path, new_path in value:
77
+ result[old_path] = Path(new_path).as_posix()
78
+ return result
79
+
80
+
23
81
  @click.group()
82
+ @click.option(
83
+ "--version",
84
+ "-v",
85
+ is_flag=True,
86
+ callback=print_version,
87
+ expose_value=False,
88
+ is_eager=True,
89
+ help="Show version and exit.",
90
+ )
24
91
  def cli():
25
92
  """SLEAP-NN: Neural network backend for training and inference for animal pose estimation.
26
93
 
27
94
  Use subcommands to run different workflows:
28
95
 
29
- train - Run training workflow
30
- track - Run inference/ tracking workflow
96
+ train - Run training workflow (auto-handles multi-GPU)
97
+ track - Run inference/tracking workflow
31
98
  eval - Run evaluation workflow
99
+ system - Display system information and GPU status
32
100
  """
33
101
  pass
34
102
 
35
103
 
104
+ def _get_num_devices_from_config(cfg: DictConfig) -> int:
105
+ """Determine the number of devices from config.
106
+
107
+ User preferences take precedence over auto-detection:
108
+ - trainer_device_indices=[0] → 1 device (user choice)
109
+ - trainer_devices=1 → 1 device (user choice)
110
+ - trainer_devices="auto" or unset → auto-detect available GPUs
111
+
112
+ Returns:
113
+ Number of devices to use for training.
114
+ """
115
+ import torch
116
+
117
+ # User preference: explicit device indices (highest priority)
118
+ device_indices = OmegaConf.select(
119
+ cfg, "trainer_config.trainer_device_indices", default=None
120
+ )
121
+ if device_indices is not None and len(device_indices) > 0:
122
+ return len(device_indices)
123
+
124
+ # User preference: explicit device count
125
+ devices = OmegaConf.select(cfg, "trainer_config.trainer_devices", default="auto")
126
+
127
+ if isinstance(devices, int):
128
+ return devices
129
+
130
+ # Auto-detect only when user hasn't specified (devices is "auto" or None)
131
+ if devices in ("auto", None, "None"):
132
+ accelerator = OmegaConf.select(
133
+ cfg, "trainer_config.trainer_accelerator", default="auto"
134
+ )
135
+
136
+ if accelerator == "cpu":
137
+ return 1
138
+ elif torch.cuda.is_available():
139
+ return torch.cuda.device_count()
140
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
141
+ return 1
142
+ else:
143
+ return 1
144
+
145
+ return 1
146
+
147
+
148
+ def _finalize_config(cfg: DictConfig) -> DictConfig:
149
+ """Finalize configuration by generating run_name if not provided.
150
+
151
+ This runs ONCE before subprocess, ensuring all workers get the same run_name.
152
+ """
153
+ # Resolve ckpt_dir first
154
+ ckpt_dir = OmegaConf.select(cfg, "trainer_config.ckpt_dir", default=None)
155
+ if ckpt_dir is None or ckpt_dir == "" or ckpt_dir == "None":
156
+ cfg.trainer_config.ckpt_dir = "."
157
+
158
+ # Generate run_name if not provided
159
+ run_name = OmegaConf.select(cfg, "trainer_config.run_name", default=None)
160
+ if run_name is None or run_name == "" or run_name == "None":
161
+ # Get model type from config
162
+ model_type = get_model_type_from_cfg(cfg)
163
+
164
+ # Count frames from labels
165
+ train_paths = cfg.data_config.train_labels_path
166
+ val_paths = OmegaConf.select(cfg, "data_config.val_labels_path", default=None)
167
+
168
+ train_count = sum(len(sio.load_slp(p)) for p in train_paths)
169
+ val_count = 0
170
+ if val_paths:
171
+ val_count = sum(len(sio.load_slp(p)) for p in val_paths)
172
+
173
+ # Generate full run_name with timestamp
174
+ timestamp = datetime.now().strftime("%y%m%d_%H%M%S")
175
+ run_name = f"{timestamp}.{model_type}.n={train_count + val_count}"
176
+ cfg.trainer_config.run_name = run_name
177
+
178
+ logger.info(f"Generated run_name: {run_name}")
179
+
180
+ return cfg
181
+
182
+
36
183
  def show_training_help():
37
- """Display training help information."""
38
- help_text = """
39
- sleap-nn train — Train SLEAP models from a config YAML file.
40
-
41
- Usage:
42
- sleap-nn train --config-dir <dir> --config-name <name> [overrides]
43
-
44
- Common overrides:
45
- trainer_config.max_epochs=100
46
- trainer_config.batch_size=32
47
-
48
- Examples:
49
- Start new run:
50
- sleap-nn train --config-dir /path/to/config_dir/ --config-name myrun
51
- Resume 20 more epochs:
52
- sleap-nn train --config-dir /path/to/config_dir/ --config-name myrun \\
53
- trainer_config.resume_ckpt_path=<path/to/ckpt> \\
54
- trainer_config.max_epochs=20
55
-
56
- Tips:
57
- - Use -m/--multirun for sweeps; outputs go under hydra.sweep.dir.
58
- - For Hydra flags and completion, use --hydra-help.
59
-
60
- For a detailed list of all available config options, please refer to https://nn.sleap.ai/config/.
184
+ """Display training help information with rich formatting."""
185
+ from rich.console import Console
186
+ from rich.panel import Panel
187
+ from rich.markdown import Markdown
188
+
189
+ console = Console()
190
+
191
+ help_md = """
192
+ ## Usage
193
+
194
+ ```
195
+ sleap-nn train <config.yaml> [overrides]
196
+ sleap-nn train --config <path/to/config.yaml> [overrides]
197
+ ```
198
+
199
+ ## Common Overrides
200
+
201
+ | Override | Description |
202
+ |----------|-------------|
203
+ | `trainer_config.max_epochs=100` | Set maximum training epochs |
204
+ | `trainer_config.batch_size=32` | Set batch size |
205
+ | `trainer_config.save_ckpt=true` | Enable checkpoint saving |
206
+
207
+ ## Examples
208
+
209
+ **Start a new training run:**
210
+ ```bash
211
+ sleap-nn train path/to/config.yaml
212
+ sleap-nn train --config path/to/config.yaml
213
+ ```
214
+
215
+ **With overrides:**
216
+ ```bash
217
+ sleap-nn train config.yaml trainer_config.max_epochs=100
218
+ ```
219
+
220
+ **Resume training:**
221
+ ```bash
222
+ sleap-nn train config.yaml trainer_config.resume_ckpt_path=/path/to/ckpt
223
+ ```
224
+
225
+ **Legacy usage (still supported):**
226
+ ```bash
227
+ sleap-nn train --config-dir /path/to/dir --config-name myrun
228
+ ```
229
+
230
+ ## Tips
231
+
232
+ - Use `-m/--multirun` for sweeps; outputs go under `hydra.sweep.dir`
233
+ - For Hydra flags and completion, use `--hydra-help`
234
+ - Config documentation: https://nn.sleap.ai/config/
61
235
  """
62
- click.echo(help_text)
236
+ console.print(
237
+ Panel(
238
+ Markdown(help_md),
239
+ title="[bold cyan]sleap-nn train[/bold cyan]",
240
+ subtitle="Train SLEAP models from a config YAML file",
241
+ border_style="cyan",
242
+ )
243
+ )
63
244
 
64
245
 
65
246
  @cli.command(cls=TrainCommand)
66
- @click.option("--config-name", "-c", type=str, help="Configuration file name")
67
247
  @click.option(
68
- "--config-dir", "-d", type=str, default=".", help="Configuration directory path"
248
+ "--config",
249
+ type=str,
250
+ help="Path to configuration file (e.g., path/to/config.yaml)",
251
+ )
252
+ @click.option("--config-name", "-c", type=str, help="Configuration file name (legacy)")
253
+ @click.option(
254
+ "--config-dir", "-d", type=str, default=".", help="Configuration directory (legacy)"
255
+ )
256
+ @click.option(
257
+ "--video-paths",
258
+ "-v",
259
+ multiple=True,
260
+ help="Video paths to replace existing paths in the labels file. "
261
+ "Order must match the order of videos in the labels file. "
262
+ "Can be specified multiple times. "
263
+ "Example: --video-paths /path/to/vid1.mp4 --video-paths /path/to/vid2.mp4",
264
+ )
265
+ @click.option(
266
+ "--video-path-map",
267
+ nargs=2,
268
+ multiple=True,
269
+ callback=parse_path_map,
270
+ metavar="OLD NEW",
271
+ help="Map old video path to new path. Takes two arguments: old path and new path. "
272
+ "Can be specified multiple times. "
273
+ 'Example: --video-path-map "/old/vid.mp`4" "/new/vid.mp4"',
274
+ )
275
+ @click.option(
276
+ "--prefix-map",
277
+ nargs=2,
278
+ multiple=True,
279
+ callback=parse_path_map,
280
+ metavar="OLD NEW",
281
+ help="Map old path prefix to new prefix. Takes two arguments: old prefix and new prefix. "
282
+ "Updates ALL videos that share the same prefix. Useful when moving data between machines. "
283
+ "Can be specified multiple times. "
284
+ 'Example: --prefix-map "/old/server/path" "/new/local/path"',
285
+ )
286
+ @click.option(
287
+ "--video-config",
288
+ type=str,
289
+ hidden=True,
290
+ help="Path to video replacement config YAML (internal use for multi-GPU).",
69
291
  )
70
292
  @click.argument("overrides", nargs=-1, type=click.UNPROCESSED)
71
- def train(config_name, config_dir, overrides):
293
+ def train(
294
+ config,
295
+ config_name,
296
+ config_dir,
297
+ video_paths,
298
+ video_path_map,
299
+ prefix_map,
300
+ video_config,
301
+ overrides,
302
+ ):
72
303
  """Run training workflow with Hydra config overrides.
73
304
 
305
+ Automatically detects multi-GPU setups and handles run_name synchronization
306
+ by spawning training in a subprocess with a pre-generated config.
307
+
74
308
  Examples:
75
- sleap-nn train --config-name myconfig --config-dir /path/to/config_dir/
76
- sleap-nn train -c myconfig -d /path/to/config_dir/ trainer_config.max_epochs=100
77
- sleap-nn train -c myconfig -d /path/to/config_dir/ +experiment=new_model
309
+ sleap-nn train path/to/config.yaml
310
+ sleap-nn train --config path/to/config.yaml trainer_config.max_epochs=100
311
+ sleap-nn train config.yaml trainer_config.trainer_devices=4
78
312
  """
79
- # Show help if no config name provided
80
- if not config_name:
313
+ # Convert overrides to a mutable list
314
+ overrides = list(overrides)
315
+
316
+ # Check if the first positional arg is a config path (not a Hydra override)
317
+ config_from_positional = None
318
+ if overrides and is_config_path(overrides[0]):
319
+ config_from_positional = overrides.pop(0)
320
+
321
+ # Resolve config path with priority:
322
+ # 1. Positional config path (e.g., sleap-nn train config.yaml)
323
+ # 2. --config flag (e.g., sleap-nn train --config config.yaml)
324
+ # 3. Legacy --config-dir/--config-name flags
325
+ if config_from_positional:
326
+ config_dir, config_name = split_config_path(config_from_positional)
327
+ elif config:
328
+ config_dir, config_name = split_config_path(config)
329
+ elif config_name:
330
+ config_dir = Path(config_dir).resolve().as_posix()
331
+ else:
332
+ # No config provided - show help
81
333
  show_training_help()
82
334
  return
83
335
 
84
- # Initialize Hydra manually
85
- # resolve the path to the config directory (hydra expects absolute path)
86
- config_dir = Path(config_dir).resolve().as_posix()
336
+ # Check video path options early
337
+ # If --video-config is provided (from subprocess), load from file
338
+ if video_config:
339
+ video_cfg = OmegaConf.load(video_config)
340
+ video_paths = tuple(video_cfg.video_paths) if video_cfg.video_paths else ()
341
+ video_path_map = (
342
+ dict(video_cfg.video_path_map) if video_cfg.video_path_map else None
343
+ )
344
+ prefix_map = dict(video_cfg.prefix_map) if video_cfg.prefix_map else None
345
+
346
+ has_video_paths = len(video_paths) > 0
347
+ has_video_path_map = video_path_map is not None
348
+ has_prefix_map = prefix_map is not None
349
+ options_used = sum([has_video_paths, has_video_path_map, has_prefix_map])
350
+
351
+ if options_used > 1:
352
+ raise click.UsageError(
353
+ "Cannot use multiple path replacement options. "
354
+ "Choose one of: --video-paths, --video-path-map, or --prefix-map."
355
+ )
356
+
357
+ # Load config to detect device count
87
358
  with hydra.initialize_config_dir(config_dir=config_dir, version_base=None):
88
- # Compose config with overrides
89
- cfg = hydra.compose(config_name=config_name, overrides=list(overrides))
359
+ cfg = hydra.compose(config_name=config_name, overrides=overrides)
90
360
 
91
361
  # Validate config
92
362
  if not hasattr(cfg, "model_config") or not cfg.model_config:
@@ -95,9 +365,118 @@ def train(config_name, config_dir, overrides):
95
365
  )
96
366
  raise click.Abort()
97
367
 
368
+ num_devices = _get_num_devices_from_config(cfg)
369
+
370
+ # Check if run_name is already set (means we're a subprocess or user provided it)
371
+ run_name = OmegaConf.select(cfg, "trainer_config.run_name", default=None)
372
+ run_name_is_set = run_name is not None and run_name != "" and run_name != "None"
373
+
374
+ # Multi-GPU path: spawn subprocess with finalized config
375
+ # Only do this if run_name is NOT set (otherwise we'd loop infinitely or user set it)
376
+ if num_devices > 1 and not run_name_is_set:
377
+ logger.info(
378
+ f"Detected {num_devices} devices, using subprocess for run_name sync..."
379
+ )
380
+
381
+ # Load and finalize config (generate run_name, apply overrides)
382
+ with hydra.initialize_config_dir(config_dir=config_dir, version_base=None):
383
+ cfg = hydra.compose(config_name=config_name, overrides=overrides)
384
+ cfg = _finalize_config(cfg)
385
+
386
+ # Save finalized config to temp file
387
+ temp_dir = tempfile.mkdtemp(prefix="sleap_nn_train_")
388
+ temp_config_path = Path(temp_dir) / "training_config.yaml"
389
+ OmegaConf.save(cfg, temp_config_path)
390
+ logger.info(f"Saved finalized config to: {temp_config_path}")
391
+
392
+ # Save video replacement config if needed (so subprocess doesn't need CLI args)
393
+ temp_video_config_path = None
394
+ if options_used == 1:
395
+ video_replacement_config = {
396
+ "video_paths": list(video_paths) if has_video_paths else None,
397
+ "video_path_map": dict(video_path_map) if has_video_path_map else None,
398
+ "prefix_map": dict(prefix_map) if has_prefix_map else None,
399
+ }
400
+ temp_video_config_path = Path(temp_dir) / "video_replacement.yaml"
401
+ OmegaConf.save(
402
+ OmegaConf.create(video_replacement_config), temp_video_config_path
403
+ )
404
+ logger.info(f"Saved video replacement config to: {temp_video_config_path}")
405
+
406
+ # Build subprocess command (no video args - they're in the temp file)
407
+ cmd = [sys.executable, "-m", "sleap_nn.cli", "train", str(temp_config_path)]
408
+ if temp_video_config_path:
409
+ cmd.extend(["--video-config", str(temp_video_config_path)])
410
+
411
+ logger.info(f"Launching subprocess: {' '.join(cmd)}")
412
+
413
+ try:
414
+ process = subprocess.Popen(cmd)
415
+ result = process.wait()
416
+ if result != 0:
417
+ logger.error(f"Training failed with exit code {result}")
418
+ sys.exit(result)
419
+ except KeyboardInterrupt:
420
+ logger.info("Training interrupted, terminating subprocess...")
421
+ process.terminate()
422
+ try:
423
+ process.wait(timeout=5)
424
+ except subprocess.TimeoutExpired:
425
+ process.kill()
426
+ process.wait()
427
+ sys.exit(1)
428
+ finally:
429
+ shutil.rmtree(temp_dir, ignore_errors=True)
430
+ logger.info("Cleaned up temporary files")
431
+
432
+ return
433
+
434
+ # Single GPU (or subprocess worker): run directly
435
+ with hydra.initialize_config_dir(config_dir=config_dir, version_base=None):
436
+ cfg = hydra.compose(config_name=config_name, overrides=overrides)
437
+
98
438
  logger.info("Input config:")
99
439
  logger.info("\n" + OmegaConf.to_yaml(cfg))
100
- run_training(cfg)
440
+
441
+ # Handle video path replacement options
442
+ train_labels = None
443
+ val_labels = None
444
+
445
+ if options_used == 1:
446
+ # Load train labels
447
+ train_labels = [
448
+ sio.load_slp(path) for path in cfg.data_config.train_labels_path
449
+ ]
450
+
451
+ # Load val labels if they exist
452
+ if (
453
+ cfg.data_config.val_labels_path is not None
454
+ and len(cfg.data_config.val_labels_path) > 0
455
+ ):
456
+ val_labels = [
457
+ sio.load_slp(path) for path in cfg.data_config.val_labels_path
458
+ ]
459
+
460
+ # Build replacement arguments based on option used
461
+ if has_video_paths:
462
+ replace_kwargs = {
463
+ "new_filenames": [Path(p).as_posix() for p in video_paths]
464
+ }
465
+ elif has_video_path_map:
466
+ replace_kwargs = {"filename_map": video_path_map}
467
+ else: # has_prefix_map
468
+ replace_kwargs = {"prefix_map": prefix_map}
469
+
470
+ # Apply replacement to train labels
471
+ for labels in train_labels:
472
+ labels.replace_filenames(**replace_kwargs)
473
+
474
+ # Apply replacement to val labels if they exist
475
+ if val_labels:
476
+ for labels in val_labels:
477
+ labels.replace_filenames(**replace_kwargs)
478
+
479
+ run_training(config=cfg, train_labels=train_labels, val_labels=val_labels)
101
480
 
102
481
 
103
482
  @cli.command()
@@ -209,6 +588,18 @@ def train(config_name, config_dir, overrides):
209
588
  default=False,
210
589
  help="Only run inference on unlabeled suggested frames when running on labels dataset. This is useful for generating predictions for initialization during labeling.",
211
590
  )
591
+ @click.option(
592
+ "--exclude_user_labeled",
593
+ is_flag=True,
594
+ default=False,
595
+ help="Skip frames that have user-labeled instances. Useful when predicting on entire video but skipping already-labeled frames.",
596
+ )
597
+ @click.option(
598
+ "--only_predicted_frames",
599
+ is_flag=True,
600
+ default=False,
601
+ help="Only run inference on frames that already have predictions. Requires .slp input file. Useful for re-predicting with a different model.",
602
+ )
212
603
  @click.option(
213
604
  "--no_empty_frames",
214
605
  is_flag=True,
@@ -275,14 +666,14 @@ def train(config_name, config_dir, overrides):
275
666
  @click.option(
276
667
  "--queue_maxsize",
277
668
  type=int,
278
- default=8,
669
+ default=32,
279
670
  help="Maximum size of the frame buffer queue.",
280
671
  )
281
672
  @click.option(
282
673
  "--crop_size",
283
674
  type=int,
284
675
  default=None,
285
- help="Crop size. If not provided, the crop size from training_config.yaml is used.",
676
+ help="Crop size. If not provided, the crop size from training_config.yaml is used. If `input_scale` is provided, then the cropped image will be resized according to `input_scale`.",
286
677
  )
287
678
  @click.option(
288
679
  "--peak_threshold",
@@ -290,6 +681,36 @@ def train(config_name, config_dir, overrides):
290
681
  default=0.2,
291
682
  help="Minimum confidence map value to consider a peak as valid.",
292
683
  )
684
+ @click.option(
685
+ "--filter_overlapping",
686
+ is_flag=True,
687
+ default=False,
688
+ help=(
689
+ "Enable filtering of overlapping instances after inference using greedy NMS. "
690
+ "Applied independently of tracking. (default: False)"
691
+ ),
692
+ )
693
+ @click.option(
694
+ "--filter_overlapping_method",
695
+ type=click.Choice(["iou", "oks"]),
696
+ default="iou",
697
+ help=(
698
+ "Similarity metric for filtering overlapping instances. "
699
+ "'iou': bounding box intersection-over-union. "
700
+ "'oks': Object Keypoint Similarity (pose-based). (default: iou)"
701
+ ),
702
+ )
703
+ @click.option(
704
+ "--filter_overlapping_threshold",
705
+ type=float,
706
+ default=0.8,
707
+ help=(
708
+ "Similarity threshold for filtering overlapping instances. "
709
+ "Instances with similarity above this threshold are removed, "
710
+ "keeping the higher-scoring instance. "
711
+ "Typical values: 0.3 (aggressive) to 0.8 (permissive). (default: 0.8)"
712
+ ),
713
+ )
293
714
  @click.option(
294
715
  "--integral_refinement",
295
716
  type=str,
@@ -422,6 +843,12 @@ def train(config_name, config_dir, overrides):
422
843
  default=0,
423
844
  help="IOU to use when culling instances *after* tracking. (default: 0)",
424
845
  )
846
+ @click.option(
847
+ "--gui",
848
+ is_flag=True,
849
+ default=False,
850
+ help="Output JSON progress for GUI integration instead of Rich progress bar.",
851
+ )
425
852
  def track(**kwargs):
426
853
  """Run Inference and Tracking workflow."""
427
854
  # Convert model_paths from tuple to list
@@ -474,5 +901,21 @@ def eval(**kwargs):
474
901
  run_evaluation(**kwargs)
475
902
 
476
903
 
904
+ @cli.command()
905
+ def system():
906
+ """Display system information and GPU status.
907
+
908
+ Shows Python version, platform, PyTorch version, CUDA availability,
909
+ driver version with compatibility check, GPU details, and package versions.
910
+ """
911
+ from sleap_nn.system_info import print_system_info
912
+
913
+ print_system_info()
914
+
915
+
916
+ cli.add_command(export_command)
917
+ cli.add_command(predict_command)
918
+
919
+
477
920
  if __name__ == "__main__":
478
921
  cli()