sleap-nn 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +9 -2
- sleap_nn/architectures/convnext.py +5 -0
- sleap_nn/architectures/encoder_decoder.py +25 -6
- sleap_nn/architectures/swint.py +8 -0
- sleap_nn/cli.py +489 -46
- sleap_nn/config/data_config.py +51 -8
- sleap_nn/config/get_config.py +32 -24
- sleap_nn/config/trainer_config.py +88 -0
- sleap_nn/data/augmentation.py +61 -200
- sleap_nn/data/custom_datasets.py +433 -61
- sleap_nn/data/instance_cropping.py +71 -6
- sleap_nn/data/normalization.py +45 -2
- sleap_nn/data/providers.py +26 -0
- sleap_nn/data/resizing.py +2 -2
- sleap_nn/data/skia_augmentation.py +414 -0
- sleap_nn/data/utils.py +135 -17
- sleap_nn/evaluation.py +177 -42
- sleap_nn/export/__init__.py +21 -0
- sleap_nn/export/cli.py +1778 -0
- sleap_nn/export/exporters/__init__.py +51 -0
- sleap_nn/export/exporters/onnx_exporter.py +80 -0
- sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
- sleap_nn/export/metadata.py +225 -0
- sleap_nn/export/predictors/__init__.py +63 -0
- sleap_nn/export/predictors/base.py +22 -0
- sleap_nn/export/predictors/onnx.py +154 -0
- sleap_nn/export/predictors/tensorrt.py +312 -0
- sleap_nn/export/utils.py +307 -0
- sleap_nn/export/wrappers/__init__.py +25 -0
- sleap_nn/export/wrappers/base.py +96 -0
- sleap_nn/export/wrappers/bottomup.py +243 -0
- sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
- sleap_nn/export/wrappers/centered_instance.py +56 -0
- sleap_nn/export/wrappers/centroid.py +58 -0
- sleap_nn/export/wrappers/single_instance.py +83 -0
- sleap_nn/export/wrappers/topdown.py +180 -0
- sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
- sleap_nn/inference/__init__.py +6 -0
- sleap_nn/inference/bottomup.py +86 -20
- sleap_nn/inference/peak_finding.py +93 -16
- sleap_nn/inference/postprocessing.py +284 -0
- sleap_nn/inference/predictors.py +339 -137
- sleap_nn/inference/provenance.py +292 -0
- sleap_nn/inference/topdown.py +55 -47
- sleap_nn/legacy_models.py +65 -11
- sleap_nn/predict.py +224 -19
- sleap_nn/system_info.py +443 -0
- sleap_nn/tracking/tracker.py +8 -1
- sleap_nn/train.py +138 -44
- sleap_nn/training/callbacks.py +1258 -5
- sleap_nn/training/lightning_modules.py +902 -220
- sleap_nn/training/model_trainer.py +424 -111
- sleap_nn/training/schedulers.py +191 -0
- sleap_nn/training/utils.py +367 -2
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/METADATA +35 -33
- sleap_nn-0.1.0.dist-info/RECORD +88 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/WHEEL +1 -1
- sleap_nn-0.0.5.dist-info/RECORD +0 -63
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/top_level.txt +0 -0
sleap_nn/cli.py
CHANGED
|
@@ -1,15 +1,63 @@
|
|
|
1
|
-
"""Unified CLI for SLEAP-NN using
|
|
1
|
+
"""Unified CLI for SLEAP-NN using rich-click for styled output."""
|
|
2
2
|
|
|
3
|
-
import
|
|
3
|
+
import subprocess
|
|
4
|
+
import tempfile
|
|
5
|
+
import shutil
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
|
|
8
|
+
import rich_click as click
|
|
9
|
+
from click import Command
|
|
4
10
|
from loguru import logger
|
|
5
11
|
from pathlib import Path
|
|
6
12
|
from omegaconf import OmegaConf, DictConfig
|
|
13
|
+
import sleap_io as sio
|
|
7
14
|
from sleap_nn.predict import run_inference, frame_list
|
|
8
15
|
from sleap_nn.evaluation import run_evaluation
|
|
16
|
+
from sleap_nn.export.cli import export as export_command
|
|
17
|
+
from sleap_nn.export.cli import predict as predict_command
|
|
9
18
|
from sleap_nn.train import run_training
|
|
19
|
+
from sleap_nn import __version__
|
|
20
|
+
from sleap_nn.config.utils import get_model_type_from_cfg
|
|
10
21
|
import hydra
|
|
11
22
|
import sys
|
|
12
|
-
|
|
23
|
+
|
|
24
|
+
# Rich-click configuration for styled help
|
|
25
|
+
click.rich_click.TEXT_MARKUP = "markdown"
|
|
26
|
+
click.rich_click.SHOW_ARGUMENTS = True
|
|
27
|
+
click.rich_click.GROUP_ARGUMENTS_OPTIONS = True
|
|
28
|
+
click.rich_click.STYLE_ERRORS_SUGGESTION = "magenta italic"
|
|
29
|
+
click.rich_click.ERRORS_EPILOGUE = (
|
|
30
|
+
"Try 'sleap-nn [COMMAND] --help' for more information."
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def is_config_path(arg: str) -> bool:
|
|
35
|
+
"""Check if an argument looks like a config file path.
|
|
36
|
+
|
|
37
|
+
Returns True if the arg ends with .yaml or .yml.
|
|
38
|
+
"""
|
|
39
|
+
return arg.endswith(".yaml") or arg.endswith(".yml")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def split_config_path(config_path: str) -> tuple:
|
|
43
|
+
"""Split a full config path into (config_dir, config_name).
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
config_path: Full path to a config file.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
Tuple of (config_dir, config_name) where config_dir is an absolute path.
|
|
50
|
+
"""
|
|
51
|
+
path = Path(config_path).resolve()
|
|
52
|
+
return path.parent.as_posix(), path.name
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def print_version(ctx, param, value):
|
|
56
|
+
"""Print version and exit."""
|
|
57
|
+
if not value or ctx.resilient_parsing:
|
|
58
|
+
return
|
|
59
|
+
click.echo(f"sleap-nn {__version__}")
|
|
60
|
+
ctx.exit()
|
|
13
61
|
|
|
14
62
|
|
|
15
63
|
class TrainCommand(Command):
|
|
@@ -20,73 +68,295 @@ class TrainCommand(Command):
|
|
|
20
68
|
show_training_help()
|
|
21
69
|
|
|
22
70
|
|
|
71
|
+
def parse_path_map(ctx, param, value):
|
|
72
|
+
"""Parse (old, new) path pairs into a dictionary for path mapping options."""
|
|
73
|
+
if not value:
|
|
74
|
+
return None
|
|
75
|
+
result = {}
|
|
76
|
+
for old_path, new_path in value:
|
|
77
|
+
result[old_path] = Path(new_path).as_posix()
|
|
78
|
+
return result
|
|
79
|
+
|
|
80
|
+
|
|
23
81
|
@click.group()
|
|
82
|
+
@click.option(
|
|
83
|
+
"--version",
|
|
84
|
+
"-v",
|
|
85
|
+
is_flag=True,
|
|
86
|
+
callback=print_version,
|
|
87
|
+
expose_value=False,
|
|
88
|
+
is_eager=True,
|
|
89
|
+
help="Show version and exit.",
|
|
90
|
+
)
|
|
24
91
|
def cli():
|
|
25
92
|
"""SLEAP-NN: Neural network backend for training and inference for animal pose estimation.
|
|
26
93
|
|
|
27
94
|
Use subcommands to run different workflows:
|
|
28
95
|
|
|
29
|
-
train - Run training workflow
|
|
30
|
-
track - Run inference/
|
|
96
|
+
train - Run training workflow (auto-handles multi-GPU)
|
|
97
|
+
track - Run inference/tracking workflow
|
|
31
98
|
eval - Run evaluation workflow
|
|
99
|
+
system - Display system information and GPU status
|
|
32
100
|
"""
|
|
33
101
|
pass
|
|
34
102
|
|
|
35
103
|
|
|
104
|
+
def _get_num_devices_from_config(cfg: DictConfig) -> int:
|
|
105
|
+
"""Determine the number of devices from config.
|
|
106
|
+
|
|
107
|
+
User preferences take precedence over auto-detection:
|
|
108
|
+
- trainer_device_indices=[0] → 1 device (user choice)
|
|
109
|
+
- trainer_devices=1 → 1 device (user choice)
|
|
110
|
+
- trainer_devices="auto" or unset → auto-detect available GPUs
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
Number of devices to use for training.
|
|
114
|
+
"""
|
|
115
|
+
import torch
|
|
116
|
+
|
|
117
|
+
# User preference: explicit device indices (highest priority)
|
|
118
|
+
device_indices = OmegaConf.select(
|
|
119
|
+
cfg, "trainer_config.trainer_device_indices", default=None
|
|
120
|
+
)
|
|
121
|
+
if device_indices is not None and len(device_indices) > 0:
|
|
122
|
+
return len(device_indices)
|
|
123
|
+
|
|
124
|
+
# User preference: explicit device count
|
|
125
|
+
devices = OmegaConf.select(cfg, "trainer_config.trainer_devices", default="auto")
|
|
126
|
+
|
|
127
|
+
if isinstance(devices, int):
|
|
128
|
+
return devices
|
|
129
|
+
|
|
130
|
+
# Auto-detect only when user hasn't specified (devices is "auto" or None)
|
|
131
|
+
if devices in ("auto", None, "None"):
|
|
132
|
+
accelerator = OmegaConf.select(
|
|
133
|
+
cfg, "trainer_config.trainer_accelerator", default="auto"
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
if accelerator == "cpu":
|
|
137
|
+
return 1
|
|
138
|
+
elif torch.cuda.is_available():
|
|
139
|
+
return torch.cuda.device_count()
|
|
140
|
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
141
|
+
return 1
|
|
142
|
+
else:
|
|
143
|
+
return 1
|
|
144
|
+
|
|
145
|
+
return 1
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def _finalize_config(cfg: DictConfig) -> DictConfig:
|
|
149
|
+
"""Finalize configuration by generating run_name if not provided.
|
|
150
|
+
|
|
151
|
+
This runs ONCE before subprocess, ensuring all workers get the same run_name.
|
|
152
|
+
"""
|
|
153
|
+
# Resolve ckpt_dir first
|
|
154
|
+
ckpt_dir = OmegaConf.select(cfg, "trainer_config.ckpt_dir", default=None)
|
|
155
|
+
if ckpt_dir is None or ckpt_dir == "" or ckpt_dir == "None":
|
|
156
|
+
cfg.trainer_config.ckpt_dir = "."
|
|
157
|
+
|
|
158
|
+
# Generate run_name if not provided
|
|
159
|
+
run_name = OmegaConf.select(cfg, "trainer_config.run_name", default=None)
|
|
160
|
+
if run_name is None or run_name == "" or run_name == "None":
|
|
161
|
+
# Get model type from config
|
|
162
|
+
model_type = get_model_type_from_cfg(cfg)
|
|
163
|
+
|
|
164
|
+
# Count frames from labels
|
|
165
|
+
train_paths = cfg.data_config.train_labels_path
|
|
166
|
+
val_paths = OmegaConf.select(cfg, "data_config.val_labels_path", default=None)
|
|
167
|
+
|
|
168
|
+
train_count = sum(len(sio.load_slp(p)) for p in train_paths)
|
|
169
|
+
val_count = 0
|
|
170
|
+
if val_paths:
|
|
171
|
+
val_count = sum(len(sio.load_slp(p)) for p in val_paths)
|
|
172
|
+
|
|
173
|
+
# Generate full run_name with timestamp
|
|
174
|
+
timestamp = datetime.now().strftime("%y%m%d_%H%M%S")
|
|
175
|
+
run_name = f"{timestamp}.{model_type}.n={train_count + val_count}"
|
|
176
|
+
cfg.trainer_config.run_name = run_name
|
|
177
|
+
|
|
178
|
+
logger.info(f"Generated run_name: {run_name}")
|
|
179
|
+
|
|
180
|
+
return cfg
|
|
181
|
+
|
|
182
|
+
|
|
36
183
|
def show_training_help():
|
|
37
|
-
"""Display training help information."""
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
184
|
+
"""Display training help information with rich formatting."""
|
|
185
|
+
from rich.console import Console
|
|
186
|
+
from rich.panel import Panel
|
|
187
|
+
from rich.markdown import Markdown
|
|
188
|
+
|
|
189
|
+
console = Console()
|
|
190
|
+
|
|
191
|
+
help_md = """
|
|
192
|
+
## Usage
|
|
193
|
+
|
|
194
|
+
```
|
|
195
|
+
sleap-nn train <config.yaml> [overrides]
|
|
196
|
+
sleap-nn train --config <path/to/config.yaml> [overrides]
|
|
197
|
+
```
|
|
198
|
+
|
|
199
|
+
## Common Overrides
|
|
200
|
+
|
|
201
|
+
| Override | Description |
|
|
202
|
+
|----------|-------------|
|
|
203
|
+
| `trainer_config.max_epochs=100` | Set maximum training epochs |
|
|
204
|
+
| `trainer_config.batch_size=32` | Set batch size |
|
|
205
|
+
| `trainer_config.save_ckpt=true` | Enable checkpoint saving |
|
|
206
|
+
|
|
207
|
+
## Examples
|
|
208
|
+
|
|
209
|
+
**Start a new training run:**
|
|
210
|
+
```bash
|
|
211
|
+
sleap-nn train path/to/config.yaml
|
|
212
|
+
sleap-nn train --config path/to/config.yaml
|
|
213
|
+
```
|
|
214
|
+
|
|
215
|
+
**With overrides:**
|
|
216
|
+
```bash
|
|
217
|
+
sleap-nn train config.yaml trainer_config.max_epochs=100
|
|
218
|
+
```
|
|
219
|
+
|
|
220
|
+
**Resume training:**
|
|
221
|
+
```bash
|
|
222
|
+
sleap-nn train config.yaml trainer_config.resume_ckpt_path=/path/to/ckpt
|
|
223
|
+
```
|
|
224
|
+
|
|
225
|
+
**Legacy usage (still supported):**
|
|
226
|
+
```bash
|
|
227
|
+
sleap-nn train --config-dir /path/to/dir --config-name myrun
|
|
228
|
+
```
|
|
229
|
+
|
|
230
|
+
## Tips
|
|
231
|
+
|
|
232
|
+
- Use `-m/--multirun` for sweeps; outputs go under `hydra.sweep.dir`
|
|
233
|
+
- For Hydra flags and completion, use `--hydra-help`
|
|
234
|
+
- Config documentation: https://nn.sleap.ai/config/
|
|
61
235
|
"""
|
|
62
|
-
|
|
236
|
+
console.print(
|
|
237
|
+
Panel(
|
|
238
|
+
Markdown(help_md),
|
|
239
|
+
title="[bold cyan]sleap-nn train[/bold cyan]",
|
|
240
|
+
subtitle="Train SLEAP models from a config YAML file",
|
|
241
|
+
border_style="cyan",
|
|
242
|
+
)
|
|
243
|
+
)
|
|
63
244
|
|
|
64
245
|
|
|
65
246
|
@cli.command(cls=TrainCommand)
|
|
66
|
-
@click.option("--config-name", "-c", type=str, help="Configuration file name")
|
|
67
247
|
@click.option(
|
|
68
|
-
"--config
|
|
248
|
+
"--config",
|
|
249
|
+
type=str,
|
|
250
|
+
help="Path to configuration file (e.g., path/to/config.yaml)",
|
|
251
|
+
)
|
|
252
|
+
@click.option("--config-name", "-c", type=str, help="Configuration file name (legacy)")
|
|
253
|
+
@click.option(
|
|
254
|
+
"--config-dir", "-d", type=str, default=".", help="Configuration directory (legacy)"
|
|
255
|
+
)
|
|
256
|
+
@click.option(
|
|
257
|
+
"--video-paths",
|
|
258
|
+
"-v",
|
|
259
|
+
multiple=True,
|
|
260
|
+
help="Video paths to replace existing paths in the labels file. "
|
|
261
|
+
"Order must match the order of videos in the labels file. "
|
|
262
|
+
"Can be specified multiple times. "
|
|
263
|
+
"Example: --video-paths /path/to/vid1.mp4 --video-paths /path/to/vid2.mp4",
|
|
264
|
+
)
|
|
265
|
+
@click.option(
|
|
266
|
+
"--video-path-map",
|
|
267
|
+
nargs=2,
|
|
268
|
+
multiple=True,
|
|
269
|
+
callback=parse_path_map,
|
|
270
|
+
metavar="OLD NEW",
|
|
271
|
+
help="Map old video path to new path. Takes two arguments: old path and new path. "
|
|
272
|
+
"Can be specified multiple times. "
|
|
273
|
+
'Example: --video-path-map "/old/vid.mp`4" "/new/vid.mp4"',
|
|
274
|
+
)
|
|
275
|
+
@click.option(
|
|
276
|
+
"--prefix-map",
|
|
277
|
+
nargs=2,
|
|
278
|
+
multiple=True,
|
|
279
|
+
callback=parse_path_map,
|
|
280
|
+
metavar="OLD NEW",
|
|
281
|
+
help="Map old path prefix to new prefix. Takes two arguments: old prefix and new prefix. "
|
|
282
|
+
"Updates ALL videos that share the same prefix. Useful when moving data between machines. "
|
|
283
|
+
"Can be specified multiple times. "
|
|
284
|
+
'Example: --prefix-map "/old/server/path" "/new/local/path"',
|
|
285
|
+
)
|
|
286
|
+
@click.option(
|
|
287
|
+
"--video-config",
|
|
288
|
+
type=str,
|
|
289
|
+
hidden=True,
|
|
290
|
+
help="Path to video replacement config YAML (internal use for multi-GPU).",
|
|
69
291
|
)
|
|
70
292
|
@click.argument("overrides", nargs=-1, type=click.UNPROCESSED)
|
|
71
|
-
def train(
|
|
293
|
+
def train(
|
|
294
|
+
config,
|
|
295
|
+
config_name,
|
|
296
|
+
config_dir,
|
|
297
|
+
video_paths,
|
|
298
|
+
video_path_map,
|
|
299
|
+
prefix_map,
|
|
300
|
+
video_config,
|
|
301
|
+
overrides,
|
|
302
|
+
):
|
|
72
303
|
"""Run training workflow with Hydra config overrides.
|
|
73
304
|
|
|
305
|
+
Automatically detects multi-GPU setups and handles run_name synchronization
|
|
306
|
+
by spawning training in a subprocess with a pre-generated config.
|
|
307
|
+
|
|
74
308
|
Examples:
|
|
75
|
-
sleap-nn train
|
|
76
|
-
sleap-nn train
|
|
77
|
-
sleap-nn train
|
|
309
|
+
sleap-nn train path/to/config.yaml
|
|
310
|
+
sleap-nn train --config path/to/config.yaml trainer_config.max_epochs=100
|
|
311
|
+
sleap-nn train config.yaml trainer_config.trainer_devices=4
|
|
78
312
|
"""
|
|
79
|
-
#
|
|
80
|
-
|
|
313
|
+
# Convert overrides to a mutable list
|
|
314
|
+
overrides = list(overrides)
|
|
315
|
+
|
|
316
|
+
# Check if the first positional arg is a config path (not a Hydra override)
|
|
317
|
+
config_from_positional = None
|
|
318
|
+
if overrides and is_config_path(overrides[0]):
|
|
319
|
+
config_from_positional = overrides.pop(0)
|
|
320
|
+
|
|
321
|
+
# Resolve config path with priority:
|
|
322
|
+
# 1. Positional config path (e.g., sleap-nn train config.yaml)
|
|
323
|
+
# 2. --config flag (e.g., sleap-nn train --config config.yaml)
|
|
324
|
+
# 3. Legacy --config-dir/--config-name flags
|
|
325
|
+
if config_from_positional:
|
|
326
|
+
config_dir, config_name = split_config_path(config_from_positional)
|
|
327
|
+
elif config:
|
|
328
|
+
config_dir, config_name = split_config_path(config)
|
|
329
|
+
elif config_name:
|
|
330
|
+
config_dir = Path(config_dir).resolve().as_posix()
|
|
331
|
+
else:
|
|
332
|
+
# No config provided - show help
|
|
81
333
|
show_training_help()
|
|
82
334
|
return
|
|
83
335
|
|
|
84
|
-
#
|
|
85
|
-
#
|
|
86
|
-
|
|
336
|
+
# Check video path options early
|
|
337
|
+
# If --video-config is provided (from subprocess), load from file
|
|
338
|
+
if video_config:
|
|
339
|
+
video_cfg = OmegaConf.load(video_config)
|
|
340
|
+
video_paths = tuple(video_cfg.video_paths) if video_cfg.video_paths else ()
|
|
341
|
+
video_path_map = (
|
|
342
|
+
dict(video_cfg.video_path_map) if video_cfg.video_path_map else None
|
|
343
|
+
)
|
|
344
|
+
prefix_map = dict(video_cfg.prefix_map) if video_cfg.prefix_map else None
|
|
345
|
+
|
|
346
|
+
has_video_paths = len(video_paths) > 0
|
|
347
|
+
has_video_path_map = video_path_map is not None
|
|
348
|
+
has_prefix_map = prefix_map is not None
|
|
349
|
+
options_used = sum([has_video_paths, has_video_path_map, has_prefix_map])
|
|
350
|
+
|
|
351
|
+
if options_used > 1:
|
|
352
|
+
raise click.UsageError(
|
|
353
|
+
"Cannot use multiple path replacement options. "
|
|
354
|
+
"Choose one of: --video-paths, --video-path-map, or --prefix-map."
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
# Load config to detect device count
|
|
87
358
|
with hydra.initialize_config_dir(config_dir=config_dir, version_base=None):
|
|
88
|
-
|
|
89
|
-
cfg = hydra.compose(config_name=config_name, overrides=list(overrides))
|
|
359
|
+
cfg = hydra.compose(config_name=config_name, overrides=overrides)
|
|
90
360
|
|
|
91
361
|
# Validate config
|
|
92
362
|
if not hasattr(cfg, "model_config") or not cfg.model_config:
|
|
@@ -95,9 +365,118 @@ def train(config_name, config_dir, overrides):
|
|
|
95
365
|
)
|
|
96
366
|
raise click.Abort()
|
|
97
367
|
|
|
368
|
+
num_devices = _get_num_devices_from_config(cfg)
|
|
369
|
+
|
|
370
|
+
# Check if run_name is already set (means we're a subprocess or user provided it)
|
|
371
|
+
run_name = OmegaConf.select(cfg, "trainer_config.run_name", default=None)
|
|
372
|
+
run_name_is_set = run_name is not None and run_name != "" and run_name != "None"
|
|
373
|
+
|
|
374
|
+
# Multi-GPU path: spawn subprocess with finalized config
|
|
375
|
+
# Only do this if run_name is NOT set (otherwise we'd loop infinitely or user set it)
|
|
376
|
+
if num_devices > 1 and not run_name_is_set:
|
|
377
|
+
logger.info(
|
|
378
|
+
f"Detected {num_devices} devices, using subprocess for run_name sync..."
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
# Load and finalize config (generate run_name, apply overrides)
|
|
382
|
+
with hydra.initialize_config_dir(config_dir=config_dir, version_base=None):
|
|
383
|
+
cfg = hydra.compose(config_name=config_name, overrides=overrides)
|
|
384
|
+
cfg = _finalize_config(cfg)
|
|
385
|
+
|
|
386
|
+
# Save finalized config to temp file
|
|
387
|
+
temp_dir = tempfile.mkdtemp(prefix="sleap_nn_train_")
|
|
388
|
+
temp_config_path = Path(temp_dir) / "training_config.yaml"
|
|
389
|
+
OmegaConf.save(cfg, temp_config_path)
|
|
390
|
+
logger.info(f"Saved finalized config to: {temp_config_path}")
|
|
391
|
+
|
|
392
|
+
# Save video replacement config if needed (so subprocess doesn't need CLI args)
|
|
393
|
+
temp_video_config_path = None
|
|
394
|
+
if options_used == 1:
|
|
395
|
+
video_replacement_config = {
|
|
396
|
+
"video_paths": list(video_paths) if has_video_paths else None,
|
|
397
|
+
"video_path_map": dict(video_path_map) if has_video_path_map else None,
|
|
398
|
+
"prefix_map": dict(prefix_map) if has_prefix_map else None,
|
|
399
|
+
}
|
|
400
|
+
temp_video_config_path = Path(temp_dir) / "video_replacement.yaml"
|
|
401
|
+
OmegaConf.save(
|
|
402
|
+
OmegaConf.create(video_replacement_config), temp_video_config_path
|
|
403
|
+
)
|
|
404
|
+
logger.info(f"Saved video replacement config to: {temp_video_config_path}")
|
|
405
|
+
|
|
406
|
+
# Build subprocess command (no video args - they're in the temp file)
|
|
407
|
+
cmd = [sys.executable, "-m", "sleap_nn.cli", "train", str(temp_config_path)]
|
|
408
|
+
if temp_video_config_path:
|
|
409
|
+
cmd.extend(["--video-config", str(temp_video_config_path)])
|
|
410
|
+
|
|
411
|
+
logger.info(f"Launching subprocess: {' '.join(cmd)}")
|
|
412
|
+
|
|
413
|
+
try:
|
|
414
|
+
process = subprocess.Popen(cmd)
|
|
415
|
+
result = process.wait()
|
|
416
|
+
if result != 0:
|
|
417
|
+
logger.error(f"Training failed with exit code {result}")
|
|
418
|
+
sys.exit(result)
|
|
419
|
+
except KeyboardInterrupt:
|
|
420
|
+
logger.info("Training interrupted, terminating subprocess...")
|
|
421
|
+
process.terminate()
|
|
422
|
+
try:
|
|
423
|
+
process.wait(timeout=5)
|
|
424
|
+
except subprocess.TimeoutExpired:
|
|
425
|
+
process.kill()
|
|
426
|
+
process.wait()
|
|
427
|
+
sys.exit(1)
|
|
428
|
+
finally:
|
|
429
|
+
shutil.rmtree(temp_dir, ignore_errors=True)
|
|
430
|
+
logger.info("Cleaned up temporary files")
|
|
431
|
+
|
|
432
|
+
return
|
|
433
|
+
|
|
434
|
+
# Single GPU (or subprocess worker): run directly
|
|
435
|
+
with hydra.initialize_config_dir(config_dir=config_dir, version_base=None):
|
|
436
|
+
cfg = hydra.compose(config_name=config_name, overrides=overrides)
|
|
437
|
+
|
|
98
438
|
logger.info("Input config:")
|
|
99
439
|
logger.info("\n" + OmegaConf.to_yaml(cfg))
|
|
100
|
-
|
|
440
|
+
|
|
441
|
+
# Handle video path replacement options
|
|
442
|
+
train_labels = None
|
|
443
|
+
val_labels = None
|
|
444
|
+
|
|
445
|
+
if options_used == 1:
|
|
446
|
+
# Load train labels
|
|
447
|
+
train_labels = [
|
|
448
|
+
sio.load_slp(path) for path in cfg.data_config.train_labels_path
|
|
449
|
+
]
|
|
450
|
+
|
|
451
|
+
# Load val labels if they exist
|
|
452
|
+
if (
|
|
453
|
+
cfg.data_config.val_labels_path is not None
|
|
454
|
+
and len(cfg.data_config.val_labels_path) > 0
|
|
455
|
+
):
|
|
456
|
+
val_labels = [
|
|
457
|
+
sio.load_slp(path) for path in cfg.data_config.val_labels_path
|
|
458
|
+
]
|
|
459
|
+
|
|
460
|
+
# Build replacement arguments based on option used
|
|
461
|
+
if has_video_paths:
|
|
462
|
+
replace_kwargs = {
|
|
463
|
+
"new_filenames": [Path(p).as_posix() for p in video_paths]
|
|
464
|
+
}
|
|
465
|
+
elif has_video_path_map:
|
|
466
|
+
replace_kwargs = {"filename_map": video_path_map}
|
|
467
|
+
else: # has_prefix_map
|
|
468
|
+
replace_kwargs = {"prefix_map": prefix_map}
|
|
469
|
+
|
|
470
|
+
# Apply replacement to train labels
|
|
471
|
+
for labels in train_labels:
|
|
472
|
+
labels.replace_filenames(**replace_kwargs)
|
|
473
|
+
|
|
474
|
+
# Apply replacement to val labels if they exist
|
|
475
|
+
if val_labels:
|
|
476
|
+
for labels in val_labels:
|
|
477
|
+
labels.replace_filenames(**replace_kwargs)
|
|
478
|
+
|
|
479
|
+
run_training(config=cfg, train_labels=train_labels, val_labels=val_labels)
|
|
101
480
|
|
|
102
481
|
|
|
103
482
|
@cli.command()
|
|
@@ -209,6 +588,18 @@ def train(config_name, config_dir, overrides):
|
|
|
209
588
|
default=False,
|
|
210
589
|
help="Only run inference on unlabeled suggested frames when running on labels dataset. This is useful for generating predictions for initialization during labeling.",
|
|
211
590
|
)
|
|
591
|
+
@click.option(
|
|
592
|
+
"--exclude_user_labeled",
|
|
593
|
+
is_flag=True,
|
|
594
|
+
default=False,
|
|
595
|
+
help="Skip frames that have user-labeled instances. Useful when predicting on entire video but skipping already-labeled frames.",
|
|
596
|
+
)
|
|
597
|
+
@click.option(
|
|
598
|
+
"--only_predicted_frames",
|
|
599
|
+
is_flag=True,
|
|
600
|
+
default=False,
|
|
601
|
+
help="Only run inference on frames that already have predictions. Requires .slp input file. Useful for re-predicting with a different model.",
|
|
602
|
+
)
|
|
212
603
|
@click.option(
|
|
213
604
|
"--no_empty_frames",
|
|
214
605
|
is_flag=True,
|
|
@@ -275,14 +666,14 @@ def train(config_name, config_dir, overrides):
|
|
|
275
666
|
@click.option(
|
|
276
667
|
"--queue_maxsize",
|
|
277
668
|
type=int,
|
|
278
|
-
default=
|
|
669
|
+
default=32,
|
|
279
670
|
help="Maximum size of the frame buffer queue.",
|
|
280
671
|
)
|
|
281
672
|
@click.option(
|
|
282
673
|
"--crop_size",
|
|
283
674
|
type=int,
|
|
284
675
|
default=None,
|
|
285
|
-
help="Crop size. If not provided, the crop size from training_config.yaml is used.",
|
|
676
|
+
help="Crop size. If not provided, the crop size from training_config.yaml is used. If `input_scale` is provided, then the cropped image will be resized according to `input_scale`.",
|
|
286
677
|
)
|
|
287
678
|
@click.option(
|
|
288
679
|
"--peak_threshold",
|
|
@@ -290,6 +681,36 @@ def train(config_name, config_dir, overrides):
|
|
|
290
681
|
default=0.2,
|
|
291
682
|
help="Minimum confidence map value to consider a peak as valid.",
|
|
292
683
|
)
|
|
684
|
+
@click.option(
|
|
685
|
+
"--filter_overlapping",
|
|
686
|
+
is_flag=True,
|
|
687
|
+
default=False,
|
|
688
|
+
help=(
|
|
689
|
+
"Enable filtering of overlapping instances after inference using greedy NMS. "
|
|
690
|
+
"Applied independently of tracking. (default: False)"
|
|
691
|
+
),
|
|
692
|
+
)
|
|
693
|
+
@click.option(
|
|
694
|
+
"--filter_overlapping_method",
|
|
695
|
+
type=click.Choice(["iou", "oks"]),
|
|
696
|
+
default="iou",
|
|
697
|
+
help=(
|
|
698
|
+
"Similarity metric for filtering overlapping instances. "
|
|
699
|
+
"'iou': bounding box intersection-over-union. "
|
|
700
|
+
"'oks': Object Keypoint Similarity (pose-based). (default: iou)"
|
|
701
|
+
),
|
|
702
|
+
)
|
|
703
|
+
@click.option(
|
|
704
|
+
"--filter_overlapping_threshold",
|
|
705
|
+
type=float,
|
|
706
|
+
default=0.8,
|
|
707
|
+
help=(
|
|
708
|
+
"Similarity threshold for filtering overlapping instances. "
|
|
709
|
+
"Instances with similarity above this threshold are removed, "
|
|
710
|
+
"keeping the higher-scoring instance. "
|
|
711
|
+
"Typical values: 0.3 (aggressive) to 0.8 (permissive). (default: 0.8)"
|
|
712
|
+
),
|
|
713
|
+
)
|
|
293
714
|
@click.option(
|
|
294
715
|
"--integral_refinement",
|
|
295
716
|
type=str,
|
|
@@ -422,6 +843,12 @@ def train(config_name, config_dir, overrides):
|
|
|
422
843
|
default=0,
|
|
423
844
|
help="IOU to use when culling instances *after* tracking. (default: 0)",
|
|
424
845
|
)
|
|
846
|
+
@click.option(
|
|
847
|
+
"--gui",
|
|
848
|
+
is_flag=True,
|
|
849
|
+
default=False,
|
|
850
|
+
help="Output JSON progress for GUI integration instead of Rich progress bar.",
|
|
851
|
+
)
|
|
425
852
|
def track(**kwargs):
|
|
426
853
|
"""Run Inference and Tracking workflow."""
|
|
427
854
|
# Convert model_paths from tuple to list
|
|
@@ -474,5 +901,21 @@ def eval(**kwargs):
|
|
|
474
901
|
run_evaluation(**kwargs)
|
|
475
902
|
|
|
476
903
|
|
|
904
|
+
@cli.command()
|
|
905
|
+
def system():
|
|
906
|
+
"""Display system information and GPU status.
|
|
907
|
+
|
|
908
|
+
Shows Python version, platform, PyTorch version, CUDA availability,
|
|
909
|
+
driver version with compatibility check, GPU details, and package versions.
|
|
910
|
+
"""
|
|
911
|
+
from sleap_nn.system_info import print_system_info
|
|
912
|
+
|
|
913
|
+
print_system_info()
|
|
914
|
+
|
|
915
|
+
|
|
916
|
+
cli.add_command(export_command)
|
|
917
|
+
cli.add_command(predict_command)
|
|
918
|
+
|
|
919
|
+
|
|
477
920
|
if __name__ == "__main__":
|
|
478
921
|
cli()
|