sleap-nn 0.1.0__py3-none-any.whl → 0.1.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. sleap_nn/__init__.py +2 -4
  2. sleap_nn/architectures/convnext.py +0 -5
  3. sleap_nn/architectures/encoder_decoder.py +6 -25
  4. sleap_nn/architectures/swint.py +0 -8
  5. sleap_nn/cli.py +60 -364
  6. sleap_nn/config/data_config.py +5 -11
  7. sleap_nn/config/get_config.py +4 -10
  8. sleap_nn/config/trainer_config.py +0 -76
  9. sleap_nn/data/augmentation.py +241 -50
  10. sleap_nn/data/custom_datasets.py +39 -411
  11. sleap_nn/data/instance_cropping.py +1 -1
  12. sleap_nn/data/resizing.py +2 -2
  13. sleap_nn/data/utils.py +17 -135
  14. sleap_nn/evaluation.py +22 -81
  15. sleap_nn/inference/bottomup.py +20 -86
  16. sleap_nn/inference/peak_finding.py +19 -88
  17. sleap_nn/inference/predictors.py +117 -224
  18. sleap_nn/legacy_models.py +11 -65
  19. sleap_nn/predict.py +9 -37
  20. sleap_nn/train.py +4 -74
  21. sleap_nn/training/callbacks.py +105 -1046
  22. sleap_nn/training/lightning_modules.py +65 -602
  23. sleap_nn/training/model_trainer.py +184 -211
  24. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/METADATA +3 -15
  25. sleap_nn-0.1.0a0.dist-info/RECORD +65 -0
  26. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/WHEEL +1 -1
  27. sleap_nn/data/skia_augmentation.py +0 -414
  28. sleap_nn/export/__init__.py +0 -21
  29. sleap_nn/export/cli.py +0 -1778
  30. sleap_nn/export/exporters/__init__.py +0 -51
  31. sleap_nn/export/exporters/onnx_exporter.py +0 -80
  32. sleap_nn/export/exporters/tensorrt_exporter.py +0 -291
  33. sleap_nn/export/metadata.py +0 -225
  34. sleap_nn/export/predictors/__init__.py +0 -63
  35. sleap_nn/export/predictors/base.py +0 -22
  36. sleap_nn/export/predictors/onnx.py +0 -154
  37. sleap_nn/export/predictors/tensorrt.py +0 -312
  38. sleap_nn/export/utils.py +0 -307
  39. sleap_nn/export/wrappers/__init__.py +0 -25
  40. sleap_nn/export/wrappers/base.py +0 -96
  41. sleap_nn/export/wrappers/bottomup.py +0 -243
  42. sleap_nn/export/wrappers/bottomup_multiclass.py +0 -195
  43. sleap_nn/export/wrappers/centered_instance.py +0 -56
  44. sleap_nn/export/wrappers/centroid.py +0 -58
  45. sleap_nn/export/wrappers/single_instance.py +0 -83
  46. sleap_nn/export/wrappers/topdown.py +0 -180
  47. sleap_nn/export/wrappers/topdown_multiclass.py +0 -304
  48. sleap_nn/inference/postprocessing.py +0 -284
  49. sleap_nn/training/schedulers.py +0 -191
  50. sleap_nn-0.1.0.dist-info/RECORD +0 -88
  51. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/entry_points.txt +0 -0
  52. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/licenses/LICENSE +0 -0
  53. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,10 @@
1
1
  """Custom `torch.utils.data.Dataset`s for different model types."""
2
2
 
3
- from sleap_nn.data.skia_augmentation import crop_and_resize_skia as crop_and_resize
3
+ from kornia.geometry.transform import crop_and_resize
4
4
 
5
- import os
6
- import threading
7
- from concurrent.futures import ThreadPoolExecutor, as_completed
8
- from copy import deepcopy
5
+ # from concurrent.futures import ThreadPoolExecutor # TODO: implement parallel processing
6
+ # import concurrent.futures
7
+ # import os
9
8
  from itertools import cycle
10
9
  from pathlib import Path
11
10
  import torch.distributed as dist
@@ -14,14 +13,6 @@ from omegaconf import DictConfig, OmegaConf
14
13
  import numpy as np
15
14
  from PIL import Image
16
15
  from loguru import logger
17
- from rich.progress import (
18
- Progress,
19
- SpinnerColumn,
20
- TextColumn,
21
- BarColumn,
22
- TimeElapsedColumn,
23
- )
24
- from rich.console import Console
25
16
  import torch
26
17
  import torchvision.transforms as T
27
18
  from torch.utils.data import Dataset, DataLoader, DistributedSampler
@@ -31,6 +22,7 @@ from sleap_nn.data.identity import generate_class_maps, make_class_vectors
31
22
  from sleap_nn.data.instance_centroids import generate_centroids
32
23
  from sleap_nn.data.instance_cropping import generate_crops
33
24
  from sleap_nn.data.normalization import (
25
+ apply_normalization,
34
26
  convert_to_grayscale,
35
27
  convert_to_rgb,
36
28
  )
@@ -46,182 +38,6 @@ from sleap_nn.data.instance_cropping import make_centered_bboxes
46
38
  from sleap_nn.training.utils import is_distributed_initialized
47
39
  from sleap_nn.config.get_config import get_aug_config
48
40
 
49
- # Minimum number of samples to use parallel caching (overhead not worth it for smaller)
50
- MIN_SAMPLES_FOR_PARALLEL_CACHING = 20
51
-
52
-
53
- class ParallelCacheFiller:
54
- """Parallel implementation of image caching using thread-local video copies.
55
-
56
- This class uses ThreadPoolExecutor to parallelize I/O-bound operations when
57
- caching images to disk or memory. Each worker thread gets its own copy of
58
- video objects to ensure thread safety.
59
-
60
- Attributes:
61
- labels: List of sio.Labels objects containing the data.
62
- lf_idx_list: List of dictionaries with labeled frame indices.
63
- cache_type: Either "disk" or "memory".
64
- cache_path: Path to save cached images (for disk caching).
65
- num_workers: Number of worker threads.
66
- """
67
-
68
- def __init__(
69
- self,
70
- labels: List[sio.Labels],
71
- lf_idx_list: List[Dict],
72
- cache_type: str,
73
- cache_path: Optional[Path] = None,
74
- num_workers: int = 4,
75
- ):
76
- """Initialize the parallel cache filler.
77
-
78
- Args:
79
- labels: List of sio.Labels objects.
80
- lf_idx_list: List of sample dictionaries with frame indices.
81
- cache_type: Either "disk" or "memory".
82
- cache_path: Path for disk caching.
83
- num_workers: Number of worker threads.
84
- """
85
- self.labels = labels
86
- self.lf_idx_list = lf_idx_list
87
- self.cache_type = cache_type
88
- self.cache_path = cache_path
89
- self.num_workers = num_workers
90
-
91
- self.cache: Dict = {}
92
- self._cache_lock = threading.Lock()
93
- self._local = threading.local()
94
- self._video_info: Dict = {}
95
-
96
- # Prepare video copies for thread-local access
97
- self._prepare_video_copies()
98
-
99
- def _prepare_video_copies(self):
100
- """Close original videos and prepare for thread-local copies."""
101
- for label in self.labels:
102
- for video in label.videos:
103
- vid_id = id(video)
104
- if vid_id not in self._video_info:
105
- # Store original state
106
- original_open_backend = video.open_backend
107
-
108
- # Close the video backend
109
- video.close()
110
- video.open_backend = False
111
-
112
- self._video_info[vid_id] = {
113
- "video": video,
114
- "original_open_backend": original_open_backend,
115
- }
116
-
117
- def _get_thread_local_video(self, video: sio.Video) -> sio.Video:
118
- """Get or create a thread-local video copy.
119
-
120
- Args:
121
- video: The original video object.
122
-
123
- Returns:
124
- A thread-local copy of the video that is safe to use.
125
- """
126
- vid_id = id(video)
127
-
128
- if not hasattr(self._local, "videos"):
129
- self._local.videos = {}
130
-
131
- if vid_id not in self._local.videos:
132
- # Create a thread-local copy
133
- video_copy = deepcopy(video)
134
- video_copy.open_backend = True
135
- self._local.videos[vid_id] = video_copy
136
-
137
- return self._local.videos[vid_id]
138
-
139
- def _process_sample(
140
- self, sample: Dict
141
- ) -> Tuple[int, int, Optional[np.ndarray], Optional[str]]:
142
- """Process a single sample (read image, optionally save/cache).
143
-
144
- Args:
145
- sample: Dictionary with labels_idx, lf_idx, etc.
146
-
147
- Returns:
148
- Tuple of (labels_idx, lf_idx, image_or_none, error_or_none).
149
- """
150
- labels_idx = sample["labels_idx"]
151
- lf_idx = sample["lf_idx"]
152
-
153
- try:
154
- # Get the labeled frame
155
- lf = self.labels[labels_idx][lf_idx]
156
-
157
- # Get thread-local video
158
- video = self._get_thread_local_video(lf.video)
159
-
160
- # Read the image
161
- img = video[lf.frame_idx]
162
-
163
- if img.shape[-1] == 1:
164
- img = np.squeeze(img)
165
-
166
- if self.cache_type == "disk":
167
- f_name = self.cache_path / f"sample_{labels_idx}_{lf_idx}.jpg"
168
- Image.fromarray(img).save(str(f_name), format="JPEG")
169
- return labels_idx, lf_idx, None, None
170
- elif self.cache_type == "memory":
171
- return labels_idx, lf_idx, img, None
172
-
173
- except Exception as e:
174
- return labels_idx, lf_idx, None, f"{type(e).__name__}: {str(e)}"
175
-
176
- def fill_cache(
177
- self, progress_callback=None
178
- ) -> Tuple[Dict, List[Tuple[int, int, str]]]:
179
- """Fill the cache in parallel.
180
-
181
- Args:
182
- progress_callback: Optional callback(completed_count) for progress updates.
183
-
184
- Returns:
185
- Tuple of (cache_dict, list_of_errors).
186
- """
187
- errors = []
188
- completed = 0
189
-
190
- with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
191
- futures = {
192
- executor.submit(self._process_sample, sample): sample
193
- for sample in self.lf_idx_list
194
- }
195
-
196
- for future in as_completed(futures):
197
- labels_idx, lf_idx, img, error = future.result()
198
-
199
- if error:
200
- errors.append((labels_idx, lf_idx, error))
201
- elif self.cache_type == "memory" and img is not None:
202
- with self._cache_lock:
203
- self.cache[(labels_idx, lf_idx)] = img
204
-
205
- completed += 1
206
- if progress_callback:
207
- progress_callback(completed)
208
-
209
- # Restore original video states
210
- self._restore_videos()
211
-
212
- return self.cache, errors
213
-
214
- def _restore_videos(self):
215
- """Restore original video states after caching is complete."""
216
- for vid_info in self._video_info.values():
217
- video = vid_info["video"]
218
- video.open_backend = vid_info["original_open_backend"]
219
- if video.open_backend:
220
- try:
221
- video.open()
222
- except Exception:
223
- pass
224
-
225
41
 
226
42
  class BaseDataset(Dataset):
227
43
  """Base class for custom torch Datasets.
@@ -260,8 +76,6 @@ class BaseDataset(Dataset):
260
76
  use_existing_imgs: Use existing imgs/ chunks in the `cache_img_path`.
261
77
  rank: Indicates the rank of the process. Used during distributed training to ensure that image storage to
262
78
  disk occurs only once across all workers.
263
- parallel_caching: If True, use parallel processing for caching (faster for large datasets). Default: True.
264
- cache_workers: Number of worker threads for parallel caching. If 0, uses min(4, cpu_count). Default: 0.
265
79
  labels_list: List of `sio.Labels` objects. Used to store the labels in the cache. (only used if `cache_img` is `None`)
266
80
  """
267
81
 
@@ -281,8 +95,6 @@ class BaseDataset(Dataset):
281
95
  cache_img_path: Optional[str] = None,
282
96
  use_existing_imgs: bool = False,
283
97
  rank: Optional[int] = None,
284
- parallel_caching: bool = True,
285
- cache_workers: int = 0,
286
98
  ) -> None:
287
99
  """Initialize class attributes."""
288
100
  super().__init__()
@@ -323,8 +135,6 @@ class BaseDataset(Dataset):
323
135
  self.cache_img = cache_img
324
136
  self.cache_img_path = cache_img_path
325
137
  self.use_existing_imgs = use_existing_imgs
326
- self.parallel_caching = parallel_caching
327
- self.cache_workers = cache_workers
328
138
  if self.cache_img is not None and "disk" in self.cache_img:
329
139
  if self.cache_img_path is None:
330
140
  self.cache_img_path = "."
@@ -350,18 +160,10 @@ class BaseDataset(Dataset):
350
160
 
351
161
  if self.cache_img is not None:
352
162
  if self.cache_img == "memory":
353
- self._fill_cache(
354
- labels,
355
- parallel=self.parallel_caching,
356
- num_workers=self.cache_workers,
357
- )
163
+ self._fill_cache(labels)
358
164
  elif self.cache_img == "disk" and not self.use_existing_imgs:
359
165
  if self.rank is None or self.rank == -1 or self.rank == 0:
360
- self._fill_cache(
361
- labels,
362
- parallel=self.parallel_caching,
363
- num_workers=self.cache_workers,
364
- )
166
+ self._fill_cache(labels)
365
167
  # Synchronize all ranks after cache creation
366
168
  if is_distributed_initialized():
367
169
  dist.barrier()
@@ -410,160 +212,20 @@ class BaseDataset(Dataset):
410
212
  """Returns an iterator."""
411
213
  return self
412
214
 
413
- def _fill_cache(
414
- self,
415
- labels: List[sio.Labels],
416
- parallel: bool = True,
417
- num_workers: int = 0,
418
- ):
419
- """Load all samples to cache.
420
-
421
- Args:
422
- labels: List of sio.Labels objects containing the data.
423
- parallel: If True, use parallel processing for caching (faster for large
424
- datasets). Default: True.
425
- num_workers: Number of worker threads for parallel caching. If 0, uses
426
- min(4, cpu_count). Default: 0.
427
- """
428
- total_samples = len(self.lf_idx_list)
429
- cache_type = "disk" if self.cache_img == "disk" else "memory"
430
-
431
- # Check for NO_COLOR env var to disable progress bar
432
- no_color = (
433
- os.environ.get("NO_COLOR") is not None
434
- or os.environ.get("FORCE_COLOR") == "0"
435
- )
436
- use_progress = not no_color
437
-
438
- # Use parallel caching for larger datasets
439
- use_parallel = parallel and total_samples >= MIN_SAMPLES_FOR_PARALLEL_CACHING
440
-
441
- logger.info(f"Caching {total_samples} images to {cache_type}...")
442
-
443
- if use_parallel:
444
- self._fill_cache_parallel(
445
- labels, total_samples, cache_type, use_progress, num_workers
446
- )
447
- else:
448
- self._fill_cache_sequential(labels, total_samples, cache_type, use_progress)
449
-
450
- logger.info(f"Caching complete.")
451
-
452
- def _fill_cache_sequential(
453
- self,
454
- labels: List[sio.Labels],
455
- total_samples: int,
456
- cache_type: str,
457
- use_progress: bool,
458
- ):
459
- """Sequential implementation of cache filling.
460
-
461
- Args:
462
- labels: List of sio.Labels objects.
463
- total_samples: Total number of samples to cache.
464
- cache_type: Either "disk" or "memory".
465
- use_progress: Whether to show a progress bar.
466
- """
467
-
468
- def process_samples(progress=None, task=None):
469
- for sample in self.lf_idx_list:
470
- labels_idx = sample["labels_idx"]
471
- lf_idx = sample["lf_idx"]
472
- img = labels[labels_idx][lf_idx].image
473
- if img.shape[-1] == 1:
474
- img = np.squeeze(img)
475
- if self.cache_img == "disk":
476
- f_name = f"{self.cache_img_path}/sample_{labels_idx}_{lf_idx}.jpg"
477
- Image.fromarray(img).save(f_name, format="JPEG")
478
- if self.cache_img == "memory":
479
- self.cache[(labels_idx, lf_idx)] = img
480
- if progress is not None:
481
- progress.update(task, advance=1)
482
-
483
- if use_progress:
484
- with Progress(
485
- SpinnerColumn(),
486
- TextColumn("[progress.description]{task.description}"),
487
- BarColumn(),
488
- TextColumn("{task.completed}/{task.total}"),
489
- TimeElapsedColumn(),
490
- console=Console(force_terminal=True),
491
- transient=True,
492
- ) as progress:
493
- task = progress.add_task(
494
- f"Caching images to {cache_type}", total=total_samples
495
- )
496
- process_samples(progress, task)
497
- else:
498
- process_samples()
499
-
500
- def _fill_cache_parallel(
501
- self,
502
- labels: List[sio.Labels],
503
- total_samples: int,
504
- cache_type: str,
505
- use_progress: bool,
506
- num_workers: int = 0,
507
- ):
508
- """Parallel implementation of cache filling using thread-local video copies.
509
-
510
- Args:
511
- labels: List of sio.Labels objects.
512
- total_samples: Total number of samples to cache.
513
- cache_type: Either "disk" or "memory".
514
- use_progress: Whether to show a progress bar.
515
- num_workers: Number of worker threads. If 0, uses min(4, cpu_count).
516
- """
517
- # Determine number of workers
518
- if num_workers <= 0:
519
- num_workers = min(4, os.cpu_count() or 1)
520
-
521
- cache_path = Path(self.cache_img_path) if self.cache_img_path else None
522
-
523
- filler = ParallelCacheFiller(
524
- labels=labels,
525
- lf_idx_list=self.lf_idx_list,
526
- cache_type=cache_type,
527
- cache_path=cache_path,
528
- num_workers=num_workers,
529
- )
530
-
531
- if use_progress:
532
- with Progress(
533
- SpinnerColumn(),
534
- TextColumn("[progress.description]{task.description}"),
535
- BarColumn(),
536
- TextColumn("{task.completed}/{task.total}"),
537
- TimeElapsedColumn(),
538
- console=Console(force_terminal=True),
539
- transient=True,
540
- ) as progress:
541
- task = progress.add_task(
542
- f"Caching images to {cache_type} (parallel, {num_workers} workers)",
543
- total=total_samples,
544
- )
545
-
546
- def progress_callback(completed):
547
- progress.update(task, completed=completed)
548
-
549
- cache, errors = filler.fill_cache(progress_callback)
550
- else:
551
- logger.info(
552
- f"Caching {total_samples} images to {cache_type} "
553
- f"(parallel, {num_workers} workers)..."
554
- )
555
- cache, errors = filler.fill_cache()
556
-
557
- # Update instance cache
558
- if cache_type == "memory":
559
- self.cache.update(cache)
560
-
561
- # Log any errors
562
- if errors:
563
- logger.warning(
564
- f"Parallel caching completed with {len(errors)} errors. "
565
- f"First error: {errors[0]}"
566
- )
215
+ def _fill_cache(self, labels: List[sio.Labels]):
216
+ """Load all samples to cache."""
217
+ # TODO: Implement parallel processing (using threads might cause error with MediaVideo backend)
218
+ for sample in self.lf_idx_list:
219
+ labels_idx = sample["labels_idx"]
220
+ lf_idx = sample["lf_idx"]
221
+ img = labels[labels_idx][lf_idx].image
222
+ if img.shape[-1] == 1:
223
+ img = np.squeeze(img)
224
+ if self.cache_img == "disk":
225
+ f_name = f"{self.cache_img_path}/sample_{labels_idx}_{lf_idx}.jpg"
226
+ Image.fromarray(img).save(f_name, format="JPEG")
227
+ if self.cache_img == "memory":
228
+ self.cache[(labels_idx, lf_idx)] = img
567
229
 
568
230
  def __len__(self) -> int:
569
231
  """Return the number of samples in the dataset."""
@@ -639,8 +301,6 @@ class BottomUpDataset(BaseDataset):
639
301
  cache_img_path: Optional[str] = None,
640
302
  use_existing_imgs: bool = False,
641
303
  rank: Optional[int] = None,
642
- parallel_caching: bool = True,
643
- cache_workers: int = 0,
644
304
  ) -> None:
645
305
  """Initialize class attributes."""
646
306
  super().__init__(
@@ -658,8 +318,6 @@ class BottomUpDataset(BaseDataset):
658
318
  cache_img_path=cache_img_path,
659
319
  use_existing_imgs=use_existing_imgs,
660
320
  rank=rank,
661
- parallel_caching=parallel_caching,
662
- cache_workers=cache_workers,
663
321
  )
664
322
  self.confmap_head_config = confmap_head_config
665
323
  self.pafs_head_config = pafs_head_config
@@ -702,6 +360,9 @@ class BottomUpDataset(BaseDataset):
702
360
  user_instances_only=self.user_instances_only,
703
361
  )
704
362
 
363
+ # apply normalization
364
+ sample["image"] = apply_normalization(sample["image"])
365
+
705
366
  if self.ensure_rgb:
706
367
  sample["image"] = convert_to_rgb(sample["image"])
707
368
  elif self.ensure_grayscale:
@@ -839,8 +500,6 @@ class BottomUpMultiClassDataset(BaseDataset):
839
500
  cache_img_path: Optional[str] = None,
840
501
  use_existing_imgs: bool = False,
841
502
  rank: Optional[int] = None,
842
- parallel_caching: bool = True,
843
- cache_workers: int = 0,
844
503
  ) -> None:
845
504
  """Initialize class attributes."""
846
505
  super().__init__(
@@ -858,8 +517,6 @@ class BottomUpMultiClassDataset(BaseDataset):
858
517
  cache_img_path=cache_img_path,
859
518
  use_existing_imgs=use_existing_imgs,
860
519
  rank=rank,
861
- parallel_caching=parallel_caching,
862
- cache_workers=cache_workers,
863
520
  )
864
521
  self.confmap_head_config = confmap_head_config
865
522
  self.class_maps_head_config = class_maps_head_config
@@ -916,6 +573,9 @@ class BottomUpMultiClassDataset(BaseDataset):
916
573
 
917
574
  sample["num_tracks"] = torch.tensor(len(self.class_names), dtype=torch.int32)
918
575
 
576
+ # apply normalization
577
+ sample["image"] = apply_normalization(sample["image"])
578
+
919
579
  if self.ensure_rgb:
920
580
  sample["image"] = convert_to_rgb(sample["image"])
921
581
  elif self.ensure_grayscale:
@@ -1054,8 +714,6 @@ class CenteredInstanceDataset(BaseDataset):
1054
714
  cache_img_path: Optional[str] = None,
1055
715
  use_existing_imgs: bool = False,
1056
716
  rank: Optional[int] = None,
1057
- parallel_caching: bool = True,
1058
- cache_workers: int = 0,
1059
717
  ) -> None:
1060
718
  """Initialize class attributes."""
1061
719
  super().__init__(
@@ -1073,8 +731,6 @@ class CenteredInstanceDataset(BaseDataset):
1073
731
  cache_img_path=cache_img_path,
1074
732
  use_existing_imgs=use_existing_imgs,
1075
733
  rank=rank,
1076
- parallel_caching=parallel_caching,
1077
- cache_workers=cache_workers,
1078
734
  )
1079
735
  self.labels = None
1080
736
  self.crop_size = crop_size
@@ -1165,6 +821,9 @@ class CenteredInstanceDataset(BaseDataset):
1165
821
 
1166
822
  instances = instances[:, inst_idx]
1167
823
 
824
+ # apply normalization
825
+ image = apply_normalization(image)
826
+
1168
827
  if self.ensure_rgb:
1169
828
  image = convert_to_rgb(image)
1170
829
  elif self.ensure_grayscale:
@@ -1333,8 +992,6 @@ class TopDownCenteredInstanceMultiClassDataset(CenteredInstanceDataset):
1333
992
  cache_img_path: Optional[str] = None,
1334
993
  use_existing_imgs: bool = False,
1335
994
  rank: Optional[int] = None,
1336
- parallel_caching: bool = True,
1337
- cache_workers: int = 0,
1338
995
  ) -> None:
1339
996
  """Initialize class attributes."""
1340
997
  super().__init__(
@@ -1355,8 +1012,6 @@ class TopDownCenteredInstanceMultiClassDataset(CenteredInstanceDataset):
1355
1012
  cache_img_path=cache_img_path,
1356
1013
  use_existing_imgs=use_existing_imgs,
1357
1014
  rank=rank,
1358
- parallel_caching=parallel_caching,
1359
- cache_workers=cache_workers,
1360
1015
  )
1361
1016
  self.class_vectors_head_config = class_vectors_head_config
1362
1017
  self.class_names = self.class_vectors_head_config.classes
@@ -1411,6 +1066,9 @@ class TopDownCenteredInstanceMultiClassDataset(CenteredInstanceDataset):
1411
1066
 
1412
1067
  instances = instances[:, inst_idx]
1413
1068
 
1069
+ # apply normalization
1070
+ image = apply_normalization(image)
1071
+
1414
1072
  if self.ensure_rgb:
1415
1073
  image = convert_to_rgb(image)
1416
1074
  elif self.ensure_grayscale:
@@ -1592,8 +1250,6 @@ class CentroidDataset(BaseDataset):
1592
1250
  cache_img_path: Optional[str] = None,
1593
1251
  use_existing_imgs: bool = False,
1594
1252
  rank: Optional[int] = None,
1595
- parallel_caching: bool = True,
1596
- cache_workers: int = 0,
1597
1253
  ) -> None:
1598
1254
  """Initialize class attributes."""
1599
1255
  super().__init__(
@@ -1611,8 +1267,6 @@ class CentroidDataset(BaseDataset):
1611
1267
  cache_img_path=cache_img_path,
1612
1268
  use_existing_imgs=use_existing_imgs,
1613
1269
  rank=rank,
1614
- parallel_caching=parallel_caching,
1615
- cache_workers=cache_workers,
1616
1270
  )
1617
1271
  self.anchor_ind = anchor_ind
1618
1272
  self.confmap_head_config = confmap_head_config
@@ -1652,6 +1306,9 @@ class CentroidDataset(BaseDataset):
1652
1306
  user_instances_only=self.user_instances_only,
1653
1307
  )
1654
1308
 
1309
+ # apply normalization
1310
+ sample["image"] = apply_normalization(sample["image"])
1311
+
1655
1312
  if self.ensure_rgb:
1656
1313
  sample["image"] = convert_to_rgb(sample["image"])
1657
1314
  elif self.ensure_grayscale:
@@ -1776,8 +1433,6 @@ class SingleInstanceDataset(BaseDataset):
1776
1433
  cache_img_path: Optional[str] = None,
1777
1434
  use_existing_imgs: bool = False,
1778
1435
  rank: Optional[int] = None,
1779
- parallel_caching: bool = True,
1780
- cache_workers: int = 0,
1781
1436
  ) -> None:
1782
1437
  """Initialize class attributes."""
1783
1438
  super().__init__(
@@ -1795,8 +1450,6 @@ class SingleInstanceDataset(BaseDataset):
1795
1450
  cache_img_path=cache_img_path,
1796
1451
  use_existing_imgs=use_existing_imgs,
1797
1452
  rank=rank,
1798
- parallel_caching=parallel_caching,
1799
- cache_workers=cache_workers,
1800
1453
  )
1801
1454
  self.confmap_head_config = confmap_head_config
1802
1455
 
@@ -1835,6 +1488,9 @@ class SingleInstanceDataset(BaseDataset):
1835
1488
  user_instances_only=self.user_instances_only,
1836
1489
  )
1837
1490
 
1491
+ # apply normalization
1492
+ sample["image"] = apply_normalization(sample["image"])
1493
+
1838
1494
  if self.ensure_rgb:
1839
1495
  sample["image"] = convert_to_rgb(sample["image"])
1840
1496
  elif self.ensure_grayscale:
@@ -2015,10 +1671,6 @@ def get_train_val_datasets(
2015
1671
  val_cache_img_path = Path(base_cache_img_path) / "val_imgs"
2016
1672
  use_existing_imgs = config.data_config.use_existing_imgs
2017
1673
 
2018
- # Parallel caching configuration
2019
- parallel_caching = getattr(config.data_config, "parallel_caching", True)
2020
- cache_workers = getattr(config.data_config, "cache_workers", 0)
2021
-
2022
1674
  model_type = get_model_type_from_cfg(config=config)
2023
1675
  backbone_type = get_backbone_type_from_cfg(config=config)
2024
1676
 
@@ -2072,8 +1724,6 @@ def get_train_val_datasets(
2072
1724
  cache_img_path=train_cache_img_path,
2073
1725
  use_existing_imgs=use_existing_imgs,
2074
1726
  rank=rank,
2075
- parallel_caching=parallel_caching,
2076
- cache_workers=cache_workers,
2077
1727
  )
2078
1728
  val_dataset = BottomUpDataset(
2079
1729
  labels=val_labels,
@@ -2097,8 +1747,6 @@ def get_train_val_datasets(
2097
1747
  cache_img_path=val_cache_img_path,
2098
1748
  use_existing_imgs=use_existing_imgs,
2099
1749
  rank=rank,
2100
- parallel_caching=parallel_caching,
2101
- cache_workers=cache_workers,
2102
1750
  )
2103
1751
 
2104
1752
  elif model_type == "multi_class_bottomup":
@@ -2132,8 +1780,6 @@ def get_train_val_datasets(
2132
1780
  cache_img_path=train_cache_img_path,
2133
1781
  use_existing_imgs=use_existing_imgs,
2134
1782
  rank=rank,
2135
- parallel_caching=parallel_caching,
2136
- cache_workers=cache_workers,
2137
1783
  )
2138
1784
  val_dataset = BottomUpMultiClassDataset(
2139
1785
  labels=val_labels,
@@ -2157,8 +1803,6 @@ def get_train_val_datasets(
2157
1803
  cache_img_path=val_cache_img_path,
2158
1804
  use_existing_imgs=use_existing_imgs,
2159
1805
  rank=rank,
2160
- parallel_caching=parallel_caching,
2161
- cache_workers=cache_workers,
2162
1806
  )
2163
1807
 
2164
1808
  elif model_type == "centered_instance":
@@ -2198,8 +1842,6 @@ def get_train_val_datasets(
2198
1842
  cache_img_path=train_cache_img_path,
2199
1843
  use_existing_imgs=use_existing_imgs,
2200
1844
  rank=rank,
2201
- parallel_caching=parallel_caching,
2202
- cache_workers=cache_workers,
2203
1845
  )
2204
1846
  val_dataset = CenteredInstanceDataset(
2205
1847
  labels=val_labels,
@@ -2224,8 +1866,6 @@ def get_train_val_datasets(
2224
1866
  cache_img_path=val_cache_img_path,
2225
1867
  use_existing_imgs=use_existing_imgs,
2226
1868
  rank=rank,
2227
- parallel_caching=parallel_caching,
2228
- cache_workers=cache_workers,
2229
1869
  )
2230
1870
 
2231
1871
  elif model_type == "multi_class_topdown":
@@ -2266,8 +1906,6 @@ def get_train_val_datasets(
2266
1906
  cache_img_path=train_cache_img_path,
2267
1907
  use_existing_imgs=use_existing_imgs,
2268
1908
  rank=rank,
2269
- parallel_caching=parallel_caching,
2270
- cache_workers=cache_workers,
2271
1909
  )
2272
1910
  val_dataset = TopDownCenteredInstanceMultiClassDataset(
2273
1911
  labels=val_labels,
@@ -2293,8 +1931,6 @@ def get_train_val_datasets(
2293
1931
  cache_img_path=val_cache_img_path,
2294
1932
  use_existing_imgs=use_existing_imgs,
2295
1933
  rank=rank,
2296
- parallel_caching=parallel_caching,
2297
- cache_workers=cache_workers,
2298
1934
  )
2299
1935
 
2300
1936
  elif model_type == "centroid":
@@ -2331,8 +1967,6 @@ def get_train_val_datasets(
2331
1967
  cache_img_path=train_cache_img_path,
2332
1968
  use_existing_imgs=use_existing_imgs,
2333
1969
  rank=rank,
2334
- parallel_caching=parallel_caching,
2335
- cache_workers=cache_workers,
2336
1970
  )
2337
1971
  val_dataset = CentroidDataset(
2338
1972
  labels=val_labels,
@@ -2356,8 +1990,6 @@ def get_train_val_datasets(
2356
1990
  cache_img_path=val_cache_img_path,
2357
1991
  use_existing_imgs=use_existing_imgs,
2358
1992
  rank=rank,
2359
- parallel_caching=parallel_caching,
2360
- cache_workers=cache_workers,
2361
1993
  )
2362
1994
 
2363
1995
  else:
@@ -2390,8 +2022,6 @@ def get_train_val_datasets(
2390
2022
  cache_img_path=train_cache_img_path,
2391
2023
  use_existing_imgs=use_existing_imgs,
2392
2024
  rank=rank,
2393
- parallel_caching=parallel_caching,
2394
- cache_workers=cache_workers,
2395
2025
  )
2396
2026
  val_dataset = SingleInstanceDataset(
2397
2027
  labels=val_labels,
@@ -2414,8 +2044,6 @@ def get_train_val_datasets(
2414
2044
  cache_img_path=val_cache_img_path,
2415
2045
  use_existing_imgs=use_existing_imgs,
2416
2046
  rank=rank,
2417
- parallel_caching=parallel_caching,
2418
- cache_workers=cache_workers,
2419
2047
  )
2420
2048
 
2421
2049
  # If using caching, close the videos to prevent `h5py objects can't be pickled error` when num_workers > 0.