sleap-nn 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sleap_nn/__init__.py +9 -2
  2. sleap_nn/architectures/convnext.py +5 -0
  3. sleap_nn/architectures/encoder_decoder.py +25 -6
  4. sleap_nn/architectures/swint.py +8 -0
  5. sleap_nn/cli.py +489 -46
  6. sleap_nn/config/data_config.py +51 -8
  7. sleap_nn/config/get_config.py +32 -24
  8. sleap_nn/config/trainer_config.py +88 -0
  9. sleap_nn/data/augmentation.py +61 -200
  10. sleap_nn/data/custom_datasets.py +433 -61
  11. sleap_nn/data/instance_cropping.py +71 -6
  12. sleap_nn/data/normalization.py +45 -2
  13. sleap_nn/data/providers.py +26 -0
  14. sleap_nn/data/resizing.py +2 -2
  15. sleap_nn/data/skia_augmentation.py +414 -0
  16. sleap_nn/data/utils.py +135 -17
  17. sleap_nn/evaluation.py +177 -42
  18. sleap_nn/export/__init__.py +21 -0
  19. sleap_nn/export/cli.py +1778 -0
  20. sleap_nn/export/exporters/__init__.py +51 -0
  21. sleap_nn/export/exporters/onnx_exporter.py +80 -0
  22. sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
  23. sleap_nn/export/metadata.py +225 -0
  24. sleap_nn/export/predictors/__init__.py +63 -0
  25. sleap_nn/export/predictors/base.py +22 -0
  26. sleap_nn/export/predictors/onnx.py +154 -0
  27. sleap_nn/export/predictors/tensorrt.py +312 -0
  28. sleap_nn/export/utils.py +307 -0
  29. sleap_nn/export/wrappers/__init__.py +25 -0
  30. sleap_nn/export/wrappers/base.py +96 -0
  31. sleap_nn/export/wrappers/bottomup.py +243 -0
  32. sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
  33. sleap_nn/export/wrappers/centered_instance.py +56 -0
  34. sleap_nn/export/wrappers/centroid.py +58 -0
  35. sleap_nn/export/wrappers/single_instance.py +83 -0
  36. sleap_nn/export/wrappers/topdown.py +180 -0
  37. sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
  38. sleap_nn/inference/__init__.py +6 -0
  39. sleap_nn/inference/bottomup.py +86 -20
  40. sleap_nn/inference/peak_finding.py +93 -16
  41. sleap_nn/inference/postprocessing.py +284 -0
  42. sleap_nn/inference/predictors.py +339 -137
  43. sleap_nn/inference/provenance.py +292 -0
  44. sleap_nn/inference/topdown.py +55 -47
  45. sleap_nn/legacy_models.py +65 -11
  46. sleap_nn/predict.py +224 -19
  47. sleap_nn/system_info.py +443 -0
  48. sleap_nn/tracking/tracker.py +8 -1
  49. sleap_nn/train.py +138 -44
  50. sleap_nn/training/callbacks.py +1258 -5
  51. sleap_nn/training/lightning_modules.py +902 -220
  52. sleap_nn/training/model_trainer.py +424 -111
  53. sleap_nn/training/schedulers.py +191 -0
  54. sleap_nn/training/utils.py +367 -2
  55. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/METADATA +35 -33
  56. sleap_nn-0.1.0.dist-info/RECORD +88 -0
  57. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/WHEEL +1 -1
  58. sleap_nn-0.0.5.dist-info/RECORD +0 -63
  59. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/entry_points.txt +0 -0
  60. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/licenses/LICENSE +0 -0
  61. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,11 @@
1
1
  """Custom `torch.utils.data.Dataset`s for different model types."""
2
2
 
3
- from kornia.geometry.transform import crop_and_resize
3
+ from sleap_nn.data.skia_augmentation import crop_and_resize_skia as crop_and_resize
4
4
 
5
- # from concurrent.futures import ThreadPoolExecutor # TODO: implement parallel processing
6
- # import concurrent.futures
7
- # import os
5
+ import os
6
+ import threading
7
+ from concurrent.futures import ThreadPoolExecutor, as_completed
8
+ from copy import deepcopy
8
9
  from itertools import cycle
9
10
  from pathlib import Path
10
11
  import torch.distributed as dist
@@ -13,6 +14,14 @@ from omegaconf import DictConfig, OmegaConf
13
14
  import numpy as np
14
15
  from PIL import Image
15
16
  from loguru import logger
17
+ from rich.progress import (
18
+ Progress,
19
+ SpinnerColumn,
20
+ TextColumn,
21
+ BarColumn,
22
+ TimeElapsedColumn,
23
+ )
24
+ from rich.console import Console
16
25
  import torch
17
26
  import torchvision.transforms as T
18
27
  from torch.utils.data import Dataset, DataLoader, DistributedSampler
@@ -22,7 +31,6 @@ from sleap_nn.data.identity import generate_class_maps, make_class_vectors
22
31
  from sleap_nn.data.instance_centroids import generate_centroids
23
32
  from sleap_nn.data.instance_cropping import generate_crops
24
33
  from sleap_nn.data.normalization import (
25
- apply_normalization,
26
34
  convert_to_grayscale,
27
35
  convert_to_rgb,
28
36
  )
@@ -38,6 +46,182 @@ from sleap_nn.data.instance_cropping import make_centered_bboxes
38
46
  from sleap_nn.training.utils import is_distributed_initialized
39
47
  from sleap_nn.config.get_config import get_aug_config
40
48
 
49
+ # Minimum number of samples to use parallel caching (overhead not worth it for smaller)
50
+ MIN_SAMPLES_FOR_PARALLEL_CACHING = 20
51
+
52
+
53
+ class ParallelCacheFiller:
54
+ """Parallel implementation of image caching using thread-local video copies.
55
+
56
+ This class uses ThreadPoolExecutor to parallelize I/O-bound operations when
57
+ caching images to disk or memory. Each worker thread gets its own copy of
58
+ video objects to ensure thread safety.
59
+
60
+ Attributes:
61
+ labels: List of sio.Labels objects containing the data.
62
+ lf_idx_list: List of dictionaries with labeled frame indices.
63
+ cache_type: Either "disk" or "memory".
64
+ cache_path: Path to save cached images (for disk caching).
65
+ num_workers: Number of worker threads.
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ labels: List[sio.Labels],
71
+ lf_idx_list: List[Dict],
72
+ cache_type: str,
73
+ cache_path: Optional[Path] = None,
74
+ num_workers: int = 4,
75
+ ):
76
+ """Initialize the parallel cache filler.
77
+
78
+ Args:
79
+ labels: List of sio.Labels objects.
80
+ lf_idx_list: List of sample dictionaries with frame indices.
81
+ cache_type: Either "disk" or "memory".
82
+ cache_path: Path for disk caching.
83
+ num_workers: Number of worker threads.
84
+ """
85
+ self.labels = labels
86
+ self.lf_idx_list = lf_idx_list
87
+ self.cache_type = cache_type
88
+ self.cache_path = cache_path
89
+ self.num_workers = num_workers
90
+
91
+ self.cache: Dict = {}
92
+ self._cache_lock = threading.Lock()
93
+ self._local = threading.local()
94
+ self._video_info: Dict = {}
95
+
96
+ # Prepare video copies for thread-local access
97
+ self._prepare_video_copies()
98
+
99
+ def _prepare_video_copies(self):
100
+ """Close original videos and prepare for thread-local copies."""
101
+ for label in self.labels:
102
+ for video in label.videos:
103
+ vid_id = id(video)
104
+ if vid_id not in self._video_info:
105
+ # Store original state
106
+ original_open_backend = video.open_backend
107
+
108
+ # Close the video backend
109
+ video.close()
110
+ video.open_backend = False
111
+
112
+ self._video_info[vid_id] = {
113
+ "video": video,
114
+ "original_open_backend": original_open_backend,
115
+ }
116
+
117
+ def _get_thread_local_video(self, video: sio.Video) -> sio.Video:
118
+ """Get or create a thread-local video copy.
119
+
120
+ Args:
121
+ video: The original video object.
122
+
123
+ Returns:
124
+ A thread-local copy of the video that is safe to use.
125
+ """
126
+ vid_id = id(video)
127
+
128
+ if not hasattr(self._local, "videos"):
129
+ self._local.videos = {}
130
+
131
+ if vid_id not in self._local.videos:
132
+ # Create a thread-local copy
133
+ video_copy = deepcopy(video)
134
+ video_copy.open_backend = True
135
+ self._local.videos[vid_id] = video_copy
136
+
137
+ return self._local.videos[vid_id]
138
+
139
+ def _process_sample(
140
+ self, sample: Dict
141
+ ) -> Tuple[int, int, Optional[np.ndarray], Optional[str]]:
142
+ """Process a single sample (read image, optionally save/cache).
143
+
144
+ Args:
145
+ sample: Dictionary with labels_idx, lf_idx, etc.
146
+
147
+ Returns:
148
+ Tuple of (labels_idx, lf_idx, image_or_none, error_or_none).
149
+ """
150
+ labels_idx = sample["labels_idx"]
151
+ lf_idx = sample["lf_idx"]
152
+
153
+ try:
154
+ # Get the labeled frame
155
+ lf = self.labels[labels_idx][lf_idx]
156
+
157
+ # Get thread-local video
158
+ video = self._get_thread_local_video(lf.video)
159
+
160
+ # Read the image
161
+ img = video[lf.frame_idx]
162
+
163
+ if img.shape[-1] == 1:
164
+ img = np.squeeze(img)
165
+
166
+ if self.cache_type == "disk":
167
+ f_name = self.cache_path / f"sample_{labels_idx}_{lf_idx}.jpg"
168
+ Image.fromarray(img).save(str(f_name), format="JPEG")
169
+ return labels_idx, lf_idx, None, None
170
+ elif self.cache_type == "memory":
171
+ return labels_idx, lf_idx, img, None
172
+
173
+ except Exception as e:
174
+ return labels_idx, lf_idx, None, f"{type(e).__name__}: {str(e)}"
175
+
176
+ def fill_cache(
177
+ self, progress_callback=None
178
+ ) -> Tuple[Dict, List[Tuple[int, int, str]]]:
179
+ """Fill the cache in parallel.
180
+
181
+ Args:
182
+ progress_callback: Optional callback(completed_count) for progress updates.
183
+
184
+ Returns:
185
+ Tuple of (cache_dict, list_of_errors).
186
+ """
187
+ errors = []
188
+ completed = 0
189
+
190
+ with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
191
+ futures = {
192
+ executor.submit(self._process_sample, sample): sample
193
+ for sample in self.lf_idx_list
194
+ }
195
+
196
+ for future in as_completed(futures):
197
+ labels_idx, lf_idx, img, error = future.result()
198
+
199
+ if error:
200
+ errors.append((labels_idx, lf_idx, error))
201
+ elif self.cache_type == "memory" and img is not None:
202
+ with self._cache_lock:
203
+ self.cache[(labels_idx, lf_idx)] = img
204
+
205
+ completed += 1
206
+ if progress_callback:
207
+ progress_callback(completed)
208
+
209
+ # Restore original video states
210
+ self._restore_videos()
211
+
212
+ return self.cache, errors
213
+
214
+ def _restore_videos(self):
215
+ """Restore original video states after caching is complete."""
216
+ for vid_info in self._video_info.values():
217
+ video = vid_info["video"]
218
+ video.open_backend = vid_info["original_open_backend"]
219
+ if video.open_backend:
220
+ try:
221
+ video.open()
222
+ except Exception:
223
+ pass
224
+
41
225
 
42
226
  class BaseDataset(Dataset):
43
227
  """Base class for custom torch Datasets.
@@ -76,6 +260,8 @@ class BaseDataset(Dataset):
76
260
  use_existing_imgs: Use existing imgs/ chunks in the `cache_img_path`.
77
261
  rank: Indicates the rank of the process. Used during distributed training to ensure that image storage to
78
262
  disk occurs only once across all workers.
263
+ parallel_caching: If True, use parallel processing for caching (faster for large datasets). Default: True.
264
+ cache_workers: Number of worker threads for parallel caching. If 0, uses min(4, cpu_count). Default: 0.
79
265
  labels_list: List of `sio.Labels` objects. Used to store the labels in the cache. (only used if `cache_img` is `None`)
80
266
  """
81
267
 
@@ -95,6 +281,8 @@ class BaseDataset(Dataset):
95
281
  cache_img_path: Optional[str] = None,
96
282
  use_existing_imgs: bool = False,
97
283
  rank: Optional[int] = None,
284
+ parallel_caching: bool = True,
285
+ cache_workers: int = 0,
98
286
  ) -> None:
99
287
  """Initialize class attributes."""
100
288
  super().__init__()
@@ -135,6 +323,8 @@ class BaseDataset(Dataset):
135
323
  self.cache_img = cache_img
136
324
  self.cache_img_path = cache_img_path
137
325
  self.use_existing_imgs = use_existing_imgs
326
+ self.parallel_caching = parallel_caching
327
+ self.cache_workers = cache_workers
138
328
  if self.cache_img is not None and "disk" in self.cache_img:
139
329
  if self.cache_img_path is None:
140
330
  self.cache_img_path = "."
@@ -160,10 +350,18 @@ class BaseDataset(Dataset):
160
350
 
161
351
  if self.cache_img is not None:
162
352
  if self.cache_img == "memory":
163
- self._fill_cache(labels)
353
+ self._fill_cache(
354
+ labels,
355
+ parallel=self.parallel_caching,
356
+ num_workers=self.cache_workers,
357
+ )
164
358
  elif self.cache_img == "disk" and not self.use_existing_imgs:
165
359
  if self.rank is None or self.rank == -1 or self.rank == 0:
166
- self._fill_cache(labels)
360
+ self._fill_cache(
361
+ labels,
362
+ parallel=self.parallel_caching,
363
+ num_workers=self.cache_workers,
364
+ )
167
365
  # Synchronize all ranks after cache creation
168
366
  if is_distributed_initialized():
169
367
  dist.barrier()
@@ -177,6 +375,9 @@ class BaseDataset(Dataset):
177
375
  if self.user_instances_only:
178
376
  if lf.user_instances is not None and len(lf.user_instances) > 0:
179
377
  lf.instances = lf.user_instances
378
+ else:
379
+ # Skip frames without user instances
380
+ continue
180
381
  is_empty = True
181
382
  for _, inst in enumerate(lf.instances):
182
383
  if not inst.is_empty: # filter all NaN instances.
@@ -209,20 +410,160 @@ class BaseDataset(Dataset):
209
410
  """Returns an iterator."""
210
411
  return self
211
412
 
212
- def _fill_cache(self, labels: List[sio.Labels]):
213
- """Load all samples to cache."""
214
- # TODO: Implement parallel processing (using threads might cause error with MediaVideo backend)
215
- for sample in self.lf_idx_list:
216
- labels_idx = sample["labels_idx"]
217
- lf_idx = sample["lf_idx"]
218
- img = labels[labels_idx][lf_idx].image
219
- if img.shape[-1] == 1:
220
- img = np.squeeze(img)
221
- if self.cache_img == "disk":
222
- f_name = f"{self.cache_img_path}/sample_{labels_idx}_{lf_idx}.jpg"
223
- Image.fromarray(img).save(f_name, format="JPEG")
224
- if self.cache_img == "memory":
225
- self.cache[(labels_idx, lf_idx)] = img
413
+ def _fill_cache(
414
+ self,
415
+ labels: List[sio.Labels],
416
+ parallel: bool = True,
417
+ num_workers: int = 0,
418
+ ):
419
+ """Load all samples to cache.
420
+
421
+ Args:
422
+ labels: List of sio.Labels objects containing the data.
423
+ parallel: If True, use parallel processing for caching (faster for large
424
+ datasets). Default: True.
425
+ num_workers: Number of worker threads for parallel caching. If 0, uses
426
+ min(4, cpu_count). Default: 0.
427
+ """
428
+ total_samples = len(self.lf_idx_list)
429
+ cache_type = "disk" if self.cache_img == "disk" else "memory"
430
+
431
+ # Check for NO_COLOR env var to disable progress bar
432
+ no_color = (
433
+ os.environ.get("NO_COLOR") is not None
434
+ or os.environ.get("FORCE_COLOR") == "0"
435
+ )
436
+ use_progress = not no_color
437
+
438
+ # Use parallel caching for larger datasets
439
+ use_parallel = parallel and total_samples >= MIN_SAMPLES_FOR_PARALLEL_CACHING
440
+
441
+ logger.info(f"Caching {total_samples} images to {cache_type}...")
442
+
443
+ if use_parallel:
444
+ self._fill_cache_parallel(
445
+ labels, total_samples, cache_type, use_progress, num_workers
446
+ )
447
+ else:
448
+ self._fill_cache_sequential(labels, total_samples, cache_type, use_progress)
449
+
450
+ logger.info(f"Caching complete.")
451
+
452
+ def _fill_cache_sequential(
453
+ self,
454
+ labels: List[sio.Labels],
455
+ total_samples: int,
456
+ cache_type: str,
457
+ use_progress: bool,
458
+ ):
459
+ """Sequential implementation of cache filling.
460
+
461
+ Args:
462
+ labels: List of sio.Labels objects.
463
+ total_samples: Total number of samples to cache.
464
+ cache_type: Either "disk" or "memory".
465
+ use_progress: Whether to show a progress bar.
466
+ """
467
+
468
+ def process_samples(progress=None, task=None):
469
+ for sample in self.lf_idx_list:
470
+ labels_idx = sample["labels_idx"]
471
+ lf_idx = sample["lf_idx"]
472
+ img = labels[labels_idx][lf_idx].image
473
+ if img.shape[-1] == 1:
474
+ img = np.squeeze(img)
475
+ if self.cache_img == "disk":
476
+ f_name = f"{self.cache_img_path}/sample_{labels_idx}_{lf_idx}.jpg"
477
+ Image.fromarray(img).save(f_name, format="JPEG")
478
+ if self.cache_img == "memory":
479
+ self.cache[(labels_idx, lf_idx)] = img
480
+ if progress is not None:
481
+ progress.update(task, advance=1)
482
+
483
+ if use_progress:
484
+ with Progress(
485
+ SpinnerColumn(),
486
+ TextColumn("[progress.description]{task.description}"),
487
+ BarColumn(),
488
+ TextColumn("{task.completed}/{task.total}"),
489
+ TimeElapsedColumn(),
490
+ console=Console(force_terminal=True),
491
+ transient=True,
492
+ ) as progress:
493
+ task = progress.add_task(
494
+ f"Caching images to {cache_type}", total=total_samples
495
+ )
496
+ process_samples(progress, task)
497
+ else:
498
+ process_samples()
499
+
500
+ def _fill_cache_parallel(
501
+ self,
502
+ labels: List[sio.Labels],
503
+ total_samples: int,
504
+ cache_type: str,
505
+ use_progress: bool,
506
+ num_workers: int = 0,
507
+ ):
508
+ """Parallel implementation of cache filling using thread-local video copies.
509
+
510
+ Args:
511
+ labels: List of sio.Labels objects.
512
+ total_samples: Total number of samples to cache.
513
+ cache_type: Either "disk" or "memory".
514
+ use_progress: Whether to show a progress bar.
515
+ num_workers: Number of worker threads. If 0, uses min(4, cpu_count).
516
+ """
517
+ # Determine number of workers
518
+ if num_workers <= 0:
519
+ num_workers = min(4, os.cpu_count() or 1)
520
+
521
+ cache_path = Path(self.cache_img_path) if self.cache_img_path else None
522
+
523
+ filler = ParallelCacheFiller(
524
+ labels=labels,
525
+ lf_idx_list=self.lf_idx_list,
526
+ cache_type=cache_type,
527
+ cache_path=cache_path,
528
+ num_workers=num_workers,
529
+ )
530
+
531
+ if use_progress:
532
+ with Progress(
533
+ SpinnerColumn(),
534
+ TextColumn("[progress.description]{task.description}"),
535
+ BarColumn(),
536
+ TextColumn("{task.completed}/{task.total}"),
537
+ TimeElapsedColumn(),
538
+ console=Console(force_terminal=True),
539
+ transient=True,
540
+ ) as progress:
541
+ task = progress.add_task(
542
+ f"Caching images to {cache_type} (parallel, {num_workers} workers)",
543
+ total=total_samples,
544
+ )
545
+
546
+ def progress_callback(completed):
547
+ progress.update(task, completed=completed)
548
+
549
+ cache, errors = filler.fill_cache(progress_callback)
550
+ else:
551
+ logger.info(
552
+ f"Caching {total_samples} images to {cache_type} "
553
+ f"(parallel, {num_workers} workers)..."
554
+ )
555
+ cache, errors = filler.fill_cache()
556
+
557
+ # Update instance cache
558
+ if cache_type == "memory":
559
+ self.cache.update(cache)
560
+
561
+ # Log any errors
562
+ if errors:
563
+ logger.warning(
564
+ f"Parallel caching completed with {len(errors)} errors. "
565
+ f"First error: {errors[0]}"
566
+ )
226
567
 
227
568
  def __len__(self) -> int:
228
569
  """Return the number of samples in the dataset."""
@@ -298,6 +639,8 @@ class BottomUpDataset(BaseDataset):
298
639
  cache_img_path: Optional[str] = None,
299
640
  use_existing_imgs: bool = False,
300
641
  rank: Optional[int] = None,
642
+ parallel_caching: bool = True,
643
+ cache_workers: int = 0,
301
644
  ) -> None:
302
645
  """Initialize class attributes."""
303
646
  super().__init__(
@@ -315,6 +658,8 @@ class BottomUpDataset(BaseDataset):
315
658
  cache_img_path=cache_img_path,
316
659
  use_existing_imgs=use_existing_imgs,
317
660
  rank=rank,
661
+ parallel_caching=parallel_caching,
662
+ cache_workers=cache_workers,
318
663
  )
319
664
  self.confmap_head_config = confmap_head_config
320
665
  self.pafs_head_config = pafs_head_config
@@ -357,9 +702,6 @@ class BottomUpDataset(BaseDataset):
357
702
  user_instances_only=self.user_instances_only,
358
703
  )
359
704
 
360
- # apply normalization
361
- sample["image"] = apply_normalization(sample["image"])
362
-
363
705
  if self.ensure_rgb:
364
706
  sample["image"] = convert_to_rgb(sample["image"])
365
707
  elif self.ensure_grayscale:
@@ -497,6 +839,8 @@ class BottomUpMultiClassDataset(BaseDataset):
497
839
  cache_img_path: Optional[str] = None,
498
840
  use_existing_imgs: bool = False,
499
841
  rank: Optional[int] = None,
842
+ parallel_caching: bool = True,
843
+ cache_workers: int = 0,
500
844
  ) -> None:
501
845
  """Initialize class attributes."""
502
846
  super().__init__(
@@ -514,6 +858,8 @@ class BottomUpMultiClassDataset(BaseDataset):
514
858
  cache_img_path=cache_img_path,
515
859
  use_existing_imgs=use_existing_imgs,
516
860
  rank=rank,
861
+ parallel_caching=parallel_caching,
862
+ cache_workers=cache_workers,
517
863
  )
518
864
  self.confmap_head_config = confmap_head_config
519
865
  self.class_maps_head_config = class_maps_head_config
@@ -570,9 +916,6 @@ class BottomUpMultiClassDataset(BaseDataset):
570
916
 
571
917
  sample["num_tracks"] = torch.tensor(len(self.class_names), dtype=torch.int32)
572
918
 
573
- # apply normalization
574
- sample["image"] = apply_normalization(sample["image"])
575
-
576
919
  if self.ensure_rgb:
577
920
  sample["image"] = convert_to_rgb(sample["image"])
578
921
  elif self.ensure_grayscale:
@@ -684,15 +1027,12 @@ class CenteredInstanceDataset(BaseDataset):
684
1027
  the images aren't cached and loaded from the `.slp` file on each access.
685
1028
  cache_img_path: Path to save the `.jpg` files. If `None`, current working dir is used.
686
1029
  use_existing_imgs: Use existing imgs/ chunks in the `cache_img_path`.
687
- crop_size: Crop size of each instance for centered-instance model.
1030
+ crop_size: Crop size of each instance for centered-instance model. If `scale` is provided, then the cropped image will be resized according to `scale`.
688
1031
  rank: Indicates the rank of the process. Used during distributed training to ensure that image storage to
689
1032
  disk occurs only once across all workers.
690
1033
  confmap_head_config: DictConfig object with all the keys in the `head_config` section.
691
1034
  (required keys: `sigma`, `output_stride`, `part_names` and `anchor_part` depending on the model type ).
692
1035
  labels_list: List of `sio.Labels` objects. Used to store the labels in the cache. (only used if `cache_img` is `None`)
693
-
694
- Note: If scale is provided for centered-instance model, the images are cropped out
695
- from the scaled image with the given crop size.
696
1036
  """
697
1037
 
698
1038
  def __init__(
@@ -714,6 +1054,8 @@ class CenteredInstanceDataset(BaseDataset):
714
1054
  cache_img_path: Optional[str] = None,
715
1055
  use_existing_imgs: bool = False,
716
1056
  rank: Optional[int] = None,
1057
+ parallel_caching: bool = True,
1058
+ cache_workers: int = 0,
717
1059
  ) -> None:
718
1060
  """Initialize class attributes."""
719
1061
  super().__init__(
@@ -731,6 +1073,8 @@ class CenteredInstanceDataset(BaseDataset):
731
1073
  cache_img_path=cache_img_path,
732
1074
  use_existing_imgs=use_existing_imgs,
733
1075
  rank=rank,
1076
+ parallel_caching=parallel_caching,
1077
+ cache_workers=cache_workers,
734
1078
  )
735
1079
  self.labels = None
736
1080
  self.crop_size = crop_size
@@ -748,6 +1092,9 @@ class CenteredInstanceDataset(BaseDataset):
748
1092
  if self.user_instances_only:
749
1093
  if lf.user_instances is not None and len(lf.user_instances) > 0:
750
1094
  lf.instances = lf.user_instances
1095
+ else:
1096
+ # Skip frames without user instances
1097
+ continue
751
1098
  for inst_idx, inst in enumerate(lf.instances):
752
1099
  if not inst.is_empty: # filter all NaN instances.
753
1100
  video_idx = labels[labels_idx].videos.index(lf.video)
@@ -818,9 +1165,6 @@ class CenteredInstanceDataset(BaseDataset):
818
1165
 
819
1166
  instances = instances[:, inst_idx]
820
1167
 
821
- # apply normalization
822
- image = apply_normalization(image)
823
-
824
1168
  if self.ensure_rgb:
825
1169
  image = convert_to_rgb(image)
826
1170
  elif self.ensure_grayscale:
@@ -834,13 +1178,6 @@ class CenteredInstanceDataset(BaseDataset):
834
1178
  )
835
1179
  instances = instances * eff_scale
836
1180
 
837
- # resize image
838
- image, instances = apply_resizer(
839
- image,
840
- instances,
841
- scale=self.scale,
842
- )
843
-
844
1181
  # get the centroids based on the anchor idx
845
1182
  centroids = generate_centroids(instances, anchor_ind=self.anchor_ind)
846
1183
 
@@ -901,6 +1238,13 @@ class CenteredInstanceDataset(BaseDataset):
901
1238
  sample["instance"] = center_instance # (n_samples=1, n_nodes, 2)
902
1239
  sample["centroid"] = centered_centroid # (n_samples=1, 2)
903
1240
 
1241
+ # resize the cropped image
1242
+ sample["instance_image"], sample["instance"] = apply_resizer(
1243
+ sample["instance_image"],
1244
+ sample["instance"],
1245
+ scale=self.scale,
1246
+ )
1247
+
904
1248
  # Pad the image (if needed) according max stride
905
1249
  sample["instance_image"] = apply_pad_to_stride(
906
1250
  sample["instance_image"], max_stride=self.max_stride
@@ -959,7 +1303,7 @@ class TopDownCenteredInstanceMultiClassDataset(CenteredInstanceDataset):
959
1303
  the images aren't cached and loaded from the `.slp` file on each access.
960
1304
  cache_img_path: Path to save the `.jpg` files. If `None`, current working dir is used.
961
1305
  use_existing_imgs: Use existing imgs/ chunks in the `cache_img_path`.
962
- crop_size: Crop size of each instance for centered-instance model.
1306
+ crop_size: Crop size of each instance for centered-instance model. If `scale` is provided, then the cropped image will be resized according to `scale`.
963
1307
  rank: Indicates the rank of the process. Used during distributed training to ensure that image storage to
964
1308
  disk occurs only once across all workers.
965
1309
  confmap_head_config: DictConfig object with all the keys in the `head_config` section.
@@ -967,9 +1311,6 @@ class TopDownCenteredInstanceMultiClassDataset(CenteredInstanceDataset):
967
1311
  class_vectors_head_config: DictConfig object with all the keys in the `head_config` section.
968
1312
  (required keys: `classes`, `num_fc_layers`, `num_fc_units`, `output_stride`, `loss_weight`).
969
1313
  labels_list: List of `sio.Labels` objects. Used to store the labels in the cache. (only used if `cache_img` is `None`)
970
-
971
- Note: If scale is provided for centered-instance model, the images are cropped out
972
- from the scaled image with the given crop size.
973
1314
  """
974
1315
 
975
1316
  def __init__(
@@ -992,6 +1333,8 @@ class TopDownCenteredInstanceMultiClassDataset(CenteredInstanceDataset):
992
1333
  cache_img_path: Optional[str] = None,
993
1334
  use_existing_imgs: bool = False,
994
1335
  rank: Optional[int] = None,
1336
+ parallel_caching: bool = True,
1337
+ cache_workers: int = 0,
995
1338
  ) -> None:
996
1339
  """Initialize class attributes."""
997
1340
  super().__init__(
@@ -1012,6 +1355,8 @@ class TopDownCenteredInstanceMultiClassDataset(CenteredInstanceDataset):
1012
1355
  cache_img_path=cache_img_path,
1013
1356
  use_existing_imgs=use_existing_imgs,
1014
1357
  rank=rank,
1358
+ parallel_caching=parallel_caching,
1359
+ cache_workers=cache_workers,
1015
1360
  )
1016
1361
  self.class_vectors_head_config = class_vectors_head_config
1017
1362
  self.class_names = self.class_vectors_head_config.classes
@@ -1066,9 +1411,6 @@ class TopDownCenteredInstanceMultiClassDataset(CenteredInstanceDataset):
1066
1411
 
1067
1412
  instances = instances[:, inst_idx]
1068
1413
 
1069
- # apply normalization
1070
- image = apply_normalization(image)
1071
-
1072
1414
  if self.ensure_rgb:
1073
1415
  image = convert_to_rgb(image)
1074
1416
  elif self.ensure_grayscale:
@@ -1082,13 +1424,6 @@ class TopDownCenteredInstanceMultiClassDataset(CenteredInstanceDataset):
1082
1424
  )
1083
1425
  instances = instances * eff_scale
1084
1426
 
1085
- # resize image
1086
- image, instances = apply_resizer(
1087
- image,
1088
- instances,
1089
- scale=self.scale,
1090
- )
1091
-
1092
1427
  # get class vectors
1093
1428
  track_ids = torch.Tensor(
1094
1429
  [
@@ -1165,6 +1500,13 @@ class TopDownCenteredInstanceMultiClassDataset(CenteredInstanceDataset):
1165
1500
  sample["instance"] = center_instance # (n_samples=1, n_nodes, 2)
1166
1501
  sample["centroid"] = centered_centroid # (n_samples=1, 2)
1167
1502
 
1503
+ # resize image
1504
+ sample["instance_image"], sample["instance"] = apply_resizer(
1505
+ sample["instance_image"],
1506
+ sample["instance"],
1507
+ scale=self.scale,
1508
+ )
1509
+
1168
1510
  # Pad the image (if needed) according max stride
1169
1511
  sample["instance_image"] = apply_pad_to_stride(
1170
1512
  sample["instance_image"], max_stride=self.max_stride
@@ -1250,6 +1592,8 @@ class CentroidDataset(BaseDataset):
1250
1592
  cache_img_path: Optional[str] = None,
1251
1593
  use_existing_imgs: bool = False,
1252
1594
  rank: Optional[int] = None,
1595
+ parallel_caching: bool = True,
1596
+ cache_workers: int = 0,
1253
1597
  ) -> None:
1254
1598
  """Initialize class attributes."""
1255
1599
  super().__init__(
@@ -1267,6 +1611,8 @@ class CentroidDataset(BaseDataset):
1267
1611
  cache_img_path=cache_img_path,
1268
1612
  use_existing_imgs=use_existing_imgs,
1269
1613
  rank=rank,
1614
+ parallel_caching=parallel_caching,
1615
+ cache_workers=cache_workers,
1270
1616
  )
1271
1617
  self.anchor_ind = anchor_ind
1272
1618
  self.confmap_head_config = confmap_head_config
@@ -1306,9 +1652,6 @@ class CentroidDataset(BaseDataset):
1306
1652
  user_instances_only=self.user_instances_only,
1307
1653
  )
1308
1654
 
1309
- # apply normalization
1310
- sample["image"] = apply_normalization(sample["image"])
1311
-
1312
1655
  if self.ensure_rgb:
1313
1656
  sample["image"] = convert_to_rgb(sample["image"])
1314
1657
  elif self.ensure_grayscale:
@@ -1433,6 +1776,8 @@ class SingleInstanceDataset(BaseDataset):
1433
1776
  cache_img_path: Optional[str] = None,
1434
1777
  use_existing_imgs: bool = False,
1435
1778
  rank: Optional[int] = None,
1779
+ parallel_caching: bool = True,
1780
+ cache_workers: int = 0,
1436
1781
  ) -> None:
1437
1782
  """Initialize class attributes."""
1438
1783
  super().__init__(
@@ -1450,6 +1795,8 @@ class SingleInstanceDataset(BaseDataset):
1450
1795
  cache_img_path=cache_img_path,
1451
1796
  use_existing_imgs=use_existing_imgs,
1452
1797
  rank=rank,
1798
+ parallel_caching=parallel_caching,
1799
+ cache_workers=cache_workers,
1453
1800
  )
1454
1801
  self.confmap_head_config = confmap_head_config
1455
1802
 
@@ -1488,9 +1835,6 @@ class SingleInstanceDataset(BaseDataset):
1488
1835
  user_instances_only=self.user_instances_only,
1489
1836
  )
1490
1837
 
1491
- # apply normalization
1492
- sample["image"] = apply_normalization(sample["image"])
1493
-
1494
1838
  if self.ensure_rgb:
1495
1839
  sample["image"] = convert_to_rgb(sample["image"])
1496
1840
  elif self.ensure_grayscale:
@@ -1671,6 +2015,10 @@ def get_train_val_datasets(
1671
2015
  val_cache_img_path = Path(base_cache_img_path) / "val_imgs"
1672
2016
  use_existing_imgs = config.data_config.use_existing_imgs
1673
2017
 
2018
+ # Parallel caching configuration
2019
+ parallel_caching = getattr(config.data_config, "parallel_caching", True)
2020
+ cache_workers = getattr(config.data_config, "cache_workers", 0)
2021
+
1674
2022
  model_type = get_model_type_from_cfg(config=config)
1675
2023
  backbone_type = get_backbone_type_from_cfg(config=config)
1676
2024
 
@@ -1724,6 +2072,8 @@ def get_train_val_datasets(
1724
2072
  cache_img_path=train_cache_img_path,
1725
2073
  use_existing_imgs=use_existing_imgs,
1726
2074
  rank=rank,
2075
+ parallel_caching=parallel_caching,
2076
+ cache_workers=cache_workers,
1727
2077
  )
1728
2078
  val_dataset = BottomUpDataset(
1729
2079
  labels=val_labels,
@@ -1747,6 +2097,8 @@ def get_train_val_datasets(
1747
2097
  cache_img_path=val_cache_img_path,
1748
2098
  use_existing_imgs=use_existing_imgs,
1749
2099
  rank=rank,
2100
+ parallel_caching=parallel_caching,
2101
+ cache_workers=cache_workers,
1750
2102
  )
1751
2103
 
1752
2104
  elif model_type == "multi_class_bottomup":
@@ -1780,6 +2132,8 @@ def get_train_val_datasets(
1780
2132
  cache_img_path=train_cache_img_path,
1781
2133
  use_existing_imgs=use_existing_imgs,
1782
2134
  rank=rank,
2135
+ parallel_caching=parallel_caching,
2136
+ cache_workers=cache_workers,
1783
2137
  )
1784
2138
  val_dataset = BottomUpMultiClassDataset(
1785
2139
  labels=val_labels,
@@ -1803,6 +2157,8 @@ def get_train_val_datasets(
1803
2157
  cache_img_path=val_cache_img_path,
1804
2158
  use_existing_imgs=use_existing_imgs,
1805
2159
  rank=rank,
2160
+ parallel_caching=parallel_caching,
2161
+ cache_workers=cache_workers,
1806
2162
  )
1807
2163
 
1808
2164
  elif model_type == "centered_instance":
@@ -1842,6 +2198,8 @@ def get_train_val_datasets(
1842
2198
  cache_img_path=train_cache_img_path,
1843
2199
  use_existing_imgs=use_existing_imgs,
1844
2200
  rank=rank,
2201
+ parallel_caching=parallel_caching,
2202
+ cache_workers=cache_workers,
1845
2203
  )
1846
2204
  val_dataset = CenteredInstanceDataset(
1847
2205
  labels=val_labels,
@@ -1866,6 +2224,8 @@ def get_train_val_datasets(
1866
2224
  cache_img_path=val_cache_img_path,
1867
2225
  use_existing_imgs=use_existing_imgs,
1868
2226
  rank=rank,
2227
+ parallel_caching=parallel_caching,
2228
+ cache_workers=cache_workers,
1869
2229
  )
1870
2230
 
1871
2231
  elif model_type == "multi_class_topdown":
@@ -1906,6 +2266,8 @@ def get_train_val_datasets(
1906
2266
  cache_img_path=train_cache_img_path,
1907
2267
  use_existing_imgs=use_existing_imgs,
1908
2268
  rank=rank,
2269
+ parallel_caching=parallel_caching,
2270
+ cache_workers=cache_workers,
1909
2271
  )
1910
2272
  val_dataset = TopDownCenteredInstanceMultiClassDataset(
1911
2273
  labels=val_labels,
@@ -1931,6 +2293,8 @@ def get_train_val_datasets(
1931
2293
  cache_img_path=val_cache_img_path,
1932
2294
  use_existing_imgs=use_existing_imgs,
1933
2295
  rank=rank,
2296
+ parallel_caching=parallel_caching,
2297
+ cache_workers=cache_workers,
1934
2298
  )
1935
2299
 
1936
2300
  elif model_type == "centroid":
@@ -1967,6 +2331,8 @@ def get_train_val_datasets(
1967
2331
  cache_img_path=train_cache_img_path,
1968
2332
  use_existing_imgs=use_existing_imgs,
1969
2333
  rank=rank,
2334
+ parallel_caching=parallel_caching,
2335
+ cache_workers=cache_workers,
1970
2336
  )
1971
2337
  val_dataset = CentroidDataset(
1972
2338
  labels=val_labels,
@@ -1990,6 +2356,8 @@ def get_train_val_datasets(
1990
2356
  cache_img_path=val_cache_img_path,
1991
2357
  use_existing_imgs=use_existing_imgs,
1992
2358
  rank=rank,
2359
+ parallel_caching=parallel_caching,
2360
+ cache_workers=cache_workers,
1993
2361
  )
1994
2362
 
1995
2363
  else:
@@ -2022,6 +2390,8 @@ def get_train_val_datasets(
2022
2390
  cache_img_path=train_cache_img_path,
2023
2391
  use_existing_imgs=use_existing_imgs,
2024
2392
  rank=rank,
2393
+ parallel_caching=parallel_caching,
2394
+ cache_workers=cache_workers,
2025
2395
  )
2026
2396
  val_dataset = SingleInstanceDataset(
2027
2397
  labels=val_labels,
@@ -2044,6 +2414,8 @@ def get_train_val_datasets(
2044
2414
  cache_img_path=val_cache_img_path,
2045
2415
  use_existing_imgs=use_existing_imgs,
2046
2416
  rank=rank,
2417
+ parallel_caching=parallel_caching,
2418
+ cache_workers=cache_workers,
2047
2419
  )
2048
2420
 
2049
2421
  # If using caching, close the videos to prevent `h5py objects can't be pickled error` when num_workers > 0.