sleap-nn 0.1.0__py3-none-any.whl → 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +1 -1
- sleap_nn/architectures/convnext.py +0 -5
- sleap_nn/architectures/encoder_decoder.py +6 -25
- sleap_nn/architectures/swint.py +0 -8
- sleap_nn/cli.py +60 -364
- sleap_nn/config/data_config.py +5 -11
- sleap_nn/config/get_config.py +4 -5
- sleap_nn/config/trainer_config.py +0 -71
- sleap_nn/data/augmentation.py +241 -50
- sleap_nn/data/custom_datasets.py +34 -364
- sleap_nn/data/instance_cropping.py +1 -1
- sleap_nn/data/resizing.py +2 -2
- sleap_nn/data/utils.py +17 -135
- sleap_nn/evaluation.py +22 -81
- sleap_nn/inference/bottomup.py +20 -86
- sleap_nn/inference/peak_finding.py +19 -88
- sleap_nn/inference/predictors.py +117 -224
- sleap_nn/legacy_models.py +11 -65
- sleap_nn/predict.py +9 -37
- sleap_nn/train.py +4 -69
- sleap_nn/training/callbacks.py +105 -1046
- sleap_nn/training/lightning_modules.py +65 -602
- sleap_nn/training/model_trainer.py +204 -201
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/METADATA +3 -15
- sleap_nn-0.1.0a1.dist-info/RECORD +65 -0
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/WHEEL +1 -1
- sleap_nn/data/skia_augmentation.py +0 -414
- sleap_nn/export/__init__.py +0 -21
- sleap_nn/export/cli.py +0 -1778
- sleap_nn/export/exporters/__init__.py +0 -51
- sleap_nn/export/exporters/onnx_exporter.py +0 -80
- sleap_nn/export/exporters/tensorrt_exporter.py +0 -291
- sleap_nn/export/metadata.py +0 -225
- sleap_nn/export/predictors/__init__.py +0 -63
- sleap_nn/export/predictors/base.py +0 -22
- sleap_nn/export/predictors/onnx.py +0 -154
- sleap_nn/export/predictors/tensorrt.py +0 -312
- sleap_nn/export/utils.py +0 -307
- sleap_nn/export/wrappers/__init__.py +0 -25
- sleap_nn/export/wrappers/base.py +0 -96
- sleap_nn/export/wrappers/bottomup.py +0 -243
- sleap_nn/export/wrappers/bottomup_multiclass.py +0 -195
- sleap_nn/export/wrappers/centered_instance.py +0 -56
- sleap_nn/export/wrappers/centroid.py +0 -58
- sleap_nn/export/wrappers/single_instance.py +0 -83
- sleap_nn/export/wrappers/topdown.py +0 -180
- sleap_nn/export/wrappers/topdown_multiclass.py +0 -304
- sleap_nn/inference/postprocessing.py +0 -284
- sleap_nn/training/schedulers.py +0 -191
- sleap_nn-0.1.0.dist-info/RECORD +0 -88
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/top_level.txt +0 -0
sleap_nn/predict.py
CHANGED
|
@@ -67,16 +67,13 @@ def run_inference(
|
|
|
67
67
|
only_predicted_frames: bool = False,
|
|
68
68
|
no_empty_frames: bool = False,
|
|
69
69
|
batch_size: int = 4,
|
|
70
|
-
queue_maxsize: int =
|
|
70
|
+
queue_maxsize: int = 8,
|
|
71
71
|
video_index: Optional[int] = None,
|
|
72
72
|
video_dataset: Optional[str] = None,
|
|
73
73
|
video_input_format: str = "channels_last",
|
|
74
74
|
frames: Optional[list] = None,
|
|
75
75
|
crop_size: Optional[int] = None,
|
|
76
76
|
peak_threshold: Union[float, List[float]] = 0.2,
|
|
77
|
-
filter_overlapping: bool = False,
|
|
78
|
-
filter_overlapping_method: str = "iou",
|
|
79
|
-
filter_overlapping_threshold: float = 0.8,
|
|
80
77
|
integral_refinement: Optional[str] = "integral",
|
|
81
78
|
integral_patch_size: int = 5,
|
|
82
79
|
return_confmaps: bool = False,
|
|
@@ -113,7 +110,6 @@ def run_inference(
|
|
|
113
110
|
tracking_pre_cull_iou_threshold: float = 0,
|
|
114
111
|
tracking_clean_instance_count: int = 0,
|
|
115
112
|
tracking_clean_iou_threshold: float = 0,
|
|
116
|
-
gui: bool = False,
|
|
117
113
|
):
|
|
118
114
|
"""Entry point to run inference on trained SLEAP-NN models.
|
|
119
115
|
|
|
@@ -151,7 +147,7 @@ def run_inference(
|
|
|
151
147
|
only_predicted_frames: (bool) `True` to run inference only on frames that already have predictions. Default: `False`.
|
|
152
148
|
no_empty_frames: (bool) `True` if empty frames that did not have predictions should be cleared before saving to output. Default: `False`.
|
|
153
149
|
batch_size: (int) Number of samples per batch. Default: 4.
|
|
154
|
-
queue_maxsize: (int) Maximum size of the frame buffer queue. Default:
|
|
150
|
+
queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 8.
|
|
155
151
|
video_index: (int) Integer index of video in .slp file to predict on. To be used with
|
|
156
152
|
an .slp path as an alternative to specifying the video path.
|
|
157
153
|
video_dataset: (str) The dataset for HDF5 videos.
|
|
@@ -164,15 +160,6 @@ def run_inference(
|
|
|
164
160
|
centroid and centered-instance model, where the first element corresponds
|
|
165
161
|
to centroid model peak finding threshold and the second element is for
|
|
166
162
|
centered-instance model peak finding.
|
|
167
|
-
filter_overlapping: (bool) If True, removes overlapping instances after
|
|
168
|
-
inference using greedy NMS. Applied independently of tracking.
|
|
169
|
-
Default: False.
|
|
170
|
-
filter_overlapping_method: (str) Similarity metric for filtering overlapping
|
|
171
|
-
instances. One of "iou" (bounding box) or "oks" (keypoint similarity).
|
|
172
|
-
Default: "iou".
|
|
173
|
-
filter_overlapping_threshold: (float) Similarity threshold for filtering.
|
|
174
|
-
Instances with similarity > threshold are removed (keeping higher-scoring).
|
|
175
|
-
Typical values: 0.3 (aggressive) to 0.8 (permissive). Default: 0.8.
|
|
176
163
|
integral_refinement: (str) If `None`, returns the grid-aligned peaks with no refinement.
|
|
177
164
|
If `"integral"`, peaks will be refined with integral regression.
|
|
178
165
|
Default: `"integral"`.
|
|
@@ -263,8 +250,6 @@ def run_inference(
|
|
|
263
250
|
tracking_pre_cull_iou_threshold: If non-zero and pre_cull_to_target also set, then use IOU threshold to remove overlapping instances over count *before* tracking. (default: 0)
|
|
264
251
|
tracking_clean_instance_count: Target number of instances to clean *after* tracking. (default: 0)
|
|
265
252
|
tracking_clean_iou_threshold: IOU to use when culling instances *after* tracking. (default: 0)
|
|
266
|
-
gui: (bool) If True, outputs JSON progress lines for GUI integration instead
|
|
267
|
-
of Rich progress bars. Default: False.
|
|
268
253
|
|
|
269
254
|
Returns:
|
|
270
255
|
Returns `sio.Labels` object if `make_labels` is True. Else this function returns
|
|
@@ -448,6 +433,13 @@ def run_inference(
|
|
|
448
433
|
else "mps" if torch.backends.mps.is_available() else "cpu"
|
|
449
434
|
)
|
|
450
435
|
|
|
436
|
+
if integral_refinement is not None and device == "mps": # TODO
|
|
437
|
+
# kornia/geometry/transform/imgwarp.py:382: in get_perspective_transform. NotImplementedError: The operator 'aten::_linalg_solve_ex.result' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
|
|
438
|
+
logger.info(
|
|
439
|
+
"Integral refinement is not supported with MPS accelerator. Setting integral refinement to None."
|
|
440
|
+
)
|
|
441
|
+
integral_refinement = None
|
|
442
|
+
|
|
451
443
|
logger.info(f"Using device: {device}")
|
|
452
444
|
|
|
453
445
|
# initializes the inference model
|
|
@@ -466,9 +458,6 @@ def run_inference(
|
|
|
466
458
|
anchor_part=anchor_part,
|
|
467
459
|
)
|
|
468
460
|
|
|
469
|
-
# Set GUI mode for progress output
|
|
470
|
-
predictor.gui = gui
|
|
471
|
-
|
|
472
461
|
if (
|
|
473
462
|
tracking
|
|
474
463
|
and not isinstance(predictor, BottomUpMultiClassPredictor)
|
|
@@ -564,20 +553,6 @@ def run_inference(
|
|
|
564
553
|
make_labels=make_labels,
|
|
565
554
|
)
|
|
566
555
|
|
|
567
|
-
# Filter overlapping instances (independent of tracking)
|
|
568
|
-
if filter_overlapping and make_labels:
|
|
569
|
-
from sleap_nn.inference.postprocessing import filter_overlapping_instances
|
|
570
|
-
|
|
571
|
-
output = filter_overlapping_instances(
|
|
572
|
-
output,
|
|
573
|
-
threshold=filter_overlapping_threshold,
|
|
574
|
-
method=filter_overlapping_method,
|
|
575
|
-
)
|
|
576
|
-
logger.info(
|
|
577
|
-
f"Filtered overlapping instances with {filter_overlapping_method.upper()} "
|
|
578
|
-
f"threshold: {filter_overlapping_threshold}"
|
|
579
|
-
)
|
|
580
|
-
|
|
581
556
|
if tracking:
|
|
582
557
|
lfs = [x for x in output]
|
|
583
558
|
if tracking_clean_instance_count > 0:
|
|
@@ -632,9 +607,6 @@ def run_inference(
|
|
|
632
607
|
# Build inference parameters for provenance
|
|
633
608
|
inference_params = {
|
|
634
609
|
"peak_threshold": peak_threshold,
|
|
635
|
-
"filter_overlapping": filter_overlapping,
|
|
636
|
-
"filter_overlapping_method": filter_overlapping_method,
|
|
637
|
-
"filter_overlapping_threshold": filter_overlapping_threshold,
|
|
638
610
|
"integral_refinement": integral_refinement,
|
|
639
611
|
"integral_patch_size": integral_patch_size,
|
|
640
612
|
"batch_size": batch_size,
|
sleap_nn/train.py
CHANGED
|
@@ -118,70 +118,6 @@ def run_training(
|
|
|
118
118
|
logger.info(f"p90 dist: {metrics['distance_metrics']['p90']}")
|
|
119
119
|
logger.info(f"p50 dist: {metrics['distance_metrics']['p50']}")
|
|
120
120
|
|
|
121
|
-
# Log test metrics to wandb summary
|
|
122
|
-
if (
|
|
123
|
-
d_name.startswith("test")
|
|
124
|
-
and trainer.config.trainer_config.use_wandb
|
|
125
|
-
):
|
|
126
|
-
import wandb
|
|
127
|
-
|
|
128
|
-
if wandb.run is not None:
|
|
129
|
-
summary_metrics = {
|
|
130
|
-
f"eval/{d_name}/mOKS": metrics["mOKS"]["mOKS"],
|
|
131
|
-
f"eval/{d_name}/oks_voc_mAP": metrics["voc_metrics"][
|
|
132
|
-
"oks_voc.mAP"
|
|
133
|
-
],
|
|
134
|
-
f"eval/{d_name}/oks_voc_mAR": metrics["voc_metrics"][
|
|
135
|
-
"oks_voc.mAR"
|
|
136
|
-
],
|
|
137
|
-
f"eval/{d_name}/mPCK": metrics["pck_metrics"]["mPCK"],
|
|
138
|
-
f"eval/{d_name}/PCK_5": metrics["pck_metrics"]["PCK@5"],
|
|
139
|
-
f"eval/{d_name}/PCK_10": metrics["pck_metrics"]["PCK@10"],
|
|
140
|
-
f"eval/{d_name}/distance_avg": metrics["distance_metrics"][
|
|
141
|
-
"avg"
|
|
142
|
-
],
|
|
143
|
-
f"eval/{d_name}/distance_p50": metrics["distance_metrics"][
|
|
144
|
-
"p50"
|
|
145
|
-
],
|
|
146
|
-
f"eval/{d_name}/distance_p95": metrics["distance_metrics"][
|
|
147
|
-
"p95"
|
|
148
|
-
],
|
|
149
|
-
f"eval/{d_name}/distance_p99": metrics["distance_metrics"][
|
|
150
|
-
"p99"
|
|
151
|
-
],
|
|
152
|
-
f"eval/{d_name}/visibility_precision": metrics[
|
|
153
|
-
"visibility_metrics"
|
|
154
|
-
]["precision"],
|
|
155
|
-
f"eval/{d_name}/visibility_recall": metrics[
|
|
156
|
-
"visibility_metrics"
|
|
157
|
-
]["recall"],
|
|
158
|
-
}
|
|
159
|
-
for key, value in summary_metrics.items():
|
|
160
|
-
wandb.run.summary[key] = value
|
|
161
|
-
|
|
162
|
-
# Finish wandb run and cleanup after all evaluation is complete
|
|
163
|
-
if trainer.config.trainer_config.use_wandb:
|
|
164
|
-
import wandb
|
|
165
|
-
import shutil
|
|
166
|
-
|
|
167
|
-
if wandb.run is not None:
|
|
168
|
-
wandb.finish()
|
|
169
|
-
|
|
170
|
-
# Delete local wandb logs if configured
|
|
171
|
-
wandb_config = trainer.config.trainer_config.wandb
|
|
172
|
-
should_delete_wandb_logs = wandb_config.delete_local_logs is True or (
|
|
173
|
-
wandb_config.delete_local_logs is None
|
|
174
|
-
and wandb_config.wandb_mode != "offline"
|
|
175
|
-
)
|
|
176
|
-
if should_delete_wandb_logs:
|
|
177
|
-
wandb_dir = run_path / "wandb"
|
|
178
|
-
if wandb_dir.exists():
|
|
179
|
-
logger.info(
|
|
180
|
-
f"Deleting local wandb logs at {wandb_dir}... "
|
|
181
|
-
"(set trainer_config.wandb.delete_local_logs=false to disable)"
|
|
182
|
-
)
|
|
183
|
-
shutil.rmtree(wandb_dir, ignore_errors=True)
|
|
184
|
-
|
|
185
121
|
|
|
186
122
|
def train(
|
|
187
123
|
train_labels_path: Optional[List[str]] = None,
|
|
@@ -203,9 +139,9 @@ def train(
|
|
|
203
139
|
crop_size: Optional[int] = None,
|
|
204
140
|
min_crop_size: Optional[int] = 100,
|
|
205
141
|
crop_padding: Optional[int] = None,
|
|
206
|
-
use_augmentations_train: bool =
|
|
142
|
+
use_augmentations_train: bool = False,
|
|
207
143
|
intensity_aug: Optional[Union[str, List[str], Dict[str, Any]]] = None,
|
|
208
|
-
geometry_aug: Optional[Union[str, List[str], Dict[str, Any]]] =
|
|
144
|
+
geometry_aug: Optional[Union[str, List[str], Dict[str, Any]]] = None,
|
|
209
145
|
init_weight: str = "default",
|
|
210
146
|
pretrained_backbone_weights: Optional[str] = None,
|
|
211
147
|
pretrained_head_weights: Optional[str] = None,
|
|
@@ -306,7 +242,7 @@ def train(
|
|
|
306
242
|
crop size. If `None`, padding is auto-computed based on augmentation settings.
|
|
307
243
|
Only used when `crop_size` is `None`. Default: None.
|
|
308
244
|
use_augmentations_train: True if the data augmentation should be applied to the
|
|
309
|
-
training data, else False. Default:
|
|
245
|
+
training data, else False. Default: False.
|
|
310
246
|
intensity_aug: One of ["uniform_noise", "gaussian_noise", "contrast", "brightness"]
|
|
311
247
|
or list of strings from the above allowed values. To have custom values, pass
|
|
312
248
|
a dict with the structure in `sleap_nn.config.data_config.IntensityConfig`.
|
|
@@ -318,8 +254,7 @@ def train(
|
|
|
318
254
|
or list of strings from the above allowed values. To have custom values, pass
|
|
319
255
|
a dict with the structure in `sleap_nn.config.data_config.GeometryConfig`.
|
|
320
256
|
For eg: {
|
|
321
|
-
"
|
|
322
|
-
"rotation_max": 45,
|
|
257
|
+
"rotation": 45,
|
|
323
258
|
"affine_p": 1.0
|
|
324
259
|
}
|
|
325
260
|
init_weight: model weights initialization method. "default" uses kaiming uniform
|