sleap-nn 0.0.5__py3-none-any.whl → 0.1.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +6 -1
- sleap_nn/cli.py +142 -3
- sleap_nn/config/data_config.py +44 -7
- sleap_nn/config/get_config.py +22 -20
- sleap_nn/config/trainer_config.py +12 -0
- sleap_nn/data/augmentation.py +54 -2
- sleap_nn/data/custom_datasets.py +22 -22
- sleap_nn/data/instance_cropping.py +70 -5
- sleap_nn/data/normalization.py +45 -2
- sleap_nn/data/providers.py +26 -0
- sleap_nn/evaluation.py +99 -23
- sleap_nn/inference/__init__.py +6 -0
- sleap_nn/inference/peak_finding.py +10 -2
- sleap_nn/inference/predictors.py +115 -20
- sleap_nn/inference/provenance.py +292 -0
- sleap_nn/inference/topdown.py +55 -47
- sleap_nn/predict.py +187 -10
- sleap_nn/system_info.py +443 -0
- sleap_nn/tracking/tracker.py +8 -1
- sleap_nn/train.py +64 -40
- sleap_nn/training/callbacks.py +317 -5
- sleap_nn/training/lightning_modules.py +325 -180
- sleap_nn/training/model_trainer.py +308 -22
- sleap_nn/training/utils.py +367 -2
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/METADATA +22 -32
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/RECORD +30 -28
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/WHEEL +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/top_level.txt +0 -0
sleap_nn/training/utils.py
CHANGED
|
@@ -1,13 +1,19 @@
|
|
|
1
1
|
"""Miscellaneous utility functions for training."""
|
|
2
2
|
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from io import BytesIO
|
|
3
5
|
import numpy as np
|
|
6
|
+
import matplotlib
|
|
7
|
+
|
|
8
|
+
matplotlib.use(
|
|
9
|
+
"Agg"
|
|
10
|
+
) # Use non-interactive backend to avoid tkinter issues on Windows CI
|
|
4
11
|
import matplotlib.pyplot as plt
|
|
5
12
|
from loguru import logger
|
|
6
13
|
from torch import nn
|
|
7
14
|
import torch.distributed as dist
|
|
8
|
-
import matplotlib
|
|
9
15
|
import seaborn as sns
|
|
10
|
-
from typing import List
|
|
16
|
+
from typing import List, Optional
|
|
11
17
|
import shutil
|
|
12
18
|
import os
|
|
13
19
|
import subprocess
|
|
@@ -236,3 +242,362 @@ def plot_peaks(
|
|
|
236
242
|
)
|
|
237
243
|
)
|
|
238
244
|
return handles
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
@dataclass
|
|
248
|
+
class VisualizationData:
|
|
249
|
+
"""Container for visualization data from a single sample.
|
|
250
|
+
|
|
251
|
+
This dataclass decouples data extraction from rendering, allowing the same
|
|
252
|
+
data to be rendered to different output targets (matplotlib, wandb, etc.).
|
|
253
|
+
|
|
254
|
+
Attributes:
|
|
255
|
+
image: Input image as (H, W, C) numpy array, normalized to [0, 1].
|
|
256
|
+
pred_confmaps: Predicted confidence maps as (H, W, nodes) array, values in [0, 1].
|
|
257
|
+
pred_peaks: Predicted keypoints as (instances, nodes, 2) or (nodes, 2) array.
|
|
258
|
+
pred_peak_values: Confidence values as (instances, nodes) or (nodes,) array.
|
|
259
|
+
gt_instances: Ground truth keypoints, same shape as pred_peaks.
|
|
260
|
+
node_names: List of node/keypoint names, e.g., ["head", "thorax", ...].
|
|
261
|
+
output_scale: Ratio of confmap size to image size (confmap_h / image_h).
|
|
262
|
+
is_paired: Whether GT and predictions can be paired for error visualization.
|
|
263
|
+
pred_pafs: Part affinity fields for bottom-up models, optional.
|
|
264
|
+
pred_class_maps: Class maps for multi-class models, optional.
|
|
265
|
+
"""
|
|
266
|
+
|
|
267
|
+
image: np.ndarray
|
|
268
|
+
pred_confmaps: np.ndarray
|
|
269
|
+
pred_peaks: np.ndarray
|
|
270
|
+
pred_peak_values: np.ndarray
|
|
271
|
+
gt_instances: np.ndarray
|
|
272
|
+
node_names: List[str] = field(default_factory=list)
|
|
273
|
+
output_scale: float = 1.0
|
|
274
|
+
is_paired: bool = True
|
|
275
|
+
pred_pafs: Optional[np.ndarray] = None
|
|
276
|
+
pred_class_maps: Optional[np.ndarray] = None
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
class MatplotlibRenderer:
|
|
280
|
+
"""Renders VisualizationData to matplotlib figures."""
|
|
281
|
+
|
|
282
|
+
def render(self, data: VisualizationData) -> matplotlib.figure.Figure:
|
|
283
|
+
"""Render visualization data to a matplotlib figure.
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
data: VisualizationData containing image, confmaps, peaks, etc.
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
A matplotlib Figure object.
|
|
290
|
+
"""
|
|
291
|
+
img = data.image
|
|
292
|
+
scale = 1.0
|
|
293
|
+
if img.shape[0] < 512:
|
|
294
|
+
scale = 2.0
|
|
295
|
+
if img.shape[0] < 256:
|
|
296
|
+
scale = 4.0
|
|
297
|
+
|
|
298
|
+
fig = plot_img(img, dpi=72 * scale, scale=scale)
|
|
299
|
+
plot_confmaps(data.pred_confmaps, output_scale=data.output_scale)
|
|
300
|
+
plot_peaks(data.gt_instances, data.pred_peaks, paired=data.is_paired)
|
|
301
|
+
return fig
|
|
302
|
+
|
|
303
|
+
def render_pafs(self, data: VisualizationData) -> matplotlib.figure.Figure:
|
|
304
|
+
"""Render PAF magnitude visualization.
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
data: VisualizationData with pred_pafs populated.
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
A matplotlib Figure object showing PAF magnitudes.
|
|
311
|
+
"""
|
|
312
|
+
if data.pred_pafs is None:
|
|
313
|
+
raise ValueError("pred_pafs is None, cannot render PAFs")
|
|
314
|
+
|
|
315
|
+
img = data.image
|
|
316
|
+
scale = 1.0
|
|
317
|
+
if img.shape[0] < 512:
|
|
318
|
+
scale = 2.0
|
|
319
|
+
if img.shape[0] < 256:
|
|
320
|
+
scale = 4.0
|
|
321
|
+
|
|
322
|
+
# Compute PAF magnitude
|
|
323
|
+
pafs = data.pred_pafs # (H, W, 2*edges) or (H, W, edges, 2)
|
|
324
|
+
if pafs.ndim == 3:
|
|
325
|
+
n_edges = pafs.shape[-1] // 2
|
|
326
|
+
pafs = pafs.reshape(pafs.shape[0], pafs.shape[1], n_edges, 2)
|
|
327
|
+
magnitude = np.sqrt(pafs[..., 0] ** 2 + pafs[..., 1] ** 2)
|
|
328
|
+
magnitude = magnitude.max(axis=-1) # Max over edges
|
|
329
|
+
|
|
330
|
+
fig = plot_img(img, dpi=72 * scale, scale=scale)
|
|
331
|
+
ax = plt.gca()
|
|
332
|
+
|
|
333
|
+
# Calculate PAF output scale from actual PAF dimensions, not confmap output_scale
|
|
334
|
+
# PAFs may have a different output_stride than confmaps
|
|
335
|
+
paf_output_scale = magnitude.shape[0] / img.shape[0]
|
|
336
|
+
|
|
337
|
+
ax.imshow(
|
|
338
|
+
magnitude,
|
|
339
|
+
alpha=0.5,
|
|
340
|
+
origin="upper",
|
|
341
|
+
cmap="viridis",
|
|
342
|
+
extent=[
|
|
343
|
+
-0.5,
|
|
344
|
+
magnitude.shape[1] / paf_output_scale - 0.5,
|
|
345
|
+
magnitude.shape[0] / paf_output_scale - 0.5,
|
|
346
|
+
-0.5,
|
|
347
|
+
],
|
|
348
|
+
)
|
|
349
|
+
return fig
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
class WandBRenderer:
|
|
353
|
+
"""Renders VisualizationData to wandb.Image objects.
|
|
354
|
+
|
|
355
|
+
Supports multiple rendering modes:
|
|
356
|
+
- "direct": Pre-render with matplotlib, convert to wandb.Image
|
|
357
|
+
- "boxes": Use wandb boxes for interactive keypoint visualization
|
|
358
|
+
- "masks": Use wandb masks for confidence map overlay
|
|
359
|
+
"""
|
|
360
|
+
|
|
361
|
+
def __init__(
|
|
362
|
+
self,
|
|
363
|
+
mode: str = "direct",
|
|
364
|
+
box_size: float = 5.0,
|
|
365
|
+
confmap_threshold: float = 0.1,
|
|
366
|
+
min_size: int = 512,
|
|
367
|
+
):
|
|
368
|
+
"""Initialize the renderer.
|
|
369
|
+
|
|
370
|
+
Args:
|
|
371
|
+
mode: Rendering mode - "direct", "boxes", or "masks".
|
|
372
|
+
box_size: Size of keypoint boxes in pixels (for "boxes" mode).
|
|
373
|
+
confmap_threshold: Threshold for confmap mask (for "masks" mode).
|
|
374
|
+
min_size: Minimum image dimension. Smaller images will be upscaled.
|
|
375
|
+
"""
|
|
376
|
+
self.mode = mode
|
|
377
|
+
self.box_size = box_size
|
|
378
|
+
self.confmap_threshold = confmap_threshold
|
|
379
|
+
self.min_size = min_size
|
|
380
|
+
self._mpl_renderer = MatplotlibRenderer()
|
|
381
|
+
|
|
382
|
+
def render(
|
|
383
|
+
self, data: VisualizationData, caption: Optional[str] = None
|
|
384
|
+
) -> "wandb.Image":
|
|
385
|
+
"""Render visualization data to a wandb.Image.
|
|
386
|
+
|
|
387
|
+
Args:
|
|
388
|
+
data: VisualizationData containing image, confmaps, peaks, etc.
|
|
389
|
+
caption: Optional caption for the image.
|
|
390
|
+
|
|
391
|
+
Returns:
|
|
392
|
+
A wandb.Image object.
|
|
393
|
+
"""
|
|
394
|
+
import wandb
|
|
395
|
+
|
|
396
|
+
if self.mode == "boxes":
|
|
397
|
+
return self._render_with_boxes(data, caption)
|
|
398
|
+
elif self.mode == "masks":
|
|
399
|
+
return self._render_with_masks(data, caption)
|
|
400
|
+
else: # "direct"
|
|
401
|
+
return self._render_direct(data, caption)
|
|
402
|
+
|
|
403
|
+
def _get_scale_factor(self, img_h: int, img_w: int) -> int:
|
|
404
|
+
"""Calculate scale factor to ensure minimum image size."""
|
|
405
|
+
min_dim = min(img_h, img_w)
|
|
406
|
+
if min_dim >= self.min_size:
|
|
407
|
+
return 1
|
|
408
|
+
return int(np.ceil(self.min_size / min_dim))
|
|
409
|
+
|
|
410
|
+
def _render_direct(
|
|
411
|
+
self, data: VisualizationData, caption: Optional[str] = None
|
|
412
|
+
) -> "wandb.Image":
|
|
413
|
+
"""Pre-render with matplotlib, return as wandb.Image."""
|
|
414
|
+
import wandb
|
|
415
|
+
from PIL import Image
|
|
416
|
+
|
|
417
|
+
fig = self._mpl_renderer.render(data)
|
|
418
|
+
|
|
419
|
+
# Convert figure to PIL Image
|
|
420
|
+
buf = BytesIO()
|
|
421
|
+
fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
|
|
422
|
+
buf.seek(0)
|
|
423
|
+
plt.close(fig)
|
|
424
|
+
|
|
425
|
+
pil_image = Image.open(buf)
|
|
426
|
+
return wandb.Image(pil_image, caption=caption)
|
|
427
|
+
|
|
428
|
+
def _render_with_boxes(
|
|
429
|
+
self, data: VisualizationData, caption: Optional[str] = None
|
|
430
|
+
) -> "wandb.Image":
|
|
431
|
+
"""Use wandb boxes for interactive keypoint visualization."""
|
|
432
|
+
import wandb
|
|
433
|
+
from PIL import Image
|
|
434
|
+
|
|
435
|
+
# Prepare class labels from node names
|
|
436
|
+
class_labels = {i: name for i, name in enumerate(data.node_names)}
|
|
437
|
+
if not class_labels:
|
|
438
|
+
class_labels = {i: f"node_{i}" for i in range(data.pred_peaks.shape[-2])}
|
|
439
|
+
|
|
440
|
+
# Convert image to uint8
|
|
441
|
+
img_uint8 = (np.clip(data.image, 0, 1) * 255).astype(np.uint8)
|
|
442
|
+
# Handle single-channel images: squeeze (H, W, 1) -> (H, W)
|
|
443
|
+
if img_uint8.ndim == 3 and img_uint8.shape[2] == 1:
|
|
444
|
+
img_uint8 = img_uint8.squeeze(axis=2)
|
|
445
|
+
img_h, img_w = img_uint8.shape[:2]
|
|
446
|
+
|
|
447
|
+
# Scale up small images for better visibility in wandb
|
|
448
|
+
scale = self._get_scale_factor(img_h, img_w)
|
|
449
|
+
if scale > 1:
|
|
450
|
+
pil_img = Image.fromarray(img_uint8)
|
|
451
|
+
pil_img = pil_img.resize(
|
|
452
|
+
(img_w * scale, img_h * scale), resample=Image.BILINEAR
|
|
453
|
+
)
|
|
454
|
+
img_uint8 = np.array(pil_img)
|
|
455
|
+
|
|
456
|
+
# Build ground truth boxes (use percent domain for proper thumbnail scaling)
|
|
457
|
+
gt_box_data = self._peaks_to_boxes(
|
|
458
|
+
data.gt_instances, data.node_names, img_w, img_h, is_gt=True
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
# Build prediction boxes
|
|
462
|
+
pred_box_data = self._peaks_to_boxes(
|
|
463
|
+
data.pred_peaks,
|
|
464
|
+
data.node_names,
|
|
465
|
+
img_w,
|
|
466
|
+
img_h,
|
|
467
|
+
peak_values=data.pred_peak_values,
|
|
468
|
+
is_gt=False,
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
return wandb.Image(
|
|
472
|
+
img_uint8,
|
|
473
|
+
boxes={
|
|
474
|
+
"ground_truth": {"box_data": gt_box_data, "class_labels": class_labels},
|
|
475
|
+
"predictions": {
|
|
476
|
+
"box_data": pred_box_data,
|
|
477
|
+
"class_labels": class_labels,
|
|
478
|
+
},
|
|
479
|
+
},
|
|
480
|
+
caption=caption,
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
def _peaks_to_boxes(
|
|
484
|
+
self,
|
|
485
|
+
peaks: np.ndarray,
|
|
486
|
+
node_names: List[str],
|
|
487
|
+
img_w: int,
|
|
488
|
+
img_h: int,
|
|
489
|
+
peak_values: Optional[np.ndarray] = None,
|
|
490
|
+
is_gt: bool = False,
|
|
491
|
+
) -> List[dict]:
|
|
492
|
+
"""Convert peaks array to wandb box_data format.
|
|
493
|
+
|
|
494
|
+
Args:
|
|
495
|
+
peaks: Keypoints as (instances, nodes, 2) or (nodes, 2).
|
|
496
|
+
node_names: List of node names.
|
|
497
|
+
img_w: Image width in pixels.
|
|
498
|
+
img_h: Image height in pixels.
|
|
499
|
+
peak_values: Optional confidence values.
|
|
500
|
+
is_gt: Whether these are ground truth points.
|
|
501
|
+
|
|
502
|
+
Returns:
|
|
503
|
+
List of box dictionaries for wandb.
|
|
504
|
+
"""
|
|
505
|
+
box_data = []
|
|
506
|
+
|
|
507
|
+
# Normalize shape to (instances, nodes, 2)
|
|
508
|
+
if peaks.ndim == 2:
|
|
509
|
+
peaks = peaks[np.newaxis, ...]
|
|
510
|
+
if peak_values is not None and peak_values.ndim == 1:
|
|
511
|
+
peak_values = peak_values[np.newaxis, ...]
|
|
512
|
+
|
|
513
|
+
# Convert box_size from pixels to percent
|
|
514
|
+
box_w_pct = self.box_size / img_w
|
|
515
|
+
box_h_pct = self.box_size / img_h
|
|
516
|
+
|
|
517
|
+
for inst_idx, instance in enumerate(peaks):
|
|
518
|
+
for node_idx, (x, y) in enumerate(instance):
|
|
519
|
+
if np.isnan(x) or np.isnan(y):
|
|
520
|
+
continue
|
|
521
|
+
|
|
522
|
+
node_name = (
|
|
523
|
+
node_names[node_idx]
|
|
524
|
+
if node_idx < len(node_names)
|
|
525
|
+
else f"node_{node_idx}"
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
# Convert pixel coordinates to percent (0-1 range)
|
|
529
|
+
x_pct = float(x) / img_w
|
|
530
|
+
y_pct = float(y) / img_h
|
|
531
|
+
|
|
532
|
+
box = {
|
|
533
|
+
"position": {
|
|
534
|
+
"middle": [x_pct, y_pct],
|
|
535
|
+
"width": box_w_pct,
|
|
536
|
+
"height": box_h_pct,
|
|
537
|
+
},
|
|
538
|
+
"domain": "percent",
|
|
539
|
+
"class_id": node_idx,
|
|
540
|
+
}
|
|
541
|
+
|
|
542
|
+
if is_gt:
|
|
543
|
+
box["box_caption"] = f"GT: {node_name}"
|
|
544
|
+
else:
|
|
545
|
+
if peak_values is not None:
|
|
546
|
+
conf = float(peak_values[inst_idx, node_idx])
|
|
547
|
+
box["box_caption"] = f"{node_name} ({conf:.2f})"
|
|
548
|
+
box["scores"] = {"confidence": conf}
|
|
549
|
+
else:
|
|
550
|
+
box["box_caption"] = node_name
|
|
551
|
+
|
|
552
|
+
box_data.append(box)
|
|
553
|
+
|
|
554
|
+
return box_data
|
|
555
|
+
|
|
556
|
+
def _render_with_masks(
|
|
557
|
+
self, data: VisualizationData, caption: Optional[str] = None
|
|
558
|
+
) -> "wandb.Image":
|
|
559
|
+
"""Use wandb masks for confidence map overlay.
|
|
560
|
+
|
|
561
|
+
Uses argmax approach: each pixel shows the dominant node.
|
|
562
|
+
"""
|
|
563
|
+
import wandb
|
|
564
|
+
|
|
565
|
+
# Prepare class labels (0 = background, 1+ = nodes)
|
|
566
|
+
class_labels = {0: "background"}
|
|
567
|
+
for i, name in enumerate(data.node_names):
|
|
568
|
+
class_labels[i + 1] = name
|
|
569
|
+
if not data.node_names:
|
|
570
|
+
n_nodes = data.pred_confmaps.shape[-1]
|
|
571
|
+
for i in range(n_nodes):
|
|
572
|
+
class_labels[i + 1] = f"node_{i}"
|
|
573
|
+
|
|
574
|
+
# Create argmax mask from confmaps
|
|
575
|
+
confmaps = data.pred_confmaps # (H/stride, W/stride, nodes)
|
|
576
|
+
max_vals = confmaps.max(axis=-1)
|
|
577
|
+
argmax_map = confmaps.argmax(axis=-1) + 1 # +1 for background offset
|
|
578
|
+
argmax_map[max_vals < self.confmap_threshold] = 0 # Background
|
|
579
|
+
|
|
580
|
+
# Convert image to uint8
|
|
581
|
+
img_uint8 = (np.clip(data.image, 0, 1) * 255).astype(np.uint8)
|
|
582
|
+
# Handle single-channel images: (H, W, 1) -> (H, W)
|
|
583
|
+
if img_uint8.ndim == 3 and img_uint8.shape[2] == 1:
|
|
584
|
+
img_uint8 = img_uint8.squeeze(axis=2)
|
|
585
|
+
img_h, img_w = img_uint8.shape[:2]
|
|
586
|
+
|
|
587
|
+
# Resize mask to match image dimensions (confmaps are H/stride, W/stride)
|
|
588
|
+
from PIL import Image
|
|
589
|
+
|
|
590
|
+
mask_pil = Image.fromarray(argmax_map.astype(np.uint8))
|
|
591
|
+
mask_pil = mask_pil.resize((img_w, img_h), resample=Image.NEAREST)
|
|
592
|
+
argmax_map = np.array(mask_pil)
|
|
593
|
+
|
|
594
|
+
return wandb.Image(
|
|
595
|
+
img_uint8,
|
|
596
|
+
masks={
|
|
597
|
+
"confidence_maps": {
|
|
598
|
+
"mask_data": argmax_map.astype(np.uint8),
|
|
599
|
+
"class_labels": class_labels,
|
|
600
|
+
}
|
|
601
|
+
},
|
|
602
|
+
caption=caption,
|
|
603
|
+
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sleap-nn
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.1.0a0
|
|
4
4
|
Summary: Neural network backend for training and inference for animal pose estimation.
|
|
5
5
|
Author-email: Divya Seshadri Murali <dimurali@salk.edu>, Elizabeth Berrigan <eberrigan@salk.edu>, Vincent Tu <vitu@ucsd.edu>, Liezl Maree <lmaree@salk.edu>, David Samy <davidasamy@gmail.com>, Talmo Pereira <talmo@salk.edu>
|
|
6
6
|
License: BSD-3-Clause
|
|
@@ -13,7 +13,7 @@ Classifier: Programming Language :: Python :: 3.13
|
|
|
13
13
|
Requires-Python: <3.14,>=3.11
|
|
14
14
|
Description-Content-Type: text/markdown
|
|
15
15
|
License-File: LICENSE
|
|
16
|
-
Requires-Dist: sleap-io
|
|
16
|
+
Requires-Dist: sleap-io<0.7.0,>=0.6.0
|
|
17
17
|
Requires-Dist: numpy
|
|
18
18
|
Requires-Dist: lightning
|
|
19
19
|
Requires-Dist: kornia
|
|
@@ -34,35 +34,19 @@ Requires-Dist: jupyterlab
|
|
|
34
34
|
Requires-Dist: pyzmq
|
|
35
35
|
Provides-Extra: torch
|
|
36
36
|
Requires-Dist: torch; extra == "torch"
|
|
37
|
-
Requires-Dist: torchvision
|
|
37
|
+
Requires-Dist: torchvision>=0.20.0; extra == "torch"
|
|
38
38
|
Provides-Extra: torch-cpu
|
|
39
39
|
Requires-Dist: torch; extra == "torch-cpu"
|
|
40
|
-
Requires-Dist: torchvision
|
|
40
|
+
Requires-Dist: torchvision>=0.20.0; extra == "torch-cpu"
|
|
41
41
|
Provides-Extra: torch-cuda118
|
|
42
42
|
Requires-Dist: torch; extra == "torch-cuda118"
|
|
43
|
-
Requires-Dist: torchvision
|
|
43
|
+
Requires-Dist: torchvision>=0.20.0; extra == "torch-cuda118"
|
|
44
44
|
Provides-Extra: torch-cuda128
|
|
45
45
|
Requires-Dist: torch; extra == "torch-cuda128"
|
|
46
|
-
Requires-Dist: torchvision
|
|
47
|
-
Provides-Extra:
|
|
48
|
-
Requires-Dist:
|
|
49
|
-
Requires-Dist:
|
|
50
|
-
Requires-Dist: black; extra == "dev"
|
|
51
|
-
Requires-Dist: pydocstyle; extra == "dev"
|
|
52
|
-
Requires-Dist: toml; extra == "dev"
|
|
53
|
-
Requires-Dist: twine; extra == "dev"
|
|
54
|
-
Requires-Dist: build; extra == "dev"
|
|
55
|
-
Requires-Dist: ipython; extra == "dev"
|
|
56
|
-
Requires-Dist: ruff; extra == "dev"
|
|
57
|
-
Provides-Extra: docs
|
|
58
|
-
Requires-Dist: mkdocs; extra == "docs"
|
|
59
|
-
Requires-Dist: mkdocs-material; extra == "docs"
|
|
60
|
-
Requires-Dist: mkdocs-jupyter; extra == "docs"
|
|
61
|
-
Requires-Dist: mike; extra == "docs"
|
|
62
|
-
Requires-Dist: mkdocstrings[python]; extra == "docs"
|
|
63
|
-
Requires-Dist: mkdocs-gen-files; extra == "docs"
|
|
64
|
-
Requires-Dist: mkdocs-literate-nav; extra == "docs"
|
|
65
|
-
Requires-Dist: mkdocs-section-index; extra == "docs"
|
|
46
|
+
Requires-Dist: torchvision>=0.20.0; extra == "torch-cuda128"
|
|
47
|
+
Provides-Extra: torch-cuda130
|
|
48
|
+
Requires-Dist: torch; extra == "torch-cuda130"
|
|
49
|
+
Requires-Dist: torchvision>=0.20.0; extra == "torch-cuda130"
|
|
66
50
|
Dynamic: license-file
|
|
67
51
|
|
|
68
52
|
# sleap-nn
|
|
@@ -120,22 +104,28 @@ powershell -c "irm https://astral.sh/uv/install.ps1 | iex"
|
|
|
120
104
|
> Replace `...` with the rest of your install command as needed.
|
|
121
105
|
|
|
122
106
|
- Sync all dependencies based on your correct wheel using `uv sync`. `uv sync` creates a `.venv` (virtual environment) inside your current working directory. This environment is only active within that directory and can't be directly accessed from outside. To use all installed packages, you must run commands with `uv run` (e.g., `uv run sleap-nn train ...` or `uv run pytest ...`).
|
|
123
|
-
- **Windows/Linux with NVIDIA GPU (CUDA
|
|
107
|
+
- **Windows/Linux with NVIDIA GPU (CUDA 13.0):**
|
|
124
108
|
|
|
125
109
|
```bash
|
|
126
|
-
uv sync --extra
|
|
110
|
+
uv sync --extra torch-cuda130
|
|
127
111
|
```
|
|
128
112
|
|
|
129
113
|
- **Windows/Linux with NVIDIA GPU (CUDA 12.8):**
|
|
130
114
|
|
|
131
115
|
```bash
|
|
132
|
-
uv sync --extra
|
|
116
|
+
uv sync --extra torch-cuda128
|
|
133
117
|
```
|
|
134
|
-
|
|
135
|
-
- **
|
|
118
|
+
|
|
119
|
+
- **Windows/Linux with NVIDIA GPU (CUDA 11.8):**
|
|
120
|
+
|
|
121
|
+
```bash
|
|
122
|
+
uv sync --extra torch-cuda118
|
|
123
|
+
```
|
|
124
|
+
|
|
125
|
+
- **macOS with Apple Silicon (M1, M2, M3, M4) or CPU-only (no GPU or unsupported GPU):**
|
|
136
126
|
Note: Even if torch-cpu is used on macOS, the MPS backend will be available.
|
|
137
127
|
```bash
|
|
138
|
-
uv sync --extra
|
|
128
|
+
uv sync --extra torch-cpu
|
|
139
129
|
```
|
|
140
130
|
|
|
141
131
|
4. **Run tests**
|
|
@@ -152,6 +142,6 @@ powershell -c "irm https://astral.sh/uv/install.ps1 | iex"
|
|
|
152
142
|
> **Upgrading All Dependencies**
|
|
153
143
|
> To ensure you have the latest versions of all dependencies, use the `--upgrade` flag with `uv sync`:
|
|
154
144
|
> ```bash
|
|
155
|
-
> uv sync --
|
|
145
|
+
> uv sync --upgrade
|
|
156
146
|
> ```
|
|
157
147
|
> This will upgrade all installed packages in your environment to the latest available versions compatible with your `pyproject.toml`.
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
sleap_nn/.DS_Store,sha256=HY8amA79eHkt7o5VUiNsMxkc9YwW6WIPyZbYRj_JdSU,6148
|
|
2
|
-
sleap_nn/__init__.py,sha256=
|
|
3
|
-
sleap_nn/cli.py,sha256=
|
|
4
|
-
sleap_nn/evaluation.py,sha256=
|
|
2
|
+
sleap_nn/__init__.py,sha256=DzQeiZIFUmfhpf6mk4j1AKAY2bofVMyIa31xbiSu-ls,1317
|
|
3
|
+
sleap_nn/cli.py,sha256=U4hpEcOxK7a92GeItY95E2DRm5P1ME1GqU__mxaDcW0,21167
|
|
4
|
+
sleap_nn/evaluation.py,sha256=3u7y85wFoBgCwOB2xOGTJIDrd2dUPWOo4m0s0oW3da4,31095
|
|
5
5
|
sleap_nn/legacy_models.py,sha256=8aGK30DZv3pW2IKDBEWH1G2mrytjaxPQD4miPUehj0M,20258
|
|
6
|
-
sleap_nn/predict.py,sha256=
|
|
7
|
-
sleap_nn/
|
|
6
|
+
sleap_nn/predict.py,sha256=8QKjRbS-L-6HF1NFJWioBPv3HSzUpFr2oGEB5hRJzQA,35523
|
|
7
|
+
sleap_nn/system_info.py,sha256=7tWe3y6s872nDbrZoHIdSs-w4w46Z4dEV2qCV-Fe7No,14711
|
|
8
|
+
sleap_nn/train.py,sha256=fWx_b1HqkadQ-GM_VEM1frCd8WkzJLqRARBNn8UoUbo,27181
|
|
8
9
|
sleap_nn/architectures/__init__.py,sha256=w0XxQcx-CYyooszzvxRkKWiJkUg-26IlwQoGna8gn40,46
|
|
9
10
|
sleap_nn/architectures/common.py,sha256=MLv-zdHsWL5Q2ct_Wv6SQbRS-5hrFtjK_pvBEfwx-vU,3660
|
|
10
11
|
sleap_nn/architectures/convnext.py,sha256=l9lMJDxIMb-9MI3ShOtVwbOUMuwOLtSQlxiVyYHqjvE,13953
|
|
@@ -15,49 +16,50 @@ sleap_nn/architectures/swint.py,sha256=S66Wd0j8Hp-rGlv1C60WSw3AwGyAyGetgfwpL0nIK
|
|
|
15
16
|
sleap_nn/architectures/unet.py,sha256=rAy2Omi6tv1MNW2nBn0Tw-94Nw_-1wFfCT3-IUyPcgo,11723
|
|
16
17
|
sleap_nn/architectures/utils.py,sha256=L0KVs0gbtG8U75Sl40oH_r_w2ySawh3oQPqIGi54HGo,2171
|
|
17
18
|
sleap_nn/config/__init__.py,sha256=l0xV1uJsGJfMPfWAqlUR7Ivu4cSCWsP-3Y9ueyPESuk,42
|
|
18
|
-
sleap_nn/config/data_config.py,sha256=
|
|
19
|
-
sleap_nn/config/get_config.py,sha256=
|
|
19
|
+
sleap_nn/config/data_config.py,sha256=5a5YlXm4V9qGvkqgFNy6o0XJ_Q06UFjpYJXmNHfvXEI,24021
|
|
20
|
+
sleap_nn/config/get_config.py,sha256=vN_aOPTj9F-QBqGGfVSv8_aFSAYl-RfXY0pdbdcqjcM,42021
|
|
20
21
|
sleap_nn/config/model_config.py,sha256=XFIbqFno7IkX0Se5WF_2_7aUalAlC2SvpDe-uP2TttM,57582
|
|
21
|
-
sleap_nn/config/trainer_config.py,sha256=
|
|
22
|
+
sleap_nn/config/trainer_config.py,sha256=PaoNtRSNc2xgzwN955aR9kTZL8IxCWdevGljLxS6jOk,28073
|
|
22
23
|
sleap_nn/config/training_job_config.py,sha256=v12_ME_tBUg8JFwOxJNW4sDQn-SedDhiJOGz-TlRwT0,5861
|
|
23
24
|
sleap_nn/config/utils.py,sha256=GgWgVs7_N7ifsJ5OQG3_EyOagNyN3Dx7wS2BAlkaRkg,5553
|
|
24
25
|
sleap_nn/data/__init__.py,sha256=eMNvFJFa3gv5Rq8oK5wzo6zt1pOlwUGYf8EQii6bq7c,54
|
|
25
|
-
sleap_nn/data/augmentation.py,sha256=
|
|
26
|
+
sleap_nn/data/augmentation.py,sha256=Kqw_DayPth_DBsmaO1G8Voou_-cYZuSPOjSQWSajgRI,13618
|
|
26
27
|
sleap_nn/data/confidence_maps.py,sha256=PTRqZWSAz1S7viJhxu7QgIC1aHiek97c_dCUsKUwG1o,6217
|
|
27
|
-
sleap_nn/data/custom_datasets.py,sha256=
|
|
28
|
+
sleap_nn/data/custom_datasets.py,sha256=2qAyLeiCPI9uudFFP7zlj6d_tbxc5OVzpnzT23mRkVw,98472
|
|
28
29
|
sleap_nn/data/edge_maps.py,sha256=75qG_7zHRw7fC8JUCVI2tzYakIoxxneWWmcrTwjcHPo,12519
|
|
29
30
|
sleap_nn/data/identity.py,sha256=7vNup6PudST4yDLyDT9wDO-cunRirTEvx4sP77xrlfk,5193
|
|
30
31
|
sleap_nn/data/instance_centroids.py,sha256=SF-3zJt_VMTbZI5ssbrvmZQZDd3684bn55EAtvcbQ6o,2172
|
|
31
|
-
sleap_nn/data/instance_cropping.py,sha256=
|
|
32
|
-
sleap_nn/data/normalization.py,sha256=
|
|
33
|
-
sleap_nn/data/providers.py,sha256=
|
|
32
|
+
sleap_nn/data/instance_cropping.py,sha256=2dYq5OTwkFN1PdMjoxyuMuHq1OEe03m3Vzqvcs_dkPE,8304
|
|
33
|
+
sleap_nn/data/normalization.py,sha256=5xEvcguG-fvAGObl4nWPZ9TEM5gvv0uYPGDuni34XII,2930
|
|
34
|
+
sleap_nn/data/providers.py,sha256=0x6GFP1s1c08ji4p0M5V6p-dhT4Z9c-SI_Aw1DWX-uM,14272
|
|
34
35
|
sleap_nn/data/resizing.py,sha256=YFpSQduIBkRK39FYmrqDL-v8zMySlEs6TJxh6zb_0ZU,5076
|
|
35
36
|
sleap_nn/data/utils.py,sha256=rT0w7KMOTlzaeKWq1TqjbgC4Lvjz_G96McllvEOqXx8,5641
|
|
36
|
-
sleap_nn/inference/__init__.py,sha256=
|
|
37
|
+
sleap_nn/inference/__init__.py,sha256=eVkCmKrxHlDFJIlZTf8B5XEOcSyw-gPQymXMY5uShOM,170
|
|
37
38
|
sleap_nn/inference/bottomup.py,sha256=NqN-G8TzAOsvCoL3bttEjA1iGsuveLOnOCXIUeFCdSA,13684
|
|
38
39
|
sleap_nn/inference/identity.py,sha256=GjNDL9MfGqNyQaK4AE8JQCAE8gpMuE_Y-3r3Gpa53CE,6540
|
|
39
40
|
sleap_nn/inference/paf_grouping.py,sha256=7Fo9lCAj-zcHgv5rI5LIMYGcixCGNt_ZbSNs8Dik7l8,69973
|
|
40
|
-
sleap_nn/inference/peak_finding.py,sha256=
|
|
41
|
-
sleap_nn/inference/predictors.py,sha256=
|
|
41
|
+
sleap_nn/inference/peak_finding.py,sha256=L9LdYKt_Bfw7cxo6xEpgF8wXcZAwq5plCfmKJ839N40,13014
|
|
42
|
+
sleap_nn/inference/predictors.py,sha256=U114RlgOXKGm5iz1lnTfE3aN9S0WCh6gWhVP3KVewfc,158046
|
|
43
|
+
sleap_nn/inference/provenance.py,sha256=0BekXyvpLMb0Vv6DjpctlLduG9RN-Q8jt5zDm783eZE,11204
|
|
42
44
|
sleap_nn/inference/single_instance.py,sha256=rOns_5TsJ1rb-lwmHG3ZY-pOhXGN2D-SfW9RmBxxzcI,4089
|
|
43
|
-
sleap_nn/inference/topdown.py,sha256=
|
|
45
|
+
sleap_nn/inference/topdown.py,sha256=Ha0Nwx-XCH_rebIuIGhP0qW68QpjLB3XRr9rxt05JLs,35108
|
|
44
46
|
sleap_nn/inference/utils.py,sha256=JnaJK4S_qLtHkWOSkHf4oRZjOmgnU9BGADQnntgGxxs,4689
|
|
45
47
|
sleap_nn/tracking/__init__.py,sha256=rGR35wpSW-n5d3cMiQUzQQ_Dy5II5DPjlXAoPw2QhmM,31
|
|
46
48
|
sleap_nn/tracking/track_instance.py,sha256=9k0uVy9VmpleaLcJh7sVWSeFUPXiw7yj95EYNdXJcks,1373
|
|
47
|
-
sleap_nn/tracking/tracker.py,sha256=
|
|
49
|
+
sleap_nn/tracking/tracker.py,sha256=_WT-HFruzyOsvcq3AtLm3vnI9MYSwyBmq-HlQvj1vmU,41955
|
|
48
50
|
sleap_nn/tracking/utils.py,sha256=uHVd_mzzZjviVDdLSKXJJ1T96n5ObKvkqIuGsl9Yy8U,11276
|
|
49
51
|
sleap_nn/tracking/candidates/__init__.py,sha256=1O7NObIwshM7j1rLHmImbFphvkM9wY1j4j1TvO5scSE,49
|
|
50
52
|
sleap_nn/tracking/candidates/fixed_window.py,sha256=D80KMlTnenuQveQVVhk9j0G8yx6K324C7nMLHgG76e0,6296
|
|
51
53
|
sleap_nn/tracking/candidates/local_queues.py,sha256=Nx3R5wwEwq0gbfH-fi3oOumfkQo8_sYe5GN47pD9Be8,7305
|
|
52
54
|
sleap_nn/training/__init__.py,sha256=vNTKsIJPZHJwFSKn5PmjiiRJunR_9e7y4_v0S6rdF8U,32
|
|
53
|
-
sleap_nn/training/callbacks.py,sha256=
|
|
54
|
-
sleap_nn/training/lightning_modules.py,sha256=
|
|
55
|
+
sleap_nn/training/callbacks.py,sha256=TVnQ6plNC2MnlTiY2rSCRuw2WRk5cQSziek_VPUcOEg,25994
|
|
56
|
+
sleap_nn/training/lightning_modules.py,sha256=G3c4xJkYWW-iSRawzkgTqkGd4lTsbPiMTcB5Nvq7jes,85512
|
|
55
57
|
sleap_nn/training/losses.py,sha256=gbdinUURh4QUzjmNd2UJpt4FXwecqKy9gHr65JZ1bZk,1632
|
|
56
|
-
sleap_nn/training/model_trainer.py,sha256=
|
|
57
|
-
sleap_nn/training/utils.py,sha256=
|
|
58
|
-
sleap_nn-0.
|
|
59
|
-
sleap_nn-0.
|
|
60
|
-
sleap_nn-0.
|
|
61
|
-
sleap_nn-0.
|
|
62
|
-
sleap_nn-0.
|
|
63
|
-
sleap_nn-0.
|
|
58
|
+
sleap_nn/training/model_trainer.py,sha256=InDKHrQxBwbltZKutW4yrBR9NThLdRpWNUGhmB0xAi4,57863
|
|
59
|
+
sleap_nn/training/utils.py,sha256=ivdkZEI0DkTCm6NPszsaDOh9jSfozkONZdl6TvvQUWI,20398
|
|
60
|
+
sleap_nn-0.1.0a0.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
61
|
+
sleap_nn-0.1.0a0.dist-info/METADATA,sha256=lxSmGNTUg9eetqHCvhw8Tv5zJtia-dIM5RzOeoDccj8,5637
|
|
62
|
+
sleap_nn-0.1.0a0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
63
|
+
sleap_nn-0.1.0a0.dist-info/entry_points.txt,sha256=zfl5Y3hidZxWBvo8qXvu5piJAXJ_l6v7xVFm0gNiUoI,46
|
|
64
|
+
sleap_nn-0.1.0a0.dist-info/top_level.txt,sha256=Kz68iQ55K75LWgSeqz4V4SCMGeFFYH-KGBOyhQh3xZE,9
|
|
65
|
+
sleap_nn-0.1.0a0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|