sleap-nn 0.1.0a3__py3-none-any.whl → 0.1.0a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +1 -1
- sleap_nn/architectures/convnext.py +5 -0
- sleap_nn/architectures/encoder_decoder.py +25 -6
- sleap_nn/architectures/swint.py +8 -0
- sleap_nn/cli.py +132 -39
- sleap_nn/inference/peak_finding.py +47 -17
- sleap_nn/inference/predictors.py +213 -106
- sleap_nn/predict.py +6 -7
- sleap_nn/training/callbacks.py +7 -2
- sleap_nn/training/model_trainer.py +32 -0
- {sleap_nn-0.1.0a3.dist-info → sleap_nn-0.1.0a4.dist-info}/METADATA +2 -1
- {sleap_nn-0.1.0a3.dist-info → sleap_nn-0.1.0a4.dist-info}/RECORD +16 -16
- {sleap_nn-0.1.0a3.dist-info → sleap_nn-0.1.0a4.dist-info}/WHEEL +0 -0
- {sleap_nn-0.1.0a3.dist-info → sleap_nn-0.1.0a4.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.1.0a3.dist-info → sleap_nn-0.1.0a4.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.1.0a3.dist-info → sleap_nn-0.1.0a4.dist-info}/top_level.txt +0 -0
sleap_nn/__init__.py
CHANGED
|
@@ -281,6 +281,10 @@ class ConvNextWrapper(nn.Module):
|
|
|
281
281
|
# Keep the block output filters the same
|
|
282
282
|
x_in_shape = int(self.arch["channels"][-1] * filters_rate)
|
|
283
283
|
|
|
284
|
+
# Encoder channels for skip connections (reversed to match decoder order)
|
|
285
|
+
# The forward pass uses enc_output[::2][::-1] for skip features
|
|
286
|
+
encoder_channels = self.arch["channels"][::-1]
|
|
287
|
+
|
|
284
288
|
self.dec = Decoder(
|
|
285
289
|
x_in_shape=x_in_shape,
|
|
286
290
|
current_stride=self.current_stride,
|
|
@@ -293,6 +297,7 @@ class ConvNextWrapper(nn.Module):
|
|
|
293
297
|
block_contraction=self.block_contraction,
|
|
294
298
|
output_stride=self.output_stride,
|
|
295
299
|
up_interpolate=up_interpolate,
|
|
300
|
+
encoder_channels=encoder_channels,
|
|
296
301
|
)
|
|
297
302
|
|
|
298
303
|
if len(self.dec.decoder_stack):
|
|
@@ -25,7 +25,7 @@ classes.
|
|
|
25
25
|
See the `EncoderDecoder` base class for requirements for creating new architectures.
|
|
26
26
|
"""
|
|
27
27
|
|
|
28
|
-
from typing import List, Text, Tuple, Union
|
|
28
|
+
from typing import List, Optional, Text, Tuple, Union
|
|
29
29
|
from collections import OrderedDict
|
|
30
30
|
import torch
|
|
31
31
|
from torch import nn
|
|
@@ -391,10 +391,18 @@ class SimpleUpsamplingBlock(nn.Module):
|
|
|
391
391
|
transpose_convs_activation: Text = "relu",
|
|
392
392
|
feat_concat: bool = True,
|
|
393
393
|
prefix: Text = "",
|
|
394
|
+
skip_channels: Optional[int] = None,
|
|
394
395
|
) -> None:
|
|
395
396
|
"""Initialize the class."""
|
|
396
397
|
super().__init__()
|
|
397
398
|
|
|
399
|
+
# Determine skip connection channels
|
|
400
|
+
# If skip_channels is provided, use it; otherwise fall back to refine_convs_filters
|
|
401
|
+
# This allows ConvNext/SwinT to specify actual encoder channels
|
|
402
|
+
self.skip_channels = (
|
|
403
|
+
skip_channels if skip_channels is not None else refine_convs_filters
|
|
404
|
+
)
|
|
405
|
+
|
|
398
406
|
self.x_in_shape = x_in_shape
|
|
399
407
|
self.current_stride = current_stride
|
|
400
408
|
self.upsampling_stride = upsampling_stride
|
|
@@ -469,13 +477,13 @@ class SimpleUpsamplingBlock(nn.Module):
|
|
|
469
477
|
first_conv_in_channels = refine_convs_filters
|
|
470
478
|
else:
|
|
471
479
|
if self.up_interpolate:
|
|
472
|
-
# With interpolation, input is x_in_shape +
|
|
473
|
-
#
|
|
474
|
-
first_conv_in_channels = x_in_shape +
|
|
480
|
+
# With interpolation, input is x_in_shape + skip_channels
|
|
481
|
+
# skip_channels may differ from refine_convs_filters for ConvNext/SwinT
|
|
482
|
+
first_conv_in_channels = x_in_shape + self.skip_channels
|
|
475
483
|
else:
|
|
476
|
-
# With transpose conv, input is transpose_conv_output +
|
|
484
|
+
# With transpose conv, input is transpose_conv_output + skip_channels
|
|
477
485
|
first_conv_in_channels = (
|
|
478
|
-
|
|
486
|
+
self.skip_channels + transpose_convs_filters
|
|
479
487
|
)
|
|
480
488
|
else:
|
|
481
489
|
if not self.feat_concat:
|
|
@@ -582,6 +590,7 @@ class Decoder(nn.Module):
|
|
|
582
590
|
block_contraction: bool = False,
|
|
583
591
|
up_interpolate: bool = True,
|
|
584
592
|
prefix: str = "dec",
|
|
593
|
+
encoder_channels: Optional[List[int]] = None,
|
|
585
594
|
) -> None:
|
|
586
595
|
"""Initialize the class."""
|
|
587
596
|
super().__init__()
|
|
@@ -598,6 +607,7 @@ class Decoder(nn.Module):
|
|
|
598
607
|
self.block_contraction = block_contraction
|
|
599
608
|
self.prefix = prefix
|
|
600
609
|
self.stride_to_filters = {}
|
|
610
|
+
self.encoder_channels = encoder_channels
|
|
601
611
|
|
|
602
612
|
self.current_strides = []
|
|
603
613
|
self.residuals = 0
|
|
@@ -624,6 +634,13 @@ class Decoder(nn.Module):
|
|
|
624
634
|
|
|
625
635
|
next_stride = current_stride // 2
|
|
626
636
|
|
|
637
|
+
# Determine skip channels for this decoder block
|
|
638
|
+
# If encoder_channels provided, use actual encoder channels
|
|
639
|
+
# Otherwise fall back to computed filters (for UNet compatibility)
|
|
640
|
+
skip_channels = None
|
|
641
|
+
if encoder_channels is not None and block < len(encoder_channels):
|
|
642
|
+
skip_channels = encoder_channels[block]
|
|
643
|
+
|
|
627
644
|
if self.stem_blocks > 0 and block >= down_blocks + self.stem_blocks:
|
|
628
645
|
# This accounts for the case where we dont have any more down block features to concatenate with.
|
|
629
646
|
# In this case, add a simple upsampling block with a conv layer and with no concatenation
|
|
@@ -642,6 +659,7 @@ class Decoder(nn.Module):
|
|
|
642
659
|
transpose_convs_batch_norm=False,
|
|
643
660
|
feat_concat=False,
|
|
644
661
|
prefix=f"{self.prefix}{block}_s{current_stride}_to_s{next_stride}",
|
|
662
|
+
skip_channels=skip_channels,
|
|
645
663
|
)
|
|
646
664
|
)
|
|
647
665
|
else:
|
|
@@ -659,6 +677,7 @@ class Decoder(nn.Module):
|
|
|
659
677
|
transpose_convs_filters=block_filters_out,
|
|
660
678
|
transpose_convs_batch_norm=False,
|
|
661
679
|
prefix=f"{self.prefix}{block}_s{current_stride}_to_s{next_stride}",
|
|
680
|
+
skip_channels=skip_channels,
|
|
662
681
|
)
|
|
663
682
|
)
|
|
664
683
|
|
sleap_nn/architectures/swint.py
CHANGED
|
@@ -309,6 +309,13 @@ class SwinTWrapper(nn.Module):
|
|
|
309
309
|
self.stem_patch_stride * (2**3) * 2
|
|
310
310
|
) # stem_stride * down_blocks_stride * final_max_pool_stride
|
|
311
311
|
|
|
312
|
+
# Encoder channels for skip connections (reversed to match decoder order)
|
|
313
|
+
# SwinT channels: embed * 2^i for each stage i, then reversed
|
|
314
|
+
num_stages = len(self.arch["depths"])
|
|
315
|
+
encoder_channels = [
|
|
316
|
+
self.arch["embed"] * (2 ** (num_stages - 1 - i)) for i in range(num_stages)
|
|
317
|
+
]
|
|
318
|
+
|
|
312
319
|
self.dec = Decoder(
|
|
313
320
|
x_in_shape=block_filters,
|
|
314
321
|
current_stride=self.current_stride,
|
|
@@ -321,6 +328,7 @@ class SwinTWrapper(nn.Module):
|
|
|
321
328
|
block_contraction=self.block_contraction,
|
|
322
329
|
output_stride=output_stride,
|
|
323
330
|
up_interpolate=up_interpolate,
|
|
331
|
+
encoder_channels=encoder_channels,
|
|
324
332
|
)
|
|
325
333
|
|
|
326
334
|
if len(self.dec.decoder_stack):
|
sleap_nn/cli.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
|
-
"""Unified CLI for SLEAP-NN using
|
|
1
|
+
"""Unified CLI for SLEAP-NN using rich-click for styled output."""
|
|
2
2
|
|
|
3
|
-
import click
|
|
3
|
+
import rich_click as click
|
|
4
|
+
from click import Command
|
|
4
5
|
from loguru import logger
|
|
5
6
|
from pathlib import Path
|
|
6
7
|
from omegaconf import OmegaConf, DictConfig
|
|
@@ -13,7 +14,36 @@ from sleap_nn.train import run_training
|
|
|
13
14
|
from sleap_nn import __version__
|
|
14
15
|
import hydra
|
|
15
16
|
import sys
|
|
16
|
-
|
|
17
|
+
|
|
18
|
+
# Rich-click configuration for styled help
|
|
19
|
+
click.rich_click.TEXT_MARKUP = "markdown"
|
|
20
|
+
click.rich_click.SHOW_ARGUMENTS = True
|
|
21
|
+
click.rich_click.GROUP_ARGUMENTS_OPTIONS = True
|
|
22
|
+
click.rich_click.STYLE_ERRORS_SUGGESTION = "magenta italic"
|
|
23
|
+
click.rich_click.ERRORS_EPILOGUE = (
|
|
24
|
+
"Try 'sleap-nn [COMMAND] --help' for more information."
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def is_config_path(arg: str) -> bool:
|
|
29
|
+
"""Check if an argument looks like a config file path.
|
|
30
|
+
|
|
31
|
+
Returns True if the arg ends with .yaml or .yml.
|
|
32
|
+
"""
|
|
33
|
+
return arg.endswith(".yaml") or arg.endswith(".yml")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def split_config_path(config_path: str) -> tuple:
|
|
37
|
+
"""Split a full config path into (config_dir, config_name).
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
config_path: Full path to a config file.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
Tuple of (config_dir, config_name) where config_dir is an absolute path.
|
|
44
|
+
"""
|
|
45
|
+
path = Path(config_path).resolve()
|
|
46
|
+
return path.parent.as_posix(), path.name
|
|
17
47
|
|
|
18
48
|
|
|
19
49
|
def print_version(ctx, param, value):
|
|
@@ -66,38 +96,77 @@ def cli():
|
|
|
66
96
|
|
|
67
97
|
|
|
68
98
|
def show_training_help():
|
|
69
|
-
"""Display training help information."""
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
99
|
+
"""Display training help information with rich formatting."""
|
|
100
|
+
from rich.console import Console
|
|
101
|
+
from rich.panel import Panel
|
|
102
|
+
from rich.markdown import Markdown
|
|
103
|
+
|
|
104
|
+
console = Console()
|
|
105
|
+
|
|
106
|
+
help_md = """
|
|
107
|
+
## Usage
|
|
108
|
+
|
|
109
|
+
```
|
|
110
|
+
sleap-nn train <config.yaml> [overrides]
|
|
111
|
+
sleap-nn train --config <path/to/config.yaml> [overrides]
|
|
112
|
+
```
|
|
113
|
+
|
|
114
|
+
## Common Overrides
|
|
115
|
+
|
|
116
|
+
| Override | Description |
|
|
117
|
+
|----------|-------------|
|
|
118
|
+
| `trainer_config.max_epochs=100` | Set maximum training epochs |
|
|
119
|
+
| `trainer_config.batch_size=32` | Set batch size |
|
|
120
|
+
| `trainer_config.save_ckpt=true` | Enable checkpoint saving |
|
|
121
|
+
|
|
122
|
+
## Examples
|
|
123
|
+
|
|
124
|
+
**Start a new training run:**
|
|
125
|
+
```bash
|
|
126
|
+
sleap-nn train path/to/config.yaml
|
|
127
|
+
sleap-nn train --config path/to/config.yaml
|
|
128
|
+
```
|
|
129
|
+
|
|
130
|
+
**With overrides:**
|
|
131
|
+
```bash
|
|
132
|
+
sleap-nn train config.yaml trainer_config.max_epochs=100
|
|
133
|
+
```
|
|
134
|
+
|
|
135
|
+
**Resume training:**
|
|
136
|
+
```bash
|
|
137
|
+
sleap-nn train config.yaml trainer_config.resume_ckpt_path=/path/to/ckpt
|
|
138
|
+
```
|
|
139
|
+
|
|
140
|
+
**Legacy usage (still supported):**
|
|
141
|
+
```bash
|
|
142
|
+
sleap-nn train --config-dir /path/to/dir --config-name myrun
|
|
143
|
+
```
|
|
144
|
+
|
|
145
|
+
## Tips
|
|
146
|
+
|
|
147
|
+
- Use `-m/--multirun` for sweeps; outputs go under `hydra.sweep.dir`
|
|
148
|
+
- For Hydra flags and completion, use `--hydra-help`
|
|
149
|
+
- Config documentation: https://nn.sleap.ai/config/
|
|
93
150
|
"""
|
|
94
|
-
|
|
151
|
+
console.print(
|
|
152
|
+
Panel(
|
|
153
|
+
Markdown(help_md),
|
|
154
|
+
title="[bold cyan]sleap-nn train[/bold cyan]",
|
|
155
|
+
subtitle="Train SLEAP models from a config YAML file",
|
|
156
|
+
border_style="cyan",
|
|
157
|
+
)
|
|
158
|
+
)
|
|
95
159
|
|
|
96
160
|
|
|
97
161
|
@cli.command(cls=TrainCommand)
|
|
98
|
-
@click.option("--config-name", "-c", type=str, help="Configuration file name")
|
|
99
162
|
@click.option(
|
|
100
|
-
"--config
|
|
163
|
+
"--config",
|
|
164
|
+
type=str,
|
|
165
|
+
help="Path to configuration file (e.g., path/to/config.yaml)",
|
|
166
|
+
)
|
|
167
|
+
@click.option("--config-name", "-c", type=str, help="Configuration file name (legacy)")
|
|
168
|
+
@click.option(
|
|
169
|
+
"--config-dir", "-d", type=str, default=".", help="Configuration directory (legacy)"
|
|
101
170
|
)
|
|
102
171
|
@click.option(
|
|
103
172
|
"--video-paths",
|
|
@@ -130,25 +199,43 @@ For a detailed list of all available config options, please refer to https://nn.
|
|
|
130
199
|
'Example: --prefix-map "/old/server/path" "/new/local/path"',
|
|
131
200
|
)
|
|
132
201
|
@click.argument("overrides", nargs=-1, type=click.UNPROCESSED)
|
|
133
|
-
def train(
|
|
202
|
+
def train(
|
|
203
|
+
config, config_name, config_dir, video_paths, video_path_map, prefix_map, overrides
|
|
204
|
+
):
|
|
134
205
|
"""Run training workflow with Hydra config overrides.
|
|
135
206
|
|
|
136
207
|
Examples:
|
|
137
|
-
sleap-nn train
|
|
208
|
+
sleap-nn train path/to/config.yaml
|
|
209
|
+
sleap-nn train --config path/to/config.yaml trainer_config.max_epochs=100
|
|
138
210
|
sleap-nn train -c myconfig -d /path/to/config_dir/ trainer_config.max_epochs=100
|
|
139
|
-
sleap-nn train -c myconfig -d /path/to/config_dir/ +experiment=new_model
|
|
140
211
|
"""
|
|
141
|
-
#
|
|
142
|
-
|
|
212
|
+
# Convert overrides to a mutable list
|
|
213
|
+
overrides = list(overrides)
|
|
214
|
+
|
|
215
|
+
# Check if the first positional arg is a config path (not a Hydra override)
|
|
216
|
+
config_from_positional = None
|
|
217
|
+
if overrides and is_config_path(overrides[0]):
|
|
218
|
+
config_from_positional = overrides.pop(0)
|
|
219
|
+
|
|
220
|
+
# Resolve config path with priority:
|
|
221
|
+
# 1. Positional config path (e.g., sleap-nn train config.yaml)
|
|
222
|
+
# 2. --config flag (e.g., sleap-nn train --config config.yaml)
|
|
223
|
+
# 3. Legacy --config-dir/--config-name flags
|
|
224
|
+
if config_from_positional:
|
|
225
|
+
config_dir, config_name = split_config_path(config_from_positional)
|
|
226
|
+
elif config:
|
|
227
|
+
config_dir, config_name = split_config_path(config)
|
|
228
|
+
elif config_name:
|
|
229
|
+
config_dir = Path(config_dir).resolve().as_posix()
|
|
230
|
+
else:
|
|
231
|
+
# No config provided - show help
|
|
143
232
|
show_training_help()
|
|
144
233
|
return
|
|
145
234
|
|
|
146
|
-
# Initialize Hydra manually
|
|
147
|
-
# resolve the path to the config directory (hydra expects absolute path)
|
|
148
|
-
config_dir = Path(config_dir).resolve().as_posix()
|
|
235
|
+
# Initialize Hydra manually (config_dir is already an absolute path)
|
|
149
236
|
with hydra.initialize_config_dir(config_dir=config_dir, version_base=None):
|
|
150
237
|
# Compose config with overrides
|
|
151
|
-
cfg = hydra.compose(config_name=config_name, overrides=
|
|
238
|
+
cfg = hydra.compose(config_name=config_name, overrides=overrides)
|
|
152
239
|
|
|
153
240
|
# Validate config
|
|
154
241
|
if not hasattr(cfg, "model_config") or not cfg.model_config:
|
|
@@ -581,6 +668,12 @@ def train(config_name, config_dir, video_paths, video_path_map, prefix_map, over
|
|
|
581
668
|
default=0,
|
|
582
669
|
help="IOU to use when culling instances *after* tracking. (default: 0)",
|
|
583
670
|
)
|
|
671
|
+
@click.option(
|
|
672
|
+
"--gui",
|
|
673
|
+
is_flag=True,
|
|
674
|
+
default=False,
|
|
675
|
+
help="Output JSON progress for GUI integration instead of Rich progress bar.",
|
|
676
|
+
)
|
|
584
677
|
def track(**kwargs):
|
|
585
678
|
"""Run Inference and Tracking workflow."""
|
|
586
679
|
# Convert model_paths from tuple to list
|
|
@@ -3,9 +3,8 @@
|
|
|
3
3
|
from typing import Optional, Tuple
|
|
4
4
|
|
|
5
5
|
import kornia as K
|
|
6
|
-
import numpy as np
|
|
7
6
|
import torch
|
|
8
|
-
|
|
7
|
+
import torch.nn.functional as F
|
|
9
8
|
|
|
10
9
|
from sleap_nn.data.instance_cropping import make_centered_bboxes
|
|
11
10
|
|
|
@@ -13,7 +12,11 @@ from sleap_nn.data.instance_cropping import make_centered_bboxes
|
|
|
13
12
|
def crop_bboxes(
|
|
14
13
|
images: torch.Tensor, bboxes: torch.Tensor, sample_inds: torch.Tensor
|
|
15
14
|
) -> torch.Tensor:
|
|
16
|
-
"""Crop bounding boxes from a batch of images.
|
|
15
|
+
"""Crop bounding boxes from a batch of images using fast tensor indexing.
|
|
16
|
+
|
|
17
|
+
This uses tensor unfold operations to extract patches, which is significantly
|
|
18
|
+
faster than kornia's crop_and_resize (17-51x speedup) as it avoids perspective
|
|
19
|
+
transform computations.
|
|
17
20
|
|
|
18
21
|
Args:
|
|
19
22
|
images: Tensor of shape (samples, channels, height, width) of a batch of images.
|
|
@@ -27,7 +30,7 @@ def crop_bboxes(
|
|
|
27
30
|
box should be cropped from.
|
|
28
31
|
|
|
29
32
|
Returns:
|
|
30
|
-
A tensor of shape (n_bboxes, crop_height, crop_width
|
|
33
|
+
A tensor of shape (n_bboxes, channels, crop_height, crop_width) of the same
|
|
31
34
|
dtype as the input image. The crop size is inferred from the bounding box
|
|
32
35
|
coordinates.
|
|
33
36
|
|
|
@@ -42,26 +45,53 @@ def crop_bboxes(
|
|
|
42
45
|
|
|
43
46
|
See also: `make_centered_bboxes`
|
|
44
47
|
"""
|
|
48
|
+
n_crops = bboxes.shape[0]
|
|
49
|
+
if n_crops == 0:
|
|
50
|
+
# Return empty tensor; use default crop size since we can't infer from bboxes
|
|
51
|
+
return torch.empty(
|
|
52
|
+
0, images.shape[1], 0, 0, device=images.device, dtype=images.dtype
|
|
53
|
+
)
|
|
54
|
+
|
|
45
55
|
# Compute bounding box size to use for crops.
|
|
46
|
-
height = abs(bboxes[0, 3, 1] - bboxes[0, 0, 1])
|
|
47
|
-
width = abs(bboxes[0, 1, 0] - bboxes[0, 0, 0])
|
|
48
|
-
box_size = tuple(torch.round(torch.Tensor((height + 1, width + 1))).to(torch.int32))
|
|
56
|
+
height = int(abs(bboxes[0, 3, 1] - bboxes[0, 0, 1]).item()) + 1
|
|
57
|
+
width = int(abs(bboxes[0, 1, 0] - bboxes[0, 0, 0]).item()) + 1
|
|
49
58
|
|
|
50
59
|
# Store original dtype for conversion back after cropping.
|
|
51
60
|
original_dtype = images.dtype
|
|
61
|
+
device = images.device
|
|
62
|
+
n_samples, channels, img_h, img_w = images.shape
|
|
63
|
+
half_h, half_w = height // 2, width // 2
|
|
52
64
|
|
|
53
|
-
#
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
images_to_crop = images_to_crop.float()
|
|
57
|
-
|
|
58
|
-
# Crop.
|
|
59
|
-
crops = crop_and_resize(
|
|
60
|
-
images_to_crop, # (n_boxes, channels, height, width)
|
|
61
|
-
boxes=bboxes,
|
|
62
|
-
size=box_size,
|
|
65
|
+
# Pad images for edge handling.
|
|
66
|
+
images_padded = F.pad(
|
|
67
|
+
images.float(), (half_w, half_w, half_h, half_h), mode="constant", value=0
|
|
63
68
|
)
|
|
64
69
|
|
|
70
|
+
# Extract all possible patches using unfold (creates a view, no copy).
|
|
71
|
+
# Shape after unfold: (n_samples, channels, img_h, img_w, height, width)
|
|
72
|
+
patches = images_padded.unfold(2, height, 1).unfold(3, width, 1)
|
|
73
|
+
|
|
74
|
+
# Get crop centers from bboxes.
|
|
75
|
+
# The bbox top-left is at index 0, with (x, y) coordinates.
|
|
76
|
+
# We need the center of the crop (peak location), which is top-left + half_size.
|
|
77
|
+
# Ensure bboxes are on the same device as images for index computation.
|
|
78
|
+
bboxes_on_device = bboxes.to(device)
|
|
79
|
+
crop_x = (bboxes_on_device[:, 0, 0] + half_w).to(torch.long)
|
|
80
|
+
crop_y = (bboxes_on_device[:, 0, 1] + half_h).to(torch.long)
|
|
81
|
+
|
|
82
|
+
# Clamp indices to valid bounds to handle edge cases where centroids
|
|
83
|
+
# might be at or beyond image boundaries.
|
|
84
|
+
crop_x = torch.clamp(crop_x, 0, patches.shape[3] - 1)
|
|
85
|
+
crop_y = torch.clamp(crop_y, 0, patches.shape[2] - 1)
|
|
86
|
+
|
|
87
|
+
# Select crops using advanced indexing.
|
|
88
|
+
# Convert sample_inds to tensor if it's a list.
|
|
89
|
+
if not isinstance(sample_inds, torch.Tensor):
|
|
90
|
+
sample_inds = torch.tensor(sample_inds, device=device)
|
|
91
|
+
sample_inds_long = sample_inds.to(device=device, dtype=torch.long)
|
|
92
|
+
crops = patches[sample_inds_long, :, crop_y, crop_x]
|
|
93
|
+
# Shape: (n_crops, channels, height, width)
|
|
94
|
+
|
|
65
95
|
# Cast back to original dtype and return.
|
|
66
96
|
crops = crops.to(original_dtype)
|
|
67
97
|
return crops
|
sleap_nn/inference/predictors.py
CHANGED
|
@@ -56,6 +56,8 @@ from rich.progress import (
|
|
|
56
56
|
MofNCompleteColumn,
|
|
57
57
|
)
|
|
58
58
|
from time import time
|
|
59
|
+
import json
|
|
60
|
+
import sys
|
|
59
61
|
|
|
60
62
|
|
|
61
63
|
def _filter_user_labeled_frames(
|
|
@@ -133,6 +135,8 @@ class Predictor(ABC):
|
|
|
133
135
|
`backbone_config`. This determines the downsampling factor applied by the backbone,
|
|
134
136
|
and is used to ensure that input images are padded or resized to be compatible
|
|
135
137
|
with the model's architecture. Default: 16.
|
|
138
|
+
gui: If True, outputs JSON progress lines for GUI integration instead of
|
|
139
|
+
Rich progress bars. Default: False.
|
|
136
140
|
"""
|
|
137
141
|
|
|
138
142
|
preprocess: bool = True
|
|
@@ -152,6 +156,7 @@ class Predictor(ABC):
|
|
|
152
156
|
] = None
|
|
153
157
|
instances_key: bool = False
|
|
154
158
|
max_stride: int = 16
|
|
159
|
+
gui: bool = False
|
|
155
160
|
|
|
156
161
|
@classmethod
|
|
157
162
|
def from_model_paths(
|
|
@@ -381,6 +386,102 @@ class Predictor(ABC):
|
|
|
381
386
|
v[n] = v[n].cpu().numpy()
|
|
382
387
|
return output
|
|
383
388
|
|
|
389
|
+
def _process_batch(self) -> tuple:
|
|
390
|
+
"""Process a single batch of frames from the pipeline.
|
|
391
|
+
|
|
392
|
+
Returns:
|
|
393
|
+
Tuple of (imgs, fidxs, vidxs, org_szs, instances, eff_scales, done)
|
|
394
|
+
where done is True if the pipeline has finished.
|
|
395
|
+
"""
|
|
396
|
+
imgs = []
|
|
397
|
+
fidxs = []
|
|
398
|
+
vidxs = []
|
|
399
|
+
org_szs = []
|
|
400
|
+
instances = []
|
|
401
|
+
eff_scales = []
|
|
402
|
+
done = False
|
|
403
|
+
|
|
404
|
+
for _ in range(self.batch_size):
|
|
405
|
+
frame = self.pipeline.frame_buffer.get()
|
|
406
|
+
if frame["image"] is None:
|
|
407
|
+
done = True
|
|
408
|
+
break
|
|
409
|
+
frame["image"], eff_scale = apply_sizematcher(
|
|
410
|
+
frame["image"],
|
|
411
|
+
self.preprocess_config["max_height"],
|
|
412
|
+
self.preprocess_config["max_width"],
|
|
413
|
+
)
|
|
414
|
+
if self.instances_key:
|
|
415
|
+
frame["instances"] = frame["instances"] * eff_scale
|
|
416
|
+
if self.preprocess_config["ensure_rgb"] and frame["image"].shape[-3] != 3:
|
|
417
|
+
frame["image"] = frame["image"].repeat(1, 3, 1, 1)
|
|
418
|
+
elif (
|
|
419
|
+
self.preprocess_config["ensure_grayscale"]
|
|
420
|
+
and frame["image"].shape[-3] != 1
|
|
421
|
+
):
|
|
422
|
+
frame["image"] = F.rgb_to_grayscale(
|
|
423
|
+
frame["image"], num_output_channels=1
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
eff_scales.append(torch.tensor(eff_scale))
|
|
427
|
+
imgs.append(frame["image"].unsqueeze(dim=0))
|
|
428
|
+
fidxs.append(frame["frame_idx"])
|
|
429
|
+
vidxs.append(frame["video_idx"])
|
|
430
|
+
org_szs.append(frame["orig_size"].unsqueeze(dim=0))
|
|
431
|
+
if self.instances_key:
|
|
432
|
+
instances.append(frame["instances"].unsqueeze(dim=0))
|
|
433
|
+
|
|
434
|
+
return imgs, fidxs, vidxs, org_szs, instances, eff_scales, done
|
|
435
|
+
|
|
436
|
+
def _run_inference_on_batch(
|
|
437
|
+
self, imgs, fidxs, vidxs, org_szs, instances, eff_scales
|
|
438
|
+
) -> Iterator[Dict[str, np.ndarray]]:
|
|
439
|
+
"""Run inference on a prepared batch of frames.
|
|
440
|
+
|
|
441
|
+
Args:
|
|
442
|
+
imgs: List of image tensors.
|
|
443
|
+
fidxs: List of frame indices.
|
|
444
|
+
vidxs: List of video indices.
|
|
445
|
+
org_szs: List of original sizes.
|
|
446
|
+
instances: List of instance tensors.
|
|
447
|
+
eff_scales: List of effective scales.
|
|
448
|
+
|
|
449
|
+
Yields:
|
|
450
|
+
Dictionaries containing inference results for each frame.
|
|
451
|
+
"""
|
|
452
|
+
# TODO: all preprocessing should be moved into InferenceModels to be exportable.
|
|
453
|
+
imgs = torch.concatenate(imgs, dim=0)
|
|
454
|
+
fidxs = torch.tensor(fidxs, dtype=torch.int32)
|
|
455
|
+
vidxs = torch.tensor(vidxs, dtype=torch.int32)
|
|
456
|
+
org_szs = torch.concatenate(org_szs, dim=0)
|
|
457
|
+
eff_scales = torch.tensor(eff_scales, dtype=torch.float32)
|
|
458
|
+
if self.instances_key:
|
|
459
|
+
instances = torch.concatenate(instances, dim=0)
|
|
460
|
+
ex = {
|
|
461
|
+
"image": imgs,
|
|
462
|
+
"frame_idx": fidxs,
|
|
463
|
+
"video_idx": vidxs,
|
|
464
|
+
"orig_size": org_szs,
|
|
465
|
+
"eff_scale": eff_scales,
|
|
466
|
+
}
|
|
467
|
+
if self.instances_key:
|
|
468
|
+
ex["instances"] = instances
|
|
469
|
+
if self.preprocess:
|
|
470
|
+
scale = self.preprocess_config["scale"]
|
|
471
|
+
if scale != 1.0:
|
|
472
|
+
if self.instances_key:
|
|
473
|
+
ex["image"], ex["instances"] = apply_resizer(
|
|
474
|
+
ex["image"], ex["instances"]
|
|
475
|
+
)
|
|
476
|
+
else:
|
|
477
|
+
ex["image"] = resize_image(ex["image"], scale)
|
|
478
|
+
ex["image"] = apply_pad_to_stride(ex["image"], self.max_stride)
|
|
479
|
+
outputs_list = self.inference_model(ex)
|
|
480
|
+
if outputs_list is not None:
|
|
481
|
+
for output in outputs_list:
|
|
482
|
+
output = self._convert_tensors_to_numpy(output)
|
|
483
|
+
yield output
|
|
484
|
+
|
|
384
485
|
def _predict_generator(self) -> Iterator[Dict[str, np.ndarray]]:
|
|
385
486
|
"""Create a generator that yields batches of inference results.
|
|
386
487
|
|
|
@@ -400,114 +501,14 @@ class Predictor(ABC):
|
|
|
400
501
|
# Loop over data batches.
|
|
401
502
|
self.pipeline.start()
|
|
402
503
|
total_frames = self.pipeline.total_len()
|
|
403
|
-
done = False
|
|
404
504
|
|
|
405
505
|
try:
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
TimeRemainingColumn(),
|
|
413
|
-
"Elapsed:",
|
|
414
|
-
TimeElapsedColumn(),
|
|
415
|
-
RateColumn(),
|
|
416
|
-
auto_refresh=False,
|
|
417
|
-
refresh_per_second=4, # Change to self.report_rate if needed
|
|
418
|
-
speed_estimate_period=5,
|
|
419
|
-
) as progress:
|
|
420
|
-
task = progress.add_task("Predicting...", total=total_frames)
|
|
421
|
-
last_report = time()
|
|
422
|
-
|
|
423
|
-
done = False
|
|
424
|
-
while not done:
|
|
425
|
-
imgs = []
|
|
426
|
-
fidxs = []
|
|
427
|
-
vidxs = []
|
|
428
|
-
org_szs = []
|
|
429
|
-
instances = []
|
|
430
|
-
eff_scales = []
|
|
431
|
-
for _ in range(self.batch_size):
|
|
432
|
-
frame = self.pipeline.frame_buffer.get()
|
|
433
|
-
if frame["image"] is None:
|
|
434
|
-
done = True
|
|
435
|
-
break
|
|
436
|
-
frame["image"], eff_scale = apply_sizematcher(
|
|
437
|
-
frame["image"],
|
|
438
|
-
self.preprocess_config["max_height"],
|
|
439
|
-
self.preprocess_config["max_width"],
|
|
440
|
-
)
|
|
441
|
-
if self.instances_key:
|
|
442
|
-
frame["instances"] = frame["instances"] * eff_scale
|
|
443
|
-
if (
|
|
444
|
-
self.preprocess_config["ensure_rgb"]
|
|
445
|
-
and frame["image"].shape[-3] != 3
|
|
446
|
-
):
|
|
447
|
-
frame["image"] = frame["image"].repeat(1, 3, 1, 1)
|
|
448
|
-
elif (
|
|
449
|
-
self.preprocess_config["ensure_grayscale"]
|
|
450
|
-
and frame["image"].shape[-3] != 1
|
|
451
|
-
):
|
|
452
|
-
frame["image"] = F.rgb_to_grayscale(
|
|
453
|
-
frame["image"], num_output_channels=1
|
|
454
|
-
)
|
|
455
|
-
|
|
456
|
-
eff_scales.append(torch.tensor(eff_scale))
|
|
457
|
-
imgs.append(frame["image"].unsqueeze(dim=0))
|
|
458
|
-
fidxs.append(frame["frame_idx"])
|
|
459
|
-
vidxs.append(frame["video_idx"])
|
|
460
|
-
org_szs.append(frame["orig_size"].unsqueeze(dim=0))
|
|
461
|
-
if self.instances_key:
|
|
462
|
-
instances.append(frame["instances"].unsqueeze(dim=0))
|
|
463
|
-
if imgs:
|
|
464
|
-
# TODO: all preprocessing should be moved into InferenceModels to be exportable.
|
|
465
|
-
imgs = torch.concatenate(imgs, dim=0)
|
|
466
|
-
fidxs = torch.tensor(fidxs, dtype=torch.int32)
|
|
467
|
-
vidxs = torch.tensor(vidxs, dtype=torch.int32)
|
|
468
|
-
org_szs = torch.concatenate(org_szs, dim=0)
|
|
469
|
-
eff_scales = torch.tensor(eff_scales, dtype=torch.float32)
|
|
470
|
-
if self.instances_key:
|
|
471
|
-
instances = torch.concatenate(instances, dim=0)
|
|
472
|
-
ex = {
|
|
473
|
-
"image": imgs,
|
|
474
|
-
"frame_idx": fidxs,
|
|
475
|
-
"video_idx": vidxs,
|
|
476
|
-
"orig_size": org_szs,
|
|
477
|
-
"eff_scale": eff_scales,
|
|
478
|
-
}
|
|
479
|
-
if self.instances_key:
|
|
480
|
-
ex["instances"] = instances
|
|
481
|
-
if self.preprocess:
|
|
482
|
-
scale = self.preprocess_config["scale"]
|
|
483
|
-
if scale != 1.0:
|
|
484
|
-
if self.instances_key:
|
|
485
|
-
ex["image"], ex["instances"] = apply_resizer(
|
|
486
|
-
ex["image"], ex["instances"]
|
|
487
|
-
)
|
|
488
|
-
else:
|
|
489
|
-
ex["image"] = resize_image(ex["image"], scale)
|
|
490
|
-
ex["image"] = apply_pad_to_stride(
|
|
491
|
-
ex["image"], self.max_stride
|
|
492
|
-
)
|
|
493
|
-
outputs_list = self.inference_model(ex)
|
|
494
|
-
if outputs_list is not None:
|
|
495
|
-
for output in outputs_list:
|
|
496
|
-
output = self._convert_tensors_to_numpy(output)
|
|
497
|
-
yield output
|
|
498
|
-
|
|
499
|
-
# Advance progress
|
|
500
|
-
num_frames = (
|
|
501
|
-
len(ex["frame_idx"])
|
|
502
|
-
if "frame_idx" in ex
|
|
503
|
-
else self.batch_size
|
|
504
|
-
)
|
|
505
|
-
progress.update(task, advance=num_frames)
|
|
506
|
-
|
|
507
|
-
# Manually refresh progress bar
|
|
508
|
-
if time() - last_report > 0.25:
|
|
509
|
-
progress.refresh()
|
|
510
|
-
last_report = time()
|
|
506
|
+
if self.gui:
|
|
507
|
+
# GUI mode: emit JSON progress lines
|
|
508
|
+
yield from self._predict_generator_gui(total_frames)
|
|
509
|
+
else:
|
|
510
|
+
# Normal mode: use Rich progress bar
|
|
511
|
+
yield from self._predict_generator_rich(total_frames)
|
|
511
512
|
|
|
512
513
|
except KeyboardInterrupt:
|
|
513
514
|
logger.info("Inference interrupted by user")
|
|
@@ -520,6 +521,112 @@ class Predictor(ABC):
|
|
|
520
521
|
|
|
521
522
|
self.pipeline.join()
|
|
522
523
|
|
|
524
|
+
def _predict_generator_gui(
|
|
525
|
+
self, total_frames: int
|
|
526
|
+
) -> Iterator[Dict[str, np.ndarray]]:
|
|
527
|
+
"""Generator for GUI mode with JSON progress output.
|
|
528
|
+
|
|
529
|
+
Args:
|
|
530
|
+
total_frames: Total number of frames to process.
|
|
531
|
+
|
|
532
|
+
Yields:
|
|
533
|
+
Dictionaries containing inference results for each frame.
|
|
534
|
+
"""
|
|
535
|
+
start_time = time()
|
|
536
|
+
frames_processed = 0
|
|
537
|
+
last_report = time()
|
|
538
|
+
done = False
|
|
539
|
+
|
|
540
|
+
while not done:
|
|
541
|
+
imgs, fidxs, vidxs, org_szs, instances, eff_scales, done = (
|
|
542
|
+
self._process_batch()
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
if imgs:
|
|
546
|
+
yield from self._run_inference_on_batch(
|
|
547
|
+
imgs, fidxs, vidxs, org_szs, instances, eff_scales
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
# Update progress
|
|
551
|
+
num_frames = len(fidxs)
|
|
552
|
+
frames_processed += num_frames
|
|
553
|
+
|
|
554
|
+
# Emit JSON progress (throttled to ~4Hz)
|
|
555
|
+
if time() - last_report > 0.25:
|
|
556
|
+
elapsed = time() - start_time
|
|
557
|
+
rate = frames_processed / elapsed if elapsed > 0 else 0
|
|
558
|
+
remaining = total_frames - frames_processed
|
|
559
|
+
eta = remaining / rate if rate > 0 else 0
|
|
560
|
+
|
|
561
|
+
progress_data = {
|
|
562
|
+
"n_processed": frames_processed,
|
|
563
|
+
"n_total": total_frames,
|
|
564
|
+
"rate": round(rate, 1),
|
|
565
|
+
"eta": round(eta, 1),
|
|
566
|
+
}
|
|
567
|
+
print(json.dumps(progress_data), flush=True)
|
|
568
|
+
last_report = time()
|
|
569
|
+
|
|
570
|
+
# Final progress emit to ensure 100% is shown
|
|
571
|
+
elapsed = time() - start_time
|
|
572
|
+
progress_data = {
|
|
573
|
+
"n_processed": total_frames,
|
|
574
|
+
"n_total": total_frames,
|
|
575
|
+
"rate": round(frames_processed / elapsed, 1) if elapsed > 0 else 0,
|
|
576
|
+
"eta": 0,
|
|
577
|
+
}
|
|
578
|
+
print(json.dumps(progress_data), flush=True)
|
|
579
|
+
|
|
580
|
+
def _predict_generator_rich(
|
|
581
|
+
self, total_frames: int
|
|
582
|
+
) -> Iterator[Dict[str, np.ndarray]]:
|
|
583
|
+
"""Generator for normal mode with Rich progress bar.
|
|
584
|
+
|
|
585
|
+
Args:
|
|
586
|
+
total_frames: Total number of frames to process.
|
|
587
|
+
|
|
588
|
+
Yields:
|
|
589
|
+
Dictionaries containing inference results for each frame.
|
|
590
|
+
"""
|
|
591
|
+
with Progress(
|
|
592
|
+
"{task.description}",
|
|
593
|
+
BarColumn(),
|
|
594
|
+
"[progress.percentage]{task.percentage:>3.0f}%",
|
|
595
|
+
MofNCompleteColumn(),
|
|
596
|
+
"ETA:",
|
|
597
|
+
TimeRemainingColumn(),
|
|
598
|
+
"Elapsed:",
|
|
599
|
+
TimeElapsedColumn(),
|
|
600
|
+
RateColumn(),
|
|
601
|
+
auto_refresh=False,
|
|
602
|
+
refresh_per_second=4,
|
|
603
|
+
speed_estimate_period=5,
|
|
604
|
+
) as progress:
|
|
605
|
+
task = progress.add_task("Predicting...", total=total_frames)
|
|
606
|
+
last_report = time()
|
|
607
|
+
done = False
|
|
608
|
+
|
|
609
|
+
while not done:
|
|
610
|
+
imgs, fidxs, vidxs, org_szs, instances, eff_scales, done = (
|
|
611
|
+
self._process_batch()
|
|
612
|
+
)
|
|
613
|
+
|
|
614
|
+
if imgs:
|
|
615
|
+
yield from self._run_inference_on_batch(
|
|
616
|
+
imgs, fidxs, vidxs, org_szs, instances, eff_scales
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
# Advance progress
|
|
620
|
+
num_frames = len(fidxs)
|
|
621
|
+
progress.update(task, advance=num_frames)
|
|
622
|
+
|
|
623
|
+
# Manually refresh progress bar
|
|
624
|
+
if time() - last_report > 0.25:
|
|
625
|
+
progress.refresh()
|
|
626
|
+
last_report = time()
|
|
627
|
+
|
|
628
|
+
self.pipeline.join()
|
|
629
|
+
|
|
523
630
|
def predict(
|
|
524
631
|
self,
|
|
525
632
|
make_labels: bool = True,
|
sleap_nn/predict.py
CHANGED
|
@@ -113,6 +113,7 @@ def run_inference(
|
|
|
113
113
|
tracking_pre_cull_iou_threshold: float = 0,
|
|
114
114
|
tracking_clean_instance_count: int = 0,
|
|
115
115
|
tracking_clean_iou_threshold: float = 0,
|
|
116
|
+
gui: bool = False,
|
|
116
117
|
):
|
|
117
118
|
"""Entry point to run inference on trained SLEAP-NN models.
|
|
118
119
|
|
|
@@ -262,6 +263,8 @@ def run_inference(
|
|
|
262
263
|
tracking_pre_cull_iou_threshold: If non-zero and pre_cull_to_target also set, then use IOU threshold to remove overlapping instances over count *before* tracking. (default: 0)
|
|
263
264
|
tracking_clean_instance_count: Target number of instances to clean *after* tracking. (default: 0)
|
|
264
265
|
tracking_clean_iou_threshold: IOU to use when culling instances *after* tracking. (default: 0)
|
|
266
|
+
gui: (bool) If True, outputs JSON progress lines for GUI integration instead
|
|
267
|
+
of Rich progress bars. Default: False.
|
|
265
268
|
|
|
266
269
|
Returns:
|
|
267
270
|
Returns `sio.Labels` object if `make_labels` is True. Else this function returns
|
|
@@ -445,13 +448,6 @@ def run_inference(
|
|
|
445
448
|
else "mps" if torch.backends.mps.is_available() else "cpu"
|
|
446
449
|
)
|
|
447
450
|
|
|
448
|
-
if integral_refinement is not None and device == "mps": # TODO
|
|
449
|
-
# kornia/geometry/transform/imgwarp.py:382: in get_perspective_transform. NotImplementedError: The operator 'aten::_linalg_solve_ex.result' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
|
|
450
|
-
logger.info(
|
|
451
|
-
"Integral refinement is not supported with MPS accelerator. Setting integral refinement to None."
|
|
452
|
-
)
|
|
453
|
-
integral_refinement = None
|
|
454
|
-
|
|
455
451
|
logger.info(f"Using device: {device}")
|
|
456
452
|
|
|
457
453
|
# initializes the inference model
|
|
@@ -470,6 +466,9 @@ def run_inference(
|
|
|
470
466
|
anchor_part=anchor_part,
|
|
471
467
|
)
|
|
472
468
|
|
|
469
|
+
# Set GUI mode for progress output
|
|
470
|
+
predictor.gui = gui
|
|
471
|
+
|
|
473
472
|
if (
|
|
474
473
|
tracking
|
|
475
474
|
and not isinstance(predictor, BottomUpMultiClassPredictor)
|
sleap_nn/training/callbacks.py
CHANGED
|
@@ -85,10 +85,15 @@ class CSVLoggerCallback(Callback):
|
|
|
85
85
|
if key == "epoch":
|
|
86
86
|
log_data["epoch"] = trainer.current_epoch
|
|
87
87
|
elif key == "learning_rate":
|
|
88
|
-
# Handle
|
|
88
|
+
# Handle multiple formats:
|
|
89
|
+
# 1. Direct "learning_rate" key
|
|
90
|
+
# 2. "train/lr" key (current format from lightning modules)
|
|
91
|
+
# 3. "lr-*" keys from LearningRateMonitor (legacy)
|
|
89
92
|
value = metrics.get(key, None)
|
|
90
93
|
if value is None:
|
|
91
|
-
|
|
94
|
+
value = metrics.get("train/lr", None)
|
|
95
|
+
if value is None:
|
|
96
|
+
# Look for lr-* keys from LearningRateMonitor (legacy)
|
|
92
97
|
for metric_key in metrics.keys():
|
|
93
98
|
if metric_key.startswith("lr-"):
|
|
94
99
|
value = metrics[metric_key]
|
|
@@ -849,6 +849,7 @@ class ModelTrainer:
|
|
|
849
849
|
"train/time",
|
|
850
850
|
"val/time",
|
|
851
851
|
]
|
|
852
|
+
# Add model-specific keys for wandb parity
|
|
852
853
|
if self.model_type in [
|
|
853
854
|
"single_instance",
|
|
854
855
|
"centered_instance",
|
|
@@ -857,6 +858,37 @@ class ModelTrainer:
|
|
|
857
858
|
csv_log_keys.extend(
|
|
858
859
|
[f"train/confmaps/{name}" for name in self.skeletons[0].node_names]
|
|
859
860
|
)
|
|
861
|
+
if self.model_type == "bottomup":
|
|
862
|
+
csv_log_keys.extend(
|
|
863
|
+
[
|
|
864
|
+
"train/confmaps_loss",
|
|
865
|
+
"train/paf_loss",
|
|
866
|
+
"val/confmaps_loss",
|
|
867
|
+
"val/paf_loss",
|
|
868
|
+
]
|
|
869
|
+
)
|
|
870
|
+
if self.model_type == "multi_class_bottomup":
|
|
871
|
+
csv_log_keys.extend(
|
|
872
|
+
[
|
|
873
|
+
"train/confmaps_loss",
|
|
874
|
+
"train/classmap_loss",
|
|
875
|
+
"train/class_accuracy",
|
|
876
|
+
"val/confmaps_loss",
|
|
877
|
+
"val/classmap_loss",
|
|
878
|
+
"val/class_accuracy",
|
|
879
|
+
]
|
|
880
|
+
)
|
|
881
|
+
if self.model_type == "multi_class_topdown":
|
|
882
|
+
csv_log_keys.extend(
|
|
883
|
+
[
|
|
884
|
+
"train/confmaps_loss",
|
|
885
|
+
"train/classvector_loss",
|
|
886
|
+
"train/class_accuracy",
|
|
887
|
+
"val/confmaps_loss",
|
|
888
|
+
"val/classvector_loss",
|
|
889
|
+
"val/class_accuracy",
|
|
890
|
+
]
|
|
891
|
+
)
|
|
860
892
|
csv_logger = CSVLoggerCallback(
|
|
861
893
|
filepath=Path(self.config.trainer_config.ckpt_dir)
|
|
862
894
|
/ self.config.trainer_config.run_name
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sleap-nn
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.0a4
|
|
4
4
|
Summary: Neural network backend for training and inference for animal pose estimation.
|
|
5
5
|
Author-email: Divya Seshadri Murali <dimurali@salk.edu>, Elizabeth Berrigan <eberrigan@salk.edu>, Vincent Tu <vitu@ucsd.edu>, Liezl Maree <lmaree@salk.edu>, David Samy <davidasamy@gmail.com>, Talmo Pereira <talmo@salk.edu>
|
|
6
6
|
License: BSD-3-Clause
|
|
@@ -32,6 +32,7 @@ Requires-Dist: hydra-core
|
|
|
32
32
|
Requires-Dist: jupyter
|
|
33
33
|
Requires-Dist: jupyterlab
|
|
34
34
|
Requires-Dist: pyzmq
|
|
35
|
+
Requires-Dist: rich-click>=1.9.5
|
|
35
36
|
Provides-Extra: torch
|
|
36
37
|
Requires-Dist: torch; extra == "torch"
|
|
37
38
|
Requires-Dist: torchvision>=0.20.0; extra == "torch"
|
|
@@ -1,18 +1,18 @@
|
|
|
1
1
|
sleap_nn/.DS_Store,sha256=HY8amA79eHkt7o5VUiNsMxkc9YwW6WIPyZbYRj_JdSU,6148
|
|
2
|
-
sleap_nn/__init__.py,sha256=
|
|
3
|
-
sleap_nn/cli.py,sha256=
|
|
2
|
+
sleap_nn/__init__.py,sha256=W6NBNc9X-Rt5XX9EQSOZ0X2wfj-G4dWlPfkwt-wCUqw,1362
|
|
3
|
+
sleap_nn/cli.py,sha256=NFwxO3Fby_RgRroF7dNXsvWnBeILuVEylaMsLCLUMlY,24792
|
|
4
4
|
sleap_nn/evaluation.py,sha256=SRO3qNOyyGoNBLLA2OKIUhvwyk0oI2ax1rtYmccx6m0,33785
|
|
5
5
|
sleap_nn/legacy_models.py,sha256=8aGK30DZv3pW2IKDBEWH1G2mrytjaxPQD4miPUehj0M,20258
|
|
6
|
-
sleap_nn/predict.py,sha256=
|
|
6
|
+
sleap_nn/predict.py,sha256=tN3vuP_fGCme7fLXd2b9DvItSr_pemzw8FUtIbkkU_U,36513
|
|
7
7
|
sleap_nn/system_info.py,sha256=7tWe3y6s872nDbrZoHIdSs-w4w46Z4dEV2qCV-Fe7No,14711
|
|
8
8
|
sleap_nn/train.py,sha256=PEaK2B0S7DoImf8vt2cvJQS-n2NBw_pUJHmXy0J4NT0,30712
|
|
9
9
|
sleap_nn/architectures/__init__.py,sha256=w0XxQcx-CYyooszzvxRkKWiJkUg-26IlwQoGna8gn40,46
|
|
10
10
|
sleap_nn/architectures/common.py,sha256=MLv-zdHsWL5Q2ct_Wv6SQbRS-5hrFtjK_pvBEfwx-vU,3660
|
|
11
|
-
sleap_nn/architectures/convnext.py,sha256=
|
|
12
|
-
sleap_nn/architectures/encoder_decoder.py,sha256=
|
|
11
|
+
sleap_nn/architectures/convnext.py,sha256=Ba9SFQHBdfz8gcMYZPMItuW-FyQuHBgUU0M8MWhaHuY,14210
|
|
12
|
+
sleap_nn/architectures/encoder_decoder.py,sha256=1cBk9WU0zkXC2aK9XZy6VKHEe2hJEpIa-rwCxNgObZg,29292
|
|
13
13
|
sleap_nn/architectures/heads.py,sha256=5E-7kQ-b2gsL0EviQ8z3KS1DAAMT4F2ZnEzx7eSG5gg,21001
|
|
14
14
|
sleap_nn/architectures/model.py,sha256=1_dsP_4T9fsEVJjDt3er0haMKtbeM6w6JC6tc2jD0Gw,7139
|
|
15
|
-
sleap_nn/architectures/swint.py,sha256=
|
|
15
|
+
sleap_nn/architectures/swint.py,sha256=hlShh1Br0eTijir2U3np8sAaNJa12Xny0VzPx8HSaRo,15060
|
|
16
16
|
sleap_nn/architectures/unet.py,sha256=rAy2Omi6tv1MNW2nBn0Tw-94Nw_-1wFfCT3-IUyPcgo,11723
|
|
17
17
|
sleap_nn/architectures/utils.py,sha256=L0KVs0gbtG8U75Sl40oH_r_w2ySawh3oQPqIGi54HGo,2171
|
|
18
18
|
sleap_nn/config/__init__.py,sha256=l0xV1uJsGJfMPfWAqlUR7Ivu4cSCWsP-3Y9ueyPESuk,42
|
|
@@ -58,9 +58,9 @@ sleap_nn/inference/__init__.py,sha256=eVkCmKrxHlDFJIlZTf8B5XEOcSyw-gPQymXMY5uShO
|
|
|
58
58
|
sleap_nn/inference/bottomup.py,sha256=3s90aRlpIcRnSNe-R5-qiuX3S48kCWMpCl8YuNnTEDI,17084
|
|
59
59
|
sleap_nn/inference/identity.py,sha256=GjNDL9MfGqNyQaK4AE8JQCAE8gpMuE_Y-3r3Gpa53CE,6540
|
|
60
60
|
sleap_nn/inference/paf_grouping.py,sha256=7Fo9lCAj-zcHgv5rI5LIMYGcixCGNt_ZbSNs8Dik7l8,69973
|
|
61
|
-
sleap_nn/inference/peak_finding.py,sha256=
|
|
61
|
+
sleap_nn/inference/peak_finding.py,sha256=l6PKGw7KiVxzd00cesUZsbttPfjP1NBy8WmxWQtBlak,14595
|
|
62
62
|
sleap_nn/inference/postprocessing.py,sha256=ZM_OH7_WIprieaujZ2Rk_34JhSDDzCry6Pq2YM_u5sg,8998
|
|
63
|
-
sleap_nn/inference/predictors.py,sha256=
|
|
63
|
+
sleap_nn/inference/predictors.py,sha256=xZyuH2bmsj_NAXcaswDFWqqmYS57v4QtZIWdsFqb3Sc,160709
|
|
64
64
|
sleap_nn/inference/provenance.py,sha256=0BekXyvpLMb0Vv6DjpctlLduG9RN-Q8jt5zDm783eZE,11204
|
|
65
65
|
sleap_nn/inference/single_instance.py,sha256=rOns_5TsJ1rb-lwmHG3ZY-pOhXGN2D-SfW9RmBxxzcI,4089
|
|
66
66
|
sleap_nn/inference/topdown.py,sha256=Ha0Nwx-XCH_rebIuIGhP0qW68QpjLB3XRr9rxt05JLs,35108
|
|
@@ -73,14 +73,14 @@ sleap_nn/tracking/candidates/__init__.py,sha256=1O7NObIwshM7j1rLHmImbFphvkM9wY1j
|
|
|
73
73
|
sleap_nn/tracking/candidates/fixed_window.py,sha256=D80KMlTnenuQveQVVhk9j0G8yx6K324C7nMLHgG76e0,6296
|
|
74
74
|
sleap_nn/tracking/candidates/local_queues.py,sha256=Nx3R5wwEwq0gbfH-fi3oOumfkQo8_sYe5GN47pD9Be8,7305
|
|
75
75
|
sleap_nn/training/__init__.py,sha256=vNTKsIJPZHJwFSKn5PmjiiRJunR_9e7y4_v0S6rdF8U,32
|
|
76
|
-
sleap_nn/training/callbacks.py,sha256=
|
|
76
|
+
sleap_nn/training/callbacks.py,sha256=7WRT2pmQQ-hRdq9n7iHC_e0zH-vDphYfe0KHdD-UGg4,38216
|
|
77
77
|
sleap_nn/training/lightning_modules.py,sha256=z98NBTrNy-GfCw4zatummJhVUO1fdjv_kPweAKcoaXc,108394
|
|
78
78
|
sleap_nn/training/losses.py,sha256=gbdinUURh4QUzjmNd2UJpt4FXwecqKy9gHr65JZ1bZk,1632
|
|
79
|
-
sleap_nn/training/model_trainer.py,sha256=
|
|
79
|
+
sleap_nn/training/model_trainer.py,sha256=okXTouoXzRcHcflRCdwR3NwUwSdX-ex1-rZOZHYCZLk,59964
|
|
80
80
|
sleap_nn/training/utils.py,sha256=ivdkZEI0DkTCm6NPszsaDOh9jSfozkONZdl6TvvQUWI,20398
|
|
81
|
-
sleap_nn-0.1.
|
|
82
|
-
sleap_nn-0.1.
|
|
83
|
-
sleap_nn-0.1.
|
|
84
|
-
sleap_nn-0.1.
|
|
85
|
-
sleap_nn-0.1.
|
|
86
|
-
sleap_nn-0.1.
|
|
81
|
+
sleap_nn-0.1.0a4.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
82
|
+
sleap_nn-0.1.0a4.dist-info/METADATA,sha256=kA66dtTSVKAdFJcnvsSEMkrT3TRyGzHAcAsCIHzoqbE,6178
|
|
83
|
+
sleap_nn-0.1.0a4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
84
|
+
sleap_nn-0.1.0a4.dist-info/entry_points.txt,sha256=zfl5Y3hidZxWBvo8qXvu5piJAXJ_l6v7xVFm0gNiUoI,46
|
|
85
|
+
sleap_nn-0.1.0a4.dist-info/top_level.txt,sha256=Kz68iQ55K75LWgSeqz4V4SCMGeFFYH-KGBOyhQh3xZE,9
|
|
86
|
+
sleap_nn-0.1.0a4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|