sleap-nn 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sleap_nn/__init__.py +9 -2
  2. sleap_nn/architectures/convnext.py +5 -0
  3. sleap_nn/architectures/encoder_decoder.py +25 -6
  4. sleap_nn/architectures/swint.py +8 -0
  5. sleap_nn/cli.py +489 -46
  6. sleap_nn/config/data_config.py +51 -8
  7. sleap_nn/config/get_config.py +32 -24
  8. sleap_nn/config/trainer_config.py +88 -0
  9. sleap_nn/data/augmentation.py +61 -200
  10. sleap_nn/data/custom_datasets.py +433 -61
  11. sleap_nn/data/instance_cropping.py +71 -6
  12. sleap_nn/data/normalization.py +45 -2
  13. sleap_nn/data/providers.py +26 -0
  14. sleap_nn/data/resizing.py +2 -2
  15. sleap_nn/data/skia_augmentation.py +414 -0
  16. sleap_nn/data/utils.py +135 -17
  17. sleap_nn/evaluation.py +177 -42
  18. sleap_nn/export/__init__.py +21 -0
  19. sleap_nn/export/cli.py +1778 -0
  20. sleap_nn/export/exporters/__init__.py +51 -0
  21. sleap_nn/export/exporters/onnx_exporter.py +80 -0
  22. sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
  23. sleap_nn/export/metadata.py +225 -0
  24. sleap_nn/export/predictors/__init__.py +63 -0
  25. sleap_nn/export/predictors/base.py +22 -0
  26. sleap_nn/export/predictors/onnx.py +154 -0
  27. sleap_nn/export/predictors/tensorrt.py +312 -0
  28. sleap_nn/export/utils.py +307 -0
  29. sleap_nn/export/wrappers/__init__.py +25 -0
  30. sleap_nn/export/wrappers/base.py +96 -0
  31. sleap_nn/export/wrappers/bottomup.py +243 -0
  32. sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
  33. sleap_nn/export/wrappers/centered_instance.py +56 -0
  34. sleap_nn/export/wrappers/centroid.py +58 -0
  35. sleap_nn/export/wrappers/single_instance.py +83 -0
  36. sleap_nn/export/wrappers/topdown.py +180 -0
  37. sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
  38. sleap_nn/inference/__init__.py +6 -0
  39. sleap_nn/inference/bottomup.py +86 -20
  40. sleap_nn/inference/peak_finding.py +93 -16
  41. sleap_nn/inference/postprocessing.py +284 -0
  42. sleap_nn/inference/predictors.py +339 -137
  43. sleap_nn/inference/provenance.py +292 -0
  44. sleap_nn/inference/topdown.py +55 -47
  45. sleap_nn/legacy_models.py +65 -11
  46. sleap_nn/predict.py +224 -19
  47. sleap_nn/system_info.py +443 -0
  48. sleap_nn/tracking/tracker.py +8 -1
  49. sleap_nn/train.py +138 -44
  50. sleap_nn/training/callbacks.py +1258 -5
  51. sleap_nn/training/lightning_modules.py +902 -220
  52. sleap_nn/training/model_trainer.py +424 -111
  53. sleap_nn/training/schedulers.py +191 -0
  54. sleap_nn/training/utils.py +367 -2
  55. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/METADATA +35 -33
  56. sleap_nn-0.1.0.dist-info/RECORD +88 -0
  57. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/WHEEL +1 -1
  58. sleap_nn-0.0.5.dist-info/RECORD +0 -63
  59. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/entry_points.txt +0 -0
  60. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/licenses/LICENSE +0 -0
  61. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,443 @@
1
+ """System diagnostics and compatibility checking for sleap-nn."""
2
+
3
+ import importlib.metadata
4
+ import json
5
+ import platform
6
+ import shutil
7
+ import subprocess
8
+ import sys
9
+ from typing import Optional
10
+
11
+ # Key packages to check versions for
12
+ PACKAGES = [
13
+ "sleap-nn",
14
+ "sleap-io",
15
+ "torch",
16
+ "pytorch-lightning",
17
+ "kornia",
18
+ "wandb",
19
+ "numpy",
20
+ "h5py",
21
+ ]
22
+
23
+ # CUDA version -> (min_driver_linux, min_driver_windows)
24
+ # Source: https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/
25
+ CUDA_DRIVER_REQUIREMENTS = {
26
+ "12.6": ("560.28.03", "560.76"),
27
+ "12.8": ("570.26", "570.65"),
28
+ "13.0": ("580.65.06", "580.00"),
29
+ }
30
+
31
+
32
+ def parse_driver_version(version: str) -> tuple[int, ...]:
33
+ """Parse driver version string into comparable tuple."""
34
+ try:
35
+ return tuple(int(x) for x in version.split("."))
36
+ except ValueError:
37
+ return (0,)
38
+
39
+
40
+ def get_min_driver_for_cuda(cuda_version: str) -> Optional[tuple[str, str]]:
41
+ """Get minimum driver versions for a CUDA version.
42
+
43
+ Args:
44
+ cuda_version: CUDA version string (e.g., "12.6" or "12.6.1")
45
+
46
+ Returns:
47
+ Tuple of (min_linux, min_windows) or None if unknown version.
48
+ """
49
+ if not cuda_version:
50
+ return None
51
+ # Match major.minor (e.g., "12.6" from "12.6.1")
52
+ parts = cuda_version.split(".")
53
+ if len(parts) >= 2:
54
+ major_minor = f"{parts[0]}.{parts[1]}"
55
+ return CUDA_DRIVER_REQUIREMENTS.get(major_minor)
56
+ return None
57
+
58
+
59
+ def check_driver_compatibility(
60
+ driver_version: str, cuda_version: str
61
+ ) -> tuple[bool, Optional[str]]:
62
+ """Check if driver version is compatible with CUDA version.
63
+
64
+ Args:
65
+ driver_version: Installed driver version string
66
+ cuda_version: CUDA version from PyTorch
67
+
68
+ Returns:
69
+ Tuple of (is_compatible, min_required_version).
70
+ If CUDA version is unknown, returns (True, None).
71
+ """
72
+ min_versions = get_min_driver_for_cuda(cuda_version)
73
+ if not min_versions:
74
+ return True, None # Unknown CUDA version, skip check
75
+
76
+ if sys.platform == "win32":
77
+ min_version = min_versions[1]
78
+ else:
79
+ min_version = min_versions[0]
80
+
81
+ current = parse_driver_version(driver_version)
82
+ required = parse_driver_version(min_version)
83
+
84
+ # Pad tuples to same length for comparison
85
+ max_len = max(len(current), len(required))
86
+ current = current + (0,) * (max_len - len(current))
87
+ required = required + (0,) * (max_len - len(required))
88
+
89
+ return current >= required, min_version
90
+
91
+
92
+ def get_nvidia_driver_version() -> Optional[str]:
93
+ """Get NVIDIA driver version from nvidia-smi."""
94
+ if not shutil.which("nvidia-smi"):
95
+ return None
96
+ try:
97
+ result = subprocess.run(
98
+ ["nvidia-smi", "--query-gpu=driver_version", "--format=csv,noheader"],
99
+ capture_output=True,
100
+ text=True,
101
+ timeout=5,
102
+ )
103
+ if result.returncode == 0:
104
+ return result.stdout.strip().split("\n")[0]
105
+ except Exception:
106
+ pass
107
+ return None
108
+
109
+
110
+ def _get_package_location(name: str, dist) -> str:
111
+ """Get the actual installed location of a package.
112
+
113
+ For editable installs, returns the source directory.
114
+ For regular installs, returns the site-packages path.
115
+ """
116
+ # Try to import the package and get its __file__
117
+ try:
118
+ # Convert package name to module name (e.g., "sleap-nn" -> "sleap_nn")
119
+ module_name = name.replace("-", "_")
120
+ module = __import__(module_name)
121
+ if hasattr(module, "__file__") and module.__file__:
122
+ # Return parent of the package's __init__.py
123
+ from pathlib import Path
124
+
125
+ return str(Path(module.__file__).parent.parent)
126
+ except (ImportError, AttributeError):
127
+ pass
128
+
129
+ # Fallback to dist._path
130
+ if dist._path:
131
+ path = dist._path.parent
132
+ # Make absolute if relative
133
+ if not path.is_absolute():
134
+ from pathlib import Path
135
+
136
+ path = Path.cwd() / path
137
+ return str(path)
138
+
139
+ return ""
140
+
141
+
142
+ def get_package_info(name: str) -> dict:
143
+ """Get package version, location, and install source.
144
+
145
+ Args:
146
+ name: Package name (e.g., "sleap-nn")
147
+
148
+ Returns:
149
+ Dict with version, location, source, and editable fields.
150
+ """
151
+ try:
152
+ dist = importlib.metadata.distribution(name)
153
+ version = dist.version
154
+
155
+ # Check for editable install and source via direct_url.json
156
+ is_editable = False
157
+ source = "pip" # Default assumption
158
+ try:
159
+ direct_url_text = dist.read_text("direct_url.json")
160
+ if direct_url_text:
161
+ direct_url = json.loads(direct_url_text)
162
+ is_editable = direct_url.get("dir_info", {}).get("editable", False)
163
+ if is_editable:
164
+ source = "editable"
165
+ elif "vcs_info" in direct_url:
166
+ source = "git"
167
+ elif direct_url.get("url", "").startswith("file://"):
168
+ source = "local"
169
+ except FileNotFoundError:
170
+ pass
171
+
172
+ # Fallback: detect old-style editable installs (.egg-info not in site-packages)
173
+ if not is_editable and dist._path:
174
+ path_str = str(dist._path)
175
+ # Old-style editable: .egg-info in source dir, not site-packages
176
+ if ".egg-info" in path_str and "site-packages" not in path_str:
177
+ is_editable = True
178
+ source = "editable"
179
+
180
+ # Check for conda install via INSTALLER file (only if not already known)
181
+ if source == "pip":
182
+ try:
183
+ installer = dist.read_text("INSTALLER")
184
+ if installer and installer.strip() == "conda":
185
+ source = "conda"
186
+ except FileNotFoundError:
187
+ pass
188
+
189
+ # Get location (after determining if editable, so we can use the right method)
190
+ location = _get_package_location(name, dist)
191
+
192
+ return {
193
+ "version": version,
194
+ "location": location,
195
+ "source": source,
196
+ "editable": is_editable,
197
+ }
198
+ except importlib.metadata.PackageNotFoundError:
199
+ return {
200
+ "version": "not installed",
201
+ "location": "",
202
+ "source": "",
203
+ "editable": False,
204
+ }
205
+
206
+
207
+ def get_system_info_dict() -> dict:
208
+ """Get system information as a dictionary.
209
+
210
+ Returns:
211
+ Dictionary with system info including Python version, platform,
212
+ PyTorch version, CUDA availability, GPU details, and package versions.
213
+ """
214
+ import torch
215
+
216
+ info = {
217
+ "python_version": sys.version.split()[0],
218
+ "platform": platform.platform(),
219
+ "pytorch_version": torch.__version__,
220
+ "cuda_available": torch.cuda.is_available(),
221
+ "cuda_version": None,
222
+ "cudnn_version": None,
223
+ "driver_version": None,
224
+ "driver_compatible": None,
225
+ "driver_min_required": None,
226
+ "gpu_count": 0,
227
+ "gpus": [],
228
+ "mps_available": False,
229
+ "accelerator": "cpu", # cpu, cuda, or mps
230
+ "packages": {},
231
+ }
232
+
233
+ # Driver version (check even if CUDA unavailable - old driver can cause this)
234
+ driver = get_nvidia_driver_version()
235
+ if driver:
236
+ info["driver_version"] = driver
237
+
238
+ # CUDA details
239
+ if torch.cuda.is_available():
240
+ info["cuda_version"] = torch.version.cuda
241
+ info["cudnn_version"] = str(torch.backends.cudnn.version())
242
+ info["gpu_count"] = torch.cuda.device_count()
243
+ info["accelerator"] = "cuda"
244
+
245
+ # Check driver compatibility
246
+ if driver and info["cuda_version"]:
247
+ is_compatible, min_required = check_driver_compatibility(
248
+ driver, info["cuda_version"]
249
+ )
250
+ info["driver_compatible"] = is_compatible
251
+ info["driver_min_required"] = min_required
252
+
253
+ # GPU details
254
+ for i in range(torch.cuda.device_count()):
255
+ props = torch.cuda.get_device_properties(i)
256
+ info["gpus"].append(
257
+ {
258
+ "id": i,
259
+ "name": props.name,
260
+ "compute_capability": f"{props.major}.{props.minor}",
261
+ "memory_gb": round(props.total_memory / (1024**3), 1),
262
+ }
263
+ )
264
+
265
+ # MPS (Apple Silicon)
266
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
267
+ info["mps_available"] = True
268
+ info["accelerator"] = "mps"
269
+ info["gpu_count"] = 1
270
+
271
+ # Package versions
272
+ for pkg in PACKAGES:
273
+ info["packages"][pkg] = get_package_info(pkg)
274
+
275
+ return info
276
+
277
+
278
+ def test_gpu_operations() -> tuple[bool, Optional[str]]:
279
+ """Test that GPU tensor operations work.
280
+
281
+ Returns:
282
+ Tuple of (success, error_message).
283
+ """
284
+ import torch
285
+
286
+ if torch.cuda.is_available():
287
+ try:
288
+ x = torch.randn(100, 100, device="cuda")
289
+ _ = torch.mm(x, x)
290
+ return True, None
291
+ except Exception as e:
292
+ return False, str(e)
293
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
294
+ try:
295
+ x = torch.randn(100, 100, device="mps")
296
+ _ = torch.mm(x, x)
297
+ return True, None
298
+ except Exception as e:
299
+ return False, str(e)
300
+ return False, "No GPU available"
301
+
302
+
303
+ def _shorten_path(path: str, max_len: int = 40) -> str:
304
+ """Shorten a path for display, keeping the end."""
305
+ if len(path) <= max_len:
306
+ return path
307
+ return "..." + path[-(max_len - 3) :]
308
+
309
+
310
+ def print_system_info() -> None:
311
+ """Print comprehensive system diagnostics to console."""
312
+ from rich.console import Console
313
+ from rich.table import Table
314
+
315
+ console = Console()
316
+ info = get_system_info_dict()
317
+
318
+ # System info table (with GPU details integrated)
319
+ table = Table(title="System Information", show_header=False)
320
+ table.add_column("Property", style="cyan")
321
+ table.add_column("Value", style="white")
322
+
323
+ table.add_row("Python", info["python_version"])
324
+ table.add_row("Platform", info["platform"])
325
+ table.add_row("PyTorch", info["pytorch_version"])
326
+
327
+ # GPU/Accelerator info
328
+ if info["accelerator"] == "cuda":
329
+ table.add_row("Accelerator", "CUDA")
330
+ table.add_row("CUDA version", info["cuda_version"] or "N/A")
331
+ table.add_row("cuDNN version", info["cudnn_version"] or "N/A")
332
+ table.add_row("Driver version", info["driver_version"] or "N/A")
333
+ table.add_row("GPU count", str(info["gpu_count"]))
334
+ # GPU details inline
335
+ for gpu in info["gpus"]:
336
+ gpu_str = f"{gpu['name']} ({gpu['memory_gb']} GB, compute {gpu['compute_capability']})"
337
+ table.add_row(f"GPU {gpu['id']}", gpu_str)
338
+ elif info["accelerator"] == "mps":
339
+ table.add_row("Accelerator", "MPS (Apple Silicon)")
340
+ else:
341
+ table.add_row("Accelerator", "CPU only")
342
+ # Show driver if present but CUDA unavailable (helps diagnose issues)
343
+ if info["driver_version"]:
344
+ table.add_row("Driver version", info["driver_version"])
345
+
346
+ console.print(table)
347
+
348
+ # Package versions table
349
+ console.print()
350
+ pkg_table = Table(title="Package Versions")
351
+ pkg_table.add_column("Package", style="cyan")
352
+ pkg_table.add_column("Version", style="white")
353
+ pkg_table.add_column("Source", style="yellow")
354
+ pkg_table.add_column("Location", style="dim")
355
+
356
+ for pkg, pkg_info in info["packages"].items():
357
+ if pkg_info["version"] == "not installed":
358
+ version_display = f"[dim]{pkg_info['version']}[/dim]"
359
+ pkg_table.add_row(pkg, version_display, "", "")
360
+ else:
361
+ location_display = _shorten_path(pkg_info["location"])
362
+ pkg_table.add_row(
363
+ pkg, pkg_info["version"], pkg_info["source"], location_display
364
+ )
365
+
366
+ console.print(pkg_table)
367
+
368
+ # Actionable diagnostics
369
+ console.print()
370
+
371
+ # Driver compatibility check (CUDA only)
372
+ if info["accelerator"] == "cuda" and info["driver_version"]:
373
+ if info["driver_min_required"]:
374
+ if info["driver_compatible"]:
375
+ console.print(
376
+ f"[green]OK[/green] Driver is compatible: "
377
+ f"{info['driver_version']} >= {info['driver_min_required']} "
378
+ f"(required for CUDA {info['cuda_version']})"
379
+ )
380
+ else:
381
+ console.print(
382
+ f"[red]FAIL[/red] Driver is too old: "
383
+ f"{info['driver_version']} < {info['driver_min_required']} "
384
+ f"(required for CUDA {info['cuda_version']})"
385
+ )
386
+ console.print(
387
+ " [yellow]Update your driver: https://www.nvidia.com/drivers[/yellow]"
388
+ )
389
+ else:
390
+ # Unknown CUDA version, can't check compatibility
391
+ console.print(
392
+ f"[yellow]![/yellow] Driver version: {info['driver_version']} "
393
+ f"(CUDA {info['cuda_version']} compatibility unknown)"
394
+ )
395
+ elif info["accelerator"] == "cpu" and info["driver_version"]:
396
+ # Has driver but no CUDA - might be a problem
397
+ console.print(
398
+ f"[yellow]![/yellow] NVIDIA driver found ({info['driver_version']}) "
399
+ "but CUDA is not available"
400
+ )
401
+ console.print(
402
+ " [dim]This may indicate a driver/PyTorch version mismatch[/dim]"
403
+ )
404
+
405
+ # GPU connection test
406
+ success, error = test_gpu_operations()
407
+ if info["accelerator"] == "cuda":
408
+ if success:
409
+ console.print("[green]OK[/green] PyTorch can use GPU")
410
+ else:
411
+ console.print(f"[red]FAIL[/red] PyTorch cannot use GPU: {error}")
412
+ elif info["accelerator"] == "mps":
413
+ if success:
414
+ console.print("[green]OK[/green] PyTorch can use GPU")
415
+ else:
416
+ console.print(f"[red]FAIL[/red] PyTorch cannot use GPU: {error}")
417
+ else:
418
+ console.print("[dim]--[/dim] No GPU available (using CPU)")
419
+
420
+
421
+ def get_startup_info_string() -> str:
422
+ """Get a concise system info string for startup logging.
423
+
424
+ Returns:
425
+ Single-line string with key system info.
426
+ """
427
+ import torch
428
+
429
+ from sleap_nn import __version__
430
+
431
+ parts = [f"sleap-nn {__version__}"]
432
+ parts.append(f"Python {sys.version.split()[0]}")
433
+ parts.append(f"PyTorch {torch.__version__}")
434
+
435
+ if torch.cuda.is_available():
436
+ parts.append(f"CUDA {torch.version.cuda}")
437
+ parts.append(f"{torch.cuda.device_count()} GPU(s)")
438
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
439
+ parts.append("MPS")
440
+ else:
441
+ parts.append("CPU only")
442
+
443
+ return " | ".join(parts)
@@ -898,7 +898,14 @@ def run_tracker(
898
898
  tracking_target_instance_count is None
899
899
  or tracking_target_instance_count == 0
900
900
  ):
901
- message = "tracking_target_instance_count is None or 0. To connect single breaks, tracking_target_instance_count should be set to an integer."
901
+ if max_tracks is not None:
902
+ suggestion = f"Add --tracking_target_instance_count {max_tracks} to your command (using your --max_tracks value)."
903
+ else:
904
+ suggestion = "Add --tracking_target_instance_count N where N is the expected number of instances per frame."
905
+ message = (
906
+ f"--post_connect_single_breaks requires --tracking_target_instance_count to be set. "
907
+ f"{suggestion}"
908
+ )
902
909
  logger.error(message)
903
910
  raise ValueError(message)
904
911
  start_final_pass_time = time()