sleap-nn 0.1.0__py3-none-any.whl → 0.1.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +2 -4
- sleap_nn/architectures/convnext.py +0 -5
- sleap_nn/architectures/encoder_decoder.py +6 -25
- sleap_nn/architectures/swint.py +0 -8
- sleap_nn/cli.py +60 -364
- sleap_nn/config/data_config.py +5 -11
- sleap_nn/config/get_config.py +4 -10
- sleap_nn/config/trainer_config.py +0 -76
- sleap_nn/data/augmentation.py +241 -50
- sleap_nn/data/custom_datasets.py +39 -411
- sleap_nn/data/instance_cropping.py +1 -1
- sleap_nn/data/resizing.py +2 -2
- sleap_nn/data/utils.py +17 -135
- sleap_nn/evaluation.py +22 -81
- sleap_nn/inference/bottomup.py +20 -86
- sleap_nn/inference/peak_finding.py +19 -88
- sleap_nn/inference/predictors.py +117 -224
- sleap_nn/legacy_models.py +11 -65
- sleap_nn/predict.py +9 -37
- sleap_nn/train.py +4 -74
- sleap_nn/training/callbacks.py +105 -1046
- sleap_nn/training/lightning_modules.py +65 -602
- sleap_nn/training/model_trainer.py +184 -211
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/METADATA +3 -15
- sleap_nn-0.1.0a0.dist-info/RECORD +65 -0
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/WHEEL +1 -1
- sleap_nn/data/skia_augmentation.py +0 -414
- sleap_nn/export/__init__.py +0 -21
- sleap_nn/export/cli.py +0 -1778
- sleap_nn/export/exporters/__init__.py +0 -51
- sleap_nn/export/exporters/onnx_exporter.py +0 -80
- sleap_nn/export/exporters/tensorrt_exporter.py +0 -291
- sleap_nn/export/metadata.py +0 -225
- sleap_nn/export/predictors/__init__.py +0 -63
- sleap_nn/export/predictors/base.py +0 -22
- sleap_nn/export/predictors/onnx.py +0 -154
- sleap_nn/export/predictors/tensorrt.py +0 -312
- sleap_nn/export/utils.py +0 -307
- sleap_nn/export/wrappers/__init__.py +0 -25
- sleap_nn/export/wrappers/base.py +0 -96
- sleap_nn/export/wrappers/bottomup.py +0 -243
- sleap_nn/export/wrappers/bottomup_multiclass.py +0 -195
- sleap_nn/export/wrappers/centered_instance.py +0 -56
- sleap_nn/export/wrappers/centroid.py +0 -58
- sleap_nn/export/wrappers/single_instance.py +0 -83
- sleap_nn/export/wrappers/topdown.py +0 -180
- sleap_nn/export/wrappers/topdown_multiclass.py +0 -304
- sleap_nn/inference/postprocessing.py +0 -284
- sleap_nn/training/schedulers.py +0 -191
- sleap_nn-0.1.0.dist-info/RECORD +0 -88
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/top_level.txt +0 -0
sleap_nn/__init__.py
CHANGED
|
@@ -41,16 +41,14 @@ def _safe_print(msg):
|
|
|
41
41
|
|
|
42
42
|
|
|
43
43
|
# Add logger with the custom filter
|
|
44
|
-
# Disable colorization to avoid ANSI codes in captured output
|
|
45
44
|
logger.add(
|
|
46
45
|
_safe_print,
|
|
47
46
|
level="DEBUG",
|
|
48
47
|
filter=_should_log,
|
|
49
|
-
format="{time:YYYY-MM-DD HH:mm:ss} | {message}",
|
|
50
|
-
colorize=False,
|
|
48
|
+
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {name}:{function}:{line} | {message}",
|
|
51
49
|
)
|
|
52
50
|
|
|
53
|
-
__version__ = "0.1.
|
|
51
|
+
__version__ = "0.1.0a0"
|
|
54
52
|
|
|
55
53
|
# Public API
|
|
56
54
|
from sleap_nn.evaluation import load_metrics
|
|
@@ -281,10 +281,6 @@ class ConvNextWrapper(nn.Module):
|
|
|
281
281
|
# Keep the block output filters the same
|
|
282
282
|
x_in_shape = int(self.arch["channels"][-1] * filters_rate)
|
|
283
283
|
|
|
284
|
-
# Encoder channels for skip connections (reversed to match decoder order)
|
|
285
|
-
# The forward pass uses enc_output[::2][::-1] for skip features
|
|
286
|
-
encoder_channels = self.arch["channels"][::-1]
|
|
287
|
-
|
|
288
284
|
self.dec = Decoder(
|
|
289
285
|
x_in_shape=x_in_shape,
|
|
290
286
|
current_stride=self.current_stride,
|
|
@@ -297,7 +293,6 @@ class ConvNextWrapper(nn.Module):
|
|
|
297
293
|
block_contraction=self.block_contraction,
|
|
298
294
|
output_stride=self.output_stride,
|
|
299
295
|
up_interpolate=up_interpolate,
|
|
300
|
-
encoder_channels=encoder_channels,
|
|
301
296
|
)
|
|
302
297
|
|
|
303
298
|
if len(self.dec.decoder_stack):
|
|
@@ -25,7 +25,7 @@ classes.
|
|
|
25
25
|
See the `EncoderDecoder` base class for requirements for creating new architectures.
|
|
26
26
|
"""
|
|
27
27
|
|
|
28
|
-
from typing import List,
|
|
28
|
+
from typing import List, Text, Tuple, Union
|
|
29
29
|
from collections import OrderedDict
|
|
30
30
|
import torch
|
|
31
31
|
from torch import nn
|
|
@@ -391,18 +391,10 @@ class SimpleUpsamplingBlock(nn.Module):
|
|
|
391
391
|
transpose_convs_activation: Text = "relu",
|
|
392
392
|
feat_concat: bool = True,
|
|
393
393
|
prefix: Text = "",
|
|
394
|
-
skip_channels: Optional[int] = None,
|
|
395
394
|
) -> None:
|
|
396
395
|
"""Initialize the class."""
|
|
397
396
|
super().__init__()
|
|
398
397
|
|
|
399
|
-
# Determine skip connection channels
|
|
400
|
-
# If skip_channels is provided, use it; otherwise fall back to refine_convs_filters
|
|
401
|
-
# This allows ConvNext/SwinT to specify actual encoder channels
|
|
402
|
-
self.skip_channels = (
|
|
403
|
-
skip_channels if skip_channels is not None else refine_convs_filters
|
|
404
|
-
)
|
|
405
|
-
|
|
406
398
|
self.x_in_shape = x_in_shape
|
|
407
399
|
self.current_stride = current_stride
|
|
408
400
|
self.upsampling_stride = upsampling_stride
|
|
@@ -477,13 +469,13 @@ class SimpleUpsamplingBlock(nn.Module):
|
|
|
477
469
|
first_conv_in_channels = refine_convs_filters
|
|
478
470
|
else:
|
|
479
471
|
if self.up_interpolate:
|
|
480
|
-
# With interpolation, input is x_in_shape +
|
|
481
|
-
#
|
|
482
|
-
first_conv_in_channels = x_in_shape +
|
|
472
|
+
# With interpolation, input is x_in_shape + feature channels
|
|
473
|
+
# The feature channels are the same as x_in_shape since they come from the same level
|
|
474
|
+
first_conv_in_channels = x_in_shape + refine_convs_filters
|
|
483
475
|
else:
|
|
484
|
-
# With transpose conv, input is transpose_conv_output +
|
|
476
|
+
# With transpose conv, input is transpose_conv_output + feature channels
|
|
485
477
|
first_conv_in_channels = (
|
|
486
|
-
|
|
478
|
+
refine_convs_filters + transpose_convs_filters
|
|
487
479
|
)
|
|
488
480
|
else:
|
|
489
481
|
if not self.feat_concat:
|
|
@@ -590,7 +582,6 @@ class Decoder(nn.Module):
|
|
|
590
582
|
block_contraction: bool = False,
|
|
591
583
|
up_interpolate: bool = True,
|
|
592
584
|
prefix: str = "dec",
|
|
593
|
-
encoder_channels: Optional[List[int]] = None,
|
|
594
585
|
) -> None:
|
|
595
586
|
"""Initialize the class."""
|
|
596
587
|
super().__init__()
|
|
@@ -607,7 +598,6 @@ class Decoder(nn.Module):
|
|
|
607
598
|
self.block_contraction = block_contraction
|
|
608
599
|
self.prefix = prefix
|
|
609
600
|
self.stride_to_filters = {}
|
|
610
|
-
self.encoder_channels = encoder_channels
|
|
611
601
|
|
|
612
602
|
self.current_strides = []
|
|
613
603
|
self.residuals = 0
|
|
@@ -634,13 +624,6 @@ class Decoder(nn.Module):
|
|
|
634
624
|
|
|
635
625
|
next_stride = current_stride // 2
|
|
636
626
|
|
|
637
|
-
# Determine skip channels for this decoder block
|
|
638
|
-
# If encoder_channels provided, use actual encoder channels
|
|
639
|
-
# Otherwise fall back to computed filters (for UNet compatibility)
|
|
640
|
-
skip_channels = None
|
|
641
|
-
if encoder_channels is not None and block < len(encoder_channels):
|
|
642
|
-
skip_channels = encoder_channels[block]
|
|
643
|
-
|
|
644
627
|
if self.stem_blocks > 0 and block >= down_blocks + self.stem_blocks:
|
|
645
628
|
# This accounts for the case where we dont have any more down block features to concatenate with.
|
|
646
629
|
# In this case, add a simple upsampling block with a conv layer and with no concatenation
|
|
@@ -659,7 +642,6 @@ class Decoder(nn.Module):
|
|
|
659
642
|
transpose_convs_batch_norm=False,
|
|
660
643
|
feat_concat=False,
|
|
661
644
|
prefix=f"{self.prefix}{block}_s{current_stride}_to_s{next_stride}",
|
|
662
|
-
skip_channels=skip_channels,
|
|
663
645
|
)
|
|
664
646
|
)
|
|
665
647
|
else:
|
|
@@ -677,7 +659,6 @@ class Decoder(nn.Module):
|
|
|
677
659
|
transpose_convs_filters=block_filters_out,
|
|
678
660
|
transpose_convs_batch_norm=False,
|
|
679
661
|
prefix=f"{self.prefix}{block}_s{current_stride}_to_s{next_stride}",
|
|
680
|
-
skip_channels=skip_channels,
|
|
681
662
|
)
|
|
682
663
|
)
|
|
683
664
|
|
sleap_nn/architectures/swint.py
CHANGED
|
@@ -309,13 +309,6 @@ class SwinTWrapper(nn.Module):
|
|
|
309
309
|
self.stem_patch_stride * (2**3) * 2
|
|
310
310
|
) # stem_stride * down_blocks_stride * final_max_pool_stride
|
|
311
311
|
|
|
312
|
-
# Encoder channels for skip connections (reversed to match decoder order)
|
|
313
|
-
# SwinT channels: embed * 2^i for each stage i, then reversed
|
|
314
|
-
num_stages = len(self.arch["depths"])
|
|
315
|
-
encoder_channels = [
|
|
316
|
-
self.arch["embed"] * (2 ** (num_stages - 1 - i)) for i in range(num_stages)
|
|
317
|
-
]
|
|
318
|
-
|
|
319
312
|
self.dec = Decoder(
|
|
320
313
|
x_in_shape=block_filters,
|
|
321
314
|
current_stride=self.current_stride,
|
|
@@ -328,7 +321,6 @@ class SwinTWrapper(nn.Module):
|
|
|
328
321
|
block_contraction=self.block_contraction,
|
|
329
322
|
output_stride=output_stride,
|
|
330
323
|
up_interpolate=up_interpolate,
|
|
331
|
-
encoder_channels=encoder_channels,
|
|
332
324
|
)
|
|
333
325
|
|
|
334
326
|
if len(self.dec.decoder_stack):
|
sleap_nn/cli.py
CHANGED
|
@@ -1,55 +1,17 @@
|
|
|
1
|
-
"""Unified CLI for SLEAP-NN using
|
|
1
|
+
"""Unified CLI for SLEAP-NN using Click."""
|
|
2
2
|
|
|
3
|
-
import
|
|
4
|
-
import tempfile
|
|
5
|
-
import shutil
|
|
6
|
-
from datetime import datetime
|
|
7
|
-
|
|
8
|
-
import rich_click as click
|
|
9
|
-
from click import Command
|
|
3
|
+
import click
|
|
10
4
|
from loguru import logger
|
|
11
5
|
from pathlib import Path
|
|
12
6
|
from omegaconf import OmegaConf, DictConfig
|
|
13
7
|
import sleap_io as sio
|
|
14
8
|
from sleap_nn.predict import run_inference, frame_list
|
|
15
9
|
from sleap_nn.evaluation import run_evaluation
|
|
16
|
-
from sleap_nn.export.cli import export as export_command
|
|
17
|
-
from sleap_nn.export.cli import predict as predict_command
|
|
18
10
|
from sleap_nn.train import run_training
|
|
19
11
|
from sleap_nn import __version__
|
|
20
|
-
from sleap_nn.config.utils import get_model_type_from_cfg
|
|
21
12
|
import hydra
|
|
22
13
|
import sys
|
|
23
|
-
|
|
24
|
-
# Rich-click configuration for styled help
|
|
25
|
-
click.rich_click.TEXT_MARKUP = "markdown"
|
|
26
|
-
click.rich_click.SHOW_ARGUMENTS = True
|
|
27
|
-
click.rich_click.GROUP_ARGUMENTS_OPTIONS = True
|
|
28
|
-
click.rich_click.STYLE_ERRORS_SUGGESTION = "magenta italic"
|
|
29
|
-
click.rich_click.ERRORS_EPILOGUE = (
|
|
30
|
-
"Try 'sleap-nn [COMMAND] --help' for more information."
|
|
31
|
-
)
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
def is_config_path(arg: str) -> bool:
|
|
35
|
-
"""Check if an argument looks like a config file path.
|
|
36
|
-
|
|
37
|
-
Returns True if the arg ends with .yaml or .yml.
|
|
38
|
-
"""
|
|
39
|
-
return arg.endswith(".yaml") or arg.endswith(".yml")
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
def split_config_path(config_path: str) -> tuple:
|
|
43
|
-
"""Split a full config path into (config_dir, config_name).
|
|
44
|
-
|
|
45
|
-
Args:
|
|
46
|
-
config_path: Full path to a config file.
|
|
47
|
-
|
|
48
|
-
Returns:
|
|
49
|
-
Tuple of (config_dir, config_name) where config_dir is an absolute path.
|
|
50
|
-
"""
|
|
51
|
-
path = Path(config_path).resolve()
|
|
52
|
-
return path.parent.as_posix(), path.name
|
|
14
|
+
from click import Command
|
|
53
15
|
|
|
54
16
|
|
|
55
17
|
def print_version(ctx, param, value):
|
|
@@ -93,165 +55,47 @@ def cli():
|
|
|
93
55
|
|
|
94
56
|
Use subcommands to run different workflows:
|
|
95
57
|
|
|
96
|
-
train - Run training workflow
|
|
97
|
-
track - Run inference/tracking workflow
|
|
58
|
+
train - Run training workflow
|
|
59
|
+
track - Run inference/ tracking workflow
|
|
98
60
|
eval - Run evaluation workflow
|
|
99
61
|
system - Display system information and GPU status
|
|
100
62
|
"""
|
|
101
63
|
pass
|
|
102
64
|
|
|
103
65
|
|
|
104
|
-
def _get_num_devices_from_config(cfg: DictConfig) -> int:
|
|
105
|
-
"""Determine the number of devices from config.
|
|
106
|
-
|
|
107
|
-
User preferences take precedence over auto-detection:
|
|
108
|
-
- trainer_device_indices=[0] → 1 device (user choice)
|
|
109
|
-
- trainer_devices=1 → 1 device (user choice)
|
|
110
|
-
- trainer_devices="auto" or unset → auto-detect available GPUs
|
|
111
|
-
|
|
112
|
-
Returns:
|
|
113
|
-
Number of devices to use for training.
|
|
114
|
-
"""
|
|
115
|
-
import torch
|
|
116
|
-
|
|
117
|
-
# User preference: explicit device indices (highest priority)
|
|
118
|
-
device_indices = OmegaConf.select(
|
|
119
|
-
cfg, "trainer_config.trainer_device_indices", default=None
|
|
120
|
-
)
|
|
121
|
-
if device_indices is not None and len(device_indices) > 0:
|
|
122
|
-
return len(device_indices)
|
|
123
|
-
|
|
124
|
-
# User preference: explicit device count
|
|
125
|
-
devices = OmegaConf.select(cfg, "trainer_config.trainer_devices", default="auto")
|
|
126
|
-
|
|
127
|
-
if isinstance(devices, int):
|
|
128
|
-
return devices
|
|
129
|
-
|
|
130
|
-
# Auto-detect only when user hasn't specified (devices is "auto" or None)
|
|
131
|
-
if devices in ("auto", None, "None"):
|
|
132
|
-
accelerator = OmegaConf.select(
|
|
133
|
-
cfg, "trainer_config.trainer_accelerator", default="auto"
|
|
134
|
-
)
|
|
135
|
-
|
|
136
|
-
if accelerator == "cpu":
|
|
137
|
-
return 1
|
|
138
|
-
elif torch.cuda.is_available():
|
|
139
|
-
return torch.cuda.device_count()
|
|
140
|
-
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
141
|
-
return 1
|
|
142
|
-
else:
|
|
143
|
-
return 1
|
|
144
|
-
|
|
145
|
-
return 1
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
def _finalize_config(cfg: DictConfig) -> DictConfig:
|
|
149
|
-
"""Finalize configuration by generating run_name if not provided.
|
|
150
|
-
|
|
151
|
-
This runs ONCE before subprocess, ensuring all workers get the same run_name.
|
|
152
|
-
"""
|
|
153
|
-
# Resolve ckpt_dir first
|
|
154
|
-
ckpt_dir = OmegaConf.select(cfg, "trainer_config.ckpt_dir", default=None)
|
|
155
|
-
if ckpt_dir is None or ckpt_dir == "" or ckpt_dir == "None":
|
|
156
|
-
cfg.trainer_config.ckpt_dir = "."
|
|
157
|
-
|
|
158
|
-
# Generate run_name if not provided
|
|
159
|
-
run_name = OmegaConf.select(cfg, "trainer_config.run_name", default=None)
|
|
160
|
-
if run_name is None or run_name == "" or run_name == "None":
|
|
161
|
-
# Get model type from config
|
|
162
|
-
model_type = get_model_type_from_cfg(cfg)
|
|
163
|
-
|
|
164
|
-
# Count frames from labels
|
|
165
|
-
train_paths = cfg.data_config.train_labels_path
|
|
166
|
-
val_paths = OmegaConf.select(cfg, "data_config.val_labels_path", default=None)
|
|
167
|
-
|
|
168
|
-
train_count = sum(len(sio.load_slp(p)) for p in train_paths)
|
|
169
|
-
val_count = 0
|
|
170
|
-
if val_paths:
|
|
171
|
-
val_count = sum(len(sio.load_slp(p)) for p in val_paths)
|
|
172
|
-
|
|
173
|
-
# Generate full run_name with timestamp
|
|
174
|
-
timestamp = datetime.now().strftime("%y%m%d_%H%M%S")
|
|
175
|
-
run_name = f"{timestamp}.{model_type}.n={train_count + val_count}"
|
|
176
|
-
cfg.trainer_config.run_name = run_name
|
|
177
|
-
|
|
178
|
-
logger.info(f"Generated run_name: {run_name}")
|
|
179
|
-
|
|
180
|
-
return cfg
|
|
181
|
-
|
|
182
|
-
|
|
183
66
|
def show_training_help():
|
|
184
|
-
"""Display training help information
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
**Start a new training run:**
|
|
210
|
-
```bash
|
|
211
|
-
sleap-nn train path/to/config.yaml
|
|
212
|
-
sleap-nn train --config path/to/config.yaml
|
|
213
|
-
```
|
|
214
|
-
|
|
215
|
-
**With overrides:**
|
|
216
|
-
```bash
|
|
217
|
-
sleap-nn train config.yaml trainer_config.max_epochs=100
|
|
218
|
-
```
|
|
219
|
-
|
|
220
|
-
**Resume training:**
|
|
221
|
-
```bash
|
|
222
|
-
sleap-nn train config.yaml trainer_config.resume_ckpt_path=/path/to/ckpt
|
|
223
|
-
```
|
|
224
|
-
|
|
225
|
-
**Legacy usage (still supported):**
|
|
226
|
-
```bash
|
|
227
|
-
sleap-nn train --config-dir /path/to/dir --config-name myrun
|
|
228
|
-
```
|
|
229
|
-
|
|
230
|
-
## Tips
|
|
231
|
-
|
|
232
|
-
- Use `-m/--multirun` for sweeps; outputs go under `hydra.sweep.dir`
|
|
233
|
-
- For Hydra flags and completion, use `--hydra-help`
|
|
234
|
-
- Config documentation: https://nn.sleap.ai/config/
|
|
67
|
+
"""Display training help information."""
|
|
68
|
+
help_text = """
|
|
69
|
+
sleap-nn train — Train SLEAP models from a config YAML file.
|
|
70
|
+
|
|
71
|
+
Usage:
|
|
72
|
+
sleap-nn train --config-dir <dir> --config-name <name> [overrides]
|
|
73
|
+
|
|
74
|
+
Common overrides:
|
|
75
|
+
trainer_config.max_epochs=100
|
|
76
|
+
trainer_config.batch_size=32
|
|
77
|
+
|
|
78
|
+
Examples:
|
|
79
|
+
Start new run:
|
|
80
|
+
sleap-nn train --config-dir /path/to/config_dir/ --config-name myrun
|
|
81
|
+
Resume 20 more epochs:
|
|
82
|
+
sleap-nn train --config-dir /path/to/config_dir/ --config-name myrun \\
|
|
83
|
+
trainer_config.resume_ckpt_path=<path/to/ckpt> \\
|
|
84
|
+
trainer_config.max_epochs=20
|
|
85
|
+
|
|
86
|
+
Tips:
|
|
87
|
+
- Use -m/--multirun for sweeps; outputs go under hydra.sweep.dir.
|
|
88
|
+
- For Hydra flags and completion, use --hydra-help.
|
|
89
|
+
|
|
90
|
+
For a detailed list of all available config options, please refer to https://nn.sleap.ai/config/.
|
|
235
91
|
"""
|
|
236
|
-
|
|
237
|
-
Panel(
|
|
238
|
-
Markdown(help_md),
|
|
239
|
-
title="[bold cyan]sleap-nn train[/bold cyan]",
|
|
240
|
-
subtitle="Train SLEAP models from a config YAML file",
|
|
241
|
-
border_style="cyan",
|
|
242
|
-
)
|
|
243
|
-
)
|
|
92
|
+
click.echo(help_text)
|
|
244
93
|
|
|
245
94
|
|
|
246
95
|
@cli.command(cls=TrainCommand)
|
|
96
|
+
@click.option("--config-name", "-c", type=str, help="Configuration file name")
|
|
247
97
|
@click.option(
|
|
248
|
-
"--config",
|
|
249
|
-
type=str,
|
|
250
|
-
help="Path to configuration file (e.g., path/to/config.yaml)",
|
|
251
|
-
)
|
|
252
|
-
@click.option("--config-name", "-c", type=str, help="Configuration file name (legacy)")
|
|
253
|
-
@click.option(
|
|
254
|
-
"--config-dir", "-d", type=str, default=".", help="Configuration directory (legacy)"
|
|
98
|
+
"--config-dir", "-d", type=str, default=".", help="Configuration directory path"
|
|
255
99
|
)
|
|
256
100
|
@click.option(
|
|
257
101
|
"--video-paths",
|
|
@@ -283,80 +127,26 @@ sleap-nn train --config-dir /path/to/dir --config-name myrun
|
|
|
283
127
|
"Can be specified multiple times. "
|
|
284
128
|
'Example: --prefix-map "/old/server/path" "/new/local/path"',
|
|
285
129
|
)
|
|
286
|
-
@click.option(
|
|
287
|
-
"--video-config",
|
|
288
|
-
type=str,
|
|
289
|
-
hidden=True,
|
|
290
|
-
help="Path to video replacement config YAML (internal use for multi-GPU).",
|
|
291
|
-
)
|
|
292
130
|
@click.argument("overrides", nargs=-1, type=click.UNPROCESSED)
|
|
293
|
-
def train(
|
|
294
|
-
config,
|
|
295
|
-
config_name,
|
|
296
|
-
config_dir,
|
|
297
|
-
video_paths,
|
|
298
|
-
video_path_map,
|
|
299
|
-
prefix_map,
|
|
300
|
-
video_config,
|
|
301
|
-
overrides,
|
|
302
|
-
):
|
|
131
|
+
def train(config_name, config_dir, video_paths, video_path_map, prefix_map, overrides):
|
|
303
132
|
"""Run training workflow with Hydra config overrides.
|
|
304
133
|
|
|
305
|
-
Automatically detects multi-GPU setups and handles run_name synchronization
|
|
306
|
-
by spawning training in a subprocess with a pre-generated config.
|
|
307
|
-
|
|
308
134
|
Examples:
|
|
309
|
-
sleap-nn train path/to/
|
|
310
|
-
sleap-nn train
|
|
311
|
-
sleap-nn train
|
|
135
|
+
sleap-nn train --config-name myconfig --config-dir /path/to/config_dir/
|
|
136
|
+
sleap-nn train -c myconfig -d /path/to/config_dir/ trainer_config.max_epochs=100
|
|
137
|
+
sleap-nn train -c myconfig -d /path/to/config_dir/ +experiment=new_model
|
|
312
138
|
"""
|
|
313
|
-
#
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
# Check if the first positional arg is a config path (not a Hydra override)
|
|
317
|
-
config_from_positional = None
|
|
318
|
-
if overrides and is_config_path(overrides[0]):
|
|
319
|
-
config_from_positional = overrides.pop(0)
|
|
320
|
-
|
|
321
|
-
# Resolve config path with priority:
|
|
322
|
-
# 1. Positional config path (e.g., sleap-nn train config.yaml)
|
|
323
|
-
# 2. --config flag (e.g., sleap-nn train --config config.yaml)
|
|
324
|
-
# 3. Legacy --config-dir/--config-name flags
|
|
325
|
-
if config_from_positional:
|
|
326
|
-
config_dir, config_name = split_config_path(config_from_positional)
|
|
327
|
-
elif config:
|
|
328
|
-
config_dir, config_name = split_config_path(config)
|
|
329
|
-
elif config_name:
|
|
330
|
-
config_dir = Path(config_dir).resolve().as_posix()
|
|
331
|
-
else:
|
|
332
|
-
# No config provided - show help
|
|
139
|
+
# Show help if no config name provided
|
|
140
|
+
if not config_name:
|
|
333
141
|
show_training_help()
|
|
334
142
|
return
|
|
335
143
|
|
|
336
|
-
#
|
|
337
|
-
#
|
|
338
|
-
|
|
339
|
-
video_cfg = OmegaConf.load(video_config)
|
|
340
|
-
video_paths = tuple(video_cfg.video_paths) if video_cfg.video_paths else ()
|
|
341
|
-
video_path_map = (
|
|
342
|
-
dict(video_cfg.video_path_map) if video_cfg.video_path_map else None
|
|
343
|
-
)
|
|
344
|
-
prefix_map = dict(video_cfg.prefix_map) if video_cfg.prefix_map else None
|
|
345
|
-
|
|
346
|
-
has_video_paths = len(video_paths) > 0
|
|
347
|
-
has_video_path_map = video_path_map is not None
|
|
348
|
-
has_prefix_map = prefix_map is not None
|
|
349
|
-
options_used = sum([has_video_paths, has_video_path_map, has_prefix_map])
|
|
350
|
-
|
|
351
|
-
if options_used > 1:
|
|
352
|
-
raise click.UsageError(
|
|
353
|
-
"Cannot use multiple path replacement options. "
|
|
354
|
-
"Choose one of: --video-paths, --video-path-map, or --prefix-map."
|
|
355
|
-
)
|
|
356
|
-
|
|
357
|
-
# Load config to detect device count
|
|
144
|
+
# Initialize Hydra manually
|
|
145
|
+
# resolve the path to the config directory (hydra expects absolute path)
|
|
146
|
+
config_dir = Path(config_dir).resolve().as_posix()
|
|
358
147
|
with hydra.initialize_config_dir(config_dir=config_dir, version_base=None):
|
|
359
|
-
|
|
148
|
+
# Compose config with overrides
|
|
149
|
+
cfg = hydra.compose(config_name=config_name, overrides=list(overrides))
|
|
360
150
|
|
|
361
151
|
# Validate config
|
|
362
152
|
if not hasattr(cfg, "model_config") or not cfg.model_config:
|
|
@@ -365,76 +155,6 @@ def train(
|
|
|
365
155
|
)
|
|
366
156
|
raise click.Abort()
|
|
367
157
|
|
|
368
|
-
num_devices = _get_num_devices_from_config(cfg)
|
|
369
|
-
|
|
370
|
-
# Check if run_name is already set (means we're a subprocess or user provided it)
|
|
371
|
-
run_name = OmegaConf.select(cfg, "trainer_config.run_name", default=None)
|
|
372
|
-
run_name_is_set = run_name is not None and run_name != "" and run_name != "None"
|
|
373
|
-
|
|
374
|
-
# Multi-GPU path: spawn subprocess with finalized config
|
|
375
|
-
# Only do this if run_name is NOT set (otherwise we'd loop infinitely or user set it)
|
|
376
|
-
if num_devices > 1 and not run_name_is_set:
|
|
377
|
-
logger.info(
|
|
378
|
-
f"Detected {num_devices} devices, using subprocess for run_name sync..."
|
|
379
|
-
)
|
|
380
|
-
|
|
381
|
-
# Load and finalize config (generate run_name, apply overrides)
|
|
382
|
-
with hydra.initialize_config_dir(config_dir=config_dir, version_base=None):
|
|
383
|
-
cfg = hydra.compose(config_name=config_name, overrides=overrides)
|
|
384
|
-
cfg = _finalize_config(cfg)
|
|
385
|
-
|
|
386
|
-
# Save finalized config to temp file
|
|
387
|
-
temp_dir = tempfile.mkdtemp(prefix="sleap_nn_train_")
|
|
388
|
-
temp_config_path = Path(temp_dir) / "training_config.yaml"
|
|
389
|
-
OmegaConf.save(cfg, temp_config_path)
|
|
390
|
-
logger.info(f"Saved finalized config to: {temp_config_path}")
|
|
391
|
-
|
|
392
|
-
# Save video replacement config if needed (so subprocess doesn't need CLI args)
|
|
393
|
-
temp_video_config_path = None
|
|
394
|
-
if options_used == 1:
|
|
395
|
-
video_replacement_config = {
|
|
396
|
-
"video_paths": list(video_paths) if has_video_paths else None,
|
|
397
|
-
"video_path_map": dict(video_path_map) if has_video_path_map else None,
|
|
398
|
-
"prefix_map": dict(prefix_map) if has_prefix_map else None,
|
|
399
|
-
}
|
|
400
|
-
temp_video_config_path = Path(temp_dir) / "video_replacement.yaml"
|
|
401
|
-
OmegaConf.save(
|
|
402
|
-
OmegaConf.create(video_replacement_config), temp_video_config_path
|
|
403
|
-
)
|
|
404
|
-
logger.info(f"Saved video replacement config to: {temp_video_config_path}")
|
|
405
|
-
|
|
406
|
-
# Build subprocess command (no video args - they're in the temp file)
|
|
407
|
-
cmd = [sys.executable, "-m", "sleap_nn.cli", "train", str(temp_config_path)]
|
|
408
|
-
if temp_video_config_path:
|
|
409
|
-
cmd.extend(["--video-config", str(temp_video_config_path)])
|
|
410
|
-
|
|
411
|
-
logger.info(f"Launching subprocess: {' '.join(cmd)}")
|
|
412
|
-
|
|
413
|
-
try:
|
|
414
|
-
process = subprocess.Popen(cmd)
|
|
415
|
-
result = process.wait()
|
|
416
|
-
if result != 0:
|
|
417
|
-
logger.error(f"Training failed with exit code {result}")
|
|
418
|
-
sys.exit(result)
|
|
419
|
-
except KeyboardInterrupt:
|
|
420
|
-
logger.info("Training interrupted, terminating subprocess...")
|
|
421
|
-
process.terminate()
|
|
422
|
-
try:
|
|
423
|
-
process.wait(timeout=5)
|
|
424
|
-
except subprocess.TimeoutExpired:
|
|
425
|
-
process.kill()
|
|
426
|
-
process.wait()
|
|
427
|
-
sys.exit(1)
|
|
428
|
-
finally:
|
|
429
|
-
shutil.rmtree(temp_dir, ignore_errors=True)
|
|
430
|
-
logger.info("Cleaned up temporary files")
|
|
431
|
-
|
|
432
|
-
return
|
|
433
|
-
|
|
434
|
-
# Single GPU (or subprocess worker): run directly
|
|
435
|
-
with hydra.initialize_config_dir(config_dir=config_dir, version_base=None):
|
|
436
|
-
cfg = hydra.compose(config_name=config_name, overrides=overrides)
|
|
437
|
-
|
|
438
158
|
logger.info("Input config:")
|
|
439
159
|
logger.info("\n" + OmegaConf.to_yaml(cfg))
|
|
440
160
|
|
|
@@ -442,6 +162,19 @@ def train(
|
|
|
442
162
|
train_labels = None
|
|
443
163
|
val_labels = None
|
|
444
164
|
|
|
165
|
+
# Check that only one replacement option is used
|
|
166
|
+
# video_paths is a tuple (empty if not used), others are None or dict
|
|
167
|
+
has_video_paths = len(video_paths) > 0
|
|
168
|
+
has_video_path_map = video_path_map is not None
|
|
169
|
+
has_prefix_map = prefix_map is not None
|
|
170
|
+
options_used = sum([has_video_paths, has_video_path_map, has_prefix_map])
|
|
171
|
+
|
|
172
|
+
if options_used > 1:
|
|
173
|
+
raise click.UsageError(
|
|
174
|
+
"Cannot use multiple path replacement options. "
|
|
175
|
+
"Choose one of: --video-paths, --video-path-map, or --prefix-map."
|
|
176
|
+
)
|
|
177
|
+
|
|
445
178
|
if options_used == 1:
|
|
446
179
|
# Load train labels
|
|
447
180
|
train_labels = [
|
|
@@ -459,12 +192,15 @@ def train(
|
|
|
459
192
|
|
|
460
193
|
# Build replacement arguments based on option used
|
|
461
194
|
if has_video_paths:
|
|
195
|
+
# List of paths (order must match videos in labels file)
|
|
462
196
|
replace_kwargs = {
|
|
463
197
|
"new_filenames": [Path(p).as_posix() for p in video_paths]
|
|
464
198
|
}
|
|
465
199
|
elif has_video_path_map:
|
|
200
|
+
# Dictionary mapping old filenames to new filenames
|
|
466
201
|
replace_kwargs = {"filename_map": video_path_map}
|
|
467
202
|
else: # has_prefix_map
|
|
203
|
+
# Dictionary mapping old prefixes to new prefixes
|
|
468
204
|
replace_kwargs = {"prefix_map": prefix_map}
|
|
469
205
|
|
|
470
206
|
# Apply replacement to train labels
|
|
@@ -666,7 +402,7 @@ def train(
|
|
|
666
402
|
@click.option(
|
|
667
403
|
"--queue_maxsize",
|
|
668
404
|
type=int,
|
|
669
|
-
default=
|
|
405
|
+
default=8,
|
|
670
406
|
help="Maximum size of the frame buffer queue.",
|
|
671
407
|
)
|
|
672
408
|
@click.option(
|
|
@@ -681,36 +417,6 @@ def train(
|
|
|
681
417
|
default=0.2,
|
|
682
418
|
help="Minimum confidence map value to consider a peak as valid.",
|
|
683
419
|
)
|
|
684
|
-
@click.option(
|
|
685
|
-
"--filter_overlapping",
|
|
686
|
-
is_flag=True,
|
|
687
|
-
default=False,
|
|
688
|
-
help=(
|
|
689
|
-
"Enable filtering of overlapping instances after inference using greedy NMS. "
|
|
690
|
-
"Applied independently of tracking. (default: False)"
|
|
691
|
-
),
|
|
692
|
-
)
|
|
693
|
-
@click.option(
|
|
694
|
-
"--filter_overlapping_method",
|
|
695
|
-
type=click.Choice(["iou", "oks"]),
|
|
696
|
-
default="iou",
|
|
697
|
-
help=(
|
|
698
|
-
"Similarity metric for filtering overlapping instances. "
|
|
699
|
-
"'iou': bounding box intersection-over-union. "
|
|
700
|
-
"'oks': Object Keypoint Similarity (pose-based). (default: iou)"
|
|
701
|
-
),
|
|
702
|
-
)
|
|
703
|
-
@click.option(
|
|
704
|
-
"--filter_overlapping_threshold",
|
|
705
|
-
type=float,
|
|
706
|
-
default=0.8,
|
|
707
|
-
help=(
|
|
708
|
-
"Similarity threshold for filtering overlapping instances. "
|
|
709
|
-
"Instances with similarity above this threshold are removed, "
|
|
710
|
-
"keeping the higher-scoring instance. "
|
|
711
|
-
"Typical values: 0.3 (aggressive) to 0.8 (permissive). (default: 0.8)"
|
|
712
|
-
),
|
|
713
|
-
)
|
|
714
420
|
@click.option(
|
|
715
421
|
"--integral_refinement",
|
|
716
422
|
type=str,
|
|
@@ -843,12 +549,6 @@ def train(
|
|
|
843
549
|
default=0,
|
|
844
550
|
help="IOU to use when culling instances *after* tracking. (default: 0)",
|
|
845
551
|
)
|
|
846
|
-
@click.option(
|
|
847
|
-
"--gui",
|
|
848
|
-
is_flag=True,
|
|
849
|
-
default=False,
|
|
850
|
-
help="Output JSON progress for GUI integration instead of Rich progress bar.",
|
|
851
|
-
)
|
|
852
552
|
def track(**kwargs):
|
|
853
553
|
"""Run Inference and Tracking workflow."""
|
|
854
554
|
# Convert model_paths from tuple to list
|
|
@@ -913,9 +613,5 @@ def system():
|
|
|
913
613
|
print_system_info()
|
|
914
614
|
|
|
915
615
|
|
|
916
|
-
cli.add_command(export_command)
|
|
917
|
-
cli.add_command(predict_command)
|
|
918
|
-
|
|
919
|
-
|
|
920
616
|
if __name__ == "__main__":
|
|
921
617
|
cli()
|