sleap-nn 0.1.0__py3-none-any.whl → 0.1.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +2 -4
- sleap_nn/architectures/convnext.py +0 -5
- sleap_nn/architectures/encoder_decoder.py +6 -25
- sleap_nn/architectures/swint.py +0 -8
- sleap_nn/cli.py +60 -364
- sleap_nn/config/data_config.py +5 -11
- sleap_nn/config/get_config.py +4 -10
- sleap_nn/config/trainer_config.py +0 -76
- sleap_nn/data/augmentation.py +241 -50
- sleap_nn/data/custom_datasets.py +39 -411
- sleap_nn/data/instance_cropping.py +1 -1
- sleap_nn/data/resizing.py +2 -2
- sleap_nn/data/utils.py +17 -135
- sleap_nn/evaluation.py +22 -81
- sleap_nn/inference/bottomup.py +20 -86
- sleap_nn/inference/peak_finding.py +19 -88
- sleap_nn/inference/predictors.py +117 -224
- sleap_nn/legacy_models.py +11 -65
- sleap_nn/predict.py +9 -37
- sleap_nn/train.py +4 -74
- sleap_nn/training/callbacks.py +105 -1046
- sleap_nn/training/lightning_modules.py +65 -602
- sleap_nn/training/model_trainer.py +184 -211
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/METADATA +3 -15
- sleap_nn-0.1.0a0.dist-info/RECORD +65 -0
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/WHEEL +1 -1
- sleap_nn/data/skia_augmentation.py +0 -414
- sleap_nn/export/__init__.py +0 -21
- sleap_nn/export/cli.py +0 -1778
- sleap_nn/export/exporters/__init__.py +0 -51
- sleap_nn/export/exporters/onnx_exporter.py +0 -80
- sleap_nn/export/exporters/tensorrt_exporter.py +0 -291
- sleap_nn/export/metadata.py +0 -225
- sleap_nn/export/predictors/__init__.py +0 -63
- sleap_nn/export/predictors/base.py +0 -22
- sleap_nn/export/predictors/onnx.py +0 -154
- sleap_nn/export/predictors/tensorrt.py +0 -312
- sleap_nn/export/utils.py +0 -307
- sleap_nn/export/wrappers/__init__.py +0 -25
- sleap_nn/export/wrappers/base.py +0 -96
- sleap_nn/export/wrappers/bottomup.py +0 -243
- sleap_nn/export/wrappers/bottomup_multiclass.py +0 -195
- sleap_nn/export/wrappers/centered_instance.py +0 -56
- sleap_nn/export/wrappers/centroid.py +0 -58
- sleap_nn/export/wrappers/single_instance.py +0 -83
- sleap_nn/export/wrappers/topdown.py +0 -180
- sleap_nn/export/wrappers/topdown_multiclass.py +0 -304
- sleap_nn/inference/postprocessing.py +0 -284
- sleap_nn/training/schedulers.py +0 -191
- sleap_nn-0.1.0.dist-info/RECORD +0 -88
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/top_level.txt +0 -0
sleap_nn/data/custom_datasets.py
CHANGED
|
@@ -1,11 +1,10 @@
|
|
|
1
1
|
"""Custom `torch.utils.data.Dataset`s for different model types."""
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from kornia.geometry.transform import crop_and_resize
|
|
4
4
|
|
|
5
|
-
import
|
|
6
|
-
import
|
|
7
|
-
|
|
8
|
-
from copy import deepcopy
|
|
5
|
+
# from concurrent.futures import ThreadPoolExecutor # TODO: implement parallel processing
|
|
6
|
+
# import concurrent.futures
|
|
7
|
+
# import os
|
|
9
8
|
from itertools import cycle
|
|
10
9
|
from pathlib import Path
|
|
11
10
|
import torch.distributed as dist
|
|
@@ -14,14 +13,6 @@ from omegaconf import DictConfig, OmegaConf
|
|
|
14
13
|
import numpy as np
|
|
15
14
|
from PIL import Image
|
|
16
15
|
from loguru import logger
|
|
17
|
-
from rich.progress import (
|
|
18
|
-
Progress,
|
|
19
|
-
SpinnerColumn,
|
|
20
|
-
TextColumn,
|
|
21
|
-
BarColumn,
|
|
22
|
-
TimeElapsedColumn,
|
|
23
|
-
)
|
|
24
|
-
from rich.console import Console
|
|
25
16
|
import torch
|
|
26
17
|
import torchvision.transforms as T
|
|
27
18
|
from torch.utils.data import Dataset, DataLoader, DistributedSampler
|
|
@@ -31,6 +22,7 @@ from sleap_nn.data.identity import generate_class_maps, make_class_vectors
|
|
|
31
22
|
from sleap_nn.data.instance_centroids import generate_centroids
|
|
32
23
|
from sleap_nn.data.instance_cropping import generate_crops
|
|
33
24
|
from sleap_nn.data.normalization import (
|
|
25
|
+
apply_normalization,
|
|
34
26
|
convert_to_grayscale,
|
|
35
27
|
convert_to_rgb,
|
|
36
28
|
)
|
|
@@ -46,182 +38,6 @@ from sleap_nn.data.instance_cropping import make_centered_bboxes
|
|
|
46
38
|
from sleap_nn.training.utils import is_distributed_initialized
|
|
47
39
|
from sleap_nn.config.get_config import get_aug_config
|
|
48
40
|
|
|
49
|
-
# Minimum number of samples to use parallel caching (overhead not worth it for smaller)
|
|
50
|
-
MIN_SAMPLES_FOR_PARALLEL_CACHING = 20
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
class ParallelCacheFiller:
|
|
54
|
-
"""Parallel implementation of image caching using thread-local video copies.
|
|
55
|
-
|
|
56
|
-
This class uses ThreadPoolExecutor to parallelize I/O-bound operations when
|
|
57
|
-
caching images to disk or memory. Each worker thread gets its own copy of
|
|
58
|
-
video objects to ensure thread safety.
|
|
59
|
-
|
|
60
|
-
Attributes:
|
|
61
|
-
labels: List of sio.Labels objects containing the data.
|
|
62
|
-
lf_idx_list: List of dictionaries with labeled frame indices.
|
|
63
|
-
cache_type: Either "disk" or "memory".
|
|
64
|
-
cache_path: Path to save cached images (for disk caching).
|
|
65
|
-
num_workers: Number of worker threads.
|
|
66
|
-
"""
|
|
67
|
-
|
|
68
|
-
def __init__(
|
|
69
|
-
self,
|
|
70
|
-
labels: List[sio.Labels],
|
|
71
|
-
lf_idx_list: List[Dict],
|
|
72
|
-
cache_type: str,
|
|
73
|
-
cache_path: Optional[Path] = None,
|
|
74
|
-
num_workers: int = 4,
|
|
75
|
-
):
|
|
76
|
-
"""Initialize the parallel cache filler.
|
|
77
|
-
|
|
78
|
-
Args:
|
|
79
|
-
labels: List of sio.Labels objects.
|
|
80
|
-
lf_idx_list: List of sample dictionaries with frame indices.
|
|
81
|
-
cache_type: Either "disk" or "memory".
|
|
82
|
-
cache_path: Path for disk caching.
|
|
83
|
-
num_workers: Number of worker threads.
|
|
84
|
-
"""
|
|
85
|
-
self.labels = labels
|
|
86
|
-
self.lf_idx_list = lf_idx_list
|
|
87
|
-
self.cache_type = cache_type
|
|
88
|
-
self.cache_path = cache_path
|
|
89
|
-
self.num_workers = num_workers
|
|
90
|
-
|
|
91
|
-
self.cache: Dict = {}
|
|
92
|
-
self._cache_lock = threading.Lock()
|
|
93
|
-
self._local = threading.local()
|
|
94
|
-
self._video_info: Dict = {}
|
|
95
|
-
|
|
96
|
-
# Prepare video copies for thread-local access
|
|
97
|
-
self._prepare_video_copies()
|
|
98
|
-
|
|
99
|
-
def _prepare_video_copies(self):
|
|
100
|
-
"""Close original videos and prepare for thread-local copies."""
|
|
101
|
-
for label in self.labels:
|
|
102
|
-
for video in label.videos:
|
|
103
|
-
vid_id = id(video)
|
|
104
|
-
if vid_id not in self._video_info:
|
|
105
|
-
# Store original state
|
|
106
|
-
original_open_backend = video.open_backend
|
|
107
|
-
|
|
108
|
-
# Close the video backend
|
|
109
|
-
video.close()
|
|
110
|
-
video.open_backend = False
|
|
111
|
-
|
|
112
|
-
self._video_info[vid_id] = {
|
|
113
|
-
"video": video,
|
|
114
|
-
"original_open_backend": original_open_backend,
|
|
115
|
-
}
|
|
116
|
-
|
|
117
|
-
def _get_thread_local_video(self, video: sio.Video) -> sio.Video:
|
|
118
|
-
"""Get or create a thread-local video copy.
|
|
119
|
-
|
|
120
|
-
Args:
|
|
121
|
-
video: The original video object.
|
|
122
|
-
|
|
123
|
-
Returns:
|
|
124
|
-
A thread-local copy of the video that is safe to use.
|
|
125
|
-
"""
|
|
126
|
-
vid_id = id(video)
|
|
127
|
-
|
|
128
|
-
if not hasattr(self._local, "videos"):
|
|
129
|
-
self._local.videos = {}
|
|
130
|
-
|
|
131
|
-
if vid_id not in self._local.videos:
|
|
132
|
-
# Create a thread-local copy
|
|
133
|
-
video_copy = deepcopy(video)
|
|
134
|
-
video_copy.open_backend = True
|
|
135
|
-
self._local.videos[vid_id] = video_copy
|
|
136
|
-
|
|
137
|
-
return self._local.videos[vid_id]
|
|
138
|
-
|
|
139
|
-
def _process_sample(
|
|
140
|
-
self, sample: Dict
|
|
141
|
-
) -> Tuple[int, int, Optional[np.ndarray], Optional[str]]:
|
|
142
|
-
"""Process a single sample (read image, optionally save/cache).
|
|
143
|
-
|
|
144
|
-
Args:
|
|
145
|
-
sample: Dictionary with labels_idx, lf_idx, etc.
|
|
146
|
-
|
|
147
|
-
Returns:
|
|
148
|
-
Tuple of (labels_idx, lf_idx, image_or_none, error_or_none).
|
|
149
|
-
"""
|
|
150
|
-
labels_idx = sample["labels_idx"]
|
|
151
|
-
lf_idx = sample["lf_idx"]
|
|
152
|
-
|
|
153
|
-
try:
|
|
154
|
-
# Get the labeled frame
|
|
155
|
-
lf = self.labels[labels_idx][lf_idx]
|
|
156
|
-
|
|
157
|
-
# Get thread-local video
|
|
158
|
-
video = self._get_thread_local_video(lf.video)
|
|
159
|
-
|
|
160
|
-
# Read the image
|
|
161
|
-
img = video[lf.frame_idx]
|
|
162
|
-
|
|
163
|
-
if img.shape[-1] == 1:
|
|
164
|
-
img = np.squeeze(img)
|
|
165
|
-
|
|
166
|
-
if self.cache_type == "disk":
|
|
167
|
-
f_name = self.cache_path / f"sample_{labels_idx}_{lf_idx}.jpg"
|
|
168
|
-
Image.fromarray(img).save(str(f_name), format="JPEG")
|
|
169
|
-
return labels_idx, lf_idx, None, None
|
|
170
|
-
elif self.cache_type == "memory":
|
|
171
|
-
return labels_idx, lf_idx, img, None
|
|
172
|
-
|
|
173
|
-
except Exception as e:
|
|
174
|
-
return labels_idx, lf_idx, None, f"{type(e).__name__}: {str(e)}"
|
|
175
|
-
|
|
176
|
-
def fill_cache(
|
|
177
|
-
self, progress_callback=None
|
|
178
|
-
) -> Tuple[Dict, List[Tuple[int, int, str]]]:
|
|
179
|
-
"""Fill the cache in parallel.
|
|
180
|
-
|
|
181
|
-
Args:
|
|
182
|
-
progress_callback: Optional callback(completed_count) for progress updates.
|
|
183
|
-
|
|
184
|
-
Returns:
|
|
185
|
-
Tuple of (cache_dict, list_of_errors).
|
|
186
|
-
"""
|
|
187
|
-
errors = []
|
|
188
|
-
completed = 0
|
|
189
|
-
|
|
190
|
-
with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
|
|
191
|
-
futures = {
|
|
192
|
-
executor.submit(self._process_sample, sample): sample
|
|
193
|
-
for sample in self.lf_idx_list
|
|
194
|
-
}
|
|
195
|
-
|
|
196
|
-
for future in as_completed(futures):
|
|
197
|
-
labels_idx, lf_idx, img, error = future.result()
|
|
198
|
-
|
|
199
|
-
if error:
|
|
200
|
-
errors.append((labels_idx, lf_idx, error))
|
|
201
|
-
elif self.cache_type == "memory" and img is not None:
|
|
202
|
-
with self._cache_lock:
|
|
203
|
-
self.cache[(labels_idx, lf_idx)] = img
|
|
204
|
-
|
|
205
|
-
completed += 1
|
|
206
|
-
if progress_callback:
|
|
207
|
-
progress_callback(completed)
|
|
208
|
-
|
|
209
|
-
# Restore original video states
|
|
210
|
-
self._restore_videos()
|
|
211
|
-
|
|
212
|
-
return self.cache, errors
|
|
213
|
-
|
|
214
|
-
def _restore_videos(self):
|
|
215
|
-
"""Restore original video states after caching is complete."""
|
|
216
|
-
for vid_info in self._video_info.values():
|
|
217
|
-
video = vid_info["video"]
|
|
218
|
-
video.open_backend = vid_info["original_open_backend"]
|
|
219
|
-
if video.open_backend:
|
|
220
|
-
try:
|
|
221
|
-
video.open()
|
|
222
|
-
except Exception:
|
|
223
|
-
pass
|
|
224
|
-
|
|
225
41
|
|
|
226
42
|
class BaseDataset(Dataset):
|
|
227
43
|
"""Base class for custom torch Datasets.
|
|
@@ -260,8 +76,6 @@ class BaseDataset(Dataset):
|
|
|
260
76
|
use_existing_imgs: Use existing imgs/ chunks in the `cache_img_path`.
|
|
261
77
|
rank: Indicates the rank of the process. Used during distributed training to ensure that image storage to
|
|
262
78
|
disk occurs only once across all workers.
|
|
263
|
-
parallel_caching: If True, use parallel processing for caching (faster for large datasets). Default: True.
|
|
264
|
-
cache_workers: Number of worker threads for parallel caching. If 0, uses min(4, cpu_count). Default: 0.
|
|
265
79
|
labels_list: List of `sio.Labels` objects. Used to store the labels in the cache. (only used if `cache_img` is `None`)
|
|
266
80
|
"""
|
|
267
81
|
|
|
@@ -281,8 +95,6 @@ class BaseDataset(Dataset):
|
|
|
281
95
|
cache_img_path: Optional[str] = None,
|
|
282
96
|
use_existing_imgs: bool = False,
|
|
283
97
|
rank: Optional[int] = None,
|
|
284
|
-
parallel_caching: bool = True,
|
|
285
|
-
cache_workers: int = 0,
|
|
286
98
|
) -> None:
|
|
287
99
|
"""Initialize class attributes."""
|
|
288
100
|
super().__init__()
|
|
@@ -323,8 +135,6 @@ class BaseDataset(Dataset):
|
|
|
323
135
|
self.cache_img = cache_img
|
|
324
136
|
self.cache_img_path = cache_img_path
|
|
325
137
|
self.use_existing_imgs = use_existing_imgs
|
|
326
|
-
self.parallel_caching = parallel_caching
|
|
327
|
-
self.cache_workers = cache_workers
|
|
328
138
|
if self.cache_img is not None and "disk" in self.cache_img:
|
|
329
139
|
if self.cache_img_path is None:
|
|
330
140
|
self.cache_img_path = "."
|
|
@@ -350,18 +160,10 @@ class BaseDataset(Dataset):
|
|
|
350
160
|
|
|
351
161
|
if self.cache_img is not None:
|
|
352
162
|
if self.cache_img == "memory":
|
|
353
|
-
self._fill_cache(
|
|
354
|
-
labels,
|
|
355
|
-
parallel=self.parallel_caching,
|
|
356
|
-
num_workers=self.cache_workers,
|
|
357
|
-
)
|
|
163
|
+
self._fill_cache(labels)
|
|
358
164
|
elif self.cache_img == "disk" and not self.use_existing_imgs:
|
|
359
165
|
if self.rank is None or self.rank == -1 or self.rank == 0:
|
|
360
|
-
self._fill_cache(
|
|
361
|
-
labels,
|
|
362
|
-
parallel=self.parallel_caching,
|
|
363
|
-
num_workers=self.cache_workers,
|
|
364
|
-
)
|
|
166
|
+
self._fill_cache(labels)
|
|
365
167
|
# Synchronize all ranks after cache creation
|
|
366
168
|
if is_distributed_initialized():
|
|
367
169
|
dist.barrier()
|
|
@@ -410,160 +212,20 @@ class BaseDataset(Dataset):
|
|
|
410
212
|
"""Returns an iterator."""
|
|
411
213
|
return self
|
|
412
214
|
|
|
413
|
-
def _fill_cache(
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
"""
|
|
428
|
-
total_samples = len(self.lf_idx_list)
|
|
429
|
-
cache_type = "disk" if self.cache_img == "disk" else "memory"
|
|
430
|
-
|
|
431
|
-
# Check for NO_COLOR env var to disable progress bar
|
|
432
|
-
no_color = (
|
|
433
|
-
os.environ.get("NO_COLOR") is not None
|
|
434
|
-
or os.environ.get("FORCE_COLOR") == "0"
|
|
435
|
-
)
|
|
436
|
-
use_progress = not no_color
|
|
437
|
-
|
|
438
|
-
# Use parallel caching for larger datasets
|
|
439
|
-
use_parallel = parallel and total_samples >= MIN_SAMPLES_FOR_PARALLEL_CACHING
|
|
440
|
-
|
|
441
|
-
logger.info(f"Caching {total_samples} images to {cache_type}...")
|
|
442
|
-
|
|
443
|
-
if use_parallel:
|
|
444
|
-
self._fill_cache_parallel(
|
|
445
|
-
labels, total_samples, cache_type, use_progress, num_workers
|
|
446
|
-
)
|
|
447
|
-
else:
|
|
448
|
-
self._fill_cache_sequential(labels, total_samples, cache_type, use_progress)
|
|
449
|
-
|
|
450
|
-
logger.info(f"Caching complete.")
|
|
451
|
-
|
|
452
|
-
def _fill_cache_sequential(
|
|
453
|
-
self,
|
|
454
|
-
labels: List[sio.Labels],
|
|
455
|
-
total_samples: int,
|
|
456
|
-
cache_type: str,
|
|
457
|
-
use_progress: bool,
|
|
458
|
-
):
|
|
459
|
-
"""Sequential implementation of cache filling.
|
|
460
|
-
|
|
461
|
-
Args:
|
|
462
|
-
labels: List of sio.Labels objects.
|
|
463
|
-
total_samples: Total number of samples to cache.
|
|
464
|
-
cache_type: Either "disk" or "memory".
|
|
465
|
-
use_progress: Whether to show a progress bar.
|
|
466
|
-
"""
|
|
467
|
-
|
|
468
|
-
def process_samples(progress=None, task=None):
|
|
469
|
-
for sample in self.lf_idx_list:
|
|
470
|
-
labels_idx = sample["labels_idx"]
|
|
471
|
-
lf_idx = sample["lf_idx"]
|
|
472
|
-
img = labels[labels_idx][lf_idx].image
|
|
473
|
-
if img.shape[-1] == 1:
|
|
474
|
-
img = np.squeeze(img)
|
|
475
|
-
if self.cache_img == "disk":
|
|
476
|
-
f_name = f"{self.cache_img_path}/sample_{labels_idx}_{lf_idx}.jpg"
|
|
477
|
-
Image.fromarray(img).save(f_name, format="JPEG")
|
|
478
|
-
if self.cache_img == "memory":
|
|
479
|
-
self.cache[(labels_idx, lf_idx)] = img
|
|
480
|
-
if progress is not None:
|
|
481
|
-
progress.update(task, advance=1)
|
|
482
|
-
|
|
483
|
-
if use_progress:
|
|
484
|
-
with Progress(
|
|
485
|
-
SpinnerColumn(),
|
|
486
|
-
TextColumn("[progress.description]{task.description}"),
|
|
487
|
-
BarColumn(),
|
|
488
|
-
TextColumn("{task.completed}/{task.total}"),
|
|
489
|
-
TimeElapsedColumn(),
|
|
490
|
-
console=Console(force_terminal=True),
|
|
491
|
-
transient=True,
|
|
492
|
-
) as progress:
|
|
493
|
-
task = progress.add_task(
|
|
494
|
-
f"Caching images to {cache_type}", total=total_samples
|
|
495
|
-
)
|
|
496
|
-
process_samples(progress, task)
|
|
497
|
-
else:
|
|
498
|
-
process_samples()
|
|
499
|
-
|
|
500
|
-
def _fill_cache_parallel(
|
|
501
|
-
self,
|
|
502
|
-
labels: List[sio.Labels],
|
|
503
|
-
total_samples: int,
|
|
504
|
-
cache_type: str,
|
|
505
|
-
use_progress: bool,
|
|
506
|
-
num_workers: int = 0,
|
|
507
|
-
):
|
|
508
|
-
"""Parallel implementation of cache filling using thread-local video copies.
|
|
509
|
-
|
|
510
|
-
Args:
|
|
511
|
-
labels: List of sio.Labels objects.
|
|
512
|
-
total_samples: Total number of samples to cache.
|
|
513
|
-
cache_type: Either "disk" or "memory".
|
|
514
|
-
use_progress: Whether to show a progress bar.
|
|
515
|
-
num_workers: Number of worker threads. If 0, uses min(4, cpu_count).
|
|
516
|
-
"""
|
|
517
|
-
# Determine number of workers
|
|
518
|
-
if num_workers <= 0:
|
|
519
|
-
num_workers = min(4, os.cpu_count() or 1)
|
|
520
|
-
|
|
521
|
-
cache_path = Path(self.cache_img_path) if self.cache_img_path else None
|
|
522
|
-
|
|
523
|
-
filler = ParallelCacheFiller(
|
|
524
|
-
labels=labels,
|
|
525
|
-
lf_idx_list=self.lf_idx_list,
|
|
526
|
-
cache_type=cache_type,
|
|
527
|
-
cache_path=cache_path,
|
|
528
|
-
num_workers=num_workers,
|
|
529
|
-
)
|
|
530
|
-
|
|
531
|
-
if use_progress:
|
|
532
|
-
with Progress(
|
|
533
|
-
SpinnerColumn(),
|
|
534
|
-
TextColumn("[progress.description]{task.description}"),
|
|
535
|
-
BarColumn(),
|
|
536
|
-
TextColumn("{task.completed}/{task.total}"),
|
|
537
|
-
TimeElapsedColumn(),
|
|
538
|
-
console=Console(force_terminal=True),
|
|
539
|
-
transient=True,
|
|
540
|
-
) as progress:
|
|
541
|
-
task = progress.add_task(
|
|
542
|
-
f"Caching images to {cache_type} (parallel, {num_workers} workers)",
|
|
543
|
-
total=total_samples,
|
|
544
|
-
)
|
|
545
|
-
|
|
546
|
-
def progress_callback(completed):
|
|
547
|
-
progress.update(task, completed=completed)
|
|
548
|
-
|
|
549
|
-
cache, errors = filler.fill_cache(progress_callback)
|
|
550
|
-
else:
|
|
551
|
-
logger.info(
|
|
552
|
-
f"Caching {total_samples} images to {cache_type} "
|
|
553
|
-
f"(parallel, {num_workers} workers)..."
|
|
554
|
-
)
|
|
555
|
-
cache, errors = filler.fill_cache()
|
|
556
|
-
|
|
557
|
-
# Update instance cache
|
|
558
|
-
if cache_type == "memory":
|
|
559
|
-
self.cache.update(cache)
|
|
560
|
-
|
|
561
|
-
# Log any errors
|
|
562
|
-
if errors:
|
|
563
|
-
logger.warning(
|
|
564
|
-
f"Parallel caching completed with {len(errors)} errors. "
|
|
565
|
-
f"First error: {errors[0]}"
|
|
566
|
-
)
|
|
215
|
+
def _fill_cache(self, labels: List[sio.Labels]):
|
|
216
|
+
"""Load all samples to cache."""
|
|
217
|
+
# TODO: Implement parallel processing (using threads might cause error with MediaVideo backend)
|
|
218
|
+
for sample in self.lf_idx_list:
|
|
219
|
+
labels_idx = sample["labels_idx"]
|
|
220
|
+
lf_idx = sample["lf_idx"]
|
|
221
|
+
img = labels[labels_idx][lf_idx].image
|
|
222
|
+
if img.shape[-1] == 1:
|
|
223
|
+
img = np.squeeze(img)
|
|
224
|
+
if self.cache_img == "disk":
|
|
225
|
+
f_name = f"{self.cache_img_path}/sample_{labels_idx}_{lf_idx}.jpg"
|
|
226
|
+
Image.fromarray(img).save(f_name, format="JPEG")
|
|
227
|
+
if self.cache_img == "memory":
|
|
228
|
+
self.cache[(labels_idx, lf_idx)] = img
|
|
567
229
|
|
|
568
230
|
def __len__(self) -> int:
|
|
569
231
|
"""Return the number of samples in the dataset."""
|
|
@@ -639,8 +301,6 @@ class BottomUpDataset(BaseDataset):
|
|
|
639
301
|
cache_img_path: Optional[str] = None,
|
|
640
302
|
use_existing_imgs: bool = False,
|
|
641
303
|
rank: Optional[int] = None,
|
|
642
|
-
parallel_caching: bool = True,
|
|
643
|
-
cache_workers: int = 0,
|
|
644
304
|
) -> None:
|
|
645
305
|
"""Initialize class attributes."""
|
|
646
306
|
super().__init__(
|
|
@@ -658,8 +318,6 @@ class BottomUpDataset(BaseDataset):
|
|
|
658
318
|
cache_img_path=cache_img_path,
|
|
659
319
|
use_existing_imgs=use_existing_imgs,
|
|
660
320
|
rank=rank,
|
|
661
|
-
parallel_caching=parallel_caching,
|
|
662
|
-
cache_workers=cache_workers,
|
|
663
321
|
)
|
|
664
322
|
self.confmap_head_config = confmap_head_config
|
|
665
323
|
self.pafs_head_config = pafs_head_config
|
|
@@ -702,6 +360,9 @@ class BottomUpDataset(BaseDataset):
|
|
|
702
360
|
user_instances_only=self.user_instances_only,
|
|
703
361
|
)
|
|
704
362
|
|
|
363
|
+
# apply normalization
|
|
364
|
+
sample["image"] = apply_normalization(sample["image"])
|
|
365
|
+
|
|
705
366
|
if self.ensure_rgb:
|
|
706
367
|
sample["image"] = convert_to_rgb(sample["image"])
|
|
707
368
|
elif self.ensure_grayscale:
|
|
@@ -839,8 +500,6 @@ class BottomUpMultiClassDataset(BaseDataset):
|
|
|
839
500
|
cache_img_path: Optional[str] = None,
|
|
840
501
|
use_existing_imgs: bool = False,
|
|
841
502
|
rank: Optional[int] = None,
|
|
842
|
-
parallel_caching: bool = True,
|
|
843
|
-
cache_workers: int = 0,
|
|
844
503
|
) -> None:
|
|
845
504
|
"""Initialize class attributes."""
|
|
846
505
|
super().__init__(
|
|
@@ -858,8 +517,6 @@ class BottomUpMultiClassDataset(BaseDataset):
|
|
|
858
517
|
cache_img_path=cache_img_path,
|
|
859
518
|
use_existing_imgs=use_existing_imgs,
|
|
860
519
|
rank=rank,
|
|
861
|
-
parallel_caching=parallel_caching,
|
|
862
|
-
cache_workers=cache_workers,
|
|
863
520
|
)
|
|
864
521
|
self.confmap_head_config = confmap_head_config
|
|
865
522
|
self.class_maps_head_config = class_maps_head_config
|
|
@@ -916,6 +573,9 @@ class BottomUpMultiClassDataset(BaseDataset):
|
|
|
916
573
|
|
|
917
574
|
sample["num_tracks"] = torch.tensor(len(self.class_names), dtype=torch.int32)
|
|
918
575
|
|
|
576
|
+
# apply normalization
|
|
577
|
+
sample["image"] = apply_normalization(sample["image"])
|
|
578
|
+
|
|
919
579
|
if self.ensure_rgb:
|
|
920
580
|
sample["image"] = convert_to_rgb(sample["image"])
|
|
921
581
|
elif self.ensure_grayscale:
|
|
@@ -1054,8 +714,6 @@ class CenteredInstanceDataset(BaseDataset):
|
|
|
1054
714
|
cache_img_path: Optional[str] = None,
|
|
1055
715
|
use_existing_imgs: bool = False,
|
|
1056
716
|
rank: Optional[int] = None,
|
|
1057
|
-
parallel_caching: bool = True,
|
|
1058
|
-
cache_workers: int = 0,
|
|
1059
717
|
) -> None:
|
|
1060
718
|
"""Initialize class attributes."""
|
|
1061
719
|
super().__init__(
|
|
@@ -1073,8 +731,6 @@ class CenteredInstanceDataset(BaseDataset):
|
|
|
1073
731
|
cache_img_path=cache_img_path,
|
|
1074
732
|
use_existing_imgs=use_existing_imgs,
|
|
1075
733
|
rank=rank,
|
|
1076
|
-
parallel_caching=parallel_caching,
|
|
1077
|
-
cache_workers=cache_workers,
|
|
1078
734
|
)
|
|
1079
735
|
self.labels = None
|
|
1080
736
|
self.crop_size = crop_size
|
|
@@ -1165,6 +821,9 @@ class CenteredInstanceDataset(BaseDataset):
|
|
|
1165
821
|
|
|
1166
822
|
instances = instances[:, inst_idx]
|
|
1167
823
|
|
|
824
|
+
# apply normalization
|
|
825
|
+
image = apply_normalization(image)
|
|
826
|
+
|
|
1168
827
|
if self.ensure_rgb:
|
|
1169
828
|
image = convert_to_rgb(image)
|
|
1170
829
|
elif self.ensure_grayscale:
|
|
@@ -1333,8 +992,6 @@ class TopDownCenteredInstanceMultiClassDataset(CenteredInstanceDataset):
|
|
|
1333
992
|
cache_img_path: Optional[str] = None,
|
|
1334
993
|
use_existing_imgs: bool = False,
|
|
1335
994
|
rank: Optional[int] = None,
|
|
1336
|
-
parallel_caching: bool = True,
|
|
1337
|
-
cache_workers: int = 0,
|
|
1338
995
|
) -> None:
|
|
1339
996
|
"""Initialize class attributes."""
|
|
1340
997
|
super().__init__(
|
|
@@ -1355,8 +1012,6 @@ class TopDownCenteredInstanceMultiClassDataset(CenteredInstanceDataset):
|
|
|
1355
1012
|
cache_img_path=cache_img_path,
|
|
1356
1013
|
use_existing_imgs=use_existing_imgs,
|
|
1357
1014
|
rank=rank,
|
|
1358
|
-
parallel_caching=parallel_caching,
|
|
1359
|
-
cache_workers=cache_workers,
|
|
1360
1015
|
)
|
|
1361
1016
|
self.class_vectors_head_config = class_vectors_head_config
|
|
1362
1017
|
self.class_names = self.class_vectors_head_config.classes
|
|
@@ -1411,6 +1066,9 @@ class TopDownCenteredInstanceMultiClassDataset(CenteredInstanceDataset):
|
|
|
1411
1066
|
|
|
1412
1067
|
instances = instances[:, inst_idx]
|
|
1413
1068
|
|
|
1069
|
+
# apply normalization
|
|
1070
|
+
image = apply_normalization(image)
|
|
1071
|
+
|
|
1414
1072
|
if self.ensure_rgb:
|
|
1415
1073
|
image = convert_to_rgb(image)
|
|
1416
1074
|
elif self.ensure_grayscale:
|
|
@@ -1592,8 +1250,6 @@ class CentroidDataset(BaseDataset):
|
|
|
1592
1250
|
cache_img_path: Optional[str] = None,
|
|
1593
1251
|
use_existing_imgs: bool = False,
|
|
1594
1252
|
rank: Optional[int] = None,
|
|
1595
|
-
parallel_caching: bool = True,
|
|
1596
|
-
cache_workers: int = 0,
|
|
1597
1253
|
) -> None:
|
|
1598
1254
|
"""Initialize class attributes."""
|
|
1599
1255
|
super().__init__(
|
|
@@ -1611,8 +1267,6 @@ class CentroidDataset(BaseDataset):
|
|
|
1611
1267
|
cache_img_path=cache_img_path,
|
|
1612
1268
|
use_existing_imgs=use_existing_imgs,
|
|
1613
1269
|
rank=rank,
|
|
1614
|
-
parallel_caching=parallel_caching,
|
|
1615
|
-
cache_workers=cache_workers,
|
|
1616
1270
|
)
|
|
1617
1271
|
self.anchor_ind = anchor_ind
|
|
1618
1272
|
self.confmap_head_config = confmap_head_config
|
|
@@ -1652,6 +1306,9 @@ class CentroidDataset(BaseDataset):
|
|
|
1652
1306
|
user_instances_only=self.user_instances_only,
|
|
1653
1307
|
)
|
|
1654
1308
|
|
|
1309
|
+
# apply normalization
|
|
1310
|
+
sample["image"] = apply_normalization(sample["image"])
|
|
1311
|
+
|
|
1655
1312
|
if self.ensure_rgb:
|
|
1656
1313
|
sample["image"] = convert_to_rgb(sample["image"])
|
|
1657
1314
|
elif self.ensure_grayscale:
|
|
@@ -1776,8 +1433,6 @@ class SingleInstanceDataset(BaseDataset):
|
|
|
1776
1433
|
cache_img_path: Optional[str] = None,
|
|
1777
1434
|
use_existing_imgs: bool = False,
|
|
1778
1435
|
rank: Optional[int] = None,
|
|
1779
|
-
parallel_caching: bool = True,
|
|
1780
|
-
cache_workers: int = 0,
|
|
1781
1436
|
) -> None:
|
|
1782
1437
|
"""Initialize class attributes."""
|
|
1783
1438
|
super().__init__(
|
|
@@ -1795,8 +1450,6 @@ class SingleInstanceDataset(BaseDataset):
|
|
|
1795
1450
|
cache_img_path=cache_img_path,
|
|
1796
1451
|
use_existing_imgs=use_existing_imgs,
|
|
1797
1452
|
rank=rank,
|
|
1798
|
-
parallel_caching=parallel_caching,
|
|
1799
|
-
cache_workers=cache_workers,
|
|
1800
1453
|
)
|
|
1801
1454
|
self.confmap_head_config = confmap_head_config
|
|
1802
1455
|
|
|
@@ -1835,6 +1488,9 @@ class SingleInstanceDataset(BaseDataset):
|
|
|
1835
1488
|
user_instances_only=self.user_instances_only,
|
|
1836
1489
|
)
|
|
1837
1490
|
|
|
1491
|
+
# apply normalization
|
|
1492
|
+
sample["image"] = apply_normalization(sample["image"])
|
|
1493
|
+
|
|
1838
1494
|
if self.ensure_rgb:
|
|
1839
1495
|
sample["image"] = convert_to_rgb(sample["image"])
|
|
1840
1496
|
elif self.ensure_grayscale:
|
|
@@ -2015,10 +1671,6 @@ def get_train_val_datasets(
|
|
|
2015
1671
|
val_cache_img_path = Path(base_cache_img_path) / "val_imgs"
|
|
2016
1672
|
use_existing_imgs = config.data_config.use_existing_imgs
|
|
2017
1673
|
|
|
2018
|
-
# Parallel caching configuration
|
|
2019
|
-
parallel_caching = getattr(config.data_config, "parallel_caching", True)
|
|
2020
|
-
cache_workers = getattr(config.data_config, "cache_workers", 0)
|
|
2021
|
-
|
|
2022
1674
|
model_type = get_model_type_from_cfg(config=config)
|
|
2023
1675
|
backbone_type = get_backbone_type_from_cfg(config=config)
|
|
2024
1676
|
|
|
@@ -2072,8 +1724,6 @@ def get_train_val_datasets(
|
|
|
2072
1724
|
cache_img_path=train_cache_img_path,
|
|
2073
1725
|
use_existing_imgs=use_existing_imgs,
|
|
2074
1726
|
rank=rank,
|
|
2075
|
-
parallel_caching=parallel_caching,
|
|
2076
|
-
cache_workers=cache_workers,
|
|
2077
1727
|
)
|
|
2078
1728
|
val_dataset = BottomUpDataset(
|
|
2079
1729
|
labels=val_labels,
|
|
@@ -2097,8 +1747,6 @@ def get_train_val_datasets(
|
|
|
2097
1747
|
cache_img_path=val_cache_img_path,
|
|
2098
1748
|
use_existing_imgs=use_existing_imgs,
|
|
2099
1749
|
rank=rank,
|
|
2100
|
-
parallel_caching=parallel_caching,
|
|
2101
|
-
cache_workers=cache_workers,
|
|
2102
1750
|
)
|
|
2103
1751
|
|
|
2104
1752
|
elif model_type == "multi_class_bottomup":
|
|
@@ -2132,8 +1780,6 @@ def get_train_val_datasets(
|
|
|
2132
1780
|
cache_img_path=train_cache_img_path,
|
|
2133
1781
|
use_existing_imgs=use_existing_imgs,
|
|
2134
1782
|
rank=rank,
|
|
2135
|
-
parallel_caching=parallel_caching,
|
|
2136
|
-
cache_workers=cache_workers,
|
|
2137
1783
|
)
|
|
2138
1784
|
val_dataset = BottomUpMultiClassDataset(
|
|
2139
1785
|
labels=val_labels,
|
|
@@ -2157,8 +1803,6 @@ def get_train_val_datasets(
|
|
|
2157
1803
|
cache_img_path=val_cache_img_path,
|
|
2158
1804
|
use_existing_imgs=use_existing_imgs,
|
|
2159
1805
|
rank=rank,
|
|
2160
|
-
parallel_caching=parallel_caching,
|
|
2161
|
-
cache_workers=cache_workers,
|
|
2162
1806
|
)
|
|
2163
1807
|
|
|
2164
1808
|
elif model_type == "centered_instance":
|
|
@@ -2198,8 +1842,6 @@ def get_train_val_datasets(
|
|
|
2198
1842
|
cache_img_path=train_cache_img_path,
|
|
2199
1843
|
use_existing_imgs=use_existing_imgs,
|
|
2200
1844
|
rank=rank,
|
|
2201
|
-
parallel_caching=parallel_caching,
|
|
2202
|
-
cache_workers=cache_workers,
|
|
2203
1845
|
)
|
|
2204
1846
|
val_dataset = CenteredInstanceDataset(
|
|
2205
1847
|
labels=val_labels,
|
|
@@ -2224,8 +1866,6 @@ def get_train_val_datasets(
|
|
|
2224
1866
|
cache_img_path=val_cache_img_path,
|
|
2225
1867
|
use_existing_imgs=use_existing_imgs,
|
|
2226
1868
|
rank=rank,
|
|
2227
|
-
parallel_caching=parallel_caching,
|
|
2228
|
-
cache_workers=cache_workers,
|
|
2229
1869
|
)
|
|
2230
1870
|
|
|
2231
1871
|
elif model_type == "multi_class_topdown":
|
|
@@ -2266,8 +1906,6 @@ def get_train_val_datasets(
|
|
|
2266
1906
|
cache_img_path=train_cache_img_path,
|
|
2267
1907
|
use_existing_imgs=use_existing_imgs,
|
|
2268
1908
|
rank=rank,
|
|
2269
|
-
parallel_caching=parallel_caching,
|
|
2270
|
-
cache_workers=cache_workers,
|
|
2271
1909
|
)
|
|
2272
1910
|
val_dataset = TopDownCenteredInstanceMultiClassDataset(
|
|
2273
1911
|
labels=val_labels,
|
|
@@ -2293,8 +1931,6 @@ def get_train_val_datasets(
|
|
|
2293
1931
|
cache_img_path=val_cache_img_path,
|
|
2294
1932
|
use_existing_imgs=use_existing_imgs,
|
|
2295
1933
|
rank=rank,
|
|
2296
|
-
parallel_caching=parallel_caching,
|
|
2297
|
-
cache_workers=cache_workers,
|
|
2298
1934
|
)
|
|
2299
1935
|
|
|
2300
1936
|
elif model_type == "centroid":
|
|
@@ -2331,8 +1967,6 @@ def get_train_val_datasets(
|
|
|
2331
1967
|
cache_img_path=train_cache_img_path,
|
|
2332
1968
|
use_existing_imgs=use_existing_imgs,
|
|
2333
1969
|
rank=rank,
|
|
2334
|
-
parallel_caching=parallel_caching,
|
|
2335
|
-
cache_workers=cache_workers,
|
|
2336
1970
|
)
|
|
2337
1971
|
val_dataset = CentroidDataset(
|
|
2338
1972
|
labels=val_labels,
|
|
@@ -2356,8 +1990,6 @@ def get_train_val_datasets(
|
|
|
2356
1990
|
cache_img_path=val_cache_img_path,
|
|
2357
1991
|
use_existing_imgs=use_existing_imgs,
|
|
2358
1992
|
rank=rank,
|
|
2359
|
-
parallel_caching=parallel_caching,
|
|
2360
|
-
cache_workers=cache_workers,
|
|
2361
1993
|
)
|
|
2362
1994
|
|
|
2363
1995
|
else:
|
|
@@ -2390,8 +2022,6 @@ def get_train_val_datasets(
|
|
|
2390
2022
|
cache_img_path=train_cache_img_path,
|
|
2391
2023
|
use_existing_imgs=use_existing_imgs,
|
|
2392
2024
|
rank=rank,
|
|
2393
|
-
parallel_caching=parallel_caching,
|
|
2394
|
-
cache_workers=cache_workers,
|
|
2395
2025
|
)
|
|
2396
2026
|
val_dataset = SingleInstanceDataset(
|
|
2397
2027
|
labels=val_labels,
|
|
@@ -2414,8 +2044,6 @@ def get_train_val_datasets(
|
|
|
2414
2044
|
cache_img_path=val_cache_img_path,
|
|
2415
2045
|
use_existing_imgs=use_existing_imgs,
|
|
2416
2046
|
rank=rank,
|
|
2417
|
-
parallel_caching=parallel_caching,
|
|
2418
|
-
cache_workers=cache_workers,
|
|
2419
2047
|
)
|
|
2420
2048
|
|
|
2421
2049
|
# If using caching, close the videos to prevent `h5py objects can't be pickled error` when num_workers > 0.
|