sleap-nn 0.1.0__py3-none-any.whl → 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. sleap_nn/__init__.py +1 -1
  2. sleap_nn/architectures/convnext.py +0 -5
  3. sleap_nn/architectures/encoder_decoder.py +6 -25
  4. sleap_nn/architectures/swint.py +0 -8
  5. sleap_nn/cli.py +60 -364
  6. sleap_nn/config/data_config.py +5 -11
  7. sleap_nn/config/get_config.py +4 -5
  8. sleap_nn/config/trainer_config.py +0 -71
  9. sleap_nn/data/augmentation.py +241 -50
  10. sleap_nn/data/custom_datasets.py +34 -364
  11. sleap_nn/data/instance_cropping.py +1 -1
  12. sleap_nn/data/resizing.py +2 -2
  13. sleap_nn/data/utils.py +17 -135
  14. sleap_nn/evaluation.py +22 -81
  15. sleap_nn/inference/bottomup.py +20 -86
  16. sleap_nn/inference/peak_finding.py +19 -88
  17. sleap_nn/inference/predictors.py +117 -224
  18. sleap_nn/legacy_models.py +11 -65
  19. sleap_nn/predict.py +9 -37
  20. sleap_nn/train.py +4 -69
  21. sleap_nn/training/callbacks.py +105 -1046
  22. sleap_nn/training/lightning_modules.py +65 -602
  23. sleap_nn/training/model_trainer.py +204 -201
  24. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/METADATA +3 -15
  25. sleap_nn-0.1.0a1.dist-info/RECORD +65 -0
  26. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/WHEEL +1 -1
  27. sleap_nn/data/skia_augmentation.py +0 -414
  28. sleap_nn/export/__init__.py +0 -21
  29. sleap_nn/export/cli.py +0 -1778
  30. sleap_nn/export/exporters/__init__.py +0 -51
  31. sleap_nn/export/exporters/onnx_exporter.py +0 -80
  32. sleap_nn/export/exporters/tensorrt_exporter.py +0 -291
  33. sleap_nn/export/metadata.py +0 -225
  34. sleap_nn/export/predictors/__init__.py +0 -63
  35. sleap_nn/export/predictors/base.py +0 -22
  36. sleap_nn/export/predictors/onnx.py +0 -154
  37. sleap_nn/export/predictors/tensorrt.py +0 -312
  38. sleap_nn/export/utils.py +0 -307
  39. sleap_nn/export/wrappers/__init__.py +0 -25
  40. sleap_nn/export/wrappers/base.py +0 -96
  41. sleap_nn/export/wrappers/bottomup.py +0 -243
  42. sleap_nn/export/wrappers/bottomup_multiclass.py +0 -195
  43. sleap_nn/export/wrappers/centered_instance.py +0 -56
  44. sleap_nn/export/wrappers/centroid.py +0 -58
  45. sleap_nn/export/wrappers/single_instance.py +0 -83
  46. sleap_nn/export/wrappers/topdown.py +0 -180
  47. sleap_nn/export/wrappers/topdown_multiclass.py +0 -304
  48. sleap_nn/inference/postprocessing.py +0 -284
  49. sleap_nn/training/schedulers.py +0 -191
  50. sleap_nn-0.1.0.dist-info/RECORD +0 -88
  51. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/entry_points.txt +0 -0
  52. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/licenses/LICENSE +0 -0
  53. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/top_level.txt +0 -0
sleap_nn/__init__.py CHANGED
@@ -50,7 +50,7 @@ logger.add(
50
50
  colorize=False,
51
51
  )
52
52
 
53
- __version__ = "0.1.0"
53
+ __version__ = "0.1.0a1"
54
54
 
55
55
  # Public API
56
56
  from sleap_nn.evaluation import load_metrics
@@ -281,10 +281,6 @@ class ConvNextWrapper(nn.Module):
281
281
  # Keep the block output filters the same
282
282
  x_in_shape = int(self.arch["channels"][-1] * filters_rate)
283
283
 
284
- # Encoder channels for skip connections (reversed to match decoder order)
285
- # The forward pass uses enc_output[::2][::-1] for skip features
286
- encoder_channels = self.arch["channels"][::-1]
287
-
288
284
  self.dec = Decoder(
289
285
  x_in_shape=x_in_shape,
290
286
  current_stride=self.current_stride,
@@ -297,7 +293,6 @@ class ConvNextWrapper(nn.Module):
297
293
  block_contraction=self.block_contraction,
298
294
  output_stride=self.output_stride,
299
295
  up_interpolate=up_interpolate,
300
- encoder_channels=encoder_channels,
301
296
  )
302
297
 
303
298
  if len(self.dec.decoder_stack):
@@ -25,7 +25,7 @@ classes.
25
25
  See the `EncoderDecoder` base class for requirements for creating new architectures.
26
26
  """
27
27
 
28
- from typing import List, Optional, Text, Tuple, Union
28
+ from typing import List, Text, Tuple, Union
29
29
  from collections import OrderedDict
30
30
  import torch
31
31
  from torch import nn
@@ -391,18 +391,10 @@ class SimpleUpsamplingBlock(nn.Module):
391
391
  transpose_convs_activation: Text = "relu",
392
392
  feat_concat: bool = True,
393
393
  prefix: Text = "",
394
- skip_channels: Optional[int] = None,
395
394
  ) -> None:
396
395
  """Initialize the class."""
397
396
  super().__init__()
398
397
 
399
- # Determine skip connection channels
400
- # If skip_channels is provided, use it; otherwise fall back to refine_convs_filters
401
- # This allows ConvNext/SwinT to specify actual encoder channels
402
- self.skip_channels = (
403
- skip_channels if skip_channels is not None else refine_convs_filters
404
- )
405
-
406
398
  self.x_in_shape = x_in_shape
407
399
  self.current_stride = current_stride
408
400
  self.upsampling_stride = upsampling_stride
@@ -477,13 +469,13 @@ class SimpleUpsamplingBlock(nn.Module):
477
469
  first_conv_in_channels = refine_convs_filters
478
470
  else:
479
471
  if self.up_interpolate:
480
- # With interpolation, input is x_in_shape + skip_channels
481
- # skip_channels may differ from refine_convs_filters for ConvNext/SwinT
482
- first_conv_in_channels = x_in_shape + self.skip_channels
472
+ # With interpolation, input is x_in_shape + feature channels
473
+ # The feature channels are the same as x_in_shape since they come from the same level
474
+ first_conv_in_channels = x_in_shape + refine_convs_filters
483
475
  else:
484
- # With transpose conv, input is transpose_conv_output + skip_channels
476
+ # With transpose conv, input is transpose_conv_output + feature channels
485
477
  first_conv_in_channels = (
486
- self.skip_channels + transpose_convs_filters
478
+ refine_convs_filters + transpose_convs_filters
487
479
  )
488
480
  else:
489
481
  if not self.feat_concat:
@@ -590,7 +582,6 @@ class Decoder(nn.Module):
590
582
  block_contraction: bool = False,
591
583
  up_interpolate: bool = True,
592
584
  prefix: str = "dec",
593
- encoder_channels: Optional[List[int]] = None,
594
585
  ) -> None:
595
586
  """Initialize the class."""
596
587
  super().__init__()
@@ -607,7 +598,6 @@ class Decoder(nn.Module):
607
598
  self.block_contraction = block_contraction
608
599
  self.prefix = prefix
609
600
  self.stride_to_filters = {}
610
- self.encoder_channels = encoder_channels
611
601
 
612
602
  self.current_strides = []
613
603
  self.residuals = 0
@@ -634,13 +624,6 @@ class Decoder(nn.Module):
634
624
 
635
625
  next_stride = current_stride // 2
636
626
 
637
- # Determine skip channels for this decoder block
638
- # If encoder_channels provided, use actual encoder channels
639
- # Otherwise fall back to computed filters (for UNet compatibility)
640
- skip_channels = None
641
- if encoder_channels is not None and block < len(encoder_channels):
642
- skip_channels = encoder_channels[block]
643
-
644
627
  if self.stem_blocks > 0 and block >= down_blocks + self.stem_blocks:
645
628
  # This accounts for the case where we dont have any more down block features to concatenate with.
646
629
  # In this case, add a simple upsampling block with a conv layer and with no concatenation
@@ -659,7 +642,6 @@ class Decoder(nn.Module):
659
642
  transpose_convs_batch_norm=False,
660
643
  feat_concat=False,
661
644
  prefix=f"{self.prefix}{block}_s{current_stride}_to_s{next_stride}",
662
- skip_channels=skip_channels,
663
645
  )
664
646
  )
665
647
  else:
@@ -677,7 +659,6 @@ class Decoder(nn.Module):
677
659
  transpose_convs_filters=block_filters_out,
678
660
  transpose_convs_batch_norm=False,
679
661
  prefix=f"{self.prefix}{block}_s{current_stride}_to_s{next_stride}",
680
- skip_channels=skip_channels,
681
662
  )
682
663
  )
683
664
 
@@ -309,13 +309,6 @@ class SwinTWrapper(nn.Module):
309
309
  self.stem_patch_stride * (2**3) * 2
310
310
  ) # stem_stride * down_blocks_stride * final_max_pool_stride
311
311
 
312
- # Encoder channels for skip connections (reversed to match decoder order)
313
- # SwinT channels: embed * 2^i for each stage i, then reversed
314
- num_stages = len(self.arch["depths"])
315
- encoder_channels = [
316
- self.arch["embed"] * (2 ** (num_stages - 1 - i)) for i in range(num_stages)
317
- ]
318
-
319
312
  self.dec = Decoder(
320
313
  x_in_shape=block_filters,
321
314
  current_stride=self.current_stride,
@@ -328,7 +321,6 @@ class SwinTWrapper(nn.Module):
328
321
  block_contraction=self.block_contraction,
329
322
  output_stride=output_stride,
330
323
  up_interpolate=up_interpolate,
331
- encoder_channels=encoder_channels,
332
324
  )
333
325
 
334
326
  if len(self.dec.decoder_stack):
sleap_nn/cli.py CHANGED
@@ -1,55 +1,17 @@
1
- """Unified CLI for SLEAP-NN using rich-click for styled output."""
1
+ """Unified CLI for SLEAP-NN using Click."""
2
2
 
3
- import subprocess
4
- import tempfile
5
- import shutil
6
- from datetime import datetime
7
-
8
- import rich_click as click
9
- from click import Command
3
+ import click
10
4
  from loguru import logger
11
5
  from pathlib import Path
12
6
  from omegaconf import OmegaConf, DictConfig
13
7
  import sleap_io as sio
14
8
  from sleap_nn.predict import run_inference, frame_list
15
9
  from sleap_nn.evaluation import run_evaluation
16
- from sleap_nn.export.cli import export as export_command
17
- from sleap_nn.export.cli import predict as predict_command
18
10
  from sleap_nn.train import run_training
19
11
  from sleap_nn import __version__
20
- from sleap_nn.config.utils import get_model_type_from_cfg
21
12
  import hydra
22
13
  import sys
23
-
24
- # Rich-click configuration for styled help
25
- click.rich_click.TEXT_MARKUP = "markdown"
26
- click.rich_click.SHOW_ARGUMENTS = True
27
- click.rich_click.GROUP_ARGUMENTS_OPTIONS = True
28
- click.rich_click.STYLE_ERRORS_SUGGESTION = "magenta italic"
29
- click.rich_click.ERRORS_EPILOGUE = (
30
- "Try 'sleap-nn [COMMAND] --help' for more information."
31
- )
32
-
33
-
34
- def is_config_path(arg: str) -> bool:
35
- """Check if an argument looks like a config file path.
36
-
37
- Returns True if the arg ends with .yaml or .yml.
38
- """
39
- return arg.endswith(".yaml") or arg.endswith(".yml")
40
-
41
-
42
- def split_config_path(config_path: str) -> tuple:
43
- """Split a full config path into (config_dir, config_name).
44
-
45
- Args:
46
- config_path: Full path to a config file.
47
-
48
- Returns:
49
- Tuple of (config_dir, config_name) where config_dir is an absolute path.
50
- """
51
- path = Path(config_path).resolve()
52
- return path.parent.as_posix(), path.name
14
+ from click import Command
53
15
 
54
16
 
55
17
  def print_version(ctx, param, value):
@@ -93,165 +55,47 @@ def cli():
93
55
 
94
56
  Use subcommands to run different workflows:
95
57
 
96
- train - Run training workflow (auto-handles multi-GPU)
97
- track - Run inference/tracking workflow
58
+ train - Run training workflow
59
+ track - Run inference/ tracking workflow
98
60
  eval - Run evaluation workflow
99
61
  system - Display system information and GPU status
100
62
  """
101
63
  pass
102
64
 
103
65
 
104
- def _get_num_devices_from_config(cfg: DictConfig) -> int:
105
- """Determine the number of devices from config.
106
-
107
- User preferences take precedence over auto-detection:
108
- - trainer_device_indices=[0] → 1 device (user choice)
109
- - trainer_devices=1 → 1 device (user choice)
110
- - trainer_devices="auto" or unset → auto-detect available GPUs
111
-
112
- Returns:
113
- Number of devices to use for training.
114
- """
115
- import torch
116
-
117
- # User preference: explicit device indices (highest priority)
118
- device_indices = OmegaConf.select(
119
- cfg, "trainer_config.trainer_device_indices", default=None
120
- )
121
- if device_indices is not None and len(device_indices) > 0:
122
- return len(device_indices)
123
-
124
- # User preference: explicit device count
125
- devices = OmegaConf.select(cfg, "trainer_config.trainer_devices", default="auto")
126
-
127
- if isinstance(devices, int):
128
- return devices
129
-
130
- # Auto-detect only when user hasn't specified (devices is "auto" or None)
131
- if devices in ("auto", None, "None"):
132
- accelerator = OmegaConf.select(
133
- cfg, "trainer_config.trainer_accelerator", default="auto"
134
- )
135
-
136
- if accelerator == "cpu":
137
- return 1
138
- elif torch.cuda.is_available():
139
- return torch.cuda.device_count()
140
- elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
141
- return 1
142
- else:
143
- return 1
144
-
145
- return 1
146
-
147
-
148
- def _finalize_config(cfg: DictConfig) -> DictConfig:
149
- """Finalize configuration by generating run_name if not provided.
150
-
151
- This runs ONCE before subprocess, ensuring all workers get the same run_name.
152
- """
153
- # Resolve ckpt_dir first
154
- ckpt_dir = OmegaConf.select(cfg, "trainer_config.ckpt_dir", default=None)
155
- if ckpt_dir is None or ckpt_dir == "" or ckpt_dir == "None":
156
- cfg.trainer_config.ckpt_dir = "."
157
-
158
- # Generate run_name if not provided
159
- run_name = OmegaConf.select(cfg, "trainer_config.run_name", default=None)
160
- if run_name is None or run_name == "" or run_name == "None":
161
- # Get model type from config
162
- model_type = get_model_type_from_cfg(cfg)
163
-
164
- # Count frames from labels
165
- train_paths = cfg.data_config.train_labels_path
166
- val_paths = OmegaConf.select(cfg, "data_config.val_labels_path", default=None)
167
-
168
- train_count = sum(len(sio.load_slp(p)) for p in train_paths)
169
- val_count = 0
170
- if val_paths:
171
- val_count = sum(len(sio.load_slp(p)) for p in val_paths)
172
-
173
- # Generate full run_name with timestamp
174
- timestamp = datetime.now().strftime("%y%m%d_%H%M%S")
175
- run_name = f"{timestamp}.{model_type}.n={train_count + val_count}"
176
- cfg.trainer_config.run_name = run_name
177
-
178
- logger.info(f"Generated run_name: {run_name}")
179
-
180
- return cfg
181
-
182
-
183
66
  def show_training_help():
184
- """Display training help information with rich formatting."""
185
- from rich.console import Console
186
- from rich.panel import Panel
187
- from rich.markdown import Markdown
188
-
189
- console = Console()
190
-
191
- help_md = """
192
- ## Usage
193
-
194
- ```
195
- sleap-nn train <config.yaml> [overrides]
196
- sleap-nn train --config <path/to/config.yaml> [overrides]
197
- ```
198
-
199
- ## Common Overrides
200
-
201
- | Override | Description |
202
- |----------|-------------|
203
- | `trainer_config.max_epochs=100` | Set maximum training epochs |
204
- | `trainer_config.batch_size=32` | Set batch size |
205
- | `trainer_config.save_ckpt=true` | Enable checkpoint saving |
206
-
207
- ## Examples
208
-
209
- **Start a new training run:**
210
- ```bash
211
- sleap-nn train path/to/config.yaml
212
- sleap-nn train --config path/to/config.yaml
213
- ```
214
-
215
- **With overrides:**
216
- ```bash
217
- sleap-nn train config.yaml trainer_config.max_epochs=100
218
- ```
219
-
220
- **Resume training:**
221
- ```bash
222
- sleap-nn train config.yaml trainer_config.resume_ckpt_path=/path/to/ckpt
223
- ```
224
-
225
- **Legacy usage (still supported):**
226
- ```bash
227
- sleap-nn train --config-dir /path/to/dir --config-name myrun
228
- ```
229
-
230
- ## Tips
231
-
232
- - Use `-m/--multirun` for sweeps; outputs go under `hydra.sweep.dir`
233
- - For Hydra flags and completion, use `--hydra-help`
234
- - Config documentation: https://nn.sleap.ai/config/
67
+ """Display training help information."""
68
+ help_text = """
69
+ sleap-nn train — Train SLEAP models from a config YAML file.
70
+
71
+ Usage:
72
+ sleap-nn train --config-dir <dir> --config-name <name> [overrides]
73
+
74
+ Common overrides:
75
+ trainer_config.max_epochs=100
76
+ trainer_config.batch_size=32
77
+
78
+ Examples:
79
+ Start new run:
80
+ sleap-nn train --config-dir /path/to/config_dir/ --config-name myrun
81
+ Resume 20 more epochs:
82
+ sleap-nn train --config-dir /path/to/config_dir/ --config-name myrun \\
83
+ trainer_config.resume_ckpt_path=<path/to/ckpt> \\
84
+ trainer_config.max_epochs=20
85
+
86
+ Tips:
87
+ - Use -m/--multirun for sweeps; outputs go under hydra.sweep.dir.
88
+ - For Hydra flags and completion, use --hydra-help.
89
+
90
+ For a detailed list of all available config options, please refer to https://nn.sleap.ai/config/.
235
91
  """
236
- console.print(
237
- Panel(
238
- Markdown(help_md),
239
- title="[bold cyan]sleap-nn train[/bold cyan]",
240
- subtitle="Train SLEAP models from a config YAML file",
241
- border_style="cyan",
242
- )
243
- )
92
+ click.echo(help_text)
244
93
 
245
94
 
246
95
  @cli.command(cls=TrainCommand)
96
+ @click.option("--config-name", "-c", type=str, help="Configuration file name")
247
97
  @click.option(
248
- "--config",
249
- type=str,
250
- help="Path to configuration file (e.g., path/to/config.yaml)",
251
- )
252
- @click.option("--config-name", "-c", type=str, help="Configuration file name (legacy)")
253
- @click.option(
254
- "--config-dir", "-d", type=str, default=".", help="Configuration directory (legacy)"
98
+ "--config-dir", "-d", type=str, default=".", help="Configuration directory path"
255
99
  )
256
100
  @click.option(
257
101
  "--video-paths",
@@ -283,80 +127,26 @@ sleap-nn train --config-dir /path/to/dir --config-name myrun
283
127
  "Can be specified multiple times. "
284
128
  'Example: --prefix-map "/old/server/path" "/new/local/path"',
285
129
  )
286
- @click.option(
287
- "--video-config",
288
- type=str,
289
- hidden=True,
290
- help="Path to video replacement config YAML (internal use for multi-GPU).",
291
- )
292
130
  @click.argument("overrides", nargs=-1, type=click.UNPROCESSED)
293
- def train(
294
- config,
295
- config_name,
296
- config_dir,
297
- video_paths,
298
- video_path_map,
299
- prefix_map,
300
- video_config,
301
- overrides,
302
- ):
131
+ def train(config_name, config_dir, video_paths, video_path_map, prefix_map, overrides):
303
132
  """Run training workflow with Hydra config overrides.
304
133
 
305
- Automatically detects multi-GPU setups and handles run_name synchronization
306
- by spawning training in a subprocess with a pre-generated config.
307
-
308
134
  Examples:
309
- sleap-nn train path/to/config.yaml
310
- sleap-nn train --config path/to/config.yaml trainer_config.max_epochs=100
311
- sleap-nn train config.yaml trainer_config.trainer_devices=4
135
+ sleap-nn train --config-name myconfig --config-dir /path/to/config_dir/
136
+ sleap-nn train -c myconfig -d /path/to/config_dir/ trainer_config.max_epochs=100
137
+ sleap-nn train -c myconfig -d /path/to/config_dir/ +experiment=new_model
312
138
  """
313
- # Convert overrides to a mutable list
314
- overrides = list(overrides)
315
-
316
- # Check if the first positional arg is a config path (not a Hydra override)
317
- config_from_positional = None
318
- if overrides and is_config_path(overrides[0]):
319
- config_from_positional = overrides.pop(0)
320
-
321
- # Resolve config path with priority:
322
- # 1. Positional config path (e.g., sleap-nn train config.yaml)
323
- # 2. --config flag (e.g., sleap-nn train --config config.yaml)
324
- # 3. Legacy --config-dir/--config-name flags
325
- if config_from_positional:
326
- config_dir, config_name = split_config_path(config_from_positional)
327
- elif config:
328
- config_dir, config_name = split_config_path(config)
329
- elif config_name:
330
- config_dir = Path(config_dir).resolve().as_posix()
331
- else:
332
- # No config provided - show help
139
+ # Show help if no config name provided
140
+ if not config_name:
333
141
  show_training_help()
334
142
  return
335
143
 
336
- # Check video path options early
337
- # If --video-config is provided (from subprocess), load from file
338
- if video_config:
339
- video_cfg = OmegaConf.load(video_config)
340
- video_paths = tuple(video_cfg.video_paths) if video_cfg.video_paths else ()
341
- video_path_map = (
342
- dict(video_cfg.video_path_map) if video_cfg.video_path_map else None
343
- )
344
- prefix_map = dict(video_cfg.prefix_map) if video_cfg.prefix_map else None
345
-
346
- has_video_paths = len(video_paths) > 0
347
- has_video_path_map = video_path_map is not None
348
- has_prefix_map = prefix_map is not None
349
- options_used = sum([has_video_paths, has_video_path_map, has_prefix_map])
350
-
351
- if options_used > 1:
352
- raise click.UsageError(
353
- "Cannot use multiple path replacement options. "
354
- "Choose one of: --video-paths, --video-path-map, or --prefix-map."
355
- )
356
-
357
- # Load config to detect device count
144
+ # Initialize Hydra manually
145
+ # resolve the path to the config directory (hydra expects absolute path)
146
+ config_dir = Path(config_dir).resolve().as_posix()
358
147
  with hydra.initialize_config_dir(config_dir=config_dir, version_base=None):
359
- cfg = hydra.compose(config_name=config_name, overrides=overrides)
148
+ # Compose config with overrides
149
+ cfg = hydra.compose(config_name=config_name, overrides=list(overrides))
360
150
 
361
151
  # Validate config
362
152
  if not hasattr(cfg, "model_config") or not cfg.model_config:
@@ -365,76 +155,6 @@ def train(
365
155
  )
366
156
  raise click.Abort()
367
157
 
368
- num_devices = _get_num_devices_from_config(cfg)
369
-
370
- # Check if run_name is already set (means we're a subprocess or user provided it)
371
- run_name = OmegaConf.select(cfg, "trainer_config.run_name", default=None)
372
- run_name_is_set = run_name is not None and run_name != "" and run_name != "None"
373
-
374
- # Multi-GPU path: spawn subprocess with finalized config
375
- # Only do this if run_name is NOT set (otherwise we'd loop infinitely or user set it)
376
- if num_devices > 1 and not run_name_is_set:
377
- logger.info(
378
- f"Detected {num_devices} devices, using subprocess for run_name sync..."
379
- )
380
-
381
- # Load and finalize config (generate run_name, apply overrides)
382
- with hydra.initialize_config_dir(config_dir=config_dir, version_base=None):
383
- cfg = hydra.compose(config_name=config_name, overrides=overrides)
384
- cfg = _finalize_config(cfg)
385
-
386
- # Save finalized config to temp file
387
- temp_dir = tempfile.mkdtemp(prefix="sleap_nn_train_")
388
- temp_config_path = Path(temp_dir) / "training_config.yaml"
389
- OmegaConf.save(cfg, temp_config_path)
390
- logger.info(f"Saved finalized config to: {temp_config_path}")
391
-
392
- # Save video replacement config if needed (so subprocess doesn't need CLI args)
393
- temp_video_config_path = None
394
- if options_used == 1:
395
- video_replacement_config = {
396
- "video_paths": list(video_paths) if has_video_paths else None,
397
- "video_path_map": dict(video_path_map) if has_video_path_map else None,
398
- "prefix_map": dict(prefix_map) if has_prefix_map else None,
399
- }
400
- temp_video_config_path = Path(temp_dir) / "video_replacement.yaml"
401
- OmegaConf.save(
402
- OmegaConf.create(video_replacement_config), temp_video_config_path
403
- )
404
- logger.info(f"Saved video replacement config to: {temp_video_config_path}")
405
-
406
- # Build subprocess command (no video args - they're in the temp file)
407
- cmd = [sys.executable, "-m", "sleap_nn.cli", "train", str(temp_config_path)]
408
- if temp_video_config_path:
409
- cmd.extend(["--video-config", str(temp_video_config_path)])
410
-
411
- logger.info(f"Launching subprocess: {' '.join(cmd)}")
412
-
413
- try:
414
- process = subprocess.Popen(cmd)
415
- result = process.wait()
416
- if result != 0:
417
- logger.error(f"Training failed with exit code {result}")
418
- sys.exit(result)
419
- except KeyboardInterrupt:
420
- logger.info("Training interrupted, terminating subprocess...")
421
- process.terminate()
422
- try:
423
- process.wait(timeout=5)
424
- except subprocess.TimeoutExpired:
425
- process.kill()
426
- process.wait()
427
- sys.exit(1)
428
- finally:
429
- shutil.rmtree(temp_dir, ignore_errors=True)
430
- logger.info("Cleaned up temporary files")
431
-
432
- return
433
-
434
- # Single GPU (or subprocess worker): run directly
435
- with hydra.initialize_config_dir(config_dir=config_dir, version_base=None):
436
- cfg = hydra.compose(config_name=config_name, overrides=overrides)
437
-
438
158
  logger.info("Input config:")
439
159
  logger.info("\n" + OmegaConf.to_yaml(cfg))
440
160
 
@@ -442,6 +162,19 @@ def train(
442
162
  train_labels = None
443
163
  val_labels = None
444
164
 
165
+ # Check that only one replacement option is used
166
+ # video_paths is a tuple (empty if not used), others are None or dict
167
+ has_video_paths = len(video_paths) > 0
168
+ has_video_path_map = video_path_map is not None
169
+ has_prefix_map = prefix_map is not None
170
+ options_used = sum([has_video_paths, has_video_path_map, has_prefix_map])
171
+
172
+ if options_used > 1:
173
+ raise click.UsageError(
174
+ "Cannot use multiple path replacement options. "
175
+ "Choose one of: --video-paths, --video-path-map, or --prefix-map."
176
+ )
177
+
445
178
  if options_used == 1:
446
179
  # Load train labels
447
180
  train_labels = [
@@ -459,12 +192,15 @@ def train(
459
192
 
460
193
  # Build replacement arguments based on option used
461
194
  if has_video_paths:
195
+ # List of paths (order must match videos in labels file)
462
196
  replace_kwargs = {
463
197
  "new_filenames": [Path(p).as_posix() for p in video_paths]
464
198
  }
465
199
  elif has_video_path_map:
200
+ # Dictionary mapping old filenames to new filenames
466
201
  replace_kwargs = {"filename_map": video_path_map}
467
202
  else: # has_prefix_map
203
+ # Dictionary mapping old prefixes to new prefixes
468
204
  replace_kwargs = {"prefix_map": prefix_map}
469
205
 
470
206
  # Apply replacement to train labels
@@ -666,7 +402,7 @@ def train(
666
402
  @click.option(
667
403
  "--queue_maxsize",
668
404
  type=int,
669
- default=32,
405
+ default=8,
670
406
  help="Maximum size of the frame buffer queue.",
671
407
  )
672
408
  @click.option(
@@ -681,36 +417,6 @@ def train(
681
417
  default=0.2,
682
418
  help="Minimum confidence map value to consider a peak as valid.",
683
419
  )
684
- @click.option(
685
- "--filter_overlapping",
686
- is_flag=True,
687
- default=False,
688
- help=(
689
- "Enable filtering of overlapping instances after inference using greedy NMS. "
690
- "Applied independently of tracking. (default: False)"
691
- ),
692
- )
693
- @click.option(
694
- "--filter_overlapping_method",
695
- type=click.Choice(["iou", "oks"]),
696
- default="iou",
697
- help=(
698
- "Similarity metric for filtering overlapping instances. "
699
- "'iou': bounding box intersection-over-union. "
700
- "'oks': Object Keypoint Similarity (pose-based). (default: iou)"
701
- ),
702
- )
703
- @click.option(
704
- "--filter_overlapping_threshold",
705
- type=float,
706
- default=0.8,
707
- help=(
708
- "Similarity threshold for filtering overlapping instances. "
709
- "Instances with similarity above this threshold are removed, "
710
- "keeping the higher-scoring instance. "
711
- "Typical values: 0.3 (aggressive) to 0.8 (permissive). (default: 0.8)"
712
- ),
713
- )
714
420
  @click.option(
715
421
  "--integral_refinement",
716
422
  type=str,
@@ -843,12 +549,6 @@ def train(
843
549
  default=0,
844
550
  help="IOU to use when culling instances *after* tracking. (default: 0)",
845
551
  )
846
- @click.option(
847
- "--gui",
848
- is_flag=True,
849
- default=False,
850
- help="Output JSON progress for GUI integration instead of Rich progress bar.",
851
- )
852
552
  def track(**kwargs):
853
553
  """Run Inference and Tracking workflow."""
854
554
  # Convert model_paths from tuple to list
@@ -913,9 +613,5 @@ def system():
913
613
  print_system_info()
914
614
 
915
615
 
916
- cli.add_command(export_command)
917
- cli.add_command(predict_command)
918
-
919
-
920
616
  if __name__ == "__main__":
921
617
  cli()