sleap-nn 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sleap_nn/__init__.py +9 -2
  2. sleap_nn/architectures/convnext.py +5 -0
  3. sleap_nn/architectures/encoder_decoder.py +25 -6
  4. sleap_nn/architectures/swint.py +8 -0
  5. sleap_nn/cli.py +489 -46
  6. sleap_nn/config/data_config.py +51 -8
  7. sleap_nn/config/get_config.py +32 -24
  8. sleap_nn/config/trainer_config.py +88 -0
  9. sleap_nn/data/augmentation.py +61 -200
  10. sleap_nn/data/custom_datasets.py +433 -61
  11. sleap_nn/data/instance_cropping.py +71 -6
  12. sleap_nn/data/normalization.py +45 -2
  13. sleap_nn/data/providers.py +26 -0
  14. sleap_nn/data/resizing.py +2 -2
  15. sleap_nn/data/skia_augmentation.py +414 -0
  16. sleap_nn/data/utils.py +135 -17
  17. sleap_nn/evaluation.py +177 -42
  18. sleap_nn/export/__init__.py +21 -0
  19. sleap_nn/export/cli.py +1778 -0
  20. sleap_nn/export/exporters/__init__.py +51 -0
  21. sleap_nn/export/exporters/onnx_exporter.py +80 -0
  22. sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
  23. sleap_nn/export/metadata.py +225 -0
  24. sleap_nn/export/predictors/__init__.py +63 -0
  25. sleap_nn/export/predictors/base.py +22 -0
  26. sleap_nn/export/predictors/onnx.py +154 -0
  27. sleap_nn/export/predictors/tensorrt.py +312 -0
  28. sleap_nn/export/utils.py +307 -0
  29. sleap_nn/export/wrappers/__init__.py +25 -0
  30. sleap_nn/export/wrappers/base.py +96 -0
  31. sleap_nn/export/wrappers/bottomup.py +243 -0
  32. sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
  33. sleap_nn/export/wrappers/centered_instance.py +56 -0
  34. sleap_nn/export/wrappers/centroid.py +58 -0
  35. sleap_nn/export/wrappers/single_instance.py +83 -0
  36. sleap_nn/export/wrappers/topdown.py +180 -0
  37. sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
  38. sleap_nn/inference/__init__.py +6 -0
  39. sleap_nn/inference/bottomup.py +86 -20
  40. sleap_nn/inference/peak_finding.py +93 -16
  41. sleap_nn/inference/postprocessing.py +284 -0
  42. sleap_nn/inference/predictors.py +339 -137
  43. sleap_nn/inference/provenance.py +292 -0
  44. sleap_nn/inference/topdown.py +55 -47
  45. sleap_nn/legacy_models.py +65 -11
  46. sleap_nn/predict.py +224 -19
  47. sleap_nn/system_info.py +443 -0
  48. sleap_nn/tracking/tracker.py +8 -1
  49. sleap_nn/train.py +138 -44
  50. sleap_nn/training/callbacks.py +1258 -5
  51. sleap_nn/training/lightning_modules.py +902 -220
  52. sleap_nn/training/model_trainer.py +424 -111
  53. sleap_nn/training/schedulers.py +191 -0
  54. sleap_nn/training/utils.py +367 -2
  55. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/METADATA +35 -33
  56. sleap_nn-0.1.0.dist-info/RECORD +88 -0
  57. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/WHEEL +1 -1
  58. sleap_nn-0.0.5.dist-info/RECORD +0 -63
  59. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/entry_points.txt +0 -0
  60. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/licenses/LICENSE +0 -0
  61. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/top_level.txt +0 -0
sleap_nn/__init__.py CHANGED
@@ -41,11 +41,18 @@ def _safe_print(msg):
41
41
 
42
42
 
43
43
  # Add logger with the custom filter
44
+ # Disable colorization to avoid ANSI codes in captured output
44
45
  logger.add(
45
46
  _safe_print,
46
47
  level="DEBUG",
47
48
  filter=_should_log,
48
- format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {name}:{function}:{line} | {message}",
49
+ format="{time:YYYY-MM-DD HH:mm:ss} | {message}",
50
+ colorize=False,
49
51
  )
50
52
 
51
- __version__ = "0.0.5"
53
+ __version__ = "0.1.0"
54
+
55
+ # Public API
56
+ from sleap_nn.evaluation import load_metrics
57
+
58
+ __all__ = ["load_metrics", "__version__"]
@@ -281,6 +281,10 @@ class ConvNextWrapper(nn.Module):
281
281
  # Keep the block output filters the same
282
282
  x_in_shape = int(self.arch["channels"][-1] * filters_rate)
283
283
 
284
+ # Encoder channels for skip connections (reversed to match decoder order)
285
+ # The forward pass uses enc_output[::2][::-1] for skip features
286
+ encoder_channels = self.arch["channels"][::-1]
287
+
284
288
  self.dec = Decoder(
285
289
  x_in_shape=x_in_shape,
286
290
  current_stride=self.current_stride,
@@ -293,6 +297,7 @@ class ConvNextWrapper(nn.Module):
293
297
  block_contraction=self.block_contraction,
294
298
  output_stride=self.output_stride,
295
299
  up_interpolate=up_interpolate,
300
+ encoder_channels=encoder_channels,
296
301
  )
297
302
 
298
303
  if len(self.dec.decoder_stack):
@@ -25,7 +25,7 @@ classes.
25
25
  See the `EncoderDecoder` base class for requirements for creating new architectures.
26
26
  """
27
27
 
28
- from typing import List, Text, Tuple, Union
28
+ from typing import List, Optional, Text, Tuple, Union
29
29
  from collections import OrderedDict
30
30
  import torch
31
31
  from torch import nn
@@ -391,10 +391,18 @@ class SimpleUpsamplingBlock(nn.Module):
391
391
  transpose_convs_activation: Text = "relu",
392
392
  feat_concat: bool = True,
393
393
  prefix: Text = "",
394
+ skip_channels: Optional[int] = None,
394
395
  ) -> None:
395
396
  """Initialize the class."""
396
397
  super().__init__()
397
398
 
399
+ # Determine skip connection channels
400
+ # If skip_channels is provided, use it; otherwise fall back to refine_convs_filters
401
+ # This allows ConvNext/SwinT to specify actual encoder channels
402
+ self.skip_channels = (
403
+ skip_channels if skip_channels is not None else refine_convs_filters
404
+ )
405
+
398
406
  self.x_in_shape = x_in_shape
399
407
  self.current_stride = current_stride
400
408
  self.upsampling_stride = upsampling_stride
@@ -469,13 +477,13 @@ class SimpleUpsamplingBlock(nn.Module):
469
477
  first_conv_in_channels = refine_convs_filters
470
478
  else:
471
479
  if self.up_interpolate:
472
- # With interpolation, input is x_in_shape + feature channels
473
- # The feature channels are the same as x_in_shape since they come from the same level
474
- first_conv_in_channels = x_in_shape + refine_convs_filters
480
+ # With interpolation, input is x_in_shape + skip_channels
481
+ # skip_channels may differ from refine_convs_filters for ConvNext/SwinT
482
+ first_conv_in_channels = x_in_shape + self.skip_channels
475
483
  else:
476
- # With transpose conv, input is transpose_conv_output + feature channels
484
+ # With transpose conv, input is transpose_conv_output + skip_channels
477
485
  first_conv_in_channels = (
478
- refine_convs_filters + transpose_convs_filters
486
+ self.skip_channels + transpose_convs_filters
479
487
  )
480
488
  else:
481
489
  if not self.feat_concat:
@@ -582,6 +590,7 @@ class Decoder(nn.Module):
582
590
  block_contraction: bool = False,
583
591
  up_interpolate: bool = True,
584
592
  prefix: str = "dec",
593
+ encoder_channels: Optional[List[int]] = None,
585
594
  ) -> None:
586
595
  """Initialize the class."""
587
596
  super().__init__()
@@ -598,6 +607,7 @@ class Decoder(nn.Module):
598
607
  self.block_contraction = block_contraction
599
608
  self.prefix = prefix
600
609
  self.stride_to_filters = {}
610
+ self.encoder_channels = encoder_channels
601
611
 
602
612
  self.current_strides = []
603
613
  self.residuals = 0
@@ -624,6 +634,13 @@ class Decoder(nn.Module):
624
634
 
625
635
  next_stride = current_stride // 2
626
636
 
637
+ # Determine skip channels for this decoder block
638
+ # If encoder_channels provided, use actual encoder channels
639
+ # Otherwise fall back to computed filters (for UNet compatibility)
640
+ skip_channels = None
641
+ if encoder_channels is not None and block < len(encoder_channels):
642
+ skip_channels = encoder_channels[block]
643
+
627
644
  if self.stem_blocks > 0 and block >= down_blocks + self.stem_blocks:
628
645
  # This accounts for the case where we dont have any more down block features to concatenate with.
629
646
  # In this case, add a simple upsampling block with a conv layer and with no concatenation
@@ -642,6 +659,7 @@ class Decoder(nn.Module):
642
659
  transpose_convs_batch_norm=False,
643
660
  feat_concat=False,
644
661
  prefix=f"{self.prefix}{block}_s{current_stride}_to_s{next_stride}",
662
+ skip_channels=skip_channels,
645
663
  )
646
664
  )
647
665
  else:
@@ -659,6 +677,7 @@ class Decoder(nn.Module):
659
677
  transpose_convs_filters=block_filters_out,
660
678
  transpose_convs_batch_norm=False,
661
679
  prefix=f"{self.prefix}{block}_s{current_stride}_to_s{next_stride}",
680
+ skip_channels=skip_channels,
662
681
  )
663
682
  )
664
683
 
@@ -309,6 +309,13 @@ class SwinTWrapper(nn.Module):
309
309
  self.stem_patch_stride * (2**3) * 2
310
310
  ) # stem_stride * down_blocks_stride * final_max_pool_stride
311
311
 
312
+ # Encoder channels for skip connections (reversed to match decoder order)
313
+ # SwinT channels: embed * 2^i for each stage i, then reversed
314
+ num_stages = len(self.arch["depths"])
315
+ encoder_channels = [
316
+ self.arch["embed"] * (2 ** (num_stages - 1 - i)) for i in range(num_stages)
317
+ ]
318
+
312
319
  self.dec = Decoder(
313
320
  x_in_shape=block_filters,
314
321
  current_stride=self.current_stride,
@@ -321,6 +328,7 @@ class SwinTWrapper(nn.Module):
321
328
  block_contraction=self.block_contraction,
322
329
  output_stride=output_stride,
323
330
  up_interpolate=up_interpolate,
331
+ encoder_channels=encoder_channels,
324
332
  )
325
333
 
326
334
  if len(self.dec.decoder_stack):