sleap-nn 0.1.0a2__py3-none-any.whl → 0.1.0a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +1 -1
- sleap_nn/cli.py +36 -0
- sleap_nn/evaluation.py +8 -0
- sleap_nn/export/__init__.py +21 -0
- sleap_nn/export/cli.py +1778 -0
- sleap_nn/export/exporters/__init__.py +51 -0
- sleap_nn/export/exporters/onnx_exporter.py +80 -0
- sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
- sleap_nn/export/metadata.py +225 -0
- sleap_nn/export/predictors/__init__.py +63 -0
- sleap_nn/export/predictors/base.py +22 -0
- sleap_nn/export/predictors/onnx.py +154 -0
- sleap_nn/export/predictors/tensorrt.py +312 -0
- sleap_nn/export/utils.py +307 -0
- sleap_nn/export/wrappers/__init__.py +25 -0
- sleap_nn/export/wrappers/base.py +96 -0
- sleap_nn/export/wrappers/bottomup.py +243 -0
- sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
- sleap_nn/export/wrappers/centered_instance.py +56 -0
- sleap_nn/export/wrappers/centroid.py +58 -0
- sleap_nn/export/wrappers/single_instance.py +83 -0
- sleap_nn/export/wrappers/topdown.py +180 -0
- sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
- sleap_nn/inference/postprocessing.py +284 -0
- sleap_nn/predict.py +29 -0
- sleap_nn/train.py +64 -0
- sleap_nn/training/callbacks.py +62 -20
- sleap_nn/training/lightning_modules.py +332 -30
- sleap_nn/training/model_trainer.py +35 -67
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/METADATA +12 -1
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/RECORD +35 -14
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/WHEEL +0 -0
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/top_level.txt +0 -0
sleap_nn/__init__.py
CHANGED
sleap_nn/cli.py
CHANGED
|
@@ -7,6 +7,8 @@ from omegaconf import OmegaConf, DictConfig
|
|
|
7
7
|
import sleap_io as sio
|
|
8
8
|
from sleap_nn.predict import run_inference, frame_list
|
|
9
9
|
from sleap_nn.evaluation import run_evaluation
|
|
10
|
+
from sleap_nn.export.cli import export as export_command
|
|
11
|
+
from sleap_nn.export.cli import predict as predict_command
|
|
10
12
|
from sleap_nn.train import run_training
|
|
11
13
|
from sleap_nn import __version__
|
|
12
14
|
import hydra
|
|
@@ -417,6 +419,36 @@ def train(config_name, config_dir, video_paths, video_path_map, prefix_map, over
|
|
|
417
419
|
default=0.2,
|
|
418
420
|
help="Minimum confidence map value to consider a peak as valid.",
|
|
419
421
|
)
|
|
422
|
+
@click.option(
|
|
423
|
+
"--filter_overlapping",
|
|
424
|
+
is_flag=True,
|
|
425
|
+
default=False,
|
|
426
|
+
help=(
|
|
427
|
+
"Enable filtering of overlapping instances after inference using greedy NMS. "
|
|
428
|
+
"Applied independently of tracking. (default: False)"
|
|
429
|
+
),
|
|
430
|
+
)
|
|
431
|
+
@click.option(
|
|
432
|
+
"--filter_overlapping_method",
|
|
433
|
+
type=click.Choice(["iou", "oks"]),
|
|
434
|
+
default="iou",
|
|
435
|
+
help=(
|
|
436
|
+
"Similarity metric for filtering overlapping instances. "
|
|
437
|
+
"'iou': bounding box intersection-over-union. "
|
|
438
|
+
"'oks': Object Keypoint Similarity (pose-based). (default: iou)"
|
|
439
|
+
),
|
|
440
|
+
)
|
|
441
|
+
@click.option(
|
|
442
|
+
"--filter_overlapping_threshold",
|
|
443
|
+
type=float,
|
|
444
|
+
default=0.8,
|
|
445
|
+
help=(
|
|
446
|
+
"Similarity threshold for filtering overlapping instances. "
|
|
447
|
+
"Instances with similarity above this threshold are removed, "
|
|
448
|
+
"keeping the higher-scoring instance. "
|
|
449
|
+
"Typical values: 0.3 (aggressive) to 0.8 (permissive). (default: 0.8)"
|
|
450
|
+
),
|
|
451
|
+
)
|
|
420
452
|
@click.option(
|
|
421
453
|
"--integral_refinement",
|
|
422
454
|
type=str,
|
|
@@ -613,5 +645,9 @@ def system():
|
|
|
613
645
|
print_system_info()
|
|
614
646
|
|
|
615
647
|
|
|
648
|
+
cli.add_command(export_command)
|
|
649
|
+
cli.add_command(predict_command)
|
|
650
|
+
|
|
651
|
+
|
|
616
652
|
if __name__ == "__main__":
|
|
617
653
|
cli()
|
sleap_nn/evaluation.py
CHANGED
|
@@ -639,11 +639,19 @@ class Evaluator:
|
|
|
639
639
|
mPCK_parts = pcks.mean(axis=0).mean(axis=-1)
|
|
640
640
|
mPCK = mPCK_parts.mean()
|
|
641
641
|
|
|
642
|
+
# Precompute PCK at common thresholds
|
|
643
|
+
idx_5 = np.argmin(np.abs(thresholds - 5))
|
|
644
|
+
idx_10 = np.argmin(np.abs(thresholds - 10))
|
|
645
|
+
pck5 = pcks[:, :, idx_5].mean()
|
|
646
|
+
pck10 = pcks[:, :, idx_10].mean()
|
|
647
|
+
|
|
642
648
|
return {
|
|
643
649
|
"thresholds": thresholds,
|
|
644
650
|
"pcks": pcks,
|
|
645
651
|
"mPCK_parts": mPCK_parts,
|
|
646
652
|
"mPCK": mPCK,
|
|
653
|
+
"PCK@5": pck5,
|
|
654
|
+
"PCK@10": pck10,
|
|
647
655
|
}
|
|
648
656
|
|
|
649
657
|
def visibility_metrics(self):
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""Export utilities for sleap-nn."""
|
|
2
|
+
|
|
3
|
+
from sleap_nn.export.exporters import export_model, export_to_onnx, export_to_tensorrt
|
|
4
|
+
from sleap_nn.export.metadata import ExportMetadata
|
|
5
|
+
from sleap_nn.export.predictors import (
|
|
6
|
+
load_exported_model,
|
|
7
|
+
ONNXPredictor,
|
|
8
|
+
TensorRTPredictor,
|
|
9
|
+
)
|
|
10
|
+
from sleap_nn.export.utils import build_bottomup_candidate_template
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"export_model",
|
|
14
|
+
"export_to_onnx",
|
|
15
|
+
"export_to_tensorrt",
|
|
16
|
+
"load_exported_model",
|
|
17
|
+
"ONNXPredictor",
|
|
18
|
+
"TensorRTPredictor",
|
|
19
|
+
"ExportMetadata",
|
|
20
|
+
"build_bottomup_candidate_template",
|
|
21
|
+
]
|