sleap-nn 0.0.5__py3-none-any.whl → 0.1.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +6 -1
- sleap_nn/cli.py +142 -3
- sleap_nn/config/data_config.py +44 -7
- sleap_nn/config/get_config.py +22 -20
- sleap_nn/config/trainer_config.py +12 -0
- sleap_nn/data/augmentation.py +54 -2
- sleap_nn/data/custom_datasets.py +22 -22
- sleap_nn/data/instance_cropping.py +70 -5
- sleap_nn/data/normalization.py +45 -2
- sleap_nn/data/providers.py +26 -0
- sleap_nn/evaluation.py +99 -23
- sleap_nn/inference/__init__.py +6 -0
- sleap_nn/inference/peak_finding.py +10 -2
- sleap_nn/inference/predictors.py +115 -20
- sleap_nn/inference/provenance.py +292 -0
- sleap_nn/inference/topdown.py +55 -47
- sleap_nn/predict.py +187 -10
- sleap_nn/system_info.py +443 -0
- sleap_nn/tracking/tracker.py +8 -1
- sleap_nn/train.py +64 -40
- sleap_nn/training/callbacks.py +317 -5
- sleap_nn/training/lightning_modules.py +325 -180
- sleap_nn/training/model_trainer.py +308 -22
- sleap_nn/training/utils.py +367 -2
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/METADATA +22 -32
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/RECORD +30 -28
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/WHEEL +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/top_level.txt +0 -0
sleap_nn/system_info.py
ADDED
|
@@ -0,0 +1,443 @@
|
|
|
1
|
+
"""System diagnostics and compatibility checking for sleap-nn."""
|
|
2
|
+
|
|
3
|
+
import importlib.metadata
|
|
4
|
+
import json
|
|
5
|
+
import platform
|
|
6
|
+
import shutil
|
|
7
|
+
import subprocess
|
|
8
|
+
import sys
|
|
9
|
+
from typing import Optional
|
|
10
|
+
|
|
11
|
+
# Key packages to check versions for
|
|
12
|
+
PACKAGES = [
|
|
13
|
+
"sleap-nn",
|
|
14
|
+
"sleap-io",
|
|
15
|
+
"torch",
|
|
16
|
+
"pytorch-lightning",
|
|
17
|
+
"kornia",
|
|
18
|
+
"wandb",
|
|
19
|
+
"numpy",
|
|
20
|
+
"h5py",
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
# CUDA version -> (min_driver_linux, min_driver_windows)
|
|
24
|
+
# Source: https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/
|
|
25
|
+
CUDA_DRIVER_REQUIREMENTS = {
|
|
26
|
+
"12.6": ("560.28.03", "560.76"),
|
|
27
|
+
"12.8": ("570.26", "570.65"),
|
|
28
|
+
"13.0": ("580.65.06", "580.00"),
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def parse_driver_version(version: str) -> tuple[int, ...]:
|
|
33
|
+
"""Parse driver version string into comparable tuple."""
|
|
34
|
+
try:
|
|
35
|
+
return tuple(int(x) for x in version.split("."))
|
|
36
|
+
except ValueError:
|
|
37
|
+
return (0,)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def get_min_driver_for_cuda(cuda_version: str) -> Optional[tuple[str, str]]:
|
|
41
|
+
"""Get minimum driver versions for a CUDA version.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
cuda_version: CUDA version string (e.g., "12.6" or "12.6.1")
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Tuple of (min_linux, min_windows) or None if unknown version.
|
|
48
|
+
"""
|
|
49
|
+
if not cuda_version:
|
|
50
|
+
return None
|
|
51
|
+
# Match major.minor (e.g., "12.6" from "12.6.1")
|
|
52
|
+
parts = cuda_version.split(".")
|
|
53
|
+
if len(parts) >= 2:
|
|
54
|
+
major_minor = f"{parts[0]}.{parts[1]}"
|
|
55
|
+
return CUDA_DRIVER_REQUIREMENTS.get(major_minor)
|
|
56
|
+
return None
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def check_driver_compatibility(
|
|
60
|
+
driver_version: str, cuda_version: str
|
|
61
|
+
) -> tuple[bool, Optional[str]]:
|
|
62
|
+
"""Check if driver version is compatible with CUDA version.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
driver_version: Installed driver version string
|
|
66
|
+
cuda_version: CUDA version from PyTorch
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Tuple of (is_compatible, min_required_version).
|
|
70
|
+
If CUDA version is unknown, returns (True, None).
|
|
71
|
+
"""
|
|
72
|
+
min_versions = get_min_driver_for_cuda(cuda_version)
|
|
73
|
+
if not min_versions:
|
|
74
|
+
return True, None # Unknown CUDA version, skip check
|
|
75
|
+
|
|
76
|
+
if sys.platform == "win32":
|
|
77
|
+
min_version = min_versions[1]
|
|
78
|
+
else:
|
|
79
|
+
min_version = min_versions[0]
|
|
80
|
+
|
|
81
|
+
current = parse_driver_version(driver_version)
|
|
82
|
+
required = parse_driver_version(min_version)
|
|
83
|
+
|
|
84
|
+
# Pad tuples to same length for comparison
|
|
85
|
+
max_len = max(len(current), len(required))
|
|
86
|
+
current = current + (0,) * (max_len - len(current))
|
|
87
|
+
required = required + (0,) * (max_len - len(required))
|
|
88
|
+
|
|
89
|
+
return current >= required, min_version
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def get_nvidia_driver_version() -> Optional[str]:
|
|
93
|
+
"""Get NVIDIA driver version from nvidia-smi."""
|
|
94
|
+
if not shutil.which("nvidia-smi"):
|
|
95
|
+
return None
|
|
96
|
+
try:
|
|
97
|
+
result = subprocess.run(
|
|
98
|
+
["nvidia-smi", "--query-gpu=driver_version", "--format=csv,noheader"],
|
|
99
|
+
capture_output=True,
|
|
100
|
+
text=True,
|
|
101
|
+
timeout=5,
|
|
102
|
+
)
|
|
103
|
+
if result.returncode == 0:
|
|
104
|
+
return result.stdout.strip().split("\n")[0]
|
|
105
|
+
except Exception:
|
|
106
|
+
pass
|
|
107
|
+
return None
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _get_package_location(name: str, dist) -> str:
|
|
111
|
+
"""Get the actual installed location of a package.
|
|
112
|
+
|
|
113
|
+
For editable installs, returns the source directory.
|
|
114
|
+
For regular installs, returns the site-packages path.
|
|
115
|
+
"""
|
|
116
|
+
# Try to import the package and get its __file__
|
|
117
|
+
try:
|
|
118
|
+
# Convert package name to module name (e.g., "sleap-nn" -> "sleap_nn")
|
|
119
|
+
module_name = name.replace("-", "_")
|
|
120
|
+
module = __import__(module_name)
|
|
121
|
+
if hasattr(module, "__file__") and module.__file__:
|
|
122
|
+
# Return parent of the package's __init__.py
|
|
123
|
+
from pathlib import Path
|
|
124
|
+
|
|
125
|
+
return str(Path(module.__file__).parent.parent)
|
|
126
|
+
except (ImportError, AttributeError):
|
|
127
|
+
pass
|
|
128
|
+
|
|
129
|
+
# Fallback to dist._path
|
|
130
|
+
if dist._path:
|
|
131
|
+
path = dist._path.parent
|
|
132
|
+
# Make absolute if relative
|
|
133
|
+
if not path.is_absolute():
|
|
134
|
+
from pathlib import Path
|
|
135
|
+
|
|
136
|
+
path = Path.cwd() / path
|
|
137
|
+
return str(path)
|
|
138
|
+
|
|
139
|
+
return ""
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def get_package_info(name: str) -> dict:
|
|
143
|
+
"""Get package version, location, and install source.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
name: Package name (e.g., "sleap-nn")
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
Dict with version, location, source, and editable fields.
|
|
150
|
+
"""
|
|
151
|
+
try:
|
|
152
|
+
dist = importlib.metadata.distribution(name)
|
|
153
|
+
version = dist.version
|
|
154
|
+
|
|
155
|
+
# Check for editable install and source via direct_url.json
|
|
156
|
+
is_editable = False
|
|
157
|
+
source = "pip" # Default assumption
|
|
158
|
+
try:
|
|
159
|
+
direct_url_text = dist.read_text("direct_url.json")
|
|
160
|
+
if direct_url_text:
|
|
161
|
+
direct_url = json.loads(direct_url_text)
|
|
162
|
+
is_editable = direct_url.get("dir_info", {}).get("editable", False)
|
|
163
|
+
if is_editable:
|
|
164
|
+
source = "editable"
|
|
165
|
+
elif "vcs_info" in direct_url:
|
|
166
|
+
source = "git"
|
|
167
|
+
elif direct_url.get("url", "").startswith("file://"):
|
|
168
|
+
source = "local"
|
|
169
|
+
except FileNotFoundError:
|
|
170
|
+
pass
|
|
171
|
+
|
|
172
|
+
# Fallback: detect old-style editable installs (.egg-info not in site-packages)
|
|
173
|
+
if not is_editable and dist._path:
|
|
174
|
+
path_str = str(dist._path)
|
|
175
|
+
# Old-style editable: .egg-info in source dir, not site-packages
|
|
176
|
+
if ".egg-info" in path_str and "site-packages" not in path_str:
|
|
177
|
+
is_editable = True
|
|
178
|
+
source = "editable"
|
|
179
|
+
|
|
180
|
+
# Check for conda install via INSTALLER file (only if not already known)
|
|
181
|
+
if source == "pip":
|
|
182
|
+
try:
|
|
183
|
+
installer = dist.read_text("INSTALLER")
|
|
184
|
+
if installer and installer.strip() == "conda":
|
|
185
|
+
source = "conda"
|
|
186
|
+
except FileNotFoundError:
|
|
187
|
+
pass
|
|
188
|
+
|
|
189
|
+
# Get location (after determining if editable, so we can use the right method)
|
|
190
|
+
location = _get_package_location(name, dist)
|
|
191
|
+
|
|
192
|
+
return {
|
|
193
|
+
"version": version,
|
|
194
|
+
"location": location,
|
|
195
|
+
"source": source,
|
|
196
|
+
"editable": is_editable,
|
|
197
|
+
}
|
|
198
|
+
except importlib.metadata.PackageNotFoundError:
|
|
199
|
+
return {
|
|
200
|
+
"version": "not installed",
|
|
201
|
+
"location": "",
|
|
202
|
+
"source": "",
|
|
203
|
+
"editable": False,
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def get_system_info_dict() -> dict:
|
|
208
|
+
"""Get system information as a dictionary.
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
Dictionary with system info including Python version, platform,
|
|
212
|
+
PyTorch version, CUDA availability, GPU details, and package versions.
|
|
213
|
+
"""
|
|
214
|
+
import torch
|
|
215
|
+
|
|
216
|
+
info = {
|
|
217
|
+
"python_version": sys.version.split()[0],
|
|
218
|
+
"platform": platform.platform(),
|
|
219
|
+
"pytorch_version": torch.__version__,
|
|
220
|
+
"cuda_available": torch.cuda.is_available(),
|
|
221
|
+
"cuda_version": None,
|
|
222
|
+
"cudnn_version": None,
|
|
223
|
+
"driver_version": None,
|
|
224
|
+
"driver_compatible": None,
|
|
225
|
+
"driver_min_required": None,
|
|
226
|
+
"gpu_count": 0,
|
|
227
|
+
"gpus": [],
|
|
228
|
+
"mps_available": False,
|
|
229
|
+
"accelerator": "cpu", # cpu, cuda, or mps
|
|
230
|
+
"packages": {},
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
# Driver version (check even if CUDA unavailable - old driver can cause this)
|
|
234
|
+
driver = get_nvidia_driver_version()
|
|
235
|
+
if driver:
|
|
236
|
+
info["driver_version"] = driver
|
|
237
|
+
|
|
238
|
+
# CUDA details
|
|
239
|
+
if torch.cuda.is_available():
|
|
240
|
+
info["cuda_version"] = torch.version.cuda
|
|
241
|
+
info["cudnn_version"] = str(torch.backends.cudnn.version())
|
|
242
|
+
info["gpu_count"] = torch.cuda.device_count()
|
|
243
|
+
info["accelerator"] = "cuda"
|
|
244
|
+
|
|
245
|
+
# Check driver compatibility
|
|
246
|
+
if driver and info["cuda_version"]:
|
|
247
|
+
is_compatible, min_required = check_driver_compatibility(
|
|
248
|
+
driver, info["cuda_version"]
|
|
249
|
+
)
|
|
250
|
+
info["driver_compatible"] = is_compatible
|
|
251
|
+
info["driver_min_required"] = min_required
|
|
252
|
+
|
|
253
|
+
# GPU details
|
|
254
|
+
for i in range(torch.cuda.device_count()):
|
|
255
|
+
props = torch.cuda.get_device_properties(i)
|
|
256
|
+
info["gpus"].append(
|
|
257
|
+
{
|
|
258
|
+
"id": i,
|
|
259
|
+
"name": props.name,
|
|
260
|
+
"compute_capability": f"{props.major}.{props.minor}",
|
|
261
|
+
"memory_gb": round(props.total_memory / (1024**3), 1),
|
|
262
|
+
}
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
# MPS (Apple Silicon)
|
|
266
|
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
267
|
+
info["mps_available"] = True
|
|
268
|
+
info["accelerator"] = "mps"
|
|
269
|
+
info["gpu_count"] = 1
|
|
270
|
+
|
|
271
|
+
# Package versions
|
|
272
|
+
for pkg in PACKAGES:
|
|
273
|
+
info["packages"][pkg] = get_package_info(pkg)
|
|
274
|
+
|
|
275
|
+
return info
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def test_gpu_operations() -> tuple[bool, Optional[str]]:
|
|
279
|
+
"""Test that GPU tensor operations work.
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
Tuple of (success, error_message).
|
|
283
|
+
"""
|
|
284
|
+
import torch
|
|
285
|
+
|
|
286
|
+
if torch.cuda.is_available():
|
|
287
|
+
try:
|
|
288
|
+
x = torch.randn(100, 100, device="cuda")
|
|
289
|
+
_ = torch.mm(x, x)
|
|
290
|
+
return True, None
|
|
291
|
+
except Exception as e:
|
|
292
|
+
return False, str(e)
|
|
293
|
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
294
|
+
try:
|
|
295
|
+
x = torch.randn(100, 100, device="mps")
|
|
296
|
+
_ = torch.mm(x, x)
|
|
297
|
+
return True, None
|
|
298
|
+
except Exception as e:
|
|
299
|
+
return False, str(e)
|
|
300
|
+
return False, "No GPU available"
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def _shorten_path(path: str, max_len: int = 40) -> str:
|
|
304
|
+
"""Shorten a path for display, keeping the end."""
|
|
305
|
+
if len(path) <= max_len:
|
|
306
|
+
return path
|
|
307
|
+
return "..." + path[-(max_len - 3) :]
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def print_system_info() -> None:
|
|
311
|
+
"""Print comprehensive system diagnostics to console."""
|
|
312
|
+
from rich.console import Console
|
|
313
|
+
from rich.table import Table
|
|
314
|
+
|
|
315
|
+
console = Console()
|
|
316
|
+
info = get_system_info_dict()
|
|
317
|
+
|
|
318
|
+
# System info table (with GPU details integrated)
|
|
319
|
+
table = Table(title="System Information", show_header=False)
|
|
320
|
+
table.add_column("Property", style="cyan")
|
|
321
|
+
table.add_column("Value", style="white")
|
|
322
|
+
|
|
323
|
+
table.add_row("Python", info["python_version"])
|
|
324
|
+
table.add_row("Platform", info["platform"])
|
|
325
|
+
table.add_row("PyTorch", info["pytorch_version"])
|
|
326
|
+
|
|
327
|
+
# GPU/Accelerator info
|
|
328
|
+
if info["accelerator"] == "cuda":
|
|
329
|
+
table.add_row("Accelerator", "CUDA")
|
|
330
|
+
table.add_row("CUDA version", info["cuda_version"] or "N/A")
|
|
331
|
+
table.add_row("cuDNN version", info["cudnn_version"] or "N/A")
|
|
332
|
+
table.add_row("Driver version", info["driver_version"] or "N/A")
|
|
333
|
+
table.add_row("GPU count", str(info["gpu_count"]))
|
|
334
|
+
# GPU details inline
|
|
335
|
+
for gpu in info["gpus"]:
|
|
336
|
+
gpu_str = f"{gpu['name']} ({gpu['memory_gb']} GB, compute {gpu['compute_capability']})"
|
|
337
|
+
table.add_row(f"GPU {gpu['id']}", gpu_str)
|
|
338
|
+
elif info["accelerator"] == "mps":
|
|
339
|
+
table.add_row("Accelerator", "MPS (Apple Silicon)")
|
|
340
|
+
else:
|
|
341
|
+
table.add_row("Accelerator", "CPU only")
|
|
342
|
+
# Show driver if present but CUDA unavailable (helps diagnose issues)
|
|
343
|
+
if info["driver_version"]:
|
|
344
|
+
table.add_row("Driver version", info["driver_version"])
|
|
345
|
+
|
|
346
|
+
console.print(table)
|
|
347
|
+
|
|
348
|
+
# Package versions table
|
|
349
|
+
console.print()
|
|
350
|
+
pkg_table = Table(title="Package Versions")
|
|
351
|
+
pkg_table.add_column("Package", style="cyan")
|
|
352
|
+
pkg_table.add_column("Version", style="white")
|
|
353
|
+
pkg_table.add_column("Source", style="yellow")
|
|
354
|
+
pkg_table.add_column("Location", style="dim")
|
|
355
|
+
|
|
356
|
+
for pkg, pkg_info in info["packages"].items():
|
|
357
|
+
if pkg_info["version"] == "not installed":
|
|
358
|
+
version_display = f"[dim]{pkg_info['version']}[/dim]"
|
|
359
|
+
pkg_table.add_row(pkg, version_display, "", "")
|
|
360
|
+
else:
|
|
361
|
+
location_display = _shorten_path(pkg_info["location"])
|
|
362
|
+
pkg_table.add_row(
|
|
363
|
+
pkg, pkg_info["version"], pkg_info["source"], location_display
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
console.print(pkg_table)
|
|
367
|
+
|
|
368
|
+
# Actionable diagnostics
|
|
369
|
+
console.print()
|
|
370
|
+
|
|
371
|
+
# Driver compatibility check (CUDA only)
|
|
372
|
+
if info["accelerator"] == "cuda" and info["driver_version"]:
|
|
373
|
+
if info["driver_min_required"]:
|
|
374
|
+
if info["driver_compatible"]:
|
|
375
|
+
console.print(
|
|
376
|
+
f"[green]OK[/green] Driver is compatible: "
|
|
377
|
+
f"{info['driver_version']} >= {info['driver_min_required']} "
|
|
378
|
+
f"(required for CUDA {info['cuda_version']})"
|
|
379
|
+
)
|
|
380
|
+
else:
|
|
381
|
+
console.print(
|
|
382
|
+
f"[red]FAIL[/red] Driver is too old: "
|
|
383
|
+
f"{info['driver_version']} < {info['driver_min_required']} "
|
|
384
|
+
f"(required for CUDA {info['cuda_version']})"
|
|
385
|
+
)
|
|
386
|
+
console.print(
|
|
387
|
+
" [yellow]Update your driver: https://www.nvidia.com/drivers[/yellow]"
|
|
388
|
+
)
|
|
389
|
+
else:
|
|
390
|
+
# Unknown CUDA version, can't check compatibility
|
|
391
|
+
console.print(
|
|
392
|
+
f"[yellow]![/yellow] Driver version: {info['driver_version']} "
|
|
393
|
+
f"(CUDA {info['cuda_version']} compatibility unknown)"
|
|
394
|
+
)
|
|
395
|
+
elif info["accelerator"] == "cpu" and info["driver_version"]:
|
|
396
|
+
# Has driver but no CUDA - might be a problem
|
|
397
|
+
console.print(
|
|
398
|
+
f"[yellow]![/yellow] NVIDIA driver found ({info['driver_version']}) "
|
|
399
|
+
"but CUDA is not available"
|
|
400
|
+
)
|
|
401
|
+
console.print(
|
|
402
|
+
" [dim]This may indicate a driver/PyTorch version mismatch[/dim]"
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
# GPU connection test
|
|
406
|
+
success, error = test_gpu_operations()
|
|
407
|
+
if info["accelerator"] == "cuda":
|
|
408
|
+
if success:
|
|
409
|
+
console.print("[green]OK[/green] PyTorch can use GPU")
|
|
410
|
+
else:
|
|
411
|
+
console.print(f"[red]FAIL[/red] PyTorch cannot use GPU: {error}")
|
|
412
|
+
elif info["accelerator"] == "mps":
|
|
413
|
+
if success:
|
|
414
|
+
console.print("[green]OK[/green] PyTorch can use GPU")
|
|
415
|
+
else:
|
|
416
|
+
console.print(f"[red]FAIL[/red] PyTorch cannot use GPU: {error}")
|
|
417
|
+
else:
|
|
418
|
+
console.print("[dim]--[/dim] No GPU available (using CPU)")
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
def get_startup_info_string() -> str:
|
|
422
|
+
"""Get a concise system info string for startup logging.
|
|
423
|
+
|
|
424
|
+
Returns:
|
|
425
|
+
Single-line string with key system info.
|
|
426
|
+
"""
|
|
427
|
+
import torch
|
|
428
|
+
|
|
429
|
+
from sleap_nn import __version__
|
|
430
|
+
|
|
431
|
+
parts = [f"sleap-nn {__version__}"]
|
|
432
|
+
parts.append(f"Python {sys.version.split()[0]}")
|
|
433
|
+
parts.append(f"PyTorch {torch.__version__}")
|
|
434
|
+
|
|
435
|
+
if torch.cuda.is_available():
|
|
436
|
+
parts.append(f"CUDA {torch.version.cuda}")
|
|
437
|
+
parts.append(f"{torch.cuda.device_count()} GPU(s)")
|
|
438
|
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
439
|
+
parts.append("MPS")
|
|
440
|
+
else:
|
|
441
|
+
parts.append("CPU only")
|
|
442
|
+
|
|
443
|
+
return " | ".join(parts)
|
sleap_nn/tracking/tracker.py
CHANGED
|
@@ -898,7 +898,14 @@ def run_tracker(
|
|
|
898
898
|
tracking_target_instance_count is None
|
|
899
899
|
or tracking_target_instance_count == 0
|
|
900
900
|
):
|
|
901
|
-
|
|
901
|
+
if max_tracks is not None:
|
|
902
|
+
suggestion = f"Add --tracking_target_instance_count {max_tracks} to your command (using your --max_tracks value)."
|
|
903
|
+
else:
|
|
904
|
+
suggestion = "Add --tracking_target_instance_count N where N is the expected number of instances per frame."
|
|
905
|
+
message = (
|
|
906
|
+
f"--post_connect_single_breaks requires --tracking_target_instance_count to be set. "
|
|
907
|
+
f"{suggestion}"
|
|
908
|
+
)
|
|
902
909
|
logger.error(message)
|
|
903
910
|
raise ValueError(message)
|
|
904
911
|
start_final_pass_time = time()
|
sleap_nn/train.py
CHANGED
|
@@ -6,6 +6,7 @@ from datetime import datetime
|
|
|
6
6
|
from time import time
|
|
7
7
|
from omegaconf import DictConfig, OmegaConf
|
|
8
8
|
from typing import Any, Dict, Optional, List, Tuple, Union
|
|
9
|
+
import sleap_io as sio
|
|
9
10
|
from sleap_nn.config.training_job_config import TrainingJobConfig
|
|
10
11
|
from sleap_nn.training.model_trainer import ModelTrainer
|
|
11
12
|
from sleap_nn.predict import run_inference as predict
|
|
@@ -15,15 +16,31 @@ from sleap_nn.config.get_config import (
|
|
|
15
16
|
get_model_config,
|
|
16
17
|
get_data_config,
|
|
17
18
|
)
|
|
19
|
+
from sleap_nn.system_info import get_startup_info_string
|
|
18
20
|
|
|
19
21
|
|
|
20
|
-
def run_training(
|
|
21
|
-
|
|
22
|
+
def run_training(
|
|
23
|
+
config: DictConfig,
|
|
24
|
+
train_labels: Optional[List[sio.Labels]] = None,
|
|
25
|
+
val_labels: Optional[List[sio.Labels]] = None,
|
|
26
|
+
):
|
|
27
|
+
"""Create ModelTrainer instance and start training.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
config: Training configuration as a DictConfig.
|
|
31
|
+
train_labels: List of Labels objects for training.
|
|
32
|
+
val_labels: List of Labels objects for validation.
|
|
33
|
+
If not provided, the labels will be loaded from paths in the config.
|
|
34
|
+
"""
|
|
22
35
|
start_train_time = time()
|
|
23
36
|
start_timestamp = str(datetime.now())
|
|
24
37
|
logger.info(f"Started training at: {start_timestamp}")
|
|
38
|
+
logger.info(get_startup_info_string())
|
|
25
39
|
|
|
26
|
-
|
|
40
|
+
# provide the labels as the train labels, val labels will be split from the train labels
|
|
41
|
+
trainer = ModelTrainer.get_model_trainer_from_config(
|
|
42
|
+
config, train_labels=train_labels, val_labels=val_labels
|
|
43
|
+
)
|
|
27
44
|
trainer.train()
|
|
28
45
|
|
|
29
46
|
finish_timestamp = str(datetime.now())
|
|
@@ -39,48 +56,44 @@ def run_training(config: DictConfig):
|
|
|
39
56
|
# run inference on val dataset
|
|
40
57
|
if trainer.config.trainer_config.save_ckpt:
|
|
41
58
|
data_paths = {}
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
/ f"
|
|
59
|
+
run_path = (
|
|
60
|
+
Path(trainer.config.trainer_config.ckpt_dir)
|
|
61
|
+
/ trainer.config.trainer_config.run_name
|
|
62
|
+
)
|
|
63
|
+
for index, _ in enumerate(trainer.train_labels):
|
|
64
|
+
logger.info(f"Run path for index {index}: {run_path.as_posix()}")
|
|
65
|
+
data_paths[f"train.{index}"] = (
|
|
66
|
+
run_path / f"labels_gt.train.{index}.slp"
|
|
50
67
|
).as_posix()
|
|
51
|
-
data_paths[f"
|
|
52
|
-
|
|
53
|
-
/ trainer.config.trainer_config.run_name
|
|
54
|
-
/ f"labels_val_gt_{index}.slp"
|
|
68
|
+
data_paths[f"val.{index}"] = (
|
|
69
|
+
run_path / f"labels_gt.val.{index}.slp"
|
|
55
70
|
).as_posix()
|
|
56
71
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
)
|
|
61
|
-
|
|
72
|
+
# Handle test_file_path as either a string or list of strings
|
|
73
|
+
test_file_path = OmegaConf.select(
|
|
74
|
+
config, "data_config.test_file_path", default=None
|
|
75
|
+
)
|
|
76
|
+
if test_file_path is not None:
|
|
77
|
+
# Normalize to list of strings
|
|
78
|
+
if isinstance(test_file_path, str):
|
|
79
|
+
test_paths = [test_file_path]
|
|
80
|
+
else:
|
|
81
|
+
test_paths = list(test_file_path)
|
|
82
|
+
# Add each test path to data_paths (always use index for consistency)
|
|
83
|
+
for idx, test_path in enumerate(test_paths):
|
|
84
|
+
data_paths[f"test.{idx}"] = test_path
|
|
62
85
|
|
|
63
86
|
for d_name, path in data_paths.items():
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
/ f"pred_{d_name}.slp"
|
|
68
|
-
)
|
|
69
|
-
metrics_path = (
|
|
70
|
-
Path(trainer.config.trainer_config.ckpt_dir)
|
|
71
|
-
/ trainer.config.trainer_config.run_name
|
|
72
|
-
/ f"{d_name}_pred_metrics.npz"
|
|
73
|
-
)
|
|
87
|
+
# d_name is now in format: "train.0", "val.0", "test.0", etc.
|
|
88
|
+
pred_path = run_path / f"labels_pr.{d_name}.slp"
|
|
89
|
+
metrics_path = run_path / f"metrics.{d_name}.npz"
|
|
74
90
|
|
|
75
91
|
pred_labels = predict(
|
|
76
92
|
data_path=path,
|
|
77
|
-
model_paths=[
|
|
78
|
-
Path(trainer.config.trainer_config.ckpt_dir)
|
|
79
|
-
/ trainer.config.trainer_config.run_name
|
|
80
|
-
],
|
|
93
|
+
model_paths=[run_path],
|
|
81
94
|
peak_threshold=0.2,
|
|
82
95
|
make_labels=True,
|
|
83
|
-
device=trainer.trainer.strategy.root_device,
|
|
96
|
+
device=str(trainer.trainer.strategy.root_device),
|
|
84
97
|
output_path=pred_path,
|
|
85
98
|
ensure_rgb=config.data_config.preprocessing.ensure_rgb,
|
|
86
99
|
ensure_grayscale=config.data_config.preprocessing.ensure_grayscale,
|
|
@@ -110,7 +123,8 @@ def train(
|
|
|
110
123
|
train_labels_path: Optional[List[str]] = None,
|
|
111
124
|
val_labels_path: Optional[List[str]] = None,
|
|
112
125
|
validation_fraction: float = 0.1,
|
|
113
|
-
|
|
126
|
+
use_same_data_for_val: bool = False,
|
|
127
|
+
test_file_path: Optional[Union[str, List[str]]] = None,
|
|
114
128
|
provider: str = "LabelsReader",
|
|
115
129
|
user_instances_only: bool = True,
|
|
116
130
|
data_pipeline_fw: str = "torch_dataset",
|
|
@@ -124,6 +138,7 @@ def train(
|
|
|
124
138
|
max_width: Optional[int] = None,
|
|
125
139
|
crop_size: Optional[int] = None,
|
|
126
140
|
min_crop_size: Optional[int] = 100,
|
|
141
|
+
crop_padding: Optional[int] = None,
|
|
127
142
|
use_augmentations_train: bool = False,
|
|
128
143
|
intensity_aug: Optional[Union[str, List[str], Dict[str, Any]]] = None,
|
|
129
144
|
geometry_aug: Optional[Union[str, List[str], Dict[str, Any]]] = None,
|
|
@@ -188,7 +203,11 @@ def train(
|
|
|
188
203
|
training set to sample for generating the validation set. The remaining
|
|
189
204
|
labeled frames will be left in the training set. If the `validation_labels`
|
|
190
205
|
are already specified, this has no effect. Default: 0.1.
|
|
191
|
-
|
|
206
|
+
use_same_data_for_val: If `True`, use the same data for both training and
|
|
207
|
+
validation (train = val). Useful for intentional overfitting on small
|
|
208
|
+
datasets. When enabled, `val_labels_path` and `validation_fraction` are
|
|
209
|
+
ignored. Default: False.
|
|
210
|
+
test_file_path: Path or list of paths to test dataset(s) (`.slp` file(s) or `.mp4` file(s)).
|
|
192
211
|
Note: This is used to get evaluation on test set after training is completed.
|
|
193
212
|
provider: Provider class to read the input sleap files. Only "LabelsReader"
|
|
194
213
|
supported for the training pipeline. Default: "LabelsReader".
|
|
@@ -210,14 +229,17 @@ def train(
|
|
|
210
229
|
is set to True, then we convert the image to grayscale (single-channel)
|
|
211
230
|
image. If the source image has only one channel and this is set to False, then we retain the single channel input. Default: `False`.
|
|
212
231
|
scale: Factor to resize the image dimensions by, specified as a float. Default: 1.0.
|
|
213
|
-
max_height: Maximum height the image should be padded to. If not provided, the
|
|
232
|
+
max_height: Maximum height the original image should be resized and padded to. If not provided, the
|
|
214
233
|
original image size will be retained. Default: None.
|
|
215
|
-
max_width: Maximum width the image should be padded to. If not provided, the
|
|
234
|
+
max_width: Maximum width the original image should be resized and padded to. If not provided, the
|
|
216
235
|
original image size will be retained. Default: None.
|
|
217
236
|
crop_size: Crop size of each instance for centered-instance model.
|
|
218
237
|
If `None`, this would be automatically computed based on the largest instance
|
|
219
|
-
in the `sio.Labels` file. Default: None.
|
|
238
|
+
in the `sio.Labels` file. If `scale` is provided, then the cropped image will be resized according to `scale`. Default: None.
|
|
220
239
|
min_crop_size: Minimum crop size to be used if `crop_size` is `None`. Default: 100.
|
|
240
|
+
crop_padding: Padding in pixels to add around instance bounding box when computing
|
|
241
|
+
crop size. If `None`, padding is auto-computed based on augmentation settings.
|
|
242
|
+
Only used when `crop_size` is `None`. Default: None.
|
|
221
243
|
use_augmentations_train: True if the data augmentation should be applied to the
|
|
222
244
|
training data, else False. Default: False.
|
|
223
245
|
intensity_aug: One of ["uniform_noise", "gaussian_noise", "contrast", "brightness"]
|
|
@@ -376,6 +398,7 @@ def train(
|
|
|
376
398
|
train_labels_path=train_labels_path,
|
|
377
399
|
val_labels_path=val_labels_path,
|
|
378
400
|
validation_fraction=validation_fraction,
|
|
401
|
+
use_same_data_for_val=use_same_data_for_val,
|
|
379
402
|
test_file_path=test_file_path,
|
|
380
403
|
provider=provider,
|
|
381
404
|
user_instances_only=user_instances_only,
|
|
@@ -390,6 +413,7 @@ def train(
|
|
|
390
413
|
max_width=max_width,
|
|
391
414
|
crop_size=crop_size,
|
|
392
415
|
min_crop_size=min_crop_size,
|
|
416
|
+
crop_padding=crop_padding,
|
|
393
417
|
use_augmentations_train=use_augmentations_train,
|
|
394
418
|
intensity_aug=intensity_aug,
|
|
395
419
|
geometry_aug=geometry_aug,
|