sleap-nn 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +9 -2
- sleap_nn/architectures/convnext.py +5 -0
- sleap_nn/architectures/encoder_decoder.py +25 -6
- sleap_nn/architectures/swint.py +8 -0
- sleap_nn/cli.py +489 -46
- sleap_nn/config/data_config.py +51 -8
- sleap_nn/config/get_config.py +32 -24
- sleap_nn/config/trainer_config.py +88 -0
- sleap_nn/data/augmentation.py +61 -200
- sleap_nn/data/custom_datasets.py +433 -61
- sleap_nn/data/instance_cropping.py +71 -6
- sleap_nn/data/normalization.py +45 -2
- sleap_nn/data/providers.py +26 -0
- sleap_nn/data/resizing.py +2 -2
- sleap_nn/data/skia_augmentation.py +414 -0
- sleap_nn/data/utils.py +135 -17
- sleap_nn/evaluation.py +177 -42
- sleap_nn/export/__init__.py +21 -0
- sleap_nn/export/cli.py +1778 -0
- sleap_nn/export/exporters/__init__.py +51 -0
- sleap_nn/export/exporters/onnx_exporter.py +80 -0
- sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
- sleap_nn/export/metadata.py +225 -0
- sleap_nn/export/predictors/__init__.py +63 -0
- sleap_nn/export/predictors/base.py +22 -0
- sleap_nn/export/predictors/onnx.py +154 -0
- sleap_nn/export/predictors/tensorrt.py +312 -0
- sleap_nn/export/utils.py +307 -0
- sleap_nn/export/wrappers/__init__.py +25 -0
- sleap_nn/export/wrappers/base.py +96 -0
- sleap_nn/export/wrappers/bottomup.py +243 -0
- sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
- sleap_nn/export/wrappers/centered_instance.py +56 -0
- sleap_nn/export/wrappers/centroid.py +58 -0
- sleap_nn/export/wrappers/single_instance.py +83 -0
- sleap_nn/export/wrappers/topdown.py +180 -0
- sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
- sleap_nn/inference/__init__.py +6 -0
- sleap_nn/inference/bottomup.py +86 -20
- sleap_nn/inference/peak_finding.py +93 -16
- sleap_nn/inference/postprocessing.py +284 -0
- sleap_nn/inference/predictors.py +339 -137
- sleap_nn/inference/provenance.py +292 -0
- sleap_nn/inference/topdown.py +55 -47
- sleap_nn/legacy_models.py +65 -11
- sleap_nn/predict.py +224 -19
- sleap_nn/system_info.py +443 -0
- sleap_nn/tracking/tracker.py +8 -1
- sleap_nn/train.py +138 -44
- sleap_nn/training/callbacks.py +1258 -5
- sleap_nn/training/lightning_modules.py +902 -220
- sleap_nn/training/model_trainer.py +424 -111
- sleap_nn/training/schedulers.py +191 -0
- sleap_nn/training/utils.py +367 -2
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/METADATA +35 -33
- sleap_nn-0.1.0.dist-info/RECORD +88 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/WHEEL +1 -1
- sleap_nn-0.0.5.dist-info/RECORD +0 -63
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/top_level.txt +0 -0
sleap_nn/system_info.py
ADDED
|
@@ -0,0 +1,443 @@
|
|
|
1
|
+
"""System diagnostics and compatibility checking for sleap-nn."""
|
|
2
|
+
|
|
3
|
+
import importlib.metadata
|
|
4
|
+
import json
|
|
5
|
+
import platform
|
|
6
|
+
import shutil
|
|
7
|
+
import subprocess
|
|
8
|
+
import sys
|
|
9
|
+
from typing import Optional
|
|
10
|
+
|
|
11
|
+
# Key packages to check versions for
|
|
12
|
+
PACKAGES = [
|
|
13
|
+
"sleap-nn",
|
|
14
|
+
"sleap-io",
|
|
15
|
+
"torch",
|
|
16
|
+
"pytorch-lightning",
|
|
17
|
+
"kornia",
|
|
18
|
+
"wandb",
|
|
19
|
+
"numpy",
|
|
20
|
+
"h5py",
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
# CUDA version -> (min_driver_linux, min_driver_windows)
|
|
24
|
+
# Source: https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/
|
|
25
|
+
CUDA_DRIVER_REQUIREMENTS = {
|
|
26
|
+
"12.6": ("560.28.03", "560.76"),
|
|
27
|
+
"12.8": ("570.26", "570.65"),
|
|
28
|
+
"13.0": ("580.65.06", "580.00"),
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def parse_driver_version(version: str) -> tuple[int, ...]:
|
|
33
|
+
"""Parse driver version string into comparable tuple."""
|
|
34
|
+
try:
|
|
35
|
+
return tuple(int(x) for x in version.split("."))
|
|
36
|
+
except ValueError:
|
|
37
|
+
return (0,)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def get_min_driver_for_cuda(cuda_version: str) -> Optional[tuple[str, str]]:
|
|
41
|
+
"""Get minimum driver versions for a CUDA version.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
cuda_version: CUDA version string (e.g., "12.6" or "12.6.1")
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Tuple of (min_linux, min_windows) or None if unknown version.
|
|
48
|
+
"""
|
|
49
|
+
if not cuda_version:
|
|
50
|
+
return None
|
|
51
|
+
# Match major.minor (e.g., "12.6" from "12.6.1")
|
|
52
|
+
parts = cuda_version.split(".")
|
|
53
|
+
if len(parts) >= 2:
|
|
54
|
+
major_minor = f"{parts[0]}.{parts[1]}"
|
|
55
|
+
return CUDA_DRIVER_REQUIREMENTS.get(major_minor)
|
|
56
|
+
return None
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def check_driver_compatibility(
|
|
60
|
+
driver_version: str, cuda_version: str
|
|
61
|
+
) -> tuple[bool, Optional[str]]:
|
|
62
|
+
"""Check if driver version is compatible with CUDA version.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
driver_version: Installed driver version string
|
|
66
|
+
cuda_version: CUDA version from PyTorch
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Tuple of (is_compatible, min_required_version).
|
|
70
|
+
If CUDA version is unknown, returns (True, None).
|
|
71
|
+
"""
|
|
72
|
+
min_versions = get_min_driver_for_cuda(cuda_version)
|
|
73
|
+
if not min_versions:
|
|
74
|
+
return True, None # Unknown CUDA version, skip check
|
|
75
|
+
|
|
76
|
+
if sys.platform == "win32":
|
|
77
|
+
min_version = min_versions[1]
|
|
78
|
+
else:
|
|
79
|
+
min_version = min_versions[0]
|
|
80
|
+
|
|
81
|
+
current = parse_driver_version(driver_version)
|
|
82
|
+
required = parse_driver_version(min_version)
|
|
83
|
+
|
|
84
|
+
# Pad tuples to same length for comparison
|
|
85
|
+
max_len = max(len(current), len(required))
|
|
86
|
+
current = current + (0,) * (max_len - len(current))
|
|
87
|
+
required = required + (0,) * (max_len - len(required))
|
|
88
|
+
|
|
89
|
+
return current >= required, min_version
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def get_nvidia_driver_version() -> Optional[str]:
|
|
93
|
+
"""Get NVIDIA driver version from nvidia-smi."""
|
|
94
|
+
if not shutil.which("nvidia-smi"):
|
|
95
|
+
return None
|
|
96
|
+
try:
|
|
97
|
+
result = subprocess.run(
|
|
98
|
+
["nvidia-smi", "--query-gpu=driver_version", "--format=csv,noheader"],
|
|
99
|
+
capture_output=True,
|
|
100
|
+
text=True,
|
|
101
|
+
timeout=5,
|
|
102
|
+
)
|
|
103
|
+
if result.returncode == 0:
|
|
104
|
+
return result.stdout.strip().split("\n")[0]
|
|
105
|
+
except Exception:
|
|
106
|
+
pass
|
|
107
|
+
return None
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _get_package_location(name: str, dist) -> str:
|
|
111
|
+
"""Get the actual installed location of a package.
|
|
112
|
+
|
|
113
|
+
For editable installs, returns the source directory.
|
|
114
|
+
For regular installs, returns the site-packages path.
|
|
115
|
+
"""
|
|
116
|
+
# Try to import the package and get its __file__
|
|
117
|
+
try:
|
|
118
|
+
# Convert package name to module name (e.g., "sleap-nn" -> "sleap_nn")
|
|
119
|
+
module_name = name.replace("-", "_")
|
|
120
|
+
module = __import__(module_name)
|
|
121
|
+
if hasattr(module, "__file__") and module.__file__:
|
|
122
|
+
# Return parent of the package's __init__.py
|
|
123
|
+
from pathlib import Path
|
|
124
|
+
|
|
125
|
+
return str(Path(module.__file__).parent.parent)
|
|
126
|
+
except (ImportError, AttributeError):
|
|
127
|
+
pass
|
|
128
|
+
|
|
129
|
+
# Fallback to dist._path
|
|
130
|
+
if dist._path:
|
|
131
|
+
path = dist._path.parent
|
|
132
|
+
# Make absolute if relative
|
|
133
|
+
if not path.is_absolute():
|
|
134
|
+
from pathlib import Path
|
|
135
|
+
|
|
136
|
+
path = Path.cwd() / path
|
|
137
|
+
return str(path)
|
|
138
|
+
|
|
139
|
+
return ""
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def get_package_info(name: str) -> dict:
|
|
143
|
+
"""Get package version, location, and install source.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
name: Package name (e.g., "sleap-nn")
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
Dict with version, location, source, and editable fields.
|
|
150
|
+
"""
|
|
151
|
+
try:
|
|
152
|
+
dist = importlib.metadata.distribution(name)
|
|
153
|
+
version = dist.version
|
|
154
|
+
|
|
155
|
+
# Check for editable install and source via direct_url.json
|
|
156
|
+
is_editable = False
|
|
157
|
+
source = "pip" # Default assumption
|
|
158
|
+
try:
|
|
159
|
+
direct_url_text = dist.read_text("direct_url.json")
|
|
160
|
+
if direct_url_text:
|
|
161
|
+
direct_url = json.loads(direct_url_text)
|
|
162
|
+
is_editable = direct_url.get("dir_info", {}).get("editable", False)
|
|
163
|
+
if is_editable:
|
|
164
|
+
source = "editable"
|
|
165
|
+
elif "vcs_info" in direct_url:
|
|
166
|
+
source = "git"
|
|
167
|
+
elif direct_url.get("url", "").startswith("file://"):
|
|
168
|
+
source = "local"
|
|
169
|
+
except FileNotFoundError:
|
|
170
|
+
pass
|
|
171
|
+
|
|
172
|
+
# Fallback: detect old-style editable installs (.egg-info not in site-packages)
|
|
173
|
+
if not is_editable and dist._path:
|
|
174
|
+
path_str = str(dist._path)
|
|
175
|
+
# Old-style editable: .egg-info in source dir, not site-packages
|
|
176
|
+
if ".egg-info" in path_str and "site-packages" not in path_str:
|
|
177
|
+
is_editable = True
|
|
178
|
+
source = "editable"
|
|
179
|
+
|
|
180
|
+
# Check for conda install via INSTALLER file (only if not already known)
|
|
181
|
+
if source == "pip":
|
|
182
|
+
try:
|
|
183
|
+
installer = dist.read_text("INSTALLER")
|
|
184
|
+
if installer and installer.strip() == "conda":
|
|
185
|
+
source = "conda"
|
|
186
|
+
except FileNotFoundError:
|
|
187
|
+
pass
|
|
188
|
+
|
|
189
|
+
# Get location (after determining if editable, so we can use the right method)
|
|
190
|
+
location = _get_package_location(name, dist)
|
|
191
|
+
|
|
192
|
+
return {
|
|
193
|
+
"version": version,
|
|
194
|
+
"location": location,
|
|
195
|
+
"source": source,
|
|
196
|
+
"editable": is_editable,
|
|
197
|
+
}
|
|
198
|
+
except importlib.metadata.PackageNotFoundError:
|
|
199
|
+
return {
|
|
200
|
+
"version": "not installed",
|
|
201
|
+
"location": "",
|
|
202
|
+
"source": "",
|
|
203
|
+
"editable": False,
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def get_system_info_dict() -> dict:
|
|
208
|
+
"""Get system information as a dictionary.
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
Dictionary with system info including Python version, platform,
|
|
212
|
+
PyTorch version, CUDA availability, GPU details, and package versions.
|
|
213
|
+
"""
|
|
214
|
+
import torch
|
|
215
|
+
|
|
216
|
+
info = {
|
|
217
|
+
"python_version": sys.version.split()[0],
|
|
218
|
+
"platform": platform.platform(),
|
|
219
|
+
"pytorch_version": torch.__version__,
|
|
220
|
+
"cuda_available": torch.cuda.is_available(),
|
|
221
|
+
"cuda_version": None,
|
|
222
|
+
"cudnn_version": None,
|
|
223
|
+
"driver_version": None,
|
|
224
|
+
"driver_compatible": None,
|
|
225
|
+
"driver_min_required": None,
|
|
226
|
+
"gpu_count": 0,
|
|
227
|
+
"gpus": [],
|
|
228
|
+
"mps_available": False,
|
|
229
|
+
"accelerator": "cpu", # cpu, cuda, or mps
|
|
230
|
+
"packages": {},
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
# Driver version (check even if CUDA unavailable - old driver can cause this)
|
|
234
|
+
driver = get_nvidia_driver_version()
|
|
235
|
+
if driver:
|
|
236
|
+
info["driver_version"] = driver
|
|
237
|
+
|
|
238
|
+
# CUDA details
|
|
239
|
+
if torch.cuda.is_available():
|
|
240
|
+
info["cuda_version"] = torch.version.cuda
|
|
241
|
+
info["cudnn_version"] = str(torch.backends.cudnn.version())
|
|
242
|
+
info["gpu_count"] = torch.cuda.device_count()
|
|
243
|
+
info["accelerator"] = "cuda"
|
|
244
|
+
|
|
245
|
+
# Check driver compatibility
|
|
246
|
+
if driver and info["cuda_version"]:
|
|
247
|
+
is_compatible, min_required = check_driver_compatibility(
|
|
248
|
+
driver, info["cuda_version"]
|
|
249
|
+
)
|
|
250
|
+
info["driver_compatible"] = is_compatible
|
|
251
|
+
info["driver_min_required"] = min_required
|
|
252
|
+
|
|
253
|
+
# GPU details
|
|
254
|
+
for i in range(torch.cuda.device_count()):
|
|
255
|
+
props = torch.cuda.get_device_properties(i)
|
|
256
|
+
info["gpus"].append(
|
|
257
|
+
{
|
|
258
|
+
"id": i,
|
|
259
|
+
"name": props.name,
|
|
260
|
+
"compute_capability": f"{props.major}.{props.minor}",
|
|
261
|
+
"memory_gb": round(props.total_memory / (1024**3), 1),
|
|
262
|
+
}
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
# MPS (Apple Silicon)
|
|
266
|
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
267
|
+
info["mps_available"] = True
|
|
268
|
+
info["accelerator"] = "mps"
|
|
269
|
+
info["gpu_count"] = 1
|
|
270
|
+
|
|
271
|
+
# Package versions
|
|
272
|
+
for pkg in PACKAGES:
|
|
273
|
+
info["packages"][pkg] = get_package_info(pkg)
|
|
274
|
+
|
|
275
|
+
return info
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def test_gpu_operations() -> tuple[bool, Optional[str]]:
|
|
279
|
+
"""Test that GPU tensor operations work.
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
Tuple of (success, error_message).
|
|
283
|
+
"""
|
|
284
|
+
import torch
|
|
285
|
+
|
|
286
|
+
if torch.cuda.is_available():
|
|
287
|
+
try:
|
|
288
|
+
x = torch.randn(100, 100, device="cuda")
|
|
289
|
+
_ = torch.mm(x, x)
|
|
290
|
+
return True, None
|
|
291
|
+
except Exception as e:
|
|
292
|
+
return False, str(e)
|
|
293
|
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
294
|
+
try:
|
|
295
|
+
x = torch.randn(100, 100, device="mps")
|
|
296
|
+
_ = torch.mm(x, x)
|
|
297
|
+
return True, None
|
|
298
|
+
except Exception as e:
|
|
299
|
+
return False, str(e)
|
|
300
|
+
return False, "No GPU available"
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def _shorten_path(path: str, max_len: int = 40) -> str:
|
|
304
|
+
"""Shorten a path for display, keeping the end."""
|
|
305
|
+
if len(path) <= max_len:
|
|
306
|
+
return path
|
|
307
|
+
return "..." + path[-(max_len - 3) :]
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def print_system_info() -> None:
|
|
311
|
+
"""Print comprehensive system diagnostics to console."""
|
|
312
|
+
from rich.console import Console
|
|
313
|
+
from rich.table import Table
|
|
314
|
+
|
|
315
|
+
console = Console()
|
|
316
|
+
info = get_system_info_dict()
|
|
317
|
+
|
|
318
|
+
# System info table (with GPU details integrated)
|
|
319
|
+
table = Table(title="System Information", show_header=False)
|
|
320
|
+
table.add_column("Property", style="cyan")
|
|
321
|
+
table.add_column("Value", style="white")
|
|
322
|
+
|
|
323
|
+
table.add_row("Python", info["python_version"])
|
|
324
|
+
table.add_row("Platform", info["platform"])
|
|
325
|
+
table.add_row("PyTorch", info["pytorch_version"])
|
|
326
|
+
|
|
327
|
+
# GPU/Accelerator info
|
|
328
|
+
if info["accelerator"] == "cuda":
|
|
329
|
+
table.add_row("Accelerator", "CUDA")
|
|
330
|
+
table.add_row("CUDA version", info["cuda_version"] or "N/A")
|
|
331
|
+
table.add_row("cuDNN version", info["cudnn_version"] or "N/A")
|
|
332
|
+
table.add_row("Driver version", info["driver_version"] or "N/A")
|
|
333
|
+
table.add_row("GPU count", str(info["gpu_count"]))
|
|
334
|
+
# GPU details inline
|
|
335
|
+
for gpu in info["gpus"]:
|
|
336
|
+
gpu_str = f"{gpu['name']} ({gpu['memory_gb']} GB, compute {gpu['compute_capability']})"
|
|
337
|
+
table.add_row(f"GPU {gpu['id']}", gpu_str)
|
|
338
|
+
elif info["accelerator"] == "mps":
|
|
339
|
+
table.add_row("Accelerator", "MPS (Apple Silicon)")
|
|
340
|
+
else:
|
|
341
|
+
table.add_row("Accelerator", "CPU only")
|
|
342
|
+
# Show driver if present but CUDA unavailable (helps diagnose issues)
|
|
343
|
+
if info["driver_version"]:
|
|
344
|
+
table.add_row("Driver version", info["driver_version"])
|
|
345
|
+
|
|
346
|
+
console.print(table)
|
|
347
|
+
|
|
348
|
+
# Package versions table
|
|
349
|
+
console.print()
|
|
350
|
+
pkg_table = Table(title="Package Versions")
|
|
351
|
+
pkg_table.add_column("Package", style="cyan")
|
|
352
|
+
pkg_table.add_column("Version", style="white")
|
|
353
|
+
pkg_table.add_column("Source", style="yellow")
|
|
354
|
+
pkg_table.add_column("Location", style="dim")
|
|
355
|
+
|
|
356
|
+
for pkg, pkg_info in info["packages"].items():
|
|
357
|
+
if pkg_info["version"] == "not installed":
|
|
358
|
+
version_display = f"[dim]{pkg_info['version']}[/dim]"
|
|
359
|
+
pkg_table.add_row(pkg, version_display, "", "")
|
|
360
|
+
else:
|
|
361
|
+
location_display = _shorten_path(pkg_info["location"])
|
|
362
|
+
pkg_table.add_row(
|
|
363
|
+
pkg, pkg_info["version"], pkg_info["source"], location_display
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
console.print(pkg_table)
|
|
367
|
+
|
|
368
|
+
# Actionable diagnostics
|
|
369
|
+
console.print()
|
|
370
|
+
|
|
371
|
+
# Driver compatibility check (CUDA only)
|
|
372
|
+
if info["accelerator"] == "cuda" and info["driver_version"]:
|
|
373
|
+
if info["driver_min_required"]:
|
|
374
|
+
if info["driver_compatible"]:
|
|
375
|
+
console.print(
|
|
376
|
+
f"[green]OK[/green] Driver is compatible: "
|
|
377
|
+
f"{info['driver_version']} >= {info['driver_min_required']} "
|
|
378
|
+
f"(required for CUDA {info['cuda_version']})"
|
|
379
|
+
)
|
|
380
|
+
else:
|
|
381
|
+
console.print(
|
|
382
|
+
f"[red]FAIL[/red] Driver is too old: "
|
|
383
|
+
f"{info['driver_version']} < {info['driver_min_required']} "
|
|
384
|
+
f"(required for CUDA {info['cuda_version']})"
|
|
385
|
+
)
|
|
386
|
+
console.print(
|
|
387
|
+
" [yellow]Update your driver: https://www.nvidia.com/drivers[/yellow]"
|
|
388
|
+
)
|
|
389
|
+
else:
|
|
390
|
+
# Unknown CUDA version, can't check compatibility
|
|
391
|
+
console.print(
|
|
392
|
+
f"[yellow]![/yellow] Driver version: {info['driver_version']} "
|
|
393
|
+
f"(CUDA {info['cuda_version']} compatibility unknown)"
|
|
394
|
+
)
|
|
395
|
+
elif info["accelerator"] == "cpu" and info["driver_version"]:
|
|
396
|
+
# Has driver but no CUDA - might be a problem
|
|
397
|
+
console.print(
|
|
398
|
+
f"[yellow]![/yellow] NVIDIA driver found ({info['driver_version']}) "
|
|
399
|
+
"but CUDA is not available"
|
|
400
|
+
)
|
|
401
|
+
console.print(
|
|
402
|
+
" [dim]This may indicate a driver/PyTorch version mismatch[/dim]"
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
# GPU connection test
|
|
406
|
+
success, error = test_gpu_operations()
|
|
407
|
+
if info["accelerator"] == "cuda":
|
|
408
|
+
if success:
|
|
409
|
+
console.print("[green]OK[/green] PyTorch can use GPU")
|
|
410
|
+
else:
|
|
411
|
+
console.print(f"[red]FAIL[/red] PyTorch cannot use GPU: {error}")
|
|
412
|
+
elif info["accelerator"] == "mps":
|
|
413
|
+
if success:
|
|
414
|
+
console.print("[green]OK[/green] PyTorch can use GPU")
|
|
415
|
+
else:
|
|
416
|
+
console.print(f"[red]FAIL[/red] PyTorch cannot use GPU: {error}")
|
|
417
|
+
else:
|
|
418
|
+
console.print("[dim]--[/dim] No GPU available (using CPU)")
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
def get_startup_info_string() -> str:
|
|
422
|
+
"""Get a concise system info string for startup logging.
|
|
423
|
+
|
|
424
|
+
Returns:
|
|
425
|
+
Single-line string with key system info.
|
|
426
|
+
"""
|
|
427
|
+
import torch
|
|
428
|
+
|
|
429
|
+
from sleap_nn import __version__
|
|
430
|
+
|
|
431
|
+
parts = [f"sleap-nn {__version__}"]
|
|
432
|
+
parts.append(f"Python {sys.version.split()[0]}")
|
|
433
|
+
parts.append(f"PyTorch {torch.__version__}")
|
|
434
|
+
|
|
435
|
+
if torch.cuda.is_available():
|
|
436
|
+
parts.append(f"CUDA {torch.version.cuda}")
|
|
437
|
+
parts.append(f"{torch.cuda.device_count()} GPU(s)")
|
|
438
|
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
439
|
+
parts.append("MPS")
|
|
440
|
+
else:
|
|
441
|
+
parts.append("CPU only")
|
|
442
|
+
|
|
443
|
+
return " | ".join(parts)
|
sleap_nn/tracking/tracker.py
CHANGED
|
@@ -898,7 +898,14 @@ def run_tracker(
|
|
|
898
898
|
tracking_target_instance_count is None
|
|
899
899
|
or tracking_target_instance_count == 0
|
|
900
900
|
):
|
|
901
|
-
|
|
901
|
+
if max_tracks is not None:
|
|
902
|
+
suggestion = f"Add --tracking_target_instance_count {max_tracks} to your command (using your --max_tracks value)."
|
|
903
|
+
else:
|
|
904
|
+
suggestion = "Add --tracking_target_instance_count N where N is the expected number of instances per frame."
|
|
905
|
+
message = (
|
|
906
|
+
f"--post_connect_single_breaks requires --tracking_target_instance_count to be set. "
|
|
907
|
+
f"{suggestion}"
|
|
908
|
+
)
|
|
902
909
|
logger.error(message)
|
|
903
910
|
raise ValueError(message)
|
|
904
911
|
start_final_pass_time = time()
|