sleap-nn 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +9 -2
- sleap_nn/architectures/convnext.py +5 -0
- sleap_nn/architectures/encoder_decoder.py +25 -6
- sleap_nn/architectures/swint.py +8 -0
- sleap_nn/cli.py +489 -46
- sleap_nn/config/data_config.py +51 -8
- sleap_nn/config/get_config.py +32 -24
- sleap_nn/config/trainer_config.py +88 -0
- sleap_nn/data/augmentation.py +61 -200
- sleap_nn/data/custom_datasets.py +433 -61
- sleap_nn/data/instance_cropping.py +71 -6
- sleap_nn/data/normalization.py +45 -2
- sleap_nn/data/providers.py +26 -0
- sleap_nn/data/resizing.py +2 -2
- sleap_nn/data/skia_augmentation.py +414 -0
- sleap_nn/data/utils.py +135 -17
- sleap_nn/evaluation.py +177 -42
- sleap_nn/export/__init__.py +21 -0
- sleap_nn/export/cli.py +1778 -0
- sleap_nn/export/exporters/__init__.py +51 -0
- sleap_nn/export/exporters/onnx_exporter.py +80 -0
- sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
- sleap_nn/export/metadata.py +225 -0
- sleap_nn/export/predictors/__init__.py +63 -0
- sleap_nn/export/predictors/base.py +22 -0
- sleap_nn/export/predictors/onnx.py +154 -0
- sleap_nn/export/predictors/tensorrt.py +312 -0
- sleap_nn/export/utils.py +307 -0
- sleap_nn/export/wrappers/__init__.py +25 -0
- sleap_nn/export/wrappers/base.py +96 -0
- sleap_nn/export/wrappers/bottomup.py +243 -0
- sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
- sleap_nn/export/wrappers/centered_instance.py +56 -0
- sleap_nn/export/wrappers/centroid.py +58 -0
- sleap_nn/export/wrappers/single_instance.py +83 -0
- sleap_nn/export/wrappers/topdown.py +180 -0
- sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
- sleap_nn/inference/__init__.py +6 -0
- sleap_nn/inference/bottomup.py +86 -20
- sleap_nn/inference/peak_finding.py +93 -16
- sleap_nn/inference/postprocessing.py +284 -0
- sleap_nn/inference/predictors.py +339 -137
- sleap_nn/inference/provenance.py +292 -0
- sleap_nn/inference/topdown.py +55 -47
- sleap_nn/legacy_models.py +65 -11
- sleap_nn/predict.py +224 -19
- sleap_nn/system_info.py +443 -0
- sleap_nn/tracking/tracker.py +8 -1
- sleap_nn/train.py +138 -44
- sleap_nn/training/callbacks.py +1258 -5
- sleap_nn/training/lightning_modules.py +902 -220
- sleap_nn/training/model_trainer.py +424 -111
- sleap_nn/training/schedulers.py +191 -0
- sleap_nn/training/utils.py +367 -2
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/METADATA +35 -33
- sleap_nn-0.1.0.dist-info/RECORD +88 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/WHEEL +1 -1
- sleap_nn-0.0.5.dist-info/RECORD +0 -63
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/top_level.txt +0 -0
sleap_nn/__init__.py
CHANGED
|
@@ -41,11 +41,18 @@ def _safe_print(msg):
|
|
|
41
41
|
|
|
42
42
|
|
|
43
43
|
# Add logger with the custom filter
|
|
44
|
+
# Disable colorization to avoid ANSI codes in captured output
|
|
44
45
|
logger.add(
|
|
45
46
|
_safe_print,
|
|
46
47
|
level="DEBUG",
|
|
47
48
|
filter=_should_log,
|
|
48
|
-
format="{time:YYYY-MM-DD HH:mm:ss} | {
|
|
49
|
+
format="{time:YYYY-MM-DD HH:mm:ss} | {message}",
|
|
50
|
+
colorize=False,
|
|
49
51
|
)
|
|
50
52
|
|
|
51
|
-
__version__ = "0.0
|
|
53
|
+
__version__ = "0.1.0"
|
|
54
|
+
|
|
55
|
+
# Public API
|
|
56
|
+
from sleap_nn.evaluation import load_metrics
|
|
57
|
+
|
|
58
|
+
__all__ = ["load_metrics", "__version__"]
|
|
@@ -281,6 +281,10 @@ class ConvNextWrapper(nn.Module):
|
|
|
281
281
|
# Keep the block output filters the same
|
|
282
282
|
x_in_shape = int(self.arch["channels"][-1] * filters_rate)
|
|
283
283
|
|
|
284
|
+
# Encoder channels for skip connections (reversed to match decoder order)
|
|
285
|
+
# The forward pass uses enc_output[::2][::-1] for skip features
|
|
286
|
+
encoder_channels = self.arch["channels"][::-1]
|
|
287
|
+
|
|
284
288
|
self.dec = Decoder(
|
|
285
289
|
x_in_shape=x_in_shape,
|
|
286
290
|
current_stride=self.current_stride,
|
|
@@ -293,6 +297,7 @@ class ConvNextWrapper(nn.Module):
|
|
|
293
297
|
block_contraction=self.block_contraction,
|
|
294
298
|
output_stride=self.output_stride,
|
|
295
299
|
up_interpolate=up_interpolate,
|
|
300
|
+
encoder_channels=encoder_channels,
|
|
296
301
|
)
|
|
297
302
|
|
|
298
303
|
if len(self.dec.decoder_stack):
|
|
@@ -25,7 +25,7 @@ classes.
|
|
|
25
25
|
See the `EncoderDecoder` base class for requirements for creating new architectures.
|
|
26
26
|
"""
|
|
27
27
|
|
|
28
|
-
from typing import List, Text, Tuple, Union
|
|
28
|
+
from typing import List, Optional, Text, Tuple, Union
|
|
29
29
|
from collections import OrderedDict
|
|
30
30
|
import torch
|
|
31
31
|
from torch import nn
|
|
@@ -391,10 +391,18 @@ class SimpleUpsamplingBlock(nn.Module):
|
|
|
391
391
|
transpose_convs_activation: Text = "relu",
|
|
392
392
|
feat_concat: bool = True,
|
|
393
393
|
prefix: Text = "",
|
|
394
|
+
skip_channels: Optional[int] = None,
|
|
394
395
|
) -> None:
|
|
395
396
|
"""Initialize the class."""
|
|
396
397
|
super().__init__()
|
|
397
398
|
|
|
399
|
+
# Determine skip connection channels
|
|
400
|
+
# If skip_channels is provided, use it; otherwise fall back to refine_convs_filters
|
|
401
|
+
# This allows ConvNext/SwinT to specify actual encoder channels
|
|
402
|
+
self.skip_channels = (
|
|
403
|
+
skip_channels if skip_channels is not None else refine_convs_filters
|
|
404
|
+
)
|
|
405
|
+
|
|
398
406
|
self.x_in_shape = x_in_shape
|
|
399
407
|
self.current_stride = current_stride
|
|
400
408
|
self.upsampling_stride = upsampling_stride
|
|
@@ -469,13 +477,13 @@ class SimpleUpsamplingBlock(nn.Module):
|
|
|
469
477
|
first_conv_in_channels = refine_convs_filters
|
|
470
478
|
else:
|
|
471
479
|
if self.up_interpolate:
|
|
472
|
-
# With interpolation, input is x_in_shape +
|
|
473
|
-
#
|
|
474
|
-
first_conv_in_channels = x_in_shape +
|
|
480
|
+
# With interpolation, input is x_in_shape + skip_channels
|
|
481
|
+
# skip_channels may differ from refine_convs_filters for ConvNext/SwinT
|
|
482
|
+
first_conv_in_channels = x_in_shape + self.skip_channels
|
|
475
483
|
else:
|
|
476
|
-
# With transpose conv, input is transpose_conv_output +
|
|
484
|
+
# With transpose conv, input is transpose_conv_output + skip_channels
|
|
477
485
|
first_conv_in_channels = (
|
|
478
|
-
|
|
486
|
+
self.skip_channels + transpose_convs_filters
|
|
479
487
|
)
|
|
480
488
|
else:
|
|
481
489
|
if not self.feat_concat:
|
|
@@ -582,6 +590,7 @@ class Decoder(nn.Module):
|
|
|
582
590
|
block_contraction: bool = False,
|
|
583
591
|
up_interpolate: bool = True,
|
|
584
592
|
prefix: str = "dec",
|
|
593
|
+
encoder_channels: Optional[List[int]] = None,
|
|
585
594
|
) -> None:
|
|
586
595
|
"""Initialize the class."""
|
|
587
596
|
super().__init__()
|
|
@@ -598,6 +607,7 @@ class Decoder(nn.Module):
|
|
|
598
607
|
self.block_contraction = block_contraction
|
|
599
608
|
self.prefix = prefix
|
|
600
609
|
self.stride_to_filters = {}
|
|
610
|
+
self.encoder_channels = encoder_channels
|
|
601
611
|
|
|
602
612
|
self.current_strides = []
|
|
603
613
|
self.residuals = 0
|
|
@@ -624,6 +634,13 @@ class Decoder(nn.Module):
|
|
|
624
634
|
|
|
625
635
|
next_stride = current_stride // 2
|
|
626
636
|
|
|
637
|
+
# Determine skip channels for this decoder block
|
|
638
|
+
# If encoder_channels provided, use actual encoder channels
|
|
639
|
+
# Otherwise fall back to computed filters (for UNet compatibility)
|
|
640
|
+
skip_channels = None
|
|
641
|
+
if encoder_channels is not None and block < len(encoder_channels):
|
|
642
|
+
skip_channels = encoder_channels[block]
|
|
643
|
+
|
|
627
644
|
if self.stem_blocks > 0 and block >= down_blocks + self.stem_blocks:
|
|
628
645
|
# This accounts for the case where we dont have any more down block features to concatenate with.
|
|
629
646
|
# In this case, add a simple upsampling block with a conv layer and with no concatenation
|
|
@@ -642,6 +659,7 @@ class Decoder(nn.Module):
|
|
|
642
659
|
transpose_convs_batch_norm=False,
|
|
643
660
|
feat_concat=False,
|
|
644
661
|
prefix=f"{self.prefix}{block}_s{current_stride}_to_s{next_stride}",
|
|
662
|
+
skip_channels=skip_channels,
|
|
645
663
|
)
|
|
646
664
|
)
|
|
647
665
|
else:
|
|
@@ -659,6 +677,7 @@ class Decoder(nn.Module):
|
|
|
659
677
|
transpose_convs_filters=block_filters_out,
|
|
660
678
|
transpose_convs_batch_norm=False,
|
|
661
679
|
prefix=f"{self.prefix}{block}_s{current_stride}_to_s{next_stride}",
|
|
680
|
+
skip_channels=skip_channels,
|
|
662
681
|
)
|
|
663
682
|
)
|
|
664
683
|
|
sleap_nn/architectures/swint.py
CHANGED
|
@@ -309,6 +309,13 @@ class SwinTWrapper(nn.Module):
|
|
|
309
309
|
self.stem_patch_stride * (2**3) * 2
|
|
310
310
|
) # stem_stride * down_blocks_stride * final_max_pool_stride
|
|
311
311
|
|
|
312
|
+
# Encoder channels for skip connections (reversed to match decoder order)
|
|
313
|
+
# SwinT channels: embed * 2^i for each stage i, then reversed
|
|
314
|
+
num_stages = len(self.arch["depths"])
|
|
315
|
+
encoder_channels = [
|
|
316
|
+
self.arch["embed"] * (2 ** (num_stages - 1 - i)) for i in range(num_stages)
|
|
317
|
+
]
|
|
318
|
+
|
|
312
319
|
self.dec = Decoder(
|
|
313
320
|
x_in_shape=block_filters,
|
|
314
321
|
current_stride=self.current_stride,
|
|
@@ -321,6 +328,7 @@ class SwinTWrapper(nn.Module):
|
|
|
321
328
|
block_contraction=self.block_contraction,
|
|
322
329
|
output_stride=output_stride,
|
|
323
330
|
up_interpolate=up_interpolate,
|
|
331
|
+
encoder_channels=encoder_channels,
|
|
324
332
|
)
|
|
325
333
|
|
|
326
334
|
if len(self.dec.decoder_stack):
|