sleap-nn 0.1.0a2__py3-none-any.whl → 0.1.0a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. sleap_nn/__init__.py +1 -1
  2. sleap_nn/architectures/convnext.py +5 -0
  3. sleap_nn/architectures/encoder_decoder.py +25 -6
  4. sleap_nn/architectures/swint.py +8 -0
  5. sleap_nn/cli.py +168 -39
  6. sleap_nn/evaluation.py +8 -0
  7. sleap_nn/export/__init__.py +21 -0
  8. sleap_nn/export/cli.py +1778 -0
  9. sleap_nn/export/exporters/__init__.py +51 -0
  10. sleap_nn/export/exporters/onnx_exporter.py +80 -0
  11. sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
  12. sleap_nn/export/metadata.py +225 -0
  13. sleap_nn/export/predictors/__init__.py +63 -0
  14. sleap_nn/export/predictors/base.py +22 -0
  15. sleap_nn/export/predictors/onnx.py +154 -0
  16. sleap_nn/export/predictors/tensorrt.py +312 -0
  17. sleap_nn/export/utils.py +307 -0
  18. sleap_nn/export/wrappers/__init__.py +25 -0
  19. sleap_nn/export/wrappers/base.py +96 -0
  20. sleap_nn/export/wrappers/bottomup.py +243 -0
  21. sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
  22. sleap_nn/export/wrappers/centered_instance.py +56 -0
  23. sleap_nn/export/wrappers/centroid.py +58 -0
  24. sleap_nn/export/wrappers/single_instance.py +83 -0
  25. sleap_nn/export/wrappers/topdown.py +180 -0
  26. sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
  27. sleap_nn/inference/peak_finding.py +47 -17
  28. sleap_nn/inference/postprocessing.py +284 -0
  29. sleap_nn/inference/predictors.py +213 -106
  30. sleap_nn/predict.py +35 -7
  31. sleap_nn/train.py +64 -0
  32. sleap_nn/training/callbacks.py +69 -22
  33. sleap_nn/training/lightning_modules.py +332 -30
  34. sleap_nn/training/model_trainer.py +67 -67
  35. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/METADATA +13 -1
  36. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/RECORD +40 -19
  37. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/WHEEL +0 -0
  38. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/entry_points.txt +0 -0
  39. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/licenses/LICENSE +0 -0
  40. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/top_level.txt +0 -0
sleap_nn/__init__.py CHANGED
@@ -50,7 +50,7 @@ logger.add(
50
50
  colorize=False,
51
51
  )
52
52
 
53
- __version__ = "0.1.0a2"
53
+ __version__ = "0.1.0a4"
54
54
 
55
55
  # Public API
56
56
  from sleap_nn.evaluation import load_metrics
@@ -281,6 +281,10 @@ class ConvNextWrapper(nn.Module):
281
281
  # Keep the block output filters the same
282
282
  x_in_shape = int(self.arch["channels"][-1] * filters_rate)
283
283
 
284
+ # Encoder channels for skip connections (reversed to match decoder order)
285
+ # The forward pass uses enc_output[::2][::-1] for skip features
286
+ encoder_channels = self.arch["channels"][::-1]
287
+
284
288
  self.dec = Decoder(
285
289
  x_in_shape=x_in_shape,
286
290
  current_stride=self.current_stride,
@@ -293,6 +297,7 @@ class ConvNextWrapper(nn.Module):
293
297
  block_contraction=self.block_contraction,
294
298
  output_stride=self.output_stride,
295
299
  up_interpolate=up_interpolate,
300
+ encoder_channels=encoder_channels,
296
301
  )
297
302
 
298
303
  if len(self.dec.decoder_stack):
@@ -25,7 +25,7 @@ classes.
25
25
  See the `EncoderDecoder` base class for requirements for creating new architectures.
26
26
  """
27
27
 
28
- from typing import List, Text, Tuple, Union
28
+ from typing import List, Optional, Text, Tuple, Union
29
29
  from collections import OrderedDict
30
30
  import torch
31
31
  from torch import nn
@@ -391,10 +391,18 @@ class SimpleUpsamplingBlock(nn.Module):
391
391
  transpose_convs_activation: Text = "relu",
392
392
  feat_concat: bool = True,
393
393
  prefix: Text = "",
394
+ skip_channels: Optional[int] = None,
394
395
  ) -> None:
395
396
  """Initialize the class."""
396
397
  super().__init__()
397
398
 
399
+ # Determine skip connection channels
400
+ # If skip_channels is provided, use it; otherwise fall back to refine_convs_filters
401
+ # This allows ConvNext/SwinT to specify actual encoder channels
402
+ self.skip_channels = (
403
+ skip_channels if skip_channels is not None else refine_convs_filters
404
+ )
405
+
398
406
  self.x_in_shape = x_in_shape
399
407
  self.current_stride = current_stride
400
408
  self.upsampling_stride = upsampling_stride
@@ -469,13 +477,13 @@ class SimpleUpsamplingBlock(nn.Module):
469
477
  first_conv_in_channels = refine_convs_filters
470
478
  else:
471
479
  if self.up_interpolate:
472
- # With interpolation, input is x_in_shape + feature channels
473
- # The feature channels are the same as x_in_shape since they come from the same level
474
- first_conv_in_channels = x_in_shape + refine_convs_filters
480
+ # With interpolation, input is x_in_shape + skip_channels
481
+ # skip_channels may differ from refine_convs_filters for ConvNext/SwinT
482
+ first_conv_in_channels = x_in_shape + self.skip_channels
475
483
  else:
476
- # With transpose conv, input is transpose_conv_output + feature channels
484
+ # With transpose conv, input is transpose_conv_output + skip_channels
477
485
  first_conv_in_channels = (
478
- refine_convs_filters + transpose_convs_filters
486
+ self.skip_channels + transpose_convs_filters
479
487
  )
480
488
  else:
481
489
  if not self.feat_concat:
@@ -582,6 +590,7 @@ class Decoder(nn.Module):
582
590
  block_contraction: bool = False,
583
591
  up_interpolate: bool = True,
584
592
  prefix: str = "dec",
593
+ encoder_channels: Optional[List[int]] = None,
585
594
  ) -> None:
586
595
  """Initialize the class."""
587
596
  super().__init__()
@@ -598,6 +607,7 @@ class Decoder(nn.Module):
598
607
  self.block_contraction = block_contraction
599
608
  self.prefix = prefix
600
609
  self.stride_to_filters = {}
610
+ self.encoder_channels = encoder_channels
601
611
 
602
612
  self.current_strides = []
603
613
  self.residuals = 0
@@ -624,6 +634,13 @@ class Decoder(nn.Module):
624
634
 
625
635
  next_stride = current_stride // 2
626
636
 
637
+ # Determine skip channels for this decoder block
638
+ # If encoder_channels provided, use actual encoder channels
639
+ # Otherwise fall back to computed filters (for UNet compatibility)
640
+ skip_channels = None
641
+ if encoder_channels is not None and block < len(encoder_channels):
642
+ skip_channels = encoder_channels[block]
643
+
627
644
  if self.stem_blocks > 0 and block >= down_blocks + self.stem_blocks:
628
645
  # This accounts for the case where we dont have any more down block features to concatenate with.
629
646
  # In this case, add a simple upsampling block with a conv layer and with no concatenation
@@ -642,6 +659,7 @@ class Decoder(nn.Module):
642
659
  transpose_convs_batch_norm=False,
643
660
  feat_concat=False,
644
661
  prefix=f"{self.prefix}{block}_s{current_stride}_to_s{next_stride}",
662
+ skip_channels=skip_channels,
645
663
  )
646
664
  )
647
665
  else:
@@ -659,6 +677,7 @@ class Decoder(nn.Module):
659
677
  transpose_convs_filters=block_filters_out,
660
678
  transpose_convs_batch_norm=False,
661
679
  prefix=f"{self.prefix}{block}_s{current_stride}_to_s{next_stride}",
680
+ skip_channels=skip_channels,
662
681
  )
663
682
  )
664
683
 
@@ -309,6 +309,13 @@ class SwinTWrapper(nn.Module):
309
309
  self.stem_patch_stride * (2**3) * 2
310
310
  ) # stem_stride * down_blocks_stride * final_max_pool_stride
311
311
 
312
+ # Encoder channels for skip connections (reversed to match decoder order)
313
+ # SwinT channels: embed * 2^i for each stage i, then reversed
314
+ num_stages = len(self.arch["depths"])
315
+ encoder_channels = [
316
+ self.arch["embed"] * (2 ** (num_stages - 1 - i)) for i in range(num_stages)
317
+ ]
318
+
312
319
  self.dec = Decoder(
313
320
  x_in_shape=block_filters,
314
321
  current_stride=self.current_stride,
@@ -321,6 +328,7 @@ class SwinTWrapper(nn.Module):
321
328
  block_contraction=self.block_contraction,
322
329
  output_stride=output_stride,
323
330
  up_interpolate=up_interpolate,
331
+ encoder_channels=encoder_channels,
324
332
  )
325
333
 
326
334
  if len(self.dec.decoder_stack):
sleap_nn/cli.py CHANGED
@@ -1,17 +1,49 @@
1
- """Unified CLI for SLEAP-NN using Click."""
1
+ """Unified CLI for SLEAP-NN using rich-click for styled output."""
2
2
 
3
- import click
3
+ import rich_click as click
4
+ from click import Command
4
5
  from loguru import logger
5
6
  from pathlib import Path
6
7
  from omegaconf import OmegaConf, DictConfig
7
8
  import sleap_io as sio
8
9
  from sleap_nn.predict import run_inference, frame_list
9
10
  from sleap_nn.evaluation import run_evaluation
11
+ from sleap_nn.export.cli import export as export_command
12
+ from sleap_nn.export.cli import predict as predict_command
10
13
  from sleap_nn.train import run_training
11
14
  from sleap_nn import __version__
12
15
  import hydra
13
16
  import sys
14
- from click import Command
17
+
18
+ # Rich-click configuration for styled help
19
+ click.rich_click.TEXT_MARKUP = "markdown"
20
+ click.rich_click.SHOW_ARGUMENTS = True
21
+ click.rich_click.GROUP_ARGUMENTS_OPTIONS = True
22
+ click.rich_click.STYLE_ERRORS_SUGGESTION = "magenta italic"
23
+ click.rich_click.ERRORS_EPILOGUE = (
24
+ "Try 'sleap-nn [COMMAND] --help' for more information."
25
+ )
26
+
27
+
28
+ def is_config_path(arg: str) -> bool:
29
+ """Check if an argument looks like a config file path.
30
+
31
+ Returns True if the arg ends with .yaml or .yml.
32
+ """
33
+ return arg.endswith(".yaml") or arg.endswith(".yml")
34
+
35
+
36
+ def split_config_path(config_path: str) -> tuple:
37
+ """Split a full config path into (config_dir, config_name).
38
+
39
+ Args:
40
+ config_path: Full path to a config file.
41
+
42
+ Returns:
43
+ Tuple of (config_dir, config_name) where config_dir is an absolute path.
44
+ """
45
+ path = Path(config_path).resolve()
46
+ return path.parent.as_posix(), path.name
15
47
 
16
48
 
17
49
  def print_version(ctx, param, value):
@@ -64,38 +96,77 @@ def cli():
64
96
 
65
97
 
66
98
  def show_training_help():
67
- """Display training help information."""
68
- help_text = """
69
- sleap-nn train — Train SLEAP models from a config YAML file.
70
-
71
- Usage:
72
- sleap-nn train --config-dir <dir> --config-name <name> [overrides]
73
-
74
- Common overrides:
75
- trainer_config.max_epochs=100
76
- trainer_config.batch_size=32
77
-
78
- Examples:
79
- Start new run:
80
- sleap-nn train --config-dir /path/to/config_dir/ --config-name myrun
81
- Resume 20 more epochs:
82
- sleap-nn train --config-dir /path/to/config_dir/ --config-name myrun \\
83
- trainer_config.resume_ckpt_path=<path/to/ckpt> \\
84
- trainer_config.max_epochs=20
85
-
86
- Tips:
87
- - Use -m/--multirun for sweeps; outputs go under hydra.sweep.dir.
88
- - For Hydra flags and completion, use --hydra-help.
89
-
90
- For a detailed list of all available config options, please refer to https://nn.sleap.ai/config/.
99
+ """Display training help information with rich formatting."""
100
+ from rich.console import Console
101
+ from rich.panel import Panel
102
+ from rich.markdown import Markdown
103
+
104
+ console = Console()
105
+
106
+ help_md = """
107
+ ## Usage
108
+
109
+ ```
110
+ sleap-nn train <config.yaml> [overrides]
111
+ sleap-nn train --config <path/to/config.yaml> [overrides]
112
+ ```
113
+
114
+ ## Common Overrides
115
+
116
+ | Override | Description |
117
+ |----------|-------------|
118
+ | `trainer_config.max_epochs=100` | Set maximum training epochs |
119
+ | `trainer_config.batch_size=32` | Set batch size |
120
+ | `trainer_config.save_ckpt=true` | Enable checkpoint saving |
121
+
122
+ ## Examples
123
+
124
+ **Start a new training run:**
125
+ ```bash
126
+ sleap-nn train path/to/config.yaml
127
+ sleap-nn train --config path/to/config.yaml
128
+ ```
129
+
130
+ **With overrides:**
131
+ ```bash
132
+ sleap-nn train config.yaml trainer_config.max_epochs=100
133
+ ```
134
+
135
+ **Resume training:**
136
+ ```bash
137
+ sleap-nn train config.yaml trainer_config.resume_ckpt_path=/path/to/ckpt
138
+ ```
139
+
140
+ **Legacy usage (still supported):**
141
+ ```bash
142
+ sleap-nn train --config-dir /path/to/dir --config-name myrun
143
+ ```
144
+
145
+ ## Tips
146
+
147
+ - Use `-m/--multirun` for sweeps; outputs go under `hydra.sweep.dir`
148
+ - For Hydra flags and completion, use `--hydra-help`
149
+ - Config documentation: https://nn.sleap.ai/config/
91
150
  """
92
- click.echo(help_text)
151
+ console.print(
152
+ Panel(
153
+ Markdown(help_md),
154
+ title="[bold cyan]sleap-nn train[/bold cyan]",
155
+ subtitle="Train SLEAP models from a config YAML file",
156
+ border_style="cyan",
157
+ )
158
+ )
93
159
 
94
160
 
95
161
  @cli.command(cls=TrainCommand)
96
- @click.option("--config-name", "-c", type=str, help="Configuration file name")
97
162
  @click.option(
98
- "--config-dir", "-d", type=str, default=".", help="Configuration directory path"
163
+ "--config",
164
+ type=str,
165
+ help="Path to configuration file (e.g., path/to/config.yaml)",
166
+ )
167
+ @click.option("--config-name", "-c", type=str, help="Configuration file name (legacy)")
168
+ @click.option(
169
+ "--config-dir", "-d", type=str, default=".", help="Configuration directory (legacy)"
99
170
  )
100
171
  @click.option(
101
172
  "--video-paths",
@@ -128,25 +199,43 @@ For a detailed list of all available config options, please refer to https://nn.
128
199
  'Example: --prefix-map "/old/server/path" "/new/local/path"',
129
200
  )
130
201
  @click.argument("overrides", nargs=-1, type=click.UNPROCESSED)
131
- def train(config_name, config_dir, video_paths, video_path_map, prefix_map, overrides):
202
+ def train(
203
+ config, config_name, config_dir, video_paths, video_path_map, prefix_map, overrides
204
+ ):
132
205
  """Run training workflow with Hydra config overrides.
133
206
 
134
207
  Examples:
135
- sleap-nn train --config-name myconfig --config-dir /path/to/config_dir/
208
+ sleap-nn train path/to/config.yaml
209
+ sleap-nn train --config path/to/config.yaml trainer_config.max_epochs=100
136
210
  sleap-nn train -c myconfig -d /path/to/config_dir/ trainer_config.max_epochs=100
137
- sleap-nn train -c myconfig -d /path/to/config_dir/ +experiment=new_model
138
211
  """
139
- # Show help if no config name provided
140
- if not config_name:
212
+ # Convert overrides to a mutable list
213
+ overrides = list(overrides)
214
+
215
+ # Check if the first positional arg is a config path (not a Hydra override)
216
+ config_from_positional = None
217
+ if overrides and is_config_path(overrides[0]):
218
+ config_from_positional = overrides.pop(0)
219
+
220
+ # Resolve config path with priority:
221
+ # 1. Positional config path (e.g., sleap-nn train config.yaml)
222
+ # 2. --config flag (e.g., sleap-nn train --config config.yaml)
223
+ # 3. Legacy --config-dir/--config-name flags
224
+ if config_from_positional:
225
+ config_dir, config_name = split_config_path(config_from_positional)
226
+ elif config:
227
+ config_dir, config_name = split_config_path(config)
228
+ elif config_name:
229
+ config_dir = Path(config_dir).resolve().as_posix()
230
+ else:
231
+ # No config provided - show help
141
232
  show_training_help()
142
233
  return
143
234
 
144
- # Initialize Hydra manually
145
- # resolve the path to the config directory (hydra expects absolute path)
146
- config_dir = Path(config_dir).resolve().as_posix()
235
+ # Initialize Hydra manually (config_dir is already an absolute path)
147
236
  with hydra.initialize_config_dir(config_dir=config_dir, version_base=None):
148
237
  # Compose config with overrides
149
- cfg = hydra.compose(config_name=config_name, overrides=list(overrides))
238
+ cfg = hydra.compose(config_name=config_name, overrides=overrides)
150
239
 
151
240
  # Validate config
152
241
  if not hasattr(cfg, "model_config") or not cfg.model_config:
@@ -417,6 +506,36 @@ def train(config_name, config_dir, video_paths, video_path_map, prefix_map, over
417
506
  default=0.2,
418
507
  help="Minimum confidence map value to consider a peak as valid.",
419
508
  )
509
+ @click.option(
510
+ "--filter_overlapping",
511
+ is_flag=True,
512
+ default=False,
513
+ help=(
514
+ "Enable filtering of overlapping instances after inference using greedy NMS. "
515
+ "Applied independently of tracking. (default: False)"
516
+ ),
517
+ )
518
+ @click.option(
519
+ "--filter_overlapping_method",
520
+ type=click.Choice(["iou", "oks"]),
521
+ default="iou",
522
+ help=(
523
+ "Similarity metric for filtering overlapping instances. "
524
+ "'iou': bounding box intersection-over-union. "
525
+ "'oks': Object Keypoint Similarity (pose-based). (default: iou)"
526
+ ),
527
+ )
528
+ @click.option(
529
+ "--filter_overlapping_threshold",
530
+ type=float,
531
+ default=0.8,
532
+ help=(
533
+ "Similarity threshold for filtering overlapping instances. "
534
+ "Instances with similarity above this threshold are removed, "
535
+ "keeping the higher-scoring instance. "
536
+ "Typical values: 0.3 (aggressive) to 0.8 (permissive). (default: 0.8)"
537
+ ),
538
+ )
420
539
  @click.option(
421
540
  "--integral_refinement",
422
541
  type=str,
@@ -549,6 +668,12 @@ def train(config_name, config_dir, video_paths, video_path_map, prefix_map, over
549
668
  default=0,
550
669
  help="IOU to use when culling instances *after* tracking. (default: 0)",
551
670
  )
671
+ @click.option(
672
+ "--gui",
673
+ is_flag=True,
674
+ default=False,
675
+ help="Output JSON progress for GUI integration instead of Rich progress bar.",
676
+ )
552
677
  def track(**kwargs):
553
678
  """Run Inference and Tracking workflow."""
554
679
  # Convert model_paths from tuple to list
@@ -613,5 +738,9 @@ def system():
613
738
  print_system_info()
614
739
 
615
740
 
741
+ cli.add_command(export_command)
742
+ cli.add_command(predict_command)
743
+
744
+
616
745
  if __name__ == "__main__":
617
746
  cli()
sleap_nn/evaluation.py CHANGED
@@ -639,11 +639,19 @@ class Evaluator:
639
639
  mPCK_parts = pcks.mean(axis=0).mean(axis=-1)
640
640
  mPCK = mPCK_parts.mean()
641
641
 
642
+ # Precompute PCK at common thresholds
643
+ idx_5 = np.argmin(np.abs(thresholds - 5))
644
+ idx_10 = np.argmin(np.abs(thresholds - 10))
645
+ pck5 = pcks[:, :, idx_5].mean()
646
+ pck10 = pcks[:, :, idx_10].mean()
647
+
642
648
  return {
643
649
  "thresholds": thresholds,
644
650
  "pcks": pcks,
645
651
  "mPCK_parts": mPCK_parts,
646
652
  "mPCK": mPCK,
653
+ "PCK@5": pck5,
654
+ "PCK@10": pck10,
647
655
  }
648
656
 
649
657
  def visibility_metrics(self):
@@ -0,0 +1,21 @@
1
+ """Export utilities for sleap-nn."""
2
+
3
+ from sleap_nn.export.exporters import export_model, export_to_onnx, export_to_tensorrt
4
+ from sleap_nn.export.metadata import ExportMetadata
5
+ from sleap_nn.export.predictors import (
6
+ load_exported_model,
7
+ ONNXPredictor,
8
+ TensorRTPredictor,
9
+ )
10
+ from sleap_nn.export.utils import build_bottomup_candidate_template
11
+
12
+ __all__ = [
13
+ "export_model",
14
+ "export_to_onnx",
15
+ "export_to_tensorrt",
16
+ "load_exported_model",
17
+ "ONNXPredictor",
18
+ "TensorRTPredictor",
19
+ "ExportMetadata",
20
+ "build_bottomup_candidate_template",
21
+ ]