sleap-nn 0.1.0__py3-none-any.whl → 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. sleap_nn/__init__.py +1 -1
  2. sleap_nn/architectures/convnext.py +0 -5
  3. sleap_nn/architectures/encoder_decoder.py +6 -25
  4. sleap_nn/architectures/swint.py +0 -8
  5. sleap_nn/cli.py +60 -364
  6. sleap_nn/config/data_config.py +5 -11
  7. sleap_nn/config/get_config.py +4 -5
  8. sleap_nn/config/trainer_config.py +0 -71
  9. sleap_nn/data/augmentation.py +241 -50
  10. sleap_nn/data/custom_datasets.py +34 -364
  11. sleap_nn/data/instance_cropping.py +1 -1
  12. sleap_nn/data/resizing.py +2 -2
  13. sleap_nn/data/utils.py +17 -135
  14. sleap_nn/evaluation.py +22 -81
  15. sleap_nn/inference/bottomup.py +20 -86
  16. sleap_nn/inference/peak_finding.py +19 -88
  17. sleap_nn/inference/predictors.py +117 -224
  18. sleap_nn/legacy_models.py +11 -65
  19. sleap_nn/predict.py +9 -37
  20. sleap_nn/train.py +4 -69
  21. sleap_nn/training/callbacks.py +105 -1046
  22. sleap_nn/training/lightning_modules.py +65 -602
  23. sleap_nn/training/model_trainer.py +204 -201
  24. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/METADATA +3 -15
  25. sleap_nn-0.1.0a1.dist-info/RECORD +65 -0
  26. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/WHEEL +1 -1
  27. sleap_nn/data/skia_augmentation.py +0 -414
  28. sleap_nn/export/__init__.py +0 -21
  29. sleap_nn/export/cli.py +0 -1778
  30. sleap_nn/export/exporters/__init__.py +0 -51
  31. sleap_nn/export/exporters/onnx_exporter.py +0 -80
  32. sleap_nn/export/exporters/tensorrt_exporter.py +0 -291
  33. sleap_nn/export/metadata.py +0 -225
  34. sleap_nn/export/predictors/__init__.py +0 -63
  35. sleap_nn/export/predictors/base.py +0 -22
  36. sleap_nn/export/predictors/onnx.py +0 -154
  37. sleap_nn/export/predictors/tensorrt.py +0 -312
  38. sleap_nn/export/utils.py +0 -307
  39. sleap_nn/export/wrappers/__init__.py +0 -25
  40. sleap_nn/export/wrappers/base.py +0 -96
  41. sleap_nn/export/wrappers/bottomup.py +0 -243
  42. sleap_nn/export/wrappers/bottomup_multiclass.py +0 -195
  43. sleap_nn/export/wrappers/centered_instance.py +0 -56
  44. sleap_nn/export/wrappers/centroid.py +0 -58
  45. sleap_nn/export/wrappers/single_instance.py +0 -83
  46. sleap_nn/export/wrappers/topdown.py +0 -180
  47. sleap_nn/export/wrappers/topdown_multiclass.py +0 -304
  48. sleap_nn/inference/postprocessing.py +0 -284
  49. sleap_nn/training/schedulers.py +0 -191
  50. sleap_nn-0.1.0.dist-info/RECORD +0 -88
  51. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/entry_points.txt +0 -0
  52. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/licenses/LICENSE +0 -0
  53. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,10 @@
1
1
  """Custom `torch.utils.data.Dataset`s for different model types."""
2
2
 
3
- from sleap_nn.data.skia_augmentation import crop_and_resize_skia as crop_and_resize
3
+ from kornia.geometry.transform import crop_and_resize
4
4
 
5
- import os
6
- import threading
7
- from concurrent.futures import ThreadPoolExecutor, as_completed
8
- from copy import deepcopy
5
+ # from concurrent.futures import ThreadPoolExecutor # TODO: implement parallel processing
6
+ # import concurrent.futures
7
+ # import os
9
8
  from itertools import cycle
10
9
  from pathlib import Path
11
10
  import torch.distributed as dist
@@ -31,6 +30,7 @@ from sleap_nn.data.identity import generate_class_maps, make_class_vectors
31
30
  from sleap_nn.data.instance_centroids import generate_centroids
32
31
  from sleap_nn.data.instance_cropping import generate_crops
33
32
  from sleap_nn.data.normalization import (
33
+ apply_normalization,
34
34
  convert_to_grayscale,
35
35
  convert_to_rgb,
36
36
  )
@@ -46,182 +46,6 @@ from sleap_nn.data.instance_cropping import make_centered_bboxes
46
46
  from sleap_nn.training.utils import is_distributed_initialized
47
47
  from sleap_nn.config.get_config import get_aug_config
48
48
 
49
- # Minimum number of samples to use parallel caching (overhead not worth it for smaller)
50
- MIN_SAMPLES_FOR_PARALLEL_CACHING = 20
51
-
52
-
53
- class ParallelCacheFiller:
54
- """Parallel implementation of image caching using thread-local video copies.
55
-
56
- This class uses ThreadPoolExecutor to parallelize I/O-bound operations when
57
- caching images to disk or memory. Each worker thread gets its own copy of
58
- video objects to ensure thread safety.
59
-
60
- Attributes:
61
- labels: List of sio.Labels objects containing the data.
62
- lf_idx_list: List of dictionaries with labeled frame indices.
63
- cache_type: Either "disk" or "memory".
64
- cache_path: Path to save cached images (for disk caching).
65
- num_workers: Number of worker threads.
66
- """
67
-
68
- def __init__(
69
- self,
70
- labels: List[sio.Labels],
71
- lf_idx_list: List[Dict],
72
- cache_type: str,
73
- cache_path: Optional[Path] = None,
74
- num_workers: int = 4,
75
- ):
76
- """Initialize the parallel cache filler.
77
-
78
- Args:
79
- labels: List of sio.Labels objects.
80
- lf_idx_list: List of sample dictionaries with frame indices.
81
- cache_type: Either "disk" or "memory".
82
- cache_path: Path for disk caching.
83
- num_workers: Number of worker threads.
84
- """
85
- self.labels = labels
86
- self.lf_idx_list = lf_idx_list
87
- self.cache_type = cache_type
88
- self.cache_path = cache_path
89
- self.num_workers = num_workers
90
-
91
- self.cache: Dict = {}
92
- self._cache_lock = threading.Lock()
93
- self._local = threading.local()
94
- self._video_info: Dict = {}
95
-
96
- # Prepare video copies for thread-local access
97
- self._prepare_video_copies()
98
-
99
- def _prepare_video_copies(self):
100
- """Close original videos and prepare for thread-local copies."""
101
- for label in self.labels:
102
- for video in label.videos:
103
- vid_id = id(video)
104
- if vid_id not in self._video_info:
105
- # Store original state
106
- original_open_backend = video.open_backend
107
-
108
- # Close the video backend
109
- video.close()
110
- video.open_backend = False
111
-
112
- self._video_info[vid_id] = {
113
- "video": video,
114
- "original_open_backend": original_open_backend,
115
- }
116
-
117
- def _get_thread_local_video(self, video: sio.Video) -> sio.Video:
118
- """Get or create a thread-local video copy.
119
-
120
- Args:
121
- video: The original video object.
122
-
123
- Returns:
124
- A thread-local copy of the video that is safe to use.
125
- """
126
- vid_id = id(video)
127
-
128
- if not hasattr(self._local, "videos"):
129
- self._local.videos = {}
130
-
131
- if vid_id not in self._local.videos:
132
- # Create a thread-local copy
133
- video_copy = deepcopy(video)
134
- video_copy.open_backend = True
135
- self._local.videos[vid_id] = video_copy
136
-
137
- return self._local.videos[vid_id]
138
-
139
- def _process_sample(
140
- self, sample: Dict
141
- ) -> Tuple[int, int, Optional[np.ndarray], Optional[str]]:
142
- """Process a single sample (read image, optionally save/cache).
143
-
144
- Args:
145
- sample: Dictionary with labels_idx, lf_idx, etc.
146
-
147
- Returns:
148
- Tuple of (labels_idx, lf_idx, image_or_none, error_or_none).
149
- """
150
- labels_idx = sample["labels_idx"]
151
- lf_idx = sample["lf_idx"]
152
-
153
- try:
154
- # Get the labeled frame
155
- lf = self.labels[labels_idx][lf_idx]
156
-
157
- # Get thread-local video
158
- video = self._get_thread_local_video(lf.video)
159
-
160
- # Read the image
161
- img = video[lf.frame_idx]
162
-
163
- if img.shape[-1] == 1:
164
- img = np.squeeze(img)
165
-
166
- if self.cache_type == "disk":
167
- f_name = self.cache_path / f"sample_{labels_idx}_{lf_idx}.jpg"
168
- Image.fromarray(img).save(str(f_name), format="JPEG")
169
- return labels_idx, lf_idx, None, None
170
- elif self.cache_type == "memory":
171
- return labels_idx, lf_idx, img, None
172
-
173
- except Exception as e:
174
- return labels_idx, lf_idx, None, f"{type(e).__name__}: {str(e)}"
175
-
176
- def fill_cache(
177
- self, progress_callback=None
178
- ) -> Tuple[Dict, List[Tuple[int, int, str]]]:
179
- """Fill the cache in parallel.
180
-
181
- Args:
182
- progress_callback: Optional callback(completed_count) for progress updates.
183
-
184
- Returns:
185
- Tuple of (cache_dict, list_of_errors).
186
- """
187
- errors = []
188
- completed = 0
189
-
190
- with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
191
- futures = {
192
- executor.submit(self._process_sample, sample): sample
193
- for sample in self.lf_idx_list
194
- }
195
-
196
- for future in as_completed(futures):
197
- labels_idx, lf_idx, img, error = future.result()
198
-
199
- if error:
200
- errors.append((labels_idx, lf_idx, error))
201
- elif self.cache_type == "memory" and img is not None:
202
- with self._cache_lock:
203
- self.cache[(labels_idx, lf_idx)] = img
204
-
205
- completed += 1
206
- if progress_callback:
207
- progress_callback(completed)
208
-
209
- # Restore original video states
210
- self._restore_videos()
211
-
212
- return self.cache, errors
213
-
214
- def _restore_videos(self):
215
- """Restore original video states after caching is complete."""
216
- for vid_info in self._video_info.values():
217
- video = vid_info["video"]
218
- video.open_backend = vid_info["original_open_backend"]
219
- if video.open_backend:
220
- try:
221
- video.open()
222
- except Exception:
223
- pass
224
-
225
49
 
226
50
  class BaseDataset(Dataset):
227
51
  """Base class for custom torch Datasets.
@@ -260,8 +84,6 @@ class BaseDataset(Dataset):
260
84
  use_existing_imgs: Use existing imgs/ chunks in the `cache_img_path`.
261
85
  rank: Indicates the rank of the process. Used during distributed training to ensure that image storage to
262
86
  disk occurs only once across all workers.
263
- parallel_caching: If True, use parallel processing for caching (faster for large datasets). Default: True.
264
- cache_workers: Number of worker threads for parallel caching. If 0, uses min(4, cpu_count). Default: 0.
265
87
  labels_list: List of `sio.Labels` objects. Used to store the labels in the cache. (only used if `cache_img` is `None`)
266
88
  """
267
89
 
@@ -281,8 +103,6 @@ class BaseDataset(Dataset):
281
103
  cache_img_path: Optional[str] = None,
282
104
  use_existing_imgs: bool = False,
283
105
  rank: Optional[int] = None,
284
- parallel_caching: bool = True,
285
- cache_workers: int = 0,
286
106
  ) -> None:
287
107
  """Initialize class attributes."""
288
108
  super().__init__()
@@ -323,8 +143,6 @@ class BaseDataset(Dataset):
323
143
  self.cache_img = cache_img
324
144
  self.cache_img_path = cache_img_path
325
145
  self.use_existing_imgs = use_existing_imgs
326
- self.parallel_caching = parallel_caching
327
- self.cache_workers = cache_workers
328
146
  if self.cache_img is not None and "disk" in self.cache_img:
329
147
  if self.cache_img_path is None:
330
148
  self.cache_img_path = "."
@@ -350,18 +168,10 @@ class BaseDataset(Dataset):
350
168
 
351
169
  if self.cache_img is not None:
352
170
  if self.cache_img == "memory":
353
- self._fill_cache(
354
- labels,
355
- parallel=self.parallel_caching,
356
- num_workers=self.cache_workers,
357
- )
171
+ self._fill_cache(labels)
358
172
  elif self.cache_img == "disk" and not self.use_existing_imgs:
359
173
  if self.rank is None or self.rank == -1 or self.rank == 0:
360
- self._fill_cache(
361
- labels,
362
- parallel=self.parallel_caching,
363
- num_workers=self.cache_workers,
364
- )
174
+ self._fill_cache(labels)
365
175
  # Synchronize all ranks after cache creation
366
176
  if is_distributed_initialized():
367
177
  dist.barrier()
@@ -410,60 +220,21 @@ class BaseDataset(Dataset):
410
220
  """Returns an iterator."""
411
221
  return self
412
222
 
413
- def _fill_cache(
414
- self,
415
- labels: List[sio.Labels],
416
- parallel: bool = True,
417
- num_workers: int = 0,
418
- ):
419
- """Load all samples to cache.
420
-
421
- Args:
422
- labels: List of sio.Labels objects containing the data.
423
- parallel: If True, use parallel processing for caching (faster for large
424
- datasets). Default: True.
425
- num_workers: Number of worker threads for parallel caching. If 0, uses
426
- min(4, cpu_count). Default: 0.
427
- """
223
+ def _fill_cache(self, labels: List[sio.Labels]):
224
+ """Load all samples to cache."""
225
+ # TODO: Implement parallel processing (using threads might cause error with MediaVideo backend)
226
+ import os
227
+ import sys
228
+
428
229
  total_samples = len(self.lf_idx_list)
429
230
  cache_type = "disk" if self.cache_img == "disk" else "memory"
430
231
 
431
- # Check for NO_COLOR env var to disable progress bar
232
+ # Check for NO_COLOR env var or non-interactive terminal
432
233
  no_color = (
433
234
  os.environ.get("NO_COLOR") is not None
434
235
  or os.environ.get("FORCE_COLOR") == "0"
435
236
  )
436
- use_progress = not no_color
437
-
438
- # Use parallel caching for larger datasets
439
- use_parallel = parallel and total_samples >= MIN_SAMPLES_FOR_PARALLEL_CACHING
440
-
441
- logger.info(f"Caching {total_samples} images to {cache_type}...")
442
-
443
- if use_parallel:
444
- self._fill_cache_parallel(
445
- labels, total_samples, cache_type, use_progress, num_workers
446
- )
447
- else:
448
- self._fill_cache_sequential(labels, total_samples, cache_type, use_progress)
449
-
450
- logger.info(f"Caching complete.")
451
-
452
- def _fill_cache_sequential(
453
- self,
454
- labels: List[sio.Labels],
455
- total_samples: int,
456
- cache_type: str,
457
- use_progress: bool,
458
- ):
459
- """Sequential implementation of cache filling.
460
-
461
- Args:
462
- labels: List of sio.Labels objects.
463
- total_samples: Total number of samples to cache.
464
- cache_type: Either "disk" or "memory".
465
- use_progress: Whether to show a progress bar.
466
- """
237
+ use_progress = sys.stdout.isatty() and not no_color
467
238
 
468
239
  def process_samples(progress=None, task=None):
469
240
  for sample in self.lf_idx_list:
@@ -495,76 +266,9 @@ class BaseDataset(Dataset):
495
266
  )
496
267
  process_samples(progress, task)
497
268
  else:
269
+ logger.info(f"Caching {total_samples} images to {cache_type}...")
498
270
  process_samples()
499
271
 
500
- def _fill_cache_parallel(
501
- self,
502
- labels: List[sio.Labels],
503
- total_samples: int,
504
- cache_type: str,
505
- use_progress: bool,
506
- num_workers: int = 0,
507
- ):
508
- """Parallel implementation of cache filling using thread-local video copies.
509
-
510
- Args:
511
- labels: List of sio.Labels objects.
512
- total_samples: Total number of samples to cache.
513
- cache_type: Either "disk" or "memory".
514
- use_progress: Whether to show a progress bar.
515
- num_workers: Number of worker threads. If 0, uses min(4, cpu_count).
516
- """
517
- # Determine number of workers
518
- if num_workers <= 0:
519
- num_workers = min(4, os.cpu_count() or 1)
520
-
521
- cache_path = Path(self.cache_img_path) if self.cache_img_path else None
522
-
523
- filler = ParallelCacheFiller(
524
- labels=labels,
525
- lf_idx_list=self.lf_idx_list,
526
- cache_type=cache_type,
527
- cache_path=cache_path,
528
- num_workers=num_workers,
529
- )
530
-
531
- if use_progress:
532
- with Progress(
533
- SpinnerColumn(),
534
- TextColumn("[progress.description]{task.description}"),
535
- BarColumn(),
536
- TextColumn("{task.completed}/{task.total}"),
537
- TimeElapsedColumn(),
538
- console=Console(force_terminal=True),
539
- transient=True,
540
- ) as progress:
541
- task = progress.add_task(
542
- f"Caching images to {cache_type} (parallel, {num_workers} workers)",
543
- total=total_samples,
544
- )
545
-
546
- def progress_callback(completed):
547
- progress.update(task, completed=completed)
548
-
549
- cache, errors = filler.fill_cache(progress_callback)
550
- else:
551
- logger.info(
552
- f"Caching {total_samples} images to {cache_type} "
553
- f"(parallel, {num_workers} workers)..."
554
- )
555
- cache, errors = filler.fill_cache()
556
-
557
- # Update instance cache
558
- if cache_type == "memory":
559
- self.cache.update(cache)
560
-
561
- # Log any errors
562
- if errors:
563
- logger.warning(
564
- f"Parallel caching completed with {len(errors)} errors. "
565
- f"First error: {errors[0]}"
566
- )
567
-
568
272
  def __len__(self) -> int:
569
273
  """Return the number of samples in the dataset."""
570
274
  return len(self.lf_idx_list)
@@ -639,8 +343,6 @@ class BottomUpDataset(BaseDataset):
639
343
  cache_img_path: Optional[str] = None,
640
344
  use_existing_imgs: bool = False,
641
345
  rank: Optional[int] = None,
642
- parallel_caching: bool = True,
643
- cache_workers: int = 0,
644
346
  ) -> None:
645
347
  """Initialize class attributes."""
646
348
  super().__init__(
@@ -658,8 +360,6 @@ class BottomUpDataset(BaseDataset):
658
360
  cache_img_path=cache_img_path,
659
361
  use_existing_imgs=use_existing_imgs,
660
362
  rank=rank,
661
- parallel_caching=parallel_caching,
662
- cache_workers=cache_workers,
663
363
  )
664
364
  self.confmap_head_config = confmap_head_config
665
365
  self.pafs_head_config = pafs_head_config
@@ -702,6 +402,9 @@ class BottomUpDataset(BaseDataset):
702
402
  user_instances_only=self.user_instances_only,
703
403
  )
704
404
 
405
+ # apply normalization
406
+ sample["image"] = apply_normalization(sample["image"])
407
+
705
408
  if self.ensure_rgb:
706
409
  sample["image"] = convert_to_rgb(sample["image"])
707
410
  elif self.ensure_grayscale:
@@ -839,8 +542,6 @@ class BottomUpMultiClassDataset(BaseDataset):
839
542
  cache_img_path: Optional[str] = None,
840
543
  use_existing_imgs: bool = False,
841
544
  rank: Optional[int] = None,
842
- parallel_caching: bool = True,
843
- cache_workers: int = 0,
844
545
  ) -> None:
845
546
  """Initialize class attributes."""
846
547
  super().__init__(
@@ -858,8 +559,6 @@ class BottomUpMultiClassDataset(BaseDataset):
858
559
  cache_img_path=cache_img_path,
859
560
  use_existing_imgs=use_existing_imgs,
860
561
  rank=rank,
861
- parallel_caching=parallel_caching,
862
- cache_workers=cache_workers,
863
562
  )
864
563
  self.confmap_head_config = confmap_head_config
865
564
  self.class_maps_head_config = class_maps_head_config
@@ -916,6 +615,9 @@ class BottomUpMultiClassDataset(BaseDataset):
916
615
 
917
616
  sample["num_tracks"] = torch.tensor(len(self.class_names), dtype=torch.int32)
918
617
 
618
+ # apply normalization
619
+ sample["image"] = apply_normalization(sample["image"])
620
+
919
621
  if self.ensure_rgb:
920
622
  sample["image"] = convert_to_rgb(sample["image"])
921
623
  elif self.ensure_grayscale:
@@ -1054,8 +756,6 @@ class CenteredInstanceDataset(BaseDataset):
1054
756
  cache_img_path: Optional[str] = None,
1055
757
  use_existing_imgs: bool = False,
1056
758
  rank: Optional[int] = None,
1057
- parallel_caching: bool = True,
1058
- cache_workers: int = 0,
1059
759
  ) -> None:
1060
760
  """Initialize class attributes."""
1061
761
  super().__init__(
@@ -1073,8 +773,6 @@ class CenteredInstanceDataset(BaseDataset):
1073
773
  cache_img_path=cache_img_path,
1074
774
  use_existing_imgs=use_existing_imgs,
1075
775
  rank=rank,
1076
- parallel_caching=parallel_caching,
1077
- cache_workers=cache_workers,
1078
776
  )
1079
777
  self.labels = None
1080
778
  self.crop_size = crop_size
@@ -1165,6 +863,9 @@ class CenteredInstanceDataset(BaseDataset):
1165
863
 
1166
864
  instances = instances[:, inst_idx]
1167
865
 
866
+ # apply normalization
867
+ image = apply_normalization(image)
868
+
1168
869
  if self.ensure_rgb:
1169
870
  image = convert_to_rgb(image)
1170
871
  elif self.ensure_grayscale:
@@ -1333,8 +1034,6 @@ class TopDownCenteredInstanceMultiClassDataset(CenteredInstanceDataset):
1333
1034
  cache_img_path: Optional[str] = None,
1334
1035
  use_existing_imgs: bool = False,
1335
1036
  rank: Optional[int] = None,
1336
- parallel_caching: bool = True,
1337
- cache_workers: int = 0,
1338
1037
  ) -> None:
1339
1038
  """Initialize class attributes."""
1340
1039
  super().__init__(
@@ -1355,8 +1054,6 @@ class TopDownCenteredInstanceMultiClassDataset(CenteredInstanceDataset):
1355
1054
  cache_img_path=cache_img_path,
1356
1055
  use_existing_imgs=use_existing_imgs,
1357
1056
  rank=rank,
1358
- parallel_caching=parallel_caching,
1359
- cache_workers=cache_workers,
1360
1057
  )
1361
1058
  self.class_vectors_head_config = class_vectors_head_config
1362
1059
  self.class_names = self.class_vectors_head_config.classes
@@ -1411,6 +1108,9 @@ class TopDownCenteredInstanceMultiClassDataset(CenteredInstanceDataset):
1411
1108
 
1412
1109
  instances = instances[:, inst_idx]
1413
1110
 
1111
+ # apply normalization
1112
+ image = apply_normalization(image)
1113
+
1414
1114
  if self.ensure_rgb:
1415
1115
  image = convert_to_rgb(image)
1416
1116
  elif self.ensure_grayscale:
@@ -1592,8 +1292,6 @@ class CentroidDataset(BaseDataset):
1592
1292
  cache_img_path: Optional[str] = None,
1593
1293
  use_existing_imgs: bool = False,
1594
1294
  rank: Optional[int] = None,
1595
- parallel_caching: bool = True,
1596
- cache_workers: int = 0,
1597
1295
  ) -> None:
1598
1296
  """Initialize class attributes."""
1599
1297
  super().__init__(
@@ -1611,8 +1309,6 @@ class CentroidDataset(BaseDataset):
1611
1309
  cache_img_path=cache_img_path,
1612
1310
  use_existing_imgs=use_existing_imgs,
1613
1311
  rank=rank,
1614
- parallel_caching=parallel_caching,
1615
- cache_workers=cache_workers,
1616
1312
  )
1617
1313
  self.anchor_ind = anchor_ind
1618
1314
  self.confmap_head_config = confmap_head_config
@@ -1652,6 +1348,9 @@ class CentroidDataset(BaseDataset):
1652
1348
  user_instances_only=self.user_instances_only,
1653
1349
  )
1654
1350
 
1351
+ # apply normalization
1352
+ sample["image"] = apply_normalization(sample["image"])
1353
+
1655
1354
  if self.ensure_rgb:
1656
1355
  sample["image"] = convert_to_rgb(sample["image"])
1657
1356
  elif self.ensure_grayscale:
@@ -1776,8 +1475,6 @@ class SingleInstanceDataset(BaseDataset):
1776
1475
  cache_img_path: Optional[str] = None,
1777
1476
  use_existing_imgs: bool = False,
1778
1477
  rank: Optional[int] = None,
1779
- parallel_caching: bool = True,
1780
- cache_workers: int = 0,
1781
1478
  ) -> None:
1782
1479
  """Initialize class attributes."""
1783
1480
  super().__init__(
@@ -1795,8 +1492,6 @@ class SingleInstanceDataset(BaseDataset):
1795
1492
  cache_img_path=cache_img_path,
1796
1493
  use_existing_imgs=use_existing_imgs,
1797
1494
  rank=rank,
1798
- parallel_caching=parallel_caching,
1799
- cache_workers=cache_workers,
1800
1495
  )
1801
1496
  self.confmap_head_config = confmap_head_config
1802
1497
 
@@ -1835,6 +1530,9 @@ class SingleInstanceDataset(BaseDataset):
1835
1530
  user_instances_only=self.user_instances_only,
1836
1531
  )
1837
1532
 
1533
+ # apply normalization
1534
+ sample["image"] = apply_normalization(sample["image"])
1535
+
1838
1536
  if self.ensure_rgb:
1839
1537
  sample["image"] = convert_to_rgb(sample["image"])
1840
1538
  elif self.ensure_grayscale:
@@ -2015,10 +1713,6 @@ def get_train_val_datasets(
2015
1713
  val_cache_img_path = Path(base_cache_img_path) / "val_imgs"
2016
1714
  use_existing_imgs = config.data_config.use_existing_imgs
2017
1715
 
2018
- # Parallel caching configuration
2019
- parallel_caching = getattr(config.data_config, "parallel_caching", True)
2020
- cache_workers = getattr(config.data_config, "cache_workers", 0)
2021
-
2022
1716
  model_type = get_model_type_from_cfg(config=config)
2023
1717
  backbone_type = get_backbone_type_from_cfg(config=config)
2024
1718
 
@@ -2072,8 +1766,6 @@ def get_train_val_datasets(
2072
1766
  cache_img_path=train_cache_img_path,
2073
1767
  use_existing_imgs=use_existing_imgs,
2074
1768
  rank=rank,
2075
- parallel_caching=parallel_caching,
2076
- cache_workers=cache_workers,
2077
1769
  )
2078
1770
  val_dataset = BottomUpDataset(
2079
1771
  labels=val_labels,
@@ -2097,8 +1789,6 @@ def get_train_val_datasets(
2097
1789
  cache_img_path=val_cache_img_path,
2098
1790
  use_existing_imgs=use_existing_imgs,
2099
1791
  rank=rank,
2100
- parallel_caching=parallel_caching,
2101
- cache_workers=cache_workers,
2102
1792
  )
2103
1793
 
2104
1794
  elif model_type == "multi_class_bottomup":
@@ -2132,8 +1822,6 @@ def get_train_val_datasets(
2132
1822
  cache_img_path=train_cache_img_path,
2133
1823
  use_existing_imgs=use_existing_imgs,
2134
1824
  rank=rank,
2135
- parallel_caching=parallel_caching,
2136
- cache_workers=cache_workers,
2137
1825
  )
2138
1826
  val_dataset = BottomUpMultiClassDataset(
2139
1827
  labels=val_labels,
@@ -2157,8 +1845,6 @@ def get_train_val_datasets(
2157
1845
  cache_img_path=val_cache_img_path,
2158
1846
  use_existing_imgs=use_existing_imgs,
2159
1847
  rank=rank,
2160
- parallel_caching=parallel_caching,
2161
- cache_workers=cache_workers,
2162
1848
  )
2163
1849
 
2164
1850
  elif model_type == "centered_instance":
@@ -2198,8 +1884,6 @@ def get_train_val_datasets(
2198
1884
  cache_img_path=train_cache_img_path,
2199
1885
  use_existing_imgs=use_existing_imgs,
2200
1886
  rank=rank,
2201
- parallel_caching=parallel_caching,
2202
- cache_workers=cache_workers,
2203
1887
  )
2204
1888
  val_dataset = CenteredInstanceDataset(
2205
1889
  labels=val_labels,
@@ -2224,8 +1908,6 @@ def get_train_val_datasets(
2224
1908
  cache_img_path=val_cache_img_path,
2225
1909
  use_existing_imgs=use_existing_imgs,
2226
1910
  rank=rank,
2227
- parallel_caching=parallel_caching,
2228
- cache_workers=cache_workers,
2229
1911
  )
2230
1912
 
2231
1913
  elif model_type == "multi_class_topdown":
@@ -2266,8 +1948,6 @@ def get_train_val_datasets(
2266
1948
  cache_img_path=train_cache_img_path,
2267
1949
  use_existing_imgs=use_existing_imgs,
2268
1950
  rank=rank,
2269
- parallel_caching=parallel_caching,
2270
- cache_workers=cache_workers,
2271
1951
  )
2272
1952
  val_dataset = TopDownCenteredInstanceMultiClassDataset(
2273
1953
  labels=val_labels,
@@ -2293,8 +1973,6 @@ def get_train_val_datasets(
2293
1973
  cache_img_path=val_cache_img_path,
2294
1974
  use_existing_imgs=use_existing_imgs,
2295
1975
  rank=rank,
2296
- parallel_caching=parallel_caching,
2297
- cache_workers=cache_workers,
2298
1976
  )
2299
1977
 
2300
1978
  elif model_type == "centroid":
@@ -2331,8 +2009,6 @@ def get_train_val_datasets(
2331
2009
  cache_img_path=train_cache_img_path,
2332
2010
  use_existing_imgs=use_existing_imgs,
2333
2011
  rank=rank,
2334
- parallel_caching=parallel_caching,
2335
- cache_workers=cache_workers,
2336
2012
  )
2337
2013
  val_dataset = CentroidDataset(
2338
2014
  labels=val_labels,
@@ -2356,8 +2032,6 @@ def get_train_val_datasets(
2356
2032
  cache_img_path=val_cache_img_path,
2357
2033
  use_existing_imgs=use_existing_imgs,
2358
2034
  rank=rank,
2359
- parallel_caching=parallel_caching,
2360
- cache_workers=cache_workers,
2361
2035
  )
2362
2036
 
2363
2037
  else:
@@ -2390,8 +2064,6 @@ def get_train_val_datasets(
2390
2064
  cache_img_path=train_cache_img_path,
2391
2065
  use_existing_imgs=use_existing_imgs,
2392
2066
  rank=rank,
2393
- parallel_caching=parallel_caching,
2394
- cache_workers=cache_workers,
2395
2067
  )
2396
2068
  val_dataset = SingleInstanceDataset(
2397
2069
  labels=val_labels,
@@ -2414,8 +2086,6 @@ def get_train_val_datasets(
2414
2086
  cache_img_path=val_cache_img_path,
2415
2087
  use_existing_imgs=use_existing_imgs,
2416
2088
  rank=rank,
2417
- parallel_caching=parallel_caching,
2418
- cache_workers=cache_workers,
2419
2089
  )
2420
2090
 
2421
2091
  # If using caching, close the videos to prevent `h5py objects can't be pickled error` when num_workers > 0.
@@ -5,7 +5,7 @@ import math
5
5
  import numpy as np
6
6
  import sleap_io as sio
7
7
  import torch
8
- from sleap_nn.data.skia_augmentation import crop_and_resize_skia as crop_and_resize
8
+ from kornia.geometry.transform import crop_and_resize
9
9
 
10
10
 
11
11
  def compute_augmentation_padding(
sleap_nn/data/resizing.py CHANGED
@@ -63,7 +63,7 @@ def apply_pad_to_stride(image: torch.Tensor, max_stride: int) -> torch.Tensor:
63
63
  image,
64
64
  (0, pad_width, 0, pad_height),
65
65
  mode="constant",
66
- )
66
+ ).to(torch.float32)
67
67
  return image
68
68
 
69
69
 
@@ -136,7 +136,7 @@ def apply_sizematcher(
136
136
  image,
137
137
  (0, pad_width, 0, pad_height),
138
138
  mode="constant",
139
- )
139
+ ).to(torch.float32)
140
140
 
141
141
  return image, eff_scale_ratio
142
142
  else: