sleap-nn 0.1.0__py3-none-any.whl → 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +1 -1
- sleap_nn/architectures/convnext.py +0 -5
- sleap_nn/architectures/encoder_decoder.py +6 -25
- sleap_nn/architectures/swint.py +0 -8
- sleap_nn/cli.py +60 -364
- sleap_nn/config/data_config.py +5 -11
- sleap_nn/config/get_config.py +4 -5
- sleap_nn/config/trainer_config.py +0 -71
- sleap_nn/data/augmentation.py +241 -50
- sleap_nn/data/custom_datasets.py +34 -364
- sleap_nn/data/instance_cropping.py +1 -1
- sleap_nn/data/resizing.py +2 -2
- sleap_nn/data/utils.py +17 -135
- sleap_nn/evaluation.py +22 -81
- sleap_nn/inference/bottomup.py +20 -86
- sleap_nn/inference/peak_finding.py +19 -88
- sleap_nn/inference/predictors.py +117 -224
- sleap_nn/legacy_models.py +11 -65
- sleap_nn/predict.py +9 -37
- sleap_nn/train.py +4 -69
- sleap_nn/training/callbacks.py +105 -1046
- sleap_nn/training/lightning_modules.py +65 -602
- sleap_nn/training/model_trainer.py +204 -201
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/METADATA +3 -15
- sleap_nn-0.1.0a1.dist-info/RECORD +65 -0
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/WHEEL +1 -1
- sleap_nn/data/skia_augmentation.py +0 -414
- sleap_nn/export/__init__.py +0 -21
- sleap_nn/export/cli.py +0 -1778
- sleap_nn/export/exporters/__init__.py +0 -51
- sleap_nn/export/exporters/onnx_exporter.py +0 -80
- sleap_nn/export/exporters/tensorrt_exporter.py +0 -291
- sleap_nn/export/metadata.py +0 -225
- sleap_nn/export/predictors/__init__.py +0 -63
- sleap_nn/export/predictors/base.py +0 -22
- sleap_nn/export/predictors/onnx.py +0 -154
- sleap_nn/export/predictors/tensorrt.py +0 -312
- sleap_nn/export/utils.py +0 -307
- sleap_nn/export/wrappers/__init__.py +0 -25
- sleap_nn/export/wrappers/base.py +0 -96
- sleap_nn/export/wrappers/bottomup.py +0 -243
- sleap_nn/export/wrappers/bottomup_multiclass.py +0 -195
- sleap_nn/export/wrappers/centered_instance.py +0 -56
- sleap_nn/export/wrappers/centroid.py +0 -58
- sleap_nn/export/wrappers/single_instance.py +0 -83
- sleap_nn/export/wrappers/topdown.py +0 -180
- sleap_nn/export/wrappers/topdown_multiclass.py +0 -304
- sleap_nn/inference/postprocessing.py +0 -284
- sleap_nn/training/schedulers.py +0 -191
- sleap_nn-0.1.0.dist-info/RECORD +0 -88
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/top_level.txt +0 -0
sleap_nn/data/custom_datasets.py
CHANGED
|
@@ -1,11 +1,10 @@
|
|
|
1
1
|
"""Custom `torch.utils.data.Dataset`s for different model types."""
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from kornia.geometry.transform import crop_and_resize
|
|
4
4
|
|
|
5
|
-
import
|
|
6
|
-
import
|
|
7
|
-
|
|
8
|
-
from copy import deepcopy
|
|
5
|
+
# from concurrent.futures import ThreadPoolExecutor # TODO: implement parallel processing
|
|
6
|
+
# import concurrent.futures
|
|
7
|
+
# import os
|
|
9
8
|
from itertools import cycle
|
|
10
9
|
from pathlib import Path
|
|
11
10
|
import torch.distributed as dist
|
|
@@ -31,6 +30,7 @@ from sleap_nn.data.identity import generate_class_maps, make_class_vectors
|
|
|
31
30
|
from sleap_nn.data.instance_centroids import generate_centroids
|
|
32
31
|
from sleap_nn.data.instance_cropping import generate_crops
|
|
33
32
|
from sleap_nn.data.normalization import (
|
|
33
|
+
apply_normalization,
|
|
34
34
|
convert_to_grayscale,
|
|
35
35
|
convert_to_rgb,
|
|
36
36
|
)
|
|
@@ -46,182 +46,6 @@ from sleap_nn.data.instance_cropping import make_centered_bboxes
|
|
|
46
46
|
from sleap_nn.training.utils import is_distributed_initialized
|
|
47
47
|
from sleap_nn.config.get_config import get_aug_config
|
|
48
48
|
|
|
49
|
-
# Minimum number of samples to use parallel caching (overhead not worth it for smaller)
|
|
50
|
-
MIN_SAMPLES_FOR_PARALLEL_CACHING = 20
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
class ParallelCacheFiller:
|
|
54
|
-
"""Parallel implementation of image caching using thread-local video copies.
|
|
55
|
-
|
|
56
|
-
This class uses ThreadPoolExecutor to parallelize I/O-bound operations when
|
|
57
|
-
caching images to disk or memory. Each worker thread gets its own copy of
|
|
58
|
-
video objects to ensure thread safety.
|
|
59
|
-
|
|
60
|
-
Attributes:
|
|
61
|
-
labels: List of sio.Labels objects containing the data.
|
|
62
|
-
lf_idx_list: List of dictionaries with labeled frame indices.
|
|
63
|
-
cache_type: Either "disk" or "memory".
|
|
64
|
-
cache_path: Path to save cached images (for disk caching).
|
|
65
|
-
num_workers: Number of worker threads.
|
|
66
|
-
"""
|
|
67
|
-
|
|
68
|
-
def __init__(
|
|
69
|
-
self,
|
|
70
|
-
labels: List[sio.Labels],
|
|
71
|
-
lf_idx_list: List[Dict],
|
|
72
|
-
cache_type: str,
|
|
73
|
-
cache_path: Optional[Path] = None,
|
|
74
|
-
num_workers: int = 4,
|
|
75
|
-
):
|
|
76
|
-
"""Initialize the parallel cache filler.
|
|
77
|
-
|
|
78
|
-
Args:
|
|
79
|
-
labels: List of sio.Labels objects.
|
|
80
|
-
lf_idx_list: List of sample dictionaries with frame indices.
|
|
81
|
-
cache_type: Either "disk" or "memory".
|
|
82
|
-
cache_path: Path for disk caching.
|
|
83
|
-
num_workers: Number of worker threads.
|
|
84
|
-
"""
|
|
85
|
-
self.labels = labels
|
|
86
|
-
self.lf_idx_list = lf_idx_list
|
|
87
|
-
self.cache_type = cache_type
|
|
88
|
-
self.cache_path = cache_path
|
|
89
|
-
self.num_workers = num_workers
|
|
90
|
-
|
|
91
|
-
self.cache: Dict = {}
|
|
92
|
-
self._cache_lock = threading.Lock()
|
|
93
|
-
self._local = threading.local()
|
|
94
|
-
self._video_info: Dict = {}
|
|
95
|
-
|
|
96
|
-
# Prepare video copies for thread-local access
|
|
97
|
-
self._prepare_video_copies()
|
|
98
|
-
|
|
99
|
-
def _prepare_video_copies(self):
|
|
100
|
-
"""Close original videos and prepare for thread-local copies."""
|
|
101
|
-
for label in self.labels:
|
|
102
|
-
for video in label.videos:
|
|
103
|
-
vid_id = id(video)
|
|
104
|
-
if vid_id not in self._video_info:
|
|
105
|
-
# Store original state
|
|
106
|
-
original_open_backend = video.open_backend
|
|
107
|
-
|
|
108
|
-
# Close the video backend
|
|
109
|
-
video.close()
|
|
110
|
-
video.open_backend = False
|
|
111
|
-
|
|
112
|
-
self._video_info[vid_id] = {
|
|
113
|
-
"video": video,
|
|
114
|
-
"original_open_backend": original_open_backend,
|
|
115
|
-
}
|
|
116
|
-
|
|
117
|
-
def _get_thread_local_video(self, video: sio.Video) -> sio.Video:
|
|
118
|
-
"""Get or create a thread-local video copy.
|
|
119
|
-
|
|
120
|
-
Args:
|
|
121
|
-
video: The original video object.
|
|
122
|
-
|
|
123
|
-
Returns:
|
|
124
|
-
A thread-local copy of the video that is safe to use.
|
|
125
|
-
"""
|
|
126
|
-
vid_id = id(video)
|
|
127
|
-
|
|
128
|
-
if not hasattr(self._local, "videos"):
|
|
129
|
-
self._local.videos = {}
|
|
130
|
-
|
|
131
|
-
if vid_id not in self._local.videos:
|
|
132
|
-
# Create a thread-local copy
|
|
133
|
-
video_copy = deepcopy(video)
|
|
134
|
-
video_copy.open_backend = True
|
|
135
|
-
self._local.videos[vid_id] = video_copy
|
|
136
|
-
|
|
137
|
-
return self._local.videos[vid_id]
|
|
138
|
-
|
|
139
|
-
def _process_sample(
|
|
140
|
-
self, sample: Dict
|
|
141
|
-
) -> Tuple[int, int, Optional[np.ndarray], Optional[str]]:
|
|
142
|
-
"""Process a single sample (read image, optionally save/cache).
|
|
143
|
-
|
|
144
|
-
Args:
|
|
145
|
-
sample: Dictionary with labels_idx, lf_idx, etc.
|
|
146
|
-
|
|
147
|
-
Returns:
|
|
148
|
-
Tuple of (labels_idx, lf_idx, image_or_none, error_or_none).
|
|
149
|
-
"""
|
|
150
|
-
labels_idx = sample["labels_idx"]
|
|
151
|
-
lf_idx = sample["lf_idx"]
|
|
152
|
-
|
|
153
|
-
try:
|
|
154
|
-
# Get the labeled frame
|
|
155
|
-
lf = self.labels[labels_idx][lf_idx]
|
|
156
|
-
|
|
157
|
-
# Get thread-local video
|
|
158
|
-
video = self._get_thread_local_video(lf.video)
|
|
159
|
-
|
|
160
|
-
# Read the image
|
|
161
|
-
img = video[lf.frame_idx]
|
|
162
|
-
|
|
163
|
-
if img.shape[-1] == 1:
|
|
164
|
-
img = np.squeeze(img)
|
|
165
|
-
|
|
166
|
-
if self.cache_type == "disk":
|
|
167
|
-
f_name = self.cache_path / f"sample_{labels_idx}_{lf_idx}.jpg"
|
|
168
|
-
Image.fromarray(img).save(str(f_name), format="JPEG")
|
|
169
|
-
return labels_idx, lf_idx, None, None
|
|
170
|
-
elif self.cache_type == "memory":
|
|
171
|
-
return labels_idx, lf_idx, img, None
|
|
172
|
-
|
|
173
|
-
except Exception as e:
|
|
174
|
-
return labels_idx, lf_idx, None, f"{type(e).__name__}: {str(e)}"
|
|
175
|
-
|
|
176
|
-
def fill_cache(
|
|
177
|
-
self, progress_callback=None
|
|
178
|
-
) -> Tuple[Dict, List[Tuple[int, int, str]]]:
|
|
179
|
-
"""Fill the cache in parallel.
|
|
180
|
-
|
|
181
|
-
Args:
|
|
182
|
-
progress_callback: Optional callback(completed_count) for progress updates.
|
|
183
|
-
|
|
184
|
-
Returns:
|
|
185
|
-
Tuple of (cache_dict, list_of_errors).
|
|
186
|
-
"""
|
|
187
|
-
errors = []
|
|
188
|
-
completed = 0
|
|
189
|
-
|
|
190
|
-
with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
|
|
191
|
-
futures = {
|
|
192
|
-
executor.submit(self._process_sample, sample): sample
|
|
193
|
-
for sample in self.lf_idx_list
|
|
194
|
-
}
|
|
195
|
-
|
|
196
|
-
for future in as_completed(futures):
|
|
197
|
-
labels_idx, lf_idx, img, error = future.result()
|
|
198
|
-
|
|
199
|
-
if error:
|
|
200
|
-
errors.append((labels_idx, lf_idx, error))
|
|
201
|
-
elif self.cache_type == "memory" and img is not None:
|
|
202
|
-
with self._cache_lock:
|
|
203
|
-
self.cache[(labels_idx, lf_idx)] = img
|
|
204
|
-
|
|
205
|
-
completed += 1
|
|
206
|
-
if progress_callback:
|
|
207
|
-
progress_callback(completed)
|
|
208
|
-
|
|
209
|
-
# Restore original video states
|
|
210
|
-
self._restore_videos()
|
|
211
|
-
|
|
212
|
-
return self.cache, errors
|
|
213
|
-
|
|
214
|
-
def _restore_videos(self):
|
|
215
|
-
"""Restore original video states after caching is complete."""
|
|
216
|
-
for vid_info in self._video_info.values():
|
|
217
|
-
video = vid_info["video"]
|
|
218
|
-
video.open_backend = vid_info["original_open_backend"]
|
|
219
|
-
if video.open_backend:
|
|
220
|
-
try:
|
|
221
|
-
video.open()
|
|
222
|
-
except Exception:
|
|
223
|
-
pass
|
|
224
|
-
|
|
225
49
|
|
|
226
50
|
class BaseDataset(Dataset):
|
|
227
51
|
"""Base class for custom torch Datasets.
|
|
@@ -260,8 +84,6 @@ class BaseDataset(Dataset):
|
|
|
260
84
|
use_existing_imgs: Use existing imgs/ chunks in the `cache_img_path`.
|
|
261
85
|
rank: Indicates the rank of the process. Used during distributed training to ensure that image storage to
|
|
262
86
|
disk occurs only once across all workers.
|
|
263
|
-
parallel_caching: If True, use parallel processing for caching (faster for large datasets). Default: True.
|
|
264
|
-
cache_workers: Number of worker threads for parallel caching. If 0, uses min(4, cpu_count). Default: 0.
|
|
265
87
|
labels_list: List of `sio.Labels` objects. Used to store the labels in the cache. (only used if `cache_img` is `None`)
|
|
266
88
|
"""
|
|
267
89
|
|
|
@@ -281,8 +103,6 @@ class BaseDataset(Dataset):
|
|
|
281
103
|
cache_img_path: Optional[str] = None,
|
|
282
104
|
use_existing_imgs: bool = False,
|
|
283
105
|
rank: Optional[int] = None,
|
|
284
|
-
parallel_caching: bool = True,
|
|
285
|
-
cache_workers: int = 0,
|
|
286
106
|
) -> None:
|
|
287
107
|
"""Initialize class attributes."""
|
|
288
108
|
super().__init__()
|
|
@@ -323,8 +143,6 @@ class BaseDataset(Dataset):
|
|
|
323
143
|
self.cache_img = cache_img
|
|
324
144
|
self.cache_img_path = cache_img_path
|
|
325
145
|
self.use_existing_imgs = use_existing_imgs
|
|
326
|
-
self.parallel_caching = parallel_caching
|
|
327
|
-
self.cache_workers = cache_workers
|
|
328
146
|
if self.cache_img is not None and "disk" in self.cache_img:
|
|
329
147
|
if self.cache_img_path is None:
|
|
330
148
|
self.cache_img_path = "."
|
|
@@ -350,18 +168,10 @@ class BaseDataset(Dataset):
|
|
|
350
168
|
|
|
351
169
|
if self.cache_img is not None:
|
|
352
170
|
if self.cache_img == "memory":
|
|
353
|
-
self._fill_cache(
|
|
354
|
-
labels,
|
|
355
|
-
parallel=self.parallel_caching,
|
|
356
|
-
num_workers=self.cache_workers,
|
|
357
|
-
)
|
|
171
|
+
self._fill_cache(labels)
|
|
358
172
|
elif self.cache_img == "disk" and not self.use_existing_imgs:
|
|
359
173
|
if self.rank is None or self.rank == -1 or self.rank == 0:
|
|
360
|
-
self._fill_cache(
|
|
361
|
-
labels,
|
|
362
|
-
parallel=self.parallel_caching,
|
|
363
|
-
num_workers=self.cache_workers,
|
|
364
|
-
)
|
|
174
|
+
self._fill_cache(labels)
|
|
365
175
|
# Synchronize all ranks after cache creation
|
|
366
176
|
if is_distributed_initialized():
|
|
367
177
|
dist.barrier()
|
|
@@ -410,60 +220,21 @@ class BaseDataset(Dataset):
|
|
|
410
220
|
"""Returns an iterator."""
|
|
411
221
|
return self
|
|
412
222
|
|
|
413
|
-
def _fill_cache(
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
"""Load all samples to cache.
|
|
420
|
-
|
|
421
|
-
Args:
|
|
422
|
-
labels: List of sio.Labels objects containing the data.
|
|
423
|
-
parallel: If True, use parallel processing for caching (faster for large
|
|
424
|
-
datasets). Default: True.
|
|
425
|
-
num_workers: Number of worker threads for parallel caching. If 0, uses
|
|
426
|
-
min(4, cpu_count). Default: 0.
|
|
427
|
-
"""
|
|
223
|
+
def _fill_cache(self, labels: List[sio.Labels]):
|
|
224
|
+
"""Load all samples to cache."""
|
|
225
|
+
# TODO: Implement parallel processing (using threads might cause error with MediaVideo backend)
|
|
226
|
+
import os
|
|
227
|
+
import sys
|
|
228
|
+
|
|
428
229
|
total_samples = len(self.lf_idx_list)
|
|
429
230
|
cache_type = "disk" if self.cache_img == "disk" else "memory"
|
|
430
231
|
|
|
431
|
-
# Check for NO_COLOR env var
|
|
232
|
+
# Check for NO_COLOR env var or non-interactive terminal
|
|
432
233
|
no_color = (
|
|
433
234
|
os.environ.get("NO_COLOR") is not None
|
|
434
235
|
or os.environ.get("FORCE_COLOR") == "0"
|
|
435
236
|
)
|
|
436
|
-
use_progress = not no_color
|
|
437
|
-
|
|
438
|
-
# Use parallel caching for larger datasets
|
|
439
|
-
use_parallel = parallel and total_samples >= MIN_SAMPLES_FOR_PARALLEL_CACHING
|
|
440
|
-
|
|
441
|
-
logger.info(f"Caching {total_samples} images to {cache_type}...")
|
|
442
|
-
|
|
443
|
-
if use_parallel:
|
|
444
|
-
self._fill_cache_parallel(
|
|
445
|
-
labels, total_samples, cache_type, use_progress, num_workers
|
|
446
|
-
)
|
|
447
|
-
else:
|
|
448
|
-
self._fill_cache_sequential(labels, total_samples, cache_type, use_progress)
|
|
449
|
-
|
|
450
|
-
logger.info(f"Caching complete.")
|
|
451
|
-
|
|
452
|
-
def _fill_cache_sequential(
|
|
453
|
-
self,
|
|
454
|
-
labels: List[sio.Labels],
|
|
455
|
-
total_samples: int,
|
|
456
|
-
cache_type: str,
|
|
457
|
-
use_progress: bool,
|
|
458
|
-
):
|
|
459
|
-
"""Sequential implementation of cache filling.
|
|
460
|
-
|
|
461
|
-
Args:
|
|
462
|
-
labels: List of sio.Labels objects.
|
|
463
|
-
total_samples: Total number of samples to cache.
|
|
464
|
-
cache_type: Either "disk" or "memory".
|
|
465
|
-
use_progress: Whether to show a progress bar.
|
|
466
|
-
"""
|
|
237
|
+
use_progress = sys.stdout.isatty() and not no_color
|
|
467
238
|
|
|
468
239
|
def process_samples(progress=None, task=None):
|
|
469
240
|
for sample in self.lf_idx_list:
|
|
@@ -495,76 +266,9 @@ class BaseDataset(Dataset):
|
|
|
495
266
|
)
|
|
496
267
|
process_samples(progress, task)
|
|
497
268
|
else:
|
|
269
|
+
logger.info(f"Caching {total_samples} images to {cache_type}...")
|
|
498
270
|
process_samples()
|
|
499
271
|
|
|
500
|
-
def _fill_cache_parallel(
|
|
501
|
-
self,
|
|
502
|
-
labels: List[sio.Labels],
|
|
503
|
-
total_samples: int,
|
|
504
|
-
cache_type: str,
|
|
505
|
-
use_progress: bool,
|
|
506
|
-
num_workers: int = 0,
|
|
507
|
-
):
|
|
508
|
-
"""Parallel implementation of cache filling using thread-local video copies.
|
|
509
|
-
|
|
510
|
-
Args:
|
|
511
|
-
labels: List of sio.Labels objects.
|
|
512
|
-
total_samples: Total number of samples to cache.
|
|
513
|
-
cache_type: Either "disk" or "memory".
|
|
514
|
-
use_progress: Whether to show a progress bar.
|
|
515
|
-
num_workers: Number of worker threads. If 0, uses min(4, cpu_count).
|
|
516
|
-
"""
|
|
517
|
-
# Determine number of workers
|
|
518
|
-
if num_workers <= 0:
|
|
519
|
-
num_workers = min(4, os.cpu_count() or 1)
|
|
520
|
-
|
|
521
|
-
cache_path = Path(self.cache_img_path) if self.cache_img_path else None
|
|
522
|
-
|
|
523
|
-
filler = ParallelCacheFiller(
|
|
524
|
-
labels=labels,
|
|
525
|
-
lf_idx_list=self.lf_idx_list,
|
|
526
|
-
cache_type=cache_type,
|
|
527
|
-
cache_path=cache_path,
|
|
528
|
-
num_workers=num_workers,
|
|
529
|
-
)
|
|
530
|
-
|
|
531
|
-
if use_progress:
|
|
532
|
-
with Progress(
|
|
533
|
-
SpinnerColumn(),
|
|
534
|
-
TextColumn("[progress.description]{task.description}"),
|
|
535
|
-
BarColumn(),
|
|
536
|
-
TextColumn("{task.completed}/{task.total}"),
|
|
537
|
-
TimeElapsedColumn(),
|
|
538
|
-
console=Console(force_terminal=True),
|
|
539
|
-
transient=True,
|
|
540
|
-
) as progress:
|
|
541
|
-
task = progress.add_task(
|
|
542
|
-
f"Caching images to {cache_type} (parallel, {num_workers} workers)",
|
|
543
|
-
total=total_samples,
|
|
544
|
-
)
|
|
545
|
-
|
|
546
|
-
def progress_callback(completed):
|
|
547
|
-
progress.update(task, completed=completed)
|
|
548
|
-
|
|
549
|
-
cache, errors = filler.fill_cache(progress_callback)
|
|
550
|
-
else:
|
|
551
|
-
logger.info(
|
|
552
|
-
f"Caching {total_samples} images to {cache_type} "
|
|
553
|
-
f"(parallel, {num_workers} workers)..."
|
|
554
|
-
)
|
|
555
|
-
cache, errors = filler.fill_cache()
|
|
556
|
-
|
|
557
|
-
# Update instance cache
|
|
558
|
-
if cache_type == "memory":
|
|
559
|
-
self.cache.update(cache)
|
|
560
|
-
|
|
561
|
-
# Log any errors
|
|
562
|
-
if errors:
|
|
563
|
-
logger.warning(
|
|
564
|
-
f"Parallel caching completed with {len(errors)} errors. "
|
|
565
|
-
f"First error: {errors[0]}"
|
|
566
|
-
)
|
|
567
|
-
|
|
568
272
|
def __len__(self) -> int:
|
|
569
273
|
"""Return the number of samples in the dataset."""
|
|
570
274
|
return len(self.lf_idx_list)
|
|
@@ -639,8 +343,6 @@ class BottomUpDataset(BaseDataset):
|
|
|
639
343
|
cache_img_path: Optional[str] = None,
|
|
640
344
|
use_existing_imgs: bool = False,
|
|
641
345
|
rank: Optional[int] = None,
|
|
642
|
-
parallel_caching: bool = True,
|
|
643
|
-
cache_workers: int = 0,
|
|
644
346
|
) -> None:
|
|
645
347
|
"""Initialize class attributes."""
|
|
646
348
|
super().__init__(
|
|
@@ -658,8 +360,6 @@ class BottomUpDataset(BaseDataset):
|
|
|
658
360
|
cache_img_path=cache_img_path,
|
|
659
361
|
use_existing_imgs=use_existing_imgs,
|
|
660
362
|
rank=rank,
|
|
661
|
-
parallel_caching=parallel_caching,
|
|
662
|
-
cache_workers=cache_workers,
|
|
663
363
|
)
|
|
664
364
|
self.confmap_head_config = confmap_head_config
|
|
665
365
|
self.pafs_head_config = pafs_head_config
|
|
@@ -702,6 +402,9 @@ class BottomUpDataset(BaseDataset):
|
|
|
702
402
|
user_instances_only=self.user_instances_only,
|
|
703
403
|
)
|
|
704
404
|
|
|
405
|
+
# apply normalization
|
|
406
|
+
sample["image"] = apply_normalization(sample["image"])
|
|
407
|
+
|
|
705
408
|
if self.ensure_rgb:
|
|
706
409
|
sample["image"] = convert_to_rgb(sample["image"])
|
|
707
410
|
elif self.ensure_grayscale:
|
|
@@ -839,8 +542,6 @@ class BottomUpMultiClassDataset(BaseDataset):
|
|
|
839
542
|
cache_img_path: Optional[str] = None,
|
|
840
543
|
use_existing_imgs: bool = False,
|
|
841
544
|
rank: Optional[int] = None,
|
|
842
|
-
parallel_caching: bool = True,
|
|
843
|
-
cache_workers: int = 0,
|
|
844
545
|
) -> None:
|
|
845
546
|
"""Initialize class attributes."""
|
|
846
547
|
super().__init__(
|
|
@@ -858,8 +559,6 @@ class BottomUpMultiClassDataset(BaseDataset):
|
|
|
858
559
|
cache_img_path=cache_img_path,
|
|
859
560
|
use_existing_imgs=use_existing_imgs,
|
|
860
561
|
rank=rank,
|
|
861
|
-
parallel_caching=parallel_caching,
|
|
862
|
-
cache_workers=cache_workers,
|
|
863
562
|
)
|
|
864
563
|
self.confmap_head_config = confmap_head_config
|
|
865
564
|
self.class_maps_head_config = class_maps_head_config
|
|
@@ -916,6 +615,9 @@ class BottomUpMultiClassDataset(BaseDataset):
|
|
|
916
615
|
|
|
917
616
|
sample["num_tracks"] = torch.tensor(len(self.class_names), dtype=torch.int32)
|
|
918
617
|
|
|
618
|
+
# apply normalization
|
|
619
|
+
sample["image"] = apply_normalization(sample["image"])
|
|
620
|
+
|
|
919
621
|
if self.ensure_rgb:
|
|
920
622
|
sample["image"] = convert_to_rgb(sample["image"])
|
|
921
623
|
elif self.ensure_grayscale:
|
|
@@ -1054,8 +756,6 @@ class CenteredInstanceDataset(BaseDataset):
|
|
|
1054
756
|
cache_img_path: Optional[str] = None,
|
|
1055
757
|
use_existing_imgs: bool = False,
|
|
1056
758
|
rank: Optional[int] = None,
|
|
1057
|
-
parallel_caching: bool = True,
|
|
1058
|
-
cache_workers: int = 0,
|
|
1059
759
|
) -> None:
|
|
1060
760
|
"""Initialize class attributes."""
|
|
1061
761
|
super().__init__(
|
|
@@ -1073,8 +773,6 @@ class CenteredInstanceDataset(BaseDataset):
|
|
|
1073
773
|
cache_img_path=cache_img_path,
|
|
1074
774
|
use_existing_imgs=use_existing_imgs,
|
|
1075
775
|
rank=rank,
|
|
1076
|
-
parallel_caching=parallel_caching,
|
|
1077
|
-
cache_workers=cache_workers,
|
|
1078
776
|
)
|
|
1079
777
|
self.labels = None
|
|
1080
778
|
self.crop_size = crop_size
|
|
@@ -1165,6 +863,9 @@ class CenteredInstanceDataset(BaseDataset):
|
|
|
1165
863
|
|
|
1166
864
|
instances = instances[:, inst_idx]
|
|
1167
865
|
|
|
866
|
+
# apply normalization
|
|
867
|
+
image = apply_normalization(image)
|
|
868
|
+
|
|
1168
869
|
if self.ensure_rgb:
|
|
1169
870
|
image = convert_to_rgb(image)
|
|
1170
871
|
elif self.ensure_grayscale:
|
|
@@ -1333,8 +1034,6 @@ class TopDownCenteredInstanceMultiClassDataset(CenteredInstanceDataset):
|
|
|
1333
1034
|
cache_img_path: Optional[str] = None,
|
|
1334
1035
|
use_existing_imgs: bool = False,
|
|
1335
1036
|
rank: Optional[int] = None,
|
|
1336
|
-
parallel_caching: bool = True,
|
|
1337
|
-
cache_workers: int = 0,
|
|
1338
1037
|
) -> None:
|
|
1339
1038
|
"""Initialize class attributes."""
|
|
1340
1039
|
super().__init__(
|
|
@@ -1355,8 +1054,6 @@ class TopDownCenteredInstanceMultiClassDataset(CenteredInstanceDataset):
|
|
|
1355
1054
|
cache_img_path=cache_img_path,
|
|
1356
1055
|
use_existing_imgs=use_existing_imgs,
|
|
1357
1056
|
rank=rank,
|
|
1358
|
-
parallel_caching=parallel_caching,
|
|
1359
|
-
cache_workers=cache_workers,
|
|
1360
1057
|
)
|
|
1361
1058
|
self.class_vectors_head_config = class_vectors_head_config
|
|
1362
1059
|
self.class_names = self.class_vectors_head_config.classes
|
|
@@ -1411,6 +1108,9 @@ class TopDownCenteredInstanceMultiClassDataset(CenteredInstanceDataset):
|
|
|
1411
1108
|
|
|
1412
1109
|
instances = instances[:, inst_idx]
|
|
1413
1110
|
|
|
1111
|
+
# apply normalization
|
|
1112
|
+
image = apply_normalization(image)
|
|
1113
|
+
|
|
1414
1114
|
if self.ensure_rgb:
|
|
1415
1115
|
image = convert_to_rgb(image)
|
|
1416
1116
|
elif self.ensure_grayscale:
|
|
@@ -1592,8 +1292,6 @@ class CentroidDataset(BaseDataset):
|
|
|
1592
1292
|
cache_img_path: Optional[str] = None,
|
|
1593
1293
|
use_existing_imgs: bool = False,
|
|
1594
1294
|
rank: Optional[int] = None,
|
|
1595
|
-
parallel_caching: bool = True,
|
|
1596
|
-
cache_workers: int = 0,
|
|
1597
1295
|
) -> None:
|
|
1598
1296
|
"""Initialize class attributes."""
|
|
1599
1297
|
super().__init__(
|
|
@@ -1611,8 +1309,6 @@ class CentroidDataset(BaseDataset):
|
|
|
1611
1309
|
cache_img_path=cache_img_path,
|
|
1612
1310
|
use_existing_imgs=use_existing_imgs,
|
|
1613
1311
|
rank=rank,
|
|
1614
|
-
parallel_caching=parallel_caching,
|
|
1615
|
-
cache_workers=cache_workers,
|
|
1616
1312
|
)
|
|
1617
1313
|
self.anchor_ind = anchor_ind
|
|
1618
1314
|
self.confmap_head_config = confmap_head_config
|
|
@@ -1652,6 +1348,9 @@ class CentroidDataset(BaseDataset):
|
|
|
1652
1348
|
user_instances_only=self.user_instances_only,
|
|
1653
1349
|
)
|
|
1654
1350
|
|
|
1351
|
+
# apply normalization
|
|
1352
|
+
sample["image"] = apply_normalization(sample["image"])
|
|
1353
|
+
|
|
1655
1354
|
if self.ensure_rgb:
|
|
1656
1355
|
sample["image"] = convert_to_rgb(sample["image"])
|
|
1657
1356
|
elif self.ensure_grayscale:
|
|
@@ -1776,8 +1475,6 @@ class SingleInstanceDataset(BaseDataset):
|
|
|
1776
1475
|
cache_img_path: Optional[str] = None,
|
|
1777
1476
|
use_existing_imgs: bool = False,
|
|
1778
1477
|
rank: Optional[int] = None,
|
|
1779
|
-
parallel_caching: bool = True,
|
|
1780
|
-
cache_workers: int = 0,
|
|
1781
1478
|
) -> None:
|
|
1782
1479
|
"""Initialize class attributes."""
|
|
1783
1480
|
super().__init__(
|
|
@@ -1795,8 +1492,6 @@ class SingleInstanceDataset(BaseDataset):
|
|
|
1795
1492
|
cache_img_path=cache_img_path,
|
|
1796
1493
|
use_existing_imgs=use_existing_imgs,
|
|
1797
1494
|
rank=rank,
|
|
1798
|
-
parallel_caching=parallel_caching,
|
|
1799
|
-
cache_workers=cache_workers,
|
|
1800
1495
|
)
|
|
1801
1496
|
self.confmap_head_config = confmap_head_config
|
|
1802
1497
|
|
|
@@ -1835,6 +1530,9 @@ class SingleInstanceDataset(BaseDataset):
|
|
|
1835
1530
|
user_instances_only=self.user_instances_only,
|
|
1836
1531
|
)
|
|
1837
1532
|
|
|
1533
|
+
# apply normalization
|
|
1534
|
+
sample["image"] = apply_normalization(sample["image"])
|
|
1535
|
+
|
|
1838
1536
|
if self.ensure_rgb:
|
|
1839
1537
|
sample["image"] = convert_to_rgb(sample["image"])
|
|
1840
1538
|
elif self.ensure_grayscale:
|
|
@@ -2015,10 +1713,6 @@ def get_train_val_datasets(
|
|
|
2015
1713
|
val_cache_img_path = Path(base_cache_img_path) / "val_imgs"
|
|
2016
1714
|
use_existing_imgs = config.data_config.use_existing_imgs
|
|
2017
1715
|
|
|
2018
|
-
# Parallel caching configuration
|
|
2019
|
-
parallel_caching = getattr(config.data_config, "parallel_caching", True)
|
|
2020
|
-
cache_workers = getattr(config.data_config, "cache_workers", 0)
|
|
2021
|
-
|
|
2022
1716
|
model_type = get_model_type_from_cfg(config=config)
|
|
2023
1717
|
backbone_type = get_backbone_type_from_cfg(config=config)
|
|
2024
1718
|
|
|
@@ -2072,8 +1766,6 @@ def get_train_val_datasets(
|
|
|
2072
1766
|
cache_img_path=train_cache_img_path,
|
|
2073
1767
|
use_existing_imgs=use_existing_imgs,
|
|
2074
1768
|
rank=rank,
|
|
2075
|
-
parallel_caching=parallel_caching,
|
|
2076
|
-
cache_workers=cache_workers,
|
|
2077
1769
|
)
|
|
2078
1770
|
val_dataset = BottomUpDataset(
|
|
2079
1771
|
labels=val_labels,
|
|
@@ -2097,8 +1789,6 @@ def get_train_val_datasets(
|
|
|
2097
1789
|
cache_img_path=val_cache_img_path,
|
|
2098
1790
|
use_existing_imgs=use_existing_imgs,
|
|
2099
1791
|
rank=rank,
|
|
2100
|
-
parallel_caching=parallel_caching,
|
|
2101
|
-
cache_workers=cache_workers,
|
|
2102
1792
|
)
|
|
2103
1793
|
|
|
2104
1794
|
elif model_type == "multi_class_bottomup":
|
|
@@ -2132,8 +1822,6 @@ def get_train_val_datasets(
|
|
|
2132
1822
|
cache_img_path=train_cache_img_path,
|
|
2133
1823
|
use_existing_imgs=use_existing_imgs,
|
|
2134
1824
|
rank=rank,
|
|
2135
|
-
parallel_caching=parallel_caching,
|
|
2136
|
-
cache_workers=cache_workers,
|
|
2137
1825
|
)
|
|
2138
1826
|
val_dataset = BottomUpMultiClassDataset(
|
|
2139
1827
|
labels=val_labels,
|
|
@@ -2157,8 +1845,6 @@ def get_train_val_datasets(
|
|
|
2157
1845
|
cache_img_path=val_cache_img_path,
|
|
2158
1846
|
use_existing_imgs=use_existing_imgs,
|
|
2159
1847
|
rank=rank,
|
|
2160
|
-
parallel_caching=parallel_caching,
|
|
2161
|
-
cache_workers=cache_workers,
|
|
2162
1848
|
)
|
|
2163
1849
|
|
|
2164
1850
|
elif model_type == "centered_instance":
|
|
@@ -2198,8 +1884,6 @@ def get_train_val_datasets(
|
|
|
2198
1884
|
cache_img_path=train_cache_img_path,
|
|
2199
1885
|
use_existing_imgs=use_existing_imgs,
|
|
2200
1886
|
rank=rank,
|
|
2201
|
-
parallel_caching=parallel_caching,
|
|
2202
|
-
cache_workers=cache_workers,
|
|
2203
1887
|
)
|
|
2204
1888
|
val_dataset = CenteredInstanceDataset(
|
|
2205
1889
|
labels=val_labels,
|
|
@@ -2224,8 +1908,6 @@ def get_train_val_datasets(
|
|
|
2224
1908
|
cache_img_path=val_cache_img_path,
|
|
2225
1909
|
use_existing_imgs=use_existing_imgs,
|
|
2226
1910
|
rank=rank,
|
|
2227
|
-
parallel_caching=parallel_caching,
|
|
2228
|
-
cache_workers=cache_workers,
|
|
2229
1911
|
)
|
|
2230
1912
|
|
|
2231
1913
|
elif model_type == "multi_class_topdown":
|
|
@@ -2266,8 +1948,6 @@ def get_train_val_datasets(
|
|
|
2266
1948
|
cache_img_path=train_cache_img_path,
|
|
2267
1949
|
use_existing_imgs=use_existing_imgs,
|
|
2268
1950
|
rank=rank,
|
|
2269
|
-
parallel_caching=parallel_caching,
|
|
2270
|
-
cache_workers=cache_workers,
|
|
2271
1951
|
)
|
|
2272
1952
|
val_dataset = TopDownCenteredInstanceMultiClassDataset(
|
|
2273
1953
|
labels=val_labels,
|
|
@@ -2293,8 +1973,6 @@ def get_train_val_datasets(
|
|
|
2293
1973
|
cache_img_path=val_cache_img_path,
|
|
2294
1974
|
use_existing_imgs=use_existing_imgs,
|
|
2295
1975
|
rank=rank,
|
|
2296
|
-
parallel_caching=parallel_caching,
|
|
2297
|
-
cache_workers=cache_workers,
|
|
2298
1976
|
)
|
|
2299
1977
|
|
|
2300
1978
|
elif model_type == "centroid":
|
|
@@ -2331,8 +2009,6 @@ def get_train_val_datasets(
|
|
|
2331
2009
|
cache_img_path=train_cache_img_path,
|
|
2332
2010
|
use_existing_imgs=use_existing_imgs,
|
|
2333
2011
|
rank=rank,
|
|
2334
|
-
parallel_caching=parallel_caching,
|
|
2335
|
-
cache_workers=cache_workers,
|
|
2336
2012
|
)
|
|
2337
2013
|
val_dataset = CentroidDataset(
|
|
2338
2014
|
labels=val_labels,
|
|
@@ -2356,8 +2032,6 @@ def get_train_val_datasets(
|
|
|
2356
2032
|
cache_img_path=val_cache_img_path,
|
|
2357
2033
|
use_existing_imgs=use_existing_imgs,
|
|
2358
2034
|
rank=rank,
|
|
2359
|
-
parallel_caching=parallel_caching,
|
|
2360
|
-
cache_workers=cache_workers,
|
|
2361
2035
|
)
|
|
2362
2036
|
|
|
2363
2037
|
else:
|
|
@@ -2390,8 +2064,6 @@ def get_train_val_datasets(
|
|
|
2390
2064
|
cache_img_path=train_cache_img_path,
|
|
2391
2065
|
use_existing_imgs=use_existing_imgs,
|
|
2392
2066
|
rank=rank,
|
|
2393
|
-
parallel_caching=parallel_caching,
|
|
2394
|
-
cache_workers=cache_workers,
|
|
2395
2067
|
)
|
|
2396
2068
|
val_dataset = SingleInstanceDataset(
|
|
2397
2069
|
labels=val_labels,
|
|
@@ -2414,8 +2086,6 @@ def get_train_val_datasets(
|
|
|
2414
2086
|
cache_img_path=val_cache_img_path,
|
|
2415
2087
|
use_existing_imgs=use_existing_imgs,
|
|
2416
2088
|
rank=rank,
|
|
2417
|
-
parallel_caching=parallel_caching,
|
|
2418
|
-
cache_workers=cache_workers,
|
|
2419
2089
|
)
|
|
2420
2090
|
|
|
2421
2091
|
# If using caching, close the videos to prevent `h5py objects can't be pickled error` when num_workers > 0.
|
sleap_nn/data/resizing.py
CHANGED
|
@@ -63,7 +63,7 @@ def apply_pad_to_stride(image: torch.Tensor, max_stride: int) -> torch.Tensor:
|
|
|
63
63
|
image,
|
|
64
64
|
(0, pad_width, 0, pad_height),
|
|
65
65
|
mode="constant",
|
|
66
|
-
)
|
|
66
|
+
).to(torch.float32)
|
|
67
67
|
return image
|
|
68
68
|
|
|
69
69
|
|
|
@@ -136,7 +136,7 @@ def apply_sizematcher(
|
|
|
136
136
|
image,
|
|
137
137
|
(0, pad_width, 0, pad_height),
|
|
138
138
|
mode="constant",
|
|
139
|
-
)
|
|
139
|
+
).to(torch.float32)
|
|
140
140
|
|
|
141
141
|
return image, eff_scale_ratio
|
|
142
142
|
else:
|