sleap-nn 0.1.0a2__py3-none-any.whl → 0.1.0a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +1 -1
- sleap_nn/architectures/convnext.py +5 -0
- sleap_nn/architectures/encoder_decoder.py +25 -6
- sleap_nn/architectures/swint.py +8 -0
- sleap_nn/cli.py +168 -39
- sleap_nn/evaluation.py +8 -0
- sleap_nn/export/__init__.py +21 -0
- sleap_nn/export/cli.py +1778 -0
- sleap_nn/export/exporters/__init__.py +51 -0
- sleap_nn/export/exporters/onnx_exporter.py +80 -0
- sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
- sleap_nn/export/metadata.py +225 -0
- sleap_nn/export/predictors/__init__.py +63 -0
- sleap_nn/export/predictors/base.py +22 -0
- sleap_nn/export/predictors/onnx.py +154 -0
- sleap_nn/export/predictors/tensorrt.py +312 -0
- sleap_nn/export/utils.py +307 -0
- sleap_nn/export/wrappers/__init__.py +25 -0
- sleap_nn/export/wrappers/base.py +96 -0
- sleap_nn/export/wrappers/bottomup.py +243 -0
- sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
- sleap_nn/export/wrappers/centered_instance.py +56 -0
- sleap_nn/export/wrappers/centroid.py +58 -0
- sleap_nn/export/wrappers/single_instance.py +83 -0
- sleap_nn/export/wrappers/topdown.py +180 -0
- sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
- sleap_nn/inference/peak_finding.py +47 -17
- sleap_nn/inference/postprocessing.py +284 -0
- sleap_nn/inference/predictors.py +213 -106
- sleap_nn/predict.py +35 -7
- sleap_nn/train.py +64 -0
- sleap_nn/training/callbacks.py +69 -22
- sleap_nn/training/lightning_modules.py +332 -30
- sleap_nn/training/model_trainer.py +67 -67
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/METADATA +13 -1
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/RECORD +40 -19
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/WHEEL +0 -0
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/top_level.txt +0 -0
sleap_nn/__init__.py
CHANGED
|
@@ -281,6 +281,10 @@ class ConvNextWrapper(nn.Module):
|
|
|
281
281
|
# Keep the block output filters the same
|
|
282
282
|
x_in_shape = int(self.arch["channels"][-1] * filters_rate)
|
|
283
283
|
|
|
284
|
+
# Encoder channels for skip connections (reversed to match decoder order)
|
|
285
|
+
# The forward pass uses enc_output[::2][::-1] for skip features
|
|
286
|
+
encoder_channels = self.arch["channels"][::-1]
|
|
287
|
+
|
|
284
288
|
self.dec = Decoder(
|
|
285
289
|
x_in_shape=x_in_shape,
|
|
286
290
|
current_stride=self.current_stride,
|
|
@@ -293,6 +297,7 @@ class ConvNextWrapper(nn.Module):
|
|
|
293
297
|
block_contraction=self.block_contraction,
|
|
294
298
|
output_stride=self.output_stride,
|
|
295
299
|
up_interpolate=up_interpolate,
|
|
300
|
+
encoder_channels=encoder_channels,
|
|
296
301
|
)
|
|
297
302
|
|
|
298
303
|
if len(self.dec.decoder_stack):
|
|
@@ -25,7 +25,7 @@ classes.
|
|
|
25
25
|
See the `EncoderDecoder` base class for requirements for creating new architectures.
|
|
26
26
|
"""
|
|
27
27
|
|
|
28
|
-
from typing import List, Text, Tuple, Union
|
|
28
|
+
from typing import List, Optional, Text, Tuple, Union
|
|
29
29
|
from collections import OrderedDict
|
|
30
30
|
import torch
|
|
31
31
|
from torch import nn
|
|
@@ -391,10 +391,18 @@ class SimpleUpsamplingBlock(nn.Module):
|
|
|
391
391
|
transpose_convs_activation: Text = "relu",
|
|
392
392
|
feat_concat: bool = True,
|
|
393
393
|
prefix: Text = "",
|
|
394
|
+
skip_channels: Optional[int] = None,
|
|
394
395
|
) -> None:
|
|
395
396
|
"""Initialize the class."""
|
|
396
397
|
super().__init__()
|
|
397
398
|
|
|
399
|
+
# Determine skip connection channels
|
|
400
|
+
# If skip_channels is provided, use it; otherwise fall back to refine_convs_filters
|
|
401
|
+
# This allows ConvNext/SwinT to specify actual encoder channels
|
|
402
|
+
self.skip_channels = (
|
|
403
|
+
skip_channels if skip_channels is not None else refine_convs_filters
|
|
404
|
+
)
|
|
405
|
+
|
|
398
406
|
self.x_in_shape = x_in_shape
|
|
399
407
|
self.current_stride = current_stride
|
|
400
408
|
self.upsampling_stride = upsampling_stride
|
|
@@ -469,13 +477,13 @@ class SimpleUpsamplingBlock(nn.Module):
|
|
|
469
477
|
first_conv_in_channels = refine_convs_filters
|
|
470
478
|
else:
|
|
471
479
|
if self.up_interpolate:
|
|
472
|
-
# With interpolation, input is x_in_shape +
|
|
473
|
-
#
|
|
474
|
-
first_conv_in_channels = x_in_shape +
|
|
480
|
+
# With interpolation, input is x_in_shape + skip_channels
|
|
481
|
+
# skip_channels may differ from refine_convs_filters for ConvNext/SwinT
|
|
482
|
+
first_conv_in_channels = x_in_shape + self.skip_channels
|
|
475
483
|
else:
|
|
476
|
-
# With transpose conv, input is transpose_conv_output +
|
|
484
|
+
# With transpose conv, input is transpose_conv_output + skip_channels
|
|
477
485
|
first_conv_in_channels = (
|
|
478
|
-
|
|
486
|
+
self.skip_channels + transpose_convs_filters
|
|
479
487
|
)
|
|
480
488
|
else:
|
|
481
489
|
if not self.feat_concat:
|
|
@@ -582,6 +590,7 @@ class Decoder(nn.Module):
|
|
|
582
590
|
block_contraction: bool = False,
|
|
583
591
|
up_interpolate: bool = True,
|
|
584
592
|
prefix: str = "dec",
|
|
593
|
+
encoder_channels: Optional[List[int]] = None,
|
|
585
594
|
) -> None:
|
|
586
595
|
"""Initialize the class."""
|
|
587
596
|
super().__init__()
|
|
@@ -598,6 +607,7 @@ class Decoder(nn.Module):
|
|
|
598
607
|
self.block_contraction = block_contraction
|
|
599
608
|
self.prefix = prefix
|
|
600
609
|
self.stride_to_filters = {}
|
|
610
|
+
self.encoder_channels = encoder_channels
|
|
601
611
|
|
|
602
612
|
self.current_strides = []
|
|
603
613
|
self.residuals = 0
|
|
@@ -624,6 +634,13 @@ class Decoder(nn.Module):
|
|
|
624
634
|
|
|
625
635
|
next_stride = current_stride // 2
|
|
626
636
|
|
|
637
|
+
# Determine skip channels for this decoder block
|
|
638
|
+
# If encoder_channels provided, use actual encoder channels
|
|
639
|
+
# Otherwise fall back to computed filters (for UNet compatibility)
|
|
640
|
+
skip_channels = None
|
|
641
|
+
if encoder_channels is not None and block < len(encoder_channels):
|
|
642
|
+
skip_channels = encoder_channels[block]
|
|
643
|
+
|
|
627
644
|
if self.stem_blocks > 0 and block >= down_blocks + self.stem_blocks:
|
|
628
645
|
# This accounts for the case where we dont have any more down block features to concatenate with.
|
|
629
646
|
# In this case, add a simple upsampling block with a conv layer and with no concatenation
|
|
@@ -642,6 +659,7 @@ class Decoder(nn.Module):
|
|
|
642
659
|
transpose_convs_batch_norm=False,
|
|
643
660
|
feat_concat=False,
|
|
644
661
|
prefix=f"{self.prefix}{block}_s{current_stride}_to_s{next_stride}",
|
|
662
|
+
skip_channels=skip_channels,
|
|
645
663
|
)
|
|
646
664
|
)
|
|
647
665
|
else:
|
|
@@ -659,6 +677,7 @@ class Decoder(nn.Module):
|
|
|
659
677
|
transpose_convs_filters=block_filters_out,
|
|
660
678
|
transpose_convs_batch_norm=False,
|
|
661
679
|
prefix=f"{self.prefix}{block}_s{current_stride}_to_s{next_stride}",
|
|
680
|
+
skip_channels=skip_channels,
|
|
662
681
|
)
|
|
663
682
|
)
|
|
664
683
|
|
sleap_nn/architectures/swint.py
CHANGED
|
@@ -309,6 +309,13 @@ class SwinTWrapper(nn.Module):
|
|
|
309
309
|
self.stem_patch_stride * (2**3) * 2
|
|
310
310
|
) # stem_stride * down_blocks_stride * final_max_pool_stride
|
|
311
311
|
|
|
312
|
+
# Encoder channels for skip connections (reversed to match decoder order)
|
|
313
|
+
# SwinT channels: embed * 2^i for each stage i, then reversed
|
|
314
|
+
num_stages = len(self.arch["depths"])
|
|
315
|
+
encoder_channels = [
|
|
316
|
+
self.arch["embed"] * (2 ** (num_stages - 1 - i)) for i in range(num_stages)
|
|
317
|
+
]
|
|
318
|
+
|
|
312
319
|
self.dec = Decoder(
|
|
313
320
|
x_in_shape=block_filters,
|
|
314
321
|
current_stride=self.current_stride,
|
|
@@ -321,6 +328,7 @@ class SwinTWrapper(nn.Module):
|
|
|
321
328
|
block_contraction=self.block_contraction,
|
|
322
329
|
output_stride=output_stride,
|
|
323
330
|
up_interpolate=up_interpolate,
|
|
331
|
+
encoder_channels=encoder_channels,
|
|
324
332
|
)
|
|
325
333
|
|
|
326
334
|
if len(self.dec.decoder_stack):
|
sleap_nn/cli.py
CHANGED
|
@@ -1,17 +1,49 @@
|
|
|
1
|
-
"""Unified CLI for SLEAP-NN using
|
|
1
|
+
"""Unified CLI for SLEAP-NN using rich-click for styled output."""
|
|
2
2
|
|
|
3
|
-
import click
|
|
3
|
+
import rich_click as click
|
|
4
|
+
from click import Command
|
|
4
5
|
from loguru import logger
|
|
5
6
|
from pathlib import Path
|
|
6
7
|
from omegaconf import OmegaConf, DictConfig
|
|
7
8
|
import sleap_io as sio
|
|
8
9
|
from sleap_nn.predict import run_inference, frame_list
|
|
9
10
|
from sleap_nn.evaluation import run_evaluation
|
|
11
|
+
from sleap_nn.export.cli import export as export_command
|
|
12
|
+
from sleap_nn.export.cli import predict as predict_command
|
|
10
13
|
from sleap_nn.train import run_training
|
|
11
14
|
from sleap_nn import __version__
|
|
12
15
|
import hydra
|
|
13
16
|
import sys
|
|
14
|
-
|
|
17
|
+
|
|
18
|
+
# Rich-click configuration for styled help
|
|
19
|
+
click.rich_click.TEXT_MARKUP = "markdown"
|
|
20
|
+
click.rich_click.SHOW_ARGUMENTS = True
|
|
21
|
+
click.rich_click.GROUP_ARGUMENTS_OPTIONS = True
|
|
22
|
+
click.rich_click.STYLE_ERRORS_SUGGESTION = "magenta italic"
|
|
23
|
+
click.rich_click.ERRORS_EPILOGUE = (
|
|
24
|
+
"Try 'sleap-nn [COMMAND] --help' for more information."
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def is_config_path(arg: str) -> bool:
|
|
29
|
+
"""Check if an argument looks like a config file path.
|
|
30
|
+
|
|
31
|
+
Returns True if the arg ends with .yaml or .yml.
|
|
32
|
+
"""
|
|
33
|
+
return arg.endswith(".yaml") or arg.endswith(".yml")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def split_config_path(config_path: str) -> tuple:
|
|
37
|
+
"""Split a full config path into (config_dir, config_name).
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
config_path: Full path to a config file.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
Tuple of (config_dir, config_name) where config_dir is an absolute path.
|
|
44
|
+
"""
|
|
45
|
+
path = Path(config_path).resolve()
|
|
46
|
+
return path.parent.as_posix(), path.name
|
|
15
47
|
|
|
16
48
|
|
|
17
49
|
def print_version(ctx, param, value):
|
|
@@ -64,38 +96,77 @@ def cli():
|
|
|
64
96
|
|
|
65
97
|
|
|
66
98
|
def show_training_help():
|
|
67
|
-
"""Display training help information."""
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
99
|
+
"""Display training help information with rich formatting."""
|
|
100
|
+
from rich.console import Console
|
|
101
|
+
from rich.panel import Panel
|
|
102
|
+
from rich.markdown import Markdown
|
|
103
|
+
|
|
104
|
+
console = Console()
|
|
105
|
+
|
|
106
|
+
help_md = """
|
|
107
|
+
## Usage
|
|
108
|
+
|
|
109
|
+
```
|
|
110
|
+
sleap-nn train <config.yaml> [overrides]
|
|
111
|
+
sleap-nn train --config <path/to/config.yaml> [overrides]
|
|
112
|
+
```
|
|
113
|
+
|
|
114
|
+
## Common Overrides
|
|
115
|
+
|
|
116
|
+
| Override | Description |
|
|
117
|
+
|----------|-------------|
|
|
118
|
+
| `trainer_config.max_epochs=100` | Set maximum training epochs |
|
|
119
|
+
| `trainer_config.batch_size=32` | Set batch size |
|
|
120
|
+
| `trainer_config.save_ckpt=true` | Enable checkpoint saving |
|
|
121
|
+
|
|
122
|
+
## Examples
|
|
123
|
+
|
|
124
|
+
**Start a new training run:**
|
|
125
|
+
```bash
|
|
126
|
+
sleap-nn train path/to/config.yaml
|
|
127
|
+
sleap-nn train --config path/to/config.yaml
|
|
128
|
+
```
|
|
129
|
+
|
|
130
|
+
**With overrides:**
|
|
131
|
+
```bash
|
|
132
|
+
sleap-nn train config.yaml trainer_config.max_epochs=100
|
|
133
|
+
```
|
|
134
|
+
|
|
135
|
+
**Resume training:**
|
|
136
|
+
```bash
|
|
137
|
+
sleap-nn train config.yaml trainer_config.resume_ckpt_path=/path/to/ckpt
|
|
138
|
+
```
|
|
139
|
+
|
|
140
|
+
**Legacy usage (still supported):**
|
|
141
|
+
```bash
|
|
142
|
+
sleap-nn train --config-dir /path/to/dir --config-name myrun
|
|
143
|
+
```
|
|
144
|
+
|
|
145
|
+
## Tips
|
|
146
|
+
|
|
147
|
+
- Use `-m/--multirun` for sweeps; outputs go under `hydra.sweep.dir`
|
|
148
|
+
- For Hydra flags and completion, use `--hydra-help`
|
|
149
|
+
- Config documentation: https://nn.sleap.ai/config/
|
|
91
150
|
"""
|
|
92
|
-
|
|
151
|
+
console.print(
|
|
152
|
+
Panel(
|
|
153
|
+
Markdown(help_md),
|
|
154
|
+
title="[bold cyan]sleap-nn train[/bold cyan]",
|
|
155
|
+
subtitle="Train SLEAP models from a config YAML file",
|
|
156
|
+
border_style="cyan",
|
|
157
|
+
)
|
|
158
|
+
)
|
|
93
159
|
|
|
94
160
|
|
|
95
161
|
@cli.command(cls=TrainCommand)
|
|
96
|
-
@click.option("--config-name", "-c", type=str, help="Configuration file name")
|
|
97
162
|
@click.option(
|
|
98
|
-
"--config
|
|
163
|
+
"--config",
|
|
164
|
+
type=str,
|
|
165
|
+
help="Path to configuration file (e.g., path/to/config.yaml)",
|
|
166
|
+
)
|
|
167
|
+
@click.option("--config-name", "-c", type=str, help="Configuration file name (legacy)")
|
|
168
|
+
@click.option(
|
|
169
|
+
"--config-dir", "-d", type=str, default=".", help="Configuration directory (legacy)"
|
|
99
170
|
)
|
|
100
171
|
@click.option(
|
|
101
172
|
"--video-paths",
|
|
@@ -128,25 +199,43 @@ For a detailed list of all available config options, please refer to https://nn.
|
|
|
128
199
|
'Example: --prefix-map "/old/server/path" "/new/local/path"',
|
|
129
200
|
)
|
|
130
201
|
@click.argument("overrides", nargs=-1, type=click.UNPROCESSED)
|
|
131
|
-
def train(
|
|
202
|
+
def train(
|
|
203
|
+
config, config_name, config_dir, video_paths, video_path_map, prefix_map, overrides
|
|
204
|
+
):
|
|
132
205
|
"""Run training workflow with Hydra config overrides.
|
|
133
206
|
|
|
134
207
|
Examples:
|
|
135
|
-
sleap-nn train
|
|
208
|
+
sleap-nn train path/to/config.yaml
|
|
209
|
+
sleap-nn train --config path/to/config.yaml trainer_config.max_epochs=100
|
|
136
210
|
sleap-nn train -c myconfig -d /path/to/config_dir/ trainer_config.max_epochs=100
|
|
137
|
-
sleap-nn train -c myconfig -d /path/to/config_dir/ +experiment=new_model
|
|
138
211
|
"""
|
|
139
|
-
#
|
|
140
|
-
|
|
212
|
+
# Convert overrides to a mutable list
|
|
213
|
+
overrides = list(overrides)
|
|
214
|
+
|
|
215
|
+
# Check if the first positional arg is a config path (not a Hydra override)
|
|
216
|
+
config_from_positional = None
|
|
217
|
+
if overrides and is_config_path(overrides[0]):
|
|
218
|
+
config_from_positional = overrides.pop(0)
|
|
219
|
+
|
|
220
|
+
# Resolve config path with priority:
|
|
221
|
+
# 1. Positional config path (e.g., sleap-nn train config.yaml)
|
|
222
|
+
# 2. --config flag (e.g., sleap-nn train --config config.yaml)
|
|
223
|
+
# 3. Legacy --config-dir/--config-name flags
|
|
224
|
+
if config_from_positional:
|
|
225
|
+
config_dir, config_name = split_config_path(config_from_positional)
|
|
226
|
+
elif config:
|
|
227
|
+
config_dir, config_name = split_config_path(config)
|
|
228
|
+
elif config_name:
|
|
229
|
+
config_dir = Path(config_dir).resolve().as_posix()
|
|
230
|
+
else:
|
|
231
|
+
# No config provided - show help
|
|
141
232
|
show_training_help()
|
|
142
233
|
return
|
|
143
234
|
|
|
144
|
-
# Initialize Hydra manually
|
|
145
|
-
# resolve the path to the config directory (hydra expects absolute path)
|
|
146
|
-
config_dir = Path(config_dir).resolve().as_posix()
|
|
235
|
+
# Initialize Hydra manually (config_dir is already an absolute path)
|
|
147
236
|
with hydra.initialize_config_dir(config_dir=config_dir, version_base=None):
|
|
148
237
|
# Compose config with overrides
|
|
149
|
-
cfg = hydra.compose(config_name=config_name, overrides=
|
|
238
|
+
cfg = hydra.compose(config_name=config_name, overrides=overrides)
|
|
150
239
|
|
|
151
240
|
# Validate config
|
|
152
241
|
if not hasattr(cfg, "model_config") or not cfg.model_config:
|
|
@@ -417,6 +506,36 @@ def train(config_name, config_dir, video_paths, video_path_map, prefix_map, over
|
|
|
417
506
|
default=0.2,
|
|
418
507
|
help="Minimum confidence map value to consider a peak as valid.",
|
|
419
508
|
)
|
|
509
|
+
@click.option(
|
|
510
|
+
"--filter_overlapping",
|
|
511
|
+
is_flag=True,
|
|
512
|
+
default=False,
|
|
513
|
+
help=(
|
|
514
|
+
"Enable filtering of overlapping instances after inference using greedy NMS. "
|
|
515
|
+
"Applied independently of tracking. (default: False)"
|
|
516
|
+
),
|
|
517
|
+
)
|
|
518
|
+
@click.option(
|
|
519
|
+
"--filter_overlapping_method",
|
|
520
|
+
type=click.Choice(["iou", "oks"]),
|
|
521
|
+
default="iou",
|
|
522
|
+
help=(
|
|
523
|
+
"Similarity metric for filtering overlapping instances. "
|
|
524
|
+
"'iou': bounding box intersection-over-union. "
|
|
525
|
+
"'oks': Object Keypoint Similarity (pose-based). (default: iou)"
|
|
526
|
+
),
|
|
527
|
+
)
|
|
528
|
+
@click.option(
|
|
529
|
+
"--filter_overlapping_threshold",
|
|
530
|
+
type=float,
|
|
531
|
+
default=0.8,
|
|
532
|
+
help=(
|
|
533
|
+
"Similarity threshold for filtering overlapping instances. "
|
|
534
|
+
"Instances with similarity above this threshold are removed, "
|
|
535
|
+
"keeping the higher-scoring instance. "
|
|
536
|
+
"Typical values: 0.3 (aggressive) to 0.8 (permissive). (default: 0.8)"
|
|
537
|
+
),
|
|
538
|
+
)
|
|
420
539
|
@click.option(
|
|
421
540
|
"--integral_refinement",
|
|
422
541
|
type=str,
|
|
@@ -549,6 +668,12 @@ def train(config_name, config_dir, video_paths, video_path_map, prefix_map, over
|
|
|
549
668
|
default=0,
|
|
550
669
|
help="IOU to use when culling instances *after* tracking. (default: 0)",
|
|
551
670
|
)
|
|
671
|
+
@click.option(
|
|
672
|
+
"--gui",
|
|
673
|
+
is_flag=True,
|
|
674
|
+
default=False,
|
|
675
|
+
help="Output JSON progress for GUI integration instead of Rich progress bar.",
|
|
676
|
+
)
|
|
552
677
|
def track(**kwargs):
|
|
553
678
|
"""Run Inference and Tracking workflow."""
|
|
554
679
|
# Convert model_paths from tuple to list
|
|
@@ -613,5 +738,9 @@ def system():
|
|
|
613
738
|
print_system_info()
|
|
614
739
|
|
|
615
740
|
|
|
741
|
+
cli.add_command(export_command)
|
|
742
|
+
cli.add_command(predict_command)
|
|
743
|
+
|
|
744
|
+
|
|
616
745
|
if __name__ == "__main__":
|
|
617
746
|
cli()
|
sleap_nn/evaluation.py
CHANGED
|
@@ -639,11 +639,19 @@ class Evaluator:
|
|
|
639
639
|
mPCK_parts = pcks.mean(axis=0).mean(axis=-1)
|
|
640
640
|
mPCK = mPCK_parts.mean()
|
|
641
641
|
|
|
642
|
+
# Precompute PCK at common thresholds
|
|
643
|
+
idx_5 = np.argmin(np.abs(thresholds - 5))
|
|
644
|
+
idx_10 = np.argmin(np.abs(thresholds - 10))
|
|
645
|
+
pck5 = pcks[:, :, idx_5].mean()
|
|
646
|
+
pck10 = pcks[:, :, idx_10].mean()
|
|
647
|
+
|
|
642
648
|
return {
|
|
643
649
|
"thresholds": thresholds,
|
|
644
650
|
"pcks": pcks,
|
|
645
651
|
"mPCK_parts": mPCK_parts,
|
|
646
652
|
"mPCK": mPCK,
|
|
653
|
+
"PCK@5": pck5,
|
|
654
|
+
"PCK@10": pck10,
|
|
647
655
|
}
|
|
648
656
|
|
|
649
657
|
def visibility_metrics(self):
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""Export utilities for sleap-nn."""
|
|
2
|
+
|
|
3
|
+
from sleap_nn.export.exporters import export_model, export_to_onnx, export_to_tensorrt
|
|
4
|
+
from sleap_nn.export.metadata import ExportMetadata
|
|
5
|
+
from sleap_nn.export.predictors import (
|
|
6
|
+
load_exported_model,
|
|
7
|
+
ONNXPredictor,
|
|
8
|
+
TensorRTPredictor,
|
|
9
|
+
)
|
|
10
|
+
from sleap_nn.export.utils import build_bottomup_candidate_template
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"export_model",
|
|
14
|
+
"export_to_onnx",
|
|
15
|
+
"export_to_tensorrt",
|
|
16
|
+
"load_exported_model",
|
|
17
|
+
"ONNXPredictor",
|
|
18
|
+
"TensorRTPredictor",
|
|
19
|
+
"ExportMetadata",
|
|
20
|
+
"build_bottomup_candidate_template",
|
|
21
|
+
]
|