sleap-nn 0.0.5__py3-none-any.whl → 0.1.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,443 @@
1
+ """System diagnostics and compatibility checking for sleap-nn."""
2
+
3
+ import importlib.metadata
4
+ import json
5
+ import platform
6
+ import shutil
7
+ import subprocess
8
+ import sys
9
+ from typing import Optional
10
+
11
+ # Key packages to check versions for
12
+ PACKAGES = [
13
+ "sleap-nn",
14
+ "sleap-io",
15
+ "torch",
16
+ "pytorch-lightning",
17
+ "kornia",
18
+ "wandb",
19
+ "numpy",
20
+ "h5py",
21
+ ]
22
+
23
+ # CUDA version -> (min_driver_linux, min_driver_windows)
24
+ # Source: https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/
25
+ CUDA_DRIVER_REQUIREMENTS = {
26
+ "12.6": ("560.28.03", "560.76"),
27
+ "12.8": ("570.26", "570.65"),
28
+ "13.0": ("580.65.06", "580.00"),
29
+ }
30
+
31
+
32
+ def parse_driver_version(version: str) -> tuple[int, ...]:
33
+ """Parse driver version string into comparable tuple."""
34
+ try:
35
+ return tuple(int(x) for x in version.split("."))
36
+ except ValueError:
37
+ return (0,)
38
+
39
+
40
+ def get_min_driver_for_cuda(cuda_version: str) -> Optional[tuple[str, str]]:
41
+ """Get minimum driver versions for a CUDA version.
42
+
43
+ Args:
44
+ cuda_version: CUDA version string (e.g., "12.6" or "12.6.1")
45
+
46
+ Returns:
47
+ Tuple of (min_linux, min_windows) or None if unknown version.
48
+ """
49
+ if not cuda_version:
50
+ return None
51
+ # Match major.minor (e.g., "12.6" from "12.6.1")
52
+ parts = cuda_version.split(".")
53
+ if len(parts) >= 2:
54
+ major_minor = f"{parts[0]}.{parts[1]}"
55
+ return CUDA_DRIVER_REQUIREMENTS.get(major_minor)
56
+ return None
57
+
58
+
59
+ def check_driver_compatibility(
60
+ driver_version: str, cuda_version: str
61
+ ) -> tuple[bool, Optional[str]]:
62
+ """Check if driver version is compatible with CUDA version.
63
+
64
+ Args:
65
+ driver_version: Installed driver version string
66
+ cuda_version: CUDA version from PyTorch
67
+
68
+ Returns:
69
+ Tuple of (is_compatible, min_required_version).
70
+ If CUDA version is unknown, returns (True, None).
71
+ """
72
+ min_versions = get_min_driver_for_cuda(cuda_version)
73
+ if not min_versions:
74
+ return True, None # Unknown CUDA version, skip check
75
+
76
+ if sys.platform == "win32":
77
+ min_version = min_versions[1]
78
+ else:
79
+ min_version = min_versions[0]
80
+
81
+ current = parse_driver_version(driver_version)
82
+ required = parse_driver_version(min_version)
83
+
84
+ # Pad tuples to same length for comparison
85
+ max_len = max(len(current), len(required))
86
+ current = current + (0,) * (max_len - len(current))
87
+ required = required + (0,) * (max_len - len(required))
88
+
89
+ return current >= required, min_version
90
+
91
+
92
+ def get_nvidia_driver_version() -> Optional[str]:
93
+ """Get NVIDIA driver version from nvidia-smi."""
94
+ if not shutil.which("nvidia-smi"):
95
+ return None
96
+ try:
97
+ result = subprocess.run(
98
+ ["nvidia-smi", "--query-gpu=driver_version", "--format=csv,noheader"],
99
+ capture_output=True,
100
+ text=True,
101
+ timeout=5,
102
+ )
103
+ if result.returncode == 0:
104
+ return result.stdout.strip().split("\n")[0]
105
+ except Exception:
106
+ pass
107
+ return None
108
+
109
+
110
+ def _get_package_location(name: str, dist) -> str:
111
+ """Get the actual installed location of a package.
112
+
113
+ For editable installs, returns the source directory.
114
+ For regular installs, returns the site-packages path.
115
+ """
116
+ # Try to import the package and get its __file__
117
+ try:
118
+ # Convert package name to module name (e.g., "sleap-nn" -> "sleap_nn")
119
+ module_name = name.replace("-", "_")
120
+ module = __import__(module_name)
121
+ if hasattr(module, "__file__") and module.__file__:
122
+ # Return parent of the package's __init__.py
123
+ from pathlib import Path
124
+
125
+ return str(Path(module.__file__).parent.parent)
126
+ except (ImportError, AttributeError):
127
+ pass
128
+
129
+ # Fallback to dist._path
130
+ if dist._path:
131
+ path = dist._path.parent
132
+ # Make absolute if relative
133
+ if not path.is_absolute():
134
+ from pathlib import Path
135
+
136
+ path = Path.cwd() / path
137
+ return str(path)
138
+
139
+ return ""
140
+
141
+
142
+ def get_package_info(name: str) -> dict:
143
+ """Get package version, location, and install source.
144
+
145
+ Args:
146
+ name: Package name (e.g., "sleap-nn")
147
+
148
+ Returns:
149
+ Dict with version, location, source, and editable fields.
150
+ """
151
+ try:
152
+ dist = importlib.metadata.distribution(name)
153
+ version = dist.version
154
+
155
+ # Check for editable install and source via direct_url.json
156
+ is_editable = False
157
+ source = "pip" # Default assumption
158
+ try:
159
+ direct_url_text = dist.read_text("direct_url.json")
160
+ if direct_url_text:
161
+ direct_url = json.loads(direct_url_text)
162
+ is_editable = direct_url.get("dir_info", {}).get("editable", False)
163
+ if is_editable:
164
+ source = "editable"
165
+ elif "vcs_info" in direct_url:
166
+ source = "git"
167
+ elif direct_url.get("url", "").startswith("file://"):
168
+ source = "local"
169
+ except FileNotFoundError:
170
+ pass
171
+
172
+ # Fallback: detect old-style editable installs (.egg-info not in site-packages)
173
+ if not is_editable and dist._path:
174
+ path_str = str(dist._path)
175
+ # Old-style editable: .egg-info in source dir, not site-packages
176
+ if ".egg-info" in path_str and "site-packages" not in path_str:
177
+ is_editable = True
178
+ source = "editable"
179
+
180
+ # Check for conda install via INSTALLER file (only if not already known)
181
+ if source == "pip":
182
+ try:
183
+ installer = dist.read_text("INSTALLER")
184
+ if installer and installer.strip() == "conda":
185
+ source = "conda"
186
+ except FileNotFoundError:
187
+ pass
188
+
189
+ # Get location (after determining if editable, so we can use the right method)
190
+ location = _get_package_location(name, dist)
191
+
192
+ return {
193
+ "version": version,
194
+ "location": location,
195
+ "source": source,
196
+ "editable": is_editable,
197
+ }
198
+ except importlib.metadata.PackageNotFoundError:
199
+ return {
200
+ "version": "not installed",
201
+ "location": "",
202
+ "source": "",
203
+ "editable": False,
204
+ }
205
+
206
+
207
+ def get_system_info_dict() -> dict:
208
+ """Get system information as a dictionary.
209
+
210
+ Returns:
211
+ Dictionary with system info including Python version, platform,
212
+ PyTorch version, CUDA availability, GPU details, and package versions.
213
+ """
214
+ import torch
215
+
216
+ info = {
217
+ "python_version": sys.version.split()[0],
218
+ "platform": platform.platform(),
219
+ "pytorch_version": torch.__version__,
220
+ "cuda_available": torch.cuda.is_available(),
221
+ "cuda_version": None,
222
+ "cudnn_version": None,
223
+ "driver_version": None,
224
+ "driver_compatible": None,
225
+ "driver_min_required": None,
226
+ "gpu_count": 0,
227
+ "gpus": [],
228
+ "mps_available": False,
229
+ "accelerator": "cpu", # cpu, cuda, or mps
230
+ "packages": {},
231
+ }
232
+
233
+ # Driver version (check even if CUDA unavailable - old driver can cause this)
234
+ driver = get_nvidia_driver_version()
235
+ if driver:
236
+ info["driver_version"] = driver
237
+
238
+ # CUDA details
239
+ if torch.cuda.is_available():
240
+ info["cuda_version"] = torch.version.cuda
241
+ info["cudnn_version"] = str(torch.backends.cudnn.version())
242
+ info["gpu_count"] = torch.cuda.device_count()
243
+ info["accelerator"] = "cuda"
244
+
245
+ # Check driver compatibility
246
+ if driver and info["cuda_version"]:
247
+ is_compatible, min_required = check_driver_compatibility(
248
+ driver, info["cuda_version"]
249
+ )
250
+ info["driver_compatible"] = is_compatible
251
+ info["driver_min_required"] = min_required
252
+
253
+ # GPU details
254
+ for i in range(torch.cuda.device_count()):
255
+ props = torch.cuda.get_device_properties(i)
256
+ info["gpus"].append(
257
+ {
258
+ "id": i,
259
+ "name": props.name,
260
+ "compute_capability": f"{props.major}.{props.minor}",
261
+ "memory_gb": round(props.total_memory / (1024**3), 1),
262
+ }
263
+ )
264
+
265
+ # MPS (Apple Silicon)
266
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
267
+ info["mps_available"] = True
268
+ info["accelerator"] = "mps"
269
+ info["gpu_count"] = 1
270
+
271
+ # Package versions
272
+ for pkg in PACKAGES:
273
+ info["packages"][pkg] = get_package_info(pkg)
274
+
275
+ return info
276
+
277
+
278
+ def test_gpu_operations() -> tuple[bool, Optional[str]]:
279
+ """Test that GPU tensor operations work.
280
+
281
+ Returns:
282
+ Tuple of (success, error_message).
283
+ """
284
+ import torch
285
+
286
+ if torch.cuda.is_available():
287
+ try:
288
+ x = torch.randn(100, 100, device="cuda")
289
+ _ = torch.mm(x, x)
290
+ return True, None
291
+ except Exception as e:
292
+ return False, str(e)
293
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
294
+ try:
295
+ x = torch.randn(100, 100, device="mps")
296
+ _ = torch.mm(x, x)
297
+ return True, None
298
+ except Exception as e:
299
+ return False, str(e)
300
+ return False, "No GPU available"
301
+
302
+
303
+ def _shorten_path(path: str, max_len: int = 40) -> str:
304
+ """Shorten a path for display, keeping the end."""
305
+ if len(path) <= max_len:
306
+ return path
307
+ return "..." + path[-(max_len - 3) :]
308
+
309
+
310
+ def print_system_info() -> None:
311
+ """Print comprehensive system diagnostics to console."""
312
+ from rich.console import Console
313
+ from rich.table import Table
314
+
315
+ console = Console()
316
+ info = get_system_info_dict()
317
+
318
+ # System info table (with GPU details integrated)
319
+ table = Table(title="System Information", show_header=False)
320
+ table.add_column("Property", style="cyan")
321
+ table.add_column("Value", style="white")
322
+
323
+ table.add_row("Python", info["python_version"])
324
+ table.add_row("Platform", info["platform"])
325
+ table.add_row("PyTorch", info["pytorch_version"])
326
+
327
+ # GPU/Accelerator info
328
+ if info["accelerator"] == "cuda":
329
+ table.add_row("Accelerator", "CUDA")
330
+ table.add_row("CUDA version", info["cuda_version"] or "N/A")
331
+ table.add_row("cuDNN version", info["cudnn_version"] or "N/A")
332
+ table.add_row("Driver version", info["driver_version"] or "N/A")
333
+ table.add_row("GPU count", str(info["gpu_count"]))
334
+ # GPU details inline
335
+ for gpu in info["gpus"]:
336
+ gpu_str = f"{gpu['name']} ({gpu['memory_gb']} GB, compute {gpu['compute_capability']})"
337
+ table.add_row(f"GPU {gpu['id']}", gpu_str)
338
+ elif info["accelerator"] == "mps":
339
+ table.add_row("Accelerator", "MPS (Apple Silicon)")
340
+ else:
341
+ table.add_row("Accelerator", "CPU only")
342
+ # Show driver if present but CUDA unavailable (helps diagnose issues)
343
+ if info["driver_version"]:
344
+ table.add_row("Driver version", info["driver_version"])
345
+
346
+ console.print(table)
347
+
348
+ # Package versions table
349
+ console.print()
350
+ pkg_table = Table(title="Package Versions")
351
+ pkg_table.add_column("Package", style="cyan")
352
+ pkg_table.add_column("Version", style="white")
353
+ pkg_table.add_column("Source", style="yellow")
354
+ pkg_table.add_column("Location", style="dim")
355
+
356
+ for pkg, pkg_info in info["packages"].items():
357
+ if pkg_info["version"] == "not installed":
358
+ version_display = f"[dim]{pkg_info['version']}[/dim]"
359
+ pkg_table.add_row(pkg, version_display, "", "")
360
+ else:
361
+ location_display = _shorten_path(pkg_info["location"])
362
+ pkg_table.add_row(
363
+ pkg, pkg_info["version"], pkg_info["source"], location_display
364
+ )
365
+
366
+ console.print(pkg_table)
367
+
368
+ # Actionable diagnostics
369
+ console.print()
370
+
371
+ # Driver compatibility check (CUDA only)
372
+ if info["accelerator"] == "cuda" and info["driver_version"]:
373
+ if info["driver_min_required"]:
374
+ if info["driver_compatible"]:
375
+ console.print(
376
+ f"[green]OK[/green] Driver is compatible: "
377
+ f"{info['driver_version']} >= {info['driver_min_required']} "
378
+ f"(required for CUDA {info['cuda_version']})"
379
+ )
380
+ else:
381
+ console.print(
382
+ f"[red]FAIL[/red] Driver is too old: "
383
+ f"{info['driver_version']} < {info['driver_min_required']} "
384
+ f"(required for CUDA {info['cuda_version']})"
385
+ )
386
+ console.print(
387
+ " [yellow]Update your driver: https://www.nvidia.com/drivers[/yellow]"
388
+ )
389
+ else:
390
+ # Unknown CUDA version, can't check compatibility
391
+ console.print(
392
+ f"[yellow]![/yellow] Driver version: {info['driver_version']} "
393
+ f"(CUDA {info['cuda_version']} compatibility unknown)"
394
+ )
395
+ elif info["accelerator"] == "cpu" and info["driver_version"]:
396
+ # Has driver but no CUDA - might be a problem
397
+ console.print(
398
+ f"[yellow]![/yellow] NVIDIA driver found ({info['driver_version']}) "
399
+ "but CUDA is not available"
400
+ )
401
+ console.print(
402
+ " [dim]This may indicate a driver/PyTorch version mismatch[/dim]"
403
+ )
404
+
405
+ # GPU connection test
406
+ success, error = test_gpu_operations()
407
+ if info["accelerator"] == "cuda":
408
+ if success:
409
+ console.print("[green]OK[/green] PyTorch can use GPU")
410
+ else:
411
+ console.print(f"[red]FAIL[/red] PyTorch cannot use GPU: {error}")
412
+ elif info["accelerator"] == "mps":
413
+ if success:
414
+ console.print("[green]OK[/green] PyTorch can use GPU")
415
+ else:
416
+ console.print(f"[red]FAIL[/red] PyTorch cannot use GPU: {error}")
417
+ else:
418
+ console.print("[dim]--[/dim] No GPU available (using CPU)")
419
+
420
+
421
+ def get_startup_info_string() -> str:
422
+ """Get a concise system info string for startup logging.
423
+
424
+ Returns:
425
+ Single-line string with key system info.
426
+ """
427
+ import torch
428
+
429
+ from sleap_nn import __version__
430
+
431
+ parts = [f"sleap-nn {__version__}"]
432
+ parts.append(f"Python {sys.version.split()[0]}")
433
+ parts.append(f"PyTorch {torch.__version__}")
434
+
435
+ if torch.cuda.is_available():
436
+ parts.append(f"CUDA {torch.version.cuda}")
437
+ parts.append(f"{torch.cuda.device_count()} GPU(s)")
438
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
439
+ parts.append("MPS")
440
+ else:
441
+ parts.append("CPU only")
442
+
443
+ return " | ".join(parts)
@@ -898,7 +898,14 @@ def run_tracker(
898
898
  tracking_target_instance_count is None
899
899
  or tracking_target_instance_count == 0
900
900
  ):
901
- message = "tracking_target_instance_count is None or 0. To connect single breaks, tracking_target_instance_count should be set to an integer."
901
+ if max_tracks is not None:
902
+ suggestion = f"Add --tracking_target_instance_count {max_tracks} to your command (using your --max_tracks value)."
903
+ else:
904
+ suggestion = "Add --tracking_target_instance_count N where N is the expected number of instances per frame."
905
+ message = (
906
+ f"--post_connect_single_breaks requires --tracking_target_instance_count to be set. "
907
+ f"{suggestion}"
908
+ )
902
909
  logger.error(message)
903
910
  raise ValueError(message)
904
911
  start_final_pass_time = time()
sleap_nn/train.py CHANGED
@@ -6,6 +6,7 @@ from datetime import datetime
6
6
  from time import time
7
7
  from omegaconf import DictConfig, OmegaConf
8
8
  from typing import Any, Dict, Optional, List, Tuple, Union
9
+ import sleap_io as sio
9
10
  from sleap_nn.config.training_job_config import TrainingJobConfig
10
11
  from sleap_nn.training.model_trainer import ModelTrainer
11
12
  from sleap_nn.predict import run_inference as predict
@@ -15,15 +16,31 @@ from sleap_nn.config.get_config import (
15
16
  get_model_config,
16
17
  get_data_config,
17
18
  )
19
+ from sleap_nn.system_info import get_startup_info_string
18
20
 
19
21
 
20
- def run_training(config: DictConfig):
21
- """Create ModelTrainer instance and start training."""
22
+ def run_training(
23
+ config: DictConfig,
24
+ train_labels: Optional[List[sio.Labels]] = None,
25
+ val_labels: Optional[List[sio.Labels]] = None,
26
+ ):
27
+ """Create ModelTrainer instance and start training.
28
+
29
+ Args:
30
+ config: Training configuration as a DictConfig.
31
+ train_labels: List of Labels objects for training.
32
+ val_labels: List of Labels objects for validation.
33
+ If not provided, the labels will be loaded from paths in the config.
34
+ """
22
35
  start_train_time = time()
23
36
  start_timestamp = str(datetime.now())
24
37
  logger.info(f"Started training at: {start_timestamp}")
38
+ logger.info(get_startup_info_string())
25
39
 
26
- trainer = ModelTrainer.get_model_trainer_from_config(config)
40
+ # provide the labels as the train labels, val labels will be split from the train labels
41
+ trainer = ModelTrainer.get_model_trainer_from_config(
42
+ config, train_labels=train_labels, val_labels=val_labels
43
+ )
27
44
  trainer.train()
28
45
 
29
46
  finish_timestamp = str(datetime.now())
@@ -39,48 +56,44 @@ def run_training(config: DictConfig):
39
56
  # run inference on val dataset
40
57
  if trainer.config.trainer_config.save_ckpt:
41
58
  data_paths = {}
42
- for index, path in enumerate(trainer.config.data_config.train_labels_path):
43
- logger.info(
44
- f"Training labels path for index {index}: {(Path(trainer.config.trainer_config.ckpt_dir) / trainer.config.trainer_config.run_name).as_posix()}"
45
- )
46
- data_paths[f"train_{index}"] = (
47
- Path(trainer.config.trainer_config.ckpt_dir)
48
- / trainer.config.trainer_config.run_name
49
- / f"labels_train_gt_{index}.slp"
59
+ run_path = (
60
+ Path(trainer.config.trainer_config.ckpt_dir)
61
+ / trainer.config.trainer_config.run_name
62
+ )
63
+ for index, _ in enumerate(trainer.train_labels):
64
+ logger.info(f"Run path for index {index}: {run_path.as_posix()}")
65
+ data_paths[f"train.{index}"] = (
66
+ run_path / f"labels_gt.train.{index}.slp"
50
67
  ).as_posix()
51
- data_paths[f"val_{index}"] = (
52
- Path(trainer.config.trainer_config.ckpt_dir)
53
- / trainer.config.trainer_config.run_name
54
- / f"labels_val_gt_{index}.slp"
68
+ data_paths[f"val.{index}"] = (
69
+ run_path / f"labels_gt.val.{index}.slp"
55
70
  ).as_posix()
56
71
 
57
- if (
58
- OmegaConf.select(config, "data_config.test_file_path", default=None)
59
- is not None
60
- ):
61
- data_paths["test"] = config.data_config.test_file_path
72
+ # Handle test_file_path as either a string or list of strings
73
+ test_file_path = OmegaConf.select(
74
+ config, "data_config.test_file_path", default=None
75
+ )
76
+ if test_file_path is not None:
77
+ # Normalize to list of strings
78
+ if isinstance(test_file_path, str):
79
+ test_paths = [test_file_path]
80
+ else:
81
+ test_paths = list(test_file_path)
82
+ # Add each test path to data_paths (always use index for consistency)
83
+ for idx, test_path in enumerate(test_paths):
84
+ data_paths[f"test.{idx}"] = test_path
62
85
 
63
86
  for d_name, path in data_paths.items():
64
- pred_path = (
65
- Path(trainer.config.trainer_config.ckpt_dir)
66
- / trainer.config.trainer_config.run_name
67
- / f"pred_{d_name}.slp"
68
- )
69
- metrics_path = (
70
- Path(trainer.config.trainer_config.ckpt_dir)
71
- / trainer.config.trainer_config.run_name
72
- / f"{d_name}_pred_metrics.npz"
73
- )
87
+ # d_name is now in format: "train.0", "val.0", "test.0", etc.
88
+ pred_path = run_path / f"labels_pr.{d_name}.slp"
89
+ metrics_path = run_path / f"metrics.{d_name}.npz"
74
90
 
75
91
  pred_labels = predict(
76
92
  data_path=path,
77
- model_paths=[
78
- Path(trainer.config.trainer_config.ckpt_dir)
79
- / trainer.config.trainer_config.run_name
80
- ],
93
+ model_paths=[run_path],
81
94
  peak_threshold=0.2,
82
95
  make_labels=True,
83
- device=trainer.trainer.strategy.root_device,
96
+ device=str(trainer.trainer.strategy.root_device),
84
97
  output_path=pred_path,
85
98
  ensure_rgb=config.data_config.preprocessing.ensure_rgb,
86
99
  ensure_grayscale=config.data_config.preprocessing.ensure_grayscale,
@@ -110,7 +123,8 @@ def train(
110
123
  train_labels_path: Optional[List[str]] = None,
111
124
  val_labels_path: Optional[List[str]] = None,
112
125
  validation_fraction: float = 0.1,
113
- test_file_path: Optional[str] = None,
126
+ use_same_data_for_val: bool = False,
127
+ test_file_path: Optional[Union[str, List[str]]] = None,
114
128
  provider: str = "LabelsReader",
115
129
  user_instances_only: bool = True,
116
130
  data_pipeline_fw: str = "torch_dataset",
@@ -124,6 +138,7 @@ def train(
124
138
  max_width: Optional[int] = None,
125
139
  crop_size: Optional[int] = None,
126
140
  min_crop_size: Optional[int] = 100,
141
+ crop_padding: Optional[int] = None,
127
142
  use_augmentations_train: bool = False,
128
143
  intensity_aug: Optional[Union[str, List[str], Dict[str, Any]]] = None,
129
144
  geometry_aug: Optional[Union[str, List[str], Dict[str, Any]]] = None,
@@ -188,7 +203,11 @@ def train(
188
203
  training set to sample for generating the validation set. The remaining
189
204
  labeled frames will be left in the training set. If the `validation_labels`
190
205
  are already specified, this has no effect. Default: 0.1.
191
- test_file_path: Path to test dataset (`.slp` file or `.mp4` file).
206
+ use_same_data_for_val: If `True`, use the same data for both training and
207
+ validation (train = val). Useful for intentional overfitting on small
208
+ datasets. When enabled, `val_labels_path` and `validation_fraction` are
209
+ ignored. Default: False.
210
+ test_file_path: Path or list of paths to test dataset(s) (`.slp` file(s) or `.mp4` file(s)).
192
211
  Note: This is used to get evaluation on test set after training is completed.
193
212
  provider: Provider class to read the input sleap files. Only "LabelsReader"
194
213
  supported for the training pipeline. Default: "LabelsReader".
@@ -210,14 +229,17 @@ def train(
210
229
  is set to True, then we convert the image to grayscale (single-channel)
211
230
  image. If the source image has only one channel and this is set to False, then we retain the single channel input. Default: `False`.
212
231
  scale: Factor to resize the image dimensions by, specified as a float. Default: 1.0.
213
- max_height: Maximum height the image should be padded to. If not provided, the
232
+ max_height: Maximum height the original image should be resized and padded to. If not provided, the
214
233
  original image size will be retained. Default: None.
215
- max_width: Maximum width the image should be padded to. If not provided, the
234
+ max_width: Maximum width the original image should be resized and padded to. If not provided, the
216
235
  original image size will be retained. Default: None.
217
236
  crop_size: Crop size of each instance for centered-instance model.
218
237
  If `None`, this would be automatically computed based on the largest instance
219
- in the `sio.Labels` file. Default: None.
238
+ in the `sio.Labels` file. If `scale` is provided, then the cropped image will be resized according to `scale`. Default: None.
220
239
  min_crop_size: Minimum crop size to be used if `crop_size` is `None`. Default: 100.
240
+ crop_padding: Padding in pixels to add around instance bounding box when computing
241
+ crop size. If `None`, padding is auto-computed based on augmentation settings.
242
+ Only used when `crop_size` is `None`. Default: None.
221
243
  use_augmentations_train: True if the data augmentation should be applied to the
222
244
  training data, else False. Default: False.
223
245
  intensity_aug: One of ["uniform_noise", "gaussian_noise", "contrast", "brightness"]
@@ -376,6 +398,7 @@ def train(
376
398
  train_labels_path=train_labels_path,
377
399
  val_labels_path=val_labels_path,
378
400
  validation_fraction=validation_fraction,
401
+ use_same_data_for_val=use_same_data_for_val,
379
402
  test_file_path=test_file_path,
380
403
  provider=provider,
381
404
  user_instances_only=user_instances_only,
@@ -390,6 +413,7 @@ def train(
390
413
  max_width=max_width,
391
414
  crop_size=crop_size,
392
415
  min_crop_size=min_crop_size,
416
+ crop_padding=crop_padding,
393
417
  use_augmentations_train=use_augmentations_train,
394
418
  intensity_aug=intensity_aug,
395
419
  geometry_aug=geometry_aug,