sleap-nn 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +9 -2
- sleap_nn/architectures/convnext.py +5 -0
- sleap_nn/architectures/encoder_decoder.py +25 -6
- sleap_nn/architectures/swint.py +8 -0
- sleap_nn/cli.py +489 -46
- sleap_nn/config/data_config.py +51 -8
- sleap_nn/config/get_config.py +32 -24
- sleap_nn/config/trainer_config.py +88 -0
- sleap_nn/data/augmentation.py +61 -200
- sleap_nn/data/custom_datasets.py +433 -61
- sleap_nn/data/instance_cropping.py +71 -6
- sleap_nn/data/normalization.py +45 -2
- sleap_nn/data/providers.py +26 -0
- sleap_nn/data/resizing.py +2 -2
- sleap_nn/data/skia_augmentation.py +414 -0
- sleap_nn/data/utils.py +135 -17
- sleap_nn/evaluation.py +177 -42
- sleap_nn/export/__init__.py +21 -0
- sleap_nn/export/cli.py +1778 -0
- sleap_nn/export/exporters/__init__.py +51 -0
- sleap_nn/export/exporters/onnx_exporter.py +80 -0
- sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
- sleap_nn/export/metadata.py +225 -0
- sleap_nn/export/predictors/__init__.py +63 -0
- sleap_nn/export/predictors/base.py +22 -0
- sleap_nn/export/predictors/onnx.py +154 -0
- sleap_nn/export/predictors/tensorrt.py +312 -0
- sleap_nn/export/utils.py +307 -0
- sleap_nn/export/wrappers/__init__.py +25 -0
- sleap_nn/export/wrappers/base.py +96 -0
- sleap_nn/export/wrappers/bottomup.py +243 -0
- sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
- sleap_nn/export/wrappers/centered_instance.py +56 -0
- sleap_nn/export/wrappers/centroid.py +58 -0
- sleap_nn/export/wrappers/single_instance.py +83 -0
- sleap_nn/export/wrappers/topdown.py +180 -0
- sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
- sleap_nn/inference/__init__.py +6 -0
- sleap_nn/inference/bottomup.py +86 -20
- sleap_nn/inference/peak_finding.py +93 -16
- sleap_nn/inference/postprocessing.py +284 -0
- sleap_nn/inference/predictors.py +339 -137
- sleap_nn/inference/provenance.py +292 -0
- sleap_nn/inference/topdown.py +55 -47
- sleap_nn/legacy_models.py +65 -11
- sleap_nn/predict.py +224 -19
- sleap_nn/system_info.py +443 -0
- sleap_nn/tracking/tracker.py +8 -1
- sleap_nn/train.py +138 -44
- sleap_nn/training/callbacks.py +1258 -5
- sleap_nn/training/lightning_modules.py +902 -220
- sleap_nn/training/model_trainer.py +424 -111
- sleap_nn/training/schedulers.py +191 -0
- sleap_nn/training/utils.py +367 -2
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/METADATA +35 -33
- sleap_nn-0.1.0.dist-info/RECORD +88 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/WHEEL +1 -1
- sleap_nn-0.0.5.dist-info/RECORD +0 -63
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/top_level.txt +0 -0
sleap_nn/data/custom_datasets.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
"""Custom `torch.utils.data.Dataset`s for different model types."""
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from sleap_nn.data.skia_augmentation import crop_and_resize_skia as crop_and_resize
|
|
4
4
|
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
5
|
+
import os
|
|
6
|
+
import threading
|
|
7
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
8
|
+
from copy import deepcopy
|
|
8
9
|
from itertools import cycle
|
|
9
10
|
from pathlib import Path
|
|
10
11
|
import torch.distributed as dist
|
|
@@ -13,6 +14,14 @@ from omegaconf import DictConfig, OmegaConf
|
|
|
13
14
|
import numpy as np
|
|
14
15
|
from PIL import Image
|
|
15
16
|
from loguru import logger
|
|
17
|
+
from rich.progress import (
|
|
18
|
+
Progress,
|
|
19
|
+
SpinnerColumn,
|
|
20
|
+
TextColumn,
|
|
21
|
+
BarColumn,
|
|
22
|
+
TimeElapsedColumn,
|
|
23
|
+
)
|
|
24
|
+
from rich.console import Console
|
|
16
25
|
import torch
|
|
17
26
|
import torchvision.transforms as T
|
|
18
27
|
from torch.utils.data import Dataset, DataLoader, DistributedSampler
|
|
@@ -22,7 +31,6 @@ from sleap_nn.data.identity import generate_class_maps, make_class_vectors
|
|
|
22
31
|
from sleap_nn.data.instance_centroids import generate_centroids
|
|
23
32
|
from sleap_nn.data.instance_cropping import generate_crops
|
|
24
33
|
from sleap_nn.data.normalization import (
|
|
25
|
-
apply_normalization,
|
|
26
34
|
convert_to_grayscale,
|
|
27
35
|
convert_to_rgb,
|
|
28
36
|
)
|
|
@@ -38,6 +46,182 @@ from sleap_nn.data.instance_cropping import make_centered_bboxes
|
|
|
38
46
|
from sleap_nn.training.utils import is_distributed_initialized
|
|
39
47
|
from sleap_nn.config.get_config import get_aug_config
|
|
40
48
|
|
|
49
|
+
# Minimum number of samples to use parallel caching (overhead not worth it for smaller)
|
|
50
|
+
MIN_SAMPLES_FOR_PARALLEL_CACHING = 20
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class ParallelCacheFiller:
|
|
54
|
+
"""Parallel implementation of image caching using thread-local video copies.
|
|
55
|
+
|
|
56
|
+
This class uses ThreadPoolExecutor to parallelize I/O-bound operations when
|
|
57
|
+
caching images to disk or memory. Each worker thread gets its own copy of
|
|
58
|
+
video objects to ensure thread safety.
|
|
59
|
+
|
|
60
|
+
Attributes:
|
|
61
|
+
labels: List of sio.Labels objects containing the data.
|
|
62
|
+
lf_idx_list: List of dictionaries with labeled frame indices.
|
|
63
|
+
cache_type: Either "disk" or "memory".
|
|
64
|
+
cache_path: Path to save cached images (for disk caching).
|
|
65
|
+
num_workers: Number of worker threads.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def __init__(
|
|
69
|
+
self,
|
|
70
|
+
labels: List[sio.Labels],
|
|
71
|
+
lf_idx_list: List[Dict],
|
|
72
|
+
cache_type: str,
|
|
73
|
+
cache_path: Optional[Path] = None,
|
|
74
|
+
num_workers: int = 4,
|
|
75
|
+
):
|
|
76
|
+
"""Initialize the parallel cache filler.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
labels: List of sio.Labels objects.
|
|
80
|
+
lf_idx_list: List of sample dictionaries with frame indices.
|
|
81
|
+
cache_type: Either "disk" or "memory".
|
|
82
|
+
cache_path: Path for disk caching.
|
|
83
|
+
num_workers: Number of worker threads.
|
|
84
|
+
"""
|
|
85
|
+
self.labels = labels
|
|
86
|
+
self.lf_idx_list = lf_idx_list
|
|
87
|
+
self.cache_type = cache_type
|
|
88
|
+
self.cache_path = cache_path
|
|
89
|
+
self.num_workers = num_workers
|
|
90
|
+
|
|
91
|
+
self.cache: Dict = {}
|
|
92
|
+
self._cache_lock = threading.Lock()
|
|
93
|
+
self._local = threading.local()
|
|
94
|
+
self._video_info: Dict = {}
|
|
95
|
+
|
|
96
|
+
# Prepare video copies for thread-local access
|
|
97
|
+
self._prepare_video_copies()
|
|
98
|
+
|
|
99
|
+
def _prepare_video_copies(self):
|
|
100
|
+
"""Close original videos and prepare for thread-local copies."""
|
|
101
|
+
for label in self.labels:
|
|
102
|
+
for video in label.videos:
|
|
103
|
+
vid_id = id(video)
|
|
104
|
+
if vid_id not in self._video_info:
|
|
105
|
+
# Store original state
|
|
106
|
+
original_open_backend = video.open_backend
|
|
107
|
+
|
|
108
|
+
# Close the video backend
|
|
109
|
+
video.close()
|
|
110
|
+
video.open_backend = False
|
|
111
|
+
|
|
112
|
+
self._video_info[vid_id] = {
|
|
113
|
+
"video": video,
|
|
114
|
+
"original_open_backend": original_open_backend,
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
def _get_thread_local_video(self, video: sio.Video) -> sio.Video:
|
|
118
|
+
"""Get or create a thread-local video copy.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
video: The original video object.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
A thread-local copy of the video that is safe to use.
|
|
125
|
+
"""
|
|
126
|
+
vid_id = id(video)
|
|
127
|
+
|
|
128
|
+
if not hasattr(self._local, "videos"):
|
|
129
|
+
self._local.videos = {}
|
|
130
|
+
|
|
131
|
+
if vid_id not in self._local.videos:
|
|
132
|
+
# Create a thread-local copy
|
|
133
|
+
video_copy = deepcopy(video)
|
|
134
|
+
video_copy.open_backend = True
|
|
135
|
+
self._local.videos[vid_id] = video_copy
|
|
136
|
+
|
|
137
|
+
return self._local.videos[vid_id]
|
|
138
|
+
|
|
139
|
+
def _process_sample(
|
|
140
|
+
self, sample: Dict
|
|
141
|
+
) -> Tuple[int, int, Optional[np.ndarray], Optional[str]]:
|
|
142
|
+
"""Process a single sample (read image, optionally save/cache).
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
sample: Dictionary with labels_idx, lf_idx, etc.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
Tuple of (labels_idx, lf_idx, image_or_none, error_or_none).
|
|
149
|
+
"""
|
|
150
|
+
labels_idx = sample["labels_idx"]
|
|
151
|
+
lf_idx = sample["lf_idx"]
|
|
152
|
+
|
|
153
|
+
try:
|
|
154
|
+
# Get the labeled frame
|
|
155
|
+
lf = self.labels[labels_idx][lf_idx]
|
|
156
|
+
|
|
157
|
+
# Get thread-local video
|
|
158
|
+
video = self._get_thread_local_video(lf.video)
|
|
159
|
+
|
|
160
|
+
# Read the image
|
|
161
|
+
img = video[lf.frame_idx]
|
|
162
|
+
|
|
163
|
+
if img.shape[-1] == 1:
|
|
164
|
+
img = np.squeeze(img)
|
|
165
|
+
|
|
166
|
+
if self.cache_type == "disk":
|
|
167
|
+
f_name = self.cache_path / f"sample_{labels_idx}_{lf_idx}.jpg"
|
|
168
|
+
Image.fromarray(img).save(str(f_name), format="JPEG")
|
|
169
|
+
return labels_idx, lf_idx, None, None
|
|
170
|
+
elif self.cache_type == "memory":
|
|
171
|
+
return labels_idx, lf_idx, img, None
|
|
172
|
+
|
|
173
|
+
except Exception as e:
|
|
174
|
+
return labels_idx, lf_idx, None, f"{type(e).__name__}: {str(e)}"
|
|
175
|
+
|
|
176
|
+
def fill_cache(
|
|
177
|
+
self, progress_callback=None
|
|
178
|
+
) -> Tuple[Dict, List[Tuple[int, int, str]]]:
|
|
179
|
+
"""Fill the cache in parallel.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
progress_callback: Optional callback(completed_count) for progress updates.
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
Tuple of (cache_dict, list_of_errors).
|
|
186
|
+
"""
|
|
187
|
+
errors = []
|
|
188
|
+
completed = 0
|
|
189
|
+
|
|
190
|
+
with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
|
|
191
|
+
futures = {
|
|
192
|
+
executor.submit(self._process_sample, sample): sample
|
|
193
|
+
for sample in self.lf_idx_list
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
for future in as_completed(futures):
|
|
197
|
+
labels_idx, lf_idx, img, error = future.result()
|
|
198
|
+
|
|
199
|
+
if error:
|
|
200
|
+
errors.append((labels_idx, lf_idx, error))
|
|
201
|
+
elif self.cache_type == "memory" and img is not None:
|
|
202
|
+
with self._cache_lock:
|
|
203
|
+
self.cache[(labels_idx, lf_idx)] = img
|
|
204
|
+
|
|
205
|
+
completed += 1
|
|
206
|
+
if progress_callback:
|
|
207
|
+
progress_callback(completed)
|
|
208
|
+
|
|
209
|
+
# Restore original video states
|
|
210
|
+
self._restore_videos()
|
|
211
|
+
|
|
212
|
+
return self.cache, errors
|
|
213
|
+
|
|
214
|
+
def _restore_videos(self):
|
|
215
|
+
"""Restore original video states after caching is complete."""
|
|
216
|
+
for vid_info in self._video_info.values():
|
|
217
|
+
video = vid_info["video"]
|
|
218
|
+
video.open_backend = vid_info["original_open_backend"]
|
|
219
|
+
if video.open_backend:
|
|
220
|
+
try:
|
|
221
|
+
video.open()
|
|
222
|
+
except Exception:
|
|
223
|
+
pass
|
|
224
|
+
|
|
41
225
|
|
|
42
226
|
class BaseDataset(Dataset):
|
|
43
227
|
"""Base class for custom torch Datasets.
|
|
@@ -76,6 +260,8 @@ class BaseDataset(Dataset):
|
|
|
76
260
|
use_existing_imgs: Use existing imgs/ chunks in the `cache_img_path`.
|
|
77
261
|
rank: Indicates the rank of the process. Used during distributed training to ensure that image storage to
|
|
78
262
|
disk occurs only once across all workers.
|
|
263
|
+
parallel_caching: If True, use parallel processing for caching (faster for large datasets). Default: True.
|
|
264
|
+
cache_workers: Number of worker threads for parallel caching. If 0, uses min(4, cpu_count). Default: 0.
|
|
79
265
|
labels_list: List of `sio.Labels` objects. Used to store the labels in the cache. (only used if `cache_img` is `None`)
|
|
80
266
|
"""
|
|
81
267
|
|
|
@@ -95,6 +281,8 @@ class BaseDataset(Dataset):
|
|
|
95
281
|
cache_img_path: Optional[str] = None,
|
|
96
282
|
use_existing_imgs: bool = False,
|
|
97
283
|
rank: Optional[int] = None,
|
|
284
|
+
parallel_caching: bool = True,
|
|
285
|
+
cache_workers: int = 0,
|
|
98
286
|
) -> None:
|
|
99
287
|
"""Initialize class attributes."""
|
|
100
288
|
super().__init__()
|
|
@@ -135,6 +323,8 @@ class BaseDataset(Dataset):
|
|
|
135
323
|
self.cache_img = cache_img
|
|
136
324
|
self.cache_img_path = cache_img_path
|
|
137
325
|
self.use_existing_imgs = use_existing_imgs
|
|
326
|
+
self.parallel_caching = parallel_caching
|
|
327
|
+
self.cache_workers = cache_workers
|
|
138
328
|
if self.cache_img is not None and "disk" in self.cache_img:
|
|
139
329
|
if self.cache_img_path is None:
|
|
140
330
|
self.cache_img_path = "."
|
|
@@ -160,10 +350,18 @@ class BaseDataset(Dataset):
|
|
|
160
350
|
|
|
161
351
|
if self.cache_img is not None:
|
|
162
352
|
if self.cache_img == "memory":
|
|
163
|
-
self._fill_cache(
|
|
353
|
+
self._fill_cache(
|
|
354
|
+
labels,
|
|
355
|
+
parallel=self.parallel_caching,
|
|
356
|
+
num_workers=self.cache_workers,
|
|
357
|
+
)
|
|
164
358
|
elif self.cache_img == "disk" and not self.use_existing_imgs:
|
|
165
359
|
if self.rank is None or self.rank == -1 or self.rank == 0:
|
|
166
|
-
self._fill_cache(
|
|
360
|
+
self._fill_cache(
|
|
361
|
+
labels,
|
|
362
|
+
parallel=self.parallel_caching,
|
|
363
|
+
num_workers=self.cache_workers,
|
|
364
|
+
)
|
|
167
365
|
# Synchronize all ranks after cache creation
|
|
168
366
|
if is_distributed_initialized():
|
|
169
367
|
dist.barrier()
|
|
@@ -177,6 +375,9 @@ class BaseDataset(Dataset):
|
|
|
177
375
|
if self.user_instances_only:
|
|
178
376
|
if lf.user_instances is not None and len(lf.user_instances) > 0:
|
|
179
377
|
lf.instances = lf.user_instances
|
|
378
|
+
else:
|
|
379
|
+
# Skip frames without user instances
|
|
380
|
+
continue
|
|
180
381
|
is_empty = True
|
|
181
382
|
for _, inst in enumerate(lf.instances):
|
|
182
383
|
if not inst.is_empty: # filter all NaN instances.
|
|
@@ -209,20 +410,160 @@ class BaseDataset(Dataset):
|
|
|
209
410
|
"""Returns an iterator."""
|
|
210
411
|
return self
|
|
211
412
|
|
|
212
|
-
def _fill_cache(
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
413
|
+
def _fill_cache(
|
|
414
|
+
self,
|
|
415
|
+
labels: List[sio.Labels],
|
|
416
|
+
parallel: bool = True,
|
|
417
|
+
num_workers: int = 0,
|
|
418
|
+
):
|
|
419
|
+
"""Load all samples to cache.
|
|
420
|
+
|
|
421
|
+
Args:
|
|
422
|
+
labels: List of sio.Labels objects containing the data.
|
|
423
|
+
parallel: If True, use parallel processing for caching (faster for large
|
|
424
|
+
datasets). Default: True.
|
|
425
|
+
num_workers: Number of worker threads for parallel caching. If 0, uses
|
|
426
|
+
min(4, cpu_count). Default: 0.
|
|
427
|
+
"""
|
|
428
|
+
total_samples = len(self.lf_idx_list)
|
|
429
|
+
cache_type = "disk" if self.cache_img == "disk" else "memory"
|
|
430
|
+
|
|
431
|
+
# Check for NO_COLOR env var to disable progress bar
|
|
432
|
+
no_color = (
|
|
433
|
+
os.environ.get("NO_COLOR") is not None
|
|
434
|
+
or os.environ.get("FORCE_COLOR") == "0"
|
|
435
|
+
)
|
|
436
|
+
use_progress = not no_color
|
|
437
|
+
|
|
438
|
+
# Use parallel caching for larger datasets
|
|
439
|
+
use_parallel = parallel and total_samples >= MIN_SAMPLES_FOR_PARALLEL_CACHING
|
|
440
|
+
|
|
441
|
+
logger.info(f"Caching {total_samples} images to {cache_type}...")
|
|
442
|
+
|
|
443
|
+
if use_parallel:
|
|
444
|
+
self._fill_cache_parallel(
|
|
445
|
+
labels, total_samples, cache_type, use_progress, num_workers
|
|
446
|
+
)
|
|
447
|
+
else:
|
|
448
|
+
self._fill_cache_sequential(labels, total_samples, cache_type, use_progress)
|
|
449
|
+
|
|
450
|
+
logger.info(f"Caching complete.")
|
|
451
|
+
|
|
452
|
+
def _fill_cache_sequential(
|
|
453
|
+
self,
|
|
454
|
+
labels: List[sio.Labels],
|
|
455
|
+
total_samples: int,
|
|
456
|
+
cache_type: str,
|
|
457
|
+
use_progress: bool,
|
|
458
|
+
):
|
|
459
|
+
"""Sequential implementation of cache filling.
|
|
460
|
+
|
|
461
|
+
Args:
|
|
462
|
+
labels: List of sio.Labels objects.
|
|
463
|
+
total_samples: Total number of samples to cache.
|
|
464
|
+
cache_type: Either "disk" or "memory".
|
|
465
|
+
use_progress: Whether to show a progress bar.
|
|
466
|
+
"""
|
|
467
|
+
|
|
468
|
+
def process_samples(progress=None, task=None):
|
|
469
|
+
for sample in self.lf_idx_list:
|
|
470
|
+
labels_idx = sample["labels_idx"]
|
|
471
|
+
lf_idx = sample["lf_idx"]
|
|
472
|
+
img = labels[labels_idx][lf_idx].image
|
|
473
|
+
if img.shape[-1] == 1:
|
|
474
|
+
img = np.squeeze(img)
|
|
475
|
+
if self.cache_img == "disk":
|
|
476
|
+
f_name = f"{self.cache_img_path}/sample_{labels_idx}_{lf_idx}.jpg"
|
|
477
|
+
Image.fromarray(img).save(f_name, format="JPEG")
|
|
478
|
+
if self.cache_img == "memory":
|
|
479
|
+
self.cache[(labels_idx, lf_idx)] = img
|
|
480
|
+
if progress is not None:
|
|
481
|
+
progress.update(task, advance=1)
|
|
482
|
+
|
|
483
|
+
if use_progress:
|
|
484
|
+
with Progress(
|
|
485
|
+
SpinnerColumn(),
|
|
486
|
+
TextColumn("[progress.description]{task.description}"),
|
|
487
|
+
BarColumn(),
|
|
488
|
+
TextColumn("{task.completed}/{task.total}"),
|
|
489
|
+
TimeElapsedColumn(),
|
|
490
|
+
console=Console(force_terminal=True),
|
|
491
|
+
transient=True,
|
|
492
|
+
) as progress:
|
|
493
|
+
task = progress.add_task(
|
|
494
|
+
f"Caching images to {cache_type}", total=total_samples
|
|
495
|
+
)
|
|
496
|
+
process_samples(progress, task)
|
|
497
|
+
else:
|
|
498
|
+
process_samples()
|
|
499
|
+
|
|
500
|
+
def _fill_cache_parallel(
|
|
501
|
+
self,
|
|
502
|
+
labels: List[sio.Labels],
|
|
503
|
+
total_samples: int,
|
|
504
|
+
cache_type: str,
|
|
505
|
+
use_progress: bool,
|
|
506
|
+
num_workers: int = 0,
|
|
507
|
+
):
|
|
508
|
+
"""Parallel implementation of cache filling using thread-local video copies.
|
|
509
|
+
|
|
510
|
+
Args:
|
|
511
|
+
labels: List of sio.Labels objects.
|
|
512
|
+
total_samples: Total number of samples to cache.
|
|
513
|
+
cache_type: Either "disk" or "memory".
|
|
514
|
+
use_progress: Whether to show a progress bar.
|
|
515
|
+
num_workers: Number of worker threads. If 0, uses min(4, cpu_count).
|
|
516
|
+
"""
|
|
517
|
+
# Determine number of workers
|
|
518
|
+
if num_workers <= 0:
|
|
519
|
+
num_workers = min(4, os.cpu_count() or 1)
|
|
520
|
+
|
|
521
|
+
cache_path = Path(self.cache_img_path) if self.cache_img_path else None
|
|
522
|
+
|
|
523
|
+
filler = ParallelCacheFiller(
|
|
524
|
+
labels=labels,
|
|
525
|
+
lf_idx_list=self.lf_idx_list,
|
|
526
|
+
cache_type=cache_type,
|
|
527
|
+
cache_path=cache_path,
|
|
528
|
+
num_workers=num_workers,
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
if use_progress:
|
|
532
|
+
with Progress(
|
|
533
|
+
SpinnerColumn(),
|
|
534
|
+
TextColumn("[progress.description]{task.description}"),
|
|
535
|
+
BarColumn(),
|
|
536
|
+
TextColumn("{task.completed}/{task.total}"),
|
|
537
|
+
TimeElapsedColumn(),
|
|
538
|
+
console=Console(force_terminal=True),
|
|
539
|
+
transient=True,
|
|
540
|
+
) as progress:
|
|
541
|
+
task = progress.add_task(
|
|
542
|
+
f"Caching images to {cache_type} (parallel, {num_workers} workers)",
|
|
543
|
+
total=total_samples,
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
def progress_callback(completed):
|
|
547
|
+
progress.update(task, completed=completed)
|
|
548
|
+
|
|
549
|
+
cache, errors = filler.fill_cache(progress_callback)
|
|
550
|
+
else:
|
|
551
|
+
logger.info(
|
|
552
|
+
f"Caching {total_samples} images to {cache_type} "
|
|
553
|
+
f"(parallel, {num_workers} workers)..."
|
|
554
|
+
)
|
|
555
|
+
cache, errors = filler.fill_cache()
|
|
556
|
+
|
|
557
|
+
# Update instance cache
|
|
558
|
+
if cache_type == "memory":
|
|
559
|
+
self.cache.update(cache)
|
|
560
|
+
|
|
561
|
+
# Log any errors
|
|
562
|
+
if errors:
|
|
563
|
+
logger.warning(
|
|
564
|
+
f"Parallel caching completed with {len(errors)} errors. "
|
|
565
|
+
f"First error: {errors[0]}"
|
|
566
|
+
)
|
|
226
567
|
|
|
227
568
|
def __len__(self) -> int:
|
|
228
569
|
"""Return the number of samples in the dataset."""
|
|
@@ -298,6 +639,8 @@ class BottomUpDataset(BaseDataset):
|
|
|
298
639
|
cache_img_path: Optional[str] = None,
|
|
299
640
|
use_existing_imgs: bool = False,
|
|
300
641
|
rank: Optional[int] = None,
|
|
642
|
+
parallel_caching: bool = True,
|
|
643
|
+
cache_workers: int = 0,
|
|
301
644
|
) -> None:
|
|
302
645
|
"""Initialize class attributes."""
|
|
303
646
|
super().__init__(
|
|
@@ -315,6 +658,8 @@ class BottomUpDataset(BaseDataset):
|
|
|
315
658
|
cache_img_path=cache_img_path,
|
|
316
659
|
use_existing_imgs=use_existing_imgs,
|
|
317
660
|
rank=rank,
|
|
661
|
+
parallel_caching=parallel_caching,
|
|
662
|
+
cache_workers=cache_workers,
|
|
318
663
|
)
|
|
319
664
|
self.confmap_head_config = confmap_head_config
|
|
320
665
|
self.pafs_head_config = pafs_head_config
|
|
@@ -357,9 +702,6 @@ class BottomUpDataset(BaseDataset):
|
|
|
357
702
|
user_instances_only=self.user_instances_only,
|
|
358
703
|
)
|
|
359
704
|
|
|
360
|
-
# apply normalization
|
|
361
|
-
sample["image"] = apply_normalization(sample["image"])
|
|
362
|
-
|
|
363
705
|
if self.ensure_rgb:
|
|
364
706
|
sample["image"] = convert_to_rgb(sample["image"])
|
|
365
707
|
elif self.ensure_grayscale:
|
|
@@ -497,6 +839,8 @@ class BottomUpMultiClassDataset(BaseDataset):
|
|
|
497
839
|
cache_img_path: Optional[str] = None,
|
|
498
840
|
use_existing_imgs: bool = False,
|
|
499
841
|
rank: Optional[int] = None,
|
|
842
|
+
parallel_caching: bool = True,
|
|
843
|
+
cache_workers: int = 0,
|
|
500
844
|
) -> None:
|
|
501
845
|
"""Initialize class attributes."""
|
|
502
846
|
super().__init__(
|
|
@@ -514,6 +858,8 @@ class BottomUpMultiClassDataset(BaseDataset):
|
|
|
514
858
|
cache_img_path=cache_img_path,
|
|
515
859
|
use_existing_imgs=use_existing_imgs,
|
|
516
860
|
rank=rank,
|
|
861
|
+
parallel_caching=parallel_caching,
|
|
862
|
+
cache_workers=cache_workers,
|
|
517
863
|
)
|
|
518
864
|
self.confmap_head_config = confmap_head_config
|
|
519
865
|
self.class_maps_head_config = class_maps_head_config
|
|
@@ -570,9 +916,6 @@ class BottomUpMultiClassDataset(BaseDataset):
|
|
|
570
916
|
|
|
571
917
|
sample["num_tracks"] = torch.tensor(len(self.class_names), dtype=torch.int32)
|
|
572
918
|
|
|
573
|
-
# apply normalization
|
|
574
|
-
sample["image"] = apply_normalization(sample["image"])
|
|
575
|
-
|
|
576
919
|
if self.ensure_rgb:
|
|
577
920
|
sample["image"] = convert_to_rgb(sample["image"])
|
|
578
921
|
elif self.ensure_grayscale:
|
|
@@ -684,15 +1027,12 @@ class CenteredInstanceDataset(BaseDataset):
|
|
|
684
1027
|
the images aren't cached and loaded from the `.slp` file on each access.
|
|
685
1028
|
cache_img_path: Path to save the `.jpg` files. If `None`, current working dir is used.
|
|
686
1029
|
use_existing_imgs: Use existing imgs/ chunks in the `cache_img_path`.
|
|
687
|
-
crop_size: Crop size of each instance for centered-instance model.
|
|
1030
|
+
crop_size: Crop size of each instance for centered-instance model. If `scale` is provided, then the cropped image will be resized according to `scale`.
|
|
688
1031
|
rank: Indicates the rank of the process. Used during distributed training to ensure that image storage to
|
|
689
1032
|
disk occurs only once across all workers.
|
|
690
1033
|
confmap_head_config: DictConfig object with all the keys in the `head_config` section.
|
|
691
1034
|
(required keys: `sigma`, `output_stride`, `part_names` and `anchor_part` depending on the model type ).
|
|
692
1035
|
labels_list: List of `sio.Labels` objects. Used to store the labels in the cache. (only used if `cache_img` is `None`)
|
|
693
|
-
|
|
694
|
-
Note: If scale is provided for centered-instance model, the images are cropped out
|
|
695
|
-
from the scaled image with the given crop size.
|
|
696
1036
|
"""
|
|
697
1037
|
|
|
698
1038
|
def __init__(
|
|
@@ -714,6 +1054,8 @@ class CenteredInstanceDataset(BaseDataset):
|
|
|
714
1054
|
cache_img_path: Optional[str] = None,
|
|
715
1055
|
use_existing_imgs: bool = False,
|
|
716
1056
|
rank: Optional[int] = None,
|
|
1057
|
+
parallel_caching: bool = True,
|
|
1058
|
+
cache_workers: int = 0,
|
|
717
1059
|
) -> None:
|
|
718
1060
|
"""Initialize class attributes."""
|
|
719
1061
|
super().__init__(
|
|
@@ -731,6 +1073,8 @@ class CenteredInstanceDataset(BaseDataset):
|
|
|
731
1073
|
cache_img_path=cache_img_path,
|
|
732
1074
|
use_existing_imgs=use_existing_imgs,
|
|
733
1075
|
rank=rank,
|
|
1076
|
+
parallel_caching=parallel_caching,
|
|
1077
|
+
cache_workers=cache_workers,
|
|
734
1078
|
)
|
|
735
1079
|
self.labels = None
|
|
736
1080
|
self.crop_size = crop_size
|
|
@@ -748,6 +1092,9 @@ class CenteredInstanceDataset(BaseDataset):
|
|
|
748
1092
|
if self.user_instances_only:
|
|
749
1093
|
if lf.user_instances is not None and len(lf.user_instances) > 0:
|
|
750
1094
|
lf.instances = lf.user_instances
|
|
1095
|
+
else:
|
|
1096
|
+
# Skip frames without user instances
|
|
1097
|
+
continue
|
|
751
1098
|
for inst_idx, inst in enumerate(lf.instances):
|
|
752
1099
|
if not inst.is_empty: # filter all NaN instances.
|
|
753
1100
|
video_idx = labels[labels_idx].videos.index(lf.video)
|
|
@@ -818,9 +1165,6 @@ class CenteredInstanceDataset(BaseDataset):
|
|
|
818
1165
|
|
|
819
1166
|
instances = instances[:, inst_idx]
|
|
820
1167
|
|
|
821
|
-
# apply normalization
|
|
822
|
-
image = apply_normalization(image)
|
|
823
|
-
|
|
824
1168
|
if self.ensure_rgb:
|
|
825
1169
|
image = convert_to_rgb(image)
|
|
826
1170
|
elif self.ensure_grayscale:
|
|
@@ -834,13 +1178,6 @@ class CenteredInstanceDataset(BaseDataset):
|
|
|
834
1178
|
)
|
|
835
1179
|
instances = instances * eff_scale
|
|
836
1180
|
|
|
837
|
-
# resize image
|
|
838
|
-
image, instances = apply_resizer(
|
|
839
|
-
image,
|
|
840
|
-
instances,
|
|
841
|
-
scale=self.scale,
|
|
842
|
-
)
|
|
843
|
-
|
|
844
1181
|
# get the centroids based on the anchor idx
|
|
845
1182
|
centroids = generate_centroids(instances, anchor_ind=self.anchor_ind)
|
|
846
1183
|
|
|
@@ -901,6 +1238,13 @@ class CenteredInstanceDataset(BaseDataset):
|
|
|
901
1238
|
sample["instance"] = center_instance # (n_samples=1, n_nodes, 2)
|
|
902
1239
|
sample["centroid"] = centered_centroid # (n_samples=1, 2)
|
|
903
1240
|
|
|
1241
|
+
# resize the cropped image
|
|
1242
|
+
sample["instance_image"], sample["instance"] = apply_resizer(
|
|
1243
|
+
sample["instance_image"],
|
|
1244
|
+
sample["instance"],
|
|
1245
|
+
scale=self.scale,
|
|
1246
|
+
)
|
|
1247
|
+
|
|
904
1248
|
# Pad the image (if needed) according max stride
|
|
905
1249
|
sample["instance_image"] = apply_pad_to_stride(
|
|
906
1250
|
sample["instance_image"], max_stride=self.max_stride
|
|
@@ -959,7 +1303,7 @@ class TopDownCenteredInstanceMultiClassDataset(CenteredInstanceDataset):
|
|
|
959
1303
|
the images aren't cached and loaded from the `.slp` file on each access.
|
|
960
1304
|
cache_img_path: Path to save the `.jpg` files. If `None`, current working dir is used.
|
|
961
1305
|
use_existing_imgs: Use existing imgs/ chunks in the `cache_img_path`.
|
|
962
|
-
crop_size: Crop size of each instance for centered-instance model.
|
|
1306
|
+
crop_size: Crop size of each instance for centered-instance model. If `scale` is provided, then the cropped image will be resized according to `scale`.
|
|
963
1307
|
rank: Indicates the rank of the process. Used during distributed training to ensure that image storage to
|
|
964
1308
|
disk occurs only once across all workers.
|
|
965
1309
|
confmap_head_config: DictConfig object with all the keys in the `head_config` section.
|
|
@@ -967,9 +1311,6 @@ class TopDownCenteredInstanceMultiClassDataset(CenteredInstanceDataset):
|
|
|
967
1311
|
class_vectors_head_config: DictConfig object with all the keys in the `head_config` section.
|
|
968
1312
|
(required keys: `classes`, `num_fc_layers`, `num_fc_units`, `output_stride`, `loss_weight`).
|
|
969
1313
|
labels_list: List of `sio.Labels` objects. Used to store the labels in the cache. (only used if `cache_img` is `None`)
|
|
970
|
-
|
|
971
|
-
Note: If scale is provided for centered-instance model, the images are cropped out
|
|
972
|
-
from the scaled image with the given crop size.
|
|
973
1314
|
"""
|
|
974
1315
|
|
|
975
1316
|
def __init__(
|
|
@@ -992,6 +1333,8 @@ class TopDownCenteredInstanceMultiClassDataset(CenteredInstanceDataset):
|
|
|
992
1333
|
cache_img_path: Optional[str] = None,
|
|
993
1334
|
use_existing_imgs: bool = False,
|
|
994
1335
|
rank: Optional[int] = None,
|
|
1336
|
+
parallel_caching: bool = True,
|
|
1337
|
+
cache_workers: int = 0,
|
|
995
1338
|
) -> None:
|
|
996
1339
|
"""Initialize class attributes."""
|
|
997
1340
|
super().__init__(
|
|
@@ -1012,6 +1355,8 @@ class TopDownCenteredInstanceMultiClassDataset(CenteredInstanceDataset):
|
|
|
1012
1355
|
cache_img_path=cache_img_path,
|
|
1013
1356
|
use_existing_imgs=use_existing_imgs,
|
|
1014
1357
|
rank=rank,
|
|
1358
|
+
parallel_caching=parallel_caching,
|
|
1359
|
+
cache_workers=cache_workers,
|
|
1015
1360
|
)
|
|
1016
1361
|
self.class_vectors_head_config = class_vectors_head_config
|
|
1017
1362
|
self.class_names = self.class_vectors_head_config.classes
|
|
@@ -1066,9 +1411,6 @@ class TopDownCenteredInstanceMultiClassDataset(CenteredInstanceDataset):
|
|
|
1066
1411
|
|
|
1067
1412
|
instances = instances[:, inst_idx]
|
|
1068
1413
|
|
|
1069
|
-
# apply normalization
|
|
1070
|
-
image = apply_normalization(image)
|
|
1071
|
-
|
|
1072
1414
|
if self.ensure_rgb:
|
|
1073
1415
|
image = convert_to_rgb(image)
|
|
1074
1416
|
elif self.ensure_grayscale:
|
|
@@ -1082,13 +1424,6 @@ class TopDownCenteredInstanceMultiClassDataset(CenteredInstanceDataset):
|
|
|
1082
1424
|
)
|
|
1083
1425
|
instances = instances * eff_scale
|
|
1084
1426
|
|
|
1085
|
-
# resize image
|
|
1086
|
-
image, instances = apply_resizer(
|
|
1087
|
-
image,
|
|
1088
|
-
instances,
|
|
1089
|
-
scale=self.scale,
|
|
1090
|
-
)
|
|
1091
|
-
|
|
1092
1427
|
# get class vectors
|
|
1093
1428
|
track_ids = torch.Tensor(
|
|
1094
1429
|
[
|
|
@@ -1165,6 +1500,13 @@ class TopDownCenteredInstanceMultiClassDataset(CenteredInstanceDataset):
|
|
|
1165
1500
|
sample["instance"] = center_instance # (n_samples=1, n_nodes, 2)
|
|
1166
1501
|
sample["centroid"] = centered_centroid # (n_samples=1, 2)
|
|
1167
1502
|
|
|
1503
|
+
# resize image
|
|
1504
|
+
sample["instance_image"], sample["instance"] = apply_resizer(
|
|
1505
|
+
sample["instance_image"],
|
|
1506
|
+
sample["instance"],
|
|
1507
|
+
scale=self.scale,
|
|
1508
|
+
)
|
|
1509
|
+
|
|
1168
1510
|
# Pad the image (if needed) according max stride
|
|
1169
1511
|
sample["instance_image"] = apply_pad_to_stride(
|
|
1170
1512
|
sample["instance_image"], max_stride=self.max_stride
|
|
@@ -1250,6 +1592,8 @@ class CentroidDataset(BaseDataset):
|
|
|
1250
1592
|
cache_img_path: Optional[str] = None,
|
|
1251
1593
|
use_existing_imgs: bool = False,
|
|
1252
1594
|
rank: Optional[int] = None,
|
|
1595
|
+
parallel_caching: bool = True,
|
|
1596
|
+
cache_workers: int = 0,
|
|
1253
1597
|
) -> None:
|
|
1254
1598
|
"""Initialize class attributes."""
|
|
1255
1599
|
super().__init__(
|
|
@@ -1267,6 +1611,8 @@ class CentroidDataset(BaseDataset):
|
|
|
1267
1611
|
cache_img_path=cache_img_path,
|
|
1268
1612
|
use_existing_imgs=use_existing_imgs,
|
|
1269
1613
|
rank=rank,
|
|
1614
|
+
parallel_caching=parallel_caching,
|
|
1615
|
+
cache_workers=cache_workers,
|
|
1270
1616
|
)
|
|
1271
1617
|
self.anchor_ind = anchor_ind
|
|
1272
1618
|
self.confmap_head_config = confmap_head_config
|
|
@@ -1306,9 +1652,6 @@ class CentroidDataset(BaseDataset):
|
|
|
1306
1652
|
user_instances_only=self.user_instances_only,
|
|
1307
1653
|
)
|
|
1308
1654
|
|
|
1309
|
-
# apply normalization
|
|
1310
|
-
sample["image"] = apply_normalization(sample["image"])
|
|
1311
|
-
|
|
1312
1655
|
if self.ensure_rgb:
|
|
1313
1656
|
sample["image"] = convert_to_rgb(sample["image"])
|
|
1314
1657
|
elif self.ensure_grayscale:
|
|
@@ -1433,6 +1776,8 @@ class SingleInstanceDataset(BaseDataset):
|
|
|
1433
1776
|
cache_img_path: Optional[str] = None,
|
|
1434
1777
|
use_existing_imgs: bool = False,
|
|
1435
1778
|
rank: Optional[int] = None,
|
|
1779
|
+
parallel_caching: bool = True,
|
|
1780
|
+
cache_workers: int = 0,
|
|
1436
1781
|
) -> None:
|
|
1437
1782
|
"""Initialize class attributes."""
|
|
1438
1783
|
super().__init__(
|
|
@@ -1450,6 +1795,8 @@ class SingleInstanceDataset(BaseDataset):
|
|
|
1450
1795
|
cache_img_path=cache_img_path,
|
|
1451
1796
|
use_existing_imgs=use_existing_imgs,
|
|
1452
1797
|
rank=rank,
|
|
1798
|
+
parallel_caching=parallel_caching,
|
|
1799
|
+
cache_workers=cache_workers,
|
|
1453
1800
|
)
|
|
1454
1801
|
self.confmap_head_config = confmap_head_config
|
|
1455
1802
|
|
|
@@ -1488,9 +1835,6 @@ class SingleInstanceDataset(BaseDataset):
|
|
|
1488
1835
|
user_instances_only=self.user_instances_only,
|
|
1489
1836
|
)
|
|
1490
1837
|
|
|
1491
|
-
# apply normalization
|
|
1492
|
-
sample["image"] = apply_normalization(sample["image"])
|
|
1493
|
-
|
|
1494
1838
|
if self.ensure_rgb:
|
|
1495
1839
|
sample["image"] = convert_to_rgb(sample["image"])
|
|
1496
1840
|
elif self.ensure_grayscale:
|
|
@@ -1671,6 +2015,10 @@ def get_train_val_datasets(
|
|
|
1671
2015
|
val_cache_img_path = Path(base_cache_img_path) / "val_imgs"
|
|
1672
2016
|
use_existing_imgs = config.data_config.use_existing_imgs
|
|
1673
2017
|
|
|
2018
|
+
# Parallel caching configuration
|
|
2019
|
+
parallel_caching = getattr(config.data_config, "parallel_caching", True)
|
|
2020
|
+
cache_workers = getattr(config.data_config, "cache_workers", 0)
|
|
2021
|
+
|
|
1674
2022
|
model_type = get_model_type_from_cfg(config=config)
|
|
1675
2023
|
backbone_type = get_backbone_type_from_cfg(config=config)
|
|
1676
2024
|
|
|
@@ -1724,6 +2072,8 @@ def get_train_val_datasets(
|
|
|
1724
2072
|
cache_img_path=train_cache_img_path,
|
|
1725
2073
|
use_existing_imgs=use_existing_imgs,
|
|
1726
2074
|
rank=rank,
|
|
2075
|
+
parallel_caching=parallel_caching,
|
|
2076
|
+
cache_workers=cache_workers,
|
|
1727
2077
|
)
|
|
1728
2078
|
val_dataset = BottomUpDataset(
|
|
1729
2079
|
labels=val_labels,
|
|
@@ -1747,6 +2097,8 @@ def get_train_val_datasets(
|
|
|
1747
2097
|
cache_img_path=val_cache_img_path,
|
|
1748
2098
|
use_existing_imgs=use_existing_imgs,
|
|
1749
2099
|
rank=rank,
|
|
2100
|
+
parallel_caching=parallel_caching,
|
|
2101
|
+
cache_workers=cache_workers,
|
|
1750
2102
|
)
|
|
1751
2103
|
|
|
1752
2104
|
elif model_type == "multi_class_bottomup":
|
|
@@ -1780,6 +2132,8 @@ def get_train_val_datasets(
|
|
|
1780
2132
|
cache_img_path=train_cache_img_path,
|
|
1781
2133
|
use_existing_imgs=use_existing_imgs,
|
|
1782
2134
|
rank=rank,
|
|
2135
|
+
parallel_caching=parallel_caching,
|
|
2136
|
+
cache_workers=cache_workers,
|
|
1783
2137
|
)
|
|
1784
2138
|
val_dataset = BottomUpMultiClassDataset(
|
|
1785
2139
|
labels=val_labels,
|
|
@@ -1803,6 +2157,8 @@ def get_train_val_datasets(
|
|
|
1803
2157
|
cache_img_path=val_cache_img_path,
|
|
1804
2158
|
use_existing_imgs=use_existing_imgs,
|
|
1805
2159
|
rank=rank,
|
|
2160
|
+
parallel_caching=parallel_caching,
|
|
2161
|
+
cache_workers=cache_workers,
|
|
1806
2162
|
)
|
|
1807
2163
|
|
|
1808
2164
|
elif model_type == "centered_instance":
|
|
@@ -1842,6 +2198,8 @@ def get_train_val_datasets(
|
|
|
1842
2198
|
cache_img_path=train_cache_img_path,
|
|
1843
2199
|
use_existing_imgs=use_existing_imgs,
|
|
1844
2200
|
rank=rank,
|
|
2201
|
+
parallel_caching=parallel_caching,
|
|
2202
|
+
cache_workers=cache_workers,
|
|
1845
2203
|
)
|
|
1846
2204
|
val_dataset = CenteredInstanceDataset(
|
|
1847
2205
|
labels=val_labels,
|
|
@@ -1866,6 +2224,8 @@ def get_train_val_datasets(
|
|
|
1866
2224
|
cache_img_path=val_cache_img_path,
|
|
1867
2225
|
use_existing_imgs=use_existing_imgs,
|
|
1868
2226
|
rank=rank,
|
|
2227
|
+
parallel_caching=parallel_caching,
|
|
2228
|
+
cache_workers=cache_workers,
|
|
1869
2229
|
)
|
|
1870
2230
|
|
|
1871
2231
|
elif model_type == "multi_class_topdown":
|
|
@@ -1906,6 +2266,8 @@ def get_train_val_datasets(
|
|
|
1906
2266
|
cache_img_path=train_cache_img_path,
|
|
1907
2267
|
use_existing_imgs=use_existing_imgs,
|
|
1908
2268
|
rank=rank,
|
|
2269
|
+
parallel_caching=parallel_caching,
|
|
2270
|
+
cache_workers=cache_workers,
|
|
1909
2271
|
)
|
|
1910
2272
|
val_dataset = TopDownCenteredInstanceMultiClassDataset(
|
|
1911
2273
|
labels=val_labels,
|
|
@@ -1931,6 +2293,8 @@ def get_train_val_datasets(
|
|
|
1931
2293
|
cache_img_path=val_cache_img_path,
|
|
1932
2294
|
use_existing_imgs=use_existing_imgs,
|
|
1933
2295
|
rank=rank,
|
|
2296
|
+
parallel_caching=parallel_caching,
|
|
2297
|
+
cache_workers=cache_workers,
|
|
1934
2298
|
)
|
|
1935
2299
|
|
|
1936
2300
|
elif model_type == "centroid":
|
|
@@ -1967,6 +2331,8 @@ def get_train_val_datasets(
|
|
|
1967
2331
|
cache_img_path=train_cache_img_path,
|
|
1968
2332
|
use_existing_imgs=use_existing_imgs,
|
|
1969
2333
|
rank=rank,
|
|
2334
|
+
parallel_caching=parallel_caching,
|
|
2335
|
+
cache_workers=cache_workers,
|
|
1970
2336
|
)
|
|
1971
2337
|
val_dataset = CentroidDataset(
|
|
1972
2338
|
labels=val_labels,
|
|
@@ -1990,6 +2356,8 @@ def get_train_val_datasets(
|
|
|
1990
2356
|
cache_img_path=val_cache_img_path,
|
|
1991
2357
|
use_existing_imgs=use_existing_imgs,
|
|
1992
2358
|
rank=rank,
|
|
2359
|
+
parallel_caching=parallel_caching,
|
|
2360
|
+
cache_workers=cache_workers,
|
|
1993
2361
|
)
|
|
1994
2362
|
|
|
1995
2363
|
else:
|
|
@@ -2022,6 +2390,8 @@ def get_train_val_datasets(
|
|
|
2022
2390
|
cache_img_path=train_cache_img_path,
|
|
2023
2391
|
use_existing_imgs=use_existing_imgs,
|
|
2024
2392
|
rank=rank,
|
|
2393
|
+
parallel_caching=parallel_caching,
|
|
2394
|
+
cache_workers=cache_workers,
|
|
2025
2395
|
)
|
|
2026
2396
|
val_dataset = SingleInstanceDataset(
|
|
2027
2397
|
labels=val_labels,
|
|
@@ -2044,6 +2414,8 @@ def get_train_val_datasets(
|
|
|
2044
2414
|
cache_img_path=val_cache_img_path,
|
|
2045
2415
|
use_existing_imgs=use_existing_imgs,
|
|
2046
2416
|
rank=rank,
|
|
2417
|
+
parallel_caching=parallel_caching,
|
|
2418
|
+
cache_workers=cache_workers,
|
|
2047
2419
|
)
|
|
2048
2420
|
|
|
2049
2421
|
# If using caching, close the videos to prevent `h5py objects can't be pickled error` when num_workers > 0.
|